mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 01:06:45 +00:00
Compare commits
4 Commits
refactor/g
...
coderabbit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8c5648bb7b | ||
|
|
b7e98acd1f | ||
|
|
433bc4ead9 | ||
|
|
011cc81678 |
@@ -59,7 +59,6 @@ func init() {
|
|||||||
|
|
||||||
// Client struct manage the life circle of background service
|
// Client struct manage the life circle of background service
|
||||||
type Client struct {
|
type Client struct {
|
||||||
cfgFile string
|
|
||||||
tunAdapter device.TunAdapter
|
tunAdapter device.TunAdapter
|
||||||
iFaceDiscover IFaceDiscover
|
iFaceDiscover IFaceDiscover
|
||||||
recorder *peer.Status
|
recorder *peer.Status
|
||||||
@@ -68,18 +67,16 @@ type Client struct {
|
|||||||
deviceName string
|
deviceName string
|
||||||
uiVersion string
|
uiVersion string
|
||||||
networkChangeListener listener.NetworkChangeListener
|
networkChangeListener listener.NetworkChangeListener
|
||||||
stateFile string
|
|
||||||
|
|
||||||
connectClient *internal.ConnectClient
|
connectClient *internal.ConnectClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
func NewClient(androidSDKVersion int, deviceName string, uiVersion string, tunAdapter TunAdapter, iFaceDiscover IFaceDiscover, networkChangeListener NetworkChangeListener) *Client {
|
||||||
execWorkaround(androidSDKVersion)
|
execWorkaround(androidSDKVersion)
|
||||||
|
|
||||||
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
net.SetAndroidProtectSocketFn(tunAdapter.ProtectSocket)
|
||||||
return &Client{
|
return &Client{
|
||||||
cfgFile: platformFiles.ConfigurationFilePath(),
|
|
||||||
deviceName: deviceName,
|
deviceName: deviceName,
|
||||||
uiVersion: uiVersion,
|
uiVersion: uiVersion,
|
||||||
tunAdapter: tunAdapter,
|
tunAdapter: tunAdapter,
|
||||||
@@ -87,15 +84,20 @@ func NewClient(platformFiles PlatformFiles, androidSDKVersion int, deviceName st
|
|||||||
recorder: peer.NewRecorder(""),
|
recorder: peer.NewRecorder(""),
|
||||||
ctxCancelLock: &sync.Mutex{},
|
ctxCancelLock: &sync.Mutex{},
|
||||||
networkChangeListener: networkChangeListener,
|
networkChangeListener: networkChangeListener,
|
||||||
stateFile: platformFiles.StateFilePath(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run start the internal client. It is a blocker function
|
// Run start the internal client. It is a blocker function
|
||||||
func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||||
exportEnvList(envList)
|
exportEnvList(envList)
|
||||||
|
|
||||||
|
cfgFile := platformFiles.ConfigurationFilePath()
|
||||||
|
stateFile := platformFiles.StateFilePath()
|
||||||
|
|
||||||
|
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
|
||||||
|
|
||||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: cfgFile,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -122,16 +124,22 @@ func (c *Client) Run(urlOpener URLOpener, isAndroidTV bool, dns *DNSList, dnsRea
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
// In this case make no sense handle registration steps.
|
// In this case make no sense handle registration steps.
|
||||||
func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsReadyListener DnsReadyListener, envList *EnvList) error {
|
||||||
exportEnvList(envList)
|
exportEnvList(envList)
|
||||||
|
|
||||||
|
cfgFile := platformFiles.ConfigurationFilePath()
|
||||||
|
stateFile := platformFiles.StateFilePath()
|
||||||
|
|
||||||
|
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
|
||||||
|
|
||||||
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
cfg, err := profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||||
ConfigPath: c.cfgFile,
|
ConfigPath: cfgFile,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -149,8 +157,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, c.stateFile)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
|
|||||||
257
client/android/profile_manager.go
Normal file
257
client/android/profile_manager.go
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package android
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Android-specific config filename (different from desktop default.json)
|
||||||
|
defaultConfigFilename = "netbird.cfg"
|
||||||
|
// Subdirectory for non-default profiles (must match Java Preferences.java)
|
||||||
|
profilesSubdir = "profiles"
|
||||||
|
// Android uses a single user context per app (non-empty username required by ServiceManager)
|
||||||
|
androidUsername = "android"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Profile represents a profile for gomobile
|
||||||
|
type Profile struct {
|
||||||
|
Name string
|
||||||
|
IsActive bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProfileArray wraps profiles for gomobile compatibility
|
||||||
|
type ProfileArray struct {
|
||||||
|
items []*Profile
|
||||||
|
}
|
||||||
|
|
||||||
|
// Length returns the number of profiles
|
||||||
|
func (p *ProfileArray) Length() int {
|
||||||
|
return len(p.items)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the profile at index i
|
||||||
|
func (p *ProfileArray) Get(i int) *Profile {
|
||||||
|
if i < 0 || i >= len(p.items) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return p.items[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
/data/data/io.netbird.client/files/ ← configDir parameter
|
||||||
|
├── netbird.cfg ← Default profile config
|
||||||
|
├── state.json ← Default profile state
|
||||||
|
├── active_profile.json ← Active profile tracker (JSON with Name + Username)
|
||||||
|
└── profiles/ ← Subdirectory for non-default profiles
|
||||||
|
├── work.json ← Work profile config
|
||||||
|
├── work.state.json ← Work profile state
|
||||||
|
├── personal.json ← Personal profile config
|
||||||
|
└── personal.state.json ← Personal profile state
|
||||||
|
*/
|
||||||
|
|
||||||
|
// ProfileManager manages profiles for Android
|
||||||
|
// It wraps the internal profilemanager to provide Android-specific behavior
|
||||||
|
type ProfileManager struct {
|
||||||
|
configDir string
|
||||||
|
serviceMgr *profilemanager.ServiceManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProfileManager creates a new profile manager for Android
|
||||||
|
func NewProfileManager(configDir string) *ProfileManager {
|
||||||
|
// Set the default config path for Android (stored in root configDir, not profiles/)
|
||||||
|
defaultConfigPath := filepath.Join(configDir, defaultConfigFilename)
|
||||||
|
|
||||||
|
// Set global paths for Android
|
||||||
|
profilemanager.DefaultConfigPathDir = configDir
|
||||||
|
profilemanager.DefaultConfigPath = defaultConfigPath
|
||||||
|
profilemanager.ActiveProfileStatePath = filepath.Join(configDir, "active_profile.json")
|
||||||
|
|
||||||
|
// Create ServiceManager with profiles/ subdirectory
|
||||||
|
// This avoids modifying the global ConfigDirOverride for profile listing
|
||||||
|
profilesDir := filepath.Join(configDir, profilesSubdir)
|
||||||
|
serviceMgr := profilemanager.NewServiceManagerWithProfilesDir(defaultConfigPath, profilesDir)
|
||||||
|
|
||||||
|
return &ProfileManager{
|
||||||
|
configDir: configDir,
|
||||||
|
serviceMgr: serviceMgr,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListProfiles returns all available profiles
|
||||||
|
func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
|
||||||
|
// Use ServiceManager (looks in profiles/ directory, checks active_profile.json for IsActive)
|
||||||
|
internalProfiles, err := pm.serviceMgr.ListProfiles(androidUsername)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to list profiles: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert internal profiles to Android Profile type
|
||||||
|
var profiles []*Profile
|
||||||
|
for _, p := range internalProfiles {
|
||||||
|
profiles = append(profiles, &Profile{
|
||||||
|
Name: p.Name,
|
||||||
|
IsActive: p.IsActive,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ProfileArray{items: profiles}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveProfile returns the currently active profile name
|
||||||
|
func (pm *ProfileManager) GetActiveProfile() (string, error) {
|
||||||
|
// Use ServiceManager to stay consistent with ListProfiles
|
||||||
|
// ServiceManager uses active_profile.json
|
||||||
|
activeState, err := pm.serviceMgr.GetActiveProfileState()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||||
|
}
|
||||||
|
return activeState.Name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SwitchProfile switches to a different profile
|
||||||
|
func (pm *ProfileManager) SwitchProfile(profileName string) error {
|
||||||
|
// Use ServiceManager to stay consistent with ListProfiles
|
||||||
|
// ServiceManager uses active_profile.json
|
||||||
|
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||||
|
Name: profileName,
|
||||||
|
Username: androidUsername,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to switch profile: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("switched to profile: %s", profileName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddProfile creates a new profile
|
||||||
|
func (pm *ProfileManager) AddProfile(profileName string) error {
|
||||||
|
// Use ServiceManager (creates profile in profiles/ directory)
|
||||||
|
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
|
||||||
|
return fmt.Errorf("failed to add profile: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("created new profile: %s", profileName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogoutProfile logs out from a profile (clears authentication)
|
||||||
|
func (pm *ProfileManager) LogoutProfile(profileName string) error {
|
||||||
|
profileName = sanitizeProfileName(profileName)
|
||||||
|
|
||||||
|
configPath, err := pm.getProfileConfigPath(profileName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if profile exists
|
||||||
|
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||||
|
return fmt.Errorf("profile '%s' does not exist", profileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read current config using internal profilemanager
|
||||||
|
config, err := profilemanager.ReadConfig(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read profile config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear authentication by removing private key and SSH key
|
||||||
|
config.PrivateKey = ""
|
||||||
|
config.SSHKey = ""
|
||||||
|
|
||||||
|
// Save config using internal profilemanager
|
||||||
|
if err := profilemanager.WriteOutConfig(configPath, config); err != nil {
|
||||||
|
return fmt.Errorf("failed to save config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("logged out from profile: %s", profileName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveProfile deletes a profile
|
||||||
|
func (pm *ProfileManager) RemoveProfile(profileName string) error {
|
||||||
|
// Use ServiceManager (removes profile from profiles/ directory)
|
||||||
|
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
|
||||||
|
return fmt.Errorf("failed to remove profile: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("removed profile: %s", profileName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getProfileConfigPath returns the config file path for a profile
|
||||||
|
// This is needed for Android-specific path handling (netbird.cfg for default profile)
|
||||||
|
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
|
||||||
|
if profileName == "" || profileName == profilemanager.DefaultProfileName {
|
||||||
|
// Android uses netbird.cfg for default profile instead of default.json
|
||||||
|
// Default profile is stored in root configDir, not in profiles/
|
||||||
|
return filepath.Join(pm.configDir, defaultConfigFilename), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-default profiles are stored in profiles subdirectory
|
||||||
|
// This matches the Java Preferences.java expectation
|
||||||
|
profileName = sanitizeProfileName(profileName)
|
||||||
|
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
|
||||||
|
return filepath.Join(profilesDir, profileName+".json"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfigPath returns the config file path for a given profile
|
||||||
|
// Java should call this instead of constructing paths with Preferences.configFile()
|
||||||
|
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
|
||||||
|
return pm.getProfileConfigPath(profileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetStateFilePath returns the state file path for a given profile
|
||||||
|
// Java should call this instead of constructing paths with Preferences.stateFile()
|
||||||
|
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
|
||||||
|
if profileName == "" || profileName == profilemanager.DefaultProfileName {
|
||||||
|
return filepath.Join(pm.configDir, "state.json"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
profileName = sanitizeProfileName(profileName)
|
||||||
|
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
|
||||||
|
return filepath.Join(profilesDir, profileName+".state.json"), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveConfigPath returns the config file path for the currently active profile
|
||||||
|
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.configFile()
|
||||||
|
func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
|
||||||
|
activeProfile, err := pm.GetActiveProfile()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||||
|
}
|
||||||
|
return pm.GetConfigPath(activeProfile)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetActiveStateFilePath returns the state file path for the currently active profile
|
||||||
|
// Java should call this instead of Preferences.getActiveProfileName() + Preferences.stateFile()
|
||||||
|
func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
|
||||||
|
activeProfile, err := pm.GetActiveProfile()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||||
|
}
|
||||||
|
return pm.GetStateFilePath(activeProfile)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeProfileName removes invalid characters from profile name
|
||||||
|
func sanitizeProfileName(name string) string {
|
||||||
|
// Keep only alphanumeric, underscore, and hyphen
|
||||||
|
var result strings.Builder
|
||||||
|
for _, r := range name {
|
||||||
|
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
|
||||||
|
(r >= '0' && r <= '9') || r == '_' || r == '-' {
|
||||||
|
result.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result.String()
|
||||||
|
}
|
||||||
136
client/cmd/kubeconfig.go
Normal file
136
client/cmd/kubeconfig.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
kubeconfigOutput string
|
||||||
|
kubeconfigCluster string
|
||||||
|
kubeconfigContext string
|
||||||
|
kubeconfigUser string
|
||||||
|
kubeconfigServer string
|
||||||
|
kubeconfigNamespace string
|
||||||
|
)
|
||||||
|
|
||||||
|
var kubeconfigCmd = &cobra.Command{
|
||||||
|
Use: "kubeconfig",
|
||||||
|
Short: "Generate kubeconfig for accessing Kubernetes via NetBird",
|
||||||
|
Long: `Generate a kubeconfig file that points to a Kubernetes cluster accessible via NetBird.
|
||||||
|
|
||||||
|
The generated kubeconfig uses a dummy bearer token for authentication when the
|
||||||
|
cluster's auth proxy is running in 'auth' mode. The actual authentication is
|
||||||
|
handled by the NetBird network - the auth proxy identifies users by their
|
||||||
|
NetBird peer IP and impersonates them in the Kubernetes API.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
netbird kubeconfig --server https://k8s.example.netbird.cloud:6443 --cluster my-cluster
|
||||||
|
netbird kubeconfig --server https://10.100.0.1:6443 -o ~/.kube/netbird-config`,
|
||||||
|
RunE: kubeconfigFunc,
|
||||||
|
}
|
||||||
|
|
||||||
|
// init configures command-line flags for the kubeconfig command.
|
||||||
|
// It registers flags for output path, cluster, context, user, server, and namespace
|
||||||
|
// and marks the server flag as required.
|
||||||
|
func init() {
|
||||||
|
kubeconfigCmd.Flags().StringVarP(&kubeconfigOutput, "output", "o", "", "Output file path (default: stdout)")
|
||||||
|
kubeconfigCmd.Flags().StringVar(&kubeconfigCluster, "cluster", "netbird-cluster", "Cluster name in kubeconfig")
|
||||||
|
kubeconfigCmd.Flags().StringVar(&kubeconfigContext, "context", "netbird", "Context name in kubeconfig")
|
||||||
|
kubeconfigCmd.Flags().StringVar(&kubeconfigUser, "user", "netbird-user", "User name in kubeconfig")
|
||||||
|
kubeconfigCmd.Flags().StringVar(&kubeconfigServer, "server", "", "Kubernetes API server URL (required)")
|
||||||
|
kubeconfigCmd.Flags().StringVar(&kubeconfigNamespace, "namespace", "default", "Default namespace")
|
||||||
|
_ = kubeconfigCmd.MarkFlagRequired("server")
|
||||||
|
}
|
||||||
|
|
||||||
|
// kubeconfigFunc generates a kubeconfig file for accessing Kubernetes via the NetBird auth proxy.
|
||||||
|
// KUBECONFIG and running kubectl.
|
||||||
|
func kubeconfigFunc(cmd *cobra.Command, args []string) error {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Get current NetBird status to verify connection
|
||||||
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
|
if err != nil {
|
||||||
|
cmd.PrintErrf("Warning: Could not connect to NetBird daemon: %v\n", err)
|
||||||
|
cmd.PrintErrln("Generating kubeconfig anyway, but make sure NetBird is running before using it.")
|
||||||
|
} else {
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{})
|
||||||
|
if err != nil {
|
||||||
|
cmd.PrintErrf("Warning: Could not get NetBird status: %v\n", status.Convert(err).Message())
|
||||||
|
} else if resp.Status != "Connected" {
|
||||||
|
cmd.PrintErrf("Warning: NetBird is not connected (status: %s)\n", resp.Status)
|
||||||
|
cmd.PrintErrln("Make sure to run 'netbird up' before using the generated kubeconfig.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kubeconfig := generateKubeconfig(kubeconfigServer, kubeconfigCluster, kubeconfigContext, kubeconfigUser, kubeconfigNamespace)
|
||||||
|
|
||||||
|
if kubeconfigOutput == "" {
|
||||||
|
fmt.Println(kubeconfig)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expand ~ in path
|
||||||
|
if strings.HasPrefix(kubeconfigOutput, "~/") {
|
||||||
|
home, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get home directory: %w", err)
|
||||||
|
}
|
||||||
|
kubeconfigOutput = filepath.Join(home, kubeconfigOutput[2:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create directory if needed
|
||||||
|
dir := filepath.Dir(kubeconfigOutput)
|
||||||
|
if err := os.MkdirAll(dir, 0700); err != nil {
|
||||||
|
return fmt.Errorf("failed to create directory %s: %w", dir, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(kubeconfigOutput, []byte(kubeconfig), 0600); err != nil {
|
||||||
|
return fmt.Errorf("failed to write kubeconfig: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("Kubeconfig written to %s\n", kubeconfigOutput)
|
||||||
|
cmd.PrintErrln("\nWarning: TLS verification is disabled (insecure-skip-tls-verify: true).")
|
||||||
|
cmd.PrintErrln("This is safe when traffic is encrypted via NetBird's WireGuard tunnel.")
|
||||||
|
cmd.Printf("\nTo use this kubeconfig:\n")
|
||||||
|
cmd.Printf(" export KUBECONFIG=%s\n", kubeconfigOutput)
|
||||||
|
cmd.Printf(" kubectl get nodes\n")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateKubeconfig creates a kubeconfig YAML string with the given parameters.
|
||||||
|
// generateKubeconfig generates a kubeconfig YAML for accessing the specified Kubernetes API server via NetBird.
|
||||||
|
// The returned config sets the current context to the provided context, includes the given cluster, user, and namespace,
|
||||||
|
// enables `insecure-skip-tls-verify: true`, and embeds the static token `netbird-auth-proxy`.
|
||||||
|
func generateKubeconfig(server, cluster, context, user, namespace string) string {
|
||||||
|
return fmt.Sprintf(`apiVersion: v1
|
||||||
|
kind: Config
|
||||||
|
clusters:
|
||||||
|
- cluster:
|
||||||
|
insecure-skip-tls-verify: true
|
||||||
|
server: %s
|
||||||
|
name: %s
|
||||||
|
contexts:
|
||||||
|
- context:
|
||||||
|
cluster: %s
|
||||||
|
namespace: %s
|
||||||
|
user: %s
|
||||||
|
name: %s
|
||||||
|
current-context: %s
|
||||||
|
users:
|
||||||
|
- name: %s
|
||||||
|
user:
|
||||||
|
token: netbird-auth-proxy
|
||||||
|
`, server, cluster, cluster, namespace, user, context, context, user)
|
||||||
|
}
|
||||||
@@ -88,6 +88,13 @@ func Execute() error {
|
|||||||
return rootCmd.Execute()
|
return rootCmd.Execute()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// init initializes package-level defaults and configures the root CLI command.
|
||||||
|
// It sets OS-specific default paths for configuration and logs, determines the default
|
||||||
|
// daemon address, registers persistent CLI flags (daemon address, management/admin URLs,
|
||||||
|
// logging, setup key options, pre-shared key, hostname, anonymization, and config path),
|
||||||
|
// and wires up all top-level and nested subcommands. It also defines upCmd-specific
|
||||||
|
// flags for external IP mapping, custom DNS resolver address, Rosenpass options,
|
||||||
|
// auto-connect control, and lazy connection.
|
||||||
func init() {
|
func init() {
|
||||||
defaultConfigPathDir = "/etc/netbird/"
|
defaultConfigPathDir = "/etc/netbird/"
|
||||||
defaultLogFileDir = "/var/log/netbird/"
|
defaultLogFileDir = "/var/log/netbird/"
|
||||||
@@ -141,6 +148,7 @@ func init() {
|
|||||||
rootCmd.AddCommand(forwardingRulesCmd)
|
rootCmd.AddCommand(forwardingRulesCmd)
|
||||||
rootCmd.AddCommand(debugCmd)
|
rootCmd.AddCommand(debugCmd)
|
||||||
rootCmd.AddCommand(profileCmd)
|
rootCmd.AddCommand(profileCmd)
|
||||||
|
rootCmd.AddCommand(kubeconfigCmd)
|
||||||
|
|
||||||
networksCMD.AddCommand(routesListCmd)
|
networksCMD.AddCommand(routesListCmd)
|
||||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||||
@@ -393,4 +401,4 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
176
client/cmd/signer/artifactkey.go
Normal file
176
client/cmd/signer/artifactkey.go
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
bundlePubKeysRootPrivKeyFile string
|
||||||
|
bundlePubKeysPubKeyFiles []string
|
||||||
|
bundlePubKeysFile string
|
||||||
|
|
||||||
|
createArtifactKeyRootPrivKeyFile string
|
||||||
|
createArtifactKeyPrivKeyFile string
|
||||||
|
createArtifactKeyPubKeyFile string
|
||||||
|
createArtifactKeyExpiration time.Duration
|
||||||
|
)
|
||||||
|
|
||||||
|
var createArtifactKeyCmd = &cobra.Command{
|
||||||
|
Use: "create-artifact-key",
|
||||||
|
Short: "Create a new artifact signing key",
|
||||||
|
Long: `Generate a new artifact signing key pair signed by the root private key.
|
||||||
|
The artifact key will be used to sign software artifacts/updates.`,
|
||||||
|
SilenceUsage: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
if createArtifactKeyExpiration <= 0 {
|
||||||
|
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := handleCreateArtifactKey(cmd, createArtifactKeyRootPrivKeyFile, createArtifactKeyPrivKeyFile, createArtifactKeyPubKeyFile, createArtifactKeyExpiration); err != nil {
|
||||||
|
return fmt.Errorf("failed to create artifact key: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var bundlePubKeysCmd = &cobra.Command{
|
||||||
|
Use: "bundle-pub-keys",
|
||||||
|
Short: "Bundle multiple artifact public keys into a signed package",
|
||||||
|
Long: `Bundle one or more artifact public keys into a signed package using the root private key.
|
||||||
|
This command is typically used to distribute or authorize a set of valid artifact signing keys.`,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
if len(bundlePubKeysPubKeyFiles) == 0 {
|
||||||
|
return fmt.Errorf("at least one --artifact-pub-key-file must be provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := handleBundlePubKeys(cmd, bundlePubKeysRootPrivKeyFile, bundlePubKeysPubKeyFiles, bundlePubKeysFile); err != nil {
|
||||||
|
return fmt.Errorf("failed to bundle public keys: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(createArtifactKeyCmd)
|
||||||
|
|
||||||
|
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the artifact key")
|
||||||
|
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPrivKeyFile, "artifact-priv-key-file", "", "Path where the artifact private key will be saved")
|
||||||
|
createArtifactKeyCmd.Flags().StringVar(&createArtifactKeyPubKeyFile, "artifact-pub-key-file", "", "Path where the artifact public key will be saved")
|
||||||
|
createArtifactKeyCmd.Flags().DurationVar(&createArtifactKeyExpiration, "expiration", 0, "Expiration duration for the artifact key (e.g., 720h, 365d, 8760h)")
|
||||||
|
|
||||||
|
if err := createArtifactKeyCmd.MarkFlagRequired("root-private-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-priv-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark artifact-priv-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := createArtifactKeyCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := createArtifactKeyCmd.MarkFlagRequired("expiration"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark expiration as required: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
rootCmd.AddCommand(bundlePubKeysCmd)
|
||||||
|
|
||||||
|
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysRootPrivKeyFile, "root-private-key-file", "", "Path to the root private key file used to sign the bundle")
|
||||||
|
bundlePubKeysCmd.Flags().StringArrayVar(&bundlePubKeysPubKeyFiles, "artifact-pub-key-file", nil, "Path(s) to the artifact public key files to include in the bundle (can be repeated)")
|
||||||
|
bundlePubKeysCmd.Flags().StringVar(&bundlePubKeysFile, "bundle-pub-key-file", "", "Path where the public keys will be saved")
|
||||||
|
|
||||||
|
if err := bundlePubKeysCmd.MarkFlagRequired("root-private-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark root-private-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := bundlePubKeysCmd.MarkFlagRequired("artifact-pub-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark artifact-pub-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := bundlePubKeysCmd.MarkFlagRequired("bundle-pub-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark bundle-pub-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleCreateArtifactKey(cmd *cobra.Command, rootPrivKeyFile, artifactPrivKeyFile, artifactPubKeyFile string, expiration time.Duration) error {
|
||||||
|
cmd.Println("Creating new artifact signing key...")
|
||||||
|
|
||||||
|
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read root private key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse private root key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
artifactKey, privPEM, pubPEM, signature, err := reposign.GenerateArtifactKey(privateRootKey, expiration)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate artifact key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(artifactPrivKeyFile, privPEM, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write private key file (%s): %w", artifactPrivKeyFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(artifactPubKeyFile, pubPEM, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write public key file (%s): %w", artifactPubKeyFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signatureFile := artifactPubKeyFile + ".sig"
|
||||||
|
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("✅ Artifact key created successfully.\n")
|
||||||
|
cmd.Printf("%s\n", artifactKey.String())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleBundlePubKeys(cmd *cobra.Command, rootPrivKeyFile string, artifactPubKeyFiles []string, bundlePubKeysFile string) error {
|
||||||
|
cmd.Println("📦 Bundling public keys into signed package...")
|
||||||
|
|
||||||
|
privKeyPEM, err := os.ReadFile(rootPrivKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read root private key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse private root key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKeys := make([]reposign.PublicKey, 0, len(artifactPubKeyFiles))
|
||||||
|
for _, pubFile := range artifactPubKeyFiles {
|
||||||
|
pubPem, err := os.ReadFile(pubFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read public key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pk, err := reposign.ParseArtifactPubKey(pubPem)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse artifact key: %w", err)
|
||||||
|
}
|
||||||
|
publicKeys = append(publicKeys, pk)
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedKeys, signature, err := reposign.BundleArtifactKeys(privateRootKey, publicKeys)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("bundle artifact keys: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.WriteFile(bundlePubKeysFile, parsedKeys, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write public keys file (%s): %w", bundlePubKeysFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signatureFile := bundlePubKeysFile + ".sig"
|
||||||
|
if err := os.WriteFile(signatureFile, signature, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write signature file (%s): %w", signatureFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("✅ Bundle created with %d public keys.\n", len(artifactPubKeyFiles))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
276
client/cmd/signer/artifactsign.go
Normal file
276
client/cmd/signer/artifactsign.go
Normal file
@@ -0,0 +1,276 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
envArtifactPrivateKey = "NB_ARTIFACT_PRIV_KEY"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
signArtifactPrivKeyFile string
|
||||||
|
signArtifactArtifactFile string
|
||||||
|
|
||||||
|
verifyArtifactPubKeyFile string
|
||||||
|
verifyArtifactFile string
|
||||||
|
verifyArtifactSignatureFile string
|
||||||
|
|
||||||
|
verifyArtifactKeyPubKeyFile string
|
||||||
|
verifyArtifactKeyRootPubKeyFile string
|
||||||
|
verifyArtifactKeySignatureFile string
|
||||||
|
verifyArtifactKeyRevocationFile string
|
||||||
|
)
|
||||||
|
|
||||||
|
var signArtifactCmd = &cobra.Command{
|
||||||
|
Use: "sign-artifact",
|
||||||
|
Short: "Sign an artifact using an artifact private key",
|
||||||
|
Long: `Sign a software artifact (e.g., update bundle or binary) using the artifact's private key.
|
||||||
|
This command produces a detached signature that can be verified using the corresponding artifact public key.`,
|
||||||
|
SilenceUsage: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
if err := handleSignArtifact(cmd, signArtifactPrivKeyFile, signArtifactArtifactFile); err != nil {
|
||||||
|
return fmt.Errorf("failed to sign artifact: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var verifyArtifactCmd = &cobra.Command{
|
||||||
|
Use: "verify-artifact",
|
||||||
|
Short: "Verify an artifact signature using an artifact public key",
|
||||||
|
Long: `Verify a software artifact signature using the artifact's public key.`,
|
||||||
|
SilenceUsage: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
if err := handleVerifyArtifact(cmd, verifyArtifactPubKeyFile, verifyArtifactFile, verifyArtifactSignatureFile); err != nil {
|
||||||
|
return fmt.Errorf("failed to verify artifact: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var verifyArtifactKeyCmd = &cobra.Command{
|
||||||
|
Use: "verify-artifact-key",
|
||||||
|
Short: "Verify an artifact public key was signed by a root key",
|
||||||
|
Long: `Verify that an artifact public key (or bundle) was properly signed by a root key.
|
||||||
|
This validates the chain of trust from the root key to the artifact key.`,
|
||||||
|
SilenceUsage: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
if err := handleVerifyArtifactKey(cmd, verifyArtifactKeyPubKeyFile, verifyArtifactKeyRootPubKeyFile, verifyArtifactKeySignatureFile, verifyArtifactKeyRevocationFile); err != nil {
|
||||||
|
return fmt.Errorf("failed to verify artifact key: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(signArtifactCmd)
|
||||||
|
rootCmd.AddCommand(verifyArtifactCmd)
|
||||||
|
rootCmd.AddCommand(verifyArtifactKeyCmd)
|
||||||
|
|
||||||
|
signArtifactCmd.Flags().StringVar(&signArtifactPrivKeyFile, "artifact-key-file", "", fmt.Sprintf("Path to the artifact private key file used for signing (or set %s env var)", envArtifactPrivateKey))
|
||||||
|
signArtifactCmd.Flags().StringVar(&signArtifactArtifactFile, "artifact-file", "", "Path to the artifact to be signed")
|
||||||
|
|
||||||
|
// artifact-file is required, but artifact-key-file can come from env var
|
||||||
|
if err := signArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark artifact-file as required: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
verifyArtifactCmd.Flags().StringVar(&verifyArtifactPubKeyFile, "artifact-public-key-file", "", "Path to the artifact public key file")
|
||||||
|
verifyArtifactCmd.Flags().StringVar(&verifyArtifactFile, "artifact-file", "", "Path to the artifact to be verified")
|
||||||
|
verifyArtifactCmd.Flags().StringVar(&verifyArtifactSignatureFile, "signature-file", "", "Path to the signature file")
|
||||||
|
|
||||||
|
if err := verifyArtifactCmd.MarkFlagRequired("artifact-public-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark artifact-public-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := verifyArtifactCmd.MarkFlagRequired("artifact-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark artifact-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := verifyArtifactCmd.MarkFlagRequired("signature-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark signature-file as required: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyPubKeyFile, "artifact-key-file", "", "Path to the artifact public key file or bundle")
|
||||||
|
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRootPubKeyFile, "root-key-file", "", "Path to the root public key file or bundle")
|
||||||
|
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeySignatureFile, "signature-file", "", "Path to the signature file")
|
||||||
|
verifyArtifactKeyCmd.Flags().StringVar(&verifyArtifactKeyRevocationFile, "revocation-file", "", "Path to the revocation list file (optional)")
|
||||||
|
|
||||||
|
if err := verifyArtifactKeyCmd.MarkFlagRequired("artifact-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark artifact-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := verifyArtifactKeyCmd.MarkFlagRequired("root-key-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark root-key-file as required: %w", err))
|
||||||
|
}
|
||||||
|
if err := verifyArtifactKeyCmd.MarkFlagRequired("signature-file"); err != nil {
|
||||||
|
panic(fmt.Errorf("mark signature-file as required: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleSignArtifact(cmd *cobra.Command, privKeyFile, artifactFile string) error {
|
||||||
|
cmd.Println("🖋️ Signing artifact...")
|
||||||
|
|
||||||
|
// Load private key from env var or file
|
||||||
|
var privKeyPEM []byte
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if envKey := os.Getenv(envArtifactPrivateKey); envKey != "" {
|
||||||
|
// Use key from environment variable
|
||||||
|
privKeyPEM = []byte(envKey)
|
||||||
|
} else if privKeyFile != "" {
|
||||||
|
// Fall back to file
|
||||||
|
privKeyPEM, err = os.ReadFile(privKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read private key file: %w", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("artifact private key must be provided via %s environment variable or --artifact-key-file flag", envArtifactPrivateKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKey, err := reposign.ParseArtifactKey(privKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse artifact private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
artifactData, err := os.ReadFile(artifactFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read artifact file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := reposign.SignData(privateKey, artifactData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("sign artifact: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sigFile := artifactFile + ".sig"
|
||||||
|
if err := os.WriteFile(artifactFile+".sig", signature, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write signature file (%s): %w", sigFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("✅ Artifact signed successfully.\n")
|
||||||
|
cmd.Printf("Signature file: %s\n", sigFile)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleVerifyArtifact(cmd *cobra.Command, pubKeyFile, artifactFile, signatureFile string) error {
|
||||||
|
cmd.Println("🔍 Verifying artifact...")
|
||||||
|
|
||||||
|
// Read artifact public key
|
||||||
|
pubKeyPEM, err := os.ReadFile(pubKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read public key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey, err := reposign.ParseArtifactPubKey(pubKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse artifact public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read artifact data
|
||||||
|
artifactData, err := os.ReadFile(artifactFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read artifact file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read signature
|
||||||
|
sigBytes, err := os.ReadFile(signatureFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read signature file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := reposign.ParseSignature(sigBytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse signature: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate artifact
|
||||||
|
if err := reposign.ValidateArtifact([]reposign.PublicKey{publicKey}, artifactData, *signature); err != nil {
|
||||||
|
return fmt.Errorf("artifact verification failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Println("✅ Artifact signature is valid")
|
||||||
|
cmd.Printf("Artifact: %s\n", artifactFile)
|
||||||
|
cmd.Printf("Signed by key: %s\n", signature.KeyID)
|
||||||
|
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleVerifyArtifactKey(cmd *cobra.Command, artifactKeyFile, rootKeyFile, signatureFile, revocationFile string) error {
|
||||||
|
cmd.Println("🔍 Verifying artifact key...")
|
||||||
|
|
||||||
|
// Read artifact key data
|
||||||
|
artifactKeyData, err := os.ReadFile(artifactKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read artifact key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read root public key(s)
|
||||||
|
rootKeyData, err := os.ReadFile(rootKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read root key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rootPublicKeys, err := parseRootPublicKeys(rootKeyData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse root public key(s): %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read signature
|
||||||
|
sigBytes, err := os.ReadFile(signatureFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read signature file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := reposign.ParseSignature(sigBytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse signature: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read optional revocation list
|
||||||
|
var revocationList *reposign.RevocationList
|
||||||
|
if revocationFile != "" {
|
||||||
|
revData, err := os.ReadFile(revocationFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read revocation file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
revocationList, err = reposign.ParseRevocationList(revData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse revocation list: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate artifact key(s)
|
||||||
|
validKeys, err := reposign.ValidateArtifactKeys(rootPublicKeys, artifactKeyData, *signature, revocationList)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("artifact key verification failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Println("✅ Artifact key(s) verified successfully")
|
||||||
|
cmd.Printf("Signed by root key: %s\n", signature.KeyID)
|
||||||
|
cmd.Printf("Signature timestamp: %s\n", signature.Timestamp.Format("2006-01-02 15:04:05 MST"))
|
||||||
|
cmd.Printf("\nValid artifact keys (%d):\n", len(validKeys))
|
||||||
|
for i, key := range validKeys {
|
||||||
|
cmd.Printf(" [%d] Key ID: %s\n", i+1, key.Metadata.ID)
|
||||||
|
cmd.Printf(" Created: %s\n", key.Metadata.CreatedAt.Format("2006-01-02 15:04:05 MST"))
|
||||||
|
if !key.Metadata.ExpiresAt.IsZero() {
|
||||||
|
cmd.Printf(" Expires: %s\n", key.Metadata.ExpiresAt.Format("2006-01-02 15:04:05 MST"))
|
||||||
|
} else {
|
||||||
|
cmd.Printf(" Expires: Never\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseRootPublicKeys parses a root public key from PEM data
|
||||||
|
func parseRootPublicKeys(data []byte) ([]reposign.PublicKey, error) {
|
||||||
|
key, err := reposign.ParseRootPublicKey(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return []reposign.PublicKey{key}, nil
|
||||||
|
}
|
||||||
21
client/cmd/signer/main.go
Normal file
21
client/cmd/signer/main.go
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
var rootCmd = &cobra.Command{
|
||||||
|
Use: "signer",
|
||||||
|
Short: "A CLI tool for managing cryptographic keys and artifacts",
|
||||||
|
Long: `signer is a command-line tool that helps you manage
|
||||||
|
root keys, artifact keys, and revocation lists securely.`,
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if err := rootCmd.Execute(); err != nil {
|
||||||
|
rootCmd.Println(err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
220
client/cmd/signer/revocation.go
Normal file
220
client/cmd/signer/revocation.go
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultRevocationListExpiration = 365 * 24 * time.Hour // 1 year
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
keyID string
|
||||||
|
revocationListFile string
|
||||||
|
privateRootKeyFile string
|
||||||
|
publicRootKeyFile string
|
||||||
|
signatureFile string
|
||||||
|
expirationDuration time.Duration
|
||||||
|
)
|
||||||
|
|
||||||
|
var createRevocationListCmd = &cobra.Command{
|
||||||
|
Use: "create-revocation-list",
|
||||||
|
Short: "Create a new revocation list signed by the private root key",
|
||||||
|
SilenceUsage: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
return handleCreateRevocationList(cmd, revocationListFile, privateRootKeyFile)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var extendRevocationListCmd = &cobra.Command{
|
||||||
|
Use: "extend-revocation-list",
|
||||||
|
Short: "Extend an existing revocation list with a given key ID",
|
||||||
|
SilenceUsage: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
return handleExtendRevocationList(cmd, keyID, revocationListFile, privateRootKeyFile)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var verifyRevocationListCmd = &cobra.Command{
|
||||||
|
Use: "verify-revocation-list",
|
||||||
|
Short: "Verify a revocation list signature using the public root key",
|
||||||
|
SilenceUsage: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
return handleVerifyRevocationList(cmd, revocationListFile, signatureFile, publicRootKeyFile)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(createRevocationListCmd)
|
||||||
|
rootCmd.AddCommand(extendRevocationListCmd)
|
||||||
|
rootCmd.AddCommand(verifyRevocationListCmd)
|
||||||
|
|
||||||
|
createRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
|
||||||
|
createRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
|
||||||
|
createRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
|
||||||
|
if err := createRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := createRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
extendRevocationListCmd.Flags().StringVar(&keyID, "key-id", "", "ID of the key to extend the revocation list for")
|
||||||
|
extendRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the existing revocation list file")
|
||||||
|
extendRevocationListCmd.Flags().StringVar(&privateRootKeyFile, "private-root-key", "", "Path to the private root key PEM file")
|
||||||
|
extendRevocationListCmd.Flags().DurationVar(&expirationDuration, "expiration", defaultRevocationListExpiration, "Expiration duration for the revocation list (e.g., 8760h for 1 year)")
|
||||||
|
if err := extendRevocationListCmd.MarkFlagRequired("key-id"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := extendRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := extendRevocationListCmd.MarkFlagRequired("private-root-key"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
verifyRevocationListCmd.Flags().StringVar(&revocationListFile, "revocation-list-file", "", "Path to the revocation list file")
|
||||||
|
verifyRevocationListCmd.Flags().StringVar(&signatureFile, "signature-file", "", "Path to the signature file")
|
||||||
|
verifyRevocationListCmd.Flags().StringVar(&publicRootKeyFile, "public-root-key", "", "Path to the public root key PEM file")
|
||||||
|
if err := verifyRevocationListCmd.MarkFlagRequired("revocation-list-file"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := verifyRevocationListCmd.MarkFlagRequired("signature-file"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := verifyRevocationListCmd.MarkFlagRequired("public-root-key"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleCreateRevocationList(cmd *cobra.Command, revocationListFile string, privateRootKeyFile string) error {
|
||||||
|
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read private root key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse private root key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rlBytes, sigBytes, err := reposign.CreateRevocationList(*privateRootKey, expirationDuration)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create revocation list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", rlBytes, sigBytes); err != nil {
|
||||||
|
return fmt.Errorf("failed to write output files: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Println("✅ Revocation list created successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleExtendRevocationList(cmd *cobra.Command, keyID, revocationListFile, privateRootKeyFile string) error {
|
||||||
|
privKeyPEM, err := os.ReadFile(privateRootKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read private root key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privateRootKey, err := reposign.ParseRootKey(privKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse private root key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rlBytes, err := os.ReadFile(revocationListFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read revocation list file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rl, err := reposign.ParseRevocationList(rlBytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse revocation list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
kid, err := reposign.ParseKeyID(keyID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid key ID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newRLBytes, sigBytes, err := reposign.ExtendRevocationList(*privateRootKey, *rl, kid, expirationDuration)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to extend revocation list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := writeOutputFiles(revocationListFile, revocationListFile+".sig", newRLBytes, sigBytes); err != nil {
|
||||||
|
return fmt.Errorf("failed to write output files: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Println("✅ Revocation list extended successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleVerifyRevocationList(cmd *cobra.Command, revocationListFile, signatureFile, publicRootKeyFile string) error {
|
||||||
|
// Read revocation list file
|
||||||
|
rlBytes, err := os.ReadFile(revocationListFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read revocation list file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read signature file
|
||||||
|
sigBytes, err := os.ReadFile(signatureFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read signature file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read public root key file
|
||||||
|
pubKeyPEM, err := os.ReadFile(publicRootKeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read public root key file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse public root key
|
||||||
|
publicKey, err := reposign.ParseRootPublicKey(pubKeyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse public root key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse signature
|
||||||
|
signature, err := reposign.ParseSignature(sigBytes)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse signature: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate revocation list
|
||||||
|
rl, err := reposign.ValidateRevocationList([]reposign.PublicKey{publicKey}, rlBytes, *signature)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to validate revocation list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display results
|
||||||
|
cmd.Println("✅ Revocation list signature is valid")
|
||||||
|
cmd.Printf("Last Updated: %s\n", rl.LastUpdated.Format(time.RFC3339))
|
||||||
|
cmd.Printf("Expires At: %s\n", rl.ExpiresAt.Format(time.RFC3339))
|
||||||
|
cmd.Printf("Number of revoked keys: %d\n", len(rl.Revoked))
|
||||||
|
|
||||||
|
if len(rl.Revoked) > 0 {
|
||||||
|
cmd.Println("\nRevoked Keys:")
|
||||||
|
for keyID, revokedTime := range rl.Revoked {
|
||||||
|
cmd.Printf(" - %s (revoked at: %s)\n", keyID, revokedTime.Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeOutputFiles(rlPath, sigPath string, rlBytes, sigBytes []byte) error {
|
||||||
|
if err := os.WriteFile(rlPath, rlBytes, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("failed to write revocation list file: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(sigPath, sigBytes, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("failed to write signature file: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
74
client/cmd/signer/rootkey.go
Normal file
74
client/cmd/signer/rootkey.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
privKeyFile string
|
||||||
|
pubKeyFile string
|
||||||
|
rootExpiration time.Duration
|
||||||
|
)
|
||||||
|
|
||||||
|
var createRootKeyCmd = &cobra.Command{
|
||||||
|
Use: "create-root-key",
|
||||||
|
Short: "Create a new root key pair",
|
||||||
|
Long: `Create a new root key pair and specify an expiration time for it.`,
|
||||||
|
SilenceUsage: true,
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
// Validate expiration
|
||||||
|
if rootExpiration <= 0 {
|
||||||
|
return fmt.Errorf("--expiration must be a positive duration (e.g., 720h, 365d, 8760h)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run main logic
|
||||||
|
if err := handleGenerateRootKey(cmd, privKeyFile, pubKeyFile, rootExpiration); err != nil {
|
||||||
|
return fmt.Errorf("failed to generate root key: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(createRootKeyCmd)
|
||||||
|
createRootKeyCmd.Flags().StringVar(&privKeyFile, "priv-key-file", "", "Path to output private key file")
|
||||||
|
createRootKeyCmd.Flags().StringVar(&pubKeyFile, "pub-key-file", "", "Path to output public key file")
|
||||||
|
createRootKeyCmd.Flags().DurationVar(&rootExpiration, "expiration", 0, "Expiration time for the root key (e.g., 720h,)")
|
||||||
|
|
||||||
|
if err := createRootKeyCmd.MarkFlagRequired("priv-key-file"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := createRootKeyCmd.MarkFlagRequired("pub-key-file"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if err := createRootKeyCmd.MarkFlagRequired("expiration"); err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleGenerateRootKey(cmd *cobra.Command, privKeyFile, pubKeyFile string, expiration time.Duration) error {
|
||||||
|
rk, privPEM, pubPEM, err := reposign.GenerateRootKey(expiration)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate root key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write private key
|
||||||
|
if err := os.WriteFile(privKeyFile, privPEM, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write private key file (%s): %w", privKeyFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write public key
|
||||||
|
if err := os.WriteFile(pubKeyFile, pubPEM, 0o600); err != nil {
|
||||||
|
return fmt.Errorf("write public key file (%s): %w", pubKeyFile, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("%s\n\n", rk.String())
|
||||||
|
cmd.Printf("✅ Root key pair generated successfully.\n")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
|||||||
r := peer.NewRecorder(config.ManagementURL.String())
|
r := peer.NewRecorder(config.ManagementURL.String())
|
||||||
r.GetFullStatus()
|
r.GetFullStatus()
|
||||||
|
|
||||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
connectClient := internal.NewConnectClient(ctx, config, r, false)
|
||||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||||
|
|
||||||
return connectClient.Run(nil)
|
return connectClient.Run(nil)
|
||||||
|
|||||||
13
client/cmd/update.go
Normal file
13
client/cmd/update.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
//go:build !windows && !darwin
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
)
|
||||||
|
|
||||||
|
var updateCmd *cobra.Command
|
||||||
|
|
||||||
|
func isUpdateBinary() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
75
client/cmd/update_supported.go
Normal file
75
client/cmd/update_supported.go
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
//go:build windows || darwin
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
updateCmd = &cobra.Command{
|
||||||
|
Use: "update",
|
||||||
|
Short: "Update the NetBird client application",
|
||||||
|
RunE: updateFunc,
|
||||||
|
}
|
||||||
|
|
||||||
|
tempDirFlag string
|
||||||
|
installerFile string
|
||||||
|
serviceDirFlag string
|
||||||
|
dryRunFlag bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
updateCmd.Flags().StringVar(&tempDirFlag, "temp-dir", "", "temporary dir")
|
||||||
|
updateCmd.Flags().StringVar(&installerFile, "installer-file", "", "installer file")
|
||||||
|
updateCmd.Flags().StringVar(&serviceDirFlag, "service-dir", "", "service directory")
|
||||||
|
updateCmd.Flags().BoolVar(&dryRunFlag, "dry-run", false, "dry run the update process without making any changes")
|
||||||
|
}
|
||||||
|
|
||||||
|
// isUpdateBinary checks if the current executable is named "update" or "update.exe"
|
||||||
|
func isUpdateBinary() bool {
|
||||||
|
// Remove extension for cross-platform compatibility
|
||||||
|
execPath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
baseName := filepath.Base(execPath)
|
||||||
|
name := strings.TrimSuffix(baseName, filepath.Ext(baseName))
|
||||||
|
|
||||||
|
return name == installer.UpdaterBinaryNameWithoutExtension()
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateFunc(cmd *cobra.Command, args []string) error {
|
||||||
|
if err := setupLogToFile(tempDirFlag); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("updater started: %s", serviceDirFlag)
|
||||||
|
updater := installer.NewWithDir(tempDirFlag)
|
||||||
|
if err := updater.Setup(context.Background(), dryRunFlag, installerFile, serviceDirFlag); err != nil {
|
||||||
|
log.Errorf("failed to update application: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupLogToFile(dir string) error {
|
||||||
|
logFile := filepath.Join(dir, installer.LogFile)
|
||||||
|
|
||||||
|
if _, err := os.Stat(logFile); err == nil {
|
||||||
|
if err := os.Remove(logFile); err != nil {
|
||||||
|
log.Errorf("failed to remove existing log file: %v\n", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return util.InitLog(logLevel, util.LogConsole, logFile)
|
||||||
|
}
|
||||||
@@ -173,7 +173,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||||
client := internal.NewConnectClient(ctx, c.config, recorder)
|
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
// TODO: make after-startup backoff err available
|
// TODO: make after-startup backoff err available
|
||||||
|
|||||||
@@ -24,10 +24,14 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
|
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
@@ -39,11 +43,13 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ConnectClient struct {
|
type ConnectClient struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
config *profilemanager.Config
|
config *profilemanager.Config
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
engine *Engine
|
doInitialAutoUpdate bool
|
||||||
engineMutex sync.Mutex
|
|
||||||
|
engine *Engine
|
||||||
|
engineMutex sync.Mutex
|
||||||
|
|
||||||
persistSyncResponse bool
|
persistSyncResponse bool
|
||||||
}
|
}
|
||||||
@@ -52,13 +58,15 @@ func NewConnectClient(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
config *profilemanager.Config,
|
config *profilemanager.Config,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
|
doInitalAutoUpdate bool,
|
||||||
|
|
||||||
) *ConnectClient {
|
) *ConnectClient {
|
||||||
return &ConnectClient{
|
return &ConnectClient{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
config: config,
|
config: config,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
engineMutex: sync.Mutex{},
|
doInitialAutoUpdate: doInitalAutoUpdate,
|
||||||
|
engineMutex: sync.Mutex{},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -162,6 +170,33 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var path string
|
||||||
|
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
|
||||||
|
// On mobile, use the provided state file path directly
|
||||||
|
if !fileExists(mobileDependency.StateFilePath) {
|
||||||
|
if err := createFile(mobileDependency.StateFilePath); err != nil {
|
||||||
|
log.Errorf("failed to create state file: %v", err)
|
||||||
|
// we are not exiting as we can run without the state manager
|
||||||
|
}
|
||||||
|
}
|
||||||
|
path = mobileDependency.StateFilePath
|
||||||
|
} else {
|
||||||
|
sm := profilemanager.NewServiceManager("")
|
||||||
|
path = sm.GetStatePath()
|
||||||
|
}
|
||||||
|
stateManager := statemanager.New(path)
|
||||||
|
stateManager.RegisterState(&sshconfig.ShutdownState{})
|
||||||
|
|
||||||
|
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
|
||||||
|
if err == nil {
|
||||||
|
updateManager.CheckUpdateSuccess(c.ctx)
|
||||||
|
|
||||||
|
inst := installer.New()
|
||||||
|
if err := inst.CleanUpInstallerFiles(); err != nil {
|
||||||
|
log.Errorf("failed to clean up temporary installer file: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
defer c.statusRecorder.ClientStop()
|
defer c.statusRecorder.ClientStop()
|
||||||
operation := func() error {
|
operation := func() error {
|
||||||
// if context cancelled we not start new backoff cycle
|
// if context cancelled we not start new backoff cycle
|
||||||
@@ -273,7 +308,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
checks := loginResp.GetChecks()
|
checks := loginResp.GetChecks()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks)
|
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
|
||||||
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||||
c.engine = engine
|
c.engine = engine
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
@@ -283,6 +318,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
|
||||||
|
// AutoUpdate will be true when the user click on "Connect" menu on the UI
|
||||||
|
if c.doInitialAutoUpdate {
|
||||||
|
log.Infof("start engine by ui, run auto-update check")
|
||||||
|
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
|
||||||
|
c.doInitialAutoUpdate = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||||
state.Set(StatusConnected)
|
state.Set(StatusConnected)
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/anonymize"
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
@@ -362,6 +363,10 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
log.Errorf("failed to add systemd logs: %v", err)
|
log.Errorf("failed to add systemd logs: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := g.addUpdateLogs(); err != nil {
|
||||||
|
log.Errorf("failed to add updater logs: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -650,6 +655,29 @@ func (g *BundleGenerator) addStateFile() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addUpdateLogs() error {
|
||||||
|
inst := installer.New()
|
||||||
|
logFiles := inst.LogFiles()
|
||||||
|
if len(logFiles) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("adding updater logs")
|
||||||
|
for _, logFile := range logFiles {
|
||||||
|
data, err := os.ReadFile(logFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to read update log file %s: %v", logFile, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
baseName := filepath.Base(logFile)
|
||||||
|
if err := g.addFileToZip(bytes.NewReader(data), filepath.Join("update-logs", baseName)); err != nil {
|
||||||
|
return fmt.Errorf("add update log file %s to zip: %w", baseName, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addCorruptedStateFiles() error {
|
func (g *BundleGenerator) addCorruptedStateFiles() error {
|
||||||
sm := profilemanager.NewServiceManager("")
|
sm := profilemanager.NewServiceManager("")
|
||||||
pattern := sm.GetStatePath()
|
pattern := sm.GetStatePath()
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -26,6 +27,11 @@ type Resolver struct {
|
|||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ipsResponse struct {
|
||||||
|
ips []netip.Addr
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
// NewResolver creates a new management domains cache resolver.
|
// NewResolver creates a new management domains cache resolver.
|
||||||
func NewResolver() *Resolver {
|
func NewResolver() *Resolver {
|
||||||
return &Resolver{
|
return &Resolver{
|
||||||
@@ -99,9 +105,9 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
|||||||
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
ips, err := lookupIPWithExtraTimeout(ctx, d)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("resolve domain %s: %w", d.SafeString(), err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var aRecords, aaaaRecords []dns.RR
|
var aRecords, aaaaRecords []dns.RR
|
||||||
@@ -159,6 +165,36 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
|
||||||
|
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
|
||||||
|
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
|
||||||
|
resultChan := make(chan *ipsResponse, 1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
||||||
|
resultChan <- &ipsResponse{
|
||||||
|
err: err,
|
||||||
|
ips: ips,
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var resp *ipsResponse
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(dnsTimeout + time.Millisecond*500):
|
||||||
|
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
|
||||||
|
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case resp = <-resultChan:
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
|
||||||
|
}
|
||||||
|
return resp.ips, nil
|
||||||
|
}
|
||||||
|
|
||||||
// PopulateFromConfig extracts and caches domains from the client configuration.
|
// PopulateFromConfig extracts and caches domains from the client configuration.
|
||||||
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
|
func (m *Resolver) PopulateFromConfig(ctx context.Context, mgmtURL *url.URL) error {
|
||||||
if mgmtURL == nil {
|
if mgmtURL == nil {
|
||||||
|
|||||||
@@ -42,14 +42,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||||
"github.com/netbirdio/netbird/client/internal/peerstore"
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
"github.com/netbirdio/netbird/client/internal/rosenpass"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
|
|
||||||
@@ -73,6 +72,7 @@ const (
|
|||||||
PeerConnectionTimeoutMax = 45000 // ms
|
PeerConnectionTimeoutMax = 45000 // ms
|
||||||
PeerConnectionTimeoutMin = 30000 // ms
|
PeerConnectionTimeoutMin = 30000 // ms
|
||||||
connInitLimit = 200
|
connInitLimit = 200
|
||||||
|
disableAutoUpdate = "disabled"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||||
@@ -201,6 +201,9 @@ type Engine struct {
|
|||||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||||
flowManager nftypes.FlowManager
|
flowManager nftypes.FlowManager
|
||||||
|
|
||||||
|
// auto-update
|
||||||
|
updateManager *updatemanager.Manager
|
||||||
|
|
||||||
// WireGuard interface monitor
|
// WireGuard interface monitor
|
||||||
wgIfaceMonitor *WGIfaceMonitor
|
wgIfaceMonitor *WGIfaceMonitor
|
||||||
|
|
||||||
@@ -221,17 +224,7 @@ type localIpUpdater interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewEngine creates a new Connection Engine with probes attached
|
// NewEngine creates a new Connection Engine with probes attached
|
||||||
func NewEngine(
|
func NewEngine(clientCtx context.Context, clientCancel context.CancelFunc, signalClient signal.Client, mgmClient mgm.Client, relayManager *relayClient.Manager, config *EngineConfig, mobileDep MobileDependency, statusRecorder *peer.Status, checks []*mgmProto.Checks, stateManager *statemanager.Manager) *Engine {
|
||||||
clientCtx context.Context,
|
|
||||||
clientCancel context.CancelFunc,
|
|
||||||
signalClient signal.Client,
|
|
||||||
mgmClient mgm.Client,
|
|
||||||
relayManager *relayClient.Manager,
|
|
||||||
config *EngineConfig,
|
|
||||||
mobileDep MobileDependency,
|
|
||||||
statusRecorder *peer.Status,
|
|
||||||
checks []*mgmProto.Checks,
|
|
||||||
) *Engine {
|
|
||||||
engine := &Engine{
|
engine := &Engine{
|
||||||
clientCtx: clientCtx,
|
clientCtx: clientCtx,
|
||||||
clientCancel: clientCancel,
|
clientCancel: clientCancel,
|
||||||
@@ -247,28 +240,12 @@ func NewEngine(
|
|||||||
TURNs: []*stun.URI{},
|
TURNs: []*stun.URI{},
|
||||||
networkSerial: 0,
|
networkSerial: 0,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
|
stateManager: stateManager,
|
||||||
checks: checks,
|
checks: checks,
|
||||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||||
}
|
}
|
||||||
|
|
||||||
sm := profilemanager.NewServiceManager("")
|
|
||||||
|
|
||||||
path := sm.GetStatePath()
|
|
||||||
if runtime.GOOS == "ios" || runtime.GOOS == "android" {
|
|
||||||
if !fileExists(mobileDep.StateFilePath) {
|
|
||||||
err := createFile(mobileDep.StateFilePath)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to create state file: %v", err)
|
|
||||||
// we are not exiting as we can run without the state manager
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
path = mobileDep.StateFilePath
|
|
||||||
}
|
|
||||||
engine.stateManager = statemanager.New(path)
|
|
||||||
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
|
|
||||||
|
|
||||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||||
return engine
|
return engine
|
||||||
}
|
}
|
||||||
@@ -308,6 +285,10 @@ func (e *Engine) Stop() error {
|
|||||||
e.srWatcher.Close()
|
e.srWatcher.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if e.updateManager != nil {
|
||||||
|
e.updateManager.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
log.Info("cleaning up status recorder states")
|
log.Info("cleaning up status recorder states")
|
||||||
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
e.statusRecorder.ReplaceOfflinePeers([]peer.State{})
|
||||||
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
e.statusRecorder.UpdateDNSStates([]peer.NSGroupState{})
|
||||||
@@ -541,6 +522,13 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
e.handleAutoUpdateVersion(autoUpdateSettings, true)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Engine) createFirewall() error {
|
func (e *Engine) createFirewall() error {
|
||||||
if e.config.DisableFirewall {
|
if e.config.DisableFirewall {
|
||||||
log.Infof("firewall is disabled")
|
log.Infof("firewall is disabled")
|
||||||
@@ -749,6 +737,41 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
|
||||||
|
if autoUpdateSettings == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
disabled := autoUpdateSettings.Version == disableAutoUpdate
|
||||||
|
|
||||||
|
// Stop and cleanup if disabled
|
||||||
|
if e.updateManager != nil && disabled {
|
||||||
|
log.Infof("auto-update is disabled, stopping update manager")
|
||||||
|
e.updateManager.Stop()
|
||||||
|
e.updateManager = nil
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
|
||||||
|
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
|
||||||
|
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start manager if needed
|
||||||
|
if e.updateManager == nil {
|
||||||
|
log.Infof("starting auto-update manager")
|
||||||
|
updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
e.updateManager = updateManager
|
||||||
|
e.updateManager.Start(e.ctx)
|
||||||
|
}
|
||||||
|
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
|
||||||
|
e.updateManager.SetVersion(autoUpdateSettings.Version)
|
||||||
|
}
|
||||||
|
|
||||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
@@ -758,6 +781,10 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return e.ctx.Err()
|
return e.ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||||
|
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
|
||||||
|
}
|
||||||
|
|
||||||
if update.GetNetbirdConfig() != nil {
|
if update.GetNetbirdConfig() != nil {
|
||||||
wCfg := update.GetNetbirdConfig()
|
wCfg := update.GetNetbirdConfig()
|
||||||
err := e.updateTURNs(wCfg.GetTurns())
|
err := e.updateTURNs(wCfg.GetTurns())
|
||||||
|
|||||||
@@ -253,6 +253,7 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
MobileDependency{},
|
MobileDependency{},
|
||||||
peer.NewRecorder("https://mgm"),
|
peer.NewRecorder("https://mgm"),
|
||||||
nil,
|
nil,
|
||||||
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
@@ -414,21 +415,13 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
engine := NewEngine(
|
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||||
ctx, cancel,
|
WgIfaceName: "utun102",
|
||||||
&signal.MockClient{},
|
WgAddr: "100.64.0.1/24",
|
||||||
&mgmt.MockClient{},
|
WgPrivateKey: key,
|
||||||
relayMgr,
|
WgPort: 33100,
|
||||||
&EngineConfig{
|
MTU: iface.DefaultMTU,
|
||||||
WgIfaceName: "utun102",
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||||
WgAddr: "100.64.0.1/24",
|
|
||||||
WgPrivateKey: key,
|
|
||||||
WgPort: 33100,
|
|
||||||
MTU: iface.DefaultMTU,
|
|
||||||
},
|
|
||||||
MobileDependency{},
|
|
||||||
peer.NewRecorder("https://mgm"),
|
|
||||||
nil)
|
|
||||||
|
|
||||||
wgIface := &MockWGIface{
|
wgIface := &MockWGIface{
|
||||||
NameFunc: func() string { return "utun102" },
|
NameFunc: func() string { return "utun102" },
|
||||||
@@ -647,7 +640,7 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
@@ -812,7 +805,7 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1014,7 +1007,7 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil)
|
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
|
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
@@ -1540,7 +1533,7 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil), nil
|
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
|
||||||
e.ctx = ctx
|
e.ctx = ctx
|
||||||
return e, err
|
return e, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
|
"os/user"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
"reflect"
|
||||||
"runtime"
|
"runtime"
|
||||||
@@ -165,19 +166,26 @@ func getConfigDir() (string, error) {
|
|||||||
if ConfigDirOverride != "" {
|
if ConfigDirOverride != "" {
|
||||||
return ConfigDirOverride, nil
|
return ConfigDirOverride, nil
|
||||||
}
|
}
|
||||||
configDir, err := os.UserConfigDir()
|
|
||||||
|
base, err := baseConfigDir()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
configDir = filepath.Join(configDir, "netbird")
|
configDir := filepath.Join(base, "netbird")
|
||||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
if err := os.MkdirAll(configDir, 0o755); err != nil {
|
||||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
return "", err
|
||||||
return "", err
|
}
|
||||||
|
return configDir, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func baseConfigDir() (string, error) {
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
if u, err := user.Current(); err == nil && u.HomeDir != "" {
|
||||||
|
return filepath.Join(u.HomeDir, "Library", "Application Support"), nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return os.UserConfigDir()
|
||||||
return configDir, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getConfigDirForUser(username string) (string, error) {
|
func getConfigDirForUser(username string) (string, error) {
|
||||||
|
|||||||
@@ -76,6 +76,7 @@ func (a *ActiveProfileState) FilePath() (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ServiceManager struct {
|
type ServiceManager struct {
|
||||||
|
profilesDir string // If set, overrides ConfigDirOverride for profile operations
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServiceManager(defaultConfigPath string) *ServiceManager {
|
func NewServiceManager(defaultConfigPath string) *ServiceManager {
|
||||||
@@ -85,6 +86,17 @@ func NewServiceManager(defaultConfigPath string) *ServiceManager {
|
|||||||
return &ServiceManager{}
|
return &ServiceManager{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewServiceManagerWithProfilesDir creates a ServiceManager with a specific profiles directory
|
||||||
|
// This allows setting the profiles directory without modifying the global ConfigDirOverride
|
||||||
|
func NewServiceManagerWithProfilesDir(defaultConfigPath string, profilesDir string) *ServiceManager {
|
||||||
|
if defaultConfigPath != "" {
|
||||||
|
DefaultConfigPath = defaultConfigPath
|
||||||
|
}
|
||||||
|
return &ServiceManager{
|
||||||
|
profilesDir: profilesDir,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) {
|
func (s *ServiceManager) CopyDefaultProfileIfNotExists() (bool, error) {
|
||||||
|
|
||||||
if err := os.MkdirAll(DefaultConfigPathDir, 0600); err != nil {
|
if err := os.MkdirAll(DefaultConfigPathDir, 0600); err != nil {
|
||||||
@@ -240,7 +252,7 @@ func (s *ServiceManager) DefaultProfilePath() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServiceManager) AddProfile(profileName, username string) error {
|
func (s *ServiceManager) AddProfile(profileName, username string) error {
|
||||||
configDir, err := getConfigDirForUser(username)
|
configDir, err := s.getConfigDir(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get config directory: %w", err)
|
return fmt.Errorf("failed to get config directory: %w", err)
|
||||||
}
|
}
|
||||||
@@ -270,7 +282,7 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
||||||
configDir, err := getConfigDirForUser(username)
|
configDir, err := s.getConfigDir(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get config directory: %w", err)
|
return fmt.Errorf("failed to get config directory: %w", err)
|
||||||
}
|
}
|
||||||
@@ -302,7 +314,7 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
|
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
|
||||||
configDir, err := getConfigDirForUser(username)
|
configDir, err := s.getConfigDir(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get config directory: %w", err)
|
return nil, fmt.Errorf("failed to get config directory: %w", err)
|
||||||
}
|
}
|
||||||
@@ -361,7 +373,7 @@ func (s *ServiceManager) GetStatePath() string {
|
|||||||
return defaultStatePath
|
return defaultStatePath
|
||||||
}
|
}
|
||||||
|
|
||||||
configDir, err := getConfigDirForUser(activeProf.Username)
|
configDir, err := s.getConfigDir(activeProf.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to get config directory for user %s: %v", activeProf.Username, err)
|
log.Warnf("failed to get config directory for user %s: %v", activeProf.Username, err)
|
||||||
return defaultStatePath
|
return defaultStatePath
|
||||||
@@ -369,3 +381,12 @@ func (s *ServiceManager) GetStatePath() string {
|
|||||||
|
|
||||||
return filepath.Join(configDir, activeProf.Name+".state.json")
|
return filepath.Join(configDir, activeProf.Name+".state.json")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
|
||||||
|
func (s *ServiceManager) getConfigDir(username string) (string, error) {
|
||||||
|
if s.profilesDir != "" {
|
||||||
|
return s.profilesDir, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return getConfigDirForUser(username)
|
||||||
|
}
|
||||||
|
|||||||
35
client/internal/updatemanager/doc.go
Normal file
35
client/internal/updatemanager/doc.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
// Package updatemanager provides automatic update management for the NetBird client.
|
||||||
|
// It monitors for new versions, handles update triggers from management server directives,
|
||||||
|
// and orchestrates the download and installation of client updates.
|
||||||
|
//
|
||||||
|
// # Overview
|
||||||
|
//
|
||||||
|
// The update manager operates as a background service that continuously monitors for
|
||||||
|
// available updates and automatically initiates the update process when conditions are met.
|
||||||
|
// It integrates with the installer package to perform the actual installation.
|
||||||
|
//
|
||||||
|
// # Update Flow
|
||||||
|
//
|
||||||
|
// The complete update process follows these steps:
|
||||||
|
//
|
||||||
|
// 1. Manager receives update directive via SetVersion() or detects new version
|
||||||
|
// 2. Manager validates update should proceed (version comparison, rate limiting)
|
||||||
|
// 3. Manager publishes "updating" event to status recorder
|
||||||
|
// 4. Manager persists UpdateState to track update attempt
|
||||||
|
// 5. Manager downloads installer file (.msi or .exe) to temporary directory
|
||||||
|
// 6. Manager triggers installation via installer.RunInstallation()
|
||||||
|
// 7. Installer package handles the actual installation process
|
||||||
|
// 8. On next startup, CheckUpdateSuccess() verifies update completion
|
||||||
|
// 9. Manager publishes success/failure event to status recorder
|
||||||
|
// 10. Manager cleans up UpdateState
|
||||||
|
//
|
||||||
|
// # State Management
|
||||||
|
//
|
||||||
|
// Update state is persisted across restarts to track update attempts:
|
||||||
|
//
|
||||||
|
// - PreUpdateVersion: Version before update attempt
|
||||||
|
// - TargetVersion: Version attempting to update to
|
||||||
|
//
|
||||||
|
// This enables verification of successful updates and appropriate user notification
|
||||||
|
// after the client restarts with the new version.
|
||||||
|
package updatemanager
|
||||||
138
client/internal/updatemanager/downloader/downloader.go
Normal file
138
client/internal/updatemanager/downloader/downloader.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package downloader
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
userAgent = "NetBird agent installer/%s"
|
||||||
|
DefaultRetryDelay = 3 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
func DownloadToFile(ctx context.Context, retryDelay time.Duration, url, dstFile string) error {
|
||||||
|
log.Debugf("starting download from %s", url)
|
||||||
|
|
||||||
|
out, err := os.Create(dstFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create destination file %q: %w", dstFile, err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if cerr := out.Close(); cerr != nil {
|
||||||
|
log.Warnf("error closing file %q: %v", dstFile, cerr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// First attempt
|
||||||
|
err = downloadToFileOnce(ctx, url, out)
|
||||||
|
if err == nil {
|
||||||
|
log.Infof("successfully downloaded file to %s", dstFile)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If retryDelay is 0, don't retry
|
||||||
|
if retryDelay == 0 {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warnf("download failed, retrying after %v: %v", retryDelay, err)
|
||||||
|
|
||||||
|
// Sleep before retry
|
||||||
|
if sleepErr := sleepWithContext(ctx, retryDelay); sleepErr != nil {
|
||||||
|
return fmt.Errorf("download cancelled during retry delay: %w", sleepErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Truncate file before retry
|
||||||
|
if err := out.Truncate(0); err != nil {
|
||||||
|
return fmt.Errorf("failed to truncate file on retry: %w", err)
|
||||||
|
}
|
||||||
|
if _, err := out.Seek(0, 0); err != nil {
|
||||||
|
return fmt.Errorf("failed to seek to beginning of file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second attempt
|
||||||
|
if err := downloadToFileOnce(ctx, url, out); err != nil {
|
||||||
|
return fmt.Errorf("download failed after retry: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("successfully downloaded file to %s", dstFile)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DownloadToMemory(ctx context.Context, url string, limit int64) ([]byte, error) {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create HTTP request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add User-Agent header
|
||||||
|
req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to perform HTTP request: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if cerr := resp.Body.Close(); cerr != nil {
|
||||||
|
log.Warnf("error closing response body: %v", cerr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := io.ReadAll(io.LimitReader(resp.Body, limit))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func downloadToFileOnce(ctx context.Context, url string, out *os.File) error {
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create HTTP request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add User-Agent header
|
||||||
|
req.Header.Set("User-Agent", fmt.Sprintf(userAgent, version.NetbirdVersion()))
|
||||||
|
|
||||||
|
resp, err := http.DefaultClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to perform HTTP request: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if cerr := resp.Body.Close(); cerr != nil {
|
||||||
|
log.Warnf("error closing response body: %v", cerr)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("unexpected HTTP status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.Copy(out, resp.Body); err != nil {
|
||||||
|
return fmt.Errorf("failed to write response body to file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sleepWithContext(ctx context.Context, duration time.Duration) error {
|
||||||
|
select {
|
||||||
|
case <-time.After(duration):
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
199
client/internal/updatemanager/downloader/downloader_test.go
Normal file
199
client/internal/updatemanager/downloader/downloader_test.go
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
package downloader
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
retryDelay = 100 * time.Millisecond
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDownloadToFile_Success(t *testing.T) {
|
||||||
|
// Create a test server that responds successfully
|
||||||
|
content := "test file content"
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte(content))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Create a temporary file for download
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||||
|
|
||||||
|
// Download the file
|
||||||
|
err := DownloadToFile(context.Background(), retryDelay, server.URL, dstFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected no error, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the file content
|
||||||
|
data, err := os.ReadFile(dstFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read downloaded file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(data) != content {
|
||||||
|
t.Errorf("expected content %q, got %q", content, string(data))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadToFile_SuccessAfterRetry(t *testing.T) {
|
||||||
|
content := "test file content after retry"
|
||||||
|
var attemptCount atomic.Int32
|
||||||
|
|
||||||
|
// Create a test server that fails on first attempt, succeeds on second
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attempt := attemptCount.Add(1)
|
||||||
|
if attempt == 1 {
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_, _ = w.Write([]byte("error"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte(content))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Create a temporary file for download
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||||
|
|
||||||
|
// Download the file (should succeed after retry)
|
||||||
|
if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err != nil {
|
||||||
|
t.Fatalf("expected no error after retry, got: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify the file content
|
||||||
|
data, err := os.ReadFile(dstFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to read downloaded file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(data) != content {
|
||||||
|
t.Errorf("expected content %q, got %q", content, string(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it took 2 attempts
|
||||||
|
if attemptCount.Load() != 2 {
|
||||||
|
t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadToFile_FailsAfterRetry(t *testing.T) {
|
||||||
|
var attemptCount atomic.Int32
|
||||||
|
|
||||||
|
// Create a test server that always fails
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attemptCount.Add(1)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_, _ = w.Write([]byte("error"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Create a temporary file for download
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||||
|
|
||||||
|
// Download the file (should fail after retry)
|
||||||
|
if err := DownloadToFile(context.Background(), 10*time.Millisecond, server.URL, dstFile); err == nil {
|
||||||
|
t.Fatal("expected error after retry, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it tried 2 times
|
||||||
|
if attemptCount.Load() != 2 {
|
||||||
|
t.Errorf("expected 2 attempts, got %d", attemptCount.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadToFile_ContextCancellationDuringRetry(t *testing.T) {
|
||||||
|
var attemptCount atomic.Int32
|
||||||
|
|
||||||
|
// Create a test server that always fails
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attemptCount.Add(1)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Create a temporary file for download
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||||
|
|
||||||
|
// Create a context that will be cancelled during retry delay
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
// Cancel after a short delay (during the retry sleep)
|
||||||
|
go func() {
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Download the file (should fail due to context cancellation during retry)
|
||||||
|
err := DownloadToFile(ctx, 1*time.Second, server.URL, dstFile)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error due to context cancellation, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have only made 1 attempt (cancelled during retry delay)
|
||||||
|
if attemptCount.Load() != 1 {
|
||||||
|
t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadToFile_InvalidURL(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||||
|
|
||||||
|
err := DownloadToFile(context.Background(), retryDelay, "://invalid-url", dstFile)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid URL, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadToFile_InvalidDestination(t *testing.T) {
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write([]byte("test"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Use an invalid destination path
|
||||||
|
err := DownloadToFile(context.Background(), retryDelay, server.URL, "/invalid/path/that/does/not/exist/file.txt")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid destination, got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDownloadToFile_NoRetry(t *testing.T) {
|
||||||
|
var attemptCount atomic.Int32
|
||||||
|
|
||||||
|
// Create a test server that always fails
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
attemptCount.Add(1)
|
||||||
|
w.WriteHeader(http.StatusInternalServerError)
|
||||||
|
_, _ = w.Write([]byte("error"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Create a temporary file for download
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
dstFile := filepath.Join(tempDir, "downloaded.txt")
|
||||||
|
|
||||||
|
// Download the file with retryDelay = 0 (should not retry)
|
||||||
|
if err := DownloadToFile(context.Background(), 0, server.URL, dstFile); err == nil {
|
||||||
|
t.Fatal("expected error, got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it only made 1 attempt (no retry)
|
||||||
|
if attemptCount.Load() != 1 {
|
||||||
|
t.Errorf("expected 1 attempt, got %d", attemptCount.Load())
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build !windows
|
||||||
|
|
||||||
|
package installer
|
||||||
|
|
||||||
|
func UpdaterBinaryNameWithoutExtension() string {
|
||||||
|
return updaterBinary
|
||||||
|
}
|
||||||
11
client/internal/updatemanager/installer/binary_windows.go
Normal file
11
client/internal/updatemanager/installer/binary_windows.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func UpdaterBinaryNameWithoutExtension() string {
|
||||||
|
ext := filepath.Ext(updaterBinary)
|
||||||
|
return strings.TrimSuffix(updaterBinary, ext)
|
||||||
|
}
|
||||||
111
client/internal/updatemanager/installer/doc.go
Normal file
111
client/internal/updatemanager/installer/doc.go
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
// Package installer provides functionality for managing NetBird application
|
||||||
|
// updates and installations across Windows, macOS. It handles
|
||||||
|
// the complete update lifecycle including artifact download, cryptographic verification,
|
||||||
|
// installation execution, process management, and result reporting.
|
||||||
|
//
|
||||||
|
// # Architecture
|
||||||
|
//
|
||||||
|
// The installer package uses a two-process architecture to enable self-updates:
|
||||||
|
//
|
||||||
|
// 1. Service Process: The main NetBird daemon process that initiates updates
|
||||||
|
// 2. Updater Process: A detached child process that performs the actual installation
|
||||||
|
//
|
||||||
|
// This separation is critical because:
|
||||||
|
// - The service binary cannot update itself while running
|
||||||
|
// - The installer (EXE/MSI/PKG) will terminate the service during installation
|
||||||
|
// - The updater process survives service termination and restarts it after installation
|
||||||
|
// - Results can be communicated back to the service after it restarts
|
||||||
|
//
|
||||||
|
// # Update Flow
|
||||||
|
//
|
||||||
|
// Service Process (RunInstallation):
|
||||||
|
//
|
||||||
|
// 1. Validates target version format (semver)
|
||||||
|
// 2. Determines installer type (EXE, MSI, PKG, or Homebrew)
|
||||||
|
// 3. Downloads installer file from GitHub releases (if applicable)
|
||||||
|
// 4. Verifies installer signature using reposign package (cryptographic verification in service process before
|
||||||
|
// launching updater)
|
||||||
|
// 5. Copies service binary to tempDir as "updater" (or "updater.exe" on Windows)
|
||||||
|
// 6. Launches updater process with detached mode:
|
||||||
|
// - --temp-dir: Temporary directory path
|
||||||
|
// - --service-dir: Service installation directory
|
||||||
|
// - --installer-file: Path to downloaded installer (if applicable)
|
||||||
|
// - --dry-run: Optional flag to test without actually installing
|
||||||
|
// 7. Service process continues running (will be terminated by installer later)
|
||||||
|
// 8. Service can watch for result.json using ResultHandler.Watch() to detect completion
|
||||||
|
//
|
||||||
|
// Updater Process (Setup):
|
||||||
|
//
|
||||||
|
// 1. Receives parameters from service via command-line arguments
|
||||||
|
// 2. Runs installer with appropriate silent/quiet flags:
|
||||||
|
// - Windows EXE: installer.exe /S
|
||||||
|
// - Windows MSI: msiexec.exe /i installer.msi /quiet /qn /l*v msi.log
|
||||||
|
// - macOS PKG: installer -pkg installer.pkg -target /
|
||||||
|
// - macOS Homebrew: brew upgrade netbirdio/tap/netbird
|
||||||
|
// 3. Installer terminates daemon and UI processes
|
||||||
|
// 4. Installer replaces binaries with new version
|
||||||
|
// 5. Updater waits for installer to complete
|
||||||
|
// 6. Updater restarts daemon:
|
||||||
|
// - Windows: netbird.exe service start
|
||||||
|
// - macOS/Linux: netbird service start
|
||||||
|
// 7. Updater restarts UI:
|
||||||
|
// - Windows: Launches netbird-ui.exe as active console user using CreateProcessAsUser
|
||||||
|
// - macOS: Uses launchctl asuser to launch NetBird.app for console user
|
||||||
|
// - Linux: Not implemented (UI typically auto-starts)
|
||||||
|
// 8. Updater writes result.json with success/error status
|
||||||
|
// 9. Updater process exits
|
||||||
|
//
|
||||||
|
// # Result Communication
|
||||||
|
//
|
||||||
|
// The ResultHandler (result.go) manages communication between updater and service:
|
||||||
|
//
|
||||||
|
// Result Structure:
|
||||||
|
//
|
||||||
|
// type Result struct {
|
||||||
|
// Success bool // true if installation succeeded
|
||||||
|
// Error string // error message if Success is false
|
||||||
|
// ExecutedAt time.Time // when installation completed
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Result files are automatically cleaned up after being read.
|
||||||
|
//
|
||||||
|
// # File Locations
|
||||||
|
//
|
||||||
|
// Temporary Directory (platform-specific):
|
||||||
|
//
|
||||||
|
// Windows:
|
||||||
|
// - Path: %ProgramData%\Netbird\tmp-install
|
||||||
|
// - Example: C:\ProgramData\Netbird\tmp-install
|
||||||
|
//
|
||||||
|
// macOS:
|
||||||
|
// - Path: /var/lib/netbird/tmp-install
|
||||||
|
// - Requires root permissions
|
||||||
|
//
|
||||||
|
// Files created during installation:
|
||||||
|
//
|
||||||
|
// tmp-install/
|
||||||
|
// installer.log
|
||||||
|
// updater[.exe] # Copy of service binary
|
||||||
|
// netbird_installer_*.[exe|msi|pkg] # Downloaded installer
|
||||||
|
// result.json # Installation result
|
||||||
|
// msi.log # MSI verbose log (Windows MSI only)
|
||||||
|
//
|
||||||
|
// # API Reference
|
||||||
|
//
|
||||||
|
// # Cleanup
|
||||||
|
//
|
||||||
|
// CleanUpInstallerFiles() removes temporary files after successful installation:
|
||||||
|
// - Downloaded installer files (*.exe, *.msi, *.pkg)
|
||||||
|
// - Updater binary copy
|
||||||
|
// - Does NOT remove result.json (cleaned by ResultHandler after read)
|
||||||
|
// - Does NOT remove msi.log (kept for debugging)
|
||||||
|
//
|
||||||
|
// # Dry-Run Mode
|
||||||
|
//
|
||||||
|
// Dry-run mode allows testing the update process without actually installing:
|
||||||
|
//
|
||||||
|
// Enable via environment variable:
|
||||||
|
//
|
||||||
|
// export NB_AUTO_UPDATE_DRY_RUN=true
|
||||||
|
// netbird service install-update 0.29.0
|
||||||
|
package installer
|
||||||
50
client/internal/updatemanager/installer/installer.go
Normal file
50
client/internal/updatemanager/installer/installer.go
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
//go:build !windows && !darwin
|
||||||
|
|
||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
updaterBinary = "updater"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Installer struct {
|
||||||
|
tempDir string
|
||||||
|
}
|
||||||
|
|
||||||
|
// New used by the service
|
||||||
|
func New() *Installer {
|
||||||
|
return &Installer{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWithDir used by the updater process, get the tempDir from the service via cmd line
|
||||||
|
func NewWithDir(tempDir string) *Installer {
|
||||||
|
return &Installer{
|
||||||
|
tempDir: tempDir,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) TempDir() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Installer) LogFiles() []string {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) CleanUpInstallerFiles() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) error {
|
||||||
|
return fmt.Errorf("unsupported platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
|
||||||
|
// This will be run by the updater process
|
||||||
|
func (u *Installer) Setup(ctx context.Context, dryRun bool, targetVersion string, daemonFolder string) (resultErr error) {
|
||||||
|
return fmt.Errorf("unsupported platform")
|
||||||
|
}
|
||||||
293
client/internal/updatemanager/installer/installer_common.go
Normal file
293
client/internal/updatemanager/installer/installer_common.go
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
//go:build windows || darwin
|
||||||
|
|
||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
goversion "github.com/hashicorp/go-version"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Installer struct {
|
||||||
|
tempDir string
|
||||||
|
}
|
||||||
|
|
||||||
|
// New used by the service
|
||||||
|
func New() *Installer {
|
||||||
|
return &Installer{
|
||||||
|
tempDir: defaultTempDir,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewWithDir used by the updater process, get the tempDir from the service via cmd line
|
||||||
|
func NewWithDir(tempDir string) *Installer {
|
||||||
|
return &Installer{
|
||||||
|
tempDir: tempDir,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RunInstallation starts the updater process to run the installation
|
||||||
|
// This will run by the original service process
|
||||||
|
func (u *Installer) RunInstallation(ctx context.Context, targetVersion string) (err error) {
|
||||||
|
resultHandler := NewResultHandler(u.tempDir)
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
if writeErr := resultHandler.WriteErr(err); writeErr != nil {
|
||||||
|
log.Errorf("failed to write error result: %v", writeErr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := validateTargetVersion(targetVersion); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := u.mkTempDir(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var installerFile string
|
||||||
|
// Download files only when not using any third-party store
|
||||||
|
if installerType := TypeOfInstaller(ctx); installerType.Downloadable() {
|
||||||
|
log.Infof("download installer")
|
||||||
|
var err error
|
||||||
|
installerFile, err = u.downloadInstaller(ctx, installerType, targetVersion)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to download installer: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
artifactVerify, err := reposign.NewArtifactVerify(DefaultSigningKeysBaseURL)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to create artifact verify: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := artifactVerify.Verify(ctx, targetVersion, installerFile); err != nil {
|
||||||
|
log.Errorf("artifact verification error: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("running installer")
|
||||||
|
updaterPath, err := u.copyUpdater()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// the directory where the service has been installed
|
||||||
|
workspace, err := getServiceDir()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
args := []string{
|
||||||
|
"--temp-dir", u.tempDir,
|
||||||
|
"--service-dir", workspace,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isDryRunEnabled() {
|
||||||
|
args = append(args, "--dry-run=true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if installerFile != "" {
|
||||||
|
args = append(args, "--installer-file", installerFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
updateCmd := exec.Command(updaterPath, args...)
|
||||||
|
log.Infof("starting updater process: %s", updateCmd.String())
|
||||||
|
|
||||||
|
// Configure the updater to run in a separate session/process group
|
||||||
|
// so it survives the parent daemon being stopped
|
||||||
|
setUpdaterProcAttr(updateCmd)
|
||||||
|
|
||||||
|
// Start the updater process asynchronously
|
||||||
|
if err := updateCmd.Start(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
pid := updateCmd.Process.Pid
|
||||||
|
log.Infof("updater started with PID %d", pid)
|
||||||
|
|
||||||
|
// Release the process so the OS can fully detach it
|
||||||
|
if err := updateCmd.Process.Release(); err != nil {
|
||||||
|
log.Warnf("failed to release updater process: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanUpInstallerFiles
|
||||||
|
// - the installer file (pkg, exe, msi)
|
||||||
|
// - the selfcopy updater.exe
|
||||||
|
func (u *Installer) CleanUpInstallerFiles() error {
|
||||||
|
// Check if tempDir exists
|
||||||
|
info, err := os.Stat(u.tempDir)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !info.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := os.Remove(filepath.Join(u.tempDir, updaterBinary)); err != nil && !os.IsNotExist(err) {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("failed to remove updater binary: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(u.tempDir)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entry := range entries {
|
||||||
|
if entry.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
name := entry.Name()
|
||||||
|
for _, ext := range binaryExtensions {
|
||||||
|
if strings.HasSuffix(strings.ToLower(name), strings.ToLower(ext)) {
|
||||||
|
if err := os.Remove(filepath.Join(u.tempDir, name)); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("failed to remove %s: %w", name, err))
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return merr.ErrorOrNil()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) downloadInstaller(ctx context.Context, installerType Type, targetVersion string) (string, error) {
|
||||||
|
fileURL := urlWithVersionArch(installerType, targetVersion)
|
||||||
|
|
||||||
|
// Clean up temp directory on error
|
||||||
|
var success bool
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
if err := os.RemoveAll(u.tempDir); err != nil {
|
||||||
|
log.Errorf("error cleaning up temporary directory: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
fileName := path.Base(fileURL)
|
||||||
|
if fileName == "." || fileName == "/" || fileName == "" {
|
||||||
|
return "", fmt.Errorf("invalid file URL: %s", fileURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
outputFilePath := filepath.Join(u.tempDir, fileName)
|
||||||
|
if err := downloader.DownloadToFile(ctx, downloader.DefaultRetryDelay, fileURL, outputFilePath); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
success = true
|
||||||
|
return outputFilePath, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) TempDir() string {
|
||||||
|
return u.tempDir
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) mkTempDir() error {
|
||||||
|
if err := os.MkdirAll(u.tempDir, 0o755); err != nil {
|
||||||
|
log.Debugf("failed to create tempdir: %s", u.tempDir)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) copyUpdater() (string, error) {
|
||||||
|
src, err := getServiceBinary()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to get updater binary: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dst := filepath.Join(u.tempDir, updaterBinary)
|
||||||
|
if err := copyFile(src, dst); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to copy updater binary: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Chmod(dst, 0o755); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to set permissions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return dst, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateTargetVersion(targetVersion string) error {
|
||||||
|
if targetVersion == "" {
|
||||||
|
return fmt.Errorf("target version cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := goversion.NewVersion(targetVersion)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid target version %q: %w", targetVersion, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyFile(src, dst string) error {
|
||||||
|
log.Infof("copying %s to %s", src, dst)
|
||||||
|
in, err := os.Open(src)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open source: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := in.Close(); err != nil {
|
||||||
|
log.Warnf("failed to close source file: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
out, err := os.Create(dst)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create destination: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := out.Close(); err != nil {
|
||||||
|
log.Warnf("failed to close destination file: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if _, err := io.Copy(out, in); err != nil {
|
||||||
|
return fmt.Errorf("copy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getServiceDir() (string, error) {
|
||||||
|
exePath, err := os.Executable()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return filepath.Dir(exePath), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getServiceBinary() (string, error) {
|
||||||
|
return os.Executable()
|
||||||
|
}
|
||||||
|
|
||||||
|
func isDryRunEnabled() bool {
|
||||||
|
return strings.EqualFold(strings.TrimSpace(os.Getenv("NB_AUTO_UPDATE_DRY_RUN")), "true")
|
||||||
|
}
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (u *Installer) LogFiles() []string {
|
||||||
|
return []string{
|
||||||
|
filepath.Join(u.tempDir, LogFile),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (u *Installer) LogFiles() []string {
|
||||||
|
return []string{
|
||||||
|
filepath.Join(u.tempDir, msiLogFile),
|
||||||
|
filepath.Join(u.tempDir, LogFile),
|
||||||
|
}
|
||||||
|
}
|
||||||
238
client/internal/updatemanager/installer/installer_run_darwin.go
Normal file
238
client/internal/updatemanager/installer/installer_run_darwin.go
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"os/user"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
daemonName = "netbird"
|
||||||
|
updaterBinary = "updater"
|
||||||
|
uiBinary = "/Applications/NetBird.app"
|
||||||
|
|
||||||
|
defaultTempDir = "/var/lib/netbird/tmp-install"
|
||||||
|
|
||||||
|
pkgDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_%version_darwin_%arch.pkg"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
binaryExtensions = []string{"pkg"}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
|
||||||
|
// This will be run by the updater process
|
||||||
|
func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
|
||||||
|
resultHandler := NewResultHandler(u.tempDir)
|
||||||
|
|
||||||
|
// Always ensure daemon and UI are restarted after setup
|
||||||
|
defer func() {
|
||||||
|
log.Infof("write out result")
|
||||||
|
var err error
|
||||||
|
if resultErr == nil {
|
||||||
|
err = resultHandler.WriteSuccess()
|
||||||
|
} else {
|
||||||
|
err = resultHandler.WriteErr(resultErr)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to write update result: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// skip service restart if dry-run mode is enabled
|
||||||
|
if dryRun {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("starting daemon back")
|
||||||
|
if err := u.startDaemon(daemonFolder); err != nil {
|
||||||
|
log.Errorf("failed to start daemon: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("starting UI back")
|
||||||
|
if err := u.startUIAsUser(); err != nil {
|
||||||
|
log.Errorf("failed to start UI: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
}()
|
||||||
|
|
||||||
|
if dryRun {
|
||||||
|
time.Sleep(7 * time.Second)
|
||||||
|
log.Infof("dry-run mode enabled, skipping actual installation")
|
||||||
|
resultErr = fmt.Errorf("dry-run mode enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch TypeOfInstaller(ctx) {
|
||||||
|
case TypePKG:
|
||||||
|
resultErr = u.installPkgFile(ctx, installerFile)
|
||||||
|
case TypeHomebrew:
|
||||||
|
resultErr = u.updateHomeBrew(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resultErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) startDaemon(daemonFolder string) error {
|
||||||
|
log.Infof("starting netbird service")
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Warnf("failed to start netbird service: %v, output: %s", err, string(output))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Infof("netbird service started successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) startUIAsUser() error {
|
||||||
|
log.Infof("starting netbird-ui: %s", uiBinary)
|
||||||
|
|
||||||
|
// Get the current console user
|
||||||
|
cmd := exec.Command("stat", "-f", "%Su", "/dev/console")
|
||||||
|
output, err := cmd.Output()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get console user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
username := strings.TrimSpace(string(output))
|
||||||
|
if username == "" || username == "root" {
|
||||||
|
return fmt.Errorf("no active user session found")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("starting UI for user: %s", username)
|
||||||
|
|
||||||
|
// Get user's UID
|
||||||
|
userInfo, err := user.Lookup(username)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to lookup user %s: %w", username, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the UI process as the console user using launchctl
|
||||||
|
// This ensures the app runs in the user's context with proper GUI access
|
||||||
|
launchCmd := exec.Command("launchctl", "asuser", userInfo.Uid, "open", "-a", uiBinary)
|
||||||
|
log.Infof("launchCmd: %s", launchCmd.String())
|
||||||
|
// Set the user's home directory for proper macOS app behavior
|
||||||
|
launchCmd.Env = append(os.Environ(), "HOME="+userInfo.HomeDir)
|
||||||
|
log.Infof("set HOME environment variable: %s", userInfo.HomeDir)
|
||||||
|
|
||||||
|
if err := launchCmd.Start(); err != nil {
|
||||||
|
return fmt.Errorf("failed to start UI process: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Release the process so it can run independently
|
||||||
|
if err := launchCmd.Process.Release(); err != nil {
|
||||||
|
log.Warnf("failed to release UI process: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("netbird-ui started successfully for user %s", username)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) installPkgFile(ctx context.Context, path string) error {
|
||||||
|
log.Infof("installing pkg file: %s", path)
|
||||||
|
|
||||||
|
// Kill any existing UI processes before installation
|
||||||
|
// This ensures the postinstall script's "open $APP" will start the new version
|
||||||
|
u.killUI()
|
||||||
|
|
||||||
|
volume := "/"
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "installer", "-pkg", path, "-target", volume)
|
||||||
|
if err := cmd.Start(); err != nil {
|
||||||
|
return fmt.Errorf("error running pkg file: %w", err)
|
||||||
|
}
|
||||||
|
log.Infof("installer started with PID %d", cmd.Process.Pid)
|
||||||
|
if err := cmd.Wait(); err != nil {
|
||||||
|
return fmt.Errorf("error running pkg file: %w", err)
|
||||||
|
}
|
||||||
|
log.Infof("pkg file installed successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) updateHomeBrew(ctx context.Context) error {
|
||||||
|
log.Infof("updating homebrew")
|
||||||
|
|
||||||
|
// Kill any existing UI processes before upgrade
|
||||||
|
// This ensures the new version will be started after upgrade
|
||||||
|
u.killUI()
|
||||||
|
|
||||||
|
// Homebrew must be run as a non-root user
|
||||||
|
// To find out which user installed NetBird using HomeBrew we can check the owner of our brew tap directory
|
||||||
|
// Check both Apple Silicon and Intel Mac paths
|
||||||
|
brewTapPath := "/opt/homebrew/Library/Taps/netbirdio/homebrew-tap/"
|
||||||
|
brewBinPath := "/opt/homebrew/bin/brew"
|
||||||
|
if _, err := os.Stat(brewTapPath); os.IsNotExist(err) {
|
||||||
|
// Try Intel Mac path
|
||||||
|
brewTapPath = "/usr/local/Homebrew/Library/Taps/netbirdio/homebrew-tap/"
|
||||||
|
brewBinPath = "/usr/local/bin/brew"
|
||||||
|
}
|
||||||
|
|
||||||
|
fileInfo, err := os.Stat(brewTapPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting homebrew installation path info: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fileSysInfo, ok := fileInfo.Sys().(*syscall.Stat_t)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("error checking file owner, sysInfo type is %T not *syscall.Stat_t", fileInfo.Sys())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get username from UID
|
||||||
|
brewUser, err := user.LookupId(fmt.Sprintf("%d", fileSysInfo.Uid))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error looking up brew installer user: %w", err)
|
||||||
|
}
|
||||||
|
userName := brewUser.Username
|
||||||
|
// Get user HOME, required for brew to run correctly
|
||||||
|
// https://github.com/Homebrew/brew/issues/15833
|
||||||
|
homeDir := brewUser.HomeDir
|
||||||
|
|
||||||
|
// Check if netbird-ui is installed (must run as the brew user, not root)
|
||||||
|
checkUICmd := exec.CommandContext(ctx, "sudo", "-u", userName, brewBinPath, "list", "--formula", "netbirdio/tap/netbird-ui")
|
||||||
|
checkUICmd.Env = append(os.Environ(), "HOME="+homeDir)
|
||||||
|
uiInstalled := checkUICmd.Run() == nil
|
||||||
|
|
||||||
|
// Homebrew does not support installing specific versions
|
||||||
|
// Thus it will always update to latest and ignore targetVersion
|
||||||
|
upgradeArgs := []string{"-u", userName, brewBinPath, "upgrade", "netbirdio/tap/netbird"}
|
||||||
|
if uiInstalled {
|
||||||
|
upgradeArgs = append(upgradeArgs, "netbirdio/tap/netbird-ui")
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "sudo", upgradeArgs...)
|
||||||
|
cmd.Env = append(os.Environ(), "HOME="+homeDir)
|
||||||
|
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
return fmt.Errorf("error running brew upgrade: %w, output: %s", err, string(output))
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("homebrew updated successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) killUI() {
|
||||||
|
log.Infof("killing existing netbird-ui processes")
|
||||||
|
cmd := exec.Command("pkill", "-x", "netbird-ui")
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
// pkill returns exit code 1 if no processes matched, which is fine
|
||||||
|
log.Debugf("pkill netbird-ui result: %v, output: %s", err, string(output))
|
||||||
|
} else {
|
||||||
|
log.Infof("netbird-ui processes killed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func urlWithVersionArch(_ Type, version string) string {
|
||||||
|
url := strings.ReplaceAll(pkgDownloadURL, "%version", version)
|
||||||
|
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
|
||||||
|
}
|
||||||
213
client/internal/updatemanager/installer/installer_run_windows.go
Normal file
213
client/internal/updatemanager/installer/installer_run_windows.go
Normal file
@@ -0,0 +1,213 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
daemonName = "netbird.exe"
|
||||||
|
uiName = "netbird-ui.exe"
|
||||||
|
updaterBinary = "updater.exe"
|
||||||
|
|
||||||
|
msiLogFile = "msi.log"
|
||||||
|
|
||||||
|
msiDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.msi"
|
||||||
|
exeDownloadURL = "https://github.com/mlsmaycon/netbird/releases/download/v%version/netbird_installer_%version_windows_%arch.exe"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultTempDir = filepath.Join(os.Getenv("ProgramData"), "Netbird", "tmp-install")
|
||||||
|
|
||||||
|
// for the cleanup
|
||||||
|
binaryExtensions = []string{"msi", "exe"}
|
||||||
|
)
|
||||||
|
|
||||||
|
// Setup runs the installer with appropriate arguments and manages the daemon/UI state
|
||||||
|
// This will be run by the updater process
|
||||||
|
func (u *Installer) Setup(ctx context.Context, dryRun bool, installerFile string, daemonFolder string) (resultErr error) {
|
||||||
|
resultHandler := NewResultHandler(u.tempDir)
|
||||||
|
|
||||||
|
// Always ensure daemon and UI are restarted after setup
|
||||||
|
defer func() {
|
||||||
|
log.Infof("starting daemon back")
|
||||||
|
if err := u.startDaemon(daemonFolder); err != nil {
|
||||||
|
log.Errorf("failed to start daemon: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("starting UI back")
|
||||||
|
if err := u.startUIAsUser(daemonFolder); err != nil {
|
||||||
|
log.Errorf("failed to start UI: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("write out result")
|
||||||
|
var err error
|
||||||
|
if resultErr == nil {
|
||||||
|
err = resultHandler.WriteSuccess()
|
||||||
|
} else {
|
||||||
|
err = resultHandler.WriteErr(resultErr)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to write update result: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if dryRun {
|
||||||
|
log.Infof("dry-run mode enabled, skipping actual installation")
|
||||||
|
resultErr = fmt.Errorf("dry-run mode enabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
installerType, err := typeByFileExtension(installerFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("%v", err)
|
||||||
|
resultErr = err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
switch installerType {
|
||||||
|
case TypeExe:
|
||||||
|
log.Infof("run exe installer: %s", installerFile)
|
||||||
|
cmd = exec.CommandContext(ctx, installerFile, "/S")
|
||||||
|
default:
|
||||||
|
installerDir := filepath.Dir(installerFile)
|
||||||
|
logPath := filepath.Join(installerDir, msiLogFile)
|
||||||
|
log.Infof("run msi installer: %s", installerFile)
|
||||||
|
cmd = exec.CommandContext(ctx, "msiexec.exe", "/i", filepath.Base(installerFile), "/quiet", "/qn", "/l*v", logPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Dir = filepath.Dir(installerFile)
|
||||||
|
|
||||||
|
if resultErr = cmd.Start(); resultErr != nil {
|
||||||
|
log.Errorf("error starting installer: %v", resultErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("installer started with PID %d", cmd.Process.Pid)
|
||||||
|
if resultErr = cmd.Wait(); resultErr != nil {
|
||||||
|
log.Errorf("installer process finished with error: %v", resultErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) startDaemon(daemonFolder string) error {
|
||||||
|
log.Infof("starting netbird service")
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, filepath.Join(daemonFolder, daemonName), "service", "start")
|
||||||
|
if output, err := cmd.CombinedOutput(); err != nil {
|
||||||
|
log.Debugf("failed to start netbird service: %v, output: %s", err, string(output))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Infof("netbird service started successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Installer) startUIAsUser(daemonFolder string) error {
|
||||||
|
uiPath := filepath.Join(daemonFolder, uiName)
|
||||||
|
log.Infof("starting netbird-ui: %s", uiPath)
|
||||||
|
|
||||||
|
// Get the active console session ID
|
||||||
|
sessionID := windows.WTSGetActiveConsoleSessionId()
|
||||||
|
if sessionID == 0xFFFFFFFF {
|
||||||
|
return fmt.Errorf("no active user session found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the user token for that session
|
||||||
|
var userToken windows.Token
|
||||||
|
err := windows.WTSQueryUserToken(sessionID, &userToken)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to query user token: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := userToken.Close(); err != nil {
|
||||||
|
log.Warnf("failed to close user token: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Duplicate the token to a primary token
|
||||||
|
var primaryToken windows.Token
|
||||||
|
err = windows.DuplicateTokenEx(
|
||||||
|
userToken,
|
||||||
|
windows.MAXIMUM_ALLOWED,
|
||||||
|
nil,
|
||||||
|
windows.SecurityImpersonation,
|
||||||
|
windows.TokenPrimary,
|
||||||
|
&primaryToken,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to duplicate token: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := primaryToken.Close(); err != nil {
|
||||||
|
log.Warnf("failed to close token: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Prepare startup info
|
||||||
|
var si windows.StartupInfo
|
||||||
|
si.Cb = uint32(unsafe.Sizeof(si))
|
||||||
|
si.Desktop = windows.StringToUTF16Ptr("winsta0\\default")
|
||||||
|
|
||||||
|
var pi windows.ProcessInformation
|
||||||
|
|
||||||
|
cmdLine, err := windows.UTF16PtrFromString(fmt.Sprintf("\"%s\"", uiPath))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to convert path to UTF16: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
creationFlags := uint32(0x00000200 | 0x00000008 | 0x00000400) // CREATE_NEW_PROCESS_GROUP | DETACHED_PROCESS | CREATE_UNICODE_ENVIRONMENT
|
||||||
|
|
||||||
|
err = windows.CreateProcessAsUser(
|
||||||
|
primaryToken,
|
||||||
|
nil,
|
||||||
|
cmdLine,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
false,
|
||||||
|
creationFlags,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
&si,
|
||||||
|
&pi,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("CreateProcessAsUser failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close handles
|
||||||
|
if err := windows.CloseHandle(pi.Process); err != nil {
|
||||||
|
log.Warnf("failed to close process handle: %v", err)
|
||||||
|
}
|
||||||
|
if err := windows.CloseHandle(pi.Thread); err != nil {
|
||||||
|
log.Warnf("failed to close thread handle: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("netbird-ui started successfully in session %d", sessionID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func urlWithVersionArch(it Type, version string) string {
|
||||||
|
var url string
|
||||||
|
if it == TypeExe {
|
||||||
|
url = exeDownloadURL
|
||||||
|
} else {
|
||||||
|
url = msiDownloadURL
|
||||||
|
}
|
||||||
|
url = strings.ReplaceAll(url, "%version", version)
|
||||||
|
return strings.ReplaceAll(url, "%arch", runtime.GOARCH)
|
||||||
|
}
|
||||||
5
client/internal/updatemanager/installer/log.go
Normal file
5
client/internal/updatemanager/installer/log.go
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
const (
|
||||||
|
LogFile = "installer.log"
|
||||||
|
)
|
||||||
15
client/internal/updatemanager/installer/procattr_darwin.go
Normal file
15
client/internal/updatemanager/installer/procattr_darwin.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setUpdaterProcAttr configures the updater process to run in a new session,
|
||||||
|
// making it independent of the parent daemon process. This ensures the updater
|
||||||
|
// survives when the daemon is stopped during the pkg installation.
|
||||||
|
func setUpdaterProcAttr(cmd *exec.Cmd) {
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
|
Setsid: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
14
client/internal/updatemanager/installer/procattr_windows.go
Normal file
14
client/internal/updatemanager/installer/procattr_windows.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os/exec"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setUpdaterProcAttr configures the updater process to run detached from the parent,
|
||||||
|
// making it independent of the parent daemon process.
|
||||||
|
func setUpdaterProcAttr(cmd *exec.Cmd) {
|
||||||
|
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||||
|
CreationFlags: syscall.CREATE_NEW_PROCESS_GROUP | 0x00000008, // 0x00000008 is DETACHED_PROCESS
|
||||||
|
}
|
||||||
|
}
|
||||||
7
client/internal/updatemanager/installer/repourl_dev.go
Normal file
7
client/internal/updatemanager/installer/repourl_dev.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build devartifactsign
|
||||||
|
|
||||||
|
package installer
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultSigningKeysBaseURL = "http://192.168.0.10:9089/signrepo"
|
||||||
|
)
|
||||||
7
client/internal/updatemanager/installer/repourl_prod.go
Normal file
7
client/internal/updatemanager/installer/repourl_prod.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build !devartifactsign
|
||||||
|
|
||||||
|
package installer
|
||||||
|
|
||||||
|
const (
|
||||||
|
DefaultSigningKeysBaseURL = "https://publickeys.netbird.io/artifact-signatures"
|
||||||
|
)
|
||||||
230
client/internal/updatemanager/installer/result.go
Normal file
230
client/internal/updatemanager/installer/result.go
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
resultFile = "result.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Result struct {
|
||||||
|
Success bool
|
||||||
|
Error string
|
||||||
|
ExecutedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResultHandler handles reading and writing update results
|
||||||
|
type ResultHandler struct {
|
||||||
|
resultFile string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResultHandler creates a new communicator with the given directory path
|
||||||
|
// The result file will be created as "result.json" in the specified directory
|
||||||
|
func NewResultHandler(installerDir string) *ResultHandler {
|
||||||
|
// Create it if it doesn't exist
|
||||||
|
// do not care if already exists
|
||||||
|
_ = os.MkdirAll(installerDir, 0o700)
|
||||||
|
|
||||||
|
rh := &ResultHandler{
|
||||||
|
resultFile: filepath.Join(installerDir, resultFile),
|
||||||
|
}
|
||||||
|
return rh
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rh *ResultHandler) GetErrorResultReason() string {
|
||||||
|
result, err := rh.tryReadResult()
|
||||||
|
if err == nil && !result.Success {
|
||||||
|
return result.Error
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rh.cleanup(); err != nil {
|
||||||
|
log.Warnf("failed to cleanup result file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rh *ResultHandler) WriteSuccess() error {
|
||||||
|
result := Result{
|
||||||
|
Success: true,
|
||||||
|
ExecutedAt: time.Now(),
|
||||||
|
}
|
||||||
|
return rh.write(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rh *ResultHandler) WriteErr(errReason error) error {
|
||||||
|
result := Result{
|
||||||
|
Success: false,
|
||||||
|
Error: errReason.Error(),
|
||||||
|
ExecutedAt: time.Now(),
|
||||||
|
}
|
||||||
|
return rh.write(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rh *ResultHandler) Watch(ctx context.Context) (Result, error) {
|
||||||
|
log.Infof("start watching result: %s", rh.resultFile)
|
||||||
|
|
||||||
|
// Check if file already exists (updater finished before we started watching)
|
||||||
|
if result, err := rh.tryReadResult(); err == nil {
|
||||||
|
log.Infof("installer result: %v", result)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dir := filepath.Dir(rh.resultFile)
|
||||||
|
|
||||||
|
if err := rh.waitForDirectory(ctx, dir); err != nil {
|
||||||
|
return Result{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return rh.watchForResultFile(ctx, dir)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rh *ResultHandler) waitForDirectory(ctx context.Context, dir string) error {
|
||||||
|
ticker := time.NewTicker(300 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-ticker.C:
|
||||||
|
if info, err := os.Stat(dir); err == nil && info.IsDir() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rh *ResultHandler) watchForResultFile(ctx context.Context, dir string) (Result, error) {
|
||||||
|
watcher, err := fsnotify.NewWatcher()
|
||||||
|
if err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
return Result{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if err := watcher.Close(); err != nil {
|
||||||
|
log.Warnf("failed to close watcher: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := watcher.Add(dir); err != nil {
|
||||||
|
return Result{}, fmt.Errorf("failed to watch directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check again after setting up watcher to avoid race condition
|
||||||
|
// (file could have been created between initial check and watcher setup)
|
||||||
|
if result, err := rh.tryReadResult(); err == nil {
|
||||||
|
log.Infof("installer result: %v", result)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return Result{}, ctx.Err()
|
||||||
|
case event, ok := <-watcher.Events:
|
||||||
|
if !ok {
|
||||||
|
return Result{}, errors.New("watcher closed unexpectedly")
|
||||||
|
}
|
||||||
|
|
||||||
|
if result, done := rh.handleWatchEvent(event); done {
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
case err, ok := <-watcher.Errors:
|
||||||
|
if !ok {
|
||||||
|
return Result{}, errors.New("watcher closed unexpectedly")
|
||||||
|
}
|
||||||
|
return Result{}, fmt.Errorf("watcher error: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rh *ResultHandler) handleWatchEvent(event fsnotify.Event) (Result, bool) {
|
||||||
|
if event.Name != rh.resultFile {
|
||||||
|
return Result{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if event.Has(fsnotify.Create) {
|
||||||
|
result, err := rh.tryReadResult()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("error while reading result: %v", err)
|
||||||
|
return result, true
|
||||||
|
}
|
||||||
|
log.Infof("installer result: %v", result)
|
||||||
|
return result, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return Result{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes the update result to a file for the UI to read
|
||||||
|
func (rh *ResultHandler) write(result Result) error {
|
||||||
|
log.Infof("write out installer result to: %s", rh.resultFile)
|
||||||
|
// Ensure directory exists
|
||||||
|
dir := filepath.Dir(rh.resultFile)
|
||||||
|
if err := os.MkdirAll(dir, 0o755); err != nil {
|
||||||
|
log.Errorf("failed to create directory %s: %v", dir, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(result)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write to a temporary file first, then rename for atomic operation
|
||||||
|
tmpPath := rh.resultFile + ".tmp"
|
||||||
|
if err := os.WriteFile(tmpPath, data, 0o600); err != nil {
|
||||||
|
log.Errorf("failed to create temp file: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Atomic rename
|
||||||
|
if err := os.Rename(tmpPath, rh.resultFile); err != nil {
|
||||||
|
if cleanupErr := os.Remove(tmpPath); cleanupErr != nil {
|
||||||
|
log.Warnf("Failed to remove temp result file: %v", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rh *ResultHandler) cleanup() error {
|
||||||
|
err := os.Remove(rh.resultFile)
|
||||||
|
if err != nil && !os.IsNotExist(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.Debugf("delete installer result file: %s", rh.resultFile)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryReadResult attempts to read and validate the result file
|
||||||
|
func (rh *ResultHandler) tryReadResult() (Result, error) {
|
||||||
|
data, err := os.ReadFile(rh.resultFile)
|
||||||
|
if err != nil {
|
||||||
|
return Result{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var result Result
|
||||||
|
if err := json.Unmarshal(data, &result); err != nil {
|
||||||
|
return Result{}, fmt.Errorf("invalid result format: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := rh.cleanup(); err != nil {
|
||||||
|
log.Warnf("failed to cleanup result file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
14
client/internal/updatemanager/installer/types.go
Normal file
14
client/internal/updatemanager/installer/types.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
type Type struct {
|
||||||
|
name string
|
||||||
|
downloadable bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Type) String() string {
|
||||||
|
return t.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t Type) Downloadable() bool {
|
||||||
|
return t.downloadable
|
||||||
|
}
|
||||||
22
client/internal/updatemanager/installer/types_darwin.go
Normal file
22
client/internal/updatemanager/installer/types_darwin.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os/exec"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
TypeHomebrew = Type{name: "Homebrew", downloadable: false}
|
||||||
|
TypePKG = Type{name: "pkg", downloadable: true}
|
||||||
|
)
|
||||||
|
|
||||||
|
func TypeOfInstaller(ctx context.Context) Type {
|
||||||
|
cmd := exec.CommandContext(ctx, "pkgutil", "--pkg-info", "io.netbird.client")
|
||||||
|
_, err := cmd.Output()
|
||||||
|
if err != nil && cmd.ProcessState.ExitCode() == 1 {
|
||||||
|
// Not installed using pkg file, thus installed using Homebrew
|
||||||
|
|
||||||
|
return TypeHomebrew
|
||||||
|
}
|
||||||
|
return TypePKG
|
||||||
|
}
|
||||||
51
client/internal/updatemanager/installer/types_windows.go
Normal file
51
client/internal/updatemanager/installer/types_windows.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package installer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows/registry"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
uninstallKeyPath64 = `SOFTWARE\WOW6432Node\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
|
||||||
|
uninstallKeyPath32 = `SOFTWARE\Microsoft\Windows\CurrentVersion\Uninstall\Netbird`
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
TypeExe = Type{name: "EXE", downloadable: true}
|
||||||
|
TypeMSI = Type{name: "MSI", downloadable: true}
|
||||||
|
)
|
||||||
|
|
||||||
|
func TypeOfInstaller(_ context.Context) Type {
|
||||||
|
paths := []string{uninstallKeyPath64, uninstallKeyPath32}
|
||||||
|
|
||||||
|
for _, path := range paths {
|
||||||
|
k, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.QUERY_VALUE)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := k.Close(); err != nil {
|
||||||
|
log.Warnf("Error closing registry key: %v", err)
|
||||||
|
}
|
||||||
|
return TypeExe
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("No registry entry found for Netbird, assuming MSI installation")
|
||||||
|
return TypeMSI
|
||||||
|
}
|
||||||
|
|
||||||
|
func typeByFileExtension(filePath string) (Type, error) {
|
||||||
|
switch {
|
||||||
|
case strings.HasSuffix(strings.ToLower(filePath), ".exe"):
|
||||||
|
return TypeExe, nil
|
||||||
|
case strings.HasSuffix(strings.ToLower(filePath), ".msi"):
|
||||||
|
return TypeMSI, nil
|
||||||
|
default:
|
||||||
|
return Type{}, fmt.Errorf("unsupported installer type for file: %s", filePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
374
client/internal/updatemanager/manager.go
Normal file
374
client/internal/updatemanager/manager.go
Normal file
@@ -0,0 +1,374 @@
|
|||||||
|
//go:build windows || darwin
|
||||||
|
|
||||||
|
package updatemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
v "github.com/hashicorp/go-version"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||||
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
latestVersion = "latest"
|
||||||
|
// this version will be ignored
|
||||||
|
developmentVersion = "development"
|
||||||
|
)
|
||||||
|
|
||||||
|
var errNoUpdateState = errors.New("no update state found")
|
||||||
|
|
||||||
|
type UpdateState struct {
|
||||||
|
PreUpdateVersion string
|
||||||
|
TargetVersion string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u UpdateState) Name() string {
|
||||||
|
return "autoUpdate"
|
||||||
|
}
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
stateManager *statemanager.Manager
|
||||||
|
|
||||||
|
lastTrigger time.Time
|
||||||
|
mgmUpdateChan chan struct{}
|
||||||
|
updateChannel chan struct{}
|
||||||
|
currentVersion string
|
||||||
|
update UpdateInterface
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
cancel context.CancelFunc
|
||||||
|
|
||||||
|
expectedVersion *v.Version
|
||||||
|
updateToLatestVersion bool
|
||||||
|
|
||||||
|
// updateMutex protect update and expectedVersion fields
|
||||||
|
updateMutex sync.Mutex
|
||||||
|
|
||||||
|
triggerUpdateFn func(context.Context, string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
|
||||||
|
if runtime.GOOS == "darwin" {
|
||||||
|
isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable()
|
||||||
|
if isBrew {
|
||||||
|
log.Warnf("auto-update disabled on Home Brew installation")
|
||||||
|
return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return newManager(statusRecorder, stateManager)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
|
||||||
|
manager := &Manager{
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
|
stateManager: stateManager,
|
||||||
|
mgmUpdateChan: make(chan struct{}, 1),
|
||||||
|
updateChannel: make(chan struct{}, 1),
|
||||||
|
currentVersion: version.NetbirdVersion(),
|
||||||
|
update: version.NewUpdate("nb/client"),
|
||||||
|
}
|
||||||
|
manager.triggerUpdateFn = manager.triggerUpdate
|
||||||
|
|
||||||
|
stateManager.RegisterState(&UpdateState{})
|
||||||
|
|
||||||
|
return manager, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckUpdateSuccess checks if the update was successful and send a notification.
|
||||||
|
// It works without to start the update manager.
|
||||||
|
func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
|
||||||
|
reason := m.lastResultErrReason()
|
||||||
|
if reason != "" {
|
||||||
|
m.statusRecorder.PublishEvent(
|
||||||
|
cProto.SystemEvent_ERROR,
|
||||||
|
cProto.SystemEvent_SYSTEM,
|
||||||
|
"Auto-update failed",
|
||||||
|
fmt.Sprintf("Auto-update failed: %s", reason),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
updateState, err := m.loadAndDeleteUpdateState(ctx)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, errNoUpdateState) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Errorf("failed to load update state: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("auto-update state loaded, %v", *updateState)
|
||||||
|
|
||||||
|
if updateState.TargetVersion == m.currentVersion {
|
||||||
|
m.statusRecorder.PublishEvent(
|
||||||
|
cProto.SystemEvent_INFO,
|
||||||
|
cProto.SystemEvent_SYSTEM,
|
||||||
|
"Auto-update completed",
|
||||||
|
fmt.Sprintf("Your NetBird Client was auto-updated to version %s", m.currentVersion),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Start(ctx context.Context) {
|
||||||
|
if m.cancel != nil {
|
||||||
|
log.Errorf("Manager already started")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.update.SetDaemonVersion(m.currentVersion)
|
||||||
|
m.update.SetOnUpdateListener(func() {
|
||||||
|
select {
|
||||||
|
case m.updateChannel <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
})
|
||||||
|
go m.update.StartFetcher()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
m.cancel = cancel
|
||||||
|
|
||||||
|
m.wg.Add(1)
|
||||||
|
go m.updateLoop(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) SetVersion(expectedVersion string) {
|
||||||
|
log.Infof("set expected agent version for upgrade: %s", expectedVersion)
|
||||||
|
if m.cancel == nil {
|
||||||
|
log.Errorf("manager not started")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.updateMutex.Lock()
|
||||||
|
defer m.updateMutex.Unlock()
|
||||||
|
|
||||||
|
if expectedVersion == "" {
|
||||||
|
log.Errorf("empty expected version provided")
|
||||||
|
m.expectedVersion = nil
|
||||||
|
m.updateToLatestVersion = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectedVersion == latestVersion {
|
||||||
|
m.updateToLatestVersion = true
|
||||||
|
m.expectedVersion = nil
|
||||||
|
} else {
|
||||||
|
expectedSemVer, err := v.NewVersion(expectedVersion)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("error parsing version: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if m.expectedVersion != nil && m.expectedVersion.Equal(expectedSemVer) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.expectedVersion = expectedSemVer
|
||||||
|
m.updateToLatestVersion = false
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case m.mgmUpdateChan <- struct{}{}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) Stop() {
|
||||||
|
if m.cancel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.cancel()
|
||||||
|
m.updateMutex.Lock()
|
||||||
|
if m.update != nil {
|
||||||
|
m.update.StopWatch()
|
||||||
|
m.update = nil
|
||||||
|
}
|
||||||
|
m.updateMutex.Unlock()
|
||||||
|
|
||||||
|
m.wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) onContextCancel() {
|
||||||
|
if m.cancel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.updateMutex.Lock()
|
||||||
|
defer m.updateMutex.Unlock()
|
||||||
|
if m.update != nil {
|
||||||
|
m.update.StopWatch()
|
||||||
|
m.update = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateLoop(ctx context.Context) {
|
||||||
|
defer m.wg.Done()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
m.onContextCancel()
|
||||||
|
return
|
||||||
|
case <-m.mgmUpdateChan:
|
||||||
|
case <-m.updateChannel:
|
||||||
|
log.Infof("fetched new version info")
|
||||||
|
}
|
||||||
|
|
||||||
|
m.handleUpdate(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleUpdate(ctx context.Context) {
|
||||||
|
var updateVersion *v.Version
|
||||||
|
|
||||||
|
m.updateMutex.Lock()
|
||||||
|
if m.update == nil {
|
||||||
|
m.updateMutex.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedVersion := m.expectedVersion
|
||||||
|
useLatest := m.updateToLatestVersion
|
||||||
|
curLatestVersion := m.update.LatestVersion()
|
||||||
|
m.updateMutex.Unlock()
|
||||||
|
|
||||||
|
switch {
|
||||||
|
// Resolve "latest" to actual version
|
||||||
|
case useLatest:
|
||||||
|
if curLatestVersion == nil {
|
||||||
|
log.Tracef("latest version not fetched yet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updateVersion = curLatestVersion
|
||||||
|
// Update to specific version
|
||||||
|
case expectedVersion != nil:
|
||||||
|
updateVersion = expectedVersion
|
||||||
|
default:
|
||||||
|
log.Debugf("no expected version information set")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion)
|
||||||
|
if !m.shouldUpdate(updateVersion) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.lastTrigger = time.Now()
|
||||||
|
log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion)
|
||||||
|
m.statusRecorder.PublishEvent(
|
||||||
|
cProto.SystemEvent_CRITICAL,
|
||||||
|
cProto.SystemEvent_SYSTEM,
|
||||||
|
"Automatically updating client",
|
||||||
|
"Your client version is older than auto-update version set in Management, updating client now.",
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
m.statusRecorder.PublishEvent(
|
||||||
|
cProto.SystemEvent_CRITICAL,
|
||||||
|
cProto.SystemEvent_SYSTEM,
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
map[string]string{"progress_window": "show", "version": updateVersion.String()},
|
||||||
|
)
|
||||||
|
|
||||||
|
updateState := UpdateState{
|
||||||
|
PreUpdateVersion: m.currentVersion,
|
||||||
|
TargetVersion: updateVersion.String(),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.stateManager.UpdateState(updateState); err != nil {
|
||||||
|
log.Warnf("failed to update state: %v", err)
|
||||||
|
} else {
|
||||||
|
if err = m.stateManager.PersistState(ctx); err != nil {
|
||||||
|
log.Warnf("failed to persist state: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil {
|
||||||
|
log.Errorf("Error triggering auto-update: %v", err)
|
||||||
|
m.statusRecorder.PublishEvent(
|
||||||
|
cProto.SystemEvent_ERROR,
|
||||||
|
cProto.SystemEvent_SYSTEM,
|
||||||
|
"Auto-update failed",
|
||||||
|
fmt.Sprintf("Auto-update failed: %v", err),
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it.
|
||||||
|
// Returns nil if no state exists.
|
||||||
|
func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, error) {
|
||||||
|
stateType := &UpdateState{}
|
||||||
|
|
||||||
|
m.stateManager.RegisterState(stateType)
|
||||||
|
if err := m.stateManager.LoadState(stateType); err != nil {
|
||||||
|
return nil, fmt.Errorf("load state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
state := m.stateManager.GetState(stateType)
|
||||||
|
if state == nil {
|
||||||
|
return nil, errNoUpdateState
|
||||||
|
}
|
||||||
|
|
||||||
|
updateState, ok := state.(*UpdateState)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("failed to cast state to UpdateState")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.stateManager.DeleteState(updateState); err != nil {
|
||||||
|
return nil, fmt.Errorf("delete state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.stateManager.PersistState(ctx); err != nil {
|
||||||
|
return nil, fmt.Errorf("persist state: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return updateState, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) shouldUpdate(updateVersion *v.Version) bool {
|
||||||
|
if m.currentVersion == developmentVersion {
|
||||||
|
log.Debugf("skipping auto-update, running development version")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
currentVersion, err := v.NewVersion(m.currentVersion)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("error checking for update, error parsing version `%s`: %v", m.currentVersion, err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if currentVersion.GreaterThanOrEqual(updateVersion) {
|
||||||
|
log.Infof("current version (%s) is equal to or higher than auto-update version (%s)", m.currentVersion, updateVersion)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Since(m.lastTrigger) < 5*time.Minute {
|
||||||
|
log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) lastResultErrReason() string {
|
||||||
|
inst := installer.New()
|
||||||
|
result := installer.NewResultHandler(inst.TempDir())
|
||||||
|
return result.GetErrorResultReason()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error {
|
||||||
|
inst := installer.New()
|
||||||
|
return inst.RunInstallation(ctx, targetVersion)
|
||||||
|
}
|
||||||
214
client/internal/updatemanager/manager_test.go
Normal file
214
client/internal/updatemanager/manager_test.go
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
//go:build windows || darwin
|
||||||
|
|
||||||
|
package updatemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"path"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
v "github.com/hashicorp/go-version"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
type versionUpdateMock struct {
|
||||||
|
latestVersion *v.Version
|
||||||
|
onUpdate func()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v versionUpdateMock) StopWatch() {}
|
||||||
|
|
||||||
|
func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
|
||||||
|
v.onUpdate = updateFn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v versionUpdateMock) LatestVersion() *v.Version {
|
||||||
|
return v.latestVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
func (v versionUpdateMock) StartFetcher() {}
|
||||||
|
|
||||||
|
func Test_LatestVersion(t *testing.T) {
|
||||||
|
testMatrix := []struct {
|
||||||
|
name string
|
||||||
|
daemonVersion string
|
||||||
|
initialLatestVersion *v.Version
|
||||||
|
latestVersion *v.Version
|
||||||
|
shouldUpdateInit bool
|
||||||
|
shouldUpdateLater bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Should only trigger update once due to time between triggers being < 5 Minutes",
|
||||||
|
daemonVersion: "1.0.0",
|
||||||
|
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||||
|
latestVersion: v.Must(v.NewSemver("1.0.2")),
|
||||||
|
shouldUpdateInit: true,
|
||||||
|
shouldUpdateLater: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
|
||||||
|
daemonVersion: "1.0.0",
|
||||||
|
initialLatestVersion: nil,
|
||||||
|
latestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||||
|
shouldUpdateInit: false,
|
||||||
|
shouldUpdateLater: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for idx, c := range testMatrix {
|
||||||
|
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
|
||||||
|
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||||
|
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
|
||||||
|
m.update = mockUpdate
|
||||||
|
|
||||||
|
targetVersionChan := make(chan string, 1)
|
||||||
|
|
||||||
|
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
|
||||||
|
targetVersionChan <- targetVersion
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
m.currentVersion = c.daemonVersion
|
||||||
|
m.Start(context.Background())
|
||||||
|
m.SetVersion("latest")
|
||||||
|
var triggeredInit bool
|
||||||
|
select {
|
||||||
|
case targetVersion := <-targetVersionChan:
|
||||||
|
if targetVersion != c.initialLatestVersion.String() {
|
||||||
|
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion)
|
||||||
|
}
|
||||||
|
triggeredInit = true
|
||||||
|
case <-time.After(10 * time.Millisecond):
|
||||||
|
triggeredInit = false
|
||||||
|
}
|
||||||
|
if triggeredInit != c.shouldUpdateInit {
|
||||||
|
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
|
||||||
|
}
|
||||||
|
|
||||||
|
mockUpdate.latestVersion = c.latestVersion
|
||||||
|
mockUpdate.onUpdate()
|
||||||
|
|
||||||
|
var triggeredLater bool
|
||||||
|
select {
|
||||||
|
case targetVersion := <-targetVersionChan:
|
||||||
|
if targetVersion != c.latestVersion.String() {
|
||||||
|
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
|
||||||
|
}
|
||||||
|
triggeredLater = true
|
||||||
|
case <-time.After(10 * time.Millisecond):
|
||||||
|
triggeredLater = false
|
||||||
|
}
|
||||||
|
if triggeredLater != c.shouldUpdateLater {
|
||||||
|
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_HandleUpdate(t *testing.T) {
|
||||||
|
testMatrix := []struct {
|
||||||
|
name string
|
||||||
|
daemonVersion string
|
||||||
|
latestVersion *v.Version
|
||||||
|
expectedVersion string
|
||||||
|
shouldUpdate bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Update to a specific version should update regardless of if latestVersion is available yet",
|
||||||
|
daemonVersion: "0.55.0",
|
||||||
|
latestVersion: nil,
|
||||||
|
expectedVersion: "0.56.0",
|
||||||
|
shouldUpdate: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Update to specific version should not update if version matches",
|
||||||
|
daemonVersion: "0.55.0",
|
||||||
|
latestVersion: nil,
|
||||||
|
expectedVersion: "0.55.0",
|
||||||
|
shouldUpdate: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Update to specific version should not update if current version is newer",
|
||||||
|
daemonVersion: "0.55.0",
|
||||||
|
latestVersion: nil,
|
||||||
|
expectedVersion: "0.54.0",
|
||||||
|
shouldUpdate: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Update to latest version should update if latest is newer",
|
||||||
|
daemonVersion: "0.55.0",
|
||||||
|
latestVersion: v.Must(v.NewSemver("0.56.0")),
|
||||||
|
expectedVersion: "latest",
|
||||||
|
shouldUpdate: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Update to latest version should not update if latest == current",
|
||||||
|
daemonVersion: "0.56.0",
|
||||||
|
latestVersion: v.Must(v.NewSemver("0.56.0")),
|
||||||
|
expectedVersion: "latest",
|
||||||
|
shouldUpdate: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should not update if daemon version is invalid",
|
||||||
|
daemonVersion: "development",
|
||||||
|
latestVersion: v.Must(v.NewSemver("1.0.0")),
|
||||||
|
expectedVersion: "latest",
|
||||||
|
shouldUpdate: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should not update if expecting latest and latest version is unavailable",
|
||||||
|
daemonVersion: "0.55.0",
|
||||||
|
latestVersion: nil,
|
||||||
|
expectedVersion: "latest",
|
||||||
|
shouldUpdate: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Should not update if expected version is invalid",
|
||||||
|
daemonVersion: "0.55.0",
|
||||||
|
latestVersion: nil,
|
||||||
|
expectedVersion: "development",
|
||||||
|
shouldUpdate: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for idx, c := range testMatrix {
|
||||||
|
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||||
|
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
|
||||||
|
m.update = &versionUpdateMock{latestVersion: c.latestVersion}
|
||||||
|
targetVersionChan := make(chan string, 1)
|
||||||
|
|
||||||
|
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
|
||||||
|
targetVersionChan <- targetVersion
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
m.currentVersion = c.daemonVersion
|
||||||
|
m.Start(context.Background())
|
||||||
|
m.SetVersion(c.expectedVersion)
|
||||||
|
|
||||||
|
var updateTriggered bool
|
||||||
|
select {
|
||||||
|
case targetVersion := <-targetVersionChan:
|
||||||
|
if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() {
|
||||||
|
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
|
||||||
|
} else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion {
|
||||||
|
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion)
|
||||||
|
}
|
||||||
|
updateTriggered = true
|
||||||
|
case <-time.After(10 * time.Millisecond):
|
||||||
|
updateTriggered = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if updateTriggered != c.shouldUpdate {
|
||||||
|
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
|
||||||
|
}
|
||||||
|
m.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
39
client/internal/updatemanager/manager_unsupported.go
Normal file
39
client/internal/updatemanager/manager_unsupported.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
//go:build !windows && !darwin
|
||||||
|
|
||||||
|
package updatemanager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager is a no-op stub for unsupported platforms
|
||||||
|
type Manager struct{}
|
||||||
|
|
||||||
|
// NewManager returns a no-op manager for unsupported platforms
|
||||||
|
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
|
||||||
|
return nil, fmt.Errorf("update manager is not supported on this platform")
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckUpdateSuccess is a no-op on unsupported platforms
|
||||||
|
func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
|
||||||
|
// no-op
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start is a no-op on unsupported platforms
|
||||||
|
func (m *Manager) Start(ctx context.Context) {
|
||||||
|
// no-op
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetVersion is a no-op on unsupported platforms
|
||||||
|
func (m *Manager) SetVersion(expectedVersion string) {
|
||||||
|
// no-op
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop is a no-op on unsupported platforms
|
||||||
|
func (m *Manager) Stop() {
|
||||||
|
// no-op
|
||||||
|
}
|
||||||
302
client/internal/updatemanager/reposign/artifact.go
Normal file
302
client/internal/updatemanager/reposign/artifact.go
Normal file
@@ -0,0 +1,302 @@
|
|||||||
|
package reposign
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"hash"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/crypto/blake2s"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
tagArtifactPrivate = "ARTIFACT PRIVATE KEY"
|
||||||
|
tagArtifactPublic = "ARTIFACT PUBLIC KEY"
|
||||||
|
|
||||||
|
maxArtifactKeySignatureAge = 10 * 365 * 24 * time.Hour
|
||||||
|
maxArtifactSignatureAge = 10 * 365 * 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
// ArtifactHash wraps a hash.Hash and counts bytes written
|
||||||
|
type ArtifactHash struct {
|
||||||
|
hash.Hash
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewArtifactHash returns an initialized ArtifactHash using BLAKE2s
|
||||||
|
func NewArtifactHash() *ArtifactHash {
|
||||||
|
h, err := blake2s.New256(nil)
|
||||||
|
if err != nil {
|
||||||
|
panic(err) // Should never happen with nil Key
|
||||||
|
}
|
||||||
|
return &ArtifactHash{Hash: h}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ah *ArtifactHash) Write(b []byte) (int, error) {
|
||||||
|
return ah.Hash.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ArtifactKey is a signing Key used to sign artifacts
|
||||||
|
type ArtifactKey struct {
|
||||||
|
PrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k ArtifactKey) String() string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"ArtifactKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]",
|
||||||
|
k.Metadata.ID,
|
||||||
|
k.Metadata.CreatedAt.Format(time.RFC3339),
|
||||||
|
k.Metadata.ExpiresAt.Format(time.RFC3339),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateArtifactKey(rootKey *RootKey, expiration time.Duration) (*ArtifactKey, []byte, []byte, []byte, error) {
|
||||||
|
// Verify root key is still valid
|
||||||
|
if !rootKey.Metadata.ExpiresAt.IsZero() && time.Now().After(rootKey.Metadata.ExpiresAt) {
|
||||||
|
return nil, nil, nil, nil, fmt.Errorf("root key has expired on %s", rootKey.Metadata.ExpiresAt.Format(time.RFC3339))
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
expirationTime := now.Add(expiration)
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, fmt.Errorf("generate ed25519 key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata := KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: now.UTC(),
|
||||||
|
ExpiresAt: expirationTime.UTC(),
|
||||||
|
}
|
||||||
|
|
||||||
|
ak := &ArtifactKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: priv,
|
||||||
|
Metadata: metadata,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal PrivateKey struct to JSON
|
||||||
|
privJSON, err := json.Marshal(ak.PrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal PublicKey struct to JSON
|
||||||
|
pubKey := PublicKey{
|
||||||
|
Key: pub,
|
||||||
|
Metadata: metadata,
|
||||||
|
}
|
||||||
|
pubJSON, err := json.Marshal(pubKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode to PEM with metadata embedded in bytes
|
||||||
|
privPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagArtifactPrivate,
|
||||||
|
Bytes: privJSON,
|
||||||
|
})
|
||||||
|
|
||||||
|
pubPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagArtifactPublic,
|
||||||
|
Bytes: pubJSON,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Sign the public key with the root key
|
||||||
|
signature, err := SignArtifactKey(*rootKey, pubPEM)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, nil, fmt.Errorf("failed to sign artifact key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ak, privPEM, pubPEM, signature, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseArtifactKey(privKeyPEM []byte) (ArtifactKey, error) {
|
||||||
|
pk, err := parsePrivateKey(privKeyPEM, tagArtifactPrivate)
|
||||||
|
if err != nil {
|
||||||
|
return ArtifactKey{}, fmt.Errorf("failed to parse artifact Key: %w", err)
|
||||||
|
}
|
||||||
|
return ArtifactKey{pk}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseArtifactPubKey(data []byte) (PublicKey, error) {
|
||||||
|
pk, _, err := parsePublicKey(data, tagArtifactPublic)
|
||||||
|
return pk, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func BundleArtifactKeys(rootKey *RootKey, keys []PublicKey) ([]byte, []byte, error) {
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return nil, nil, errors.New("no keys to bundle")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create bundle by concatenating PEM-encoded keys
|
||||||
|
var pubBundle []byte
|
||||||
|
|
||||||
|
for _, pk := range keys {
|
||||||
|
// Marshal PublicKey struct to JSON
|
||||||
|
pubJSON, err := json.Marshal(pk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode to PEM
|
||||||
|
pubPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagArtifactPublic,
|
||||||
|
Bytes: pubJSON,
|
||||||
|
})
|
||||||
|
|
||||||
|
pubBundle = append(pubBundle, pubPEM...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign the entire bundle with the root key
|
||||||
|
signature, err := SignArtifactKey(*rootKey, pubBundle)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to sign artifact key bundle: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return pubBundle, signature, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateArtifactKeys(publicRootKeys []PublicKey, data []byte, signature Signature, revocationList *RevocationList) ([]PublicKey, error) {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
if signature.Timestamp.After(now.Add(maxClockSkew)) {
|
||||||
|
err := fmt.Errorf("signature timestamp is in the future: %v", signature.Timestamp)
|
||||||
|
log.Debugf("artifact signature error: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if now.Sub(signature.Timestamp) > maxArtifactKeySignatureAge {
|
||||||
|
err := fmt.Errorf("signature is too old: %v (created %v)", now.Sub(signature.Timestamp), signature.Timestamp)
|
||||||
|
log.Debugf("artifact signature error: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconstruct the signed message: artifact_key_data || timestamp
|
||||||
|
msg := make([]byte, 0, len(data)+8)
|
||||||
|
msg = append(msg, data...)
|
||||||
|
msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
|
||||||
|
|
||||||
|
if !verifyAny(publicRootKeys, msg, signature.Signature) {
|
||||||
|
return nil, errors.New("failed to verify signature of artifact keys")
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKeys, err := parsePublicKeyBundle(data, tagArtifactPublic)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to parse public keys: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
validKeys := make([]PublicKey, 0, len(pubKeys))
|
||||||
|
for _, pubKey := range pubKeys {
|
||||||
|
// Filter out expired keys
|
||||||
|
if !pubKey.Metadata.ExpiresAt.IsZero() && now.After(pubKey.Metadata.ExpiresAt) {
|
||||||
|
log.Debugf("Key %s is expired at %v (current time %v)",
|
||||||
|
pubKey.Metadata.ID, pubKey.Metadata.ExpiresAt, now)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if revocationList != nil {
|
||||||
|
if revTime, revoked := revocationList.Revoked[pubKey.Metadata.ID]; revoked {
|
||||||
|
log.Debugf("Key %s is revoked as of %v (created %v)",
|
||||||
|
pubKey.Metadata.ID, revTime, pubKey.Metadata.CreatedAt)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
validKeys = append(validKeys, pubKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(validKeys) == 0 {
|
||||||
|
log.Debugf("no valid public keys found for artifact keys")
|
||||||
|
return nil, fmt.Errorf("all %d artifact keys are revoked", len(pubKeys))
|
||||||
|
}
|
||||||
|
|
||||||
|
return validKeys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateArtifact(artifactPubKeys []PublicKey, data []byte, signature Signature) error {
|
||||||
|
// Validate signature timestamp
|
||||||
|
now := time.Now().UTC()
|
||||||
|
if signature.Timestamp.After(now.Add(maxClockSkew)) {
|
||||||
|
err := fmt.Errorf("artifact signature timestamp is in the future: %v", signature.Timestamp)
|
||||||
|
log.Debugf("failed to verify signature of artifact: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if now.Sub(signature.Timestamp) > maxArtifactSignatureAge {
|
||||||
|
return fmt.Errorf("artifact signature is too old: %v (created %v)",
|
||||||
|
now.Sub(signature.Timestamp), signature.Timestamp)
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewArtifactHash()
|
||||||
|
if _, err := h.Write(data); err != nil {
|
||||||
|
return fmt.Errorf("failed to hash artifact: %w", err)
|
||||||
|
}
|
||||||
|
hash := h.Sum(nil)
|
||||||
|
|
||||||
|
// Reconstruct the signed message: hash || length || timestamp
|
||||||
|
msg := make([]byte, 0, len(hash)+8+8)
|
||||||
|
msg = append(msg, hash...)
|
||||||
|
msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data)))
|
||||||
|
msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
|
||||||
|
|
||||||
|
// Find matching Key and verify
|
||||||
|
for _, keyInfo := range artifactPubKeys {
|
||||||
|
if keyInfo.Metadata.ID == signature.KeyID {
|
||||||
|
// Check Key expiration
|
||||||
|
if !keyInfo.Metadata.ExpiresAt.IsZero() &&
|
||||||
|
signature.Timestamp.After(keyInfo.Metadata.ExpiresAt) {
|
||||||
|
return fmt.Errorf("signing Key %s expired at %v, signature from %v",
|
||||||
|
signature.KeyID, keyInfo.Metadata.ExpiresAt, signature.Timestamp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ed25519.Verify(keyInfo.Key, msg, signature.Signature) {
|
||||||
|
log.Debugf("artifact verified successfully with Key: %s", signature.KeyID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("signature verification failed for Key %s", signature.KeyID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("no signing Key found with ID %s", signature.KeyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SignData(artifactKey ArtifactKey, data []byte) ([]byte, error) {
|
||||||
|
if len(data) == 0 { // Check happens too late
|
||||||
|
return nil, fmt.Errorf("artifact length must be positive, got %d", len(data))
|
||||||
|
}
|
||||||
|
|
||||||
|
h := NewArtifactHash()
|
||||||
|
if _, err := h.Write(data); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to write artifact hash: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
timestamp := time.Now().UTC()
|
||||||
|
|
||||||
|
if !artifactKey.Metadata.ExpiresAt.IsZero() && timestamp.After(artifactKey.Metadata.ExpiresAt) {
|
||||||
|
return nil, fmt.Errorf("artifact key expired at %v", artifactKey.Metadata.ExpiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
hash := h.Sum(nil)
|
||||||
|
|
||||||
|
// Create message: hash || length || timestamp
|
||||||
|
msg := make([]byte, 0, len(hash)+8+8)
|
||||||
|
msg = append(msg, hash...)
|
||||||
|
msg = binary.LittleEndian.AppendUint64(msg, uint64(len(data)))
|
||||||
|
msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
|
||||||
|
|
||||||
|
sig := ed25519.Sign(artifactKey.Key, msg)
|
||||||
|
|
||||||
|
bundle := Signature{
|
||||||
|
Signature: sig,
|
||||||
|
Timestamp: timestamp,
|
||||||
|
KeyID: artifactKey.Metadata.ID,
|
||||||
|
Algorithm: "ed25519",
|
||||||
|
HashAlgo: "blake2s",
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(bundle)
|
||||||
|
}
|
||||||
1080
client/internal/updatemanager/reposign/artifact_test.go
Normal file
1080
client/internal/updatemanager/reposign/artifact_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,6 @@
|
|||||||
|
-----BEGIN ROOT PUBLIC KEY-----
|
||||||
|
eyJLZXkiOiJoaGIxdGRDSEZNMFBuQWp1b2w2cXJ1QXRFbWFFSlg1QjFsZUNxWmpn
|
||||||
|
V1pvPSIsIk1ldGFkYXRhIjp7ImlkIjoiOWE0OTg2NmI2MzE2MjNiNCIsImNyZWF0
|
||||||
|
ZWRfYXQiOiIyMDI1LTExLTI0VDE3OjE1OjI4LjYyNzE3MzE3MVoiLCJleHBpcmVz
|
||||||
|
X2F0IjoiMjAzNS0xMS0yMlQxNzoxNToyOC42MjcxNzMxNzFaIn19
|
||||||
|
-----END ROOT PUBLIC KEY-----
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
-----BEGIN ROOT PUBLIC KEY-----
|
||||||
|
eyJLZXkiOiJyTDByVTN2MEFOZUNmbDZraitiUUd3TE1waU5CaUJLdVBWSnZtQzgr
|
||||||
|
ZS84PSIsIk1ldGFkYXRhIjp7ImlkIjoiMTBkNjQyZTY2N2FmMDNkNCIsImNyZWF0
|
||||||
|
ZWRfYXQiOiIyMDI1LTExLTIwVDE3OjI5OjI5LjE4MDk0NjMxNloiLCJleHBpcmVz
|
||||||
|
X2F0IjoiMjAyNi0xMS0yMFQxNzoyOToyOS4xODA5NDYzMTZaIn19
|
||||||
|
-----END ROOT PUBLIC KEY-----
|
||||||
174
client/internal/updatemanager/reposign/doc.go
Normal file
174
client/internal/updatemanager/reposign/doc.go
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
// Package reposign implements a cryptographic signing and verification system
|
||||||
|
// for NetBird software update artifacts. It provides a hierarchical key
|
||||||
|
// management system with support for key rotation, revocation, and secure
|
||||||
|
// artifact distribution.
|
||||||
|
//
|
||||||
|
// # Architecture
|
||||||
|
//
|
||||||
|
// The package uses a two-tier key hierarchy:
|
||||||
|
//
|
||||||
|
// - Root Keys: Long-lived keys that sign artifact keys. These are embedded
|
||||||
|
// in the client binary and establish the root of trust. Root keys should
|
||||||
|
// be kept offline and highly secured.
|
||||||
|
//
|
||||||
|
// - Artifact Keys: Short-lived keys that sign release artifacts (binaries,
|
||||||
|
// packages, etc.). These are rotated regularly and can be revoked if
|
||||||
|
// compromised. Artifact keys are signed by root keys and distributed via
|
||||||
|
// a public repository.
|
||||||
|
//
|
||||||
|
// This separation allows for operational flexibility: artifact keys can be
|
||||||
|
// rotated frequently without requiring client updates, while root keys remain
|
||||||
|
// stable and embedded in the software.
|
||||||
|
//
|
||||||
|
// # Cryptographic Primitives
|
||||||
|
//
|
||||||
|
// The package uses strong, modern cryptographic algorithms:
|
||||||
|
// - Ed25519: Fast, secure digital signatures (no timing attacks)
|
||||||
|
// - BLAKE2s-256: Fast cryptographic hash for artifacts
|
||||||
|
// - SHA-256: Key ID generation
|
||||||
|
// - JSON: Structured key and signature serialization
|
||||||
|
// - PEM: Standard key encoding format
|
||||||
|
//
|
||||||
|
// # Security Features
|
||||||
|
//
|
||||||
|
// Timestamp Binding:
|
||||||
|
// - All signatures include cryptographically-bound timestamps
|
||||||
|
// - Prevents replay attacks and enforces signature freshness
|
||||||
|
// - Clock skew tolerance: 5 minutes
|
||||||
|
//
|
||||||
|
// Key Expiration:
|
||||||
|
// - All keys have expiration times
|
||||||
|
// - Expired keys are automatically rejected
|
||||||
|
// - Signing with an expired key fails immediately
|
||||||
|
//
|
||||||
|
// Key Revocation:
|
||||||
|
// - Compromised keys can be revoked via a signed revocation list
|
||||||
|
// - Revocation list is checked during artifact validation
|
||||||
|
// - Revoked keys are filtered out before artifact verification
|
||||||
|
//
|
||||||
|
// # File Structure
|
||||||
|
//
|
||||||
|
// The package expects the following file layout in the key repository:
|
||||||
|
//
|
||||||
|
// signrepo/
|
||||||
|
// artifact-key-pub.pem # Bundle of artifact public keys
|
||||||
|
// artifact-key-pub.pem.sig # Root signature of the bundle
|
||||||
|
// revocation-list.json # List of revoked key IDs
|
||||||
|
// revocation-list.json.sig # Root signature of revocation list
|
||||||
|
//
|
||||||
|
// And in the artifacts repository:
|
||||||
|
//
|
||||||
|
// releases/
|
||||||
|
// v0.28.0/
|
||||||
|
// netbird-linux-amd64
|
||||||
|
// netbird-linux-amd64.sig # Artifact signature
|
||||||
|
// netbird-darwin-amd64
|
||||||
|
// netbird-darwin-amd64.sig
|
||||||
|
// ...
|
||||||
|
//
|
||||||
|
// # Embedded Root Keys
|
||||||
|
//
|
||||||
|
// Root public keys are embedded in the client binary at compile time:
|
||||||
|
// - Production keys: certs/ directory
|
||||||
|
// - Development keys: certsdev/ directory
|
||||||
|
//
|
||||||
|
// The build tag determines which keys are embedded:
|
||||||
|
// - Production builds: //go:build !devartifactsign
|
||||||
|
// - Development builds: //go:build devartifactsign
|
||||||
|
//
|
||||||
|
// This ensures that development artifacts cannot be verified using production
|
||||||
|
// keys and vice versa.
|
||||||
|
//
|
||||||
|
// # Key Rotation Strategies
|
||||||
|
//
|
||||||
|
// Root Key Rotation:
|
||||||
|
//
|
||||||
|
// Root keys can be rotated without breaking existing clients by leveraging
|
||||||
|
// the multi-key verification system. The loadEmbeddedPublicKeys function
|
||||||
|
// reads ALL files from the certs/ directory and accepts signatures from ANY
|
||||||
|
// of the embedded root keys.
|
||||||
|
//
|
||||||
|
// To rotate root keys:
|
||||||
|
//
|
||||||
|
// 1. Generate a new root key pair:
|
||||||
|
// newRootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
|
||||||
|
//
|
||||||
|
// 2. Add the new public key to the certs/ directory as a new file:
|
||||||
|
// certs/
|
||||||
|
// root-pub-2024.pem # Old key (keep this!)
|
||||||
|
// root-pub-2025.pem # New key (add this)
|
||||||
|
//
|
||||||
|
// 3. Build new client versions with both keys embedded. The verification
|
||||||
|
// will accept signatures from either key.
|
||||||
|
//
|
||||||
|
// 4. Start signing new artifact keys with the new root key. Old clients
|
||||||
|
// with only the old root key will reject these, but new clients with
|
||||||
|
// both keys will accept them.
|
||||||
|
//
|
||||||
|
// Each file in certs/ can contain a single key or a bundle of keys (multiple
|
||||||
|
// PEM blocks). The system will parse all keys from all files and use them
|
||||||
|
// for verification. This provides maximum flexibility for key management.
|
||||||
|
//
|
||||||
|
// Important: Never remove all old root keys at once. Always maintain at least
|
||||||
|
// one overlapping key between releases to ensure smooth transitions.
|
||||||
|
//
|
||||||
|
// Artifact Key Rotation:
|
||||||
|
//
|
||||||
|
// Artifact keys should be rotated regularly (e.g., every 90 days) using the
|
||||||
|
// bundling mechanism. The BundleArtifactKeys function allows multiple artifact
|
||||||
|
// keys to be bundled together in a single signed package, and ValidateArtifact
|
||||||
|
// will accept signatures from ANY key in the bundle.
|
||||||
|
//
|
||||||
|
// To rotate artifact keys smoothly:
|
||||||
|
//
|
||||||
|
// 1. Generate a new artifact key while keeping the old one:
|
||||||
|
// newKey, newPrivPEM, newPubPEM, newSig, err := GenerateArtifactKey(rootKey, 90 * 24 * time.Hour)
|
||||||
|
// // Keep oldPubPEM and oldKey available
|
||||||
|
//
|
||||||
|
// 2. Create a bundle containing both old and new public keys
|
||||||
|
//
|
||||||
|
// 3. Upload the bundle and its signature to the key repository:
|
||||||
|
// signrepo/artifact-key-pub.pem # Contains both keys
|
||||||
|
// signrepo/artifact-key-pub.pem.sig # Root signature
|
||||||
|
//
|
||||||
|
// 4. Start signing new releases with the NEW key, but keep the bundle
|
||||||
|
// unchanged. Clients will download the bundle (containing both keys)
|
||||||
|
// and accept signatures from either key.
|
||||||
|
//
|
||||||
|
// Key bundle validation workflow:
|
||||||
|
// 1. Client downloads artifact-key-pub.pem and artifact-key-pub.pem.sig
|
||||||
|
// 2. ValidateArtifactKeys verifies the bundle signature with ANY embedded root key
|
||||||
|
// 3. ValidateArtifactKeys parses all public keys from the bundle
|
||||||
|
// 4. ValidateArtifactKeys filters out expired or revoked keys
|
||||||
|
// 5. When verifying an artifact, ValidateArtifact tries each key until one succeeds
|
||||||
|
//
|
||||||
|
// This multi-key acceptance model enables overlapping validity periods and
|
||||||
|
// smooth transitions without client update requirements.
|
||||||
|
//
|
||||||
|
// # Best Practices
|
||||||
|
//
|
||||||
|
// Root Key Management:
|
||||||
|
// - Generate root keys offline on an air-gapped machine
|
||||||
|
// - Store root private keys in hardware security modules (HSM) if possible
|
||||||
|
// - Use separate root keys for production and development
|
||||||
|
// - Rotate root keys infrequently (e.g., every 5-10 years)
|
||||||
|
// - Plan for root key rotation: embed multiple root public keys
|
||||||
|
//
|
||||||
|
// Artifact Key Management:
|
||||||
|
// - Rotate artifact keys regularly (e.g., every 90 days)
|
||||||
|
// - Use separate artifact keys for different release channels if needed
|
||||||
|
// - Revoke keys immediately upon suspected compromise
|
||||||
|
// - Bundle multiple artifact keys to enable smooth rotation
|
||||||
|
//
|
||||||
|
// Signing Process:
|
||||||
|
// - Sign artifacts in a secure CI/CD environment
|
||||||
|
// - Never commit private keys to version control
|
||||||
|
// - Use environment variables or secret management for keys
|
||||||
|
// - Verify signatures immediately after signing
|
||||||
|
//
|
||||||
|
// Distribution:
|
||||||
|
// - Serve keys and revocation lists from a reliable CDN
|
||||||
|
// - Use HTTPS for all key and artifact downloads
|
||||||
|
// - Monitor download failures and signature verification failures
|
||||||
|
// - Keep revocation list up to date
|
||||||
|
package reposign
|
||||||
10
client/internal/updatemanager/reposign/embed_dev.go
Normal file
10
client/internal/updatemanager/reposign/embed_dev.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build devartifactsign
|
||||||
|
|
||||||
|
package reposign
|
||||||
|
|
||||||
|
import "embed"
|
||||||
|
|
||||||
|
//go:embed certsdev
|
||||||
|
var embeddedCerts embed.FS
|
||||||
|
|
||||||
|
const embeddedCertsDir = "certsdev"
|
||||||
10
client/internal/updatemanager/reposign/embed_prod.go
Normal file
10
client/internal/updatemanager/reposign/embed_prod.go
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
//go:build !devartifactsign
|
||||||
|
|
||||||
|
package reposign
|
||||||
|
|
||||||
|
import "embed"
|
||||||
|
|
||||||
|
//go:embed certs
|
||||||
|
var embeddedCerts embed.FS
|
||||||
|
|
||||||
|
const embeddedCertsDir = "certs"
|
||||||
171
client/internal/updatemanager/reposign/key.go
Normal file
171
client/internal/updatemanager/reposign/key.go
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
package reposign
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxClockSkew = 5 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// KeyID is a unique identifier for a Key (first 8 bytes of SHA-256 of public Key)
|
||||||
|
type KeyID [8]byte
|
||||||
|
|
||||||
|
// computeKeyID generates a unique ID from a public Key
|
||||||
|
func computeKeyID(pub ed25519.PublicKey) KeyID {
|
||||||
|
h := sha256.Sum256(pub)
|
||||||
|
var id KeyID
|
||||||
|
copy(id[:], h[:8])
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements json.Marshaler for KeyID
|
||||||
|
func (k KeyID) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(k.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements json.Unmarshaler for KeyID
|
||||||
|
func (k *KeyID) UnmarshalJSON(data []byte) error {
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(data, &s); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
parsed, err := ParseKeyID(s)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
*k = parsed
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseKeyID parses a hex string (16 hex chars = 8 bytes) into a KeyID.
|
||||||
|
func ParseKeyID(s string) (KeyID, error) {
|
||||||
|
var id KeyID
|
||||||
|
if len(s) != 16 {
|
||||||
|
return id, fmt.Errorf("invalid KeyID length: got %d, want 16 hex chars (8 bytes)", len(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := hex.DecodeString(s)
|
||||||
|
if err != nil {
|
||||||
|
return id, fmt.Errorf("failed to decode KeyID: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(id[:], b)
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k KeyID) String() string {
|
||||||
|
return fmt.Sprintf("%x", k[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyMetadata contains versioning and lifecycle information for a Key
|
||||||
|
type KeyMetadata struct {
|
||||||
|
ID KeyID `json:"id"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at,omitempty"` // Optional expiration
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublicKey wraps a public Key with its Metadata
|
||||||
|
type PublicKey struct {
|
||||||
|
Key ed25519.PublicKey
|
||||||
|
Metadata KeyMetadata
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePublicKeyBundle(bundle []byte, typeTag string) ([]PublicKey, error) {
|
||||||
|
var keys []PublicKey
|
||||||
|
for len(bundle) > 0 {
|
||||||
|
keyInfo, rest, err := parsePublicKey(bundle, typeTag)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
keys = append(keys, keyInfo)
|
||||||
|
bundle = rest
|
||||||
|
}
|
||||||
|
if len(keys) == 0 {
|
||||||
|
return nil, errors.New("no keys found in bundle")
|
||||||
|
}
|
||||||
|
return keys, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePublicKey(data []byte, typeTag string) (PublicKey, []byte, error) {
|
||||||
|
b, rest := pem.Decode(data)
|
||||||
|
if b == nil {
|
||||||
|
return PublicKey{}, nil, errors.New("failed to decode PEM data")
|
||||||
|
}
|
||||||
|
if b.Type != typeTag {
|
||||||
|
return PublicKey{}, nil, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal JSON-embedded format
|
||||||
|
var pub PublicKey
|
||||||
|
if err := json.Unmarshal(b.Bytes, &pub); err != nil {
|
||||||
|
return PublicKey{}, nil, fmt.Errorf("failed to unmarshal public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate key length
|
||||||
|
if len(pub.Key) != ed25519.PublicKeySize {
|
||||||
|
return PublicKey{}, nil, fmt.Errorf("incorrect Ed25519 public key size: expected %d, got %d",
|
||||||
|
ed25519.PublicKeySize, len(pub.Key))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always recompute ID to ensure integrity
|
||||||
|
pub.Metadata.ID = computeKeyID(pub.Key)
|
||||||
|
|
||||||
|
return pub, rest, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type PrivateKey struct {
|
||||||
|
Key ed25519.PrivateKey
|
||||||
|
Metadata KeyMetadata
|
||||||
|
}
|
||||||
|
|
||||||
|
func parsePrivateKey(data []byte, typeTag string) (PrivateKey, error) {
|
||||||
|
b, rest := pem.Decode(data)
|
||||||
|
if b == nil {
|
||||||
|
return PrivateKey{}, errors.New("failed to decode PEM data")
|
||||||
|
}
|
||||||
|
if len(rest) > 0 {
|
||||||
|
return PrivateKey{}, errors.New("trailing PEM data")
|
||||||
|
}
|
||||||
|
if b.Type != typeTag {
|
||||||
|
return PrivateKey{}, fmt.Errorf("PEM type is %q, want %q", b.Type, typeTag)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal JSON-embedded format
|
||||||
|
var pk PrivateKey
|
||||||
|
if err := json.Unmarshal(b.Bytes, &pk); err != nil {
|
||||||
|
return PrivateKey{}, fmt.Errorf("failed to unmarshal private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate key length
|
||||||
|
if len(pk.Key) != ed25519.PrivateKeySize {
|
||||||
|
return PrivateKey{}, fmt.Errorf("incorrect Ed25519 private key size: expected %d, got %d",
|
||||||
|
ed25519.PrivateKeySize, len(pk.Key))
|
||||||
|
}
|
||||||
|
|
||||||
|
return pk, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyAny(publicRootKeys []PublicKey, msg, sig []byte) bool {
|
||||||
|
// Verify with root keys
|
||||||
|
var rootKeys []ed25519.PublicKey
|
||||||
|
for _, r := range publicRootKeys {
|
||||||
|
rootKeys = append(rootKeys, r.Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, k := range rootKeys {
|
||||||
|
if ed25519.Verify(k, msg, sig) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
636
client/internal/updatemanager/reposign/key_test.go
Normal file
636
client/internal/updatemanager/reposign/key_test.go
Normal file
@@ -0,0 +1,636 @@
|
|||||||
|
package reposign
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test KeyID functions
|
||||||
|
|
||||||
|
func TestComputeKeyID(t *testing.T) {
|
||||||
|
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
keyID := computeKeyID(pub)
|
||||||
|
|
||||||
|
// Verify it's the first 8 bytes of SHA-256
|
||||||
|
h := sha256.Sum256(pub)
|
||||||
|
expectedID := KeyID{}
|
||||||
|
copy(expectedID[:], h[:8])
|
||||||
|
|
||||||
|
assert.Equal(t, expectedID, keyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeKeyID_Deterministic(t *testing.T) {
|
||||||
|
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Computing KeyID multiple times should give the same result
|
||||||
|
keyID1 := computeKeyID(pub)
|
||||||
|
keyID2 := computeKeyID(pub)
|
||||||
|
|
||||||
|
assert.Equal(t, keyID1, keyID2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestComputeKeyID_DifferentKeys(t *testing.T) {
|
||||||
|
pub1, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pub2, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
keyID1 := computeKeyID(pub1)
|
||||||
|
keyID2 := computeKeyID(pub2)
|
||||||
|
|
||||||
|
// Different keys should produce different IDs
|
||||||
|
assert.NotEqual(t, keyID1, keyID2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseKeyID_Valid(t *testing.T) {
|
||||||
|
hexStr := "0123456789abcdef"
|
||||||
|
|
||||||
|
keyID, err := ParseKeyID(hexStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
expected := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
|
||||||
|
assert.Equal(t, expected, keyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseKeyID_InvalidLength(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
}{
|
||||||
|
{"too short", "01234567"},
|
||||||
|
{"too long", "0123456789abcdef00"},
|
||||||
|
{"empty", ""},
|
||||||
|
{"odd length", "0123456789abcde"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := ParseKeyID(tt.input)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "invalid KeyID length")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseKeyID_InvalidHex(t *testing.T) {
|
||||||
|
invalidHex := "0123456789abcxyz" // 'xyz' are not valid hex
|
||||||
|
|
||||||
|
_, err := ParseKeyID(invalidHex)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to decode KeyID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyID_String(t *testing.T) {
|
||||||
|
keyID := KeyID{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}
|
||||||
|
|
||||||
|
str := keyID.String()
|
||||||
|
assert.Equal(t, "0123456789abcdef", str)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyID_RoundTrip(t *testing.T) {
|
||||||
|
original := "fedcba9876543210"
|
||||||
|
|
||||||
|
keyID, err := ParseKeyID(original)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result := keyID.String()
|
||||||
|
assert.Equal(t, original, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyID_ZeroValue(t *testing.T) {
|
||||||
|
keyID := KeyID{}
|
||||||
|
str := keyID.String()
|
||||||
|
assert.Equal(t, "0000000000000000", str)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test KeyMetadata
|
||||||
|
|
||||||
|
func TestKeyMetadata_JSONMarshaling(t *testing.T) {
|
||||||
|
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
metadata := KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
|
||||||
|
ExpiresAt: time.Date(2025, 1, 15, 10, 30, 0, 0, time.UTC),
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(metadata)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var decoded KeyMetadata
|
||||||
|
err = json.Unmarshal(jsonData, &decoded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, metadata.ID, decoded.ID)
|
||||||
|
assert.Equal(t, metadata.CreatedAt.Unix(), decoded.CreatedAt.Unix())
|
||||||
|
assert.Equal(t, metadata.ExpiresAt.Unix(), decoded.ExpiresAt.Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyMetadata_NoExpiration(t *testing.T) {
|
||||||
|
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
metadata := KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
|
||||||
|
ExpiresAt: time.Time{}, // Zero value = no expiration
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(metadata)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var decoded KeyMetadata
|
||||||
|
err = json.Unmarshal(jsonData, &decoded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.True(t, decoded.ExpiresAt.IsZero())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test PublicKey
|
||||||
|
|
||||||
|
func TestPublicKey_JSONMarshaling(t *testing.T) {
|
||||||
|
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pubKey := PublicKey{
|
||||||
|
Key: pub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(pubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var decoded PublicKey
|
||||||
|
err = json.Unmarshal(jsonData, &decoded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, pubKey.Key, decoded.Key)
|
||||||
|
assert.Equal(t, pubKey.Metadata.ID, decoded.Metadata.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test parsePublicKey
|
||||||
|
|
||||||
|
func TestParsePublicKey_Valid(t *testing.T) {
|
||||||
|
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
metadata := KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
ExpiresAt: time.Now().Add(365 * 24 * time.Hour).UTC(),
|
||||||
|
}
|
||||||
|
|
||||||
|
pubKey := PublicKey{
|
||||||
|
Key: pub,
|
||||||
|
Metadata: metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal to JSON
|
||||||
|
jsonData, err := json.Marshal(pubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Encode to PEM
|
||||||
|
pemData := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagRootPublic,
|
||||||
|
Bytes: jsonData,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Parse it back
|
||||||
|
parsed, rest, err := parsePublicKey(pemData, tagRootPublic)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, rest)
|
||||||
|
assert.Equal(t, pub, parsed.Key)
|
||||||
|
assert.Equal(t, metadata.ID, parsed.Metadata.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePublicKey_InvalidPEM(t *testing.T) {
|
||||||
|
invalidPEM := []byte("not a PEM")
|
||||||
|
|
||||||
|
_, _, err := parsePublicKey(invalidPEM, tagRootPublic)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to decode PEM")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePublicKey_WrongType(t *testing.T) {
|
||||||
|
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pubKey := PublicKey{
|
||||||
|
Key: pub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(pubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Encode with wrong type
|
||||||
|
pemData := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "WRONG TYPE",
|
||||||
|
Bytes: jsonData,
|
||||||
|
})
|
||||||
|
|
||||||
|
_, _, err = parsePublicKey(pemData, tagRootPublic)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "PEM type")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePublicKey_InvalidJSON(t *testing.T) {
|
||||||
|
pemData := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagRootPublic,
|
||||||
|
Bytes: []byte("invalid json"),
|
||||||
|
})
|
||||||
|
|
||||||
|
_, _, err := parsePublicKey(pemData, tagRootPublic)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to unmarshal")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePublicKey_InvalidKeySize(t *testing.T) {
|
||||||
|
// Create a public key with wrong size
|
||||||
|
pubKey := PublicKey{
|
||||||
|
Key: []byte{0x01, 0x02, 0x03}, // Too short
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: KeyID{},
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(pubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pemData := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagRootPublic,
|
||||||
|
Bytes: jsonData,
|
||||||
|
})
|
||||||
|
|
||||||
|
_, _, err = parsePublicKey(pemData, tagRootPublic)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "incorrect Ed25519 public key size")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePublicKey_IDRecomputation(t *testing.T) {
|
||||||
|
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create a public key with WRONG ID
|
||||||
|
wrongID := KeyID{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}
|
||||||
|
pubKey := PublicKey{
|
||||||
|
Key: pub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: wrongID,
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(pubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pemData := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagRootPublic,
|
||||||
|
Bytes: jsonData,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Parse should recompute the correct ID
|
||||||
|
parsed, _, err := parsePublicKey(pemData, tagRootPublic)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
correctID := computeKeyID(pub)
|
||||||
|
assert.Equal(t, correctID, parsed.Metadata.ID)
|
||||||
|
assert.NotEqual(t, wrongID, parsed.Metadata.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test parsePublicKeyBundle
|
||||||
|
|
||||||
|
func TestParsePublicKeyBundle_Single(t *testing.T) {
|
||||||
|
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pubKey := PublicKey{
|
||||||
|
Key: pub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(pubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pemData := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagRootPublic,
|
||||||
|
Bytes: jsonData,
|
||||||
|
})
|
||||||
|
|
||||||
|
keys, err := parsePublicKeyBundle(pemData, tagRootPublic)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, keys, 1)
|
||||||
|
assert.Equal(t, pub, keys[0].Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePublicKeyBundle_Multiple(t *testing.T) {
|
||||||
|
var bundle []byte
|
||||||
|
|
||||||
|
// Create 3 keys
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pubKey := PublicKey{
|
||||||
|
Key: pub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(pubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pemData := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagRootPublic,
|
||||||
|
Bytes: jsonData,
|
||||||
|
})
|
||||||
|
|
||||||
|
bundle = append(bundle, pemData...)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys, err := parsePublicKeyBundle(bundle, tagRootPublic)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, keys, 3)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePublicKeyBundle_Empty(t *testing.T) {
|
||||||
|
_, err := parsePublicKeyBundle([]byte{}, tagRootPublic)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no keys found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePublicKeyBundle_Invalid(t *testing.T) {
|
||||||
|
_, err := parsePublicKeyBundle([]byte("invalid data"), tagRootPublic)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test PrivateKey
|
||||||
|
|
||||||
|
func TestPrivateKey_JSONMarshaling(t *testing.T) {
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
privKey := PrivateKey{
|
||||||
|
Key: priv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(privKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var decoded PrivateKey
|
||||||
|
err = json.Unmarshal(jsonData, &decoded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, privKey.Key, decoded.Key)
|
||||||
|
assert.Equal(t, privKey.Metadata.ID, decoded.Metadata.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test parsePrivateKey
|
||||||
|
|
||||||
|
func TestParsePrivateKey_Valid(t *testing.T) {
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
privKey := PrivateKey{
|
||||||
|
Key: priv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(privKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pemData := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagRootPrivate,
|
||||||
|
Bytes: jsonData,
|
||||||
|
})
|
||||||
|
|
||||||
|
parsed, err := parsePrivateKey(pemData, tagRootPrivate)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, priv, parsed.Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePrivateKey_InvalidPEM(t *testing.T) {
|
||||||
|
_, err := parsePrivateKey([]byte("not a PEM"), tagRootPrivate)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to decode PEM")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePrivateKey_TrailingData(t *testing.T) {
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
privKey := PrivateKey{
|
||||||
|
Key: priv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(privKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pemData := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagRootPrivate,
|
||||||
|
Bytes: jsonData,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Add trailing data
|
||||||
|
pemData = append(pemData, []byte("extra data")...)
|
||||||
|
|
||||||
|
_, err = parsePrivateKey(pemData, tagRootPrivate)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "trailing PEM data")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePrivateKey_WrongType(t *testing.T) {
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
privKey := PrivateKey{
|
||||||
|
Key: priv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(privKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pemData := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "WRONG TYPE",
|
||||||
|
Bytes: jsonData,
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err = parsePrivateKey(pemData, tagRootPrivate)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "PEM type")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParsePrivateKey_InvalidKeySize(t *testing.T) {
|
||||||
|
privKey := PrivateKey{
|
||||||
|
Key: []byte{0x01, 0x02, 0x03}, // Too short
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: KeyID{},
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(privKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pemData := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagRootPrivate,
|
||||||
|
Bytes: jsonData,
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err = parsePrivateKey(pemData, tagRootPrivate)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "incorrect Ed25519 private key size")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test verifyAny
|
||||||
|
|
||||||
|
func TestVerifyAny_ValidSignature(t *testing.T) {
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
message := []byte("test message")
|
||||||
|
signature := ed25519.Sign(priv, message)
|
||||||
|
|
||||||
|
rootKeys := []PublicKey{
|
||||||
|
{
|
||||||
|
Key: pub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := verifyAny(rootKeys, message, signature)
|
||||||
|
assert.True(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyAny_InvalidSignature(t *testing.T) {
|
||||||
|
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
message := []byte("test message")
|
||||||
|
invalidSignature := make([]byte, ed25519.SignatureSize)
|
||||||
|
|
||||||
|
rootKeys := []PublicKey{
|
||||||
|
{
|
||||||
|
Key: pub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := verifyAny(rootKeys, message, invalidSignature)
|
||||||
|
assert.False(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyAny_MultipleKeys(t *testing.T) {
|
||||||
|
// Create 3 key pairs
|
||||||
|
pub1, priv1, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pub2, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pub3, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
message := []byte("test message")
|
||||||
|
signature := ed25519.Sign(priv1, message)
|
||||||
|
|
||||||
|
rootKeys := []PublicKey{
|
||||||
|
{Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}},
|
||||||
|
{Key: pub1, Metadata: KeyMetadata{ID: computeKeyID(pub1)}}, // Correct key in middle
|
||||||
|
{Key: pub3, Metadata: KeyMetadata{ID: computeKeyID(pub3)}},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := verifyAny(rootKeys, message, signature)
|
||||||
|
assert.True(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyAny_NoMatchingKey(t *testing.T) {
|
||||||
|
_, priv1, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
pub2, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
message := []byte("test message")
|
||||||
|
signature := ed25519.Sign(priv1, message)
|
||||||
|
|
||||||
|
// Only include pub2, not pub1
|
||||||
|
rootKeys := []PublicKey{
|
||||||
|
{Key: pub2, Metadata: KeyMetadata{ID: computeKeyID(pub2)}},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := verifyAny(rootKeys, message, signature)
|
||||||
|
assert.False(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyAny_EmptyKeys(t *testing.T) {
|
||||||
|
message := []byte("test message")
|
||||||
|
signature := make([]byte, ed25519.SignatureSize)
|
||||||
|
|
||||||
|
result := verifyAny([]PublicKey{}, message, signature)
|
||||||
|
assert.False(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestVerifyAny_TamperedMessage(t *testing.T) {
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
message := []byte("test message")
|
||||||
|
signature := ed25519.Sign(priv, message)
|
||||||
|
|
||||||
|
rootKeys := []PublicKey{
|
||||||
|
{Key: pub, Metadata: KeyMetadata{ID: computeKeyID(pub)}},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify with different message
|
||||||
|
tamperedMessage := []byte("different message")
|
||||||
|
result := verifyAny(rootKeys, tamperedMessage, signature)
|
||||||
|
assert.False(t, result)
|
||||||
|
}
|
||||||
229
client/internal/updatemanager/reposign/revocation.go
Normal file
229
client/internal/updatemanager/reposign/revocation.go
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
package reposign
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxRevocationSignatureAge = 10 * 365 * 24 * time.Hour
|
||||||
|
defaultRevocationListExpiration = 365 * 24 * time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
|
type RevocationList struct {
|
||||||
|
Revoked map[KeyID]time.Time `json:"revoked"` // KeyID -> revocation time
|
||||||
|
LastUpdated time.Time `json:"last_updated"` // When the list was last modified
|
||||||
|
ExpiresAt time.Time `json:"expires_at"` // When the list expires
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl RevocationList) MarshalJSON() ([]byte, error) {
|
||||||
|
// Convert map[KeyID]time.Time to map[string]time.Time
|
||||||
|
strMap := make(map[string]time.Time, len(rl.Revoked))
|
||||||
|
for k, v := range rl.Revoked {
|
||||||
|
strMap[k.String()] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(map[string]interface{}{
|
||||||
|
"revoked": strMap,
|
||||||
|
"last_updated": rl.LastUpdated,
|
||||||
|
"expires_at": rl.ExpiresAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *RevocationList) UnmarshalJSON(data []byte) error {
|
||||||
|
var temp struct {
|
||||||
|
Revoked map[string]time.Time `json:"revoked"`
|
||||||
|
LastUpdated time.Time `json:"last_updated"`
|
||||||
|
ExpiresAt time.Time `json:"expires_at"`
|
||||||
|
Version int `json:"version"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(data, &temp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert map[string]time.Time back to map[KeyID]time.Time
|
||||||
|
rl.Revoked = make(map[KeyID]time.Time, len(temp.Revoked))
|
||||||
|
for k, v := range temp.Revoked {
|
||||||
|
kid, err := ParseKeyID(k)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse KeyID %q: %w", k, err)
|
||||||
|
}
|
||||||
|
rl.Revoked[kid] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
rl.LastUpdated = temp.LastUpdated
|
||||||
|
rl.ExpiresAt = temp.ExpiresAt
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseRevocationList(data []byte) (*RevocationList, error) {
|
||||||
|
var rl RevocationList
|
||||||
|
if err := json.Unmarshal(data, &rl); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal revocation list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the map if it's nil (in case of empty JSON object)
|
||||||
|
if rl.Revoked == nil {
|
||||||
|
rl.Revoked = make(map[KeyID]time.Time)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rl.LastUpdated.IsZero() {
|
||||||
|
return nil, fmt.Errorf("revocation list missing last_updated timestamp")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rl.ExpiresAt.IsZero() {
|
||||||
|
return nil, fmt.Errorf("revocation list missing expires_at timestamp")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &rl, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidateRevocationList(publicRootKeys []PublicKey, data []byte, signature Signature) (*RevocationList, error) {
|
||||||
|
revoList, err := ParseRevocationList(data)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to parse revocation list: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
// Validate signature timestamp
|
||||||
|
if signature.Timestamp.After(now.Add(maxClockSkew)) {
|
||||||
|
err := fmt.Errorf("revocation signature timestamp is in the future: %v", signature.Timestamp)
|
||||||
|
log.Debugf("revocation list signature error: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if now.Sub(signature.Timestamp) > maxRevocationSignatureAge {
|
||||||
|
err := fmt.Errorf("revocation list signature is too old: %v (created %v)",
|
||||||
|
now.Sub(signature.Timestamp), signature.Timestamp)
|
||||||
|
log.Debugf("revocation list signature error: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure LastUpdated is not in the future (with clock skew tolerance)
|
||||||
|
if revoList.LastUpdated.After(now.Add(maxClockSkew)) {
|
||||||
|
err := fmt.Errorf("revocation list LastUpdated is in the future: %v", revoList.LastUpdated)
|
||||||
|
log.Errorf("rejecting future-dated revocation list: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the revocation list has expired
|
||||||
|
if now.After(revoList.ExpiresAt) {
|
||||||
|
err := fmt.Errorf("revocation list expired at %v (current time: %v)", revoList.ExpiresAt, now)
|
||||||
|
log.Errorf("rejecting expired revocation list: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure ExpiresAt is not in the future by more than the expected expiration window
|
||||||
|
// (allows some clock skew but prevents maliciously long expiration times)
|
||||||
|
if revoList.ExpiresAt.After(now.Add(maxRevocationSignatureAge)) {
|
||||||
|
err := fmt.Errorf("revocation list ExpiresAt is too far in the future: %v", revoList.ExpiresAt)
|
||||||
|
log.Errorf("rejecting revocation list with invalid expiration: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate signature timestamp is close to LastUpdated
|
||||||
|
// (prevents signing old lists with new timestamps)
|
||||||
|
timeDiff := signature.Timestamp.Sub(revoList.LastUpdated).Abs()
|
||||||
|
if timeDiff > maxClockSkew {
|
||||||
|
err := fmt.Errorf("signature timestamp %v differs too much from list LastUpdated %v (diff: %v)",
|
||||||
|
signature.Timestamp, revoList.LastUpdated, timeDiff)
|
||||||
|
log.Errorf("timestamp mismatch in revocation list: %v", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reconstruct the signed message: revocation_list_data || timestamp || version
|
||||||
|
msg := make([]byte, 0, len(data)+8)
|
||||||
|
msg = append(msg, data...)
|
||||||
|
msg = binary.LittleEndian.AppendUint64(msg, uint64(signature.Timestamp.Unix()))
|
||||||
|
|
||||||
|
if !verifyAny(publicRootKeys, msg, signature.Signature) {
|
||||||
|
return nil, errors.New("revocation list verification failed")
|
||||||
|
}
|
||||||
|
return revoList, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateRevocationList(privateRootKey RootKey, expiration time.Duration) ([]byte, []byte, error) {
|
||||||
|
now := time.Now()
|
||||||
|
rl := RevocationList{
|
||||||
|
Revoked: make(map[KeyID]time.Time),
|
||||||
|
LastUpdated: now.UTC(),
|
||||||
|
ExpiresAt: now.Add(expiration).UTC(),
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := signRevocationList(privateRootKey, rl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rlData, err := json.Marshal(&rl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signData, err := json.Marshal(signature)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to marshal signature: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rlData, signData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtendRevocationList(privateRootKey RootKey, rl RevocationList, kid KeyID, expiration time.Duration) ([]byte, []byte, error) {
|
||||||
|
now := time.Now().UTC()
|
||||||
|
|
||||||
|
rl.Revoked[kid] = now
|
||||||
|
rl.LastUpdated = now
|
||||||
|
rl.ExpiresAt = now.Add(expiration)
|
||||||
|
|
||||||
|
signature, err := signRevocationList(privateRootKey, rl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to sign revocation list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rlData, err := json.Marshal(&rl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to marshal revocation list: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signData, err := json.Marshal(signature)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("failed to marshal signature: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rlData, signData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func signRevocationList(privateRootKey RootKey, rl RevocationList) (*Signature, error) {
|
||||||
|
data, err := json.Marshal(rl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal revocation list for signing: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
timestamp := time.Now().UTC()
|
||||||
|
|
||||||
|
msg := make([]byte, 0, len(data)+8)
|
||||||
|
msg = append(msg, data...)
|
||||||
|
msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
|
||||||
|
|
||||||
|
sig := ed25519.Sign(privateRootKey.Key, msg)
|
||||||
|
|
||||||
|
signature := &Signature{
|
||||||
|
Signature: sig,
|
||||||
|
Timestamp: timestamp,
|
||||||
|
KeyID: privateRootKey.Metadata.ID,
|
||||||
|
Algorithm: "ed25519",
|
||||||
|
HashAlgo: "sha512",
|
||||||
|
}
|
||||||
|
|
||||||
|
return signature, nil
|
||||||
|
}
|
||||||
860
client/internal/updatemanager/reposign/revocation_test.go
Normal file
860
client/internal/updatemanager/reposign/revocation_test.go
Normal file
@@ -0,0 +1,860 @@
|
|||||||
|
package reposign
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test RevocationList marshaling/unmarshaling
|
||||||
|
|
||||||
|
func TestRevocationList_MarshalJSON(t *testing.T) {
|
||||||
|
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
keyID := computeKeyID(pub)
|
||||||
|
revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
|
||||||
|
lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
|
||||||
|
expiresAt := time.Date(2024, 4, 15, 11, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
rl := &RevocationList{
|
||||||
|
Revoked: map[KeyID]time.Time{
|
||||||
|
keyID: revokedTime,
|
||||||
|
},
|
||||||
|
LastUpdated: lastUpdated,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(rl)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify it can be unmarshaled back
|
||||||
|
var decoded map[string]interface{}
|
||||||
|
err = json.Unmarshal(jsonData, &decoded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Contains(t, decoded, "revoked")
|
||||||
|
assert.Contains(t, decoded, "last_updated")
|
||||||
|
assert.Contains(t, decoded, "expires_at")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRevocationList_UnmarshalJSON(t *testing.T) {
|
||||||
|
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
keyID := computeKeyID(pub)
|
||||||
|
revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
|
||||||
|
lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
jsonData := map[string]interface{}{
|
||||||
|
"revoked": map[string]string{
|
||||||
|
keyID.String(): revokedTime.Format(time.RFC3339),
|
||||||
|
},
|
||||||
|
"last_updated": lastUpdated.Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonBytes, err := json.Marshal(jsonData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var rl RevocationList
|
||||||
|
err = json.Unmarshal(jsonBytes, &rl)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, rl.Revoked, 1)
|
||||||
|
assert.Contains(t, rl.Revoked, keyID)
|
||||||
|
assert.Equal(t, lastUpdated.Unix(), rl.LastUpdated.Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRevocationList_MarshalUnmarshal_Roundtrip(t *testing.T) {
|
||||||
|
pub1, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
pub2, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
keyID1 := computeKeyID(pub1)
|
||||||
|
keyID2 := computeKeyID(pub2)
|
||||||
|
|
||||||
|
original := &RevocationList{
|
||||||
|
Revoked: map[KeyID]time.Time{
|
||||||
|
keyID1: time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC),
|
||||||
|
keyID2: time.Date(2024, 2, 20, 14, 45, 0, 0, time.UTC),
|
||||||
|
},
|
||||||
|
LastUpdated: time.Date(2024, 2, 20, 15, 0, 0, 0, time.UTC),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal
|
||||||
|
jsonData, err := original.MarshalJSON()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
var decoded RevocationList
|
||||||
|
err = decoded.UnmarshalJSON(jsonData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
assert.Len(t, decoded.Revoked, 2)
|
||||||
|
assert.Equal(t, original.Revoked[keyID1].Unix(), decoded.Revoked[keyID1].Unix())
|
||||||
|
assert.Equal(t, original.Revoked[keyID2].Unix(), decoded.Revoked[keyID2].Unix())
|
||||||
|
assert.Equal(t, original.LastUpdated.Unix(), decoded.LastUpdated.Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRevocationList_UnmarshalJSON_InvalidKeyID(t *testing.T) {
|
||||||
|
jsonData := []byte(`{
|
||||||
|
"revoked": {
|
||||||
|
"invalid_key_id": "2024-01-15T10:30:00Z"
|
||||||
|
},
|
||||||
|
"last_updated": "2024-01-15T11:00:00Z"
|
||||||
|
}`)
|
||||||
|
|
||||||
|
var rl RevocationList
|
||||||
|
err := json.Unmarshal(jsonData, &rl)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to parse KeyID")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRevocationList_EmptyRevoked(t *testing.T) {
|
||||||
|
rl := &RevocationList{
|
||||||
|
Revoked: make(map[KeyID]time.Time),
|
||||||
|
LastUpdated: time.Now().UTC(),
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := rl.MarshalJSON()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var decoded RevocationList
|
||||||
|
err = decoded.UnmarshalJSON(jsonData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Empty(t, decoded.Revoked)
|
||||||
|
assert.NotNil(t, decoded.Revoked)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test ParseRevocationList
|
||||||
|
|
||||||
|
func TestParseRevocationList_Valid(t *testing.T) {
|
||||||
|
pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
keyID := computeKeyID(pub)
|
||||||
|
revokedTime := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
|
||||||
|
lastUpdated := time.Date(2024, 1, 15, 11, 0, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
rl := RevocationList{
|
||||||
|
Revoked: map[KeyID]time.Time{
|
||||||
|
keyID: revokedTime,
|
||||||
|
},
|
||||||
|
LastUpdated: lastUpdated,
|
||||||
|
ExpiresAt: time.Date(2025, 2, 20, 14, 45, 0, 0, time.UTC),
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := rl.MarshalJSON()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsed, err := ParseRevocationList(jsonData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, parsed)
|
||||||
|
assert.Len(t, parsed.Revoked, 1)
|
||||||
|
assert.Equal(t, lastUpdated.Unix(), parsed.LastUpdated.Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRevocationList_InvalidJSON(t *testing.T) {
|
||||||
|
invalidJSON := []byte("not valid json")
|
||||||
|
|
||||||
|
_, err := ParseRevocationList(invalidJSON)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to unmarshal")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRevocationList_MissingLastUpdated(t *testing.T) {
|
||||||
|
jsonData := []byte(`{
|
||||||
|
"revoked": {}
|
||||||
|
}`)
|
||||||
|
|
||||||
|
_, err := ParseRevocationList(jsonData)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "missing last_updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRevocationList_EmptyObject(t *testing.T) {
|
||||||
|
jsonData := []byte(`{}`)
|
||||||
|
|
||||||
|
_, err := ParseRevocationList(jsonData)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "missing last_updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRevocationList_NilRevoked(t *testing.T) {
|
||||||
|
lastUpdated := time.Now().UTC()
|
||||||
|
expiresAt := lastUpdated.Add(90 * 24 * time.Hour)
|
||||||
|
jsonData := []byte(`{
|
||||||
|
"last_updated": "` + lastUpdated.Format(time.RFC3339) + `",
|
||||||
|
"expires_at": "` + expiresAt.Format(time.RFC3339) + `"
|
||||||
|
}`)
|
||||||
|
|
||||||
|
parsed, err := ParseRevocationList(jsonData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, parsed.Revoked)
|
||||||
|
assert.Empty(t, parsed.Revoked)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRevocationList_MissingExpiresAt(t *testing.T) {
|
||||||
|
lastUpdated := time.Now().UTC()
|
||||||
|
jsonData := []byte(`{
|
||||||
|
"revoked": {},
|
||||||
|
"last_updated": "` + lastUpdated.Format(time.RFC3339) + `"
|
||||||
|
}`)
|
||||||
|
|
||||||
|
_, err := ParseRevocationList(jsonData)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "missing expires_at")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test ValidateRevocationList
|
||||||
|
|
||||||
|
func TestValidateRevocationList_Valid(t *testing.T) {
|
||||||
|
// Generate root key
|
||||||
|
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKey := RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: rootPriv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rootKeys := []PublicKey{
|
||||||
|
{
|
||||||
|
Key: rootPub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create revocation list
|
||||||
|
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
signature, err := ParseSignature(sigData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Validate
|
||||||
|
rl, err := ValidateRevocationList(rootKeys, rlData, *signature)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, rl)
|
||||||
|
assert.Empty(t, rl.Revoked)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRevocationList_InvalidSignature(t *testing.T) {
|
||||||
|
// Generate root key
|
||||||
|
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKey := RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: rootPriv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rootKeys := []PublicKey{
|
||||||
|
{
|
||||||
|
Key: rootPub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create revocation list
|
||||||
|
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create invalid signature
|
||||||
|
invalidSig := Signature{
|
||||||
|
Signature: make([]byte, 64),
|
||||||
|
Timestamp: time.Now().UTC(),
|
||||||
|
KeyID: computeKeyID(rootPub),
|
||||||
|
Algorithm: "ed25519",
|
||||||
|
HashAlgo: "sha512",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate should fail
|
||||||
|
_, err = ValidateRevocationList(rootKeys, rlData, invalidSig)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "verification failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRevocationList_FutureTimestamp(t *testing.T) {
|
||||||
|
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKey := RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: rootPriv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rootKeys := []PublicKey{
|
||||||
|
{
|
||||||
|
Key: rootPub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
signature, err := ParseSignature(sigData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Modify timestamp to be in the future
|
||||||
|
signature.Timestamp = time.Now().UTC().Add(10 * time.Minute)
|
||||||
|
|
||||||
|
_, err = ValidateRevocationList(rootKeys, rlData, *signature)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "in the future")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRevocationList_TooOld(t *testing.T) {
|
||||||
|
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKey := RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: rootPriv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rootKeys := []PublicKey{
|
||||||
|
{
|
||||||
|
Key: rootPub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
signature, err := ParseSignature(sigData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Modify timestamp to be too old
|
||||||
|
signature.Timestamp = time.Now().UTC().Add(-20 * 365 * 24 * time.Hour)
|
||||||
|
|
||||||
|
_, err = ValidateRevocationList(rootKeys, rlData, *signature)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "too old")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRevocationList_InvalidJSON(t *testing.T) {
|
||||||
|
rootPub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKeys := []PublicKey{
|
||||||
|
{
|
||||||
|
Key: rootPub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
signature := Signature{
|
||||||
|
Signature: make([]byte, 64),
|
||||||
|
Timestamp: time.Now().UTC(),
|
||||||
|
KeyID: computeKeyID(rootPub),
|
||||||
|
Algorithm: "ed25519",
|
||||||
|
HashAlgo: "sha512",
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = ValidateRevocationList(rootKeys, []byte("invalid json"), signature)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRevocationList_FutureLastUpdated(t *testing.T) {
|
||||||
|
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKey := RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: rootPriv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rootKeys := []PublicKey{
|
||||||
|
{
|
||||||
|
Key: rootPub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create revocation list with future LastUpdated
|
||||||
|
rl := RevocationList{
|
||||||
|
Revoked: make(map[KeyID]time.Time),
|
||||||
|
LastUpdated: time.Now().UTC().Add(10 * time.Minute),
|
||||||
|
ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour),
|
||||||
|
}
|
||||||
|
|
||||||
|
rlData, err := json.Marshal(rl)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Sign it
|
||||||
|
sig, err := signRevocationList(rootKey, rl)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = ValidateRevocationList(rootKeys, rlData, *sig)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "LastUpdated is in the future")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRevocationList_TimestampMismatch(t *testing.T) {
|
||||||
|
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKey := RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: rootPriv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rootKeys := []PublicKey{
|
||||||
|
{
|
||||||
|
Key: rootPub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create revocation list with LastUpdated far in the past
|
||||||
|
rl := RevocationList{
|
||||||
|
Revoked: make(map[KeyID]time.Time),
|
||||||
|
LastUpdated: time.Now().UTC().Add(-1 * time.Hour),
|
||||||
|
ExpiresAt: time.Now().UTC().Add(365 * 24 * time.Hour),
|
||||||
|
}
|
||||||
|
|
||||||
|
rlData, err := json.Marshal(rl)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Sign it with current timestamp
|
||||||
|
sig, err := signRevocationList(rootKey, rl)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Modify signature timestamp to differ too much from LastUpdated
|
||||||
|
sig.Timestamp = time.Now().UTC()
|
||||||
|
|
||||||
|
_, err = ValidateRevocationList(rootKeys, rlData, *sig)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "differs too much")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRevocationList_Expired(t *testing.T) {
|
||||||
|
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKey := RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: rootPriv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rootKeys := []PublicKey{
|
||||||
|
{
|
||||||
|
Key: rootPub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create revocation list that expired in the past
|
||||||
|
now := time.Now().UTC()
|
||||||
|
rl := RevocationList{
|
||||||
|
Revoked: make(map[KeyID]time.Time),
|
||||||
|
LastUpdated: now.Add(-100 * 24 * time.Hour),
|
||||||
|
ExpiresAt: now.Add(-10 * 24 * time.Hour), // Expired 10 days ago
|
||||||
|
}
|
||||||
|
|
||||||
|
rlData, err := json.Marshal(rl)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Sign it
|
||||||
|
sig, err := signRevocationList(rootKey, rl)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Adjust signature timestamp to match LastUpdated
|
||||||
|
sig.Timestamp = rl.LastUpdated
|
||||||
|
|
||||||
|
_, err = ValidateRevocationList(rootKeys, rlData, *sig)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRevocationList_ExpiresAtTooFarInFuture(t *testing.T) {
|
||||||
|
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKey := RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: rootPriv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rootKeys := []PublicKey{
|
||||||
|
{
|
||||||
|
Key: rootPub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create revocation list with ExpiresAt too far in the future (beyond maxRevocationSignatureAge)
|
||||||
|
now := time.Now().UTC()
|
||||||
|
rl := RevocationList{
|
||||||
|
Revoked: make(map[KeyID]time.Time),
|
||||||
|
LastUpdated: now,
|
||||||
|
ExpiresAt: now.Add(15 * 365 * 24 * time.Hour), // 15 years in the future
|
||||||
|
}
|
||||||
|
|
||||||
|
rlData, err := json.Marshal(rl)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Sign it
|
||||||
|
sig, err := signRevocationList(rootKey, rl)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, err = ValidateRevocationList(rootKeys, rlData, *sig)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "too far in the future")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test CreateRevocationList
|
||||||
|
|
||||||
|
func TestCreateRevocationList_Valid(t *testing.T) {
|
||||||
|
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKey := RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: rootPriv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, rlData)
|
||||||
|
assert.NotEmpty(t, sigData)
|
||||||
|
|
||||||
|
// Verify it can be parsed
|
||||||
|
rl, err := ParseRevocationList(rlData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, rl.Revoked)
|
||||||
|
assert.False(t, rl.LastUpdated.IsZero())
|
||||||
|
|
||||||
|
// Verify signature can be parsed
|
||||||
|
sig, err := ParseSignature(sigData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, sig.Signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test ExtendRevocationList
|
||||||
|
|
||||||
|
func TestExtendRevocationList_AddKey(t *testing.T) {
|
||||||
|
// Generate root key
|
||||||
|
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKey := RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: rootPriv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create empty revocation list
|
||||||
|
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rl, err := ParseRevocationList(rlData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, rl.Revoked)
|
||||||
|
|
||||||
|
// Generate a key to revoke
|
||||||
|
revokedPub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
revokedKeyID := computeKeyID(revokedPub)
|
||||||
|
|
||||||
|
// Extend the revocation list
|
||||||
|
newRLData, newSigData, err := ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the new list
|
||||||
|
newRL, err := ParseRevocationList(newRLData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, newRL.Revoked, 1)
|
||||||
|
assert.Contains(t, newRL.Revoked, revokedKeyID)
|
||||||
|
|
||||||
|
// Verify signature
|
||||||
|
sig, err := ParseSignature(newSigData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, sig.Signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtendRevocationList_MultipleKeys(t *testing.T) {
|
||||||
|
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKey := RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: rootPriv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create empty revocation list
|
||||||
|
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rl, err := ParseRevocationList(rlData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add first key
|
||||||
|
key1Pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
key1ID := computeKeyID(key1Pub)
|
||||||
|
|
||||||
|
rlData, _, err = ExtendRevocationList(rootKey, *rl, key1ID, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rl, err = ParseRevocationList(rlData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, rl.Revoked, 1)
|
||||||
|
|
||||||
|
// Add second key
|
||||||
|
key2Pub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
key2ID := computeKeyID(key2Pub)
|
||||||
|
|
||||||
|
rlData, _, err = ExtendRevocationList(rootKey, *rl, key2ID, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rl, err = ParseRevocationList(rlData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, rl.Revoked, 2)
|
||||||
|
assert.Contains(t, rl.Revoked, key1ID)
|
||||||
|
assert.Contains(t, rl.Revoked, key2ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtendRevocationList_DuplicateKey(t *testing.T) {
|
||||||
|
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKey := RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: rootPriv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create empty revocation list
|
||||||
|
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rl, err := ParseRevocationList(rlData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add a key
|
||||||
|
keyPub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
keyID := computeKeyID(keyPub)
|
||||||
|
|
||||||
|
rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rl, err = ParseRevocationList(rlData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
firstRevocationTime := rl.Revoked[keyID]
|
||||||
|
|
||||||
|
// Wait a bit
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Add the same key again
|
||||||
|
rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rl, err = ParseRevocationList(rlData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, rl.Revoked, 1)
|
||||||
|
|
||||||
|
// The revocation time should be updated
|
||||||
|
secondRevocationTime := rl.Revoked[keyID]
|
||||||
|
assert.True(t, secondRevocationTime.After(firstRevocationTime) || secondRevocationTime.Equal(firstRevocationTime))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtendRevocationList_UpdatesLastUpdated(t *testing.T) {
|
||||||
|
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKey := RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: rootPriv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create revocation list
|
||||||
|
rlData, _, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rl, err := ParseRevocationList(rlData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
firstLastUpdated := rl.LastUpdated
|
||||||
|
|
||||||
|
// Wait a bit
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Extend list
|
||||||
|
keyPub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
keyID := computeKeyID(keyPub)
|
||||||
|
|
||||||
|
rlData, _, err = ExtendRevocationList(rootKey, *rl, keyID, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rl, err = ParseRevocationList(rlData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// LastUpdated should be updated
|
||||||
|
assert.True(t, rl.LastUpdated.After(firstLastUpdated))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Integration test
|
||||||
|
|
||||||
|
func TestRevocationList_FullWorkflow(t *testing.T) {
|
||||||
|
// Create root key
|
||||||
|
rootPub, rootPriv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKey := RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: rootPriv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
rootKeys := []PublicKey{
|
||||||
|
{
|
||||||
|
Key: rootPub,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(rootPub),
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 1: Create empty revocation list
|
||||||
|
rlData, sigData, err := CreateRevocationList(rootKey, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Step 2: Validate it
|
||||||
|
sig, err := ParseSignature(sigData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rl, err := ValidateRevocationList(rootKeys, rlData, *sig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, rl.Revoked)
|
||||||
|
|
||||||
|
// Step 3: Revoke a key
|
||||||
|
revokedPub, _, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
revokedKeyID := computeKeyID(revokedPub)
|
||||||
|
|
||||||
|
rlData, sigData, err = ExtendRevocationList(rootKey, *rl, revokedKeyID, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Step 4: Validate the extended list
|
||||||
|
sig, err = ParseSignature(sigData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rl, err = ValidateRevocationList(rootKeys, rlData, *sig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Len(t, rl.Revoked, 1)
|
||||||
|
assert.Contains(t, rl.Revoked, revokedKeyID)
|
||||||
|
|
||||||
|
// Step 5: Verify the revocation time is reasonable
|
||||||
|
revTime := rl.Revoked[revokedKeyID]
|
||||||
|
now := time.Now().UTC()
|
||||||
|
assert.True(t, revTime.Before(now) || revTime.Equal(now))
|
||||||
|
assert.True(t, now.Sub(revTime) < time.Minute)
|
||||||
|
}
|
||||||
120
client/internal/updatemanager/reposign/root.go
Normal file
120
client/internal/updatemanager/reposign/root.go
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
package reposign
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
tagRootPrivate = "ROOT PRIVATE KEY"
|
||||||
|
tagRootPublic = "ROOT PUBLIC KEY"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RootKey is a root Key used to sign signing keys
|
||||||
|
type RootKey struct {
|
||||||
|
PrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k RootKey) String() string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"RootKey[ID=%s, CreatedAt=%s, ExpiresAt=%s]",
|
||||||
|
k.Metadata.ID,
|
||||||
|
k.Metadata.CreatedAt.Format(time.RFC3339),
|
||||||
|
k.Metadata.ExpiresAt.Format(time.RFC3339),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseRootKey(privKeyPEM []byte) (*RootKey, error) {
|
||||||
|
pk, err := parsePrivateKey(privKeyPEM, tagRootPrivate)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse root Key: %w", err)
|
||||||
|
}
|
||||||
|
return &RootKey{pk}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseRootPublicKey parses a root public key from PEM format
|
||||||
|
func ParseRootPublicKey(pubKeyPEM []byte) (PublicKey, error) {
|
||||||
|
pk, _, err := parsePublicKey(pubKeyPEM, tagRootPublic)
|
||||||
|
if err != nil {
|
||||||
|
return PublicKey{}, fmt.Errorf("failed to parse root public key: %w", err)
|
||||||
|
}
|
||||||
|
return pk, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateRootKey generates a new root Key pair with Metadata
|
||||||
|
func GenerateRootKey(expiration time.Duration) (*RootKey, []byte, []byte, error) {
|
||||||
|
now := time.Now()
|
||||||
|
expirationTime := now.Add(expiration)
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata := KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: now.UTC(),
|
||||||
|
ExpiresAt: expirationTime.UTC(),
|
||||||
|
}
|
||||||
|
|
||||||
|
rk := &RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: priv,
|
||||||
|
Metadata: metadata,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal PrivateKey struct to JSON
|
||||||
|
privJSON, err := json.Marshal(rk.PrivateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("failed to marshal private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal PublicKey struct to JSON
|
||||||
|
pubKey := PublicKey{
|
||||||
|
Key: pub,
|
||||||
|
Metadata: metadata,
|
||||||
|
}
|
||||||
|
pubJSON, err := json.Marshal(pubKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, fmt.Errorf("failed to marshal public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode to PEM with metadata embedded in bytes
|
||||||
|
privPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagRootPrivate,
|
||||||
|
Bytes: privJSON,
|
||||||
|
})
|
||||||
|
|
||||||
|
pubPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagRootPublic,
|
||||||
|
Bytes: pubJSON,
|
||||||
|
})
|
||||||
|
|
||||||
|
return rk, privPEM, pubPEM, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SignArtifactKey(rootKey RootKey, data []byte) ([]byte, error) {
|
||||||
|
timestamp := time.Now().UTC()
|
||||||
|
|
||||||
|
// This ensures the timestamp is cryptographically bound to the signature
|
||||||
|
msg := make([]byte, 0, len(data)+8)
|
||||||
|
msg = append(msg, data...)
|
||||||
|
msg = binary.LittleEndian.AppendUint64(msg, uint64(timestamp.Unix()))
|
||||||
|
|
||||||
|
sig := ed25519.Sign(rootKey.Key, msg)
|
||||||
|
// Create signature bundle with timestamp and Metadata
|
||||||
|
bundle := Signature{
|
||||||
|
Signature: sig,
|
||||||
|
Timestamp: timestamp,
|
||||||
|
KeyID: rootKey.Metadata.ID,
|
||||||
|
Algorithm: "ed25519",
|
||||||
|
HashAlgo: "sha512",
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(bundle)
|
||||||
|
}
|
||||||
476
client/internal/updatemanager/reposign/root_test.go
Normal file
476
client/internal/updatemanager/reposign/root_test.go
Normal file
@@ -0,0 +1,476 @@
|
|||||||
|
package reposign
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/ed25519"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/binary"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test RootKey.String()
|
||||||
|
|
||||||
|
func TestRootKey_String(t *testing.T) {
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
|
||||||
|
expiresAt := time.Date(2034, 1, 15, 10, 30, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
rk := RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: priv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
str := rk.String()
|
||||||
|
assert.Contains(t, str, "RootKey")
|
||||||
|
assert.Contains(t, str, computeKeyID(pub).String())
|
||||||
|
assert.Contains(t, str, "2024-01-15")
|
||||||
|
assert.Contains(t, str, "2034-01-15")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRootKey_String_NoExpiration(t *testing.T) {
|
||||||
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
createdAt := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
rk := RootKey{
|
||||||
|
PrivateKey{
|
||||||
|
Key: priv,
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: computeKeyID(pub),
|
||||||
|
CreatedAt: createdAt,
|
||||||
|
ExpiresAt: time.Time{}, // No expiration
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
str := rk.String()
|
||||||
|
assert.Contains(t, str, "RootKey")
|
||||||
|
assert.Contains(t, str, "0001-01-01") // Zero time format
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test GenerateRootKey
|
||||||
|
|
||||||
|
func TestGenerateRootKey_Valid(t *testing.T) {
|
||||||
|
expiration := 10 * 365 * 24 * time.Hour // 10 years
|
||||||
|
|
||||||
|
rk, privPEM, pubPEM, err := GenerateRootKey(expiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, rk)
|
||||||
|
assert.NotEmpty(t, privPEM)
|
||||||
|
assert.NotEmpty(t, pubPEM)
|
||||||
|
|
||||||
|
// Verify the key has correct metadata
|
||||||
|
assert.False(t, rk.Metadata.CreatedAt.IsZero())
|
||||||
|
assert.False(t, rk.Metadata.ExpiresAt.IsZero())
|
||||||
|
assert.True(t, rk.Metadata.ExpiresAt.After(rk.Metadata.CreatedAt))
|
||||||
|
|
||||||
|
// Verify expiration is approximately correct
|
||||||
|
expectedExpiration := time.Now().Add(expiration)
|
||||||
|
timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration)
|
||||||
|
assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateRootKey_ShortExpiration(t *testing.T) {
|
||||||
|
expiration := 24 * time.Hour // 1 day
|
||||||
|
|
||||||
|
rk, _, _, err := GenerateRootKey(expiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, rk)
|
||||||
|
|
||||||
|
// Verify expiration
|
||||||
|
expectedExpiration := time.Now().Add(expiration)
|
||||||
|
timeDiff := rk.Metadata.ExpiresAt.Sub(expectedExpiration)
|
||||||
|
assert.True(t, timeDiff < time.Minute && timeDiff > -time.Minute)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateRootKey_ZeroExpiration(t *testing.T) {
|
||||||
|
rk, _, _, err := GenerateRootKey(0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, rk)
|
||||||
|
|
||||||
|
// With zero expiration, ExpiresAt should be equal to CreatedAt
|
||||||
|
assert.Equal(t, rk.Metadata.CreatedAt, rk.Metadata.ExpiresAt)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateRootKey_PEMFormat(t *testing.T) {
|
||||||
|
rk, privPEM, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify private key PEM
|
||||||
|
privBlock, _ := pem.Decode(privPEM)
|
||||||
|
require.NotNil(t, privBlock)
|
||||||
|
assert.Equal(t, tagRootPrivate, privBlock.Type)
|
||||||
|
|
||||||
|
var privKey PrivateKey
|
||||||
|
err = json.Unmarshal(privBlock.Bytes, &privKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, rk.Key, privKey.Key)
|
||||||
|
|
||||||
|
// Verify public key PEM
|
||||||
|
pubBlock, _ := pem.Decode(pubPEM)
|
||||||
|
require.NotNil(t, pubBlock)
|
||||||
|
assert.Equal(t, tagRootPublic, pubBlock.Type)
|
||||||
|
|
||||||
|
var pubKey PublicKey
|
||||||
|
err = json.Unmarshal(pubBlock.Bytes, &pubKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, rk.Metadata.ID, pubKey.Metadata.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateRootKey_KeySize(t *testing.T) {
|
||||||
|
rk, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Ed25519 private key should be 64 bytes
|
||||||
|
assert.Equal(t, ed25519.PrivateKeySize, len(rk.Key))
|
||||||
|
|
||||||
|
// Ed25519 public key should be 32 bytes
|
||||||
|
pubKey := rk.Key.Public().(ed25519.PublicKey)
|
||||||
|
assert.Equal(t, ed25519.PublicKeySize, len(pubKey))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateRootKey_UniqueKeys(t *testing.T) {
|
||||||
|
rk1, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rk2, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Different keys should have different IDs
|
||||||
|
assert.NotEqual(t, rk1.Metadata.ID, rk2.Metadata.ID)
|
||||||
|
assert.NotEqual(t, rk1.Key, rk2.Key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test ParseRootKey
|
||||||
|
|
||||||
|
func TestParseRootKey_Valid(t *testing.T) {
|
||||||
|
original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsed, err := ParseRootKey(privPEM)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, parsed)
|
||||||
|
|
||||||
|
// Verify the parsed key matches the original
|
||||||
|
assert.Equal(t, original.Key, parsed.Key)
|
||||||
|
assert.Equal(t, original.Metadata.ID, parsed.Metadata.ID)
|
||||||
|
assert.Equal(t, original.Metadata.CreatedAt.Unix(), parsed.Metadata.CreatedAt.Unix())
|
||||||
|
assert.Equal(t, original.Metadata.ExpiresAt.Unix(), parsed.Metadata.ExpiresAt.Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRootKey_InvalidPEM(t *testing.T) {
|
||||||
|
_, err := ParseRootKey([]byte("not a valid PEM"))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to parse")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRootKey_EmptyData(t *testing.T) {
|
||||||
|
_, err := ParseRootKey([]byte{})
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRootKey_WrongType(t *testing.T) {
|
||||||
|
// Generate an artifact key instead of root key
|
||||||
|
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
artifactKey, privPEM, _, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Try to parse artifact key as root key
|
||||||
|
_, err = ParseRootKey(privPEM)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "PEM type")
|
||||||
|
|
||||||
|
// Just to use artifactKey to avoid unused variable warning
|
||||||
|
_ = artifactKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRootKey_CorruptedJSON(t *testing.T) {
|
||||||
|
// Create PEM with corrupted JSON
|
||||||
|
corruptedPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagRootPrivate,
|
||||||
|
Bytes: []byte("corrupted json data"),
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := ParseRootKey(corruptedPEM)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRootKey_InvalidKeySize(t *testing.T) {
|
||||||
|
// Create a key with invalid size
|
||||||
|
invalidKey := PrivateKey{
|
||||||
|
Key: []byte{0x01, 0x02, 0x03}, // Too short
|
||||||
|
Metadata: KeyMetadata{
|
||||||
|
ID: KeyID{},
|
||||||
|
CreatedAt: time.Now().UTC(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
privJSON, err := json.Marshal(invalidKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
invalidPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagRootPrivate,
|
||||||
|
Bytes: privJSON,
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err = ParseRootKey(invalidPEM)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "incorrect Ed25519 private key size")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseRootKey_Roundtrip(t *testing.T) {
|
||||||
|
// Generate a key
|
||||||
|
original, privPEM, _, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Parse it
|
||||||
|
parsed, err := ParseRootKey(privPEM)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Generate PEM again from parsed key
|
||||||
|
privJSON2, err := json.Marshal(parsed.PrivateKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
privPEM2 := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: tagRootPrivate,
|
||||||
|
Bytes: privJSON2,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Parse again
|
||||||
|
parsed2, err := ParseRootKey(privPEM2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Should still match original
|
||||||
|
assert.Equal(t, original.Key, parsed2.Key)
|
||||||
|
assert.Equal(t, original.Metadata.ID, parsed2.Metadata.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test SignArtifactKey
|
||||||
|
|
||||||
|
func TestSignArtifactKey_Valid(t *testing.T) {
|
||||||
|
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
data := []byte("test data to sign")
|
||||||
|
sigData, err := SignArtifactKey(*rootKey, data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, sigData)
|
||||||
|
|
||||||
|
// Parse and verify signature
|
||||||
|
sig, err := ParseSignature(sigData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, sig.Signature)
|
||||||
|
assert.Equal(t, rootKey.Metadata.ID, sig.KeyID)
|
||||||
|
assert.Equal(t, "ed25519", sig.Algorithm)
|
||||||
|
assert.Equal(t, "sha512", sig.HashAlgo)
|
||||||
|
assert.False(t, sig.Timestamp.IsZero())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignArtifactKey_EmptyData(t *testing.T) {
|
||||||
|
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sigData, err := SignArtifactKey(*rootKey, []byte{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, sigData)
|
||||||
|
|
||||||
|
// Should still be able to parse
|
||||||
|
sig, err := ParseSignature(sigData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, sig.Signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignArtifactKey_Verify(t *testing.T) {
|
||||||
|
rootKey, _, pubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Parse public key
|
||||||
|
pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Sign some data
|
||||||
|
data := []byte("test data for verification")
|
||||||
|
sigData, err := SignArtifactKey(*rootKey, data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Parse signature
|
||||||
|
sig, err := ParseSignature(sigData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Reconstruct message
|
||||||
|
msg := make([]byte, 0, len(data)+8)
|
||||||
|
msg = append(msg, data...)
|
||||||
|
msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix()))
|
||||||
|
|
||||||
|
// Verify signature
|
||||||
|
valid := ed25519.Verify(pubKey.Key, msg, sig.Signature)
|
||||||
|
assert.True(t, valid)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignArtifactKey_DifferentData(t *testing.T) {
|
||||||
|
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
data1 := []byte("data1")
|
||||||
|
data2 := []byte("data2")
|
||||||
|
|
||||||
|
sig1, err := SignArtifactKey(*rootKey, data1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sig2, err := SignArtifactKey(*rootKey, data2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Different data should produce different signatures
|
||||||
|
assert.NotEqual(t, sig1, sig2)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignArtifactKey_MultipleSignatures(t *testing.T) {
|
||||||
|
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
data := []byte("test data")
|
||||||
|
|
||||||
|
// Sign twice with a small delay
|
||||||
|
sig1, err := SignArtifactKey(*rootKey, data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
sig2, err := SignArtifactKey(*rootKey, data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Signatures should be different due to different timestamps
|
||||||
|
assert.NotEqual(t, sig1, sig2)
|
||||||
|
|
||||||
|
// Parse both signatures
|
||||||
|
parsed1, err := ParseSignature(sig1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsed2, err := ParseSignature(sig2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Timestamps should be different
|
||||||
|
assert.True(t, parsed2.Timestamp.After(parsed1.Timestamp))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignArtifactKey_LargeData(t *testing.T) {
|
||||||
|
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create 1MB of data
|
||||||
|
largeData := make([]byte, 1024*1024)
|
||||||
|
for i := range largeData {
|
||||||
|
largeData[i] = byte(i % 256)
|
||||||
|
}
|
||||||
|
|
||||||
|
sigData, err := SignArtifactKey(*rootKey, largeData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, sigData)
|
||||||
|
|
||||||
|
// Verify signature can be parsed
|
||||||
|
sig, err := ParseSignature(sigData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, sig.Signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignArtifactKey_TimestampInSignature(t *testing.T) {
|
||||||
|
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
beforeSign := time.Now().UTC()
|
||||||
|
data := []byte("test data")
|
||||||
|
sigData, err := SignArtifactKey(*rootKey, data)
|
||||||
|
require.NoError(t, err)
|
||||||
|
afterSign := time.Now().UTC()
|
||||||
|
|
||||||
|
sig, err := ParseSignature(sigData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Timestamp should be between before and after
|
||||||
|
assert.True(t, sig.Timestamp.After(beforeSign.Add(-time.Second)))
|
||||||
|
assert.True(t, sig.Timestamp.Before(afterSign.Add(time.Second)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Integration test
|
||||||
|
|
||||||
|
func TestRootKey_FullWorkflow(t *testing.T) {
|
||||||
|
// Step 1: Generate root key
|
||||||
|
rootKey, privPEM, pubPEM, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, rootKey)
|
||||||
|
assert.NotEmpty(t, privPEM)
|
||||||
|
assert.NotEmpty(t, pubPEM)
|
||||||
|
|
||||||
|
// Step 2: Parse the private key back
|
||||||
|
parsedRootKey, err := ParseRootKey(privPEM)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, rootKey.Key, parsedRootKey.Key)
|
||||||
|
assert.Equal(t, rootKey.Metadata.ID, parsedRootKey.Metadata.ID)
|
||||||
|
|
||||||
|
// Step 3: Generate an artifact key using root key
|
||||||
|
artifactKey, _, artifactPubPEM, artifactSig, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, artifactKey)
|
||||||
|
|
||||||
|
// Step 4: Verify the artifact key signature
|
||||||
|
pubKey, _, err := parsePublicKey(pubPEM, tagRootPublic)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sig, err := ParseSignature(artifactSig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
artifactPubKey, _, err := parsePublicKey(artifactPubPEM, tagArtifactPublic)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Reconstruct message - SignArtifactKey signs the PEM, not the JSON
|
||||||
|
msg := make([]byte, 0, len(artifactPubPEM)+8)
|
||||||
|
msg = append(msg, artifactPubPEM...)
|
||||||
|
msg = binary.LittleEndian.AppendUint64(msg, uint64(sig.Timestamp.Unix()))
|
||||||
|
|
||||||
|
// Verify with root public key
|
||||||
|
valid := ed25519.Verify(pubKey.Key, msg, sig.Signature)
|
||||||
|
assert.True(t, valid, "Artifact key signature should be valid")
|
||||||
|
|
||||||
|
// Step 5: Use artifact key to sign data
|
||||||
|
testData := []byte("This is test artifact data")
|
||||||
|
dataSig, err := SignData(*artifactKey, testData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, dataSig)
|
||||||
|
|
||||||
|
// Step 6: Verify the artifact data signature
|
||||||
|
dataSigParsed, err := ParseSignature(dataSig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = ValidateArtifact([]PublicKey{artifactPubKey}, testData, *dataSigParsed)
|
||||||
|
assert.NoError(t, err, "Artifact data signature should be valid")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRootKey_ExpiredKeyWorkflow(t *testing.T) {
|
||||||
|
// Generate a root key that expires very soon
|
||||||
|
rootKey, _, _, err := GenerateRootKey(1 * time.Millisecond)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Wait for expiration
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
|
||||||
|
// Try to generate artifact key with expired root key
|
||||||
|
_, _, _, _, err = GenerateArtifactKey(rootKey, 30*24*time.Hour)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "expired")
|
||||||
|
}
|
||||||
24
client/internal/updatemanager/reposign/signature.go
Normal file
24
client/internal/updatemanager/reposign/signature.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package reposign
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Signature contains a signature with associated Metadata
|
||||||
|
type Signature struct {
|
||||||
|
Signature []byte `json:"signature"`
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
KeyID KeyID `json:"key_id"`
|
||||||
|
Algorithm string `json:"algorithm"` // "ed25519"
|
||||||
|
HashAlgo string `json:"hash_algo"` // "blake2s" or sha512
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseSignature(data []byte) (*Signature, error) {
|
||||||
|
var signature Signature
|
||||||
|
if err := json.Unmarshal(data, &signature); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &signature, nil
|
||||||
|
}
|
||||||
277
client/internal/updatemanager/reposign/signature_test.go
Normal file
277
client/internal/updatemanager/reposign/signature_test.go
Normal file
@@ -0,0 +1,277 @@
|
|||||||
|
package reposign
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseSignature_Valid(t *testing.T) {
|
||||||
|
timestamp := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC)
|
||||||
|
keyID, err := ParseKeyID("0123456789abcdef")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
signatureData := []byte{0x01, 0x02, 0x03, 0x04}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(Signature{
|
||||||
|
Signature: signatureData,
|
||||||
|
Timestamp: timestamp,
|
||||||
|
KeyID: keyID,
|
||||||
|
Algorithm: "ed25519",
|
||||||
|
HashAlgo: "blake2s",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sig, err := ParseSignature(jsonData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, sig)
|
||||||
|
assert.Equal(t, signatureData, sig.Signature)
|
||||||
|
assert.Equal(t, timestamp.Unix(), sig.Timestamp.Unix())
|
||||||
|
assert.Equal(t, keyID, sig.KeyID)
|
||||||
|
assert.Equal(t, "ed25519", sig.Algorithm)
|
||||||
|
assert.Equal(t, "blake2s", sig.HashAlgo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSignature_InvalidJSON(t *testing.T) {
|
||||||
|
invalidJSON := []byte(`{invalid json}`)
|
||||||
|
|
||||||
|
sig, err := ParseSignature(invalidJSON)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, sig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSignature_EmptyData(t *testing.T) {
|
||||||
|
emptyJSON := []byte(`{}`)
|
||||||
|
|
||||||
|
sig, err := ParseSignature(emptyJSON)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, sig)
|
||||||
|
assert.Empty(t, sig.Signature)
|
||||||
|
assert.True(t, sig.Timestamp.IsZero())
|
||||||
|
assert.Equal(t, KeyID{}, sig.KeyID)
|
||||||
|
assert.Empty(t, sig.Algorithm)
|
||||||
|
assert.Empty(t, sig.HashAlgo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSignature_MissingFields(t *testing.T) {
|
||||||
|
// JSON with only some fields
|
||||||
|
partialJSON := []byte(`{
|
||||||
|
"signature": "AQIDBA==",
|
||||||
|
"algorithm": "ed25519"
|
||||||
|
}`)
|
||||||
|
|
||||||
|
sig, err := ParseSignature(partialJSON)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, sig)
|
||||||
|
assert.NotEmpty(t, sig.Signature)
|
||||||
|
assert.Equal(t, "ed25519", sig.Algorithm)
|
||||||
|
assert.True(t, sig.Timestamp.IsZero())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignature_MarshalUnmarshal_Roundtrip(t *testing.T) {
|
||||||
|
timestamp := time.Date(2024, 6, 20, 14, 45, 30, 0, time.UTC)
|
||||||
|
keyID, err := ParseKeyID("fedcba9876543210")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
original := Signature{
|
||||||
|
Signature: []byte{0xde, 0xad, 0xbe, 0xef, 0xca, 0xfe},
|
||||||
|
Timestamp: timestamp,
|
||||||
|
KeyID: keyID,
|
||||||
|
Algorithm: "ed25519",
|
||||||
|
HashAlgo: "sha512",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Marshal
|
||||||
|
jsonData, err := json.Marshal(original)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Unmarshal
|
||||||
|
parsed, err := ParseSignature(jsonData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify
|
||||||
|
assert.Equal(t, original.Signature, parsed.Signature)
|
||||||
|
assert.Equal(t, original.Timestamp.Unix(), parsed.Timestamp.Unix())
|
||||||
|
assert.Equal(t, original.KeyID, parsed.KeyID)
|
||||||
|
assert.Equal(t, original.Algorithm, parsed.Algorithm)
|
||||||
|
assert.Equal(t, original.HashAlgo, parsed.HashAlgo)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignature_NilSignatureBytes(t *testing.T) {
|
||||||
|
timestamp := time.Now().UTC()
|
||||||
|
keyID, err := ParseKeyID("0011223344556677")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sig := Signature{
|
||||||
|
Signature: nil,
|
||||||
|
Timestamp: timestamp,
|
||||||
|
KeyID: keyID,
|
||||||
|
Algorithm: "ed25519",
|
||||||
|
HashAlgo: "blake2s",
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(sig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsed, err := ParseSignature(jsonData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Nil(t, parsed.Signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignature_LargeSignature(t *testing.T) {
|
||||||
|
timestamp := time.Now().UTC()
|
||||||
|
keyID, err := ParseKeyID("aabbccddeeff0011")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create a large signature (64 bytes for ed25519)
|
||||||
|
largeSignature := make([]byte, 64)
|
||||||
|
for i := range largeSignature {
|
||||||
|
largeSignature[i] = byte(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
sig := Signature{
|
||||||
|
Signature: largeSignature,
|
||||||
|
Timestamp: timestamp,
|
||||||
|
KeyID: keyID,
|
||||||
|
Algorithm: "ed25519",
|
||||||
|
HashAlgo: "blake2s",
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(sig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsed, err := ParseSignature(jsonData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, largeSignature, parsed.Signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignature_WithDifferentHashAlgorithms(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
hashAlgo string
|
||||||
|
}{
|
||||||
|
{"blake2s", "blake2s"},
|
||||||
|
{"sha512", "sha512"},
|
||||||
|
{"sha256", "sha256"},
|
||||||
|
{"empty", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
keyID, err := ParseKeyID("1122334455667788")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
sig := Signature{
|
||||||
|
Signature: []byte{0x01, 0x02},
|
||||||
|
Timestamp: time.Now().UTC(),
|
||||||
|
KeyID: keyID,
|
||||||
|
Algorithm: "ed25519",
|
||||||
|
HashAlgo: tt.hashAlgo,
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(sig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsed, err := ParseSignature(jsonData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.hashAlgo, parsed.HashAlgo)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignature_TimestampPrecision(t *testing.T) {
|
||||||
|
// Test that timestamp preserves precision through JSON marshaling
|
||||||
|
timestamp := time.Date(2024, 3, 15, 10, 30, 45, 123456789, time.UTC)
|
||||||
|
keyID, err := ParseKeyID("8877665544332211")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
sig := Signature{
|
||||||
|
Signature: []byte{0xaa, 0xbb},
|
||||||
|
Timestamp: timestamp,
|
||||||
|
KeyID: keyID,
|
||||||
|
Algorithm: "ed25519",
|
||||||
|
HashAlgo: "blake2s",
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(sig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsed, err := ParseSignature(jsonData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// JSON timestamps typically have second or millisecond precision
|
||||||
|
// so we check that at least seconds match
|
||||||
|
assert.Equal(t, timestamp.Unix(), parsed.Timestamp.Unix())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSignature_MalformedKeyID(t *testing.T) {
|
||||||
|
// Test with a malformed KeyID field
|
||||||
|
malformedJSON := []byte(`{
|
||||||
|
"signature": "AQID",
|
||||||
|
"timestamp": "2024-01-15T10:30:00Z",
|
||||||
|
"key_id": "invalid_keyid_format",
|
||||||
|
"algorithm": "ed25519",
|
||||||
|
"hash_algo": "blake2s"
|
||||||
|
}`)
|
||||||
|
|
||||||
|
// This should fail since "invalid_keyid_format" is not a valid KeyID
|
||||||
|
sig, err := ParseSignature(malformedJSON)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, sig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSignature_InvalidTimestamp(t *testing.T) {
|
||||||
|
// Test with an invalid timestamp format
|
||||||
|
invalidTimestampJSON := []byte(`{
|
||||||
|
"signature": "AQID",
|
||||||
|
"timestamp": "not-a-timestamp",
|
||||||
|
"key_id": "0123456789abcdef",
|
||||||
|
"algorithm": "ed25519",
|
||||||
|
"hash_algo": "blake2s"
|
||||||
|
}`)
|
||||||
|
|
||||||
|
sig, err := ParseSignature(invalidTimestampJSON)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, sig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignature_ZeroKeyID(t *testing.T) {
|
||||||
|
// Test with a zero KeyID
|
||||||
|
sig := Signature{
|
||||||
|
Signature: []byte{0x01, 0x02, 0x03},
|
||||||
|
Timestamp: time.Now().UTC(),
|
||||||
|
KeyID: KeyID{},
|
||||||
|
Algorithm: "ed25519",
|
||||||
|
HashAlgo: "blake2s",
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(sig)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsed, err := ParseSignature(jsonData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, KeyID{}, parsed.KeyID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSignature_ExtraFields(t *testing.T) {
|
||||||
|
// JSON with extra fields that should be ignored
|
||||||
|
jsonWithExtra := []byte(`{
|
||||||
|
"signature": "AQIDBA==",
|
||||||
|
"timestamp": "2024-01-15T10:30:00Z",
|
||||||
|
"key_id": "0123456789abcdef",
|
||||||
|
"algorithm": "ed25519",
|
||||||
|
"hash_algo": "blake2s",
|
||||||
|
"extra_field": "should be ignored",
|
||||||
|
"another_extra": 12345
|
||||||
|
}`)
|
||||||
|
|
||||||
|
sig, err := ParseSignature(jsonWithExtra)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotNil(t, sig)
|
||||||
|
assert.NotEmpty(t, sig.Signature)
|
||||||
|
assert.Equal(t, "ed25519", sig.Algorithm)
|
||||||
|
assert.Equal(t, "blake2s", sig.HashAlgo)
|
||||||
|
}
|
||||||
187
client/internal/updatemanager/reposign/verify.go
Normal file
187
client/internal/updatemanager/reposign/verify.go
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
package reposign
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
artifactPubKeysFileName = "artifact-key-pub.pem"
|
||||||
|
artifactPubKeysSigFileName = "artifact-key-pub.pem.sig"
|
||||||
|
revocationFileName = "revocation-list.json"
|
||||||
|
revocationSignFileName = "revocation-list.json.sig"
|
||||||
|
|
||||||
|
keySizeLimit = 5 * 1024 * 1024 //5MB
|
||||||
|
signatureLimit = 1024
|
||||||
|
revocationLimit = 10 * 1024 * 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
type ArtifactVerify struct {
|
||||||
|
rootKeys []PublicKey
|
||||||
|
keysBaseURL *url.URL
|
||||||
|
|
||||||
|
revocationList *RevocationList
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewArtifactVerify(keysBaseURL string) (*ArtifactVerify, error) {
|
||||||
|
allKeys, err := loadEmbeddedPublicKeys()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return newArtifactVerify(keysBaseURL, allKeys)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newArtifactVerify(keysBaseURL string, allKeys []PublicKey) (*ArtifactVerify, error) {
|
||||||
|
ku, err := url.Parse(keysBaseURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid keys base URL %q: %v", keysBaseURL, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
a := &ArtifactVerify{
|
||||||
|
rootKeys: allKeys,
|
||||||
|
keysBaseURL: ku,
|
||||||
|
}
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ArtifactVerify) Verify(ctx context.Context, version string, artifactFile string) error {
|
||||||
|
version = strings.TrimPrefix(version, "v")
|
||||||
|
|
||||||
|
revocationList, err := a.loadRevocationList(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load revocation list: %v", err)
|
||||||
|
}
|
||||||
|
a.revocationList = revocationList
|
||||||
|
|
||||||
|
artifactPubKeys, err := a.loadArtifactKeys(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load artifact keys: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := a.loadArtifactSignature(ctx, version, artifactFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to download signature file for: %s, %v", filepath.Base(artifactFile), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
artifactData, err := os.ReadFile(artifactFile)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to read artifact file: %v", err)
|
||||||
|
return fmt.Errorf("failed to read artifact file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := ValidateArtifact(artifactPubKeys, artifactData, *signature); err != nil {
|
||||||
|
return fmt.Errorf("failed to validate artifact: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ArtifactVerify) loadRevocationList(ctx context.Context) (*RevocationList, error) {
|
||||||
|
downloadURL := a.keysBaseURL.JoinPath("keys", revocationFileName).String()
|
||||||
|
data, err := downloader.DownloadToMemory(ctx, downloadURL, revocationLimit)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to download revocation list '%s': %s", downloadURL, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
downloadURL = a.keysBaseURL.JoinPath("keys", revocationSignFileName).String()
|
||||||
|
sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to download revocation list '%s': %s", downloadURL, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := ParseSignature(sigData)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to parse revocation list signature: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ValidateRevocationList(a.rootKeys, data, *signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ArtifactVerify) loadArtifactKeys(ctx context.Context) ([]PublicKey, error) {
|
||||||
|
downloadURL := a.keysBaseURL.JoinPath("keys", artifactPubKeysFileName).String()
|
||||||
|
log.Debugf("starting downloading artifact keys from: %s", downloadURL)
|
||||||
|
data, err := downloader.DownloadToMemory(ctx, downloadURL, keySizeLimit)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to download artifact keys: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
downloadURL = a.keysBaseURL.JoinPath("keys", artifactPubKeysSigFileName).String()
|
||||||
|
log.Debugf("start downloading signature of artifact pub key from: %s", downloadURL)
|
||||||
|
sigData, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to download signature of public keys: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := ParseSignature(sigData)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to parse signature of public keys: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ValidateArtifactKeys(a.rootKeys, data, *signature, a.revocationList)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ArtifactVerify) loadArtifactSignature(ctx context.Context, version string, artifactFile string) (*Signature, error) {
|
||||||
|
artifactFile = filepath.Base(artifactFile)
|
||||||
|
downloadURL := a.keysBaseURL.JoinPath("tag", "v"+version, artifactFile+".sig").String()
|
||||||
|
data, err := downloader.DownloadToMemory(ctx, downloadURL, signatureLimit)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to download artifact signature: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
signature, err := ParseSignature(data)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("failed to parse artifact signature: %s", err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return signature, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadEmbeddedPublicKeys() ([]PublicKey, error) {
|
||||||
|
files, err := embeddedCerts.ReadDir(embeddedCertsDir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read embedded certs: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var allKeys []PublicKey
|
||||||
|
for _, file := range files {
|
||||||
|
if file.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := embeddedCerts.ReadFile(embeddedCertsDir + "/" + file.Name())
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read cert file %s: %w", file.Name(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys, err := parsePublicKeyBundle(data, tagRootPublic)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse cert %s: %w", file.Name(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
allKeys = append(allKeys, keys...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(allKeys) == 0 {
|
||||||
|
return nil, fmt.Errorf("no valid public keys found in embedded certs")
|
||||||
|
}
|
||||||
|
|
||||||
|
return allKeys, nil
|
||||||
|
}
|
||||||
528
client/internal/updatemanager/reposign/verify_test.go
Normal file
528
client/internal/updatemanager/reposign/verify_test.go
Normal file
@@ -0,0 +1,528 @@
|
|||||||
|
package reposign
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ed25519"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Test ArtifactVerify construction
|
||||||
|
|
||||||
|
func TestArtifactVerify_Construction(t *testing.T) {
|
||||||
|
// Generate test root key
|
||||||
|
rootKey, _, rootPubPEM, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootPubKey, _, err := parsePublicKey(rootPubPEM, tagRootPublic)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
keysBaseURL := "http://localhost:8080/artifact-signatures"
|
||||||
|
|
||||||
|
av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NotNil(t, av)
|
||||||
|
assert.NotEmpty(t, av.rootKeys)
|
||||||
|
assert.Equal(t, keysBaseURL, av.keysBaseURL.String())
|
||||||
|
|
||||||
|
// Verify root key structure
|
||||||
|
assert.NotEmpty(t, av.rootKeys[0].Key)
|
||||||
|
assert.Equal(t, rootKey.Metadata.ID, av.rootKeys[0].Metadata.ID)
|
||||||
|
assert.False(t, av.rootKeys[0].Metadata.CreatedAt.IsZero())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArtifactVerify_MultipleRootKeys(t *testing.T) {
|
||||||
|
// Generate multiple test root keys
|
||||||
|
rootKey1, _, rootPubPEM1, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
rootPubKey1, _, err := parsePublicKey(rootPubPEM1, tagRootPublic)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKey2, _, rootPubPEM2, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
rootPubKey2, _, err := parsePublicKey(rootPubPEM2, tagRootPublic)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
keysBaseURL := "http://localhost:8080/artifact-signatures"
|
||||||
|
|
||||||
|
av, err := newArtifactVerify(keysBaseURL, []PublicKey{rootPubKey1, rootPubKey2})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, av.rootKeys, 2)
|
||||||
|
assert.NotEqual(t, rootKey1.Metadata.ID, rootKey2.Metadata.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Verify workflow with mock HTTP server
|
||||||
|
|
||||||
|
func TestArtifactVerify_FullWorkflow(t *testing.T) {
|
||||||
|
// Create temporary test directory
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Step 1: Generate root key
|
||||||
|
rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Step 2: Generate artifact key
|
||||||
|
artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Step 3: Create revocation list
|
||||||
|
revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Step 4: Bundle artifact keys
|
||||||
|
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Step 5: Create test artifact
|
||||||
|
artifactPath := filepath.Join(tempDir, "test-artifact.bin")
|
||||||
|
artifactData := []byte("This is test artifact data for verification")
|
||||||
|
err = os.WriteFile(artifactPath, artifactData, 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Step 6: Sign artifact
|
||||||
|
artifactSigData, err := SignData(*artifactKey, artifactData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Step 7: Setup mock HTTP server
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/artifact-signatures/keys/" + revocationFileName:
|
||||||
|
_, _ = w.Write(revocationData)
|
||||||
|
case "/artifact-signatures/keys/" + revocationSignFileName:
|
||||||
|
_, _ = w.Write(revocationSig)
|
||||||
|
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
|
||||||
|
_, _ = w.Write(artifactKeysBundle)
|
||||||
|
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
|
||||||
|
_, _ = w.Write(artifactKeysSig)
|
||||||
|
case "/artifacts/v1.0.0/test-artifact.bin":
|
||||||
|
_, _ = w.Write(artifactData)
|
||||||
|
case "/artifact-signatures/tag/v1.0.0/test-artifact.bin.sig":
|
||||||
|
_, _ = w.Write(artifactSigData)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Step 8: Create ArtifactVerify with test root key
|
||||||
|
rootPubKey := PublicKey{
|
||||||
|
Key: rootKey.Key.Public().(ed25519.PublicKey),
|
||||||
|
Metadata: rootKey.Metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Step 9: Verify artifact
|
||||||
|
ctx := context.Background()
|
||||||
|
err = av.Verify(ctx, "1.0.0", artifactPath)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArtifactVerify_InvalidRevocationList(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
artifactPath := filepath.Join(tempDir, "test.bin")
|
||||||
|
err := os.WriteFile(artifactPath, []byte("test"), 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Setup server with invalid revocation list
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/artifact-signatures/keys/" + revocationFileName:
|
||||||
|
_, _ = w.Write([]byte("invalid data"))
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootPubKey := PublicKey{
|
||||||
|
Key: rootKey.Key.Public().(ed25519.PublicKey),
|
||||||
|
Metadata: rootKey.Metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err = av.Verify(ctx, "1.0.0", artifactPath)
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to load revocation list")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArtifactVerify_MissingArtifactFile(t *testing.T) {
|
||||||
|
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootPubKey := PublicKey{
|
||||||
|
Key: rootKey.Key.Public().(ed25519.PublicKey),
|
||||||
|
Metadata: rootKey.Metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create revocation list
|
||||||
|
revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create signature for non-existent file
|
||||||
|
testData := []byte("test")
|
||||||
|
artifactSigData, err := SignData(*artifactKey, testData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/artifact-signatures/keys/" + revocationFileName:
|
||||||
|
_, _ = w.Write(revocationData)
|
||||||
|
case "/artifact-signatures/keys/" + revocationSignFileName:
|
||||||
|
_, _ = w.Write(revocationSig)
|
||||||
|
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
|
||||||
|
_, _ = w.Write(artifactKeysBundle)
|
||||||
|
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
|
||||||
|
_, _ = w.Write(artifactKeysSig)
|
||||||
|
case "/artifact-signatures/tag/v1.0.0/missing.bin.sig":
|
||||||
|
_, _ = w.Write(artifactSigData)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err = av.Verify(ctx, "1.0.0", "file.bin")
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArtifactVerify_ServerUnavailable(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
artifactPath := filepath.Join(tempDir, "test.bin")
|
||||||
|
err := os.WriteFile(artifactPath, []byte("test"), 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootPubKey := PublicKey{
|
||||||
|
Key: rootKey.Key.Public().(ed25519.PublicKey),
|
||||||
|
Metadata: rootKey.Metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use URL that doesn't exist
|
||||||
|
av, err := newArtifactVerify("http://localhost:19999/keys", []PublicKey{rootPubKey})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err = av.Verify(ctx, "1.0.0", artifactPath)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArtifactVerify_ContextCancellation(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
artifactPath := filepath.Join(tempDir, "test.bin")
|
||||||
|
err := os.WriteFile(artifactPath, []byte("test"), 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create a server that delays response
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
_, _ = w.Write([]byte("data"))
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
rootKey, _, _, err := GenerateRootKey(365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rootPubKey := PublicKey{
|
||||||
|
Key: rootKey.Key.Public().(ed25519.PublicKey),
|
||||||
|
Metadata: rootKey.Metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
av, err := newArtifactVerify(server.URL, []PublicKey{rootPubKey})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create context that cancels quickly
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err = av.Verify(ctx, "1.0.0", artifactPath)
|
||||||
|
assert.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArtifactVerify_WithRevocation(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Generate root key
|
||||||
|
rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Generate two artifact keys
|
||||||
|
artifactKey1, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
_, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create revocation list with first key revoked
|
||||||
|
emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsedRevocation, err := ParseRevocationList(emptyRevocation)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Bundle both keys
|
||||||
|
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create artifact signed by revoked key
|
||||||
|
artifactPath := filepath.Join(tempDir, "test.bin")
|
||||||
|
artifactData := []byte("test data")
|
||||||
|
err = os.WriteFile(artifactPath, artifactData, 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
artifactSigData, err := SignData(*artifactKey1, artifactData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/artifact-signatures/keys/" + revocationFileName:
|
||||||
|
_, _ = w.Write(revocationData)
|
||||||
|
case "/artifact-signatures/keys/" + revocationSignFileName:
|
||||||
|
_, _ = w.Write(revocationSig)
|
||||||
|
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
|
||||||
|
_, _ = w.Write(artifactKeysBundle)
|
||||||
|
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
|
||||||
|
_, _ = w.Write(artifactKeysSig)
|
||||||
|
case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
|
||||||
|
_, _ = w.Write(artifactSigData)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
rootPubKey := PublicKey{
|
||||||
|
Key: rootKey.Key.Public().(ed25519.PublicKey),
|
||||||
|
Metadata: rootKey.Metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err = av.Verify(ctx, "1.0.0", artifactPath)
|
||||||
|
// Should fail because the signing key is revoked
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "no signing Key found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArtifactVerify_ValidWithSecondKey(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Generate root key
|
||||||
|
rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Generate two artifact keys
|
||||||
|
_, _, artifactPubPEM1, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
artifactPubKey1, err := ParseArtifactPubKey(artifactPubPEM1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
artifactKey2, _, artifactPubPEM2, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
artifactPubKey2, err := ParseArtifactPubKey(artifactPubPEM2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create revocation list with first key revoked
|
||||||
|
emptyRevocation, _, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
parsedRevocation, err := ParseRevocationList(emptyRevocation)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
revocationData, revocationSig, err := ExtendRevocationList(*rootKey, *parsedRevocation, artifactPubKey1.Metadata.ID, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Bundle both keys
|
||||||
|
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey1, artifactPubKey2})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create artifact signed by second key (not revoked)
|
||||||
|
artifactPath := filepath.Join(tempDir, "test.bin")
|
||||||
|
artifactData := []byte("test data")
|
||||||
|
err = os.WriteFile(artifactPath, artifactData, 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
artifactSigData, err := SignData(*artifactKey2, artifactData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/artifact-signatures/keys/" + revocationFileName:
|
||||||
|
_, _ = w.Write(revocationData)
|
||||||
|
case "/artifact-signatures/keys/" + revocationSignFileName:
|
||||||
|
_, _ = w.Write(revocationSig)
|
||||||
|
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
|
||||||
|
_, _ = w.Write(artifactKeysBundle)
|
||||||
|
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
|
||||||
|
_, _ = w.Write(artifactKeysSig)
|
||||||
|
case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
|
||||||
|
_, _ = w.Write(artifactSigData)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
rootPubKey := PublicKey{
|
||||||
|
Key: rootKey.Key.Public().(ed25519.PublicKey),
|
||||||
|
Metadata: rootKey.Metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err = av.Verify(ctx, "1.0.0", artifactPath)
|
||||||
|
// Should succeed because second key is not revoked
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestArtifactVerify_TamperedArtifact(t *testing.T) {
|
||||||
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
|
// Generate root key and artifact key
|
||||||
|
rootKey, _, _, err := GenerateRootKey(10 * 365 * 24 * time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
artifactKey, _, artifactPubPEM, _, err := GenerateArtifactKey(rootKey, 30*24*time.Hour)
|
||||||
|
require.NoError(t, err)
|
||||||
|
artifactPubKey, err := ParseArtifactPubKey(artifactPubPEM)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create revocation list
|
||||||
|
revocationData, revocationSig, err := CreateRevocationList(*rootKey, defaultRevocationListExpiration)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Bundle keys
|
||||||
|
artifactKeysBundle, artifactKeysSig, err := BundleArtifactKeys(rootKey, []PublicKey{artifactPubKey})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Sign original data
|
||||||
|
originalData := []byte("original data")
|
||||||
|
artifactSigData, err := SignData(*artifactKey, originalData)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Write tampered data to file
|
||||||
|
artifactPath := filepath.Join(tempDir, "test.bin")
|
||||||
|
tamperedData := []byte("tampered data")
|
||||||
|
err = os.WriteFile(artifactPath, tamperedData, 0644)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/artifact-signatures/keys/" + revocationFileName:
|
||||||
|
_, _ = w.Write(revocationData)
|
||||||
|
case "/artifact-signatures/keys/" + revocationSignFileName:
|
||||||
|
_, _ = w.Write(revocationSig)
|
||||||
|
case "/artifact-signatures/keys/" + artifactPubKeysFileName:
|
||||||
|
_, _ = w.Write(artifactKeysBundle)
|
||||||
|
case "/artifact-signatures/keys/" + artifactPubKeysSigFileName:
|
||||||
|
_, _ = w.Write(artifactKeysSig)
|
||||||
|
case "/artifact-signatures/tag/v1.0.0/test.bin.sig":
|
||||||
|
_, _ = w.Write(artifactSigData)
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
rootPubKey := PublicKey{
|
||||||
|
Key: rootKey.Key.Public().(ed25519.PublicKey),
|
||||||
|
Metadata: rootKey.Metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
av, err := newArtifactVerify(server.URL+"/artifact-signatures", []PublicKey{rootPubKey})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
err = av.Verify(ctx, "1.0.0", artifactPath)
|
||||||
|
// Should fail because artifact was tampered
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "failed to validate artifact")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test URL validation
|
||||||
|
|
||||||
|
func TestArtifactVerify_URLParsing(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
keysBaseURL string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Valid HTTP URL",
|
||||||
|
keysBaseURL: "http://example.com/artifact-signatures",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Valid HTTPS URL",
|
||||||
|
keysBaseURL: "https://example.com/artifact-signatures",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "URL with port",
|
||||||
|
keysBaseURL: "http://localhost:8080/artifact-signatures",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Invalid URL",
|
||||||
|
keysBaseURL: "://invalid",
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := newArtifactVerify(tt.keysBaseURL, nil)
|
||||||
|
if tt.expectError {
|
||||||
|
assert.Error(t, err)
|
||||||
|
} else {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
11
client/internal/updatemanager/update.go
Normal file
11
client/internal/updatemanager/update.go
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
package updatemanager
|
||||||
|
|
||||||
|
import v "github.com/hashicorp/go-version"
|
||||||
|
|
||||||
|
type UpdateInterface interface {
|
||||||
|
StopWatch()
|
||||||
|
SetDaemonVersion(newVersion string) bool
|
||||||
|
SetOnUpdateListener(updateFn func())
|
||||||
|
LatestVersion() *v.Version
|
||||||
|
StartFetcher()
|
||||||
|
}
|
||||||
@@ -131,7 +131,7 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
|
|||||||
c.onHostDnsFn = func([]string) {}
|
c.onHostDnsFn = func([]string) {}
|
||||||
cfg.WgIface = interfaceName
|
cfg.WgIface = interfaceName
|
||||||
|
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
||||||
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
|
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -51,7 +51,7 @@
|
|||||||
</ComponentGroup>
|
</ComponentGroup>
|
||||||
|
|
||||||
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
|
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
|
||||||
<util:CloseApplication Id="CloseNetBirdUI" CloseMessage="no" Target="netbird-ui.exe" RebootPrompt="no" />
|
<util:CloseApplication Id="CloseNetBirdUI" CloseMessage="no" Target="netbird-ui.exe" RebootPrompt="no" TerminateProcess="0" />
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
// Code generated by protoc-gen-go. DO NOT EDIT.
|
// Code generated by protoc-gen-go. DO NOT EDIT.
|
||||||
// versions:
|
// versions:
|
||||||
// protoc-gen-go v1.36.6
|
// protoc-gen-go v1.36.6
|
||||||
// protoc v6.32.1
|
// protoc v3.21.12
|
||||||
// source: daemon.proto
|
// source: daemon.proto
|
||||||
|
|
||||||
package proto
|
package proto
|
||||||
@@ -893,6 +893,7 @@ type UpRequest struct {
|
|||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
|
ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
|
||||||
Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"`
|
Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"`
|
||||||
|
AutoUpdate *bool `protobuf:"varint,3,opt,name=autoUpdate,proto3,oneof" json:"autoUpdate,omitempty"`
|
||||||
unknownFields protoimpl.UnknownFields
|
unknownFields protoimpl.UnknownFields
|
||||||
sizeCache protoimpl.SizeCache
|
sizeCache protoimpl.SizeCache
|
||||||
}
|
}
|
||||||
@@ -941,6 +942,13 @@ func (x *UpRequest) GetUsername() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *UpRequest) GetAutoUpdate() bool {
|
||||||
|
if x != nil && x.AutoUpdate != nil {
|
||||||
|
return *x.AutoUpdate
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type UpResponse struct {
|
type UpResponse struct {
|
||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
unknownFields protoimpl.UnknownFields
|
unknownFields protoimpl.UnknownFields
|
||||||
@@ -5356,6 +5364,94 @@ func (x *WaitJWTTokenResponse) GetExpiresIn() int64 {
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type InstallerResultRequest struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *InstallerResultRequest) Reset() {
|
||||||
|
*x = InstallerResultRequest{}
|
||||||
|
mi := &file_daemon_proto_msgTypes[79]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *InstallerResultRequest) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*InstallerResultRequest) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *InstallerResultRequest) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[79]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use InstallerResultRequest.ProtoReflect.Descriptor instead.
|
||||||
|
func (*InstallerResultRequest) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{79}
|
||||||
|
}
|
||||||
|
|
||||||
|
type InstallerResultResponse struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
Success bool `protobuf:"varint,1,opt,name=success,proto3" json:"success,omitempty"`
|
||||||
|
ErrorMsg string `protobuf:"bytes,2,opt,name=errorMsg,proto3" json:"errorMsg,omitempty"`
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *InstallerResultResponse) Reset() {
|
||||||
|
*x = InstallerResultResponse{}
|
||||||
|
mi := &file_daemon_proto_msgTypes[80]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *InstallerResultResponse) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*InstallerResultResponse) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *InstallerResultResponse) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[80]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use InstallerResultResponse.ProtoReflect.Descriptor instead.
|
||||||
|
func (*InstallerResultResponse) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{80}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *InstallerResultResponse) GetSuccess() bool {
|
||||||
|
if x != nil {
|
||||||
|
return x.Success
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *InstallerResultResponse) GetErrorMsg() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.ErrorMsg
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
type PortInfo_Range struct {
|
type PortInfo_Range struct {
|
||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
|
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
|
||||||
@@ -5366,7 +5462,7 @@ type PortInfo_Range struct {
|
|||||||
|
|
||||||
func (x *PortInfo_Range) Reset() {
|
func (x *PortInfo_Range) Reset() {
|
||||||
*x = PortInfo_Range{}
|
*x = PortInfo_Range{}
|
||||||
mi := &file_daemon_proto_msgTypes[80]
|
mi := &file_daemon_proto_msgTypes[82]
|
||||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
ms.StoreMessageInfo(mi)
|
ms.StoreMessageInfo(mi)
|
||||||
}
|
}
|
||||||
@@ -5378,7 +5474,7 @@ func (x *PortInfo_Range) String() string {
|
|||||||
func (*PortInfo_Range) ProtoMessage() {}
|
func (*PortInfo_Range) ProtoMessage() {}
|
||||||
|
|
||||||
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
|
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
|
||||||
mi := &file_daemon_proto_msgTypes[80]
|
mi := &file_daemon_proto_msgTypes[82]
|
||||||
if x != nil {
|
if x != nil {
|
||||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
if ms.LoadMessageInfo() == nil {
|
if ms.LoadMessageInfo() == nil {
|
||||||
@@ -5502,12 +5598,16 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\buserCode\x18\x01 \x01(\tR\buserCode\x12\x1a\n" +
|
"\buserCode\x18\x01 \x01(\tR\buserCode\x12\x1a\n" +
|
||||||
"\bhostname\x18\x02 \x01(\tR\bhostname\",\n" +
|
"\bhostname\x18\x02 \x01(\tR\bhostname\",\n" +
|
||||||
"\x14WaitSSOLoginResponse\x12\x14\n" +
|
"\x14WaitSSOLoginResponse\x12\x14\n" +
|
||||||
"\x05email\x18\x01 \x01(\tR\x05email\"p\n" +
|
"\x05email\x18\x01 \x01(\tR\x05email\"\xa4\x01\n" +
|
||||||
"\tUpRequest\x12%\n" +
|
"\tUpRequest\x12%\n" +
|
||||||
"\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" +
|
"\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" +
|
||||||
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
|
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01\x12#\n" +
|
||||||
|
"\n" +
|
||||||
|
"autoUpdate\x18\x03 \x01(\bH\x02R\n" +
|
||||||
|
"autoUpdate\x88\x01\x01B\x0e\n" +
|
||||||
"\f_profileNameB\v\n" +
|
"\f_profileNameB\v\n" +
|
||||||
"\t_username\"\f\n" +
|
"\t_usernameB\r\n" +
|
||||||
|
"\v_autoUpdate\"\f\n" +
|
||||||
"\n" +
|
"\n" +
|
||||||
"UpResponse\"\xa1\x01\n" +
|
"UpResponse\"\xa1\x01\n" +
|
||||||
"\rStatusRequest\x12,\n" +
|
"\rStatusRequest\x12,\n" +
|
||||||
@@ -5893,7 +5993,11 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\x14WaitJWTTokenResponse\x12\x14\n" +
|
"\x14WaitJWTTokenResponse\x12\x14\n" +
|
||||||
"\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
|
"\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
|
||||||
"\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
|
"\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
|
||||||
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn*b\n" +
|
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn\"\x18\n" +
|
||||||
|
"\x16InstallerResultRequest\"O\n" +
|
||||||
|
"\x17InstallerResultResponse\x12\x18\n" +
|
||||||
|
"\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" +
|
||||||
|
"\berrorMsg\x18\x02 \x01(\tR\berrorMsg*b\n" +
|
||||||
"\bLogLevel\x12\v\n" +
|
"\bLogLevel\x12\v\n" +
|
||||||
"\aUNKNOWN\x10\x00\x12\t\n" +
|
"\aUNKNOWN\x10\x00\x12\t\n" +
|
||||||
"\x05PANIC\x10\x01\x12\t\n" +
|
"\x05PANIC\x10\x01\x12\t\n" +
|
||||||
@@ -5902,7 +6006,7 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\x04WARN\x10\x04\x12\b\n" +
|
"\x04WARN\x10\x04\x12\b\n" +
|
||||||
"\x04INFO\x10\x05\x12\t\n" +
|
"\x04INFO\x10\x05\x12\t\n" +
|
||||||
"\x05DEBUG\x10\x06\x12\t\n" +
|
"\x05DEBUG\x10\x06\x12\t\n" +
|
||||||
"\x05TRACE\x10\a2\xdb\x12\n" +
|
"\x05TRACE\x10\a2\xb4\x13\n" +
|
||||||
"\rDaemonService\x126\n" +
|
"\rDaemonService\x126\n" +
|
||||||
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
|
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
|
||||||
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
|
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
|
||||||
@@ -5938,7 +6042,8 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
|
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
|
||||||
"\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
|
"\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
|
||||||
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" +
|
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00\x12N\n" +
|
||||||
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00B\bZ\x06/protob\x06proto3"
|
"\x11NotifyOSLifecycle\x12\x1a.daemon.OSLifecycleRequest\x1a\x1b.daemon.OSLifecycleResponse\"\x00\x12W\n" +
|
||||||
|
"\x12GetInstallerResult\x12\x1e.daemon.InstallerResultRequest\x1a\x1f.daemon.InstallerResultResponse\"\x00B\bZ\x06/protob\x06proto3"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
file_daemon_proto_rawDescOnce sync.Once
|
file_daemon_proto_rawDescOnce sync.Once
|
||||||
@@ -5953,7 +6058,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
|
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 4)
|
||||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 82)
|
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 84)
|
||||||
var file_daemon_proto_goTypes = []any{
|
var file_daemon_proto_goTypes = []any{
|
||||||
(LogLevel)(0), // 0: daemon.LogLevel
|
(LogLevel)(0), // 0: daemon.LogLevel
|
||||||
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
|
(OSLifecycleRequest_CycleType)(0), // 1: daemon.OSLifecycleRequest.CycleType
|
||||||
@@ -6038,19 +6143,21 @@ var file_daemon_proto_goTypes = []any{
|
|||||||
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
|
(*RequestJWTAuthResponse)(nil), // 80: daemon.RequestJWTAuthResponse
|
||||||
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
|
(*WaitJWTTokenRequest)(nil), // 81: daemon.WaitJWTTokenRequest
|
||||||
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
|
(*WaitJWTTokenResponse)(nil), // 82: daemon.WaitJWTTokenResponse
|
||||||
nil, // 83: daemon.Network.ResolvedIPsEntry
|
(*InstallerResultRequest)(nil), // 83: daemon.InstallerResultRequest
|
||||||
(*PortInfo_Range)(nil), // 84: daemon.PortInfo.Range
|
(*InstallerResultResponse)(nil), // 84: daemon.InstallerResultResponse
|
||||||
nil, // 85: daemon.SystemEvent.MetadataEntry
|
nil, // 85: daemon.Network.ResolvedIPsEntry
|
||||||
(*durationpb.Duration)(nil), // 86: google.protobuf.Duration
|
(*PortInfo_Range)(nil), // 86: daemon.PortInfo.Range
|
||||||
(*timestamppb.Timestamp)(nil), // 87: google.protobuf.Timestamp
|
nil, // 87: daemon.SystemEvent.MetadataEntry
|
||||||
|
(*durationpb.Duration)(nil), // 88: google.protobuf.Duration
|
||||||
|
(*timestamppb.Timestamp)(nil), // 89: google.protobuf.Timestamp
|
||||||
}
|
}
|
||||||
var file_daemon_proto_depIdxs = []int32{
|
var file_daemon_proto_depIdxs = []int32{
|
||||||
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
|
1, // 0: daemon.OSLifecycleRequest.type:type_name -> daemon.OSLifecycleRequest.CycleType
|
||||||
86, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
88, // 1: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||||
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
27, // 2: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||||
87, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
89, // 3: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||||
87, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
89, // 4: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||||
86, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
88, // 5: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||||
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
|
25, // 6: daemon.SSHServerState.sessions:type_name -> daemon.SSHSessionInfo
|
||||||
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
22, // 7: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||||
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
21, // 8: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||||
@@ -6061,8 +6168,8 @@ var file_daemon_proto_depIdxs = []int32{
|
|||||||
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
57, // 13: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
||||||
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
|
26, // 14: daemon.FullStatus.sshServerState:type_name -> daemon.SSHServerState
|
||||||
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
33, // 15: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
||||||
83, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
85, // 16: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||||
84, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
86, // 17: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||||
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
34, // 18: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
||||||
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
34, // 19: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
||||||
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
35, // 20: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
||||||
@@ -6073,10 +6180,10 @@ var file_daemon_proto_depIdxs = []int32{
|
|||||||
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
54, // 25: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
||||||
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
2, // 26: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
||||||
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
3, // 27: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
||||||
87, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
89, // 28: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||||
85, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
87, // 29: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||||
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
57, // 30: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
||||||
86, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
88, // 31: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||||
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
70, // 32: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
||||||
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
32, // 33: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||||
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
7, // 34: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||||
@@ -6111,40 +6218,42 @@ var file_daemon_proto_depIdxs = []int32{
|
|||||||
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
79, // 63: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
||||||
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
81, // 64: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
||||||
5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
|
5, // 65: daemon.DaemonService.NotifyOSLifecycle:input_type -> daemon.OSLifecycleRequest
|
||||||
8, // 66: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
83, // 66: daemon.DaemonService.GetInstallerResult:input_type -> daemon.InstallerResultRequest
|
||||||
10, // 67: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
8, // 67: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||||
12, // 68: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
10, // 68: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||||
14, // 69: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
12, // 69: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||||
16, // 70: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
14, // 70: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||||
18, // 71: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
16, // 71: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||||
29, // 72: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
18, // 72: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||||
31, // 73: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
29, // 73: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
||||||
31, // 74: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
31, // 74: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||||
36, // 75: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
31, // 75: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||||
38, // 76: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
36, // 76: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
||||||
40, // 77: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
38, // 77: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||||
42, // 78: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
40, // 78: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||||
45, // 79: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
42, // 79: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||||
47, // 80: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
45, // 80: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
||||||
49, // 81: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
47, // 81: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
||||||
51, // 82: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
49, // 82: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||||
55, // 83: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
51, // 83: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
||||||
57, // 84: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
55, // 84: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||||
59, // 85: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
57, // 85: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
||||||
61, // 86: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
59, // 86: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
||||||
63, // 87: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
61, // 87: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
||||||
65, // 88: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
63, // 88: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
||||||
67, // 89: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
65, // 89: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
||||||
69, // 90: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
67, // 90: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
||||||
72, // 91: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
69, // 91: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
||||||
74, // 92: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
72, // 92: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
||||||
76, // 93: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
74, // 93: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
||||||
78, // 94: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
76, // 94: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
||||||
80, // 95: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
78, // 95: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||||
82, // 96: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
80, // 96: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||||
6, // 97: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
|
82, // 97: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||||
66, // [66:98] is the sub-list for method output_type
|
6, // 98: daemon.DaemonService.NotifyOSLifecycle:output_type -> daemon.OSLifecycleResponse
|
||||||
34, // [34:66] is the sub-list for method input_type
|
84, // 99: daemon.DaemonService.GetInstallerResult:output_type -> daemon.InstallerResultResponse
|
||||||
|
67, // [67:100] is the sub-list for method output_type
|
||||||
|
34, // [34:67] is the sub-list for method input_type
|
||||||
34, // [34:34] is the sub-list for extension type_name
|
34, // [34:34] is the sub-list for extension type_name
|
||||||
34, // [34:34] is the sub-list for extension extendee
|
34, // [34:34] is the sub-list for extension extendee
|
||||||
0, // [0:34] is the sub-list for field type_name
|
0, // [0:34] is the sub-list for field type_name
|
||||||
@@ -6174,7 +6283,7 @@ func file_daemon_proto_init() {
|
|||||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
|
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
|
||||||
NumEnums: 4,
|
NumEnums: 4,
|
||||||
NumMessages: 82,
|
NumMessages: 84,
|
||||||
NumExtensions: 0,
|
NumExtensions: 0,
|
||||||
NumServices: 1,
|
NumServices: 1,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -95,6 +95,8 @@ service DaemonService {
|
|||||||
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
||||||
|
|
||||||
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
|
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
|
||||||
|
|
||||||
|
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -215,6 +217,7 @@ message WaitSSOLoginResponse {
|
|||||||
message UpRequest {
|
message UpRequest {
|
||||||
optional string profileName = 1;
|
optional string profileName = 1;
|
||||||
optional string username = 2;
|
optional string username = 2;
|
||||||
|
optional bool autoUpdate = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
message UpResponse {}
|
message UpResponse {}
|
||||||
@@ -772,3 +775,11 @@ message WaitJWTTokenResponse {
|
|||||||
// expiration time in seconds
|
// expiration time in seconds
|
||||||
int64 expiresIn = 3;
|
int64 expiresIn = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
message InstallerResultRequest {
|
||||||
|
}
|
||||||
|
|
||||||
|
message InstallerResultResponse {
|
||||||
|
bool success = 1;
|
||||||
|
string errorMsg = 2;
|
||||||
|
}
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ type DaemonServiceClient interface {
|
|||||||
// WaitJWTToken waits for JWT authentication completion
|
// WaitJWTToken waits for JWT authentication completion
|
||||||
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
|
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
|
||||||
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
|
NotifyOSLifecycle(ctx context.Context, in *OSLifecycleRequest, opts ...grpc.CallOption) (*OSLifecycleResponse, error)
|
||||||
|
GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type daemonServiceClient struct {
|
type daemonServiceClient struct {
|
||||||
@@ -392,6 +393,15 @@ func (c *daemonServiceClient) NotifyOSLifecycle(ctx context.Context, in *OSLifec
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *daemonServiceClient) GetInstallerResult(ctx context.Context, in *InstallerResultRequest, opts ...grpc.CallOption) (*InstallerResultResponse, error) {
|
||||||
|
out := new(InstallerResultResponse)
|
||||||
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/GetInstallerResult", in, out, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DaemonServiceServer is the server API for DaemonService service.
|
// DaemonServiceServer is the server API for DaemonService service.
|
||||||
// All implementations must embed UnimplementedDaemonServiceServer
|
// All implementations must embed UnimplementedDaemonServiceServer
|
||||||
// for forward compatibility
|
// for forward compatibility
|
||||||
@@ -449,6 +459,7 @@ type DaemonServiceServer interface {
|
|||||||
// WaitJWTToken waits for JWT authentication completion
|
// WaitJWTToken waits for JWT authentication completion
|
||||||
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
|
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
|
||||||
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
|
NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error)
|
||||||
|
GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error)
|
||||||
mustEmbedUnimplementedDaemonServiceServer()
|
mustEmbedUnimplementedDaemonServiceServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -552,6 +563,9 @@ func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTo
|
|||||||
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
|
func (UnimplementedDaemonServiceServer) NotifyOSLifecycle(context.Context, *OSLifecycleRequest) (*OSLifecycleResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method NotifyOSLifecycle not implemented")
|
||||||
}
|
}
|
||||||
|
func (UnimplementedDaemonServiceServer) GetInstallerResult(context.Context, *InstallerResultRequest) (*InstallerResultResponse, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method GetInstallerResult not implemented")
|
||||||
|
}
|
||||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||||
|
|
||||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||||
@@ -1144,6 +1158,24 @@ func _DaemonService_NotifyOSLifecycle_Handler(srv interface{}, ctx context.Conte
|
|||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func _DaemonService_GetInstallerResult_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(InstallerResultRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DaemonServiceServer).GetInstallerResult(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/daemon.DaemonService/GetInstallerResult",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DaemonServiceServer).GetInstallerResult(ctx, req.(*InstallerResultRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||||
// It's only intended for direct use with grpc.RegisterService,
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
// and not to be introspected or modified (even as a copy)
|
// and not to be introspected or modified (even as a copy)
|
||||||
@@ -1275,6 +1307,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
|||||||
MethodName: "NotifyOSLifecycle",
|
MethodName: "NotifyOSLifecycle",
|
||||||
Handler: _DaemonService_NotifyOSLifecycle_Handler,
|
Handler: _DaemonService_NotifyOSLifecycle_Handler,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
MethodName: "GetInstallerResult",
|
||||||
|
Handler: _DaemonService_GetInstallerResult_Handler,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Streams: []grpc.StreamDesc{
|
Streams: []grpc.StreamDesc{
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -14,4 +14,4 @@ cd "$script_path"
|
|||||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
|
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
|
||||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
|
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
|
||||||
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
|
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
|
||||||
cd "$old_pwd"
|
cd "$old_pwd"
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ func (s *Server) Start() error {
|
|||||||
s.clientRunning = true
|
s.clientRunning = true
|
||||||
s.clientRunningChan = make(chan struct{})
|
s.clientRunningChan = make(chan struct{})
|
||||||
s.clientGiveUpChan = make(chan struct{})
|
s.clientGiveUpChan = make(chan struct{})
|
||||||
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
go s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, s.clientRunningChan, s.clientGiveUpChan)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -223,7 +223,7 @@ func (s *Server) setDefaultConfigIfNotExists(ctx context.Context) error {
|
|||||||
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
|
// connectWithRetryRuns runs the client connection with a backoff strategy where we retry the operation as additional
|
||||||
// mechanism to keep the client connected even when the connection is lost.
|
// mechanism to keep the client connected even when the connection is lost.
|
||||||
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
// we cancel retry if the client receive a stop or down command, or if disable auto connect is configured.
|
||||||
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}, giveUpChan chan struct{}) {
|
func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}, giveUpChan chan struct{}) {
|
||||||
defer func() {
|
defer func() {
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
s.clientRunning = false
|
s.clientRunning = false
|
||||||
@@ -231,7 +231,7 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
if s.config.DisableAutoConnect {
|
if s.config.DisableAutoConnect {
|
||||||
if err := s.connect(ctx, s.config, s.statusRecorder, runningChan); err != nil {
|
if err := s.connect(ctx, s.config, s.statusRecorder, doInitialAutoUpdate, runningChan); err != nil {
|
||||||
log.Debugf("run client connection exited with error: %v", err)
|
log.Debugf("run client connection exited with error: %v", err)
|
||||||
}
|
}
|
||||||
log.Tracef("client connection exited")
|
log.Tracef("client connection exited")
|
||||||
@@ -260,7 +260,8 @@ func (s *Server) connectWithRetryRuns(ctx context.Context, profileConfig *profil
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
runOperation := func() error {
|
runOperation := func() error {
|
||||||
err := s.connect(ctx, profileConfig, statusRecorder, runningChan)
|
err := s.connect(ctx, profileConfig, statusRecorder, doInitialAutoUpdate, runningChan)
|
||||||
|
doInitialAutoUpdate = false
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
|
log.Debugf("run client connection exited with error: %v. Will retry in the background", err)
|
||||||
return err
|
return err
|
||||||
@@ -728,7 +729,12 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
|||||||
s.clientRunning = true
|
s.clientRunning = true
|
||||||
s.clientRunningChan = make(chan struct{})
|
s.clientRunningChan = make(chan struct{})
|
||||||
s.clientGiveUpChan = make(chan struct{})
|
s.clientGiveUpChan = make(chan struct{})
|
||||||
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, s.clientRunningChan, s.clientGiveUpChan)
|
|
||||||
|
var doAutoUpdate bool
|
||||||
|
if msg != nil && msg.AutoUpdate != nil && *msg.AutoUpdate {
|
||||||
|
doAutoUpdate = true
|
||||||
|
}
|
||||||
|
go s.connectWithRetryRuns(ctx, s.config, s.statusRecorder, doAutoUpdate, s.clientRunningChan, s.clientGiveUpChan)
|
||||||
|
|
||||||
return s.waitForUp(callerCtx)
|
return s.waitForUp(callerCtx)
|
||||||
}
|
}
|
||||||
@@ -1539,9 +1545,9 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest)
|
|||||||
return features, nil
|
return features, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, runningChan chan struct{}) error {
|
func (s *Server) connect(ctx context.Context, config *profilemanager.Config, statusRecorder *peer.Status, doInitialAutoUpdate bool, runningChan chan struct{}) error {
|
||||||
log.Tracef("running client connection")
|
log.Tracef("running client connection")
|
||||||
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder)
|
s.connectClient = internal.NewConnectClient(ctx, config, statusRecorder, doInitialAutoUpdate)
|
||||||
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
|
s.connectClient.SetSyncResponsePersistence(s.persistSyncResponse)
|
||||||
if err := s.connectClient.Run(runningChan); err != nil {
|
if err := s.connectClient.Run(runningChan); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
|
|||||||
t.Setenv(maxRetryTimeVar, "5s")
|
t.Setenv(maxRetryTimeVar, "5s")
|
||||||
t.Setenv(retryMultiplierVar, "1")
|
t.Setenv(retryMultiplierVar, "1")
|
||||||
|
|
||||||
s.connectWithRetryRuns(ctx, config, s.statusRecorder, nil, nil)
|
s.connectWithRetryRuns(ctx, config, s.statusRecorder, false, nil, nil)
|
||||||
if counter < 3 {
|
if counter < 3 {
|
||||||
t.Fatalf("expected counter > 2, got %d", counter)
|
t.Fatalf("expected counter > 2, got %d", counter)
|
||||||
}
|
}
|
||||||
|
|||||||
30
client/server/updateresult.go
Normal file
30
client/server/updateresult.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *Server) GetInstallerResult(ctx context.Context, _ *proto.InstallerResultRequest) (*proto.InstallerResultResponse, error) {
|
||||||
|
inst := installer.New()
|
||||||
|
dir := inst.TempDir()
|
||||||
|
|
||||||
|
rh := installer.NewResultHandler(dir)
|
||||||
|
result, err := rh.Watch(ctx)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to watch update result: %v", err)
|
||||||
|
return &proto.InstallerResultResponse{
|
||||||
|
Success: false,
|
||||||
|
ErrorMsg: err.Error(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.InstallerResultResponse{
|
||||||
|
Success: result.Success,
|
||||||
|
ErrorMsg: result.Error,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -34,6 +34,7 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
protobuf "google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
@@ -43,7 +44,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/ui/desktop"
|
"github.com/netbirdio/netbird/client/ui/desktop"
|
||||||
"github.com/netbirdio/netbird/client/ui/event"
|
"github.com/netbirdio/netbird/client/ui/event"
|
||||||
"github.com/netbirdio/netbird/client/ui/process"
|
"github.com/netbirdio/netbird/client/ui/process"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
@@ -87,22 +87,24 @@ func main() {
|
|||||||
|
|
||||||
// Create the service client (this also builds the settings or networks UI if requested).
|
// Create the service client (this also builds the settings or networks UI if requested).
|
||||||
client := newServiceClient(&newServiceClientArgs{
|
client := newServiceClient(&newServiceClientArgs{
|
||||||
addr: flags.daemonAddr,
|
addr: flags.daemonAddr,
|
||||||
logFile: logFile,
|
logFile: logFile,
|
||||||
app: a,
|
app: a,
|
||||||
showSettings: flags.showSettings,
|
showSettings: flags.showSettings,
|
||||||
showNetworks: flags.showNetworks,
|
showNetworks: flags.showNetworks,
|
||||||
showLoginURL: flags.showLoginURL,
|
showLoginURL: flags.showLoginURL,
|
||||||
showDebug: flags.showDebug,
|
showDebug: flags.showDebug,
|
||||||
showProfiles: flags.showProfiles,
|
showProfiles: flags.showProfiles,
|
||||||
showQuickActions: flags.showQuickActions,
|
showQuickActions: flags.showQuickActions,
|
||||||
|
showUpdate: flags.showUpdate,
|
||||||
|
showUpdateVersion: flags.showUpdateVersion,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Watch for theme/settings changes to update the icon.
|
// Watch for theme/settings changes to update the icon.
|
||||||
go watchSettingsChanges(a, client)
|
go watchSettingsChanges(a, client)
|
||||||
|
|
||||||
// Run in window mode if any UI flag was set.
|
// Run in window mode if any UI flag was set.
|
||||||
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions {
|
if flags.showSettings || flags.showNetworks || flags.showDebug || flags.showLoginURL || flags.showProfiles || flags.showQuickActions || flags.showUpdate {
|
||||||
a.Run()
|
a.Run()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -128,15 +130,17 @@ func main() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type cliFlags struct {
|
type cliFlags struct {
|
||||||
daemonAddr string
|
daemonAddr string
|
||||||
showSettings bool
|
showSettings bool
|
||||||
showNetworks bool
|
showNetworks bool
|
||||||
showProfiles bool
|
showProfiles bool
|
||||||
showDebug bool
|
showDebug bool
|
||||||
showLoginURL bool
|
showLoginURL bool
|
||||||
showQuickActions bool
|
showQuickActions bool
|
||||||
errorMsg string
|
errorMsg string
|
||||||
saveLogsInFile bool
|
saveLogsInFile bool
|
||||||
|
showUpdate bool
|
||||||
|
showUpdateVersion string
|
||||||
}
|
}
|
||||||
|
|
||||||
// parseFlags reads and returns all needed command-line flags.
|
// parseFlags reads and returns all needed command-line flags.
|
||||||
@@ -156,6 +160,8 @@ func parseFlags() *cliFlags {
|
|||||||
flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window")
|
flag.StringVar(&flags.errorMsg, "error-msg", "", "displays an error message window")
|
||||||
flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
|
flag.BoolVar(&flags.saveLogsInFile, "use-log-file", false, fmt.Sprintf("save logs in a file: %s/netbird-ui-PID.log", os.TempDir()))
|
||||||
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
|
flag.BoolVar(&flags.showLoginURL, "login-url", false, "show login URL in a popup window")
|
||||||
|
flag.BoolVar(&flags.showUpdate, "update", false, "show update progress window")
|
||||||
|
flag.StringVar(&flags.showUpdateVersion, "update-version", "", "version to update to")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
return &flags
|
return &flags
|
||||||
}
|
}
|
||||||
@@ -319,6 +325,8 @@ type serviceClient struct {
|
|||||||
mExitNodeDeselectAll *systray.MenuItem
|
mExitNodeDeselectAll *systray.MenuItem
|
||||||
logFile string
|
logFile string
|
||||||
wLoginURL fyne.Window
|
wLoginURL fyne.Window
|
||||||
|
wUpdateProgress fyne.Window
|
||||||
|
updateContextCancel context.CancelFunc
|
||||||
|
|
||||||
connectCancel context.CancelFunc
|
connectCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
@@ -329,15 +337,17 @@ type menuHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type newServiceClientArgs struct {
|
type newServiceClientArgs struct {
|
||||||
addr string
|
addr string
|
||||||
logFile string
|
logFile string
|
||||||
app fyne.App
|
app fyne.App
|
||||||
showSettings bool
|
showSettings bool
|
||||||
showNetworks bool
|
showNetworks bool
|
||||||
showDebug bool
|
showDebug bool
|
||||||
showLoginURL bool
|
showLoginURL bool
|
||||||
showProfiles bool
|
showProfiles bool
|
||||||
showQuickActions bool
|
showQuickActions bool
|
||||||
|
showUpdate bool
|
||||||
|
showUpdateVersion string
|
||||||
}
|
}
|
||||||
|
|
||||||
// newServiceClient instance constructor
|
// newServiceClient instance constructor
|
||||||
@@ -355,7 +365,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
|||||||
|
|
||||||
showAdvancedSettings: args.showSettings,
|
showAdvancedSettings: args.showSettings,
|
||||||
showNetworks: args.showNetworks,
|
showNetworks: args.showNetworks,
|
||||||
update: version.NewUpdate("nb/client-ui"),
|
update: version.NewUpdateAndStart("nb/client-ui"),
|
||||||
}
|
}
|
||||||
|
|
||||||
s.eventHandler = newEventHandler(s)
|
s.eventHandler = newEventHandler(s)
|
||||||
@@ -375,6 +385,8 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
|||||||
s.showProfilesUI()
|
s.showProfilesUI()
|
||||||
case args.showQuickActions:
|
case args.showQuickActions:
|
||||||
s.showQuickActionsUI()
|
s.showQuickActionsUI()
|
||||||
|
case args.showUpdate:
|
||||||
|
s.showUpdateProgress(ctx, args.showUpdateVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s
|
return s
|
||||||
@@ -814,7 +826,7 @@ func (s *serviceClient) handleSSOLogin(ctx context.Context, loginResp *proto.Log
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) menuUpClick(ctx context.Context) error {
|
func (s *serviceClient) menuUpClick(ctx context.Context, wannaAutoUpdate bool) error {
|
||||||
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
|
systray.SetTemplateIcon(iconConnectingMacOS, s.icConnecting)
|
||||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -836,7 +848,9 @@ func (s *serviceClient) menuUpClick(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := conn.Up(ctx, &proto.UpRequest{}); err != nil {
|
if _, err := s.conn.Up(s.ctx, &proto.UpRequest{
|
||||||
|
AutoUpdate: protobuf.Bool(wannaAutoUpdate),
|
||||||
|
}); err != nil {
|
||||||
return fmt.Errorf("start connection: %w", err)
|
return fmt.Errorf("start connection: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1097,6 +1111,26 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
s.updateExitNodes()
|
s.updateExitNodes()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
|
||||||
|
// todo use new Category
|
||||||
|
if windowAction, ok := event.Metadata["progress_window"]; ok {
|
||||||
|
targetVersion, ok := event.Metadata["version"]
|
||||||
|
if !ok {
|
||||||
|
targetVersion = "unknown"
|
||||||
|
}
|
||||||
|
log.Debugf("window action: %v", windowAction)
|
||||||
|
if windowAction == "show" {
|
||||||
|
if s.updateContextCancel != nil {
|
||||||
|
s.updateContextCancel()
|
||||||
|
s.updateContextCancel = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
subCtx, cancel := context.WithCancel(s.ctx)
|
||||||
|
go s.eventHandler.runSelfCommand(subCtx, "update", "--update-version", targetVersion)
|
||||||
|
s.updateContextCancel = cancel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
go s.eventManager.Start(s.ctx)
|
go s.eventManager.Start(s.ctx)
|
||||||
go s.eventHandler.listen(s.ctx)
|
go s.eventHandler.listen(s.ctx)
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ func (h *eventHandler) handleConnectClick() {
|
|||||||
go func() {
|
go func() {
|
||||||
defer connectCancel()
|
defer connectCancel()
|
||||||
|
|
||||||
if err := h.client.menuUpClick(connectCtx); err != nil {
|
if err := h.client.menuUpClick(connectCtx, true); err != nil {
|
||||||
st, ok := status.FromError(err)
|
st, ok := status.FromError(err)
|
||||||
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
|
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
|
||||||
log.Debugf("connect operation cancelled by user")
|
log.Debugf("connect operation cancelled by user")
|
||||||
@@ -185,7 +185,7 @@ func (h *eventHandler) handleAdvancedSettingsClick() {
|
|||||||
go func() {
|
go func() {
|
||||||
defer h.client.mAdvancedSettings.Enable()
|
defer h.client.mAdvancedSettings.Enable()
|
||||||
defer h.client.getSrvConfig()
|
defer h.client.getSrvConfig()
|
||||||
h.runSelfCommand(h.client.ctx, "settings", "true")
|
h.runSelfCommand(h.client.ctx, "settings")
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,7 +193,7 @@ func (h *eventHandler) handleCreateDebugBundleClick() {
|
|||||||
h.client.mCreateDebugBundle.Disable()
|
h.client.mCreateDebugBundle.Disable()
|
||||||
go func() {
|
go func() {
|
||||||
defer h.client.mCreateDebugBundle.Enable()
|
defer h.client.mCreateDebugBundle.Enable()
|
||||||
h.runSelfCommand(h.client.ctx, "debug", "true")
|
h.runSelfCommand(h.client.ctx, "debug")
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -217,7 +217,7 @@ func (h *eventHandler) handleNetworksClick() {
|
|||||||
h.client.mNetworks.Disable()
|
h.client.mNetworks.Disable()
|
||||||
go func() {
|
go func() {
|
||||||
defer h.client.mNetworks.Enable()
|
defer h.client.mNetworks.Enable()
|
||||||
h.runSelfCommand(h.client.ctx, "networks", "true")
|
h.runSelfCommand(h.client.ctx, "networks")
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -237,17 +237,21 @@ func (h *eventHandler) updateConfigWithErr() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) {
|
func (h *eventHandler) runSelfCommand(ctx context.Context, command string, args ...string) {
|
||||||
proc, err := os.Executable()
|
proc, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("error getting executable path: %v", err)
|
log.Errorf("error getting executable path: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd := exec.CommandContext(ctx, proc,
|
// Build the full command arguments
|
||||||
fmt.Sprintf("--%s=%s", command, arg),
|
cmdArgs := []string{
|
||||||
|
fmt.Sprintf("--%s=true", command),
|
||||||
fmt.Sprintf("--daemon-addr=%s", h.client.addr),
|
fmt.Sprintf("--daemon-addr=%s", h.client.addr),
|
||||||
)
|
}
|
||||||
|
cmdArgs = append(cmdArgs, args...)
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, proc, cmdArgs...)
|
||||||
|
|
||||||
if out := h.client.attachOutput(cmd); out != nil {
|
if out := h.client.attachOutput(cmd); out != nil {
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -257,17 +261,17 @@ func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string)
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("running command: %s --%s=%s --daemon-addr=%s", proc, command, arg, h.client.addr)
|
log.Printf("running command: %s", cmd.String())
|
||||||
|
|
||||||
if err := cmd.Run(); err != nil {
|
if err := cmd.Run(); err != nil {
|
||||||
var exitErr *exec.ExitError
|
var exitErr *exec.ExitError
|
||||||
if errors.As(err, &exitErr) {
|
if errors.As(err, &exitErr) {
|
||||||
log.Printf("command '%s %s' failed with exit code %d", command, arg, exitErr.ExitCode())
|
log.Printf("command '%s' failed with exit code %d", cmd.String(), exitErr.ExitCode())
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("command '%s %s' completed successfully", command, arg)
|
log.Printf("command '%s' completed successfully", cmd.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *eventHandler) logout(ctx context.Context) error {
|
func (h *eventHandler) logout(ctx context.Context) error {
|
||||||
|
|||||||
@@ -397,7 +397,7 @@ type profileMenu struct {
|
|||||||
logoutSubItem *subItem
|
logoutSubItem *subItem
|
||||||
profilesState []Profile
|
profilesState []Profile
|
||||||
downClickCallback func() error
|
downClickCallback func() error
|
||||||
upClickCallback func(context.Context) error
|
upClickCallback func(context.Context, bool) error
|
||||||
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
|
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
|
||||||
loadSettingsCallback func()
|
loadSettingsCallback func()
|
||||||
app fyne.App
|
app fyne.App
|
||||||
@@ -411,7 +411,7 @@ type newProfileMenuArgs struct {
|
|||||||
profileMenuItem *systray.MenuItem
|
profileMenuItem *systray.MenuItem
|
||||||
emailMenuItem *systray.MenuItem
|
emailMenuItem *systray.MenuItem
|
||||||
downClickCallback func() error
|
downClickCallback func() error
|
||||||
upClickCallback func(context.Context) error
|
upClickCallback func(context.Context, bool) error
|
||||||
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
|
getSrvClientCallback func(timeout time.Duration) (proto.DaemonServiceClient, error)
|
||||||
loadSettingsCallback func()
|
loadSettingsCallback func()
|
||||||
app fyne.App
|
app fyne.App
|
||||||
@@ -579,7 +579,7 @@ func (p *profileMenu) refresh() {
|
|||||||
connectCtx, connectCancel := context.WithCancel(p.ctx)
|
connectCtx, connectCancel := context.WithCancel(p.ctx)
|
||||||
p.serviceClient.connectCancel = connectCancel
|
p.serviceClient.connectCancel = connectCancel
|
||||||
|
|
||||||
if err := p.upClickCallback(connectCtx); err != nil {
|
if err := p.upClickCallback(connectCtx, false); err != nil {
|
||||||
log.Errorf("failed to handle up click after switching profile: %v", err)
|
log.Errorf("failed to handle up click after switching profile: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -267,7 +267,7 @@ func (s *serviceClient) showQuickActionsUI() {
|
|||||||
|
|
||||||
connCmd := connectCommand{
|
connCmd := connectCommand{
|
||||||
connectClient: func() error {
|
connectClient: func() error {
|
||||||
return s.menuUpClick(s.ctx)
|
return s.menuUpClick(s.ctx, false)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
140
client/ui/update.go
Normal file
140
client/ui/update.go
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
//go:build !(linux && 386)
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"fyne.io/fyne/v2/container"
|
||||||
|
"fyne.io/fyne/v2/widget"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (s *serviceClient) showUpdateProgress(ctx context.Context, version string) {
|
||||||
|
log.Infof("show installer progress window: %s", version)
|
||||||
|
s.wUpdateProgress = s.app.NewWindow("Automatically updating client")
|
||||||
|
|
||||||
|
statusLabel := widget.NewLabel("Updating...")
|
||||||
|
infoLabel := widget.NewLabel(fmt.Sprintf("Your client version is older than the auto-update version set in Management.\nUpdating client to: %s.", version))
|
||||||
|
content := container.NewVBox(infoLabel, statusLabel)
|
||||||
|
s.wUpdateProgress.SetContent(content)
|
||||||
|
s.wUpdateProgress.CenterOnScreen()
|
||||||
|
s.wUpdateProgress.SetFixedSize(true)
|
||||||
|
s.wUpdateProgress.SetCloseIntercept(func() {
|
||||||
|
// this is empty to lock window until result known
|
||||||
|
})
|
||||||
|
s.wUpdateProgress.RequestFocus()
|
||||||
|
s.wUpdateProgress.Show()
|
||||||
|
|
||||||
|
updateWindowCtx, cancel := context.WithTimeout(ctx, 15*time.Minute)
|
||||||
|
|
||||||
|
// Initialize dot updater
|
||||||
|
updateText := dotUpdater()
|
||||||
|
|
||||||
|
// Channel to receive the result from RPC call
|
||||||
|
resultErrCh := make(chan error, 1)
|
||||||
|
resultOkCh := make(chan struct{}, 1)
|
||||||
|
|
||||||
|
// Start RPC call in background
|
||||||
|
go func() {
|
||||||
|
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||||
|
if err != nil {
|
||||||
|
log.Infof("backend not reachable, upgrade in progress: %v", err)
|
||||||
|
close(resultOkCh)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := conn.GetInstallerResult(updateWindowCtx, &proto.InstallerResultRequest{})
|
||||||
|
if err != nil {
|
||||||
|
log.Infof("backend stopped responding, upgrade in progress: %v", err)
|
||||||
|
close(resultOkCh)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !resp.Success {
|
||||||
|
resultErrCh <- mapInstallError(resp.ErrorMsg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Success
|
||||||
|
close(resultOkCh)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Update UI with dots and wait for result
|
||||||
|
go func() {
|
||||||
|
ticker := time.NewTicker(time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// allow closing update window after 10 sec
|
||||||
|
timerResetCloseInterceptor := time.NewTimer(10 * time.Second)
|
||||||
|
defer timerResetCloseInterceptor.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-updateWindowCtx.Done():
|
||||||
|
s.showInstallerResult(statusLabel, updateWindowCtx.Err())
|
||||||
|
return
|
||||||
|
case err := <-resultErrCh:
|
||||||
|
s.showInstallerResult(statusLabel, err)
|
||||||
|
return
|
||||||
|
case <-resultOkCh:
|
||||||
|
log.Info("backend exited, upgrade in progress, closing all UI")
|
||||||
|
killParentUIProcess()
|
||||||
|
s.app.Quit()
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
statusLabel.SetText(updateText())
|
||||||
|
case <-timerResetCloseInterceptor.C:
|
||||||
|
s.wUpdateProgress.SetCloseIntercept(nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *serviceClient) showInstallerResult(statusLabel *widget.Label, err error) {
|
||||||
|
s.wUpdateProgress.SetCloseIntercept(nil)
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, context.DeadlineExceeded):
|
||||||
|
log.Warn("update watcher timed out")
|
||||||
|
statusLabel.SetText("Update timed out. Please try again.")
|
||||||
|
case errors.Is(err, context.Canceled):
|
||||||
|
log.Info("update watcher canceled")
|
||||||
|
statusLabel.SetText("Update canceled.")
|
||||||
|
case err != nil:
|
||||||
|
log.Errorf("update failed: %v", err)
|
||||||
|
statusLabel.SetText("Update failed: " + err.Error())
|
||||||
|
default:
|
||||||
|
s.wUpdateProgress.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dotUpdater returns a closure that cycles through dots for a loading animation.
|
||||||
|
func dotUpdater() func() string {
|
||||||
|
dotCount := 0
|
||||||
|
return func() string {
|
||||||
|
dotCount = (dotCount + 1) % 4
|
||||||
|
return fmt.Sprintf("%s%s", "Updating", strings.Repeat(".", dotCount))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mapInstallError(msg string) error {
|
||||||
|
msg = strings.ToLower(strings.TrimSpace(msg))
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.Contains(msg, "deadline exceeded"), strings.Contains(msg, "timeout"):
|
||||||
|
return context.DeadlineExceeded
|
||||||
|
case strings.Contains(msg, "canceled"), strings.Contains(msg, "cancelled"):
|
||||||
|
return context.Canceled
|
||||||
|
case msg == "":
|
||||||
|
return errors.New("unknown update error")
|
||||||
|
default:
|
||||||
|
return errors.New(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
7
client/ui/update_notwindows.go
Normal file
7
client/ui/update_notwindows.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build !windows && !(linux && 386)
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
func killParentUIProcess() {
|
||||||
|
// No-op on non-Windows platforms
|
||||||
|
}
|
||||||
44
client/ui/update_windows.go
Normal file
44
client/ui/update_windows.go
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
//go:build windows
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
nbprocess "github.com/netbirdio/netbird/client/ui/process"
|
||||||
|
)
|
||||||
|
|
||||||
|
// killParentUIProcess finds and kills the parent systray UI process on Windows.
|
||||||
|
// This is a workaround in case the MSI installer fails to properly terminate the UI process.
|
||||||
|
// The installer should handle this via util:CloseApplication with TerminateProcess, but this
|
||||||
|
// provides an additional safety mechanism to ensure the UI is closed before the upgrade proceeds.
|
||||||
|
func killParentUIProcess() {
|
||||||
|
pid, running, err := nbprocess.IsAnotherProcessRunning()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to check for parent UI process: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !running {
|
||||||
|
log.Debug("no parent UI process found to kill")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("killing parent UI process (PID: %d)", pid)
|
||||||
|
|
||||||
|
// Open the process with terminate rights
|
||||||
|
handle, err := windows.OpenProcess(windows.PROCESS_TERMINATE, false, uint32(pid))
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to open parent process %d: %v", pid, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = windows.CloseHandle(handle)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Terminate the process with exit code 0
|
||||||
|
if err := windows.TerminateProcess(handle, 0); err != nil {
|
||||||
|
log.Warnf("failed to terminate parent process %d: %v", pid, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -183,7 +183,7 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
|||||||
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
|
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", s.listener.Addr().String())
|
||||||
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
|
s.serveGRPCWithHTTP(ctx, s.listener, rootHandler, tlsEnabled)
|
||||||
|
|
||||||
s.update = version.NewUpdate("nb/management")
|
s.update = version.NewUpdateAndStart("nb/management")
|
||||||
s.update.SetDaemonVersion(version.NetbirdVersion())
|
s.update.SetDaemonVersion(version.NetbirdVersion())
|
||||||
s.update.SetOnUpdateListener(func() {
|
s.update.SetOnUpdateListener(func() {
|
||||||
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())
|
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())
|
||||||
|
|||||||
@@ -83,6 +83,10 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken
|
|||||||
return nbConfig
|
return nbConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// toPeerConfig builds a proto.PeerConfig from internal peer, network, DNS name, and settings.
|
||||||
|
//
|
||||||
|
// The returned PeerConfig includes the peer's IP with network mask, FQDN, SSH configuration
|
||||||
|
// (including JWT config when SSH is enabled), and flags for routing DNS resolution and lazy connections.
|
||||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.PeerConfig {
|
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow) *proto.PeerConfig {
|
||||||
netmask, _ := network.Net.Mask.Size()
|
netmask, _ := network.Net.Mask.Size()
|
||||||
fqdn := peer.FQDN(dnsName)
|
fqdn := peer.FQDN(dnsName)
|
||||||
@@ -104,7 +108,16 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
// ToSyncResponse constructs a proto.SyncResponse that bundles the peer's runtime configuration,
|
||||||
|
// network map (routes, DNS, peer lists, firewall and forwarding rules), Netbird configuration,
|
||||||
|
// and posture checks for a sync operation.
|
||||||
|
//
|
||||||
|
// The response includes PeerConfig, NetworkMap (with Serial, Routes, DNSConfig, PeerConfig,
|
||||||
|
// RemotePeers, OfflinePeers, FirewallRules, RoutesFirewallRules, and ForwardingRules when present),
|
||||||
|
// NetbirdConfig (extended with integrations), and Checks. Remote peer lists' "IsEmpty" flags are
|
||||||
|
// set based on their lengths. If remotePeerGroupsLookup is non-nil, each remote peer's Groups and
|
||||||
|
// UserId fields are populated using GetPeerGroupNames for that peer.
|
||||||
|
func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nbconfig.HttpServerConfig, deviceFlowConfig *nbconfig.DeviceAuthorizationFlow, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *cache.DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64, remotePeerGroupsLookup PeerGroupsLookup) *proto.SyncResponse {
|
||||||
response := &proto.SyncResponse{
|
response := &proto.SyncResponse{
|
||||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
|
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, httpConfig, deviceFlowConfig),
|
||||||
NetworkMap: &proto.NetworkMap{
|
NetworkMap: &proto.NetworkMap{
|
||||||
@@ -122,13 +135,13 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
|||||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||||
|
|
||||||
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
remotePeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||||
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName)
|
remotePeers = appendRemotePeerConfig(remotePeers, networkMap.Peers, dnsName, remotePeerGroupsLookup)
|
||||||
response.RemotePeers = remotePeers
|
response.RemotePeers = remotePeers
|
||||||
response.NetworkMap.RemotePeers = remotePeers
|
response.NetworkMap.RemotePeers = remotePeers
|
||||||
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
response.RemotePeersIsEmpty = len(remotePeers) == 0
|
||||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||||
|
|
||||||
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
|
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName, remotePeerGroupsLookup)
|
||||||
|
|
||||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
||||||
response.NetworkMap.FirewallRules = firewallRules
|
response.NetworkMap.FirewallRules = firewallRules
|
||||||
@@ -149,14 +162,58 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
|||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
// PeerGroupsLookup provides group names for a peer ID
|
||||||
|
type PeerGroupsLookup interface {
|
||||||
|
GetPeerGroupNames(peerID string) []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccountPeerGroupsLookup implements PeerGroupsLookup using a pre-built reverse index
|
||||||
|
// for O(1) lookup performance instead of O(N*M) iteration.
|
||||||
|
type AccountPeerGroupsLookup struct {
|
||||||
|
peerToGroups map[string][]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAccountPeerGroupsLookup creates a new AccountPeerGroupsLookup from an Account.
|
||||||
|
// NewAccountPeerGroupsLookup builds an AccountPeerGroupsLookup containing a reverse index
|
||||||
|
// from peer ID to the names of groups that include that peer. If account is nil, it returns nil.
|
||||||
|
func NewAccountPeerGroupsLookup(account *types.Account) *AccountPeerGroupsLookup {
|
||||||
|
if account == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
peerToGroups := make(map[string][]string)
|
||||||
|
for _, group := range account.Groups {
|
||||||
|
for _, peerID := range group.Peers {
|
||||||
|
peerToGroups[peerID] = append(peerToGroups[peerID], group.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &AccountPeerGroupsLookup{peerToGroups: peerToGroups}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPeerGroupNames returns the group names for a given peer ID.
|
||||||
|
// Returns nil if the peer is not found in any group.
|
||||||
|
func (a *AccountPeerGroupsLookup) GetPeerGroupNames(peerID string) []string {
|
||||||
|
if a == nil || a.peerToGroups == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return a.peerToGroups[peerID]
|
||||||
|
}
|
||||||
|
|
||||||
|
// appendRemotePeerConfig appends a RemotePeerConfig for each peer in peers to dst.
|
||||||
|
// For each peer it adds a RemotePeerConfig populated with the WireGuard public key, an /32 allowed IP derived from the peer IP, SSH public key, FQDN (computed using dnsName), agent version, group names retrieved from groupsLookup when provided, and the user ID, and returns the extended slice.
|
||||||
|
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string, groupsLookup PeerGroupsLookup) []*proto.RemotePeerConfig {
|
||||||
for _, rPeer := range peers {
|
for _, rPeer := range peers {
|
||||||
|
var groups []string
|
||||||
|
if groupsLookup != nil {
|
||||||
|
groups = groupsLookup.GetPeerGroupNames(rPeer.ID)
|
||||||
|
}
|
||||||
dst = append(dst, &proto.RemotePeerConfig{
|
dst = append(dst, &proto.RemotePeerConfig{
|
||||||
WgPubKey: rPeer.Key,
|
WgPubKey: rPeer.Key,
|
||||||
AllowedIps: []string{rPeer.IP.String() + "/32"},
|
AllowedIps: []string{rPeer.IP.String() + "/32"},
|
||||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||||
Fqdn: rPeer.FQDN(dnsName),
|
Fqdn: rPeer.FQDN(dnsName),
|
||||||
AgentVersion: rPeer.Meta.WtVersion,
|
AgentVersion: rPeer.Meta.WtVersion,
|
||||||
|
Groups: groups,
|
||||||
|
UserId: rPeer.UserID,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return dst
|
return dst
|
||||||
@@ -402,4 +459,4 @@ func deriveIssuerFromTokenEndpoint(tokenEndpoint string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Sprintf("%s://%s/", u.Scheme, u.Host)
|
return fmt.Sprintf("%s://%s/", u.Scheme, u.Host)
|
||||||
}
|
}
|
||||||
@@ -321,7 +321,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
|
|
||||||
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled ||
|
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled ||
|
||||||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
|
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
|
||||||
oldSettings.DNSDomain != newSettings.DNSDomain {
|
oldSettings.DNSDomain != newSettings.DNSDomain ||
|
||||||
|
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion {
|
||||||
updateAccountPeers = true
|
updateAccountPeers = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -360,6 +361,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
|||||||
am.handleLazyConnectionSettings(ctx, oldSettings, newSettings, userID, accountID)
|
am.handleLazyConnectionSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||||
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||||
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||||
|
am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||||
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
|
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -451,6 +453,14 @@ func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (am *DefaultAccountManager) handleAutoUpdateVersionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
|
||||||
|
if oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion {
|
||||||
|
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountAutoUpdateVersionUpdated, map[string]any{
|
||||||
|
"version": newSettings.AutoUpdateVersion,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
|
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
|
||||||
if newSettings.PeerInactivityExpirationEnabled {
|
if newSettings.PeerInactivityExpirationEnabled {
|
||||||
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
|
||||||
|
|||||||
@@ -181,6 +181,8 @@ const (
|
|||||||
UserRejected Activity = 90
|
UserRejected Activity = 90
|
||||||
UserCreated Activity = 91
|
UserCreated Activity = 91
|
||||||
|
|
||||||
|
AccountAutoUpdateVersionUpdated Activity = 92
|
||||||
|
|
||||||
AccountDeleted Activity = 99999
|
AccountDeleted Activity = 99999
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -287,9 +289,12 @@ var activityMap = map[Activity]Code{
|
|||||||
AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"},
|
AccountNetworkRangeUpdated: {"Account network range updated", "account.network.range.update"},
|
||||||
|
|
||||||
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
|
PeerIPUpdated: {"Peer IP updated", "peer.ip.update"},
|
||||||
UserApproved: {"User approved", "user.approve"},
|
|
||||||
UserRejected: {"User rejected", "user.reject"},
|
UserApproved: {"User approved", "user.approve"},
|
||||||
UserCreated: {"User created", "user.create"},
|
UserRejected: {"User rejected", "user.reject"},
|
||||||
|
UserCreated: {"User created", "user.create"},
|
||||||
|
|
||||||
|
AccountAutoUpdateVersionUpdated: {"Account AutoUpdate Version updated", "account.settings.auto.version.update"},
|
||||||
}
|
}
|
||||||
|
|
||||||
// StringCode returns a string code of the activity
|
// StringCode returns a string code of the activity
|
||||||
|
|||||||
@@ -3,12 +3,15 @@ package accounts
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
|
|
||||||
|
goversion "github.com/hashicorp/go-version"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
@@ -26,7 +29,9 @@ const (
|
|||||||
// MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16)
|
// MinNetworkBits is the minimum prefix length for IPv4 network ranges (e.g., /29 gives 8 addresses, /28 gives 16)
|
||||||
MinNetworkBitsIPv4 = 28
|
MinNetworkBitsIPv4 = 28
|
||||||
// MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges
|
// MinNetworkBitsIPv6 is the minimum prefix length for IPv6 network ranges
|
||||||
MinNetworkBitsIPv6 = 120
|
MinNetworkBitsIPv6 = 120
|
||||||
|
disableAutoUpdate = "disabled"
|
||||||
|
autoUpdateLatestVersion = "latest"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handler is a handler that handles the server.Account HTTP endpoints
|
// handler is a handler that handles the server.Account HTTP endpoints
|
||||||
@@ -162,6 +167,61 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
|||||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJSONRequestBody) (*types.Settings, error) {
|
||||||
|
returnSettings := &types.Settings{
|
||||||
|
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
||||||
|
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
||||||
|
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
|
||||||
|
|
||||||
|
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
|
||||||
|
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Settings.Extra != nil {
|
||||||
|
returnSettings.Extra = &types.ExtraSettings{
|
||||||
|
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
|
||||||
|
UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
|
||||||
|
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
|
||||||
|
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
|
||||||
|
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Settings.JwtGroupsEnabled != nil {
|
||||||
|
returnSettings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
|
||||||
|
}
|
||||||
|
if req.Settings.GroupsPropagationEnabled != nil {
|
||||||
|
returnSettings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled
|
||||||
|
}
|
||||||
|
if req.Settings.JwtGroupsClaimName != nil {
|
||||||
|
returnSettings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
|
||||||
|
}
|
||||||
|
if req.Settings.JwtAllowGroups != nil {
|
||||||
|
returnSettings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
||||||
|
}
|
||||||
|
if req.Settings.RoutingPeerDnsResolutionEnabled != nil {
|
||||||
|
returnSettings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled
|
||||||
|
}
|
||||||
|
if req.Settings.DnsDomain != nil {
|
||||||
|
returnSettings.DNSDomain = *req.Settings.DnsDomain
|
||||||
|
}
|
||||||
|
if req.Settings.LazyConnectionEnabled != nil {
|
||||||
|
returnSettings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
|
||||||
|
}
|
||||||
|
if req.Settings.AutoUpdateVersion != nil {
|
||||||
|
_, err := goversion.NewSemver(*req.Settings.AutoUpdateVersion)
|
||||||
|
if *req.Settings.AutoUpdateVersion == autoUpdateLatestVersion ||
|
||||||
|
*req.Settings.AutoUpdateVersion == disableAutoUpdate ||
|
||||||
|
err == nil {
|
||||||
|
returnSettings.AutoUpdateVersion = *req.Settings.AutoUpdateVersion
|
||||||
|
} else if *req.Settings.AutoUpdateVersion != "" {
|
||||||
|
return nil, fmt.Errorf("invalid AutoUpdateVersion")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return returnSettings, nil
|
||||||
|
}
|
||||||
|
|
||||||
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
||||||
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||||
@@ -186,45 +246,10 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
settings := &types.Settings{
|
settings, err := h.updateAccountRequestSettings(req)
|
||||||
PeerLoginExpirationEnabled: req.Settings.PeerLoginExpirationEnabled,
|
if err != nil {
|
||||||
PeerLoginExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerLoginExpiration)),
|
util.WriteError(r.Context(), err, w)
|
||||||
RegularUsersViewBlocked: req.Settings.RegularUsersViewBlocked,
|
return
|
||||||
|
|
||||||
PeerInactivityExpirationEnabled: req.Settings.PeerInactivityExpirationEnabled,
|
|
||||||
PeerInactivityExpiration: time.Duration(float64(time.Second.Nanoseconds()) * float64(req.Settings.PeerInactivityExpiration)),
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Settings.Extra != nil {
|
|
||||||
settings.Extra = &types.ExtraSettings{
|
|
||||||
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
|
|
||||||
UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
|
|
||||||
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
|
|
||||||
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
|
|
||||||
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Settings.JwtGroupsEnabled != nil {
|
|
||||||
settings.JWTGroupsEnabled = *req.Settings.JwtGroupsEnabled
|
|
||||||
}
|
|
||||||
if req.Settings.GroupsPropagationEnabled != nil {
|
|
||||||
settings.GroupsPropagationEnabled = *req.Settings.GroupsPropagationEnabled
|
|
||||||
}
|
|
||||||
if req.Settings.JwtGroupsClaimName != nil {
|
|
||||||
settings.JWTGroupsClaimName = *req.Settings.JwtGroupsClaimName
|
|
||||||
}
|
|
||||||
if req.Settings.JwtAllowGroups != nil {
|
|
||||||
settings.JWTAllowGroups = *req.Settings.JwtAllowGroups
|
|
||||||
}
|
|
||||||
if req.Settings.RoutingPeerDnsResolutionEnabled != nil {
|
|
||||||
settings.RoutingPeerDNSResolutionEnabled = *req.Settings.RoutingPeerDnsResolutionEnabled
|
|
||||||
}
|
|
||||||
if req.Settings.DnsDomain != nil {
|
|
||||||
settings.DNSDomain = *req.Settings.DnsDomain
|
|
||||||
}
|
|
||||||
if req.Settings.LazyConnectionEnabled != nil {
|
|
||||||
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
|
|
||||||
}
|
}
|
||||||
if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" {
|
if req.Settings.NetworkRange != nil && *req.Settings.NetworkRange != "" {
|
||||||
prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange)
|
prefix, err := netip.ParsePrefix(*req.Settings.NetworkRange)
|
||||||
@@ -313,6 +338,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
|||||||
RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled,
|
RoutingPeerDnsResolutionEnabled: &settings.RoutingPeerDNSResolutionEnabled,
|
||||||
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
|
LazyConnectionEnabled: &settings.LazyConnectionEnabled,
|
||||||
DnsDomain: &settings.DNSDomain,
|
DnsDomain: &settings.DNSDomain,
|
||||||
|
AutoUpdateVersion: &settings.AutoUpdateVersion,
|
||||||
}
|
}
|
||||||
|
|
||||||
if settings.NetworkRange.IsValid() {
|
if settings.NetworkRange.IsValid() {
|
||||||
|
|||||||
@@ -121,6 +121,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
RoutingPeerDnsResolutionEnabled: br(false),
|
RoutingPeerDnsResolutionEnabled: br(false),
|
||||||
LazyConnectionEnabled: br(false),
|
LazyConnectionEnabled: br(false),
|
||||||
DnsDomain: sr(""),
|
DnsDomain: sr(""),
|
||||||
|
AutoUpdateVersion: sr(""),
|
||||||
},
|
},
|
||||||
expectedArray: true,
|
expectedArray: true,
|
||||||
expectedID: accountID,
|
expectedID: accountID,
|
||||||
@@ -143,6 +144,30 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
RoutingPeerDnsResolutionEnabled: br(false),
|
RoutingPeerDnsResolutionEnabled: br(false),
|
||||||
LazyConnectionEnabled: br(false),
|
LazyConnectionEnabled: br(false),
|
||||||
DnsDomain: sr(""),
|
DnsDomain: sr(""),
|
||||||
|
AutoUpdateVersion: sr(""),
|
||||||
|
},
|
||||||
|
expectedArray: false,
|
||||||
|
expectedID: accountID,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PutAccount OK with autoUpdateVersion",
|
||||||
|
expectedBody: true,
|
||||||
|
requestType: http.MethodPut,
|
||||||
|
requestPath: "/api/accounts/" + accountID,
|
||||||
|
requestBody: bytes.NewBufferString("{\"settings\": {\"auto_update_version\": \"latest\", \"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"),
|
||||||
|
expectedStatus: http.StatusOK,
|
||||||
|
expectedSettings: api.AccountSettings{
|
||||||
|
PeerLoginExpiration: 15552000,
|
||||||
|
PeerLoginExpirationEnabled: true,
|
||||||
|
GroupsPropagationEnabled: br(false),
|
||||||
|
JwtGroupsClaimName: sr(""),
|
||||||
|
JwtGroupsEnabled: br(false),
|
||||||
|
JwtAllowGroups: &[]string{},
|
||||||
|
RegularUsersViewBlocked: false,
|
||||||
|
RoutingPeerDnsResolutionEnabled: br(false),
|
||||||
|
LazyConnectionEnabled: br(false),
|
||||||
|
DnsDomain: sr(""),
|
||||||
|
AutoUpdateVersion: sr("latest"),
|
||||||
},
|
},
|
||||||
expectedArray: false,
|
expectedArray: false,
|
||||||
expectedID: accountID,
|
expectedID: accountID,
|
||||||
@@ -165,6 +190,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
RoutingPeerDnsResolutionEnabled: br(false),
|
RoutingPeerDnsResolutionEnabled: br(false),
|
||||||
LazyConnectionEnabled: br(false),
|
LazyConnectionEnabled: br(false),
|
||||||
DnsDomain: sr(""),
|
DnsDomain: sr(""),
|
||||||
|
AutoUpdateVersion: sr(""),
|
||||||
},
|
},
|
||||||
expectedArray: false,
|
expectedArray: false,
|
||||||
expectedID: accountID,
|
expectedID: accountID,
|
||||||
@@ -187,6 +213,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
RoutingPeerDnsResolutionEnabled: br(false),
|
RoutingPeerDnsResolutionEnabled: br(false),
|
||||||
LazyConnectionEnabled: br(false),
|
LazyConnectionEnabled: br(false),
|
||||||
DnsDomain: sr(""),
|
DnsDomain: sr(""),
|
||||||
|
AutoUpdateVersion: sr(""),
|
||||||
},
|
},
|
||||||
expectedArray: false,
|
expectedArray: false,
|
||||||
expectedID: accountID,
|
expectedID: accountID,
|
||||||
@@ -209,6 +236,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
|||||||
RoutingPeerDnsResolutionEnabled: br(false),
|
RoutingPeerDnsResolutionEnabled: br(false),
|
||||||
LazyConnectionEnabled: br(false),
|
LazyConnectionEnabled: br(false),
|
||||||
DnsDomain: sr(""),
|
DnsDomain: sr(""),
|
||||||
|
AutoUpdateVersion: sr(""),
|
||||||
},
|
},
|
||||||
expectedArray: false,
|
expectedArray: false,
|
||||||
expectedID: accountID,
|
expectedID: accountID,
|
||||||
|
|||||||
@@ -52,6 +52,9 @@ type Settings struct {
|
|||||||
|
|
||||||
// LazyConnectionEnabled indicates if the experimental feature is enabled or disabled
|
// LazyConnectionEnabled indicates if the experimental feature is enabled or disabled
|
||||||
LazyConnectionEnabled bool `gorm:"default:false"`
|
LazyConnectionEnabled bool `gorm:"default:false"`
|
||||||
|
|
||||||
|
// AutoUpdateVersion client auto-update version
|
||||||
|
AutoUpdateVersion string `gorm:"default:'disabled'"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Copy copies the Settings struct
|
// Copy copies the Settings struct
|
||||||
@@ -72,6 +75,7 @@ func (s *Settings) Copy() *Settings {
|
|||||||
LazyConnectionEnabled: s.LazyConnectionEnabled,
|
LazyConnectionEnabled: s.LazyConnectionEnabled,
|
||||||
DNSDomain: s.DNSDomain,
|
DNSDomain: s.DNSDomain,
|
||||||
NetworkRange: s.NetworkRange,
|
NetworkRange: s.NetworkRange,
|
||||||
|
AutoUpdateVersion: s.AutoUpdateVersion,
|
||||||
}
|
}
|
||||||
if s.Extra != nil {
|
if s.Extra != nil {
|
||||||
settings.Extra = s.Extra.Copy()
|
settings.Extra = s.Extra.Copy()
|
||||||
|
|||||||
@@ -145,6 +145,10 @@ components:
|
|||||||
description: Enables or disables experimental lazy connection
|
description: Enables or disables experimental lazy connection
|
||||||
type: boolean
|
type: boolean
|
||||||
example: true
|
example: true
|
||||||
|
auto_update_version:
|
||||||
|
description: Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1")
|
||||||
|
type: string
|
||||||
|
example: "0.51.2"
|
||||||
required:
|
required:
|
||||||
- peer_login_expiration_enabled
|
- peer_login_expiration_enabled
|
||||||
- peer_login_expiration
|
- peer_login_expiration
|
||||||
|
|||||||
@@ -291,6 +291,9 @@ type AccountRequest struct {
|
|||||||
|
|
||||||
// AccountSettings defines model for AccountSettings.
|
// AccountSettings defines model for AccountSettings.
|
||||||
type AccountSettings struct {
|
type AccountSettings struct {
|
||||||
|
// AutoUpdateVersion Set Clients auto-update version. "latest", "disabled", or a specific version (e.g "0.50.1")
|
||||||
|
AutoUpdateVersion *string `json:"auto_update_version,omitempty"`
|
||||||
|
|
||||||
// DnsDomain Allows to define a custom dns domain for the account
|
// DnsDomain Allows to define a custom dns domain for the account
|
||||||
DnsDomain *string `json:"dns_domain,omitempty"`
|
DnsDomain *string `json:"dns_domain,omitempty"`
|
||||||
Extra *AccountExtraSettings `json:"extra,omitempty"`
|
Extra *AccountExtraSettings `json:"extra,omitempty"`
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -280,6 +280,18 @@ message PeerConfig {
|
|||||||
bool LazyConnectionEnabled = 6;
|
bool LazyConnectionEnabled = 6;
|
||||||
|
|
||||||
int32 mtu = 7;
|
int32 mtu = 7;
|
||||||
|
|
||||||
|
// Auto-update config
|
||||||
|
AutoUpdateSettings autoUpdate = 8;
|
||||||
|
}
|
||||||
|
|
||||||
|
message AutoUpdateSettings {
|
||||||
|
string version = 1;
|
||||||
|
/*
|
||||||
|
alwaysUpdate = true → Updates happen automatically in the background
|
||||||
|
alwaysUpdate = false → Updates only happen when triggered by a peer connection
|
||||||
|
*/
|
||||||
|
bool alwaysUpdate = 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections
|
// NetworkMap represents a network state of the peer with the corresponding configuration parameters to establish peer-to-peer connections
|
||||||
|
|||||||
@@ -41,21 +41,28 @@ func NewUpdate(httpAgent string) *Update {
|
|||||||
currentVersion, _ = goversion.NewVersion("0.0.0")
|
currentVersion, _ = goversion.NewVersion("0.0.0")
|
||||||
}
|
}
|
||||||
|
|
||||||
latestAvailable, _ := goversion.NewVersion("0.0.0")
|
|
||||||
|
|
||||||
u := &Update{
|
u := &Update{
|
||||||
httpAgent: httpAgent,
|
httpAgent: httpAgent,
|
||||||
latestAvailable: latestAvailable,
|
uiVersion: currentVersion,
|
||||||
uiVersion: currentVersion,
|
fetchDone: make(chan struct{}),
|
||||||
fetchTicker: time.NewTicker(fetchPeriod),
|
|
||||||
fetchDone: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
go u.startFetcher()
|
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUpdateAndStart(httpAgent string) *Update {
|
||||||
|
u := NewUpdate(httpAgent)
|
||||||
|
go u.StartFetcher()
|
||||||
|
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
// StopWatch stop the version info fetch loop
|
// StopWatch stop the version info fetch loop
|
||||||
func (u *Update) StopWatch() {
|
func (u *Update) StopWatch() {
|
||||||
|
if u.fetchTicker == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
u.fetchTicker.Stop()
|
u.fetchTicker.Stop()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
@@ -94,7 +101,18 @@ func (u *Update) SetOnUpdateListener(updateFn func()) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *Update) startFetcher() {
|
func (u *Update) LatestVersion() *goversion.Version {
|
||||||
|
u.versionsLock.Lock()
|
||||||
|
defer u.versionsLock.Unlock()
|
||||||
|
return u.latestAvailable
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Update) StartFetcher() {
|
||||||
|
if u.fetchTicker != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
u.fetchTicker = time.NewTicker(fetchPeriod)
|
||||||
|
|
||||||
if changed := u.fetchVersion(); changed {
|
if changed := u.fetchVersion(); changed {
|
||||||
u.checkUpdate()
|
u.checkUpdate()
|
||||||
}
|
}
|
||||||
@@ -181,6 +199,10 @@ func (u *Update) isUpdateAvailable() bool {
|
|||||||
u.versionsLock.Lock()
|
u.versionsLock.Lock()
|
||||||
defer u.versionsLock.Unlock()
|
defer u.versionsLock.Unlock()
|
||||||
|
|
||||||
|
if u.latestAvailable == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
if u.latestAvailable.GreaterThan(u.uiVersion) {
|
if u.latestAvailable.GreaterThan(u.uiVersion) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func TestNewUpdate(t *testing.T) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
|
||||||
onUpdate := false
|
onUpdate := false
|
||||||
u := NewUpdate(httpAgent)
|
u := NewUpdateAndStart(httpAgent)
|
||||||
defer u.StopWatch()
|
defer u.StopWatch()
|
||||||
u.SetOnUpdateListener(func() {
|
u.SetOnUpdateListener(func() {
|
||||||
onUpdate = true
|
onUpdate = true
|
||||||
@@ -48,7 +48,7 @@ func TestDoNotUpdate(t *testing.T) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
|
||||||
onUpdate := false
|
onUpdate := false
|
||||||
u := NewUpdate(httpAgent)
|
u := NewUpdateAndStart(httpAgent)
|
||||||
defer u.StopWatch()
|
defer u.StopWatch()
|
||||||
u.SetOnUpdateListener(func() {
|
u.SetOnUpdateListener(func() {
|
||||||
onUpdate = true
|
onUpdate = true
|
||||||
@@ -73,7 +73,7 @@ func TestDaemonUpdate(t *testing.T) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
|
|
||||||
onUpdate := false
|
onUpdate := false
|
||||||
u := NewUpdate(httpAgent)
|
u := NewUpdateAndStart(httpAgent)
|
||||||
defer u.StopWatch()
|
defer u.StopWatch()
|
||||||
u.SetOnUpdateListener(func() {
|
u.SetOnUpdateListener(func() {
|
||||||
onUpdate = true
|
onUpdate = true
|
||||||
|
|||||||
Reference in New Issue
Block a user