mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-02 14:09:56 +00:00
Compare commits
2 Commits
profile-id
...
daemon-own
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dd301f2691 | ||
|
|
174dc24867 |
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -23,7 +24,6 @@ const (
|
||||
|
||||
// Profile represents a profile for gomobile
|
||||
type Profile struct {
|
||||
ID string
|
||||
Name string
|
||||
IsActive bool
|
||||
}
|
||||
@@ -99,7 +99,6 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
|
||||
var profiles []*Profile
|
||||
for _, p := range internalProfiles {
|
||||
profiles = append(profiles, &Profile{
|
||||
ID: p.ID,
|
||||
Name: p.Name,
|
||||
IsActive: p.IsActive,
|
||||
})
|
||||
@@ -109,65 +108,55 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
|
||||
}
|
||||
|
||||
// GetActiveProfile returns the currently active profile name
|
||||
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
|
||||
func (pm *ProfileManager) GetActiveProfile() (string, error) {
|
||||
// Use ServiceManager to stay consistent with ListProfiles
|
||||
// ServiceManager uses active_profile.json
|
||||
activeState, err := pm.serviceMgr.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get active profile: %w", err)
|
||||
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
|
||||
// ActiveProfileState only stores the ID (and username), not the display
|
||||
// name. Resolve the ID to the full profile so callers get the real Name.
|
||||
prof, err := pm.serviceMgr.ResolveProfile(activeState.ID, androidUsername)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to resolve active profile %q: %w", activeState.ID, err)
|
||||
}
|
||||
return &Profile{ID: prof.ID, Name: prof.Name, IsActive: true}, nil
|
||||
return activeState.Name, nil
|
||||
}
|
||||
|
||||
// SwitchProfile switches to a different profile
|
||||
func (pm *ProfileManager) SwitchProfile(id string) error {
|
||||
func (pm *ProfileManager) SwitchProfile(profileName string) error {
|
||||
// Use ServiceManager to stay consistent with ListProfiles
|
||||
// ServiceManager uses active_profile.json
|
||||
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
ID: id,
|
||||
Name: profileName,
|
||||
Username: androidUsername,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to switch profile: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("switched to profile: %s", id)
|
||||
log.Infof("switched to profile: %s", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddProfile creates a new profile
|
||||
func (pm *ProfileManager) AddProfile(profileName string) error {
|
||||
// Use ServiceManager (creates profile in profiles/ directory)
|
||||
profile, err := pm.serviceMgr.AddProfile(profileName, androidUsername)
|
||||
if err != nil {
|
||||
if err := pm.serviceMgr.AddProfile(profileName, androidUsername, nil); err != nil {
|
||||
return fmt.Errorf("failed to add profile: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("created new profile: %s", profile.ID)
|
||||
log.Infof("created new profile: %s", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogoutProfile logs out from a profile (clears authentication)
|
||||
func (pm *ProfileManager) LogoutProfile(id string) error {
|
||||
configPath, err := pm.getProfileConfigPath(id)
|
||||
func (pm *ProfileManager) LogoutProfile(profileName string) error {
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
|
||||
configPath, err := pm.getProfileConfigPath(profileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !profilemanager.IsValidProfileFilenameStem(id) {
|
||||
return fmt.Errorf("id '%s' is not valid", id)
|
||||
}
|
||||
|
||||
// Check if profile exists
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
return fmt.Errorf("profile '%s' does not exist", id)
|
||||
return fmt.Errorf("profile '%s' does not exist", profileName)
|
||||
}
|
||||
|
||||
// Read current config using internal profilemanager
|
||||
@@ -185,49 +174,53 @@ func (pm *ProfileManager) LogoutProfile(id string) error {
|
||||
return fmt.Errorf("failed to save config: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("logged out from profile: %s", id)
|
||||
log.Infof("logged out from profile: %s", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveProfile deletes a profile
|
||||
func (pm *ProfileManager) RemoveProfile(id string) error {
|
||||
func (pm *ProfileManager) RemoveProfile(profileName string) error {
|
||||
// Use ServiceManager (removes profile from profiles/ directory)
|
||||
if err := pm.serviceMgr.RemoveProfile(id, androidUsername); err != nil {
|
||||
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
|
||||
return fmt.Errorf("failed to remove profile: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("removed profile: %s", id)
|
||||
log.Infof("removed profile: %s", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// getProfileConfigPath returns the config file path for a profile
|
||||
// This is needed for Android-specific path handling (netbird.cfg for default profile)
|
||||
func (pm *ProfileManager) getProfileConfigPath(id string) (string, error) {
|
||||
if id == "" || id == profilemanager.DefaultProfileName {
|
||||
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
|
||||
if profileName == "" || profileName == profilemanager.DefaultProfileName {
|
||||
// Android uses netbird.cfg for default profile instead of default.json
|
||||
// Default profile is stored in root configDir, not in profiles/
|
||||
return filepath.Join(pm.configDir, defaultConfigFilename), nil
|
||||
}
|
||||
|
||||
// Non-default profiles are stored in profiles subdirectory
|
||||
// This matches the Java Preferences.java expectation
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
|
||||
return filepath.Join(profilesDir, id+".json"), nil
|
||||
return filepath.Join(profilesDir, profileName+".json"), nil
|
||||
}
|
||||
|
||||
// GetConfigPath returns the config file path for a given profile id
|
||||
// GetConfigPath returns the config file path for a given profile
|
||||
// Java should call this instead of constructing paths with Preferences.configFile()
|
||||
func (pm *ProfileManager) GetConfigPath(id string) (string, error) {
|
||||
return pm.getProfileConfigPath(id)
|
||||
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
|
||||
return pm.getProfileConfigPath(profileName)
|
||||
}
|
||||
|
||||
// GetStateFilePath returns the state file path for a given profile
|
||||
// Java should call this instead of constructing paths with Preferences.stateFile()
|
||||
func (pm *ProfileManager) GetStateFilePath(id string) (string, error) {
|
||||
if id == "" || id == profilemanager.DefaultProfileName {
|
||||
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
|
||||
if profileName == "" || profileName == profilemanager.DefaultProfileName {
|
||||
return filepath.Join(pm.configDir, "state.json"), nil
|
||||
}
|
||||
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
|
||||
return filepath.Join(profilesDir, id+".state.json"), nil
|
||||
return filepath.Join(profilesDir, profileName+".state.json"), nil
|
||||
}
|
||||
|
||||
// GetActiveConfigPath returns the config file path for the currently active profile
|
||||
@@ -237,7 +230,7 @@ func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
return pm.GetConfigPath(activeProfile.ID)
|
||||
return pm.GetConfigPath(activeProfile)
|
||||
}
|
||||
|
||||
// GetActiveStateFilePath returns the state file path for the currently active profile
|
||||
@@ -247,5 +240,18 @@ func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
return pm.GetStateFilePath(activeProfile.ID)
|
||||
return pm.GetStateFilePath(activeProfile)
|
||||
}
|
||||
|
||||
// sanitizeProfileName removes invalid characters from profile name
|
||||
func sanitizeProfileName(name string) string {
|
||||
// Keep only alphanumeric, underscore, and hyphen
|
||||
var result strings.Builder
|
||||
for _, r := range name {
|
||||
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
|
||||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
|
||||
result.WriteRune(r)
|
||||
}
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
@@ -102,11 +102,11 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
|
||||
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||
Hostname: hostName,
|
||||
DnsLabels: dnsLabelsReq,
|
||||
ProfileName: &activeProf.ID,
|
||||
ProfileName: &activeProf.Name,
|
||||
Username: &username,
|
||||
}
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.ID)
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
@@ -170,13 +170,14 @@ func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, pr
|
||||
return activeProf, nil
|
||||
}
|
||||
|
||||
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, handle string, username string) error {
|
||||
resolvedID, err := switchProfile(ctx, handle, username)
|
||||
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) error {
|
||||
err := switchProfile(context.Background(), profileName, username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile on daemon: %v", err)
|
||||
}
|
||||
|
||||
if err := pm.SwitchProfile(resolvedID); err != nil {
|
||||
err = pm.SwitchProfile(profileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
@@ -204,15 +205,11 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
|
||||
return nil
|
||||
}
|
||||
|
||||
// switchProfile asks the daemon to switch to the profile identified by
|
||||
// handle (a name, ID, or unique ID prefix). Returns the resolved profile
|
||||
// ID so the caller can update the local active-profile state without
|
||||
// re-resolving the handle.
|
||||
func switchProfile(ctx context.Context, handle string, username string) (string, error) {
|
||||
func switchProfile(ctx context.Context, profileName string, username string) error {
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
//nolint
|
||||
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||
"If the daemon is not running please run: "+
|
||||
"\nnetbird service install \nnetbird service start\n", err)
|
||||
}
|
||||
@@ -220,15 +217,15 @@ func switchProfile(ctx context.Context, handle string, username string) (string,
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
resp, err := client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &handle,
|
||||
_, err = client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &profileName,
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("switch profile failed: %v", err)
|
||||
return fmt.Errorf("switch profile failed: %v", err)
|
||||
}
|
||||
|
||||
return resp.Id, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, activeProf *profilemanager.Profile) error {
|
||||
@@ -252,7 +249,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
||||
return fmt.Errorf("read config file %s: %v", configFilePath, err)
|
||||
}
|
||||
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.ID)
|
||||
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("foreground login failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ func TestLogin(t *testing.T) {
|
||||
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
|
||||
sm := profilemanager.ServiceManager{}
|
||||
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
ID: "default",
|
||||
Name: "default",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
84
client/cmd/owner.go
Normal file
84
client/cmd/owner.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
var ownerCmd = &cobra.Command{
|
||||
Use: "owner",
|
||||
Short: "Manage daemon owner UIDs",
|
||||
Long: `Manage the list of UIDs allowed to control the NetBird daemon.
|
||||
|
||||
Owners are persisted in the active profile config and survive daemon restarts.
|
||||
The first call from the user logged in at the GUI / console session claims
|
||||
ownership automatically; these subcommands cover the rest of the lifecycle.`,
|
||||
}
|
||||
|
||||
var ownerAddCmd = &cobra.Command{
|
||||
Use: "add <uid>",
|
||||
Short: "Add a UID as an owner of the daemon",
|
||||
Long: `Add a UID to the active profile's owner list. Requires root or an
|
||||
existing owner. Use this to grant another local user permanent access without
|
||||
having them log in at the console first.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: addOwnerFunc,
|
||||
}
|
||||
|
||||
var ownerResetCmd = &cobra.Command{
|
||||
Use: "reset",
|
||||
Short: "Clear the daemon's owner list",
|
||||
Long: `Clear the active profile's owner list, returning the daemon to its
|
||||
unconfigured state. The next call from the active console-session user will
|
||||
re-claim ownership. Requires root.`,
|
||||
RunE: resetOwnerFunc,
|
||||
}
|
||||
|
||||
func addOwnerFunc(cmd *cobra.Command, args []string) error {
|
||||
if err := setupCmd(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
uid, err := strconv.ParseUint(args[0], 10, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse uid %q: %w", args[0], err)
|
||||
}
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
if _, err := client.AddOwner(cmd.Context(), &proto.AddOwnerRequest{Uid: uint32(uid)}); err != nil {
|
||||
return fmt.Errorf("add owner: %w", err)
|
||||
}
|
||||
|
||||
cmd.Printf("UID %d added as owner\n", uid)
|
||||
return nil
|
||||
}
|
||||
|
||||
func resetOwnerFunc(cmd *cobra.Command, _ []string) error {
|
||||
if err := setupCmd(cmd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
if _, err := client.ResetOwner(cmd.Context(), &proto.ResetOwnerRequest{}); err != nil {
|
||||
return fmt.Errorf("reset owner: %w", err)
|
||||
}
|
||||
|
||||
cmd.Println("daemon owner list cleared; next call from the active console user will re-claim ownership")
|
||||
return nil
|
||||
}
|
||||
@@ -2,16 +2,11 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/user"
|
||||
"strings"
|
||||
"text/tabwriter"
|
||||
"time"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc/codes"
|
||||
gstatus "google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
@@ -19,8 +14,6 @@ import (
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var profileListShowID bool
|
||||
|
||||
var profileCmd = &cobra.Command{
|
||||
Use: "profile",
|
||||
Short: "Manage NetBird client profiles",
|
||||
@@ -38,32 +31,27 @@ var profileListCmd = &cobra.Command{
|
||||
var profileAddCmd = &cobra.Command{
|
||||
Use: "add <profile_name>",
|
||||
Short: "Add a new profile",
|
||||
Long: `Add a new profile. Profile name is free-form, a unique ID is generated for the on-disk config file.`,
|
||||
Long: `Add a new profile to the NetBird client. The profile name must be unique.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: addProfileFunc,
|
||||
}
|
||||
|
||||
var profileRemoveCmd = &cobra.Command{
|
||||
Use: "remove <profile>",
|
||||
Short: "Remove a profile",
|
||||
Long: `Remove a profile by name, ID, or unique ID prefix.`,
|
||||
Aliases: []string{"rm"},
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: removeProfileFunc,
|
||||
Use: "remove <profile_name>",
|
||||
Short: "Remove a profile",
|
||||
Long: `Remove a profile from the NetBird client. The profile must not be inactive.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: removeProfileFunc,
|
||||
}
|
||||
|
||||
var profileSelectCmd = &cobra.Command{
|
||||
Use: "select <profile>",
|
||||
Use: "select <profile_name>",
|
||||
Short: "Select a profile",
|
||||
Long: `Make the specified profile active. Accepts a name, ID, or unique ID prefix.`,
|
||||
Long: `Make the specified profile active. This will switch the client to use the selected profile's configuration.`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: selectProfileFunc,
|
||||
}
|
||||
|
||||
func init() {
|
||||
profileListCmd.Flags().BoolVar(&profileListShowID, "show-id", false, "show the profile ID column")
|
||||
}
|
||||
|
||||
func setupCmd(cmd *cobra.Command) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
SetFlagsFromEnvVars(cmd)
|
||||
@@ -77,7 +65,6 @@ func setupCmd(cmd *cobra.Command) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func listProfilesFunc(cmd *cobra.Command, _ []string) error {
|
||||
if err := setupCmd(cmd); err != nil {
|
||||
return err
|
||||
@@ -96,32 +83,25 @@ func listProfilesFunc(cmd *cobra.Command, _ []string) error {
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
resp, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
|
||||
profiles, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tw := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0)
|
||||
if profileListShowID {
|
||||
fmt.Fprintln(tw, "ID\tNAME\tACTIVE")
|
||||
} else {
|
||||
fmt.Fprintln(tw, "NAME\tACTIVE")
|
||||
}
|
||||
for _, profile := range resp.Profiles {
|
||||
marker := ""
|
||||
// list profiles, add a tick if the profile is active
|
||||
cmd.Println("Found", len(profiles.Profiles), "profiles:")
|
||||
for _, profile := range profiles.Profiles {
|
||||
// use a cross to indicate the passive profiles
|
||||
activeMarker := "✗"
|
||||
if profile.IsActive {
|
||||
marker = "✓"
|
||||
}
|
||||
name := profilemanager.StripCtrlChars(profile.Name)
|
||||
if profileListShowID {
|
||||
fmt.Fprintf(tw, "%s\t%s\t%s\n", profilemanager.ShortID(profile.Id), name, marker)
|
||||
} else {
|
||||
fmt.Fprintf(tw, "%s\t%s\n", name, marker)
|
||||
activeMarker = "✓"
|
||||
}
|
||||
cmd.Println(activeMarker, profile.Name)
|
||||
}
|
||||
return tw.Flush()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -141,49 +121,19 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
profileName := args[0]
|
||||
|
||||
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
|
||||
_, err = daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
|
||||
ProfileName: profileName,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err == nil {
|
||||
cmd.Printf("Profile added: %s %s\n", profilemanager.ShortID(resp.Id), profilemanager.StripCtrlChars(profileName))
|
||||
return nil
|
||||
}
|
||||
|
||||
if st, ok := gstatus.FromError(err); ok && st.Code() == codes.AlreadyExists {
|
||||
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
|
||||
if dupCount > 0 {
|
||||
cmd.Printf("Warning: %d other profile(s) already use the name %q.\n", dupCount, profileName)
|
||||
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
|
||||
}
|
||||
resp, err = daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
|
||||
ProfileName: profileName,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Printf("Profile added: %s %s\n", profilemanager.ShortID(resp.Id), profilemanager.StripCtrlChars(profileName))
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func countProfilesWithName(ctx context.Context, c proto.DaemonServiceClient, username, name string) (int, error) {
|
||||
resp, err := c.ListProfiles(ctx, &proto.ListProfilesRequest{Username: username})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
return err
|
||||
}
|
||||
n := 0
|
||||
for _, p := range resp.Profiles {
|
||||
if p.Name == name {
|
||||
n++
|
||||
}
|
||||
}
|
||||
return n, nil
|
||||
|
||||
cmd.Println("Profile added successfully:", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
func removeProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
@@ -203,17 +153,18 @@ func removeProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
handle := args[0]
|
||||
|
||||
resp, err := daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
|
||||
ProfileName: handle,
|
||||
profileName := args[0]
|
||||
|
||||
_, err = daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
|
||||
ProfileName: profileName,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return wrapAmbiguityError(err, handle)
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Printf("Profile removed: %s\n", resp.Id)
|
||||
cmd.Println("Profile removed successfully:", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -223,7 +174,7 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
|
||||
profileManager := profilemanager.NewProfileManager()
|
||||
handle := args[0]
|
||||
profileName := args[0]
|
||||
|
||||
currUser, err := user.Current()
|
||||
if err != nil {
|
||||
@@ -240,15 +191,32 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
|
||||
daemonClient := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
switchResp, err := daemonClient.SwitchProfile(ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &handle,
|
||||
Username: &currUser.Username,
|
||||
profiles, err := daemonClient.ListProfiles(ctx, &proto.ListProfilesRequest{
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
return wrapAmbiguityError(err, handle)
|
||||
return fmt.Errorf("list profiles: %w", err)
|
||||
}
|
||||
|
||||
if err := profileManager.SwitchProfile(switchResp.Id); err != nil {
|
||||
var profileExists bool
|
||||
|
||||
for _, profile := range profiles.Profiles {
|
||||
if profile.Name == profileName {
|
||||
profileExists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !profileExists {
|
||||
return fmt.Errorf("profile %s does not exist", profileName)
|
||||
}
|
||||
|
||||
if err := switchProfile(cmd.Context(), profileName, currUser.Username); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = profileManager.SwitchProfile(profileName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -263,29 +231,6 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
cmd.Printf("Profile switched to: %s\n", profilemanager.ShortID(switchResp.Id))
|
||||
cmd.Println("Profile switched successfully to:", profileName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// wrapAmbiguityError turns the daemon's gRPC InvalidArgument errors
|
||||
// (which carry the resolver's message verbatim) into CLI-friendly text
|
||||
// that points the user at --show-id.
|
||||
func wrapAmbiguityError(err error, handle string) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
st, ok := gstatus.FromError(err)
|
||||
if !ok {
|
||||
return err
|
||||
}
|
||||
switch st.Code() {
|
||||
case codes.InvalidArgument:
|
||||
msg := st.Message()
|
||||
if strings.Contains(msg, "ambiguous") {
|
||||
return errors.New(msg + "\nRun `netbird profile list --show-id` to see IDs, then select by ID prefix:\n netbird profile select|remove <id-prefix>")
|
||||
}
|
||||
case codes.NotFound:
|
||||
return fmt.Errorf("profile %q not found", handle)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/owner"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
@@ -156,8 +157,12 @@ func init() {
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
rootCmd.AddCommand(profileCmd)
|
||||
rootCmd.AddCommand(ownerCmd)
|
||||
rootCmd.AddCommand(exposeCmd)
|
||||
|
||||
ownerCmd.AddCommand(ownerAddCmd)
|
||||
ownerCmd.AddCommand(ownerResetCmd)
|
||||
|
||||
networksCMD.AddCommand(routesListCmd)
|
||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||
|
||||
@@ -250,11 +255,24 @@ func DialClientGRPCServer(ctx context.Context, addr string) (*grpc.ClientConn, e
|
||||
return grpc.DialContext(
|
||||
ctx,
|
||||
strings.TrimPrefix(addr, "tcp://"),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
daemonDialTransportOption(addr),
|
||||
grpc.WithBlock(),
|
||||
)
|
||||
}
|
||||
|
||||
// daemonDialTransportOption returns the appropriate transport credentials for connecting
|
||||
// to the daemon. On Unix socket platforms, uses Unix transport credentials so the server
|
||||
// can extract the caller's UID for owner verification. Otherwise, uses insecure credentials.
|
||||
func daemonDialTransportOption(addr string) grpc.DialOption {
|
||||
if strings.HasPrefix(addr, "unix://") {
|
||||
creds := owner.NewUnixTransportCredentials()
|
||||
if creds != nil {
|
||||
return grpc.WithTransportCredentials(creds)
|
||||
}
|
||||
}
|
||||
return grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
}
|
||||
|
||||
// WithBackOff execute function in backoff cycle.
|
||||
func WithBackOff(bf func() error) error {
|
||||
return backoff.RetryNotify(bf, CLIBackOffSettings, func(err error, duration time.Duration) {
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/spf13/cobra"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/owner"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/server"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -29,9 +30,6 @@ func (p *program) Start(svc service.Service) error {
|
||||
// Collect static system and platform information
|
||||
system.UpdateStaticInfoAsync()
|
||||
|
||||
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
|
||||
p.serv = grpc.NewServer()
|
||||
|
||||
split := strings.Split(daemonAddr, "://")
|
||||
switch split[0] {
|
||||
case "unix":
|
||||
@@ -47,6 +45,12 @@ func (p *program) Start(svc service.Service) error {
|
||||
return fmt.Errorf("unsupported daemon address protocol: %v", split[0])
|
||||
}
|
||||
|
||||
// Set up owner enforcement for Unix sockets.
|
||||
configAdapter := &owner.ConfigAdapter{}
|
||||
serverOpts := ownerServerOpts(split[0], configAdapter)
|
||||
|
||||
p.serv = grpc.NewServer(serverOpts...)
|
||||
|
||||
listen, err := net.Listen(split[0], split[1])
|
||||
if err != nil {
|
||||
return fmt.Errorf("listen daemon interface: %w", err)
|
||||
@@ -65,6 +69,8 @@ func (p *program) Start(svc service.Service) error {
|
||||
if err := serverInstance.Start(); err != nil {
|
||||
log.Fatalf("failed to start daemon: %v", err)
|
||||
}
|
||||
|
||||
configAdapter.SetBackend(serverInstance)
|
||||
proto.RegisterDaemonServiceServer(p.serv, serverInstance)
|
||||
|
||||
p.serverInstanceMu.Lock()
|
||||
@@ -79,6 +85,32 @@ func (p *program) Start(svc service.Service) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ownerServerOpts returns gRPC server options for owner enforcement.
|
||||
// On Unix socket platforms, this includes transport credentials for peer credential
|
||||
// extraction and interceptors that check the caller's UID. On other platforms or TCP,
|
||||
// no owner enforcement is applied and a warning is logged so operators know the daemon
|
||||
// is running without per-user authorization.
|
||||
func ownerServerOpts(protocol string, configAdapter *owner.ConfigAdapter) []grpc.ServerOption {
|
||||
if protocol != "unix" {
|
||||
log.Warnf("daemon socket owner enforcement is not applied for protocol %q", protocol)
|
||||
return nil
|
||||
}
|
||||
|
||||
creds := owner.NewUnixTransportCredentials()
|
||||
if creds == nil {
|
||||
log.Warnf("daemon socket owner enforcement unavailable on this platform; daemon will accept any local connection")
|
||||
return nil
|
||||
}
|
||||
|
||||
interceptor := owner.NewInterceptor(configAdapter)
|
||||
|
||||
return []grpc.ServerOption{
|
||||
grpc.Creds(creds),
|
||||
grpc.ChainUnaryInterceptor(interceptor.UnaryInterceptor()),
|
||||
grpc.ChainStreamInterceptor(interceptor.StreamInterceptor()),
|
||||
}
|
||||
}
|
||||
|
||||
func (p *program) Stop(srv service.Service) error {
|
||||
p.serverInstanceMu.Lock()
|
||||
if p.serverInstance != nil {
|
||||
|
||||
@@ -44,6 +44,9 @@ const (
|
||||
|
||||
profileNameFlag = "profile"
|
||||
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
|
||||
|
||||
claimOwnerFlag = "owner"
|
||||
claimOwnerDesc = "claim owner privileges for this profile, restricting daemon control to the current user and root"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -54,6 +57,7 @@ var (
|
||||
showQR bool
|
||||
profileName string
|
||||
configPath string
|
||||
claimOwner bool
|
||||
|
||||
upCmd = &cobra.Command{
|
||||
Use: "up",
|
||||
@@ -87,6 +91,7 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
||||
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
|
||||
upCmd.PersistentFlags().BoolVar(&claimOwner, claimOwnerFlag, false, claimOwnerDesc)
|
||||
|
||||
}
|
||||
|
||||
@@ -128,12 +133,13 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
||||
var profileSwitched bool
|
||||
// switch profile if provided
|
||||
if profileName != "" {
|
||||
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
|
||||
err = switchProfile(cmd.Context(), profileName, username.Username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
if err := pm.SwitchProfile(resolvedID); err != nil {
|
||||
err = pm.SwitchProfile(profileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile: %v", err)
|
||||
}
|
||||
|
||||
@@ -260,10 +266,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
|
||||
}
|
||||
|
||||
// set the new config
|
||||
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.ID, username.Username)
|
||||
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.Name, username.Username)
|
||||
if _, err := client.SetConfig(ctx, req); err != nil {
|
||||
if st, ok := gstatus.FromError(err); ok && st.Code() == codes.Unavailable {
|
||||
log.Warnf("setConfig method is not available in the daemon: %s", st.Message())
|
||||
log.Warnf("setConfig method is not available in the daemon")
|
||||
} else {
|
||||
return fmt.Errorf("call service setConfig method: %v", err)
|
||||
}
|
||||
@@ -288,10 +294,10 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
|
||||
return fmt.Errorf("setup login request: %v", err)
|
||||
}
|
||||
|
||||
loginRequest.ProfileName = &activeProf.ID
|
||||
loginRequest.ProfileName = &activeProf.Name
|
||||
loginRequest.Username = &username
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.ID)
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
@@ -328,8 +334,9 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
|
||||
}
|
||||
|
||||
if _, err := client.Up(ctx, &proto.UpRequest{
|
||||
ProfileName: &activeProf.ID,
|
||||
ProfileName: &activeProf.Name,
|
||||
Username: &username,
|
||||
ClaimOwner: claimOwner,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("call service up method: %v", err)
|
||||
}
|
||||
|
||||
@@ -29,14 +29,14 @@ func TestUpDaemon(t *testing.T) {
|
||||
}
|
||||
|
||||
sm := profilemanager.ServiceManager{}
|
||||
created, err := sm.AddProfile("test1", currUser.Username)
|
||||
err = sm.AddProfile("test1", currUser.Username, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to add profile: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
ID: created.ID,
|
||||
Name: "test1",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -843,7 +843,6 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
"PreSharedKey": "sensitive: WireGuard pre-shared key",
|
||||
"SSHKey": "sensitive: SSH private key",
|
||||
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
|
||||
"Name": "non-config: profile name is not needed for debug purposes",
|
||||
}
|
||||
|
||||
mURL, _ := url.Parse("https://api.example.com:443")
|
||||
|
||||
46
client/internal/owner/config.go
Normal file
46
client/internal/owner/config.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package owner
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ConfigAdapter is a thread-safe OwnerConfig that delegates to a lazily-set backend.
|
||||
// This allows the interceptor to be created before the daemon server (and its config)
|
||||
// is initialized, which is necessary because gRPC interceptors are set at server creation time.
|
||||
type ConfigAdapter struct {
|
||||
mu sync.RWMutex
|
||||
backend OwnerConfig
|
||||
}
|
||||
|
||||
// SetBackend sets the actual config implementation. Must be called before any RPCs are served.
|
||||
func (a *ConfigAdapter) SetBackend(backend OwnerConfig) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.backend = backend
|
||||
}
|
||||
|
||||
// GetOwnerUIDs delegates to the backend.
|
||||
func (a *ConfigAdapter) GetOwnerUIDs() []UID {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
|
||||
if a.backend == nil {
|
||||
// No backend yet, return empty (root-only).
|
||||
return []UID{}
|
||||
}
|
||||
|
||||
return a.backend.GetOwnerUIDs()
|
||||
}
|
||||
|
||||
// AddOwnerUID delegates to the backend.
|
||||
func (a *ConfigAdapter) AddOwnerUID(uid UID) error {
|
||||
a.mu.RLock()
|
||||
defer a.mu.RUnlock()
|
||||
|
||||
if a.backend == nil {
|
||||
return fmt.Errorf("owner config backend not initialized")
|
||||
}
|
||||
|
||||
return a.backend.AddOwnerUID(uid)
|
||||
}
|
||||
17
client/internal/owner/consoleuser/consoleuser.go
Normal file
17
client/internal/owner/consoleuser/consoleuser.go
Normal file
@@ -0,0 +1,17 @@
|
||||
// Package consoleuser provides the OS-level "active console user" UID lookup
|
||||
// used to gate ownership TOFU. The active console user is the local user
|
||||
// physically at the machine (or in the foreground GUI session): the user that
|
||||
// can legitimately claim the daemon as theirs on first run.
|
||||
package consoleuser
|
||||
|
||||
// ActiveUID returns the UID of the currently active console / GUI session
|
||||
// user, and true if such a user exists. Returns 0, false on platforms without
|
||||
// a console concept (ios, android), on headless servers with no active
|
||||
// session, or on lookup failure.
|
||||
//
|
||||
// Implementations must fail closed: any error or ambiguity returns (0, false)
|
||||
// so that the caller treats the result as "no console user" rather than
|
||||
// granting access to an unverified UID.
|
||||
func ActiveUID() (uint32, bool) {
|
||||
return activeUID()
|
||||
}
|
||||
58
client/internal/owner/consoleuser/consoleuser_darwin.go
Normal file
58
client/internal/owner/consoleuser/consoleuser_darwin.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package consoleuser
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"github.com/ebitengine/purego"
|
||||
)
|
||||
|
||||
// activeUID returns the UID of the user currently logged into the macOS GUI
|
||||
// console session. Uses SCDynamicStoreCopyConsoleUser from the
|
||||
// SystemConfiguration framework via purego (no cgo).
|
||||
func activeUID() (uint32, bool) {
|
||||
sc, err := purego.Dlopen(
|
||||
"/System/Library/Frameworks/SystemConfiguration.framework/SystemConfiguration",
|
||||
purego.RTLD_NOW|purego.RTLD_GLOBAL,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
cf, err := purego.Dlopen(
|
||||
"/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation",
|
||||
purego.RTLD_NOW|purego.RTLD_GLOBAL,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
// CFStringRef SCDynamicStoreCopyConsoleUser(SCDynamicStoreRef store,
|
||||
// uid_t *uid, gid_t *gid);
|
||||
//
|
||||
// We pass nil for the store (NULL is accepted; the framework creates a
|
||||
// transient one), discard the returned CFStringRef username (we only
|
||||
// need the UID), and read uid via the out-pointer.
|
||||
var copyConsoleUser func(store uintptr, uidPtr, gidPtr unsafe.Pointer) uintptr
|
||||
purego.RegisterLibFunc(©ConsoleUser, sc, "SCDynamicStoreCopyConsoleUser")
|
||||
|
||||
var cfRelease func(uintptr)
|
||||
purego.RegisterLibFunc(&cfRelease, cf, "CFRelease")
|
||||
|
||||
var uid uint32
|
||||
var gid uint32
|
||||
|
||||
cfStr := copyConsoleUser(0, unsafe.Pointer(&uid), unsafe.Pointer(&gid))
|
||||
if cfStr == 0 {
|
||||
return 0, false
|
||||
}
|
||||
cfRelease(cfStr)
|
||||
|
||||
// loginwindow / no GUI session reports uid 0. We don't want the
|
||||
// console-user path to grant anything to root (root is already always
|
||||
// allowed by the interceptor), so treat uid 0 as "no console user".
|
||||
if uid == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return uid, true
|
||||
}
|
||||
34
client/internal/owner/consoleuser/consoleuser_freebsd.go
Normal file
34
client/internal/owner/consoleuser/consoleuser_freebsd.go
Normal file
@@ -0,0 +1,34 @@
|
||||
package consoleuser
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// activeUID returns the UID of the user currently logged into the FreeBSD
|
||||
// console. FreeBSD's vt(4) chowns the active virtual terminal device to the
|
||||
// logged-in user, so a non-root owner of any /dev/ttyvN reliably identifies
|
||||
// the console user.
|
||||
//
|
||||
// We scan /dev/ttyv0../dev/ttyv9 and return the first non-root owner. Network
|
||||
// ptys (pts) are intentionally not considered: SSH'd users are not "at the
|
||||
// console" and must not TOFU-claim ownership.
|
||||
func activeUID() (uint32, bool) {
|
||||
for i := 0; i < 10; i++ {
|
||||
path := fmt.Sprintf("/dev/ttyv%d", i)
|
||||
fi, err := os.Stat(path)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
st, ok := fi.Sys().(*syscall.Stat_t)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if st.Uid == 0 {
|
||||
continue
|
||||
}
|
||||
return st.Uid, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
64
client/internal/owner/consoleuser/consoleuser_linux.go
Normal file
64
client/internal/owner/consoleuser/consoleuser_linux.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package consoleuser
|
||||
|
||||
import (
|
||||
"github.com/godbus/dbus/v5"
|
||||
)
|
||||
|
||||
const (
|
||||
loginDest = "org.freedesktop.login1"
|
||||
loginPath = dbus.ObjectPath("/org/freedesktop/login1")
|
||||
loginInterface = "org.freedesktop.login1.Manager"
|
||||
listSessions = loginInterface + ".ListSessions"
|
||||
|
||||
sessionInterface = "org.freedesktop.login1.Session"
|
||||
sessionActive = sessionInterface + ".Active"
|
||||
sessionClass = sessionInterface + ".Class"
|
||||
)
|
||||
|
||||
// activeUID queries systemd-logind for the active local user session and
|
||||
// returns that user's UID. Falls back to (0, false) on any error or when no
|
||||
// active user session exists (headless box, no GUI, no login at the console).
|
||||
func activeUID() (uint32, bool) {
|
||||
conn, err := dbus.SystemBus()
|
||||
if err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
mgr := conn.Object(loginDest, loginPath)
|
||||
|
||||
// ListSessions returns []struct{ID string; UID uint32; User string;
|
||||
// Seat string; Path dbus.ObjectPath}.
|
||||
var sessions []struct {
|
||||
ID string
|
||||
UID uint32
|
||||
User string
|
||||
Seat string
|
||||
Path dbus.ObjectPath
|
||||
}
|
||||
if err := mgr.Call(listSessions, 0).Store(&sessions); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
for _, s := range sessions {
|
||||
obj := conn.Object(loginDest, s.Path)
|
||||
|
||||
active, err := obj.GetProperty(sessionActive)
|
||||
if err != nil || active.Value() != true {
|
||||
continue
|
||||
}
|
||||
|
||||
class, err := obj.GetProperty(sessionClass)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// Only "user" sessions count; "greeter" / "lock-screen" / etc. are
|
||||
// not someone we should grant ownership to.
|
||||
if classStr, ok := class.Value().(string); !ok || classStr != "user" {
|
||||
continue
|
||||
}
|
||||
|
||||
return s.UID, true
|
||||
}
|
||||
|
||||
return 0, false
|
||||
}
|
||||
9
client/internal/owner/consoleuser/consoleuser_other.go
Normal file
9
client/internal/owner/consoleuser/consoleuser_other.go
Normal file
@@ -0,0 +1,9 @@
|
||||
//go:build !linux && !darwin && !freebsd && !windows
|
||||
|
||||
package consoleuser
|
||||
|
||||
// activeUID has no meaning on platforms without a console-user concept
|
||||
// (ios, android). Returns no-user so TOFU never fires.
|
||||
func activeUID() (uint32, bool) {
|
||||
return 0, false
|
||||
}
|
||||
59
client/internal/owner/consoleuser/consoleuser_windows.go
Normal file
59
client/internal/owner/consoleuser/consoleuser_windows.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package consoleuser
|
||||
|
||||
import (
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// activeUID returns a synthetic UID (the user SID's RID) for the currently
|
||||
// active Windows console session. The owner package treats UIDs as opaque
|
||||
// uint32 identifiers; on Windows we use the user account RID, which is stable
|
||||
// per-account on a given machine.
|
||||
//
|
||||
// Returns (0, false) when there is no active console session, the session has
|
||||
// no logged-in user, or any lookup fails.
|
||||
func activeUID() (uint32, bool) {
|
||||
sessionID := windows.WTSGetActiveConsoleSessionId()
|
||||
if sessionID == 0xFFFFFFFF {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
var token windows.Token
|
||||
if err := windows.WTSQueryUserToken(sessionID, &token); err != nil {
|
||||
return 0, false
|
||||
}
|
||||
defer token.Close()
|
||||
|
||||
user, err := tokenUserSID(token)
|
||||
if err != nil || user == nil {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
subCount := user.SubAuthorityCount()
|
||||
if subCount == 0 {
|
||||
return 0, false
|
||||
}
|
||||
rid := user.SubAuthority(uint32(subCount) - 1)
|
||||
if rid == 0 {
|
||||
return 0, false
|
||||
}
|
||||
return rid, true
|
||||
}
|
||||
|
||||
// tokenUserSID returns the user SID associated with the given access token.
|
||||
func tokenUserSID(token windows.Token) (*windows.SID, error) {
|
||||
var size uint32
|
||||
err := windows.GetTokenInformation(token, windows.TokenUser, nil, 0, &size)
|
||||
if err != windows.ERROR_INSUFFICIENT_BUFFER {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
buf := make([]byte, size)
|
||||
if err := windows.GetTokenInformation(token, windows.TokenUser, &buf[0], size, &size); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tu := (*windows.Tokenuser)(unsafe.Pointer(&buf[0]))
|
||||
return tu.User.Sid, nil
|
||||
}
|
||||
37
client/internal/owner/creds.go
Normal file
37
client/internal/owner/creds.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package owner
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/peer"
|
||||
)
|
||||
|
||||
// UnixAuthInfo implements credentials.AuthInfo carrying the peer's UID from SO_PEERCRED.
|
||||
type UnixAuthInfo struct {
|
||||
credentials.CommonAuthInfo
|
||||
UID UID
|
||||
GID uint32
|
||||
PID int32
|
||||
}
|
||||
|
||||
// AuthType returns the authentication type.
|
||||
func (u UnixAuthInfo) AuthType() string {
|
||||
return "unix_peercred"
|
||||
}
|
||||
|
||||
// UIDFromContext extracts the caller's UID from the gRPC peer context.
|
||||
// Returns uid and true if Unix credentials were available, 0 and false otherwise.
|
||||
func UIDFromContext(ctx context.Context) (UID, bool) {
|
||||
p, ok := peer.FromContext(ctx)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
info, ok := p.AuthInfo.(UnixAuthInfo)
|
||||
if !ok {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
return info.UID, true
|
||||
}
|
||||
48
client/internal/owner/env.go
Normal file
48
client/internal/owner/env.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package owner
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// EnvOwnerUID is the environment variable that seeds the owner UID list for new config files.
|
||||
// MDM deployments can set this (e.g. via --service-env NB_OWNER_UID=1000) so the first
|
||||
// config created by the daemon pre-populates the owner without requiring "netbird up --owner".
|
||||
// Multiple UIDs can be comma-separated: NB_OWNER_UID=1000,1001
|
||||
const EnvOwnerUID = "NB_OWNER_UID"
|
||||
|
||||
// OwnerUIDsFromEnv parses NB_OWNER_UID into a UID slice.
|
||||
// Returns nil if the variable is unset, allowing the caller to distinguish
|
||||
// "not configured" from "explicitly empty".
|
||||
func OwnerUIDsFromEnv() []UID {
|
||||
val := os.Getenv(EnvOwnerUID)
|
||||
if val == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := strings.Split(val, ",")
|
||||
uids := make([]UID, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
uid, err := strconv.ParseUint(p, 10, 32)
|
||||
if err != nil {
|
||||
log.Warnf("ignoring invalid UID %q in %s: %v", p, EnvOwnerUID, err)
|
||||
continue
|
||||
}
|
||||
uids = append(uids, UID(uid))
|
||||
}
|
||||
|
||||
if len(uids) == 0 {
|
||||
log.Warnf("%s set but contains no valid UIDs, defaulting to root-only", EnvOwnerUID)
|
||||
return []UID{}
|
||||
}
|
||||
|
||||
log.Infof("seeding owner UIDs from %s: %v", EnvOwnerUID, uids)
|
||||
return uids
|
||||
}
|
||||
81
client/internal/owner/env_test.go
Normal file
81
client/internal/owner/env_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package owner
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOwnerUIDsFromEnv(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
envValue string
|
||||
unset bool
|
||||
want []UID
|
||||
}{
|
||||
{
|
||||
name: "unset returns nil",
|
||||
unset: true,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "empty string returns nil",
|
||||
envValue: "",
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "single UID",
|
||||
envValue: "1000",
|
||||
want: []UID{1000},
|
||||
},
|
||||
{
|
||||
name: "multiple UIDs",
|
||||
envValue: "1000,1001,1002",
|
||||
want: []UID{1000, 1001, 1002},
|
||||
},
|
||||
{
|
||||
name: "spaces around UIDs",
|
||||
envValue: " 1000 , 1001 ",
|
||||
want: []UID{1000, 1001},
|
||||
},
|
||||
{
|
||||
name: "invalid UID skipped",
|
||||
envValue: "1000,notanumber,1001",
|
||||
want: []UID{1000, 1001},
|
||||
},
|
||||
{
|
||||
name: "all invalid returns empty slice",
|
||||
envValue: "abc,def",
|
||||
want: []UID{},
|
||||
},
|
||||
{
|
||||
name: "trailing comma",
|
||||
envValue: "1000,",
|
||||
want: []UID{1000},
|
||||
},
|
||||
{
|
||||
name: "zero UID is valid",
|
||||
envValue: "0",
|
||||
want: []UID{0},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv(EnvOwnerUID, tt.envValue)
|
||||
if tt.unset {
|
||||
os.Unsetenv(EnvOwnerUID)
|
||||
}
|
||||
|
||||
got := OwnerUIDsFromEnv()
|
||||
|
||||
if tt.want == nil {
|
||||
require.Nil(t, got)
|
||||
} else {
|
||||
assert.Equal(t, tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
170
client/internal/owner/interceptor.go
Normal file
170
client/internal/owner/interceptor.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package owner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/owner/consoleuser"
|
||||
)
|
||||
|
||||
const servicePath = "/daemon.DaemonService/"
|
||||
|
||||
// profileBypassMethods skip the active-profile owner check. They either
|
||||
// operate on a specific target profile (and the handler enforces target-profile
|
||||
// owner-or-root itself) or are per-user listings/creations that don't affect
|
||||
// the active session and shouldn't require active-profile ownership. Peer
|
||||
// credentials are still required.
|
||||
var profileBypassMethods = map[string]bool{
|
||||
servicePath + "AddProfile": true,
|
||||
servicePath + "ListProfiles": true,
|
||||
servicePath + "RemoveProfile": true,
|
||||
servicePath + "SwitchProfile": true,
|
||||
}
|
||||
|
||||
// Error messages returned to denied callers. They are multi-line so the
|
||||
// suggested commands sit on their own line for easy triple-click copy-paste.
|
||||
const (
|
||||
errNoPeerCreds = "peer credentials unavailable; rerun via the netbird CLI"
|
||||
|
||||
errNoOwnerConfigured = `no daemon owner is configured and no console-session user matches your UID.
|
||||
Run as root for one-off use:
|
||||
sudo netbird ...
|
||||
Or call from the active console session: the first call from the user logged in
|
||||
at the GUI/console claims ownership automatically.`
|
||||
|
||||
errOwnerRequired = `this operation requires root or the daemon owner (uid %d is not an owner).
|
||||
Run as root for one-off use:
|
||||
sudo netbird ...
|
||||
Or ask an existing owner (or root) to add you:
|
||||
sudo netbird owner add %[1]d`
|
||||
)
|
||||
|
||||
// consoleUIDLookup is the function used to look up the active console UID.
|
||||
// Overridable in tests; defaults to the platform implementation.
|
||||
var consoleUIDLookup = consoleuser.ActiveUID
|
||||
|
||||
// OwnerConfig provides access to the current owner UIDs setting.
|
||||
// The interceptor reads and writes through this interface so it can
|
||||
// work with the profile manager's config without a direct dependency.
|
||||
type OwnerConfig interface {
|
||||
// GetOwnerUIDs returns the current owner UIDs.
|
||||
// nil means legacy/migration TOFU (field absent from existing config).
|
||||
// empty means fresh install (root-only with console-user TOFU exception).
|
||||
// populated means those UIDs plus root may control the daemon.
|
||||
GetOwnerUIDs() []UID
|
||||
|
||||
// AddOwnerUID adds the given UID to the owner list and persists it.
|
||||
AddOwnerUID(uid UID) error
|
||||
}
|
||||
|
||||
// Interceptor enforces owner restrictions on the daemon gRPC socket.
|
||||
type Interceptor struct {
|
||||
config OwnerConfig
|
||||
// mu serializes the read-then-write of OwnerUIDs during TOFU/claim flows
|
||||
// so two concurrent first-callers can't both end up persisted as owners.
|
||||
// Holds across the OwnerConfig.AddOwnerUID call; safe because no callback
|
||||
// path takes this mutex.
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewInterceptor creates an owner interceptor backed by the given config.
|
||||
func NewInterceptor(config OwnerConfig) *Interceptor {
|
||||
return &Interceptor{config: config}
|
||||
}
|
||||
|
||||
// UnaryInterceptor returns a gRPC unary server interceptor that enforces owner policy.
|
||||
func (i *Interceptor) UnaryInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(
|
||||
ctx context.Context,
|
||||
req any,
|
||||
info *grpc.UnaryServerInfo,
|
||||
handler grpc.UnaryHandler,
|
||||
) (any, error) {
|
||||
if err := i.authorize(ctx, info.FullMethod); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
// StreamInterceptor returns a gRPC stream server interceptor that enforces owner policy.
|
||||
func (i *Interceptor) StreamInterceptor() grpc.StreamServerInterceptor {
|
||||
return func(
|
||||
srv any,
|
||||
ss grpc.ServerStream,
|
||||
info *grpc.StreamServerInfo,
|
||||
handler grpc.StreamHandler,
|
||||
) error {
|
||||
if err := i.authorize(ss.Context(), info.FullMethod); err != nil {
|
||||
return err
|
||||
}
|
||||
return handler(srv, ss)
|
||||
}
|
||||
}
|
||||
|
||||
// authorize checks whether the caller is allowed to call the given method.
|
||||
// Every RPC is gated; root is always allowed. Non-root callers are accepted
|
||||
// when they are existing owners, when the config is in legacy TOFU state
|
||||
// (claim on first call, preserves pre-enforcement behavior), or when the
|
||||
// config is in fresh-install state and they match the active console user.
|
||||
func (i *Interceptor) authorize(ctx context.Context, fullMethod string) error {
|
||||
uid, ok := UIDFromContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.PermissionDenied, errNoPeerCreds)
|
||||
}
|
||||
|
||||
if uid == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Profile-management RPCs do their own per-target authorization in the
|
||||
// handler. The interceptor only confirms peer credentials are present.
|
||||
if profileBypassMethods[fullMethod] {
|
||||
return nil
|
||||
}
|
||||
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
ownerUIDs := i.config.GetOwnerUIDs()
|
||||
|
||||
switch {
|
||||
case ownerUIDs == nil:
|
||||
// Legacy / migration TOFU: existing pre-enforcement config has no
|
||||
// owners field. Any non-root local caller claims on first call so
|
||||
// upgrades don't break.
|
||||
return i.claim(uid, "migration TOFU")
|
||||
|
||||
case len(ownerUIDs) == 0:
|
||||
// Fresh-install root-only mode with a console-user exception so the
|
||||
// GUI/CLI just works for the user physically at the machine. SSH'd
|
||||
// or otherwise non-console callers are denied.
|
||||
consoleUID, ok := consoleUIDLookup()
|
||||
if ok && uint32(uid) == consoleUID {
|
||||
return i.claim(uid, "console-user TOFU")
|
||||
}
|
||||
return status.Error(codes.PermissionDenied, errNoOwnerConfigured)
|
||||
|
||||
case slices.Contains(ownerUIDs, uid):
|
||||
return nil
|
||||
|
||||
default:
|
||||
return status.Errorf(codes.PermissionDenied, errOwnerRequired, uid)
|
||||
}
|
||||
}
|
||||
|
||||
// claim adds uid to the owner list and persists it. The caller must hold i.mu.
|
||||
func (i *Interceptor) claim(uid UID, reason string) error {
|
||||
log.Infof("%s: claiming owner for UID %d", reason, uid)
|
||||
if err := i.config.AddOwnerUID(uid); err != nil {
|
||||
log.Errorf("persist owner UID: %v", err)
|
||||
return status.Error(codes.Internal, "persist owner UID")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
277
client/internal/owner/interceptor_test.go
Normal file
277
client/internal/owner/interceptor_test.go
Normal file
@@ -0,0 +1,277 @@
|
||||
package owner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type mockOwnerConfig struct {
|
||||
uids []UID
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockOwnerConfig) GetOwnerUIDs() []UID {
|
||||
return m.uids
|
||||
}
|
||||
|
||||
func (m *mockOwnerConfig) AddOwnerUID(uid UID) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
m.uids = append(m.uids, uid)
|
||||
return nil
|
||||
}
|
||||
|
||||
func peerContext(uid UID) context.Context {
|
||||
return peer.NewContext(context.Background(), &peer.Peer{
|
||||
Addr: &net.UnixAddr{Name: "/tmp/test.sock", Net: "unix"},
|
||||
AuthInfo: UnixAuthInfo{
|
||||
CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity},
|
||||
UID: uid,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func noPeerContext() context.Context {
|
||||
return context.Background()
|
||||
}
|
||||
|
||||
// withConsoleUID overrides the platform console-user lookup for a single test.
|
||||
func withConsoleUID(t *testing.T, uid uint32, ok bool) {
|
||||
t.Helper()
|
||||
prev := consoleUIDLookup
|
||||
consoleUIDLookup = func() (uint32, bool) { return uid, ok }
|
||||
t.Cleanup(func() { consoleUIDLookup = prev })
|
||||
}
|
||||
|
||||
func TestInterceptor_RootAlwaysAllowed(t *testing.T) {
|
||||
cfg := &mockOwnerConfig{uids: []UID{1000}}
|
||||
interceptor := NewInterceptor(cfg)
|
||||
|
||||
for _, method := range []string{
|
||||
"/daemon.DaemonService/Up",
|
||||
"/daemon.DaemonService/Status",
|
||||
"/daemon.DaemonService/Down",
|
||||
} {
|
||||
err := interceptor.authorize(peerContext(0), method)
|
||||
assert.NoError(t, err, "root should always be allowed for %s", method)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInterceptor_NoPeerCreds_AlwaysDenies(t *testing.T) {
|
||||
cfg := &mockOwnerConfig{uids: []UID{1000}}
|
||||
interceptor := NewInterceptor(cfg)
|
||||
|
||||
for _, method := range []string{
|
||||
"/daemon.DaemonService/Status",
|
||||
"/daemon.DaemonService/Up",
|
||||
"/daemon.DaemonService/SomeNewMethod",
|
||||
} {
|
||||
err := interceptor.authorize(noPeerContext(), method)
|
||||
require.Error(t, err, "method %s should be denied without peer creds", method)
|
||||
assert.Equal(t, codes.PermissionDenied, status.Code(err))
|
||||
}
|
||||
}
|
||||
|
||||
// TestInterceptor_LegacyMigration covers the nil-OwnerUIDs branch:
|
||||
// pre-enforcement configs upgraded to this version. Any non-root local caller
|
||||
// can claim on first call.
|
||||
func TestInterceptor_LegacyMigration_AnyCallerClaims(t *testing.T) {
|
||||
withConsoleUID(t, 0, false) // no console; should not matter for nil
|
||||
cfg := &mockOwnerConfig{uids: nil}
|
||||
interceptor := NewInterceptor(cfg)
|
||||
|
||||
// First call from any UID claims regardless of method.
|
||||
err := interceptor.authorize(peerContext(1000), "/daemon.DaemonService/Status")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []UID{1000}, cfg.uids)
|
||||
|
||||
// After claim, a different UID is denied.
|
||||
err = interceptor.authorize(peerContext(2000), "/daemon.DaemonService/Status")
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.PermissionDenied, status.Code(err))
|
||||
}
|
||||
|
||||
// TestInterceptor_FreshInstall covers the empty-OwnerUIDs branch: console-user
|
||||
// can claim, others denied.
|
||||
func TestInterceptor_FreshInstall_ConsoleUserClaims(t *testing.T) {
|
||||
withConsoleUID(t, 1000, true)
|
||||
cfg := &mockOwnerConfig{uids: []UID{}}
|
||||
interceptor := NewInterceptor(cfg)
|
||||
|
||||
err := interceptor.authorize(peerContext(1000), "/daemon.DaemonService/Status")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []UID{1000}, cfg.uids)
|
||||
}
|
||||
|
||||
func TestInterceptor_FreshInstall_NonConsoleDenied(t *testing.T) {
|
||||
withConsoleUID(t, 1000, true)
|
||||
cfg := &mockOwnerConfig{uids: []UID{}}
|
||||
interceptor := NewInterceptor(cfg)
|
||||
|
||||
err := interceptor.authorize(peerContext(2000), "/daemon.DaemonService/Up")
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.PermissionDenied, status.Code(err))
|
||||
assert.Empty(t, cfg.uids, "non-console caller must not claim")
|
||||
}
|
||||
|
||||
func TestInterceptor_FreshInstall_NoConsole_Denied(t *testing.T) {
|
||||
withConsoleUID(t, 0, false)
|
||||
cfg := &mockOwnerConfig{uids: []UID{}}
|
||||
interceptor := NewInterceptor(cfg)
|
||||
|
||||
err := interceptor.authorize(peerContext(1000), "/daemon.DaemonService/Up")
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.PermissionDenied, status.Code(err))
|
||||
}
|
||||
|
||||
func TestInterceptor_OwnerUID_AllowsOwner(t *testing.T) {
|
||||
cfg := &mockOwnerConfig{uids: []UID{1000}}
|
||||
interceptor := NewInterceptor(cfg)
|
||||
|
||||
err := interceptor.authorize(peerContext(1000), "/daemon.DaemonService/Down")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInterceptor_OwnerUID_DeniesOther(t *testing.T) {
|
||||
withConsoleUID(t, 9999, true) // console-user TOFU should not apply once owners exist
|
||||
cfg := &mockOwnerConfig{uids: []UID{1000}}
|
||||
interceptor := NewInterceptor(cfg)
|
||||
|
||||
err := interceptor.authorize(peerContext(2000), "/daemon.DaemonService/Down")
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.PermissionDenied, status.Code(err))
|
||||
}
|
||||
|
||||
func TestInterceptor_MultipleOwners(t *testing.T) {
|
||||
cfg := &mockOwnerConfig{uids: []UID{1000, 2000}}
|
||||
interceptor := NewInterceptor(cfg)
|
||||
|
||||
err := interceptor.authorize(peerContext(1000), "/daemon.DaemonService/Down")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = interceptor.authorize(peerContext(2000), "/daemon.DaemonService/Up")
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = interceptor.authorize(peerContext(3000), "/daemon.DaemonService/Down")
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.PermissionDenied, status.Code(err))
|
||||
}
|
||||
|
||||
// TestInterceptor_UnknownMethodRequiresOwner pins the safe-by-default invariant:
|
||||
// any future RPC still goes through owner enforcement.
|
||||
func TestInterceptor_UnknownMethodRequiresOwner(t *testing.T) {
|
||||
cfg := &mockOwnerConfig{uids: []UID{1000}}
|
||||
interceptor := NewInterceptor(cfg)
|
||||
|
||||
err := interceptor.authorize(peerContext(2000), "/daemon.DaemonService/SomeFutureMethod")
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.PermissionDenied, status.Code(err))
|
||||
|
||||
err = interceptor.authorize(peerContext(1000), "/daemon.DaemonService/SomeFutureMethod")
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestInterceptor_ErrorMessageActionable(t *testing.T) {
|
||||
withConsoleUID(t, 9999, true)
|
||||
cfg := &mockOwnerConfig{uids: []UID{1000}}
|
||||
interceptor := NewInterceptor(cfg)
|
||||
|
||||
err := interceptor.authorize(peerContext(2000), "/daemon.DaemonService/Down")
|
||||
require.Error(t, err)
|
||||
msg := status.Convert(err).Message()
|
||||
assert.Contains(t, msg, "sudo netbird")
|
||||
assert.Contains(t, msg, "owner add")
|
||||
}
|
||||
|
||||
func TestInterceptor_UnaryIntegration(t *testing.T) {
|
||||
cfg := &mockOwnerConfig{uids: []UID{1000}}
|
||||
interceptor := NewInterceptor(cfg)
|
||||
|
||||
unary := interceptor.UnaryInterceptor()
|
||||
|
||||
resp, err := unary(peerContext(1000), nil, &grpc.UnaryServerInfo{FullMethod: "/daemon.DaemonService/Down"}, func(ctx context.Context, req any) (any, error) {
|
||||
return "ok", nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "ok", resp)
|
||||
|
||||
_, err = unary(peerContext(2000), nil, &grpc.UnaryServerInfo{FullMethod: "/daemon.DaemonService/Down"}, func(ctx context.Context, req any) (any, error) {
|
||||
t.Fatal("handler should not be called")
|
||||
return nil, nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.PermissionDenied, status.Code(err))
|
||||
}
|
||||
|
||||
func TestInterceptor_StreamIntegration(t *testing.T) {
|
||||
cfg := &mockOwnerConfig{uids: []UID{1000}}
|
||||
interceptor := NewInterceptor(cfg)
|
||||
|
||||
stream := interceptor.StreamInterceptor()
|
||||
|
||||
called := false
|
||||
err := stream(nil, &mockServerStream{ctx: peerContext(1000)},
|
||||
&grpc.StreamServerInfo{FullMethod: "/daemon.DaemonService/SubscribeEvents"},
|
||||
func(srv any, stream grpc.ServerStream) error {
|
||||
called = true
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.True(t, called)
|
||||
|
||||
err = stream(nil, &mockServerStream{ctx: peerContext(2000)},
|
||||
&grpc.StreamServerInfo{FullMethod: "/daemon.DaemonService/SubscribeEvents"},
|
||||
func(srv any, stream grpc.ServerStream) error {
|
||||
t.Fatal("handler should not be called")
|
||||
return nil
|
||||
})
|
||||
require.Error(t, err)
|
||||
assert.Equal(t, codes.PermissionDenied, status.Code(err))
|
||||
}
|
||||
|
||||
type mockServerStream struct {
|
||||
grpc.ServerStream
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (m *mockServerStream) Context() context.Context { return m.ctx }
|
||||
|
||||
// TestInterceptor_ProfileBypass pins that profile-management methods reach
|
||||
// the handler regardless of active-profile ownership; the handler enforces
|
||||
// per-target-profile auth itself.
|
||||
func TestInterceptor_ProfileBypass(t *testing.T) {
|
||||
cfg := &mockOwnerConfig{uids: []UID{1000}}
|
||||
interceptor := NewInterceptor(cfg)
|
||||
|
||||
// Caller UID 2000 is not an owner of the active profile but must be
|
||||
// allowed through for these methods.
|
||||
for _, method := range []string{
|
||||
"/daemon.DaemonService/AddProfile",
|
||||
"/daemon.DaemonService/ListProfiles",
|
||||
"/daemon.DaemonService/RemoveProfile",
|
||||
"/daemon.DaemonService/SwitchProfile",
|
||||
} {
|
||||
err := interceptor.authorize(peerContext(2000), method)
|
||||
assert.NoError(t, err, "profile method %s should bypass active-owner check", method)
|
||||
}
|
||||
|
||||
// Without peer creds, even bypass methods are denied.
|
||||
for _, method := range []string{
|
||||
"/daemon.DaemonService/AddProfile",
|
||||
"/daemon.DaemonService/SwitchProfile",
|
||||
} {
|
||||
err := interceptor.authorize(noPeerContext(), method)
|
||||
require.Error(t, err, "bypass method %s still requires peer creds", method)
|
||||
assert.Equal(t, codes.PermissionDenied, status.Code(err))
|
||||
}
|
||||
}
|
||||
66
client/internal/owner/transport_bsd.go
Normal file
66
client/internal/owner/transport_bsd.go
Normal file
@@ -0,0 +1,66 @@
|
||||
//go:build darwin || freebsd
|
||||
|
||||
package owner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
// NewUnixTransportCredentials returns gRPC TransportCredentials that extract
|
||||
// peer UID from Unix socket connections via LOCAL_PEERCRED (Xucred).
|
||||
func NewUnixTransportCredentials() credentials.TransportCredentials {
|
||||
return &unixCreds{}
|
||||
}
|
||||
|
||||
type unixCreds struct{}
|
||||
|
||||
func (c *unixCreds) ClientHandshake(_ context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
return conn, UnixAuthInfo{}, nil
|
||||
}
|
||||
|
||||
// ServerHandshake extracts peer credentials from the Unix connection using LOCAL_PEERCRED.
|
||||
// Returns an error if credentials cannot be extracted (fail-closed).
|
||||
func (c *unixCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
uc, ok := conn.(*net.UnixConn)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("expected *net.UnixConn, got %T", conn)
|
||||
}
|
||||
|
||||
raw, err := uc.SyscallConn()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get raw conn for peer credentials: %w", err)
|
||||
}
|
||||
|
||||
var xucred *unix.Xucred
|
||||
var credErr error
|
||||
if err := raw.Control(func(fd uintptr) {
|
||||
xucred, credErr = unix.GetsockoptXucred(int(fd), unix.SOL_LOCAL, unix.LOCAL_PEERCRED)
|
||||
}); err != nil {
|
||||
return nil, nil, fmt.Errorf("control raw conn for peer credentials: %w", err)
|
||||
}
|
||||
if credErr != nil {
|
||||
return nil, nil, fmt.Errorf("get peer credentials: %w", credErr)
|
||||
}
|
||||
|
||||
return conn, UnixAuthInfo{
|
||||
CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity},
|
||||
UID: UID(xucred.Uid),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *unixCreds) Info() credentials.ProtocolInfo {
|
||||
return credentials.ProtocolInfo{SecurityProtocol: "unix_peercred"}
|
||||
}
|
||||
|
||||
func (c *unixCreds) Clone() credentials.TransportCredentials {
|
||||
return &unixCreds{}
|
||||
}
|
||||
|
||||
func (c *unixCreds) OverrideServerName(_ string) error {
|
||||
return nil
|
||||
}
|
||||
11
client/internal/owner/transport_generic.go
Normal file
11
client/internal/owner/transport_generic.go
Normal file
@@ -0,0 +1,11 @@
|
||||
//go:build !linux && !darwin && !freebsd
|
||||
|
||||
package owner
|
||||
|
||||
import "google.golang.org/grpc/credentials"
|
||||
|
||||
// NewUnixTransportCredentials returns nil on platforms without Unix socket peer credentials.
|
||||
// The daemon should use insecure credentials and skip owner enforcement.
|
||||
func NewUnixTransportCredentials() credentials.TransportCredentials {
|
||||
return nil
|
||||
}
|
||||
66
client/internal/owner/transport_linux.go
Normal file
66
client/internal/owner/transport_linux.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package owner
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
// NewUnixTransportCredentials returns gRPC TransportCredentials that extract
|
||||
// peer UID/GID/PID from Unix socket connections via SO_PEERCRED.
|
||||
func NewUnixTransportCredentials() credentials.TransportCredentials {
|
||||
return &unixCreds{}
|
||||
}
|
||||
|
||||
type unixCreds struct{}
|
||||
|
||||
func (c *unixCreds) ClientHandshake(_ context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
return conn, UnixAuthInfo{}, nil
|
||||
}
|
||||
|
||||
// ServerHandshake extracts peer credentials from the Unix connection.
|
||||
// Returns an error if credentials cannot be extracted (fail-closed).
|
||||
func (c *unixCreds) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
|
||||
uc, ok := conn.(*net.UnixConn)
|
||||
if !ok {
|
||||
return nil, nil, fmt.Errorf("expected *net.UnixConn, got %T", conn)
|
||||
}
|
||||
|
||||
raw, err := uc.SyscallConn()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get raw conn for peer credentials: %w", err)
|
||||
}
|
||||
|
||||
var ucred *unix.Ucred
|
||||
var credErr error
|
||||
if err := raw.Control(func(fd uintptr) {
|
||||
ucred, credErr = unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED)
|
||||
}); err != nil {
|
||||
return nil, nil, fmt.Errorf("control raw conn for peer credentials: %w", err)
|
||||
}
|
||||
if credErr != nil {
|
||||
return nil, nil, fmt.Errorf("get peer credentials: %w", credErr)
|
||||
}
|
||||
|
||||
return conn, UnixAuthInfo{
|
||||
CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.NoSecurity},
|
||||
UID: UID(ucred.Uid),
|
||||
GID: ucred.Gid,
|
||||
PID: ucred.Pid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *unixCreds) Info() credentials.ProtocolInfo {
|
||||
return credentials.ProtocolInfo{SecurityProtocol: "unix_peercred"}
|
||||
}
|
||||
|
||||
func (c *unixCreds) Clone() credentials.TransportCredentials {
|
||||
return &unixCreds{}
|
||||
}
|
||||
|
||||
func (c *unixCreds) OverrideServerName(_ string) error {
|
||||
return nil
|
||||
}
|
||||
107
client/internal/owner/transport_test.go
Normal file
107
client/internal/owner/transport_test.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package owner
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc/credentials"
|
||||
)
|
||||
|
||||
func TestUnixTransportCredentials_ServerHandshake(t *testing.T) {
|
||||
creds := NewUnixTransportCredentials()
|
||||
if creds == nil {
|
||||
t.Skip("unix transport credentials not supported on this platform")
|
||||
}
|
||||
|
||||
sockPath := filepath.Join(t.TempDir(), "test.sock")
|
||||
|
||||
ln, err := net.Listen("unix", sockPath)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { ln.Close() })
|
||||
|
||||
done := make(chan struct{})
|
||||
var serverConn net.Conn
|
||||
var serverAuth credentials.AuthInfo
|
||||
var serverErr error
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
raw, err := ln.Accept()
|
||||
if err != nil {
|
||||
serverErr = err
|
||||
return
|
||||
}
|
||||
serverConn, serverAuth, serverErr = creds.ServerHandshake(raw)
|
||||
}()
|
||||
|
||||
client, err := net.Dial("unix", sockPath)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { client.Close() })
|
||||
|
||||
<-done
|
||||
require.NoError(t, serverErr)
|
||||
require.NotNil(t, serverConn)
|
||||
t.Cleanup(func() { serverConn.Close() })
|
||||
|
||||
authInfo, ok := serverAuth.(UnixAuthInfo)
|
||||
require.True(t, ok, "expected UnixAuthInfo, got %T", serverAuth)
|
||||
assert.Equal(t, UID(os.Getuid()), authInfo.UID, "UID should match current user")
|
||||
}
|
||||
|
||||
func TestUnixTransportCredentials_ServerHandshake_NonUnixConn(t *testing.T) {
|
||||
creds := NewUnixTransportCredentials()
|
||||
if creds == nil {
|
||||
t.Skip("unix transport credentials not supported on this platform")
|
||||
}
|
||||
|
||||
// Use a TCP connection, which is not *net.UnixConn.
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { ln.Close() })
|
||||
|
||||
done := make(chan struct{})
|
||||
var handshakeErr error
|
||||
|
||||
go func() {
|
||||
defer close(done)
|
||||
raw, err := ln.Accept()
|
||||
if err != nil {
|
||||
handshakeErr = err
|
||||
return
|
||||
}
|
||||
defer raw.Close()
|
||||
_, _, handshakeErr = creds.ServerHandshake(raw)
|
||||
}()
|
||||
|
||||
client, err := net.Dial("tcp", ln.Addr().String())
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { client.Close() })
|
||||
|
||||
<-done
|
||||
require.Error(t, handshakeErr, "ServerHandshake must fail for non-Unix connections")
|
||||
}
|
||||
|
||||
func TestUnixTransportCredentials_Info(t *testing.T) {
|
||||
creds := NewUnixTransportCredentials()
|
||||
if creds == nil {
|
||||
t.Skip("unix transport credentials not supported on this platform")
|
||||
}
|
||||
|
||||
info := creds.Info()
|
||||
assert.Equal(t, "unix_peercred", info.SecurityProtocol)
|
||||
}
|
||||
|
||||
func TestUnixTransportCredentials_Clone(t *testing.T) {
|
||||
creds := NewUnixTransportCredentials()
|
||||
if creds == nil {
|
||||
t.Skip("unix transport credentials not supported on this platform")
|
||||
}
|
||||
|
||||
cloned := creds.Clone()
|
||||
require.NotNil(t, cloned)
|
||||
assert.Equal(t, creds.Info(), cloned.Info())
|
||||
}
|
||||
5
client/internal/owner/uid.go
Normal file
5
client/internal/owner/uid.go
Normal file
@@ -0,0 +1,5 @@
|
||||
package owner
|
||||
|
||||
// UID is a Unix user ID. Defined as a distinct type so it can't be silently
|
||||
// swapped with GID, PID, or other uint32 values at call sites.
|
||||
type UID uint32
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/internal/owner"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
@@ -99,14 +100,14 @@ type ConfigInput struct {
|
||||
LazyConnectionEnabled *bool
|
||||
|
||||
MTU *uint16
|
||||
|
||||
// OwnerUIDs sets the UIDs of users allowed to control the daemon.
|
||||
// When non-nil, replaces the config's OwnerUIDs.
|
||||
OwnerUIDs []owner.UID
|
||||
}
|
||||
|
||||
// Config Configuration type
|
||||
type Config struct {
|
||||
// Name is the human-readable profile name shown in CLI/UI listings.
|
||||
// It is independent of the profile's on-disk filename (which is the ID).
|
||||
Name string
|
||||
|
||||
// Wireguard private key of local peer
|
||||
PrivateKey string
|
||||
PreSharedKey string
|
||||
@@ -178,6 +179,12 @@ type Config struct {
|
||||
LazyConnectionEnabled bool
|
||||
|
||||
MTU uint16
|
||||
|
||||
// OwnerUIDs controls who can perform privileged daemon operations via the gRPC socket.
|
||||
// nil (absent from JSON): TOFU mode, first privileged caller claims ownership (backward compat for existing installs).
|
||||
// [] (empty slice): root-only, no non-root owners until explicitly set via "netbird up --owner".
|
||||
// [uid1, uid2, ...]: these UIDs plus root can perform privileged operations.
|
||||
OwnerUIDs []owner.UID `json:"OwnerUIDs"`
|
||||
}
|
||||
|
||||
var ConfigDirOverride string
|
||||
@@ -238,10 +245,18 @@ func fileExists(path string) (bool, error) {
|
||||
|
||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||
func createNewConfig(input ConfigInput) (*Config, error) {
|
||||
// Seed owner UIDs from environment if set (for MDM deployments),
|
||||
// otherwise default to root-only (empty slice).
|
||||
ownerUIDs := owner.OwnerUIDsFromEnv()
|
||||
if ownerUIDs == nil {
|
||||
ownerUIDs = []owner.UID{}
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
// defaults to false only for new (post 0.26) configurations
|
||||
ServerSSHAllowed: util.False(),
|
||||
WgPort: iface.DefaultWgPort,
|
||||
OwnerUIDs: ownerUIDs,
|
||||
}
|
||||
|
||||
if _, err := config.apply(input); err != nil {
|
||||
@@ -616,6 +631,14 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.OwnerUIDs != nil {
|
||||
if !slices.Equal(config.OwnerUIDs, input.OwnerUIDs) {
|
||||
log.Infof("updating owner UIDs to %v", input.OwnerUIDs)
|
||||
config.OwnerUIDs = input.OwnerUIDs
|
||||
updated = true
|
||||
}
|
||||
}
|
||||
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
package profilemanager
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
const (
|
||||
// profileIDByteLen is the number of random bytes generated for a new
|
||||
// profile ID. The resulting hex string is twice this length.
|
||||
profileIDByteLen = 16
|
||||
|
||||
// shortIDLen is the number of leading characters of an ID we render in
|
||||
// list output. Profiles per device are few, so 8 chars is collision-safe
|
||||
// in practice and easy to type as a prefix.
|
||||
shortIDLen = 8
|
||||
|
||||
// maxProfileNameLen caps the human-readable profile name to keep table
|
||||
// output legible and prevent denial-of-service via huge JSON fields.
|
||||
maxProfileNameLen = 128
|
||||
|
||||
// maxProfileIDLen bounds the on-disk filename we'll accept. New
|
||||
// IDs are 32 hex chars, legacy stems are sanitized profile names. The
|
||||
// cap is generous enough to cover both without permitting absurdly
|
||||
// long filenames.
|
||||
maxProfileIDLen = 64
|
||||
)
|
||||
|
||||
// generateProfileID returns a new random hex ID for a profile file.
|
||||
func generateProfileID() (string, error) {
|
||||
buf := make([]byte, profileIDByteLen)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return "", fmt.Errorf("read random bytes: %w", err)
|
||||
}
|
||||
return hex.EncodeToString(buf), nil
|
||||
}
|
||||
|
||||
// IsValidProfileFilenameStem reports whether s is safe to use as the stem
|
||||
// of a profile JSON filename.
|
||||
func IsValidProfileFilenameStem(s string) bool {
|
||||
if s == "" || len(s) > maxProfileIDLen {
|
||||
return false
|
||||
}
|
||||
if s == defaultProfileName {
|
||||
return true
|
||||
}
|
||||
if strings.ContainsAny(s, `/\`) || strings.Contains(s, "..") {
|
||||
return false
|
||||
}
|
||||
// filepath.Base catches any leftover separators on platforms with
|
||||
// exotic path conventions.
|
||||
if filepath.Base(s) != s {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if !(unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '-') {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// sanitizeDisplayName normalizes a user-supplied profile display name for
|
||||
// storage. It strips ASCII control characters, rejects invalid UTF-8, and
|
||||
// caps the length. Emojis, spaces, punctuation, and non-ASCII letters are
|
||||
// preserved. Returns an error if nothing usable remains.
|
||||
func sanitizeDisplayName(name string) (string, error) {
|
||||
if !utf8.ValidString(name) {
|
||||
return "", fmt.Errorf("name is not valid UTF-8")
|
||||
}
|
||||
name = StripCtrlChars(name)
|
||||
name = strings.TrimSpace(name)
|
||||
if name == "" {
|
||||
return "", fmt.Errorf("name is empty after sanitization")
|
||||
}
|
||||
if utf8.RuneCountInString(name) > maxProfileNameLen {
|
||||
return "", fmt.Errorf("name exceeds %d characters", maxProfileNameLen)
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
// StripCtrlChars control characters from a name before printing it.
|
||||
func StripCtrlChars(name string) string {
|
||||
var b strings.Builder
|
||||
b.Grow(len(name))
|
||||
for _, r := range name {
|
||||
// Skip C0 controls and DEL, plus C1 controls (0x80–0x9F).
|
||||
if r < 0x20 || r == 0x7F || (r >= 0x80 && r <= 0x9F) {
|
||||
continue
|
||||
}
|
||||
b.WriteRune(r)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// ShortID truncates an ID for display.
|
||||
func ShortID(id string) string {
|
||||
if id == DefaultProfileName {
|
||||
return id
|
||||
}
|
||||
if len(id) <= shortIDLen {
|
||||
return id
|
||||
}
|
||||
return id[:shortIDLen]
|
||||
}
|
||||
@@ -19,41 +19,19 @@ const (
|
||||
)
|
||||
|
||||
type Profile struct {
|
||||
// ID is the on-disk filename stem (without .json). For new profiles
|
||||
// it is a 32-char hex string; legacy profiles created before the
|
||||
// ID-keyed layout keep their original name as their ID. The reserved
|
||||
// value "default" identifies the special default profile.
|
||||
ID string
|
||||
// Name is the human-readable display name. Falls back to ID when the
|
||||
// underlying JSON has no "name" field set.
|
||||
Name string
|
||||
// Path is the absolute path to the profile JSON. Populated by the
|
||||
// loader so callers do not have to reconstruct it from ID + dir.
|
||||
Path string
|
||||
Name string
|
||||
IsActive bool
|
||||
}
|
||||
|
||||
func (p *Profile) FilePath() (string, error) {
|
||||
if p.Path != "" {
|
||||
return p.Path, nil
|
||||
if p.Name == "" {
|
||||
return "", fmt.Errorf("active profile name is empty")
|
||||
}
|
||||
|
||||
id := p.ID
|
||||
if id == "" {
|
||||
id = p.Name
|
||||
}
|
||||
if id == "" {
|
||||
return "", fmt.Errorf("profile ID is empty")
|
||||
}
|
||||
|
||||
if id == defaultProfileName {
|
||||
if p.Name == defaultProfileName {
|
||||
return DefaultConfigPath, nil
|
||||
}
|
||||
|
||||
if !IsValidProfileFilenameStem(id) {
|
||||
return "", fmt.Errorf("invalid profile ID: %q", id)
|
||||
}
|
||||
|
||||
username, err := user.Current()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get current user: %w", err)
|
||||
@@ -64,13 +42,10 @@ func (p *Profile) FilePath() (string, error) {
|
||||
return "", fmt.Errorf("failed to get config directory for user %s: %w", username.Username, err)
|
||||
}
|
||||
|
||||
return filepath.Join(configDir, id+".json"), nil
|
||||
return filepath.Join(configDir, p.Name+".json"), nil
|
||||
}
|
||||
|
||||
func (p *Profile) IsDefault() bool {
|
||||
if p.ID != "" {
|
||||
return p.ID == defaultProfileName
|
||||
}
|
||||
return p.Name == defaultProfileName
|
||||
}
|
||||
|
||||
@@ -82,24 +57,18 @@ func NewProfileManager() *ProfileManager {
|
||||
return &ProfileManager{}
|
||||
}
|
||||
|
||||
// GetActiveProfile returns the active profile as recorded in the local
|
||||
// user state file. Only ID is populated.
|
||||
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
id := pm.getActiveProfileState()
|
||||
return &Profile{ID: id}, nil
|
||||
prof := pm.getActiveProfileState()
|
||||
return &Profile{Name: prof}, nil
|
||||
}
|
||||
|
||||
// SwitchProfile records the given profile ID as active in the local user
|
||||
// state file.
|
||||
func (pm *ProfileManager) SwitchProfile(id string) error {
|
||||
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
|
||||
return fmt.Errorf("invalid profile ID: %q", id)
|
||||
}
|
||||
func (pm *ProfileManager) SwitchProfile(profileName string) error {
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
|
||||
if err := pm.setActiveProfileState(id); err != nil {
|
||||
if err := pm.setActiveProfileState(profileName); err != nil {
|
||||
return fmt.Errorf("failed to switch profile: %w", err)
|
||||
}
|
||||
return nil
|
||||
@@ -173,7 +142,7 @@ func GetLoginHint() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
profileState, err := pm.GetProfileState(activeProf.ID)
|
||||
profileState, err := pm.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
return ""
|
||||
|
||||
@@ -50,14 +50,14 @@ func TestServiceManager_CreateAndGetDefaultProfile(t *testing.T) {
|
||||
|
||||
state, err := sm.GetActiveProfileState()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, defaultProfileName, state.ID) // No active profile state yet
|
||||
assert.Equal(t, state.Name, defaultProfileName) // No active profile state yet
|
||||
|
||||
err = sm.SetActiveProfileStateToDefault()
|
||||
assert.NoError(t, err)
|
||||
|
||||
active, err := sm.GetActiveProfileState()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "default", active.ID)
|
||||
assert.Equal(t, "default", active.Name)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -92,14 +92,14 @@ func TestServiceManager_SetActiveProfileState(t *testing.T) {
|
||||
currUser, err := user.Current()
|
||||
assert.NoError(t, err)
|
||||
sm := &ServiceManager{}
|
||||
state := &ActiveProfileState{ID: "foo", Username: currUser.Username}
|
||||
state := &ActiveProfileState{Name: "foo", Username: currUser.Username}
|
||||
err = sm.SetActiveProfileState(state)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Should error on nil or incomplete state
|
||||
err = sm.SetActiveProfileState(nil)
|
||||
assert.Error(t, err)
|
||||
err = sm.SetActiveProfileState(&ActiveProfileState{ID: "", Username: ""})
|
||||
err = sm.SetActiveProfileState(&ActiveProfileState{Name: "", Username: ""})
|
||||
assert.Error(t, err)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -2,7 +2,6 @@ package profilemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -14,6 +13,7 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/owner"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -24,43 +24,12 @@ var (
|
||||
DefaultConfigPathDir = ""
|
||||
DefaultConfigPath = ""
|
||||
ActiveProfileStatePath = ""
|
||||
)
|
||||
|
||||
var (
|
||||
ErrorOldDefaultConfigNotFound = errors.New("old default config not found")
|
||||
)
|
||||
|
||||
// ErrAmbiguousHandle is returned when a profile handle (ID prefix or name)
|
||||
// matches more than one profile. Callers can render Candidates to help the
|
||||
// user disambiguate.
|
||||
type ErrAmbiguousHandle struct {
|
||||
Handle string
|
||||
Candidates []Profile
|
||||
Kind AmbiguityKind
|
||||
}
|
||||
|
||||
// AmbiguityKind describes which matcher produced the ambiguity, so callers
|
||||
// can tailor the error message.
|
||||
type AmbiguityKind int
|
||||
|
||||
const (
|
||||
AmbiguityKindIDPrefix AmbiguityKind = iota
|
||||
AmbiguityKindName
|
||||
)
|
||||
|
||||
// profileMeta is the minimal slice of a profile JSON we need, so we avoid
|
||||
// reading all fields
|
||||
type profileMeta struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
func (e *ErrAmbiguousHandle) Error() string {
|
||||
switch e.Kind {
|
||||
case AmbiguityKindIDPrefix:
|
||||
return fmt.Sprintf("ID prefix %q is ambiguous (matches %d profiles)", e.Handle, len(e.Candidates))
|
||||
default:
|
||||
return fmt.Sprintf("name %q is ambiguous (%d profiles share this name)", e.Handle, len(e.Candidates))
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
DefaultConfigPathDir = "/var/lib/netbird/"
|
||||
@@ -86,34 +55,25 @@ func init() {
|
||||
}
|
||||
|
||||
type ActiveProfileState struct {
|
||||
// ID is the on-disk filename stem of the active profile. The JSON tag stays
|
||||
// as "name" for backwards compatibility with active state files written
|
||||
// before the ID-based config files. Legacy values were profile names, which
|
||||
// were also the legacy filename stems, so they still resolve to the correct
|
||||
// file on disk.
|
||||
ID string `json:"name"`
|
||||
Name string `json:"name"`
|
||||
Username string `json:"username"`
|
||||
}
|
||||
|
||||
func (a *ActiveProfileState) FilePath() (string, error) {
|
||||
if a.ID == "" {
|
||||
return "", fmt.Errorf("active profile ID is empty")
|
||||
if a.Name == "" {
|
||||
return "", fmt.Errorf("active profile name is empty")
|
||||
}
|
||||
|
||||
if a.ID == defaultProfileName {
|
||||
if a.Name == defaultProfileName {
|
||||
return DefaultConfigPath, nil
|
||||
}
|
||||
|
||||
if !IsValidProfileFilenameStem(a.ID) {
|
||||
return "", fmt.Errorf("invalid profile ID: %q", a.ID)
|
||||
}
|
||||
|
||||
configDir, err := getConfigDirForUser(a.Username)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get config directory for user %s: %w", a.Username, err)
|
||||
}
|
||||
|
||||
return filepath.Join(configDir, a.ID+".json"), nil
|
||||
return filepath.Join(configDir, a.Name+".json"), nil
|
||||
}
|
||||
|
||||
type ServiceManager struct {
|
||||
@@ -219,7 +179,7 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
|
||||
return nil, fmt.Errorf("failed to set active profile to default: %w", err)
|
||||
}
|
||||
return &ActiveProfileState{
|
||||
ID: defaultProfileName,
|
||||
Name: "default",
|
||||
Username: "",
|
||||
}, nil
|
||||
} else {
|
||||
@@ -227,12 +187,12 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if activeProfile.ID == "" {
|
||||
if activeProfile.Name == "" {
|
||||
if err := s.SetActiveProfileStateToDefault(); err != nil {
|
||||
return nil, fmt.Errorf("failed to set active profile to default: %w", err)
|
||||
}
|
||||
return &ActiveProfileState{
|
||||
ID: defaultProfileName,
|
||||
Name: "default",
|
||||
Username: "",
|
||||
}, nil
|
||||
}
|
||||
@@ -257,29 +217,25 @@ func (s *ServiceManager) setDefaultActiveState() error {
|
||||
}
|
||||
|
||||
func (s *ServiceManager) SetActiveProfileState(a *ActiveProfileState) error {
|
||||
if a == nil || a.ID == "" {
|
||||
if a == nil || a.Name == "" {
|
||||
return errors.New("invalid active profile state")
|
||||
}
|
||||
|
||||
if a.ID != defaultProfileName && a.Username == "" {
|
||||
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.ID)
|
||||
}
|
||||
|
||||
if a.ID != defaultProfileName && !IsValidProfileFilenameStem(a.ID) {
|
||||
return fmt.Errorf("invalid profile ID: %q", a.ID)
|
||||
if a.Name != defaultProfileName && a.Username == "" {
|
||||
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.Name)
|
||||
}
|
||||
|
||||
if err := util.WriteJsonWithRestrictedPermission(context.Background(), ActiveProfileStatePath, a); err != nil {
|
||||
return fmt.Errorf("failed to write active profile state: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("active profile set to %s for %s", a.ID, a.Username)
|
||||
log.Infof("active profile set to %s for %s", a.Name, a.Username)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *ServiceManager) SetActiveProfileStateToDefault() error {
|
||||
return s.SetActiveProfileState(&ActiveProfileState{
|
||||
ID: defaultProfileName,
|
||||
Name: "default",
|
||||
Username: "",
|
||||
})
|
||||
}
|
||||
@@ -288,75 +244,60 @@ func (s *ServiceManager) DefaultProfilePath() string {
|
||||
return DefaultConfigPath
|
||||
}
|
||||
|
||||
// AddProfile creates a new profile with a generated ID. The user-supplied
|
||||
// displayName is stored inside the JSON's name field, the on-disk filename
|
||||
// uses the generated ID.
|
||||
//
|
||||
// The returned Profile carries the freshly-generated ID so callers can
|
||||
// show it to the user (and so the gRPC AddProfileResponse can include
|
||||
// it).
|
||||
func (s *ServiceManager) AddProfile(displayName, username string) (*Profile, error) {
|
||||
// AddProfile creates a new profile with the given name. inheritOwnerUIDs is
|
||||
// applied to the new profile's OwnerUIDs (pass the active profile's owners so
|
||||
// the caller stays authorized; pass nil to leave the default empty/env-seeded).
|
||||
func (s *ServiceManager) AddProfile(profileName, username string, inheritOwnerUIDs []owner.UID) error {
|
||||
configDir, err := s.getConfigDir(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get config directory: %w", err)
|
||||
return fmt.Errorf("failed to get config directory: %w", err)
|
||||
}
|
||||
|
||||
displayName, err = sanitizeDisplayName(displayName)
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
|
||||
if profileName == defaultProfileName {
|
||||
return fmt.Errorf("cannot create profile with reserved name: %s", defaultProfileName)
|
||||
}
|
||||
|
||||
profPath := filepath.Join(configDir, profileName+".json")
|
||||
profileExists, err := fileExists(profPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid profile name: %w", err)
|
||||
return fmt.Errorf("failed to check if profile exists: %w", err)
|
||||
}
|
||||
if profileExists {
|
||||
return ErrProfileAlreadyExists
|
||||
}
|
||||
|
||||
if displayName == defaultProfileName {
|
||||
return nil, fmt.Errorf("cannot create profile with reserved name: %s", defaultProfileName)
|
||||
}
|
||||
|
||||
id, err := generateProfileID()
|
||||
cfg, err := createNewConfig(ConfigInput{ConfigPath: profPath, OwnerUIDs: inheritOwnerUIDs})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate profile id: %w", err)
|
||||
return fmt.Errorf("failed to create new config: %w", err)
|
||||
}
|
||||
|
||||
profPath := filepath.Join(configDir, id+".json")
|
||||
cfg, err := createNewConfig(ConfigInput{ConfigPath: profPath})
|
||||
err = util.WriteJson(context.Background(), profPath, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create new config: %w", err)
|
||||
}
|
||||
cfg.Name = displayName
|
||||
|
||||
if err := util.WriteJson(context.Background(), profPath, cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to write profile config: %w", err)
|
||||
return fmt.Errorf("failed to write profile config: %w", err)
|
||||
}
|
||||
|
||||
return &Profile{
|
||||
ID: id,
|
||||
Name: displayName,
|
||||
Path: profPath,
|
||||
}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveProfile deletes the profile identified by id. Callers must have
|
||||
// already resolved any user-supplied handle to a concrete ID via
|
||||
// ResolveProfile.
|
||||
func (s *ServiceManager) RemoveProfile(id, username string) error {
|
||||
if id == defaultProfileName {
|
||||
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
||||
configDir, err := s.getConfigDir(username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get config directory: %w", err)
|
||||
}
|
||||
|
||||
profileName = sanitizeProfileName(profileName)
|
||||
|
||||
if profileName == defaultProfileName {
|
||||
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
|
||||
}
|
||||
if !IsValidProfileFilenameStem(id) {
|
||||
return fmt.Errorf("invalid profile ID: %q", id)
|
||||
}
|
||||
|
||||
profiles, err := s.loadAllProfiles(username)
|
||||
profPath := filepath.Join(configDir, profileName+".json")
|
||||
profileExists, err := fileExists(profPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("load profiles: %w", err)
|
||||
return fmt.Errorf("failed to check if profile exists: %w", err)
|
||||
}
|
||||
|
||||
var target *Profile
|
||||
for i := range profiles {
|
||||
if profiles[i].ID == id {
|
||||
target = &profiles[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if target == nil {
|
||||
if !profileExists {
|
||||
return ErrProfileNotFound
|
||||
}
|
||||
|
||||
@@ -364,26 +305,57 @@ func (s *ServiceManager) RemoveProfile(id, username string) error {
|
||||
if err != nil && !errors.Is(err, ErrNoActiveProfile) {
|
||||
return fmt.Errorf("failed to get active profile: %w", err)
|
||||
}
|
||||
if activeProf != nil && activeProf.ID == id {
|
||||
return fmt.Errorf("cannot remove active profile: %s", id)
|
||||
|
||||
if activeProf != nil && activeProf.Name == profileName {
|
||||
return fmt.Errorf("cannot remove active profile: %s", profileName)
|
||||
}
|
||||
|
||||
if err := util.RemoveJson(target.Path); err != nil {
|
||||
err = util.RemoveJson(profPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove profile config: %w", err)
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(filepath.Dir(target.Path), id+".state.json")
|
||||
if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
|
||||
log.Warnf("failed to remove profile state file %s: %v", stateFile, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListProfiles returns every profile for the given user, including the
|
||||
// default profile, with IsActive flags set.
|
||||
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
|
||||
return s.loadAllProfiles(username)
|
||||
configDir, err := s.getConfigDir(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get config directory: %w", err)
|
||||
}
|
||||
|
||||
files, err := util.ListFiles(configDir, "*.json")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list profile files: %w", err)
|
||||
}
|
||||
|
||||
var filtered []string
|
||||
for _, file := range files {
|
||||
if strings.HasSuffix(file, "state.json") {
|
||||
continue // skip state files
|
||||
}
|
||||
filtered = append(filtered, file)
|
||||
}
|
||||
sort.Strings(filtered)
|
||||
|
||||
var activeProfName string
|
||||
activeProf, err := s.GetActiveProfileState()
|
||||
if err == nil {
|
||||
activeProfName = activeProf.Name
|
||||
}
|
||||
|
||||
var profiles []Profile
|
||||
// add default profile always
|
||||
profiles = append(profiles, Profile{Name: defaultProfileName, IsActive: activeProfName == "" || activeProfName == defaultProfileName})
|
||||
for _, file := range filtered {
|
||||
profileName := strings.TrimSuffix(filepath.Base(file), ".json")
|
||||
var isActive bool
|
||||
if activeProfName != "" && activeProfName == profileName {
|
||||
isActive = true
|
||||
}
|
||||
profiles = append(profiles, Profile{Name: profileName, IsActive: isActive})
|
||||
}
|
||||
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// GetStatePath returns the path to the state file based on the operating system
|
||||
@@ -401,12 +373,7 @@ func (s *ServiceManager) GetStatePath() string {
|
||||
return defaultStatePath
|
||||
}
|
||||
|
||||
if activeProf.ID == defaultProfileName {
|
||||
return defaultStatePath
|
||||
}
|
||||
|
||||
if !IsValidProfileFilenameStem(activeProf.ID) {
|
||||
log.Warnf("invalid active profile ID %q, using default state path", activeProf.ID)
|
||||
if activeProf.Name == defaultProfileName {
|
||||
return defaultStatePath
|
||||
}
|
||||
|
||||
@@ -416,7 +383,7 @@ func (s *ServiceManager) GetStatePath() string {
|
||||
return defaultStatePath
|
||||
}
|
||||
|
||||
return filepath.Join(configDir, activeProf.ID+".state.json")
|
||||
return filepath.Join(configDir, activeProf.Name+".state.json")
|
||||
}
|
||||
|
||||
// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
|
||||
@@ -427,165 +394,3 @@ func (s *ServiceManager) getConfigDir(username string) (string, error) {
|
||||
|
||||
return getConfigDirForUser(username)
|
||||
}
|
||||
|
||||
// loadAllProfiles returns every profile visible to the daemon for the
|
||||
// given user, including the default profile. The returned slice is sorted
|
||||
// by ID for a stable display order.
|
||||
//
|
||||
// Each Profile is fully populated: ID is the filename stem, Name comes
|
||||
// from the JSON's "name" field (falling back to the filename stem when absent)
|
||||
// and Path is built from a basename read off disk.
|
||||
func (s *ServiceManager) loadAllProfiles(username string) ([]Profile, error) {
|
||||
activeID, activeIsDefault := s.activeProfileID()
|
||||
|
||||
profiles := []Profile{{
|
||||
ID: defaultProfileName,
|
||||
Name: defaultProfileName,
|
||||
Path: DefaultConfigPath,
|
||||
IsActive: activeIsDefault,
|
||||
}}
|
||||
|
||||
configDir, err := s.getConfigDir(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get config directory: %w", err)
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(configDir)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return profiles, nil
|
||||
}
|
||||
return nil, fmt.Errorf("read profile directory: %w", err)
|
||||
}
|
||||
|
||||
var fileProfiles []Profile
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
base := entry.Name()
|
||||
if !strings.HasSuffix(base, ".json") {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(base, ".state.json") {
|
||||
continue
|
||||
}
|
||||
stem := strings.TrimSuffix(base, ".json")
|
||||
if stem == defaultProfileName {
|
||||
// default lives at the top-level config dir, not under /<user>
|
||||
continue
|
||||
}
|
||||
if !IsValidProfileFilenameStem(stem) {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(configDir, base)
|
||||
name := readProfileName(path)
|
||||
if name == "" {
|
||||
name = stem
|
||||
}
|
||||
fileProfiles = append(fileProfiles, Profile{
|
||||
ID: stem,
|
||||
Name: name,
|
||||
Path: path,
|
||||
IsActive: stem == activeID,
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(fileProfiles, func(i, j int) bool {
|
||||
if fileProfiles[i].Name != fileProfiles[j].Name {
|
||||
return fileProfiles[i].Name < fileProfiles[j].Name
|
||||
}
|
||||
// Sort tie-break on ID so duplicate names always render in the same order.
|
||||
return fileProfiles[i].ID < fileProfiles[j].ID
|
||||
})
|
||||
profiles = append(profiles, fileProfiles...)
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
// readProfileName parses just the "name" field from the profile Json.
|
||||
func readProfileName(path string) string {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
var meta profileMeta
|
||||
if err := json.Unmarshal(data, &meta); err != nil {
|
||||
return ""
|
||||
}
|
||||
return meta.Name
|
||||
}
|
||||
|
||||
// activeProfileID returns the currently-active profile's ID. The second
|
||||
// return value is true when the active profile is the default one.
|
||||
func (s *ServiceManager) activeProfileID() (string, bool) {
|
||||
state, err := s.GetActiveProfileState()
|
||||
if err != nil || state == nil {
|
||||
return defaultProfileName, true
|
||||
}
|
||||
if state.ID == "" || state.ID == defaultProfileName {
|
||||
return defaultProfileName, true
|
||||
}
|
||||
return state.ID, false
|
||||
}
|
||||
|
||||
// ResolveProfile turns a user-supplied handle into a Profile. Resolution
|
||||
// precedence is: exact ID match, then unique ID prefix, then unique exact
|
||||
// name. Ambiguous matches return *ErrAmbiguousHandle so callers can
|
||||
// surface the candidates.
|
||||
func (s *ServiceManager) ResolveProfile(handle, username string) (*Profile, error) {
|
||||
if handle == "" {
|
||||
return nil, fmt.Errorf("profile handle is empty")
|
||||
}
|
||||
|
||||
profiles, err := s.loadAllProfiles(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range profiles {
|
||||
if profiles[i].ID == handle {
|
||||
return &profiles[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
// ID prefix match. Skip the default profile so `select d` does not
|
||||
// accidentally pick it via prefix.
|
||||
var prefixMatches []Profile
|
||||
for i := range profiles {
|
||||
if profiles[i].ID == defaultProfileName {
|
||||
continue
|
||||
}
|
||||
if strings.HasPrefix(profiles[i].ID, handle) {
|
||||
prefixMatches = append(prefixMatches, profiles[i])
|
||||
}
|
||||
}
|
||||
if len(prefixMatches) == 1 {
|
||||
return &prefixMatches[0], nil
|
||||
}
|
||||
if len(prefixMatches) > 1 {
|
||||
return nil, &ErrAmbiguousHandle{
|
||||
Handle: handle,
|
||||
Candidates: prefixMatches,
|
||||
Kind: AmbiguityKindIDPrefix,
|
||||
}
|
||||
}
|
||||
|
||||
var nameMatches []Profile
|
||||
for i := range profiles {
|
||||
if profiles[i].Name == handle {
|
||||
nameMatches = append(nameMatches, profiles[i])
|
||||
}
|
||||
}
|
||||
if len(nameMatches) == 1 {
|
||||
return &nameMatches[0], nil
|
||||
}
|
||||
if len(nameMatches) > 1 {
|
||||
return nil, &ErrAmbiguousHandle{
|
||||
Handle: handle,
|
||||
Candidates: nameMatches,
|
||||
Kind: AmbiguityKindName,
|
||||
}
|
||||
}
|
||||
|
||||
return nil, ErrProfileNotFound
|
||||
}
|
||||
|
||||
@@ -1,230 +0,0 @@
|
||||
package profilemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// withTestSM wires up patched globals + a clean config dir and returns a
|
||||
// fully initialized ServiceManager plus the username we are scoped to.
|
||||
func withTestSM(t *testing.T, fn func(sm *ServiceManager, username string)) {
|
||||
t.Helper()
|
||||
withTempConfigDir(t, func(configDir string) {
|
||||
withPatchedGlobals(t, configDir, func() {
|
||||
u, err := user.Current()
|
||||
require.NoError(t, err)
|
||||
sm := &ServiceManager{}
|
||||
require.NoError(t, sm.CreateDefaultProfile())
|
||||
fn(sm, u.Username)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_ExactID(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
created, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := sm.ResolveProfile(created.ID, username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, created.ID, got.ID)
|
||||
assert.Equal(t, "work", got.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_IDPrefix(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
created, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
prefix := created.ID[:4]
|
||||
got, err := sm.ResolveProfile(prefix, username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, created.ID, got.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_AmbiguousPrefix(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
// Plant two profiles whose IDs share a known prefix by writing
|
||||
// the files directly, since generated IDs are random.
|
||||
configDir, err := sm.getConfigDir(username)
|
||||
require.NoError(t, err)
|
||||
for _, id := range []string{"abcd1111aaaa", "abcd2222bbbb"} {
|
||||
path := filepath.Join(configDir, id+".json")
|
||||
require.NoError(t, util.WriteJson(context.Background(), path, &Config{Name: id}))
|
||||
}
|
||||
|
||||
_, err = sm.ResolveProfile("abcd", username)
|
||||
var amb *ErrAmbiguousHandle
|
||||
require.ErrorAs(t, err, &amb)
|
||||
assert.Equal(t, AmbiguityKindIDPrefix, amb.Kind)
|
||||
assert.Len(t, amb.Candidates, 2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_ExactNameUnique(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
_, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := sm.ResolveProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "work", got.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_AmbiguousName(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
_, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
_, err = sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = sm.ResolveProfile("work", username)
|
||||
var amb *ErrAmbiguousHandle
|
||||
require.ErrorAs(t, err, &amb)
|
||||
assert.Equal(t, AmbiguityKindName, amb.Kind)
|
||||
assert.Len(t, amb.Candidates, 2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_NotFound(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
_, err := sm.ResolveProfile("nope", username)
|
||||
assert.ErrorIs(t, err, ErrProfileNotFound)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_DefaultByExactID(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
got, err := sm.ResolveProfile(defaultProfileName, username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, defaultProfileName, got.ID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServiceProfile_LegacyFilenameCoexists(t *testing.T) {
|
||||
// Legacy profiles stored as <name>.json with no "name" JSON field
|
||||
// should still be discoverable by name and removable by name.
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
configDir, err := sm.getConfigDir(username)
|
||||
require.NoError(t, err)
|
||||
path := filepath.Join(configDir, "legacy.json")
|
||||
require.NoError(t, util.WriteJson(context.Background(), path, &Config{}))
|
||||
|
||||
got, err := sm.ResolveProfile("legacy", username)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "legacy", got.ID)
|
||||
// Name falls back to the filename stem when JSON omits it.
|
||||
assert.Equal(t, "legacy", got.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddProfile_AllowsDuplicateWithFlag(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
first, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
second, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
assert.NotEqual(t, first.ID, second.ID)
|
||||
assert.Equal(t, "work", second.Name)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAddProfile_RejectsInvalidNames(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
cases := []string{
|
||||
"", // empty
|
||||
"\x00\x01", // only control chars (becomes empty)
|
||||
strings.Repeat("a", maxProfileNameLen+1), // too long
|
||||
}
|
||||
for _, name := range cases {
|
||||
_, err := sm.AddProfile(name, username)
|
||||
assert.Error(t, err, "expected error for %q", name)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRemoveProfile_RejectsInvalidID(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
err := sm.RemoveProfile("../escape", username)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestSanitizeDisplayName(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{"work", "work", false},
|
||||
{"My Work Account", "My Work Account", false},
|
||||
{"emoji 🚀 ok", "emoji 🚀 ok", false},
|
||||
{"漢字テスト", "漢字テスト", false},
|
||||
{"with\x00null", "withnull", false},
|
||||
{"\x01\x02\x03", "", true},
|
||||
{"", "", true},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got, err := sanitizeDisplayName(tc.in)
|
||||
if tc.wantErr {
|
||||
assert.Error(t, err, "case %q", tc.in)
|
||||
continue
|
||||
}
|
||||
assert.NoError(t, err, "case %q", tc.in)
|
||||
assert.Equal(t, tc.want, got, "case %q", tc.in)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsValidProfileFilenameStem(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want bool
|
||||
}{
|
||||
{"default", true},
|
||||
{"abc123def456", true},
|
||||
{"legacy-name", true},
|
||||
{"legacy_name", true},
|
||||
{"", false},
|
||||
{"..", false},
|
||||
{"../etc", false},
|
||||
{"foo/bar", false},
|
||||
{`foo\bar`, false},
|
||||
{"with space", false},
|
||||
{"with.dot", false},
|
||||
{strings.Repeat("a", maxProfileIDLen+1), false},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
got := IsValidProfileFilenameStem(tc.in)
|
||||
assert.Equal(t, tc.want, got, "case %q", tc.in)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveProfile_DeletesStateFile(t *testing.T) {
|
||||
withTestSM(t, func(sm *ServiceManager, username string) {
|
||||
created, err := sm.AddProfile("work", username)
|
||||
require.NoError(t, err)
|
||||
|
||||
configDir, err := sm.getConfigDir(username)
|
||||
require.NoError(t, err)
|
||||
statePath := filepath.Join(configDir, created.ID+".state.json")
|
||||
require.NoError(t, os.WriteFile(statePath, []byte(`{"email":"a@b"}`), 0600))
|
||||
|
||||
require.NoError(t, sm.RemoveProfile(created.ID, username))
|
||||
_, err = os.Stat(statePath)
|
||||
assert.True(t, errors.Is(err, os.ErrNotExist), "state file should be removed")
|
||||
})
|
||||
}
|
||||
@@ -13,20 +13,13 @@ type ProfileState struct {
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
// GetProfileState reads the per-profile state file keyed by profile ID.
|
||||
// The state file lives in the user's config directory. Legacy state files
|
||||
// keyed by the old profile name remain readable.
|
||||
func (pm *ProfileManager) GetProfileState(id string) (*ProfileState, error) {
|
||||
func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, error) {
|
||||
configDir, err := getConfigDir()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get config directory: %w", err)
|
||||
}
|
||||
|
||||
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
|
||||
return nil, fmt.Errorf("invalid profile ID: %q", id)
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(configDir, id+".state.json")
|
||||
stateFile := filepath.Join(configDir, profileName+".state.json")
|
||||
stateFileExists, err := fileExists(stateFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
|
||||
@@ -58,12 +51,7 @@ func (pm *ProfileManager) SetActiveProfileState(state *ProfileState) error {
|
||||
return fmt.Errorf("get active profile: %w", err)
|
||||
}
|
||||
|
||||
id := activeProf.ID
|
||||
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
|
||||
return fmt.Errorf("invalid active profile ID: %q", id)
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(configDir, id+".state.json")
|
||||
stateFile := filepath.Join(configDir, activeProf.Name+".state.json")
|
||||
err = util.WriteJsonWithRestrictedPermission(context.Background(), stateFile, state)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write profile state: %w", err)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -91,6 +91,15 @@ service DaemonService {
|
||||
|
||||
rpc GetActiveProfile(GetActiveProfileRequest) returns (GetActiveProfileResponse) {}
|
||||
|
||||
// AddOwner adds a UID to the active profile's owner list. Requires
|
||||
// root or an existing owner.
|
||||
rpc AddOwner(AddOwnerRequest) returns (AddOwnerResponse) {}
|
||||
|
||||
// ResetOwner clears the active profile's owner list, returning it to
|
||||
// the unconfigured state. The next call from the active console-session
|
||||
// user will then re-claim ownership. Requires root.
|
||||
rpc ResetOwner(ResetOwnerRequest) returns (ResetOwnerResponse) {}
|
||||
|
||||
// Logout disconnects from the network and deletes the peer from the management server
|
||||
rpc Logout(LogoutRequest) returns (LogoutResponse) {}
|
||||
|
||||
@@ -227,6 +236,10 @@ message UpRequest {
|
||||
optional string profileName = 1;
|
||||
optional string username = 2;
|
||||
reserved 3;
|
||||
// When true, the caller claims owner privileges for this profile.
|
||||
// Requires root or current owner; for new installs (root-only mode),
|
||||
// the calling UID becomes an owner.
|
||||
bool claimOwner = 4;
|
||||
}
|
||||
|
||||
message UpResponse {}
|
||||
@@ -613,18 +626,11 @@ message GetEventsResponse {
|
||||
}
|
||||
|
||||
message SwitchProfileRequest {
|
||||
// profileName is treated as a handle: exact ID, unique ID prefix, or
|
||||
// unique display name. The daemon resolves it server-side.
|
||||
optional string profileName = 1;
|
||||
optional string username = 2;
|
||||
}
|
||||
|
||||
message SwitchProfileResponse {
|
||||
// id is the resolved on-disk ID of the profile that became active.
|
||||
// Lets CLI clients update their local active-profile state without
|
||||
// duplicating the resolution logic.
|
||||
string id = 1;
|
||||
}
|
||||
message SwitchProfileResponse {}
|
||||
|
||||
message SetConfigRequest {
|
||||
string username = 1;
|
||||
@@ -691,29 +697,27 @@ message SetConfigResponse{}
|
||||
|
||||
message AddProfileRequest {
|
||||
string username = 1;
|
||||
// profileName carries the human-readable display name for the new
|
||||
// profile. The on-disk filename is a separately-generated ID.
|
||||
string profileName = 2;
|
||||
}
|
||||
|
||||
message AddProfileResponse {
|
||||
// id is the generated on-disk ID of the new profile. CLI clients
|
||||
// display a truncated form, UI clients can ignore it.
|
||||
string id = 1;
|
||||
message AddProfileResponse {}
|
||||
|
||||
message AddOwnerRequest {
|
||||
uint32 uid = 1;
|
||||
}
|
||||
|
||||
message AddOwnerResponse {}
|
||||
|
||||
message ResetOwnerRequest {}
|
||||
|
||||
message ResetOwnerResponse {}
|
||||
|
||||
message RemoveProfileRequest {
|
||||
string username = 1;
|
||||
// profileName is treated as a handle: an exact ID, a unique ID
|
||||
// prefix, or a unique display name. Resolution happens server-side.
|
||||
string profileName = 2;
|
||||
}
|
||||
|
||||
message RemoveProfileResponse {
|
||||
// id is the full resolved ID of the removed profile, so callers can
|
||||
// confirm exactly which profile a name/prefix handle resolved to.
|
||||
string id = 1;
|
||||
}
|
||||
message RemoveProfileResponse {}
|
||||
|
||||
message ListProfilesRequest {
|
||||
string username = 1;
|
||||
@@ -726,7 +730,6 @@ message ListProfilesResponse {
|
||||
message Profile {
|
||||
string name = 1;
|
||||
bool is_active = 2;
|
||||
string id = 3;
|
||||
}
|
||||
|
||||
message GetActiveProfileRequest {}
|
||||
@@ -734,7 +737,6 @@ message GetActiveProfileRequest {}
|
||||
message GetActiveProfileResponse {
|
||||
string profileName = 1;
|
||||
string username = 2;
|
||||
string id = 3;
|
||||
}
|
||||
|
||||
message LogoutRequest {
|
||||
|
||||
@@ -48,6 +48,8 @@ const (
|
||||
DaemonService_RemoveProfile_FullMethodName = "/daemon.DaemonService/RemoveProfile"
|
||||
DaemonService_ListProfiles_FullMethodName = "/daemon.DaemonService/ListProfiles"
|
||||
DaemonService_GetActiveProfile_FullMethodName = "/daemon.DaemonService/GetActiveProfile"
|
||||
DaemonService_AddOwner_FullMethodName = "/daemon.DaemonService/AddOwner"
|
||||
DaemonService_ResetOwner_FullMethodName = "/daemon.DaemonService/ResetOwner"
|
||||
DaemonService_Logout_FullMethodName = "/daemon.DaemonService/Logout"
|
||||
DaemonService_GetFeatures_FullMethodName = "/daemon.DaemonService/GetFeatures"
|
||||
DaemonService_TriggerUpdate_FullMethodName = "/daemon.DaemonService/TriggerUpdate"
|
||||
@@ -115,6 +117,13 @@ type DaemonServiceClient interface {
|
||||
RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error)
|
||||
ListProfiles(ctx context.Context, in *ListProfilesRequest, opts ...grpc.CallOption) (*ListProfilesResponse, error)
|
||||
GetActiveProfile(ctx context.Context, in *GetActiveProfileRequest, opts ...grpc.CallOption) (*GetActiveProfileResponse, error)
|
||||
// AddOwner adds a UID to the active profile's owner list. Requires
|
||||
// root or an existing owner.
|
||||
AddOwner(ctx context.Context, in *AddOwnerRequest, opts ...grpc.CallOption) (*AddOwnerResponse, error)
|
||||
// ResetOwner clears the active profile's owner list, returning it to
|
||||
// the unconfigured state. The next call from the active console-session
|
||||
// user will then re-claim ownership. Requires root.
|
||||
ResetOwner(ctx context.Context, in *ResetOwnerRequest, opts ...grpc.CallOption) (*ResetOwnerResponse, error)
|
||||
// Logout disconnects from the network and deletes the peer from the management server
|
||||
Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error)
|
||||
GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error)
|
||||
@@ -452,6 +461,26 @@ func (c *daemonServiceClient) GetActiveProfile(ctx context.Context, in *GetActiv
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) AddOwner(ctx context.Context, in *AddOwnerRequest, opts ...grpc.CallOption) (*AddOwnerResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(AddOwnerResponse)
|
||||
err := c.cc.Invoke(ctx, DaemonService_AddOwner_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) ResetOwner(ctx context.Context, in *ResetOwnerRequest, opts ...grpc.CallOption) (*ResetOwnerResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(ResetOwnerResponse)
|
||||
err := c.cc.Invoke(ctx, DaemonService_ResetOwner_FullMethodName, in, out, cOpts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *daemonServiceClient) Logout(ctx context.Context, in *LogoutRequest, opts ...grpc.CallOption) (*LogoutResponse, error) {
|
||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||
out := new(LogoutResponse)
|
||||
@@ -616,6 +645,13 @@ type DaemonServiceServer interface {
|
||||
RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error)
|
||||
ListProfiles(context.Context, *ListProfilesRequest) (*ListProfilesResponse, error)
|
||||
GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error)
|
||||
// AddOwner adds a UID to the active profile's owner list. Requires
|
||||
// root or an existing owner.
|
||||
AddOwner(context.Context, *AddOwnerRequest) (*AddOwnerResponse, error)
|
||||
// ResetOwner clears the active profile's owner list, returning it to
|
||||
// the unconfigured state. The next call from the active console-session
|
||||
// user will then re-claim ownership. Requires root.
|
||||
ResetOwner(context.Context, *ResetOwnerRequest) (*ResetOwnerResponse, error)
|
||||
// Logout disconnects from the network and deletes the peer from the management server
|
||||
Logout(context.Context, *LogoutRequest) (*LogoutResponse, error)
|
||||
GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error)
|
||||
@@ -732,6 +768,12 @@ func (UnimplementedDaemonServiceServer) ListProfiles(context.Context, *ListProfi
|
||||
func (UnimplementedDaemonServiceServer) GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method GetActiveProfile not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) AddOwner(context.Context, *AddOwnerRequest) (*AddOwnerResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method AddOwner not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) ResetOwner(context.Context, *ResetOwnerRequest) (*ResetOwnerResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method ResetOwner not implemented")
|
||||
}
|
||||
func (UnimplementedDaemonServiceServer) Logout(context.Context, *LogoutRequest) (*LogoutResponse, error) {
|
||||
return nil, status.Error(codes.Unimplemented, "method Logout not implemented")
|
||||
}
|
||||
@@ -1291,6 +1333,42 @@ func _DaemonService_GetActiveProfile_Handler(srv interface{}, ctx context.Contex
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_AddOwner_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(AddOwnerRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).AddOwner(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: DaemonService_AddOwner_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).AddOwner(ctx, req.(*AddOwnerRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_ResetOwner_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(ResetOwnerRequest)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(DaemonServiceServer).ResetOwner(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: DaemonService_ResetOwner_FullMethodName,
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(DaemonServiceServer).ResetOwner(ctx, req.(*ResetOwnerRequest))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _DaemonService_Logout_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(LogoutRequest)
|
||||
if err := dec(in); err != nil {
|
||||
@@ -1579,6 +1657,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "GetActiveProfile",
|
||||
Handler: _DaemonService_GetActiveProfile_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "AddOwner",
|
||||
Handler: _DaemonService_AddOwner_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "ResetOwner",
|
||||
Handler: _DaemonService_ResetOwner_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "Logout",
|
||||
Handler: _DaemonService_Logout_Handler,
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
if ! which realpath >/dev/null 2>&1; then
|
||||
echo realpath is not installed
|
||||
echo run: brew install coreutils
|
||||
exit 1
|
||||
if ! which realpath > /dev/null 2>&1
|
||||
then
|
||||
echo realpath is not installed
|
||||
echo run: brew install coreutils
|
||||
exit 1
|
||||
fi
|
||||
|
||||
old_pwd=$(pwd)
|
||||
script_path=$(dirname $(realpath "$0"))
|
||||
cd "$script_path"
|
||||
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.6.1
|
||||
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
|
||||
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
|
||||
cd "$old_pwd"
|
||||
|
||||
@@ -79,7 +79,7 @@ func TestPersistLoginOverrides(t *testing.T) {
|
||||
_, err := profilemanager.UpdateOrCreateConfig(seed)
|
||||
require.NoError(t, err, "seed config")
|
||||
|
||||
activeProf := &profilemanager.ActiveProfileState{ID: "default"}
|
||||
activeProf := &profilemanager.ActiveProfileState{Name: "default"}
|
||||
err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK)
|
||||
require.NoError(t, err, "persistLoginOverrides")
|
||||
|
||||
|
||||
172
client/server/owner.go
Normal file
172
client/server/owner.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/owner"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
// authorizeTargetProfile enforces the "match or root" rule for operations
|
||||
// that target a specific profile (Remove/Switch). The caller must be root
|
||||
// or appear in the target profile config's OwnerUIDs. A target profile in
|
||||
// legacy TOFU state (nil OwnerUIDs) is treated as unowned and therefore
|
||||
// accessible to any peer-creds caller, which matches pre-enforcement
|
||||
// behavior on upgraded installs.
|
||||
func (s *Server) authorizeTargetProfile(ctx context.Context, profileName, username string) error {
|
||||
uid, ok := owner.UIDFromContext(ctx)
|
||||
if !ok {
|
||||
return status.Error(codes.PermissionDenied, "peer credentials unavailable")
|
||||
}
|
||||
if uid == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
cfg, err := s.readProfileConfig(profileName, username)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read target profile config: %w", err)
|
||||
}
|
||||
|
||||
// Legacy / never-claimed target: allow, mirroring the migration TOFU
|
||||
// semantics in the interceptor.
|
||||
if cfg.OwnerUIDs == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if slices.Contains(cfg.OwnerUIDs, uid) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return status.Errorf(codes.PermissionDenied,
|
||||
"profile %q is owned by another user (uid %d is not in its owner list)", profileName, uid)
|
||||
}
|
||||
|
||||
// readProfileConfig loads a profile's config from disk without making it
|
||||
// active. Used by authorizeTargetProfile.
|
||||
func (s *Server) readProfileConfig(profileName, username string) (*profilemanager.Config, error) {
|
||||
state := &profilemanager.ActiveProfileState{Name: profileName, Username: username}
|
||||
path, err := state.FilePath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("resolve profile path: %w", err)
|
||||
}
|
||||
cfg, err := profilemanager.GetConfig(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("load %s: %w", path, err)
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// GetOwnerUIDs returns the current owner UIDs from the active config.
|
||||
// nil means TOFU mode, empty means root-only, populated means those UIDs are owners.
|
||||
func (s *Server) GetOwnerUIDs() []owner.UID {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.config == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.config.OwnerUIDs
|
||||
}
|
||||
|
||||
// AddOwnerUID adds the given UID to the owner list in the active profile config.
|
||||
func (s *Server) AddOwnerUID(uid owner.UID) error {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
return s.addOwnerUIDLocked(uid)
|
||||
}
|
||||
|
||||
// addOwnerUIDLocked adds uid to the active profile's owner list and persists it.
|
||||
// The caller must hold s.mutex.
|
||||
func (s *Server) addOwnerUIDLocked(uid owner.UID) error {
|
||||
if s.config == nil {
|
||||
return fmt.Errorf("config not loaded")
|
||||
}
|
||||
|
||||
if slices.Contains(s.config.OwnerUIDs, uid) {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.config.OwnerUIDs = append(s.config.OwnerUIDs, uid)
|
||||
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get active profile: %w", err)
|
||||
}
|
||||
|
||||
cfgPath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get profile file path: %w", err)
|
||||
}
|
||||
|
||||
if err := util.WriteJson(context.Background(), cfgPath, s.config); err != nil {
|
||||
return fmt.Errorf("write config: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("owner UID %d added in %s (owners: %v)", uid, cfgPath, s.config.OwnerUIDs)
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddOwner handles the AddOwner RPC. The interceptor has already gated this
|
||||
// call (caller must be root or an existing owner); the handler just persists
|
||||
// the new UID into the active profile config.
|
||||
func (s *Server) AddOwner(_ context.Context, msg *proto.AddOwnerRequest) (*proto.AddOwnerResponse, error) {
|
||||
if msg == nil || msg.Uid == 0 {
|
||||
return nil, status.Error(codes.InvalidArgument, "uid must be non-zero")
|
||||
}
|
||||
if err := s.AddOwnerUID(owner.UID(msg.Uid)); err != nil {
|
||||
return nil, fmt.Errorf("add owner: %w", err)
|
||||
}
|
||||
return &proto.AddOwnerResponse{}, nil
|
||||
}
|
||||
|
||||
// ResetOwner clears the active profile's owner list. Only callable by root
|
||||
// (the interceptor enforces this: a non-owner non-root caller is denied
|
||||
// before reaching the handler, and only owners or root can reach Add/Reset
|
||||
// at all; we additionally require root here so existing owners can't reset
|
||||
// each other out).
|
||||
func (s *Server) ResetOwner(ctx context.Context, _ *proto.ResetOwnerRequest) (*proto.ResetOwnerResponse, error) {
|
||||
uid, ok := owner.UIDFromContext(ctx)
|
||||
if !ok {
|
||||
return nil, status.Error(codes.PermissionDenied, "peer credentials unavailable")
|
||||
}
|
||||
if uid != 0 {
|
||||
return nil, status.Error(codes.PermissionDenied, "reset-owner requires root")
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if s.config == nil {
|
||||
return nil, fmt.Errorf("config not loaded")
|
||||
}
|
||||
|
||||
// Reset to the fresh-install state (empty, not nil): only root and the
|
||||
// active console-session user can reclaim. nil would be legacy migration
|
||||
// TOFU, where any non-root caller (including SSH) could reclaim.
|
||||
s.config.OwnerUIDs = []owner.UID{}
|
||||
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get active profile: %w", err)
|
||||
}
|
||||
cfgPath, err := activeProf.FilePath()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get profile file path: %w", err)
|
||||
}
|
||||
if err := util.WriteJson(context.Background(), cfgPath, s.config); err != nil {
|
||||
return nil, fmt.Errorf("write config: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("owner list reset; next call from the active console user will re-claim ownership")
|
||||
return &proto.ResetOwnerResponse{}, nil
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/internal/owner"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sleephandler "github.com/netbirdio/netbird/client/internal/sleep/handler"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
@@ -308,14 +309,15 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
|
||||
}
|
||||
|
||||
resolved, err := s.resolveProfileHandle(msg.ProfileName, msg.Username)
|
||||
if err != nil {
|
||||
log.Errorf("failed to resolve profile %q: %v", msg.ProfileName, err)
|
||||
return nil, err
|
||||
profState := profilemanager.ActiveProfileState{
|
||||
Name: msg.ProfileName,
|
||||
Username: msg.Username,
|
||||
}
|
||||
profPath := resolved.Path
|
||||
if profPath == "" {
|
||||
profPath = profilemanager.DefaultConfigPath
|
||||
|
||||
profPath, err := profState.FilePath()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile file path: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
}
|
||||
|
||||
var config profilemanager.ConfigInput
|
||||
@@ -445,9 +447,30 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
}
|
||||
|
||||
if msg.ProfileName != nil {
|
||||
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
log.Errorf("failed to switch profile: %v", err)
|
||||
return nil, err
|
||||
if *msg.ProfileName != "default" && (msg.Username == nil || *msg.Username == "") {
|
||||
log.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName)
|
||||
return nil, fmt.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName)
|
||||
}
|
||||
|
||||
var username string
|
||||
if *msg.ProfileName != "default" {
|
||||
username = *msg.Username
|
||||
}
|
||||
|
||||
if *msg.ProfileName != activeProf.Name && username != activeProf.Username {
|
||||
if s.checkProfilesDisabled() {
|
||||
log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled")
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
}
|
||||
|
||||
log.Infof("switching to profile %s for user '%s'", *msg.ProfileName, username)
|
||||
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
Name: *msg.ProfileName,
|
||||
Username: username,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to set active profile state: %v", err)
|
||||
return nil, fmt.Errorf("failed to set active profile state: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -457,7 +480,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("active profile: %s for %s", activeProf.ID, activeProf.Username)
|
||||
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
|
||||
|
||||
s.mutex.Lock()
|
||||
|
||||
@@ -689,10 +712,10 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
}
|
||||
|
||||
if msg != nil && msg.ProfileName != nil {
|
||||
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
s.mutex.Unlock()
|
||||
log.Errorf("failed to switch profile: %v", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to switch profile: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -703,7 +726,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("active profile: %s for %s", activeProf.ID, activeProf.Username)
|
||||
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
|
||||
|
||||
config, _, err := s.getConfig(activeProf)
|
||||
if err != nil {
|
||||
@@ -713,6 +736,18 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
|
||||
}
|
||||
s.config = config
|
||||
|
||||
// An explicit --owner claim locks the active profile to the calling user
|
||||
// (plus root). Root has no specific UID to claim, so only non-root callers
|
||||
// take effect here; the interceptor has already authorized the call.
|
||||
if msg != nil && msg.ClaimOwner {
|
||||
if uid, ok := owner.UIDFromContext(callerCtx); ok && uid != 0 {
|
||||
if err := s.addOwnerUIDLocked(uid); err != nil {
|
||||
s.mutex.Unlock()
|
||||
return nil, fmt.Errorf("claim owner: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.statusRecorder.UpdateManagementAddress(s.config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(s.config.RosenpassEnabled, s.config.RosenpassPermissive)
|
||||
|
||||
@@ -746,64 +781,50 @@ func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveProfileHandle resolves a wire-level profile handle (display
|
||||
// name, ID, or unique ID prefix) to a concrete profile. Returns gRPC
|
||||
// status errors so handlers can return them directly.
|
||||
func (s *Server) resolveProfileHandle(handle, username string) (*profilemanager.Profile, error) {
|
||||
p, err := s.profileManager.ResolveProfile(handle, username)
|
||||
if err == nil {
|
||||
return p, nil
|
||||
}
|
||||
var amb *profilemanager.ErrAmbiguousHandle
|
||||
if errors.As(err, &amb) {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "%v", amb)
|
||||
}
|
||||
if errors.Is(err, profilemanager.ErrProfileNotFound) {
|
||||
return nil, gstatus.Errorf(codes.NotFound, "profile %q not found", handle)
|
||||
}
|
||||
return nil, fmt.Errorf("resolve profile: %w", err)
|
||||
}
|
||||
|
||||
// switchProfileIfNeeded resolves the user-supplied handle, updates the
|
||||
// active profile state if it differs from the current one, and returns
|
||||
// the resolved profile so callers can include its ID in RPC responses.
|
||||
func (s *Server) switchProfileIfNeeded(handle string, userName *string, activeProf *profilemanager.ActiveProfileState) (*profilemanager.Profile, error) {
|
||||
if handle != profilemanager.DefaultProfileName && (userName == nil || *userName == "") {
|
||||
log.Errorf("profile name is set to %s, but username is not provided", handle)
|
||||
return nil, fmt.Errorf("profile name is set to %s, but username is not provided", handle)
|
||||
func (s *Server) switchProfileIfNeeded(profileName string, userName *string, activeProf *profilemanager.ActiveProfileState) error {
|
||||
if profileName != "default" && (userName == nil || *userName == "") {
|
||||
log.Errorf("profile name is set to %s, but username is not provided", profileName)
|
||||
return fmt.Errorf("profile name is set to %s, but username is not provided", profileName)
|
||||
}
|
||||
|
||||
var username string
|
||||
if handle != profilemanager.DefaultProfileName {
|
||||
if profileName != "default" {
|
||||
username = *userName
|
||||
}
|
||||
|
||||
resolved, err := s.resolveProfileHandle(handle, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resolved.ID != activeProf.ID || username != activeProf.Username {
|
||||
if profileName != activeProf.Name || username != activeProf.Username {
|
||||
if s.checkProfilesDisabled() {
|
||||
log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled")
|
||||
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
}
|
||||
|
||||
log.Infof("switching to profile %s (%s) for user %s", resolved.Name, resolved.ID, username)
|
||||
log.Infof("switching to profile %s for user %s", profileName, username)
|
||||
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
ID: resolved.ID,
|
||||
Name: profileName,
|
||||
Username: username,
|
||||
}); err != nil {
|
||||
log.Errorf("failed to set active profile state: %v", err)
|
||||
return nil, fmt.Errorf("failed to set active profile state: %w", err)
|
||||
return fmt.Errorf("failed to set active profile state: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return resolved, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// SwitchProfile switches the active profile in the daemon.
|
||||
func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfileRequest) (*proto.SwitchProfileResponse, error) {
|
||||
// Switching downs the current session and starts another, so the caller
|
||||
// must own the target profile (or be root).
|
||||
if msg != nil && msg.ProfileName != nil {
|
||||
username := ""
|
||||
if msg.Username != nil {
|
||||
username = *msg.Username
|
||||
}
|
||||
if err := s.authorizeTargetProfile(callerCtx, *msg.ProfileName, username); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
@@ -814,9 +835,9 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
|
||||
}
|
||||
|
||||
if msg != nil && msg.ProfileName != nil {
|
||||
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
|
||||
log.Errorf("failed to switch profile: %v", err)
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to switch profile: %w", err)
|
||||
}
|
||||
}
|
||||
activeProf, err = s.profileManager.GetActiveProfileState()
|
||||
@@ -832,7 +853,7 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
|
||||
|
||||
s.config = config
|
||||
|
||||
return &proto.SwitchProfileResponse{Id: activeProf.ID}, nil
|
||||
return &proto.SwitchProfileResponse{}, nil
|
||||
}
|
||||
|
||||
// Down engine work in the daemon.
|
||||
@@ -916,27 +937,22 @@ func (s *Server) Logout(ctx context.Context, msg *proto.LogoutRequest) (*proto.L
|
||||
}
|
||||
|
||||
func (s *Server) handleProfileLogout(ctx context.Context, msg *proto.LogoutRequest) (*proto.LogoutResponse, error) {
|
||||
if err := s.validateProfileOperation(*msg.ProfileName, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if msg.Username == nil || *msg.Username == "" {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided when profile name is specified")
|
||||
}
|
||||
username := *msg.Username
|
||||
|
||||
resolved, err := s.resolveProfileHandle(*msg.ProfileName, username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.validateProfileOperation(resolved.ID, true); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.logoutFromProfile(ctx, resolved); err != nil {
|
||||
log.Errorf("failed to logout from profile %s: %v", resolved.ID, err)
|
||||
if err := s.logoutFromProfile(ctx, *msg.ProfileName, username); err != nil {
|
||||
log.Errorf("failed to logout from profile %s: %v", *msg.ProfileName, err)
|
||||
return nil, gstatus.Errorf(codes.Internal, "logout: %v", err)
|
||||
}
|
||||
|
||||
activeProf, _ := s.profileManager.GetActiveProfileState()
|
||||
if activeProf != nil && activeProf.ID == resolved.ID {
|
||||
if activeProf != nil && activeProf.Name == *msg.ProfileName {
|
||||
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
|
||||
log.Errorf("failed to cleanup connection: %v", err)
|
||||
}
|
||||
@@ -998,30 +1014,30 @@ func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*prof
|
||||
return config, configExisted, nil
|
||||
}
|
||||
|
||||
func (s *Server) canRemoveProfile(id string) error {
|
||||
if id == profilemanager.DefaultProfileName {
|
||||
func (s *Server) canRemoveProfile(profileName string) error {
|
||||
if profileName == profilemanager.DefaultProfileName {
|
||||
return fmt.Errorf("remove profile with reserved name: %s", profilemanager.DefaultProfileName)
|
||||
}
|
||||
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err == nil && activeProf.ID == id {
|
||||
return fmt.Errorf("remove active profile: %s", id)
|
||||
if err == nil && activeProf.Name == profileName {
|
||||
return fmt.Errorf("remove active profile: %s", profileName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) validateProfileOperation(id string, allowActiveProfile bool) error {
|
||||
func (s *Server) validateProfileOperation(profileName string, allowActiveProfile bool) error {
|
||||
if s.checkProfilesDisabled() {
|
||||
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
|
||||
}
|
||||
|
||||
if id == "" {
|
||||
if profileName == "" {
|
||||
return gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
|
||||
}
|
||||
|
||||
if !allowActiveProfile {
|
||||
if err := s.canRemoveProfile(id); err != nil {
|
||||
if err := s.canRemoveProfile(profileName); err != nil {
|
||||
return gstatus.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
}
|
||||
@@ -1029,15 +1045,25 @@ func (s *Server) validateProfileOperation(id string, allowActiveProfile bool) er
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) logoutFromProfile(ctx context.Context, profile *profilemanager.Profile) error {
|
||||
// logoutFromProfile logs out from a specific profile by loading its config and sending logout request
|
||||
func (s *Server) logoutFromProfile(ctx context.Context, profileName, username string) error {
|
||||
activeProf, err := s.profileManager.GetActiveProfileState()
|
||||
if err == nil && activeProf.ID == profile.ID && s.connectClient != nil {
|
||||
if err == nil && activeProf.Name == profileName && s.connectClient != nil {
|
||||
return s.sendLogoutRequest(ctx)
|
||||
}
|
||||
|
||||
config, err := profilemanager.GetConfig(profile.Path)
|
||||
profileState := &profilemanager.ActiveProfileState{
|
||||
Name: profileName,
|
||||
Username: username,
|
||||
}
|
||||
profilePath, err := profileState.FilePath()
|
||||
if err != nil {
|
||||
return fmt.Errorf("profile '%s' not found", profile.ID)
|
||||
return fmt.Errorf("get profile path: %w", err)
|
||||
}
|
||||
|
||||
config, err := profilemanager.GetConfig(profilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("profile '%s' not found", profileName)
|
||||
}
|
||||
|
||||
return s.sendLogoutRequestWithConfig(ctx, config)
|
||||
@@ -1451,14 +1477,15 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
|
||||
resolved, err := s.resolveProfileHandle(req.ProfileName, req.Username)
|
||||
if err != nil {
|
||||
log.Errorf("failed to resolve profile %q: %v", req.ProfileName, err)
|
||||
return nil, err
|
||||
prof := profilemanager.ActiveProfileState{
|
||||
Name: req.ProfileName,
|
||||
Username: req.Username,
|
||||
}
|
||||
cfgPath := resolved.Path
|
||||
if cfgPath == "" {
|
||||
cfgPath = profilemanager.DefaultConfigPath
|
||||
|
||||
cfgPath, err := prof.FilePath()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get active profile file path: %v", err)
|
||||
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
|
||||
}
|
||||
|
||||
cfg, err := profilemanager.GetConfig(cfgPath)
|
||||
@@ -1562,46 +1589,47 @@ func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
|
||||
}
|
||||
|
||||
created, err := s.profileManager.AddProfile(msg.ProfileName, msg.Username)
|
||||
if err != nil {
|
||||
if errors.Is(err, profilemanager.ErrProfileAlreadyExists) {
|
||||
return nil, gstatus.Errorf(codes.AlreadyExists, "profile %q already exists", msg.ProfileName)
|
||||
}
|
||||
// New profiles auto-claim the caller as their sole owner so the user who
|
||||
// just created the profile retains control (and other local users can't
|
||||
// touch it via SwitchProfile/RemoveProfile). When called by root, leave
|
||||
// OwnerUIDs at the default (empty/env-seeded); root explicitly didn't
|
||||
// claim ownership for any specific user.
|
||||
var initialOwners []owner.UID
|
||||
if uid, ok := owner.UIDFromContext(ctx); ok && uid != 0 {
|
||||
initialOwners = []owner.UID{uid}
|
||||
}
|
||||
|
||||
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username, initialOwners); err != nil {
|
||||
log.Errorf("failed to create profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to create profile: %w", err)
|
||||
}
|
||||
|
||||
return &proto.AddProfileResponse{Id: created.ID}, nil
|
||||
return &proto.AddProfileResponse{}, nil
|
||||
}
|
||||
|
||||
// RemoveProfile removes a profile from the daemon.
|
||||
func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequest) (*proto.RemoveProfileResponse, error) {
|
||||
if err := s.authorizeTargetProfile(ctx, msg.ProfileName, msg.Username); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
if msg.ProfileName == "" {
|
||||
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
|
||||
}
|
||||
|
||||
resolved, err := s.resolveProfileHandle(msg.ProfileName, msg.Username)
|
||||
if err != nil {
|
||||
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.validateProfileOperation(resolved.ID, false); err != nil {
|
||||
return nil, err
|
||||
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
|
||||
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
|
||||
}
|
||||
|
||||
if err := s.logoutFromProfile(ctx, resolved); err != nil {
|
||||
log.Warnf("failed to logout from profile %s before removal: %v", resolved.ID, err)
|
||||
}
|
||||
|
||||
if err := s.profileManager.RemoveProfile(resolved.ID, msg.Username); err != nil {
|
||||
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
|
||||
log.Errorf("failed to remove profile: %v", err)
|
||||
return nil, fmt.Errorf("failed to remove profile: %w", err)
|
||||
}
|
||||
|
||||
return &proto.RemoveProfileResponse{Id: resolved.ID}, nil
|
||||
return &proto.RemoveProfileResponse{}, nil
|
||||
}
|
||||
|
||||
// ListProfiles lists all profiles in the daemon.
|
||||
@@ -1624,7 +1652,6 @@ func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesReques
|
||||
}
|
||||
for i, profile := range profiles {
|
||||
response.Profiles[i] = &proto.Profile{
|
||||
Id: profile.ID,
|
||||
Name: profile.Name,
|
||||
IsActive: profile.IsActive,
|
||||
}
|
||||
@@ -1633,9 +1660,7 @@ func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesReques
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// GetActiveProfile returns the active profile in the daemon. The
|
||||
// ProfileName field carries the display name for backwards compatibility
|
||||
// with UI clients, new callers should prefer Id.
|
||||
// GetActiveProfile returns the active profile in the daemon.
|
||||
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
@@ -1646,22 +1671,9 @@ func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfi
|
||||
return nil, fmt.Errorf("failed to get active profile state: %w", err)
|
||||
}
|
||||
|
||||
displayName := activeProfile.ID
|
||||
if activeProfile.ID != profilemanager.DefaultProfileName {
|
||||
if profiles, lerr := s.profileManager.ListProfiles(activeProfile.Username); lerr == nil {
|
||||
for _, p := range profiles {
|
||||
if p.ID == activeProfile.ID {
|
||||
displayName = p.Name
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &proto.GetActiveProfileResponse{
|
||||
ProfileName: displayName,
|
||||
ProfileName: activeProfile.Name,
|
||||
Username: activeProfile.Username,
|
||||
Id: activeProfile.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -97,7 +97,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
ID: "test-profile",
|
||||
Name: "test-profile",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -158,7 +158,7 @@ func TestServer_Up(t *testing.T) {
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
ID: profName,
|
||||
Name: profName,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -228,7 +228,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
ID: "default",
|
||||
Name: "default",
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -47,7 +47,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
|
||||
pm := profilemanager.ServiceManager{}
|
||||
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
|
||||
ID: profName,
|
||||
Name: profName,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -112,7 +112,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
profState := profilemanager.ActiveProfileState{
|
||||
ID: profName,
|
||||
Name: profName,
|
||||
Username: currUser.Username,
|
||||
}
|
||||
cfgPath, err := profState.FilePath()
|
||||
|
||||
@@ -622,7 +622,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
|
||||
}
|
||||
|
||||
req := &proto.SetConfigRequest{
|
||||
ProfileName: activeProf.ID,
|
||||
ProfileName: activeProf.Name,
|
||||
Username: currUser.Username,
|
||||
}
|
||||
|
||||
@@ -789,11 +789,11 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe
|
||||
|
||||
loginReq := &proto.LoginRequest{
|
||||
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
|
||||
ProfileName: &activeProf.ID,
|
||||
ProfileName: &activeProf.Name,
|
||||
Username: &currUser.Username,
|
||||
}
|
||||
|
||||
profileState, err := s.profileManager.GetProfileState(activeProf.ID)
|
||||
profileState, err := s.profileManager.GetProfileState(activeProf.Name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get profile state for login hint: %v", err)
|
||||
} else if profileState.Email != "" {
|
||||
@@ -1309,7 +1309,7 @@ func (s *serviceClient) getSrvConfig() {
|
||||
}
|
||||
|
||||
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.ID,
|
||||
ProfileName: activeProf.Name,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -1533,7 +1533,7 @@ func (s *serviceClient) loadSettings() {
|
||||
}
|
||||
|
||||
cfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
|
||||
ProfileName: activeProf.ID,
|
||||
ProfileName: activeProf.Name,
|
||||
Username: currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -1610,7 +1610,7 @@ func (s *serviceClient) updateConfig() error {
|
||||
}
|
||||
|
||||
req := proto.SetConfigRequest{
|
||||
ProfileName: activeProf.ID,
|
||||
ProfileName: activeProf.Name,
|
||||
Username: currUser.Username,
|
||||
DisableAutoConnect: &disableAutoStart,
|
||||
ServerSSHAllowed: &sshAllowed,
|
||||
|
||||
@@ -88,7 +88,7 @@ func (s *serviceClient) showProfilesUI() {
|
||||
return
|
||||
}
|
||||
// switch
|
||||
err = s.switchProfile(profile.ID)
|
||||
err = s.switchProfile(profile.Name)
|
||||
if err != nil {
|
||||
log.Errorf("failed to switch profile: %v", err)
|
||||
dialog.ShowError(errors.New("failed to select profile"), s.wProfiles)
|
||||
@@ -130,7 +130,7 @@ func (s *serviceClient) showProfilesUI() {
|
||||
logoutBtn.Show()
|
||||
logoutBtn.SetText("Deregister")
|
||||
logoutBtn.OnTapped = func() {
|
||||
s.handleProfileLogout(profile, refresh)
|
||||
s.handleProfileLogout(profile.Name, refresh)
|
||||
}
|
||||
|
||||
// Remove profile
|
||||
@@ -144,7 +144,7 @@ func (s *serviceClient) showProfilesUI() {
|
||||
return
|
||||
}
|
||||
|
||||
err = s.removeProfile(profile.ID)
|
||||
err = s.removeProfile(profile.Name)
|
||||
if err != nil {
|
||||
log.Errorf("failed to remove profile: %v", err)
|
||||
dialog.ShowError(fmt.Errorf("failed to remove profile"), s.wProfiles)
|
||||
@@ -250,7 +250,7 @@ func (s *serviceClient) addProfile(profileName string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) switchProfile(handle string) error {
|
||||
func (s *serviceClient) switchProfile(profileName string) error {
|
||||
conn, err := s.getSrvClient(defaultFailTimeout)
|
||||
if err != nil {
|
||||
return fmt.Errorf(getClientFMT, err)
|
||||
@@ -261,15 +261,15 @@ func (s *serviceClient) switchProfile(handle string) error {
|
||||
return fmt.Errorf("get current user: %w", err)
|
||||
}
|
||||
|
||||
resp, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &handle,
|
||||
if _, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &profileName,
|
||||
Username: &currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
}); err != nil {
|
||||
return fmt.Errorf("switch profile failed: %w", err)
|
||||
}
|
||||
|
||||
if err := s.profileManager.SwitchProfile(resp.Id); err != nil {
|
||||
err = s.profileManager.SwitchProfile(profileName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("switch profile: %w", err)
|
||||
}
|
||||
|
||||
@@ -299,7 +299,6 @@ func (s *serviceClient) removeProfile(profileName string) error {
|
||||
}
|
||||
|
||||
type Profile struct {
|
||||
ID string
|
||||
Name string
|
||||
IsActive bool
|
||||
}
|
||||
@@ -325,7 +324,6 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
|
||||
|
||||
for _, profile := range profilesResp.Profiles {
|
||||
profiles = append(profiles, Profile{
|
||||
ID: profile.Id,
|
||||
Name: profile.Name,
|
||||
IsActive: profile.IsActive,
|
||||
})
|
||||
@@ -334,10 +332,10 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
|
||||
return profiles, nil
|
||||
}
|
||||
|
||||
func (s *serviceClient) handleProfileLogout(profile Profile, refreshCallback func()) {
|
||||
func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) {
|
||||
dialog.ShowConfirm(
|
||||
"Deregister",
|
||||
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profile.Name),
|
||||
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profileName),
|
||||
func(confirm bool) {
|
||||
if !confirm {
|
||||
return
|
||||
@@ -358,10 +356,8 @@ func (s *serviceClient) handleProfileLogout(profile Profile, refreshCallback fun
|
||||
}
|
||||
|
||||
username := currUser.Username
|
||||
// ProfileName is treated as a handle; send the ID so the
|
||||
// daemon resolves to exactly this profile.
|
||||
_, err = conn.Logout(s.ctx, &proto.LogoutRequest{
|
||||
ProfileName: &profile.ID,
|
||||
ProfileName: &profileName,
|
||||
Username: &username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -372,7 +368,7 @@ func (s *serviceClient) handleProfileLogout(profile Profile, refreshCallback fun
|
||||
|
||||
dialog.ShowInformation(
|
||||
"Deregistered",
|
||||
fmt.Sprintf("Successfully deregistered from '%s'", profile.Name),
|
||||
fmt.Sprintf("Successfully deregistered from '%s'", profileName),
|
||||
s.wProfiles,
|
||||
)
|
||||
|
||||
@@ -465,7 +461,6 @@ func (p *profileMenu) getProfiles() ([]Profile, error) {
|
||||
|
||||
for _, profile := range profilesResp.Profiles {
|
||||
profiles = append(profiles, Profile{
|
||||
ID: profile.Id,
|
||||
Name: profile.Name,
|
||||
IsActive: profile.IsActive,
|
||||
})
|
||||
@@ -506,7 +501,7 @@ func (p *profileMenu) refresh() {
|
||||
}
|
||||
|
||||
if activeProf.ProfileName == "default" || activeProf.Username == currUser.Username {
|
||||
activeProfState, err := p.profileManager.GetProfileState(activeProf.Id)
|
||||
activeProfState, err := p.profileManager.GetProfileState(activeProf.ProfileName)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get active profile state: %v", err)
|
||||
p.emailMenuItem.Hide()
|
||||
@@ -546,8 +541,8 @@ func (p *profileMenu) refresh() {
|
||||
return
|
||||
}
|
||||
|
||||
switchResp, err := conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &profile.ID,
|
||||
_, err = conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{
|
||||
ProfileName: &profile.Name,
|
||||
Username: &currUser.Username,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -557,7 +552,7 @@ func (p *profileMenu) refresh() {
|
||||
return
|
||||
}
|
||||
|
||||
err = p.profileManager.SwitchProfile(switchResp.Id)
|
||||
err = p.profileManager.SwitchProfile(profile.Name)
|
||||
if err != nil {
|
||||
log.Errorf("failed to switch profile '%s': %v", profile.Name, err)
|
||||
return
|
||||
|
||||
@@ -6,9 +6,11 @@ import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
goproto "google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
|
||||
@@ -185,9 +187,38 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
response.NetworkMap.SshAuth = &proto.SSHAuth{AuthorizedUsers: hashedUsers, MachineUsers: machineUsers, UserIDClaim: userIDClaim}
|
||||
}
|
||||
|
||||
// settings == nil → field stays nil → "no info in this snapshot", client
|
||||
// preserves the deadline it already had. settings non-nil → emit either a
|
||||
// valid deadline or the explicit-zero "disabled" sentinel via
|
||||
// encodeSessionExpiresAt.
|
||||
if settings != nil {
|
||||
response.SessionExpiresAt = encodeSessionExpiresAt(
|
||||
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
|
||||
)
|
||||
}
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
// encodeSessionExpiresAt encodes a server-side deadline into the 3-state wire
|
||||
// representation used on LoginResponse, SyncResponse and
|
||||
// ExtendAuthSessionResponse. See the proto comments on those messages.
|
||||
//
|
||||
// - deadline.IsZero() → returns &Timestamp{} (seconds=0, nanos=0): the
|
||||
// "expiry disabled or peer is not SSO-tracked" sentinel; the client clears
|
||||
// its anchor.
|
||||
// - deadline non-zero → returns timestamppb.New(deadline): the new absolute
|
||||
// UTC deadline.
|
||||
//
|
||||
// Returning nil ("no info, preserve client's anchor") is the caller's job —
|
||||
// only meaningful on Sync builds where settings were not resolved.
|
||||
func encodeSessionExpiresAt(deadline time.Time) *timestamppb.Timestamp {
|
||||
if deadline.IsZero() {
|
||||
return ×tamppb.Timestamp{}
|
||||
}
|
||||
return timestamppb.New(deadline)
|
||||
}
|
||||
|
||||
func buildAuthorizedUsersProto(ctx context.Context, authorizedUsers map[string]map[string]struct{}) ([][]byte, map[string]*proto.MachineUserIndexes) {
|
||||
userIDToIndex := make(map[string]uint32)
|
||||
var hashedUsers [][]byte
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
@@ -200,3 +201,29 @@ func TestBuildJWTConfig_Audiences(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeSessionExpiresAt pins the wire encoding the client's
|
||||
// applySessionDeadline depends on:
|
||||
//
|
||||
// - zero deadline → &Timestamp{} (seconds=0, nanos=0): the explicit
|
||||
// "expiry disabled or peer is not SSO-tracked" sentinel.
|
||||
// - non-zero → timestamppb.New(deadline): the absolute UTC deadline.
|
||||
//
|
||||
// The third state (nil pointer = "no info in this snapshot") is the caller's
|
||||
// responsibility on the Sync path when settings could not be resolved; the
|
||||
// helper itself never returns nil.
|
||||
func TestEncodeSessionExpiresAt(t *testing.T) {
|
||||
t.Run("zero deadline encodes as explicit-zero sentinel", func(t *testing.T) {
|
||||
got := encodeSessionExpiresAt(time.Time{})
|
||||
assert.NotNil(t, got, "must not return nil; nil means 'no info', not 'disabled'")
|
||||
assert.Equal(t, int64(0), got.GetSeconds())
|
||||
assert.Equal(t, int32(0), got.GetNanos())
|
||||
})
|
||||
|
||||
t.Run("non-zero deadline round-trips", func(t *testing.T) {
|
||||
deadline := time.Date(2030, 1, 2, 3, 4, 5, 0, time.UTC)
|
||||
got := encodeSessionExpiresAt(deadline)
|
||||
assert.NotNil(t, got)
|
||||
assert.True(t, got.AsTime().Equal(deadline))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -821,6 +821,80 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExtendAuthSession refreshes the peer's SSO session expiry deadline using a
|
||||
// fresh JWT. The same JWT validation pipeline as Login is used. The tunnel
|
||||
// stays up; no network map sync is performed. The new deadline is returned
|
||||
// in ExtendAuthSessionResponse.SessionExpiresAt.
|
||||
func (s *Server) ExtendAuthSession(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) {
|
||||
extendReq := &proto.ExtendAuthSessionRequest{}
|
||||
peerKey, err := s.parseRequest(ctx, req, extendReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
|
||||
if accountID, accErr := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String()); accErr == nil {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
|
||||
}
|
||||
|
||||
jwt := extendReq.GetJwtToken()
|
||||
if jwt == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "jwt token is required")
|
||||
}
|
||||
|
||||
var userID string
|
||||
const attempts = 3
|
||||
for i := 0; i < attempts; i++ {
|
||||
userID, err = s.validateToken(ctx, peerKey.String(), jwt)
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
if i == attempts-1 {
|
||||
break
|
||||
}
|
||||
log.WithContext(ctx).Warnf("failed validating JWT token while extending session for peer %s: %v. Retrying (idP cache).", peerKey.String(), err)
|
||||
select {
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userID == "" {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "jwt token did not yield a user id")
|
||||
}
|
||||
|
||||
deadline, err := s.accountManager.ExtendPeerSession(ctx, peerKey.String(), userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed extending session for peer %s: %v", peerKey.String(), err)
|
||||
return nil, mapError(ctx, err)
|
||||
}
|
||||
|
||||
// Success path normally returns a non-zero deadline. A defensive zero
|
||||
// would still encode as the explicit "disabled" sentinel rather than nil,
|
||||
// so the client clears any stale anchor instead of preserving it.
|
||||
resp := &proto.ExtendAuthSessionResponse{
|
||||
SessionExpiresAt: encodeSessionExpiresAt(deadline),
|
||||
}
|
||||
|
||||
wgKey, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed processing request")
|
||||
}
|
||||
encrypted, err := encryption.EncryptMessage(peerKey, wgKey, resp)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed encrypting response")
|
||||
}
|
||||
return &proto.EncryptedMessage{
|
||||
WgPubKey: wgKey.PublicKey().String(),
|
||||
Body: encrypted,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, netMap *types.NetworkMap, postureChecks []*posture.Checks) (*proto.LoginResponse, error) {
|
||||
var relayToken *Token
|
||||
var err error
|
||||
@@ -844,6 +918,12 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
|
||||
Checks: toProtocolChecks(ctx, postureChecks),
|
||||
}
|
||||
|
||||
// settings is always non-nil here, so we never emit nil — encoder returns
|
||||
// either a valid deadline or the explicit-zero "disabled" sentinel.
|
||||
loginResp.SessionExpiresAt = encodeSessionExpiresAt(
|
||||
peer.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration),
|
||||
)
|
||||
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -355,7 +355,17 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
|
||||
oldSettings.DNSDomain != newSettings.DNSDomain ||
|
||||
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion ||
|
||||
oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways {
|
||||
oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways ||
|
||||
oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled ||
|
||||
oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
|
||||
// Session deadline is derived from LastLogin + PeerLoginExpiration
|
||||
// on every Login/Sync response. Without a fan-out push, connected
|
||||
// peers keep the deadline they received at login time and only see
|
||||
// the new value after the next unrelated NetworkMap change. Add
|
||||
// these two fields to the trigger list so admin-side expiry tweaks
|
||||
// (e.g. shortening from 24h to 1h) reach every connected peer
|
||||
// within seconds, which is what the proactive-warning feature
|
||||
// relies on (see client/internal/auth/sessionwatch).
|
||||
updateAccountPeers = true
|
||||
}
|
||||
|
||||
|
||||
@@ -109,6 +109,7 @@ type Manager interface {
|
||||
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
|
||||
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
|
||||
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
|
||||
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
|
||||
GetExternalCacheManager() ExternalCacheManager
|
||||
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
|
||||
|
||||
@@ -1304,6 +1304,21 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LoginPeer", reflect.TypeOf((*MockManager)(nil).LoginPeer), ctx, login)
|
||||
}
|
||||
|
||||
// ExtendPeerSession mocks base method.
|
||||
func (m *MockManager) ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ExtendPeerSession", ctx, peerPubKey, userID)
|
||||
ret0, _ := ret[0].(time.Time)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ExtendPeerSession indicates an expected call of ExtendPeerSession.
|
||||
func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ExtendPeerSession", reflect.TypeOf((*MockManager)(nil).ExtendPeerSession), ctx, peerPubKey, userID)
|
||||
}
|
||||
|
||||
// MarkPeerConnected mocks base method.
|
||||
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -240,6 +240,10 @@ const (
|
||||
AccountLocalMfaEnabled Activity = 123
|
||||
// AccountLocalMfaDisabled indicates that a user disabled TOTP MFA for local users
|
||||
AccountLocalMfaDisabled Activity = 124
|
||||
// UserExtendedPeerSession indicates that a user refreshed their peer's
|
||||
// SSO session deadline via ExtendAuthSession without re-establishing the
|
||||
// tunnel. Distinct from UserLoggedInPeer (full interactive login).
|
||||
UserExtendedPeerSession Activity = 125
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
@@ -394,6 +398,8 @@ var activityMap = map[Activity]Code{
|
||||
AccountLocalMfaEnabled: {"Account local MFA enabled", "account.setting.local.mfa.enable"},
|
||||
AccountLocalMfaDisabled: {"Account local MFA disabled", "account.setting.local.mfa.disable"},
|
||||
|
||||
UserExtendedPeerSession: {"User extended peer session", "user.peer.session.extend"},
|
||||
|
||||
DomainAdded: {"Domain added", "domain.add"},
|
||||
DomainDeleted: {"Domain deleted", "domain.delete"},
|
||||
DomainValidated: {"Domain validated", "domain.validate"},
|
||||
|
||||
@@ -98,6 +98,7 @@ type MockAccountManager struct {
|
||||
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
|
||||
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
|
||||
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
ExtendPeerSessionFunc func(ctx context.Context, peerPubKey, userID string) (time.Time, error)
|
||||
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
|
||||
ApproveUserFunc func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error)
|
||||
@@ -860,6 +861,14 @@ func (am *MockAccountManager) LoginPeer(ctx context.Context, login types.PeerLog
|
||||
return nil, nil, nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented")
|
||||
}
|
||||
|
||||
// ExtendPeerSession mocks ExtendPeerSession of the AccountManager interface
|
||||
func (am *MockAccountManager) ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) {
|
||||
if am.ExtendPeerSessionFunc != nil {
|
||||
return am.ExtendPeerSessionFunc(ctx, peerPubKey, userID)
|
||||
}
|
||||
return time.Time{}, status.Errorf(codes.Unimplemented, "method ExtendPeerSession is not implemented")
|
||||
}
|
||||
|
||||
// SyncPeer mocks SyncPeer of the AccountManager interface
|
||||
func (am *MockAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if am.SyncPeerFunc != nil {
|
||||
|
||||
@@ -1151,6 +1151,79 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return p, nmap, pc, err
|
||||
}
|
||||
|
||||
// ExtendPeerSession refreshes the peer's SSO session deadline by updating
|
||||
// LastLogin after a successful JWT validation. The tunnel is untouched: no
|
||||
// network map sync, no peer reconnect.
|
||||
//
|
||||
// Preconditions enforced here:
|
||||
// - userID must be present (caller validated the JWT and extracted the user ID).
|
||||
// - The peer must exist and be SSO-registered (AddedWithSSOLogin) with
|
||||
// LoginExpirationEnabled.
|
||||
// - Account-level PeerLoginExpirationEnabled must be true.
|
||||
// - The JWT user must match peer.UserID (mirrors LoginPeer at peer.go ~1028).
|
||||
//
|
||||
// Returns the new absolute UTC deadline.
|
||||
func (am *DefaultAccountManager) ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) {
|
||||
if userID == "" {
|
||||
return time.Time{}, status.Errorf(status.PermissionDenied, "session extend requires a JWT")
|
||||
}
|
||||
|
||||
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
if !settings.PeerLoginExpirationEnabled {
|
||||
return time.Time{}, status.Errorf(status.PreconditionFailed, "peer login expiration is disabled for the account")
|
||||
}
|
||||
|
||||
var refreshed *nbpeer.Peer
|
||||
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
peer, err := transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !peer.AddedWithSSOLogin() || !peer.LoginExpirationEnabled {
|
||||
return status.Errorf(status.PreconditionFailed, "peer is not eligible for session extension")
|
||||
}
|
||||
|
||||
if peer.UserID != userID {
|
||||
log.WithContext(ctx).Warnf("user mismatch when extending session for peer %s: peer user %s, jwt user %s", peer.ID, peer.UserID, userID)
|
||||
return status.NewPeerLoginMismatchError()
|
||||
}
|
||||
|
||||
peer = peer.UpdateLastLogin()
|
||||
if err := transaction.SavePeer(ctx, accountID, peer); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := transaction.SaveUserLastLogin(ctx, accountID, userID, peer.GetLastLogin()); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update user last login during session extend: %v", err)
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userID, peer.ID, accountID, activity.UserExtendedPeerSession, peer.EventMeta(am.networkMapController.GetDNSDomain(settings)))
|
||||
refreshed = peer
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return time.Time{}, err
|
||||
}
|
||||
|
||||
// Reschedule the per-account expiration job. schedulePeerLoginExpiration
|
||||
// is a no-op when a job is already running, but the running job will pick
|
||||
// up the new LastLogin on its next tick. Calling it here is harmless and
|
||||
// guarantees a job is in flight even if a prior one ended right before
|
||||
// the extend.
|
||||
am.schedulePeerLoginExpiration(ctx, accountID)
|
||||
|
||||
return refreshed.SessionExpiresAt(settings.PeerLoginExpirationEnabled, settings.PeerLoginExpiration), nil
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks for the peer.
|
||||
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
|
||||
@@ -367,6 +367,22 @@ func (p *Peer) LoginExpired(expiresIn time.Duration) (bool, time.Duration) {
|
||||
return timeLeft <= 0, timeLeft
|
||||
}
|
||||
|
||||
// SessionExpiresAt returns the absolute UTC instant at which the peer's SSO
|
||||
// session expires, derived from LastLogin and the account-level
|
||||
// PeerLoginExpiration setting. Returns the zero value when login expiration
|
||||
// does not apply (peer not SSO-registered, peer-level toggle off, or account
|
||||
// expiry disabled). Callers should treat the zero value as "no deadline".
|
||||
func (p *Peer) SessionExpiresAt(accountExpirationEnabled bool, expiresIn time.Duration) time.Time {
|
||||
if !accountExpirationEnabled || !p.AddedWithSSOLogin() || !p.LoginExpirationEnabled {
|
||||
return time.Time{}
|
||||
}
|
||||
last := p.GetLastLogin()
|
||||
if last.IsZero() {
|
||||
return time.Time{}
|
||||
}
|
||||
return last.Add(expiresIn).UTC()
|
||||
}
|
||||
|
||||
// FQDN returns peers FQDN combined of the peer's DNS label and the system's DNS domain
|
||||
func (p *Peer) FQDN(dnsDomain string) string {
|
||||
if dnsDomain == "" {
|
||||
|
||||
@@ -16,6 +16,10 @@ type Client interface {
|
||||
Job(ctx context.Context, msgHandler func(msg *proto.JobRequest) *proto.JobResponse) error
|
||||
Register(setupKey string, jwtToken string, sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
Login(sysInfo *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
// ExtendAuthSession refreshes the peer's SSO session deadline using a fresh JWT.
|
||||
// Returns the new absolute deadline; zero time when the server reports the peer
|
||||
// is not eligible for session extension.
|
||||
ExtendAuthSession(sysInfo *system.Info, jwtToken string) (*proto.ExtendAuthSessionResponse, error)
|
||||
GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error)
|
||||
GetPKCEAuthorizationFlow() (*proto.PKCEAuthorizationFlow, error)
|
||||
GetNetworkMap(sysInfo *system.Info) (*proto.NetworkMap, error)
|
||||
|
||||
@@ -607,6 +607,61 @@ func (c *GrpcClient) Login(sysInfo *system.Info, pubSSHKey []byte, dnsLabels dom
|
||||
return c.login(&proto.LoginRequest{Meta: infoToMetaData(sysInfo), PeerKeys: keys, DnsLabels: dnsLabels.ToPunycodeList()})
|
||||
}
|
||||
|
||||
// ExtendAuthSession refreshes the peer's SSO session deadline on the management
|
||||
// server using a freshly issued JWT. The tunnel is untouched: no network map
|
||||
// sync, no peer reconnect. Returns the new absolute UTC deadline (zero time
|
||||
// when the server reports the field empty).
|
||||
func (c *GrpcClient) ExtendAuthSession(sysInfo *system.Info, jwtToken string) (*proto.ExtendAuthSessionResponse, error) {
|
||||
if !c.ready() {
|
||||
return nil, errors.New(errMsgNoMgmtConnection)
|
||||
}
|
||||
|
||||
serverKey, err := c.getServerPublicKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reqBody, err := encryption.EncryptMessage(*serverKey, c.key, &proto.ExtendAuthSessionRequest{
|
||||
JwtToken: jwtToken,
|
||||
Meta: infoToMetaData(sysInfo),
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("failed to encrypt extend auth session message: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var resp *proto.EncryptedMessage
|
||||
operation := func() error {
|
||||
mgmCtx, cancel := context.WithTimeout(context.Background(), ConnectTimeout)
|
||||
defer cancel()
|
||||
|
||||
var err error
|
||||
resp, err = c.realClient.ExtendAuthSession(mgmCtx, &proto.EncryptedMessage{
|
||||
WgPubKey: c.key.PublicKey().String(),
|
||||
Body: reqBody,
|
||||
})
|
||||
if err != nil {
|
||||
if s, ok := gstatus.FromError(err); ok && s.Code() == codes.Canceled {
|
||||
return err
|
||||
}
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := backoff.Retry(operation, nbgrpc.Backoff(c.ctx)); err != nil {
|
||||
log.Errorf("failed to extend auth session on Management Service: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := &proto.ExtendAuthSessionResponse{}
|
||||
if err := encryption.DecryptMessage(*serverKey, c.key, resp.Body, out); err != nil {
|
||||
log.Errorf("failed to decrypt extend auth session response: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// GetDeviceAuthorizationFlow returns a device authorization flow information.
|
||||
// It also takes care of encrypting and decrypting messages.
|
||||
func (c *GrpcClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) {
|
||||
|
||||
@@ -14,6 +14,7 @@ type MockClient struct {
|
||||
SyncFunc func(ctx context.Context, sysInfo *system.Info, msgHandler func(msg *proto.SyncResponse) error) error
|
||||
RegisterFunc func(setupKey string, jwtToken string, info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
LoginFunc func(info *system.Info, sshKey []byte, dnsLabels domain.List) (*proto.LoginResponse, error)
|
||||
ExtendAuthSessionFunc func(info *system.Info, jwtToken string) (*proto.ExtendAuthSessionResponse, error)
|
||||
GetDeviceAuthorizationFlowFunc func() (*proto.DeviceAuthorizationFlow, error)
|
||||
GetPKCEAuthorizationFlowFunc func() (*proto.PKCEAuthorizationFlow, error)
|
||||
GetServerURLFunc func() string
|
||||
@@ -65,6 +66,13 @@ func (m *MockClient) Login(info *system.Info, sshKey []byte, dnsLabels domain.Li
|
||||
return m.LoginFunc(info, sshKey, dnsLabels)
|
||||
}
|
||||
|
||||
func (m *MockClient) ExtendAuthSession(info *system.Info, jwtToken string) (*proto.ExtendAuthSessionResponse, error) {
|
||||
if m.ExtendAuthSessionFunc == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.ExtendAuthSessionFunc(info, jwtToken)
|
||||
}
|
||||
|
||||
func (m *MockClient) GetDeviceAuthorizationFlow() (*proto.DeviceAuthorizationFlow, error) {
|
||||
if m.GetDeviceAuthorizationFlowFunc == nil {
|
||||
return nil, nil
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -52,6 +52,14 @@ service ManagementService {
|
||||
// Executes a job on a target peer (e.g., debug bundle)
|
||||
rpc Job(stream EncryptedMessage) returns (stream EncryptedMessage) {}
|
||||
|
||||
// ExtendAuthSession refreshes the peer's session expiry deadline using a fresh JWT.
|
||||
// Same JWT validation pipeline as Login (including jwt.UserID == peer.UserID check),
|
||||
// but does not redo the network-map sync. Only valid for SSO-registered peers where
|
||||
// login expiration is enabled. The tunnel remains up.
|
||||
// EncryptedMessage of the request has a body of ExtendAuthSessionRequest.
|
||||
// EncryptedMessage of the response has a body of ExtendAuthSessionResponse.
|
||||
rpc ExtendAuthSession(EncryptedMessage) returns (EncryptedMessage) {}
|
||||
|
||||
// CreateExpose creates a temporary reverse proxy service for a peer
|
||||
rpc CreateExpose(EncryptedMessage) returns (EncryptedMessage) {}
|
||||
|
||||
@@ -133,6 +141,15 @@ message SyncResponse {
|
||||
|
||||
// Posture checks to be evaluated by client
|
||||
repeated Checks Checks = 6;
|
||||
|
||||
// 3-state session deadline. Carried on every Sync snapshot so admin-side
|
||||
// changes propagate live without a client reconnect.
|
||||
// field unset (nil) → snapshot carries no info; client keeps the
|
||||
// deadline it already had
|
||||
// set, seconds=0 nanos=0 → explicit "expiry disabled" or peer is not
|
||||
// SSO-registered; client clears its anchor
|
||||
// set, valid timestamp → new absolute UTC deadline
|
||||
google.protobuf.Timestamp sessionExpiresAt = 7;
|
||||
}
|
||||
|
||||
message SyncMetaRequest {
|
||||
@@ -244,6 +261,31 @@ message LoginResponse {
|
||||
PeerConfig peerConfig = 2;
|
||||
// Posture checks to be evaluated by client
|
||||
repeated Checks Checks = 3;
|
||||
|
||||
// 3-state session deadline; same encoding as SyncResponse.sessionExpiresAt.
|
||||
// field unset (nil) → no info; client keeps any deadline it had
|
||||
// set, seconds=0 nanos=0 → explicit "expiry disabled" / non-SSO peer
|
||||
// set, valid timestamp → new absolute UTC deadline
|
||||
google.protobuf.Timestamp sessionExpiresAt = 4;
|
||||
}
|
||||
|
||||
// ExtendAuthSessionRequest carries a fresh JWT to refresh the peer's session deadline.
|
||||
// The encrypted body of an EncryptedMessage with this payload is sent to the
|
||||
// ExtendAuthSession RPC.
|
||||
message ExtendAuthSessionRequest {
|
||||
// SSO token (must be a fresh, valid JWT for the peer's owning user)
|
||||
string jwtToken = 1;
|
||||
// Meta data of the peer (used for IdP user info refresh consistent with Login)
|
||||
PeerSystemMeta meta = 2;
|
||||
}
|
||||
|
||||
// ExtendAuthSessionResponse contains the refreshed session deadline.
|
||||
message ExtendAuthSessionResponse {
|
||||
// 3-state session deadline; same encoding as SyncResponse.sessionExpiresAt.
|
||||
// In practice ExtendAuthSession only succeeds for SSO peers with expiry
|
||||
// enabled, so this carries a valid timestamp on the success path. The
|
||||
// 3-state encoding is documented here for symmetry with Login/Sync.
|
||||
google.protobuf.Timestamp sessionExpiresAt = 1;
|
||||
}
|
||||
|
||||
message ServerKeyResponse {
|
||||
|
||||
@@ -52,6 +52,13 @@ type ManagementServiceClient interface {
|
||||
Logout(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*Empty, error)
|
||||
// Executes a job on a target peer (e.g., debug bundle)
|
||||
Job(ctx context.Context, opts ...grpc.CallOption) (ManagementService_JobClient, error)
|
||||
// ExtendAuthSession refreshes the peer's session expiry deadline using a fresh JWT.
|
||||
// Same JWT validation pipeline as Login (including jwt.UserID == peer.UserID check),
|
||||
// but does not redo the network-map sync. Only valid for SSO-registered peers where
|
||||
// login expiration is enabled. The tunnel remains up.
|
||||
// EncryptedMessage of the request has a body of ExtendAuthSessionRequest.
|
||||
// EncryptedMessage of the response has a body of ExtendAuthSessionResponse.
|
||||
ExtendAuthSession(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
|
||||
// CreateExpose creates a temporary reverse proxy service for a peer
|
||||
CreateExpose(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error)
|
||||
// RenewExpose extends the TTL of an active expose session
|
||||
@@ -194,6 +201,15 @@ func (x *managementServiceJobClient) Recv() (*EncryptedMessage, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (c *managementServiceClient) ExtendAuthSession(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) {
|
||||
out := new(EncryptedMessage)
|
||||
err := c.cc.Invoke(ctx, "/management.ManagementService/ExtendAuthSession", in, out, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *managementServiceClient) CreateExpose(ctx context.Context, in *EncryptedMessage, opts ...grpc.CallOption) (*EncryptedMessage, error) {
|
||||
out := new(EncryptedMessage)
|
||||
err := c.cc.Invoke(ctx, "/management.ManagementService/CreateExpose", in, out, opts...)
|
||||
@@ -259,6 +275,13 @@ type ManagementServiceServer interface {
|
||||
Logout(context.Context, *EncryptedMessage) (*Empty, error)
|
||||
// Executes a job on a target peer (e.g., debug bundle)
|
||||
Job(ManagementService_JobServer) error
|
||||
// ExtendAuthSession refreshes the peer's session expiry deadline using a fresh JWT.
|
||||
// Same JWT validation pipeline as Login (including jwt.UserID == peer.UserID check),
|
||||
// but does not redo the network-map sync. Only valid for SSO-registered peers where
|
||||
// login expiration is enabled. The tunnel remains up.
|
||||
// EncryptedMessage of the request has a body of ExtendAuthSessionRequest.
|
||||
// EncryptedMessage of the response has a body of ExtendAuthSessionResponse.
|
||||
ExtendAuthSession(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
|
||||
// CreateExpose creates a temporary reverse proxy service for a peer
|
||||
CreateExpose(context.Context, *EncryptedMessage) (*EncryptedMessage, error)
|
||||
// RenewExpose extends the TTL of an active expose session
|
||||
@@ -299,6 +322,9 @@ func (UnimplementedManagementServiceServer) Logout(context.Context, *EncryptedMe
|
||||
func (UnimplementedManagementServiceServer) Job(ManagementService_JobServer) error {
|
||||
return status.Errorf(codes.Unimplemented, "method Job not implemented")
|
||||
}
|
||||
func (UnimplementedManagementServiceServer) ExtendAuthSession(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method ExtendAuthSession not implemented")
|
||||
}
|
||||
func (UnimplementedManagementServiceServer) CreateExpose(context.Context, *EncryptedMessage) (*EncryptedMessage, error) {
|
||||
return nil, status.Errorf(codes.Unimplemented, "method CreateExpose not implemented")
|
||||
}
|
||||
@@ -494,6 +520,24 @@ func (x *managementServiceJobServer) Recv() (*EncryptedMessage, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func _ManagementService_ExtendAuthSession_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(EncryptedMessage)
|
||||
if err := dec(in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if interceptor == nil {
|
||||
return srv.(ManagementServiceServer).ExtendAuthSession(ctx, in)
|
||||
}
|
||||
info := &grpc.UnaryServerInfo{
|
||||
Server: srv,
|
||||
FullMethod: "/management.ManagementService/ExtendAuthSession",
|
||||
}
|
||||
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||
return srv.(ManagementServiceServer).ExtendAuthSession(ctx, req.(*EncryptedMessage))
|
||||
}
|
||||
return interceptor(ctx, in, info, handler)
|
||||
}
|
||||
|
||||
func _ManagementService_CreateExpose_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||
in := new(EncryptedMessage)
|
||||
if err := dec(in); err != nil {
|
||||
@@ -583,6 +627,10 @@ var ManagementService_ServiceDesc = grpc.ServiceDesc{
|
||||
MethodName: "Logout",
|
||||
Handler: _ManagementService_Logout_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "ExtendAuthSession",
|
||||
Handler: _ManagementService_ExtendAuthSession_Handler,
|
||||
},
|
||||
{
|
||||
MethodName: "CreateExpose",
|
||||
Handler: _ManagementService_CreateExpose_Handler,
|
||||
|
||||
Reference in New Issue
Block a user