Compare commits

..

6 Commits

Author SHA1 Message Date
Theodor S. Midtlien
75a898869f Fix review 2026-06-02 15:43:45 +02:00
Theodor S. Midtlien
e27f0ed05b Clean up 2026-06-02 14:48:04 +02:00
Theodor S. Midtlien
1d04a34f45 Migrate android profile manager 2026-06-02 09:58:28 +02:00
Theodor S. Midtlien
827c798334 Migrate to profile ids 2026-06-01 12:53:29 +02:00
Riccardo Manfrin
7ea5e37dd4 [client] Improve rosenpass support (#6136)
* Updates rosenpass version

go-rosenpass v0.4.0 → v0.5.42 bump — detailed findings

Change summary
cunicu.li/go-rosenpass  v0.4.0  → v0.5.42   (target)
cilium/ebpf             v0.15.0 → v0.19.0   (transitive)
gopacket/gopacket       v1.1.1  → v1.4.0    (transitive)
wireguard               2023-07 → 2023-12   (transitive)
wireguard/wgctrl        2023-04 → 2024-12   (transitive)

Wire interop

v0.4.0 (in v0.70.5) <-> v0.5.42 OK
v0.5.42 <-> v0.5.42 OK

Quantum resistance: true both ends

---
**Replay error eliminated.**

Before (on v0.4.0):

`ERROR Failed to handle message: failed to load biscuit (ICR1): detected replay`

Recurring every ~50ms for minutes at a time. Gone entirely after both ends upgraded to v0.5.42. Upstream fix in biscuit/replay handling between v0.4.x and v0.5.x series.

* Fixup [::]:port socket trying to send to v4

* Adds more tests on netbird<->rosenpass interactions

* Anticipates rp handler creation before generateConfig

* [client] Moves deterministic key gen into rosenpass

* go mod tidy

* Adds reminder to reason about rosenpass surface area

* Apply code rabbit suggestions
2026-05-28 09:01:18 +02:00
Riccardo Manfrin
9d7ef9b255 [client] Fix statemanager possible deadlock (#6228)
1. Stop() takes m.mu.Lock() and defers m.mu.Unlock()
2. <-m.done under lock
3. periodicStateSave defers close(m.done)
4. periodicStateSave calls PersistState() (line 256) which does m.mu.Lock()

Double Stop() remains idempotent: second cancel() on dead ctx
 (no-op) and reads done already closed (immediate return).
2026-05-28 08:54:15 +02:00
33 changed files with 1712 additions and 886 deletions

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
@@ -24,6 +23,7 @@ const (
// Profile represents a profile for gomobile
type Profile struct {
ID string
Name string
IsActive bool
}
@@ -99,6 +99,7 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
var profiles []*Profile
for _, p := range internalProfiles {
profiles = append(profiles, &Profile{
ID: p.ID,
Name: p.Name,
IsActive: p.IsActive,
})
@@ -108,55 +109,65 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
}
// GetActiveProfile returns the currently active profile name
func (pm *ProfileManager) GetActiveProfile() (string, error) {
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
activeState, err := pm.serviceMgr.GetActiveProfileState()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
return nil, fmt.Errorf("failed to get active profile: %w", err)
}
return activeState.Name, nil
// ActiveProfileState only stores the ID (and username), not the display
// name. Resolve the ID to the full profile so callers get the real Name.
prof, err := pm.serviceMgr.ResolveProfile(activeState.ID, androidUsername)
if err != nil {
return nil, fmt.Errorf("failed to resolve active profile %q: %w", activeState.ID, err)
}
return &Profile{ID: prof.ID, Name: prof.Name, IsActive: true}, nil
}
// SwitchProfile switches to a different profile
func (pm *ProfileManager) SwitchProfile(profileName string) error {
func (pm *ProfileManager) SwitchProfile(id string) error {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profileName,
ID: id,
Username: androidUsername,
})
if err != nil {
return fmt.Errorf("failed to switch profile: %w", err)
}
log.Infof("switched to profile: %s", profileName)
log.Infof("switched to profile: %s", id)
return nil
}
// AddProfile creates a new profile
func (pm *ProfileManager) AddProfile(profileName string) error {
// Use ServiceManager (creates profile in profiles/ directory)
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
profile, err := pm.serviceMgr.AddProfile(profileName, androidUsername)
if err != nil {
return fmt.Errorf("failed to add profile: %w", err)
}
log.Infof("created new profile: %s", profileName)
log.Infof("created new profile: %s", profile.ID)
return nil
}
// LogoutProfile logs out from a profile (clears authentication)
func (pm *ProfileManager) LogoutProfile(profileName string) error {
profileName = sanitizeProfileName(profileName)
configPath, err := pm.getProfileConfigPath(profileName)
func (pm *ProfileManager) LogoutProfile(id string) error {
configPath, err := pm.getProfileConfigPath(id)
if err != nil {
return err
}
if !profilemanager.IsValidProfileFilenameStem(id) {
return fmt.Errorf("id '%s' is not valid", id)
}
// Check if profile exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return fmt.Errorf("profile '%s' does not exist", profileName)
return fmt.Errorf("profile '%s' does not exist", id)
}
// Read current config using internal profilemanager
@@ -174,53 +185,49 @@ func (pm *ProfileManager) LogoutProfile(profileName string) error {
return fmt.Errorf("failed to save config: %w", err)
}
log.Infof("logged out from profile: %s", profileName)
log.Infof("logged out from profile: %s", id)
return nil
}
// RemoveProfile deletes a profile
func (pm *ProfileManager) RemoveProfile(profileName string) error {
func (pm *ProfileManager) RemoveProfile(id string) error {
// Use ServiceManager (removes profile from profiles/ directory)
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
if err := pm.serviceMgr.RemoveProfile(id, androidUsername); err != nil {
return fmt.Errorf("failed to remove profile: %w", err)
}
log.Infof("removed profile: %s", profileName)
log.Infof("removed profile: %s", id)
return nil
}
// getProfileConfigPath returns the config file path for a profile
// This is needed for Android-specific path handling (netbird.cfg for default profile)
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
func (pm *ProfileManager) getProfileConfigPath(id string) (string, error) {
if id == "" || id == profilemanager.DefaultProfileName {
// Android uses netbird.cfg for default profile instead of default.json
// Default profile is stored in root configDir, not in profiles/
return filepath.Join(pm.configDir, defaultConfigFilename), nil
}
// Non-default profiles are stored in profiles subdirectory
// This matches the Java Preferences.java expectation
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".json"), nil
return filepath.Join(profilesDir, id+".json"), nil
}
// GetConfigPath returns the config file path for a given profile
// GetConfigPath returns the config file path for a given profile id
// Java should call this instead of constructing paths with Preferences.configFile()
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
return pm.getProfileConfigPath(profileName)
func (pm *ProfileManager) GetConfigPath(id string) (string, error) {
return pm.getProfileConfigPath(id)
}
// GetStateFilePath returns the state file path for a given profile
// Java should call this instead of constructing paths with Preferences.stateFile()
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
func (pm *ProfileManager) GetStateFilePath(id string) (string, error) {
if id == "" || id == profilemanager.DefaultProfileName {
return filepath.Join(pm.configDir, "state.json"), nil
}
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".state.json"), nil
return filepath.Join(profilesDir, id+".state.json"), nil
}
// GetActiveConfigPath returns the config file path for the currently active profile
@@ -230,7 +237,7 @@ func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetConfigPath(activeProfile)
return pm.GetConfigPath(activeProfile.ID)
}
// GetActiveStateFilePath returns the state file path for the currently active profile
@@ -240,18 +247,5 @@ func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetStateFilePath(activeProfile)
}
// sanitizeProfileName removes invalid characters from profile name
func sanitizeProfileName(name string) string {
// Keep only alphanumeric, underscore, and hyphen
var result strings.Builder
for _, r := range name {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
result.WriteRune(r)
}
}
return result.String()
return pm.GetStateFilePath(activeProfile.ID)
}

View File

@@ -102,11 +102,11 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName,
DnsLabels: dnsLabelsReq,
ProfileName: &activeProf.Name,
ProfileName: &activeProf.ID,
Username: &username,
}
profileState, err := pm.GetProfileState(activeProf.Name)
profileState, err := pm.GetProfileState(activeProf.ID)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
@@ -170,14 +170,13 @@ func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, pr
return activeProf, nil
}
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) error {
err := switchProfile(context.Background(), profileName, username)
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, handle string, username string) error {
resolvedID, err := switchProfile(ctx, handle, username)
if err != nil {
return fmt.Errorf("switch profile on daemon: %v", err)
}
err = pm.SwitchProfile(profileName)
if err != nil {
if err := pm.SwitchProfile(resolvedID); err != nil {
return fmt.Errorf("switch profile: %v", err)
}
@@ -205,11 +204,15 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
return nil
}
func switchProfile(ctx context.Context, profileName string, username string) error {
// switchProfile asks the daemon to switch to the profile identified by
// handle (a name, ID, or unique ID prefix). Returns the resolved profile
// ID so the caller can update the local active-profile state without
// re-resolving the handle.
func switchProfile(ctx context.Context, handle string, username string) (string, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
@@ -217,15 +220,15 @@ func switchProfile(ctx context.Context, profileName string, username string) err
client := proto.NewDaemonServiceClient(conn)
_, err = client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profileName,
resp, err := client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &handle,
Username: &username,
})
if err != nil {
return fmt.Errorf("switch profile failed: %v", err)
return "", fmt.Errorf("switch profile failed: %v", err)
}
return nil
return resp.Id, nil
}
func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, activeProf *profilemanager.Profile) error {
@@ -249,7 +252,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
return fmt.Errorf("read config file %s: %v", configFilePath, err)
}
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.ID)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}

View File

@@ -27,7 +27,7 @@ func TestLogin(t *testing.T) {
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
sm := profilemanager.ServiceManager{}
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "default",
ID: "default",
Username: currUser.Username,
})
if err != nil {

View File

@@ -2,11 +2,16 @@ package cmd
import (
"context"
"errors"
"fmt"
"os/user"
"strings"
"text/tabwriter"
"time"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
@@ -14,6 +19,8 @@ import (
"github.com/netbirdio/netbird/util"
)
var profileListShowID bool
var profileCmd = &cobra.Command{
Use: "profile",
Short: "Manage NetBird client profiles",
@@ -31,27 +38,32 @@ var profileListCmd = &cobra.Command{
var profileAddCmd = &cobra.Command{
Use: "add <profile_name>",
Short: "Add a new profile",
Long: `Add a new profile to the NetBird client. The profile name must be unique.`,
Long: `Add a new profile. Profile name is free-form, a unique ID is generated for the on-disk config file.`,
Args: cobra.ExactArgs(1),
RunE: addProfileFunc,
}
var profileRemoveCmd = &cobra.Command{
Use: "remove <profile_name>",
Short: "Remove a profile",
Long: `Remove a profile from the NetBird client. The profile must not be inactive.`,
Args: cobra.ExactArgs(1),
RunE: removeProfileFunc,
Use: "remove <profile>",
Short: "Remove a profile",
Long: `Remove a profile by name, ID, or unique ID prefix.`,
Aliases: []string{"rm"},
Args: cobra.ExactArgs(1),
RunE: removeProfileFunc,
}
var profileSelectCmd = &cobra.Command{
Use: "select <profile_name>",
Use: "select <profile>",
Short: "Select a profile",
Long: `Make the specified profile active. This will switch the client to use the selected profile's configuration.`,
Long: `Make the specified profile active. Accepts a name, ID, or unique ID prefix.`,
Args: cobra.ExactArgs(1),
RunE: selectProfileFunc,
}
func init() {
profileListCmd.Flags().BoolVar(&profileListShowID, "show-id", false, "show the profile ID column")
}
func setupCmd(cmd *cobra.Command) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
@@ -65,6 +77,7 @@ func setupCmd(cmd *cobra.Command) error {
return nil
}
func listProfilesFunc(cmd *cobra.Command, _ []string) error {
if err := setupCmd(cmd); err != nil {
return err
@@ -83,25 +96,32 @@ func listProfilesFunc(cmd *cobra.Command, _ []string) error {
daemonClient := proto.NewDaemonServiceClient(conn)
profiles, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
resp, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
Username: currUser.Username,
})
if err != nil {
return err
}
// list profiles, add a tick if the profile is active
cmd.Println("Found", len(profiles.Profiles), "profiles:")
for _, profile := range profiles.Profiles {
// use a cross to indicate the passive profiles
activeMarker := "✗"
if profile.IsActive {
activeMarker = "✓"
}
cmd.Println(activeMarker, profile.Name)
tw := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0)
if profileListShowID {
fmt.Fprintln(tw, "ID\tNAME\tACTIVE")
} else {
fmt.Fprintln(tw, "NAME\tACTIVE")
}
return nil
for _, profile := range resp.Profiles {
marker := ""
if profile.IsActive {
marker = "✓"
}
name := profilemanager.StripCtrlChars(profile.Name)
if profileListShowID {
fmt.Fprintf(tw, "%s\t%s\t%s\n", profilemanager.ShortID(profile.Id), name, marker)
} else {
fmt.Fprintf(tw, "%s\t%s\n", name, marker)
}
}
return tw.Flush()
}
func addProfileFunc(cmd *cobra.Command, args []string) error {
@@ -121,19 +141,49 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
}
daemonClient := proto.NewDaemonServiceClient(conn)
profileName := args[0]
_, err = daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
if err != nil {
return err
if err == nil {
cmd.Printf("Profile added: %s %s\n", profilemanager.ShortID(resp.Id), profilemanager.StripCtrlChars(profileName))
return nil
}
cmd.Println("Profile added successfully:", profileName)
return nil
if st, ok := gstatus.FromError(err); ok && st.Code() == codes.AlreadyExists {
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
if dupCount > 0 {
cmd.Printf("Warning: %d other profile(s) already use the name %q.\n", dupCount, profileName)
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
}
resp, err = daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
if err != nil {
return err
}
cmd.Printf("Profile added: %s %s\n", profilemanager.ShortID(resp.Id), profilemanager.StripCtrlChars(profileName))
return nil
}
return err
}
func countProfilesWithName(ctx context.Context, c proto.DaemonServiceClient, username, name string) (int, error) {
resp, err := c.ListProfiles(ctx, &proto.ListProfilesRequest{Username: username})
if err != nil {
return 0, err
}
n := 0
for _, p := range resp.Profiles {
if p.Name == name {
n++
}
}
return n, nil
}
func removeProfileFunc(cmd *cobra.Command, args []string) error {
@@ -153,18 +203,17 @@ func removeProfileFunc(cmd *cobra.Command, args []string) error {
}
daemonClient := proto.NewDaemonServiceClient(conn)
handle := args[0]
profileName := args[0]
_, err = daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
ProfileName: profileName,
resp, err := daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
ProfileName: handle,
Username: currUser.Username,
})
if err != nil {
return err
return wrapAmbiguityError(err, handle)
}
cmd.Println("Profile removed successfully:", profileName)
cmd.Printf("Profile removed: %s\n", resp.Id)
return nil
}
@@ -174,7 +223,7 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
}
profileManager := profilemanager.NewProfileManager()
profileName := args[0]
handle := args[0]
currUser, err := user.Current()
if err != nil {
@@ -191,32 +240,15 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
daemonClient := proto.NewDaemonServiceClient(conn)
profiles, err := daemonClient.ListProfiles(ctx, &proto.ListProfilesRequest{
Username: currUser.Username,
switchResp, err := daemonClient.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &handle,
Username: &currUser.Username,
})
if err != nil {
return fmt.Errorf("list profiles: %w", err)
return wrapAmbiguityError(err, handle)
}
var profileExists bool
for _, profile := range profiles.Profiles {
if profile.Name == profileName {
profileExists = true
break
}
}
if !profileExists {
return fmt.Errorf("profile %s does not exist", profileName)
}
if err := switchProfile(cmd.Context(), profileName, currUser.Username); err != nil {
return err
}
err = profileManager.SwitchProfile(profileName)
if err != nil {
if err := profileManager.SwitchProfile(switchResp.Id); err != nil {
return err
}
@@ -231,6 +263,29 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
}
}
cmd.Println("Profile switched successfully to:", profileName)
cmd.Printf("Profile switched to: %s\n", profilemanager.ShortID(switchResp.Id))
return nil
}
// wrapAmbiguityError turns the daemon's gRPC InvalidArgument errors
// (which carry the resolver's message verbatim) into CLI-friendly text
// that points the user at --show-id.
func wrapAmbiguityError(err error, handle string) error {
if err == nil {
return nil
}
st, ok := gstatus.FromError(err)
if !ok {
return err
}
switch st.Code() {
case codes.InvalidArgument:
msg := st.Message()
if strings.Contains(msg, "ambiguous") {
return errors.New(msg + "\nRun `netbird profile list --show-id` to see IDs, then select by ID prefix:\n netbird profile select|remove <id-prefix>")
}
case codes.NotFound:
return fmt.Errorf("profile %q not found", handle)
}
return err
}

View File

@@ -128,13 +128,12 @@ func upFunc(cmd *cobra.Command, args []string) error {
var profileSwitched bool
// switch profile if provided
if profileName != "" {
err = switchProfile(cmd.Context(), profileName, username.Username)
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
err = pm.SwitchProfile(profileName)
if err != nil {
if err := pm.SwitchProfile(resolvedID); err != nil {
return fmt.Errorf("switch profile: %v", err)
}
@@ -261,10 +260,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
}
// set the new config
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.Name, username.Username)
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.ID, username.Username)
if _, err := client.SetConfig(ctx, req); err != nil {
if st, ok := gstatus.FromError(err); ok && st.Code() == codes.Unavailable {
log.Warnf("setConfig method is not available in the daemon")
log.Warnf("setConfig method is not available in the daemon: %s", st.Message())
} else {
return fmt.Errorf("call service setConfig method: %v", err)
}
@@ -289,10 +288,10 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
return fmt.Errorf("setup login request: %v", err)
}
loginRequest.ProfileName = &activeProf.Name
loginRequest.ProfileName = &activeProf.ID
loginRequest.Username = &username
profileState, err := pm.GetProfileState(activeProf.Name)
profileState, err := pm.GetProfileState(activeProf.ID)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
@@ -329,7 +328,7 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
}
if _, err := client.Up(ctx, &proto.UpRequest{
ProfileName: &activeProf.Name,
ProfileName: &activeProf.ID,
Username: &username,
}); err != nil {
return fmt.Errorf("call service up method: %v", err)

View File

@@ -29,14 +29,14 @@ func TestUpDaemon(t *testing.T) {
}
sm := profilemanager.ServiceManager{}
err = sm.AddProfile("test1", currUser.Username)
created, err := sm.AddProfile("test1", currUser.Username)
if err != nil {
t.Fatalf("failed to add profile: %v", err)
return
}
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "test1",
ID: created.ID,
Username: currUser.Username,
})
if err != nil {

View File

@@ -843,6 +843,7 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
"PreSharedKey": "sensitive: WireGuard pre-shared key",
"SSHKey": "sensitive: SSH private key",
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
"Name": "non-config: profile name is not needed for debug purposes",
}
mURL, _ := url.Parse("https://api.example.com:443")

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer/id"
"github.com/netbirdio/netbird/client/internal/peer/worker"
"github.com/netbirdio/netbird/client/internal/portforward"
"github.com/netbirdio/netbird/client/internal/rosenpass"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/route"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
@@ -899,7 +900,7 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
}
// Fallback to deterministic key if no NetBird PSK is configured
determKey, err := conn.rosenpassDetermKey()
determKey, err := rosenpass.DeterministicSeedKey(conn.config.LocalKey, conn.config.Key)
if err != nil {
conn.Log.Errorf("failed to generate Rosenpass initial key: %v", err)
return nil
@@ -908,26 +909,6 @@ func (conn *Conn) presharedKey(remoteRosenpassKey []byte) *wgtypes.Key {
return determKey
}
// todo: move this logic into Rosenpass package
func (conn *Conn) rosenpassDetermKey() (*wgtypes.Key, error) {
lk := []byte(conn.config.LocalKey)
rk := []byte(conn.config.Key) // remote key
var keyInput []byte
if string(lk) > string(rk) {
//nolint:gocritic
keyInput = append(lk[:16], rk[:16]...)
} else {
//nolint:gocritic
keyInput = append(rk[:16], lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, err
}
return &key, nil
}
func isController(config ConnConfig) bool {
return config.LocalKey > config.Key
}

View File

@@ -103,6 +103,10 @@ type ConfigInput struct {
// Config Configuration type
type Config struct {
// Name is the human-readable profile name shown in CLI/UI listings.
// It is independent of the profile's on-disk filename (which is the ID).
Name string
// Wireguard private key of local peer
PrivateKey string
PreSharedKey string

View File

@@ -0,0 +1,110 @@
package profilemanager
import (
"crypto/rand"
"encoding/hex"
"fmt"
"path/filepath"
"strings"
"unicode"
"unicode/utf8"
)
const (
// profileIDByteLen is the number of random bytes generated for a new
// profile ID. The resulting hex string is twice this length.
profileIDByteLen = 16
// shortIDLen is the number of leading characters of an ID we render in
// list output. Profiles per device are few, so 8 chars is collision-safe
// in practice and easy to type as a prefix.
shortIDLen = 8
// maxProfileNameLen caps the human-readable profile name to keep table
// output legible and prevent denial-of-service via huge JSON fields.
maxProfileNameLen = 128
// maxProfileIDLen bounds the on-disk filename we'll accept. New
// IDs are 32 hex chars, legacy stems are sanitized profile names. The
// cap is generous enough to cover both without permitting absurdly
// long filenames.
maxProfileIDLen = 64
)
// generateProfileID returns a new random hex ID for a profile file.
func generateProfileID() (string, error) {
buf := make([]byte, profileIDByteLen)
if _, err := rand.Read(buf); err != nil {
return "", fmt.Errorf("read random bytes: %w", err)
}
return hex.EncodeToString(buf), nil
}
// IsValidProfileFilenameStem reports whether s is safe to use as the stem
// of a profile JSON filename.
func IsValidProfileFilenameStem(s string) bool {
if s == "" || len(s) > maxProfileIDLen {
return false
}
if s == defaultProfileName {
return true
}
if strings.ContainsAny(s, `/\`) || strings.Contains(s, "..") {
return false
}
// filepath.Base catches any leftover separators on platforms with
// exotic path conventions.
if filepath.Base(s) != s {
return false
}
for _, r := range s {
if !(unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '-') {
return false
}
}
return true
}
// sanitizeDisplayName normalizes a user-supplied profile display name for
// storage. It strips ASCII control characters, rejects invalid UTF-8, and
// caps the length. Emojis, spaces, punctuation, and non-ASCII letters are
// preserved. Returns an error if nothing usable remains.
func sanitizeDisplayName(name string) (string, error) {
if !utf8.ValidString(name) {
return "", fmt.Errorf("name is not valid UTF-8")
}
name = StripCtrlChars(name)
name = strings.TrimSpace(name)
if name == "" {
return "", fmt.Errorf("name is empty after sanitization")
}
if utf8.RuneCountInString(name) > maxProfileNameLen {
return "", fmt.Errorf("name exceeds %d characters", maxProfileNameLen)
}
return name, nil
}
// StripCtrlChars control characters from a name before printing it.
func StripCtrlChars(name string) string {
var b strings.Builder
b.Grow(len(name))
for _, r := range name {
// Skip C0 controls and DEL, plus C1 controls (0x800x9F).
if r < 0x20 || r == 0x7F || (r >= 0x80 && r <= 0x9F) {
continue
}
b.WriteRune(r)
}
return b.String()
}
// ShortID truncates an ID for display.
func ShortID(id string) string {
if id == DefaultProfileName {
return id
}
if len(id) <= shortIDLen {
return id
}
return id[:shortIDLen]
}

View File

@@ -19,19 +19,41 @@ const (
)
type Profile struct {
Name string
// ID is the on-disk filename stem (without .json). For new profiles
// it is a 32-char hex string; legacy profiles created before the
// ID-keyed layout keep their original name as their ID. The reserved
// value "default" identifies the special default profile.
ID string
// Name is the human-readable display name. Falls back to ID when the
// underlying JSON has no "name" field set.
Name string
// Path is the absolute path to the profile JSON. Populated by the
// loader so callers do not have to reconstruct it from ID + dir.
Path string
IsActive bool
}
func (p *Profile) FilePath() (string, error) {
if p.Name == "" {
return "", fmt.Errorf("active profile name is empty")
if p.Path != "" {
return p.Path, nil
}
if p.Name == defaultProfileName {
id := p.ID
if id == "" {
id = p.Name
}
if id == "" {
return "", fmt.Errorf("profile ID is empty")
}
if id == defaultProfileName {
return DefaultConfigPath, nil
}
if !IsValidProfileFilenameStem(id) {
return "", fmt.Errorf("invalid profile ID: %q", id)
}
username, err := user.Current()
if err != nil {
return "", fmt.Errorf("failed to get current user: %w", err)
@@ -42,10 +64,13 @@ func (p *Profile) FilePath() (string, error) {
return "", fmt.Errorf("failed to get config directory for user %s: %w", username.Username, err)
}
return filepath.Join(configDir, p.Name+".json"), nil
return filepath.Join(configDir, id+".json"), nil
}
func (p *Profile) IsDefault() bool {
if p.ID != "" {
return p.ID == defaultProfileName
}
return p.Name == defaultProfileName
}
@@ -57,18 +82,24 @@ func NewProfileManager() *ProfileManager {
return &ProfileManager{}
}
// GetActiveProfile returns the active profile as recorded in the local
// user state file. Only ID is populated.
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
prof := pm.getActiveProfileState()
return &Profile{Name: prof}, nil
id := pm.getActiveProfileState()
return &Profile{ID: id}, nil
}
func (pm *ProfileManager) SwitchProfile(profileName string) error {
profileName = sanitizeProfileName(profileName)
// SwitchProfile records the given profile ID as active in the local user
// state file.
func (pm *ProfileManager) SwitchProfile(id string) error {
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid profile ID: %q", id)
}
if err := pm.setActiveProfileState(profileName); err != nil {
if err := pm.setActiveProfileState(id); err != nil {
return fmt.Errorf("failed to switch profile: %w", err)
}
return nil
@@ -142,7 +173,7 @@ func GetLoginHint() string {
return ""
}
profileState, err := pm.GetProfileState(activeProf.Name)
profileState, err := pm.GetProfileState(activeProf.ID)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
return ""

View File

@@ -50,14 +50,14 @@ func TestServiceManager_CreateAndGetDefaultProfile(t *testing.T) {
state, err := sm.GetActiveProfileState()
assert.NoError(t, err)
assert.Equal(t, state.Name, defaultProfileName) // No active profile state yet
assert.Equal(t, defaultProfileName, state.ID) // No active profile state yet
err = sm.SetActiveProfileStateToDefault()
assert.NoError(t, err)
active, err := sm.GetActiveProfileState()
assert.NoError(t, err)
assert.Equal(t, "default", active.Name)
assert.Equal(t, "default", active.ID)
})
})
}
@@ -92,14 +92,14 @@ func TestServiceManager_SetActiveProfileState(t *testing.T) {
currUser, err := user.Current()
assert.NoError(t, err)
sm := &ServiceManager{}
state := &ActiveProfileState{Name: "foo", Username: currUser.Username}
state := &ActiveProfileState{ID: "foo", Username: currUser.Username}
err = sm.SetActiveProfileState(state)
assert.NoError(t, err)
// Should error on nil or incomplete state
err = sm.SetActiveProfileState(nil)
assert.Error(t, err)
err = sm.SetActiveProfileState(&ActiveProfileState{Name: "", Username: ""})
err = sm.SetActiveProfileState(&ActiveProfileState{ID: "", Username: ""})
assert.Error(t, err)
})
})

View File

@@ -2,6 +2,7 @@ package profilemanager
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
@@ -23,12 +24,43 @@ var (
DefaultConfigPathDir = ""
DefaultConfigPath = ""
ActiveProfileStatePath = ""
)
var (
ErrorOldDefaultConfigNotFound = errors.New("old default config not found")
)
// ErrAmbiguousHandle is returned when a profile handle (ID prefix or name)
// matches more than one profile. Callers can render Candidates to help the
// user disambiguate.
type ErrAmbiguousHandle struct {
Handle string
Candidates []Profile
Kind AmbiguityKind
}
// AmbiguityKind describes which matcher produced the ambiguity, so callers
// can tailor the error message.
type AmbiguityKind int
const (
AmbiguityKindIDPrefix AmbiguityKind = iota
AmbiguityKindName
)
// profileMeta is the minimal slice of a profile JSON we need, so we avoid
// reading all fields
type profileMeta struct {
Name string
}
func (e *ErrAmbiguousHandle) Error() string {
switch e.Kind {
case AmbiguityKindIDPrefix:
return fmt.Sprintf("ID prefix %q is ambiguous (matches %d profiles)", e.Handle, len(e.Candidates))
default:
return fmt.Sprintf("name %q is ambiguous (%d profiles share this name)", e.Handle, len(e.Candidates))
}
}
func init() {
DefaultConfigPathDir = "/var/lib/netbird/"
@@ -54,25 +86,34 @@ func init() {
}
type ActiveProfileState struct {
Name string `json:"name"`
// ID is the on-disk filename stem of the active profile. The JSON tag stays
// as "name" for backwards compatibility with active state files written
// before the ID-based config files. Legacy values were profile names, which
// were also the legacy filename stems, so they still resolve to the correct
// file on disk.
ID string `json:"name"`
Username string `json:"username"`
}
func (a *ActiveProfileState) FilePath() (string, error) {
if a.Name == "" {
return "", fmt.Errorf("active profile name is empty")
if a.ID == "" {
return "", fmt.Errorf("active profile ID is empty")
}
if a.Name == defaultProfileName {
if a.ID == defaultProfileName {
return DefaultConfigPath, nil
}
if !IsValidProfileFilenameStem(a.ID) {
return "", fmt.Errorf("invalid profile ID: %q", a.ID)
}
configDir, err := getConfigDirForUser(a.Username)
if err != nil {
return "", fmt.Errorf("failed to get config directory for user %s: %w", a.Username, err)
}
return filepath.Join(configDir, a.Name+".json"), nil
return filepath.Join(configDir, a.ID+".json"), nil
}
type ServiceManager struct {
@@ -178,7 +219,7 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
return nil, fmt.Errorf("failed to set active profile to default: %w", err)
}
return &ActiveProfileState{
Name: "default",
ID: defaultProfileName,
Username: "",
}, nil
} else {
@@ -186,12 +227,12 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
}
}
if activeProfile.Name == "" {
if activeProfile.ID == "" {
if err := s.SetActiveProfileStateToDefault(); err != nil {
return nil, fmt.Errorf("failed to set active profile to default: %w", err)
}
return &ActiveProfileState{
Name: "default",
ID: defaultProfileName,
Username: "",
}, nil
}
@@ -216,25 +257,29 @@ func (s *ServiceManager) setDefaultActiveState() error {
}
func (s *ServiceManager) SetActiveProfileState(a *ActiveProfileState) error {
if a == nil || a.Name == "" {
if a == nil || a.ID == "" {
return errors.New("invalid active profile state")
}
if a.Name != defaultProfileName && a.Username == "" {
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.Name)
if a.ID != defaultProfileName && a.Username == "" {
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.ID)
}
if a.ID != defaultProfileName && !IsValidProfileFilenameStem(a.ID) {
return fmt.Errorf("invalid profile ID: %q", a.ID)
}
if err := util.WriteJsonWithRestrictedPermission(context.Background(), ActiveProfileStatePath, a); err != nil {
return fmt.Errorf("failed to write active profile state: %w", err)
}
log.Infof("active profile set to %s for %s", a.Name, a.Username)
log.Infof("active profile set to %s for %s", a.ID, a.Username)
return nil
}
func (s *ServiceManager) SetActiveProfileStateToDefault() error {
return s.SetActiveProfileState(&ActiveProfileState{
Name: "default",
ID: defaultProfileName,
Username: "",
})
}
@@ -243,57 +288,75 @@ func (s *ServiceManager) DefaultProfilePath() string {
return DefaultConfigPath
}
func (s *ServiceManager) AddProfile(profileName, username string) error {
// AddProfile creates a new profile with a generated ID. The user-supplied
// displayName is stored inside the JSON's name field, the on-disk filename
// uses the generated ID.
//
// The returned Profile carries the freshly-generated ID so callers can
// show it to the user (and so the gRPC AddProfileResponse can include
// it).
func (s *ServiceManager) AddProfile(displayName, username string) (*Profile, error) {
configDir, err := s.getConfigDir(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
return nil, fmt.Errorf("failed to get config directory: %w", err)
}
profileName = sanitizeProfileName(profileName)
if profileName == defaultProfileName {
return fmt.Errorf("cannot create profile with reserved name: %s", defaultProfileName)
}
profPath := filepath.Join(configDir, profileName+".json")
profileExists, err := fileExists(profPath)
displayName, err = sanitizeDisplayName(displayName)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if profileExists {
return ErrProfileAlreadyExists
return nil, fmt.Errorf("invalid profile name: %w", err)
}
if displayName == defaultProfileName {
return nil, fmt.Errorf("cannot create profile with reserved name: %s", defaultProfileName)
}
id, err := generateProfileID()
if err != nil {
return nil, fmt.Errorf("generate profile id: %w", err)
}
profPath := filepath.Join(configDir, id+".json")
cfg, err := createNewConfig(ConfigInput{ConfigPath: profPath})
if err != nil {
return fmt.Errorf("failed to create new config: %w", err)
return nil, fmt.Errorf("failed to create new config: %w", err)
}
cfg.Name = displayName
if err := util.WriteJson(context.Background(), profPath, cfg); err != nil {
return nil, fmt.Errorf("failed to write profile config: %w", err)
}
err = util.WriteJson(context.Background(), profPath, cfg)
if err != nil {
return fmt.Errorf("failed to write profile config: %w", err)
}
return nil
return &Profile{
ID: id,
Name: displayName,
Path: profPath,
}, nil
}
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
configDir, err := s.getConfigDir(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
}
profileName = sanitizeProfileName(profileName)
if profileName == defaultProfileName {
// RemoveProfile deletes the profile identified by id. Callers must have
// already resolved any user-supplied handle to a concrete ID via
// ResolveProfile.
func (s *ServiceManager) RemoveProfile(id, username string) error {
if id == defaultProfileName {
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
}
profPath := filepath.Join(configDir, profileName+".json")
profileExists, err := fileExists(profPath)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
if !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid profile ID: %q", id)
}
if !profileExists {
profiles, err := s.loadAllProfiles(username)
if err != nil {
return fmt.Errorf("load profiles: %w", err)
}
var target *Profile
for i := range profiles {
if profiles[i].ID == id {
target = &profiles[i]
break
}
}
if target == nil {
return ErrProfileNotFound
}
@@ -301,57 +364,26 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
if err != nil && !errors.Is(err, ErrNoActiveProfile) {
return fmt.Errorf("failed to get active profile: %w", err)
}
if activeProf != nil && activeProf.Name == profileName {
return fmt.Errorf("cannot remove active profile: %s", profileName)
if activeProf != nil && activeProf.ID == id {
return fmt.Errorf("cannot remove active profile: %s", id)
}
err = util.RemoveJson(profPath)
if err != nil {
if err := util.RemoveJson(target.Path); err != nil {
return fmt.Errorf("failed to remove profile config: %w", err)
}
stateFile := filepath.Join(filepath.Dir(target.Path), id+".state.json")
if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
log.Warnf("failed to remove profile state file %s: %v", stateFile, err)
}
return nil
}
// ListProfiles returns every profile for the given user, including the
// default profile, with IsActive flags set.
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("failed to get config directory: %w", err)
}
files, err := util.ListFiles(configDir, "*.json")
if err != nil {
return nil, fmt.Errorf("failed to list profile files: %w", err)
}
var filtered []string
for _, file := range files {
if strings.HasSuffix(file, "state.json") {
continue // skip state files
}
filtered = append(filtered, file)
}
sort.Strings(filtered)
var activeProfName string
activeProf, err := s.GetActiveProfileState()
if err == nil {
activeProfName = activeProf.Name
}
var profiles []Profile
// add default profile always
profiles = append(profiles, Profile{Name: defaultProfileName, IsActive: activeProfName == "" || activeProfName == defaultProfileName})
for _, file := range filtered {
profileName := strings.TrimSuffix(filepath.Base(file), ".json")
var isActive bool
if activeProfName != "" && activeProfName == profileName {
isActive = true
}
profiles = append(profiles, Profile{Name: profileName, IsActive: isActive})
}
return profiles, nil
return s.loadAllProfiles(username)
}
// GetStatePath returns the path to the state file based on the operating system
@@ -369,7 +401,12 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath
}
if activeProf.Name == defaultProfileName {
if activeProf.ID == defaultProfileName {
return defaultStatePath
}
if !IsValidProfileFilenameStem(activeProf.ID) {
log.Warnf("invalid active profile ID %q, using default state path", activeProf.ID)
return defaultStatePath
}
@@ -379,7 +416,7 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath
}
return filepath.Join(configDir, activeProf.Name+".state.json")
return filepath.Join(configDir, activeProf.ID+".state.json")
}
// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
@@ -390,3 +427,165 @@ func (s *ServiceManager) getConfigDir(username string) (string, error) {
return getConfigDirForUser(username)
}
// loadAllProfiles returns every profile visible to the daemon for the
// given user, including the default profile. The returned slice is sorted
// by ID for a stable display order.
//
// Each Profile is fully populated: ID is the filename stem, Name comes
// from the JSON's "name" field (falling back to the filename stem when absent)
// and Path is built from a basename read off disk.
func (s *ServiceManager) loadAllProfiles(username string) ([]Profile, error) {
activeID, activeIsDefault := s.activeProfileID()
profiles := []Profile{{
ID: defaultProfileName,
Name: defaultProfileName,
Path: DefaultConfigPath,
IsActive: activeIsDefault,
}}
configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("get config directory: %w", err)
}
entries, err := os.ReadDir(configDir)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return profiles, nil
}
return nil, fmt.Errorf("read profile directory: %w", err)
}
var fileProfiles []Profile
for _, entry := range entries {
if entry.IsDir() {
continue
}
base := entry.Name()
if !strings.HasSuffix(base, ".json") {
continue
}
if strings.HasSuffix(base, ".state.json") {
continue
}
stem := strings.TrimSuffix(base, ".json")
if stem == defaultProfileName {
// default lives at the top-level config dir, not under /<user>
continue
}
if !IsValidProfileFilenameStem(stem) {
continue
}
path := filepath.Join(configDir, base)
name := readProfileName(path)
if name == "" {
name = stem
}
fileProfiles = append(fileProfiles, Profile{
ID: stem,
Name: name,
Path: path,
IsActive: stem == activeID,
})
}
sort.Slice(fileProfiles, func(i, j int) bool {
if fileProfiles[i].Name != fileProfiles[j].Name {
return fileProfiles[i].Name < fileProfiles[j].Name
}
// Sort tie-break on ID so duplicate names always render in the same order.
return fileProfiles[i].ID < fileProfiles[j].ID
})
profiles = append(profiles, fileProfiles...)
return profiles, nil
}
// readProfileName parses just the "name" field from the profile Json.
func readProfileName(path string) string {
data, err := os.ReadFile(path)
if err != nil {
return ""
}
var meta profileMeta
if err := json.Unmarshal(data, &meta); err != nil {
return ""
}
return meta.Name
}
// activeProfileID returns the currently-active profile's ID. The second
// return value is true when the active profile is the default one.
func (s *ServiceManager) activeProfileID() (string, bool) {
state, err := s.GetActiveProfileState()
if err != nil || state == nil {
return defaultProfileName, true
}
if state.ID == "" || state.ID == defaultProfileName {
return defaultProfileName, true
}
return state.ID, false
}
// ResolveProfile turns a user-supplied handle into a Profile. Resolution
// precedence is: exact ID match, then unique ID prefix, then unique exact
// name. Ambiguous matches return *ErrAmbiguousHandle so callers can
// surface the candidates.
func (s *ServiceManager) ResolveProfile(handle, username string) (*Profile, error) {
if handle == "" {
return nil, fmt.Errorf("profile handle is empty")
}
profiles, err := s.loadAllProfiles(username)
if err != nil {
return nil, err
}
for i := range profiles {
if profiles[i].ID == handle {
return &profiles[i], nil
}
}
// ID prefix match. Skip the default profile so `select d` does not
// accidentally pick it via prefix.
var prefixMatches []Profile
for i := range profiles {
if profiles[i].ID == defaultProfileName {
continue
}
if strings.HasPrefix(profiles[i].ID, handle) {
prefixMatches = append(prefixMatches, profiles[i])
}
}
if len(prefixMatches) == 1 {
return &prefixMatches[0], nil
}
if len(prefixMatches) > 1 {
return nil, &ErrAmbiguousHandle{
Handle: handle,
Candidates: prefixMatches,
Kind: AmbiguityKindIDPrefix,
}
}
var nameMatches []Profile
for i := range profiles {
if profiles[i].Name == handle {
nameMatches = append(nameMatches, profiles[i])
}
}
if len(nameMatches) == 1 {
return &nameMatches[0], nil
}
if len(nameMatches) > 1 {
return nil, &ErrAmbiguousHandle{
Handle: handle,
Candidates: nameMatches,
Kind: AmbiguityKindName,
}
}
return nil, ErrProfileNotFound
}

View File

@@ -0,0 +1,230 @@
package profilemanager
import (
"context"
"errors"
"os"
"os/user"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/util"
)
// withTestSM wires up patched globals + a clean config dir and returns a
// fully initialized ServiceManager plus the username we are scoped to.
func withTestSM(t *testing.T, fn func(sm *ServiceManager, username string)) {
t.Helper()
withTempConfigDir(t, func(configDir string) {
withPatchedGlobals(t, configDir, func() {
u, err := user.Current()
require.NoError(t, err)
sm := &ServiceManager{}
require.NoError(t, sm.CreateDefaultProfile())
fn(sm, u.Username)
})
})
}
func TestServiceProfile_ExactID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
got, err := sm.ResolveProfile(created.ID, username)
require.NoError(t, err)
assert.Equal(t, created.ID, got.ID)
assert.Equal(t, "work", got.Name)
})
}
func TestServiceProfile_IDPrefix(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
prefix := created.ID[:4]
got, err := sm.ResolveProfile(prefix, username)
require.NoError(t, err)
assert.Equal(t, created.ID, got.ID)
})
}
func TestServiceProfile_AmbiguousPrefix(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
// Plant two profiles whose IDs share a known prefix by writing
// the files directly, since generated IDs are random.
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
for _, id := range []string{"abcd1111aaaa", "abcd2222bbbb"} {
path := filepath.Join(configDir, id+".json")
require.NoError(t, util.WriteJson(context.Background(), path, &Config{Name: id}))
}
_, err = sm.ResolveProfile("abcd", username)
var amb *ErrAmbiguousHandle
require.ErrorAs(t, err, &amb)
assert.Equal(t, AmbiguityKindIDPrefix, amb.Kind)
assert.Len(t, amb.Candidates, 2)
})
}
func TestServiceProfile_ExactNameUnique(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.AddProfile("work", username)
require.NoError(t, err)
got, err := sm.ResolveProfile("work", username)
require.NoError(t, err)
assert.Equal(t, "work", got.Name)
})
}
func TestServiceProfile_AmbiguousName(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.AddProfile("work", username)
require.NoError(t, err)
_, err = sm.AddProfile("work", username)
require.NoError(t, err)
_, err = sm.ResolveProfile("work", username)
var amb *ErrAmbiguousHandle
require.ErrorAs(t, err, &amb)
assert.Equal(t, AmbiguityKindName, amb.Kind)
assert.Len(t, amb.Candidates, 2)
})
}
func TestServiceProfile_NotFound(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.ResolveProfile("nope", username)
assert.ErrorIs(t, err, ErrProfileNotFound)
})
}
func TestServiceProfile_DefaultByExactID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
got, err := sm.ResolveProfile(defaultProfileName, username)
require.NoError(t, err)
assert.Equal(t, defaultProfileName, got.ID)
})
}
func TestServiceProfile_LegacyFilenameCoexists(t *testing.T) {
// Legacy profiles stored as <name>.json with no "name" JSON field
// should still be discoverable by name and removable by name.
withTestSM(t, func(sm *ServiceManager, username string) {
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
path := filepath.Join(configDir, "legacy.json")
require.NoError(t, util.WriteJson(context.Background(), path, &Config{}))
got, err := sm.ResolveProfile("legacy", username)
require.NoError(t, err)
assert.Equal(t, "legacy", got.ID)
// Name falls back to the filename stem when JSON omits it.
assert.Equal(t, "legacy", got.Name)
})
}
func TestAddProfile_AllowsDuplicateWithFlag(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
first, err := sm.AddProfile("work", username)
require.NoError(t, err)
second, err := sm.AddProfile("work", username)
require.NoError(t, err)
assert.NotEqual(t, first.ID, second.ID)
assert.Equal(t, "work", second.Name)
})
}
func TestAddProfile_RejectsInvalidNames(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
cases := []string{
"", // empty
"\x00\x01", // only control chars (becomes empty)
strings.Repeat("a", maxProfileNameLen+1), // too long
}
for _, name := range cases {
_, err := sm.AddProfile(name, username)
assert.Error(t, err, "expected error for %q", name)
}
})
}
func TestRemoveProfile_RejectsInvalidID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
err := sm.RemoveProfile("../escape", username)
assert.Error(t, err)
})
}
func TestSanitizeDisplayName(t *testing.T) {
cases := []struct {
in string
want string
wantErr bool
}{
{"work", "work", false},
{"My Work Account", "My Work Account", false},
{"emoji 🚀 ok", "emoji 🚀 ok", false},
{"漢字テスト", "漢字テスト", false},
{"with\x00null", "withnull", false},
{"\x01\x02\x03", "", true},
{"", "", true},
}
for _, tc := range cases {
got, err := sanitizeDisplayName(tc.in)
if tc.wantErr {
assert.Error(t, err, "case %q", tc.in)
continue
}
assert.NoError(t, err, "case %q", tc.in)
assert.Equal(t, tc.want, got, "case %q", tc.in)
}
}
func TestIsValidProfileFilenameStem(t *testing.T) {
cases := []struct {
in string
want bool
}{
{"default", true},
{"abc123def456", true},
{"legacy-name", true},
{"legacy_name", true},
{"", false},
{"..", false},
{"../etc", false},
{"foo/bar", false},
{`foo\bar`, false},
{"with space", false},
{"with.dot", false},
{strings.Repeat("a", maxProfileIDLen+1), false},
}
for _, tc := range cases {
got := IsValidProfileFilenameStem(tc.in)
assert.Equal(t, tc.want, got, "case %q", tc.in)
}
}
func TestRemoveProfile_DeletesStateFile(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
statePath := filepath.Join(configDir, created.ID+".state.json")
require.NoError(t, os.WriteFile(statePath, []byte(`{"email":"a@b"}`), 0600))
require.NoError(t, sm.RemoveProfile(created.ID, username))
_, err = os.Stat(statePath)
assert.True(t, errors.Is(err, os.ErrNotExist), "state file should be removed")
})
}

View File

@@ -13,13 +13,20 @@ type ProfileState struct {
Email string `json:"email"`
}
func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, error) {
// GetProfileState reads the per-profile state file keyed by profile ID.
// The state file lives in the user's config directory. Legacy state files
// keyed by the old profile name remain readable.
func (pm *ProfileManager) GetProfileState(id string) (*ProfileState, error) {
configDir, err := getConfigDir()
if err != nil {
return nil, fmt.Errorf("get config directory: %w", err)
}
stateFile := filepath.Join(configDir, profileName+".state.json")
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return nil, fmt.Errorf("invalid profile ID: %q", id)
}
stateFile := filepath.Join(configDir, id+".state.json")
stateFileExists, err := fileExists(stateFile)
if err != nil {
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
@@ -51,7 +58,12 @@ func (pm *ProfileManager) SetActiveProfileState(state *ProfileState) error {
return fmt.Errorf("get active profile: %w", err)
}
stateFile := filepath.Join(configDir, activeProf.Name+".state.json")
id := activeProf.ID
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid active profile ID: %q", id)
}
stateFile := filepath.Join(configDir, id+".state.json")
err = util.WriteJsonWithRestrictedPermission(context.Background(), stateFile, state)
if err != nil {
return fmt.Errorf("write profile state: %w", err)

View File

@@ -28,6 +28,15 @@ func hashRosenpassKey(key []byte) string {
return hex.EncodeToString(hasher.Sum(nil))
}
// rpServer is the subset of rp.Server used by Manager. Defined as an interface
// so tests can substitute a mock without spinning up a real UDP server.
type rpServer interface {
AddPeer(rp.PeerConfig) (rp.PeerID, error)
RemovePeer(rp.PeerID) error
Run() error
Close() error
}
type Manager struct {
ifaceName string
spk []byte
@@ -36,7 +45,7 @@ type Manager struct {
preSharedKey *[32]byte
rpPeerIDs map[string]*rp.PeerID
rpWgHandler *NetbirdHandler
server *rp.Server
server rpServer
lock sync.Mutex
port int
wgIface PresharedKeySetter
@@ -51,7 +60,22 @@ func NewManager(preSharedKey *wgtypes.Key, wgIfaceName string) (*Manager, error)
rpKeyHash := hashRosenpassKey(public)
log.Tracef("generated new rosenpass key pair with public key %s", rpKeyHash)
return &Manager{ifaceName: wgIfaceName, rpKeyHash: rpKeyHash, spk: public, ssk: secret, preSharedKey: (*[32]byte)(preSharedKey), rpPeerIDs: make(map[string]*rp.PeerID), lock: sync.Mutex{}}, nil
return &Manager{
ifaceName: wgIfaceName,
rpKeyHash: rpKeyHash,
spk: public,
ssk: secret,
preSharedKey: (*[32]byte)(preSharedKey),
rpPeerIDs: make(map[string]*rp.PeerID),
// rpWgHandler is created here (instead of only in generateConfig) so it
// is never nil between NewManager and Run(). Otherwise an early
// OnConnected call (race observed on Android, issue #4341) panics on
// nil receiver in addPeer -> m.rpWgHandler.AddPeer. generateConfig will
// replace it with a fresh handler on each Run() to clear stale peer
// state from previous engine sessions.
rpWgHandler: NewNetbirdHandler(),
lock: sync.Mutex{},
}, nil
}
func (m *Manager) GetPubKey() []byte {
@@ -65,6 +89,16 @@ func (m *Manager) GetAddress() *net.UDPAddr {
// addPeer adds a new peer to the Rosenpass server
func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuardIP string, wireGuardPubKey string) error {
// Defense in depth against issue #4341 (Android crash): if Run() has not
// completed yet, m.server / m.rpWgHandler may be nil. Return an explicit
// error instead of panicking on nil-receiver dereference.
if m.server == nil {
return fmt.Errorf("rosenpass server not initialized")
}
if m.rpWgHandler == nil {
return fmt.Errorf("rosenpass wg handler not initialized")
}
var err error
pcfg := rp.PeerConfig{PublicKey: rosenpassPubKey}
if m.preSharedKey != nil {
@@ -79,6 +113,16 @@ func (m *Manager) addPeer(rosenpassPubKey []byte, rosenpassAddr string, wireGuar
if pcfg.Endpoint, err = net.ResolveUDPAddr("udp", peerAddr); err != nil {
return fmt.Errorf("failed to resolve peer endpoint address: %w", err)
}
// Our local Rosenpass UDP server binds on the IPv6 wildcard ([::]) — see
// GetAddress(). The remote peer's endpoint (pcfg.Endpoint) is the destination
// our server will sendto when initiating handshakes. ResolveUDPAddr returns a
// 4-byte IPv4 for IPv4 hosts, which the kernel rejects (EDESTADDRREQ) when
// sent from an AF_INET6 socket. Normalize the remote endpoint to IPv4-mapped
// IPv6 so its address family matches our listening socket.
// TODO: maybe bind the Rosenpass UDP server to the peer wg IP addr
if v4 := pcfg.Endpoint.IP.To4(); v4 != nil {
pcfg.Endpoint.IP = v4.To16()
}
}
peerID, err := m.server.AddPeer(pcfg)
if err != nil {
@@ -182,24 +226,31 @@ func (m *Manager) Run() error {
return err
}
m.server, err = rp.NewUDPServer(conf)
server, err := rp.NewUDPServer(conf)
if err != nil {
return err
}
m.lock.Lock()
m.server = server
m.lock.Unlock()
log.Infof("starting rosenpass server on port %d", m.port)
return m.server.Run()
return server.Run()
}
// Close closes the Rosenpass server
func (m *Manager) Close() error {
if m.server != nil {
err := m.server.Close()
if err != nil {
log.Errorf("failed closing local rosenpass server")
}
m.server = nil
m.lock.Lock()
server := m.server
m.server = nil
m.lock.Unlock()
if server == nil {
return nil
}
if err := server.Close(); err != nil {
log.Errorf("failed closing local rosenpass server: %v", err)
}
return nil
}

View File

@@ -1,14 +1,412 @@
package rosenpass
import (
"errors"
"os"
"sync"
"testing"
rp "cunicu.li/go-rosenpass"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// --- test doubles -----------------------------------------------------------
type addPeerCall struct {
cfg rp.PeerConfig
}
type removePeerCall struct {
id rp.PeerID
}
type mockServer struct {
mu sync.Mutex
addCalls []addPeerCall
removed []removePeerCall
nextID rp.PeerID
addErr error
removeErr error
closed bool
ran bool
}
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.addCalls = append(m.addCalls, addPeerCall{cfg: cfg})
if m.addErr != nil {
return rp.PeerID{}, m.addErr
}
// Increment a byte in nextID so distinct peers get distinct IDs.
m.nextID[0]++
return m.nextID, nil
}
func (m *mockServer) RemovePeer(id rp.PeerID) error {
m.mu.Lock()
defer m.mu.Unlock()
m.removed = append(m.removed, removePeerCall{id: id})
return m.removeErr
}
func (m *mockServer) Run() error { m.ran = true; return nil }
func (m *mockServer) Close() error { m.closed = true; return nil }
type setPSKCall struct {
peerKey string
psk wgtypes.Key
updateOnly bool
}
type mockIface struct {
mu sync.Mutex
calls []setPSKCall
err error
}
func (m *mockIface) SetPresharedKey(peerKey string, psk wgtypes.Key, updateOnly bool) error {
m.mu.Lock()
defer m.mu.Unlock()
m.calls = append(m.calls, setPSKCall{peerKey: peerKey, psk: psk, updateOnly: updateOnly})
return m.err
}
// newTestManager builds a Manager with deterministic spk so tie-break
// against a peer pubkey is controllable from tests. The provided spk byte
// becomes the first byte; remaining bytes are zero.
func newTestManager(spkFirstByte byte, mock *mockServer) *Manager {
spk := make([]byte, 32)
spk[0] = spkFirstByte
return &Manager{
ifaceName: "wt0",
spk: spk,
ssk: make([]byte, 32),
rpKeyHash: "test-hash",
rpPeerIDs: make(map[string]*rp.PeerID),
rpWgHandler: NewNetbirdHandler(),
server: mock,
}
}
// validWGKey returns a deterministic 32-byte wireguard public key (base64).
func validWGKey(t *testing.T, lastByte byte) string {
t.Helper()
var k wgtypes.Key
k[31] = lastByte
return k.String()
}
// --- pure helpers ----------------------------------------------------------
func TestHashRosenpassKey_Deterministic(t *testing.T) {
key := []byte("hello-rosenpass")
require.Equal(t, hashRosenpassKey(key), hashRosenpassKey(key))
require.Len(t, hashRosenpassKey(key), 64) // sha256 hex
}
func TestHashRosenpassKey_DifferentInputsDifferOutputs(t *testing.T) {
require.NotEqual(t, hashRosenpassKey([]byte("a")), hashRosenpassKey([]byte("b")))
}
func TestGetLogLevel_DefaultWhenUnset(t *testing.T) {
// Snapshot + unset to exercise the LookupEnv ok=false branch. t.Setenv
// can only set, not delete, so do it manually with restore via t.Cleanup.
prev, hadPrev := os.LookupEnv(defaultLogLevelVar)
require.NoError(t, os.Unsetenv(defaultLogLevelVar))
t.Cleanup(func() {
if hadPrev {
_ = os.Setenv(defaultLogLevelVar, prev)
} else {
_ = os.Unsetenv(defaultLogLevelVar)
}
})
require.Equal(t, defaultLog.String(), getLogLevel().String())
}
func TestGetLogLevel_Cases(t *testing.T) {
cases := map[string]string{
"debug": "DEBUG",
"info": "INFO",
"warn": "WARN",
"error": "ERROR",
"unknown": "INFO", // default fallback
}
for input, wantStr := range cases {
input, wantStr := input, wantStr
t.Run(input, func(t *testing.T) {
t.Setenv(defaultLogLevelVar, input)
require.Equal(t, wantStr, getLogLevel().String())
})
}
}
func TestFindRandomAvailableUDPPort(t *testing.T) {
port, err := findRandomAvailableUDPPort()
require.NoError(t, err)
require.Greater(t, port, 0)
require.LessOrEqual(t, port, 65535)
}
// --- addPeer ---------------------------------------------------------------
func TestAddPeer_HigherLocalPubkey_SetsEndpoint(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv) // local spk lexicographically larger
remotePubKey := make([]byte, 32) // remote spk = all zeros (smaller)
err := m.addPeer(remotePubKey, "rosenpass-host:7000", "100.1.1.1", validWGKey(t, 1))
require.NoError(t, err)
require.Len(t, srv.addCalls, 1)
ep := srv.addCalls[0].cfg.Endpoint
require.NotNil(t, ep, "initiator side must set Endpoint")
require.Equal(t, 7000, ep.Port)
require.Equal(t, "100.1.1.1", ep.IP.String())
}
func TestAddPeer_HigherLocalPubkey_EndpointIPIsIPv4Mapped(t *testing.T) {
// Regression guard for the EDESTADDRREQ fix: Endpoint.IP must be 16-byte
// (IPv4-mapped IPv6) so it matches the AF_INET6 listening socket family.
srv := &mockServer{}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.NoError(t, err)
ep := srv.addCalls[0].cfg.Endpoint
require.NotNil(t, ep)
require.Len(t, ep.IP, 16, "IPv4 endpoint must be normalized to 16-byte v4-mapped form")
require.True(t, ep.IP.To4() != nil, "Endpoint must still be detected as IPv4")
}
func TestAddPeer_LowerLocalPubkey_LeavesEndpointNil(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0x00, srv) // local spk smaller
remotePubKey := make([]byte, 32)
remotePubKey[0] = 0xFF
err := m.addPeer(remotePubKey, "rp:5000", "100.1.1.1", validWGKey(t, 2))
require.NoError(t, err)
require.Nil(t, srv.addCalls[0].cfg.Endpoint, "responder side must NOT set Endpoint")
}
func TestAddPeer_PresharedKeyPropagated(t *testing.T) {
srv := &mockServer{}
psk := &wgtypes.Key{0x42}
m := newTestManager(0xFF, srv)
m.preSharedKey = (*[32]byte)(psk)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 3))
require.NoError(t, err)
require.Equal(t, [32]byte(*psk), [32]byte(srv.addCalls[0].cfg.PresharedKey))
}
func TestAddPeer_InvalidRosenpassAddr_ReturnsError(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv) // initiator path → parses rosenpassAddr
err := m.addPeer(make([]byte, 32), "not-a-host-port", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Empty(t, srv.addCalls, "server.AddPeer must not run when address parse fails")
}
func TestAddPeer_InvalidWireGuardPubKey_ReturnsError(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", "not-a-valid-key")
require.Error(t, err)
}
func TestAddPeer_ServerError_Propagates(t *testing.T) {
srv := &mockServer{addErr: errors.New("boom")}
m := newTestManager(0xFF, srv)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
}
// Regression guard for issue #4341 (Android crash). If Run() has not completed
// before OnConnected fires, m.rpWgHandler or m.server may be nil. Without the
// nil guards, m.rpWgHandler.AddPeer panics on nil receiver.
func TestAddPeer_NilHandler_ReturnsErrorNoCrash(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.rpWgHandler = nil // simulate Run() not yet completed
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Contains(t, err.Error(), "wg handler not initialized")
}
func TestAddPeer_NilServer_ReturnsErrorNoCrash(t *testing.T) {
m := newTestManager(0xFF, nil)
m.server = nil // simulate Run() not yet completed
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", validWGKey(t, 1))
require.Error(t, err)
require.Contains(t, err.Error(), "server not initialized")
}
// NewManager must pre-initialize rpWgHandler so the nil-receiver crash from
// issue #4341 cannot occur in the window between NewManager and Run().
func TestNewManager_PreInitializesHandler(t *testing.T) {
psk := wgtypes.Key{}
m, err := NewManager(&psk, "wt0")
require.NoError(t, err)
require.NotNil(t, m.rpWgHandler, "rpWgHandler must be initialized in NewManager")
}
func TestAddPeer_RecordsPeerID(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 5)
err := m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey)
require.NoError(t, err)
require.Contains(t, m.rpPeerIDs, wgKey)
}
// --- OnConnected / OnDisconnected ------------------------------------------
func TestOnConnected_NilRemotePubKey_NoAddPeer(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.OnConnected(validWGKey(t, 1), nil, "100.1.1.1", "rp:5000")
require.Empty(t, srv.addCalls, "nil remote rosenpass pubkey must skip AddPeer")
require.Empty(t, m.rpPeerIDs)
}
func TestOnConnected_ValidPubKey_CallsAddPeer(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 1)
m.OnConnected(wgKey, make([]byte, 32), "100.1.1.1", "rp:5000")
require.Len(t, srv.addCalls, 1)
require.Contains(t, m.rpPeerIDs, wgKey)
}
func TestOnDisconnected_UnknownPeer_NoOp(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
m.OnDisconnected(validWGKey(t, 99))
require.Empty(t, srv.removed, "unknown peer key must not call RemovePeer")
}
func TestOnDisconnected_KnownPeer_CallsRemoveAndForgets(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 1)
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
require.Contains(t, m.rpPeerIDs, wgKey)
m.OnDisconnected(wgKey)
require.Len(t, srv.removed, 1)
require.NotContains(t, m.rpPeerIDs, wgKey, "peer must be forgotten after disconnect")
}
// --- IsPresharedKeyInitialized ---------------------------------------------
func TestIsPresharedKeyInitialized_UnknownPeer_ReturnsFalse(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
require.False(t, m.IsPresharedKeyInitialized(validWGKey(t, 1)))
}
func TestIsPresharedKeyInitialized_AddedButNotHandshaken_ReturnsFalse(t *testing.T) {
srv := &mockServer{}
m := newTestManager(0xFF, srv)
wgKey := validWGKey(t, 2)
require.NoError(t, m.addPeer(make([]byte, 32), "rp:5000", "100.1.1.1", wgKey))
require.False(t, m.IsPresharedKeyInitialized(wgKey))
}
// --- NetbirdHandler.outputKey ----------------------------------------------
func TestHandler_OutputKey_FirstCallUsesUpdateOnlyFalse(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x01}
wgKey := wgtypes.Key{0xAA}
h.AddPeer(pid, "wt0", rp.Key(wgKey))
psk := rp.Key{0xBB}
h.HandshakeCompleted(pid, psk)
require.Len(t, iface.calls, 1)
require.False(t, iface.calls[0].updateOnly, "first PSK rotation must use updateOnly=false")
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
}
func TestHandler_OutputKey_SubsequentCallsUseUpdateOnlyTrue(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x02}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xCC}))
h.HandshakeCompleted(pid, rp.Key{0x01}) // first
h.HandshakeCompleted(pid, rp.Key{0x02}) // second
require.Len(t, iface.calls, 2)
require.False(t, iface.calls[0].updateOnly)
require.True(t, iface.calls[1].updateOnly, "subsequent rotations must use updateOnly=true")
}
func TestHandler_OutputKey_NilInterface_NoCrashNoCall(t *testing.T) {
h := NewNetbirdHandler()
// no SetInterface — iface remains nil
pid := rp.PeerID{0x03}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{}))
// Must not panic.
h.HandshakeCompleted(pid, rp.Key{})
}
func TestHandler_OutputKey_UnknownPeer_NoCall(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
h.HandshakeCompleted(rp.PeerID{0xFF}, rp.Key{})
require.Empty(t, iface.calls, "unknown peer id must not trigger SetPresharedKey")
}
func TestHandler_RemovePeer_ClearsInitializedState(t *testing.T) {
h := NewNetbirdHandler()
iface := &mockIface{}
h.SetInterface(iface)
pid := rp.PeerID{0x04}
h.AddPeer(pid, "wt0", rp.Key(wgtypes.Key{0xDD}))
h.HandshakeCompleted(pid, rp.Key{0x01})
require.True(t, h.IsPeerInitialized(pid))
h.RemovePeer(pid)
require.False(t, h.IsPeerInitialized(pid), "RemovePeer must clear initialized flag")
}
func TestHandler_SetInterfaceAfterAddPeer_StillReceivesKey(t *testing.T) {
h := NewNetbirdHandler()
pid := rp.PeerID{0x05}
wgKey := wgtypes.Key{0xEE}
h.AddPeer(pid, "wt0", rp.Key(wgKey))
iface := &mockIface{}
h.SetInterface(iface) // set after AddPeer
h.HandshakeCompleted(pid, rp.Key{0x42})
require.Len(t, iface.calls, 1)
require.Equal(t, wgKey.String(), iface.calls[0].peerKey)
}

View File

@@ -0,0 +1,42 @@
package rosenpass
import (
"fmt"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)
// DeterministicSeedKey derives a 32-byte WireGuard preshared key from a pair
// of peer public keys. Both peers, given the same key pair, produce the same
// output regardless of which side runs the function: the inputs are ordered
// lexicographically before concatenation.
//
// NetBird uses this value as the initial Rosenpass-side preshared key when no
// explicit account-level PSK is configured, so both peers converge on the same
// PSK before the first post-quantum handshake completes.
//
// The resulting key MUST NOT be treated as quantum-safe: it is deterministic
// from public keys and exists only to seed WireGuard until Rosenpass rotates
// in a real post-quantum PSK.
func DeterministicSeedKey(localKey, remoteKey string) (*wgtypes.Key, error) {
lk := []byte(localKey)
rk := []byte(remoteKey)
if len(lk) < 16 || len(rk) < 16 {
return nil, fmt.Errorf("rosenpass: peer keys must be at least 16 bytes (got local=%d, remote=%d)", len(lk), len(rk))
}
var keyInput []byte
if localKey > remoteKey {
keyInput = append(keyInput, lk[:16]...)
keyInput = append(keyInput, rk[:16]...)
} else {
keyInput = append(keyInput, rk[:16]...)
keyInput = append(keyInput, lk[:16]...)
}
key, err := wgtypes.NewKey(keyInput)
if err != nil {
return nil, fmt.Errorf("rosenpass: deterministic seed key: %w", err)
}
return &key, nil
}

View File

@@ -0,0 +1,44 @@
package rosenpass
import (
"strings"
"testing"
"github.com/stretchr/testify/require"
)
func TestDeterministicSeedKey_SameForBothSides(t *testing.T) {
// Peer A and peer B must derive the same PSK regardless of which side
// computes it: the function orders inputs internally.
a := strings.Repeat("a", 32)
b := strings.Repeat("b", 32)
keyAB, err := DeterministicSeedKey(a, b)
require.NoError(t, err)
keyBA, err := DeterministicSeedKey(b, a)
require.NoError(t, err)
require.Equal(t, keyAB.String(), keyBA.String(), "swapping arguments must yield identical key")
}
func TestDeterministicSeedKey_ChangesWithKeys(t *testing.T) {
a := strings.Repeat("a", 32)
b := strings.Repeat("b", 32)
c := strings.Repeat("c", 32)
keyAB, err := DeterministicSeedKey(a, b)
require.NoError(t, err)
keyAC, err := DeterministicSeedKey(a, c)
require.NoError(t, err)
require.NotEqual(t, keyAB.String(), keyAC.String(), "different peer pair must yield different key")
}
func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
short := "short" // < 16 bytes
long := strings.Repeat("x", 32)
_, err := DeterministicSeedKey(short, long)
require.Error(t, err)
_, err = DeterministicSeedKey(long, short)
require.Error(t, err)
}

View File

@@ -96,17 +96,19 @@ func (m *Manager) Stop(ctx context.Context) error {
}
m.mu.Lock()
defer m.mu.Unlock()
cancel := m.cancel
done := m.done
m.mu.Unlock()
if m.cancel == nil {
if cancel == nil {
return nil
}
m.cancel()
cancel()
select {
case <-ctx.Done():
return ctx.Err()
case <-m.done:
case <-done:
}
return nil

View File

@@ -76,9 +76,6 @@ type Client struct {
dnsManager dns.IosDnsManager
loginComplete bool
connectClient *internal.ConnectClient
// config holds the active configuration once Run has loaded it. Consumed by
// the in-app SSH client for the NetBird SSH key and the OAuth flow.
config *profilemanager.Config
// preloadedConfig holds config loaded from JSON (used on tvOS where file writes are blocked)
preloadedConfig *profilemanager.Config
}
@@ -163,7 +160,6 @@ func (c *Client) Run(fd int32, interfaceName string, envList *EnvList) error {
ctx = internal.CtxInitState(ctx)
c.onHostDnsFn = func([]string) {}
cfg.WgIface = interfaceName
c.config = cfg
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOniOS(fd, c.networkChangeListener, c.dnsManager, c.stateFile)
@@ -531,13 +527,6 @@ func (c *Client) DeselectRoute(id string) error {
return nil
}
// sshState returns the active config and the running connect client for the
// in-app SSH client. Both are nil until Run has loaded the config and started
// the tunnel.
func (c *Client) sshState() (*profilemanager.Config, *internal.ConnectClient) {
return c.config, c.connectClient
}
func formatDuration(d time.Duration) string {
ds := d.String()
dotIndex := strings.Index(ds, ".")

View File

@@ -1,431 +0,0 @@
//go:build ios
package NetBirdSDK
import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"
"time"
log "github.com/sirupsen/logrus"
gossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/internal"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
)
const (
sshDialTimeout = 30 * time.Second
sshDetectionTimeout = 5 * time.Second
)
// SSHTerminalListener receives SSH session events. It is implemented in Swift.
//
// All callbacks are invoked from goroutines and may run concurrently with each
// other; the implementation must be safe to call from any thread.
type SSHTerminalListener interface {
OnConnected()
OnData(data []byte)
OnClose(reason string)
OnError(message string)
}
// SSHClient is a NetBird-aware SSH client exposed to Swift via gomobile.
//
// It dials through the running NetBird tunnel and runs a standard SSH session
// on top with PTY enabled. Host-key verification uses the NetBird-provided
// peer SSH host keys, identical to the desktop client.
type SSHClient struct {
nb *Client
mu sync.Mutex
listener SSHTerminalListener
urlOpener URLOpener
sshClient *gossh.Client
session *gossh.Session
stdin io.WriteCloser
closed bool
}
// NewSSHClient creates a new SSH client bound to the running NetBird Client.
func NewSSHClient(c *Client) *SSHClient {
return &SSHClient{nb: c}
}
// SetListener registers the Swift listener. Must be called before Connect to
// receive any events.
func (s *SSHClient) SetListener(l SSHTerminalListener) {
s.mu.Lock()
s.listener = l
s.mu.Unlock()
}
// SetURLOpener registers the Swift URL opener used to display the device-code
// authorization page in an in-app browser when the target peer requires JWT
// authentication. Must be set before Connect to be effective.
func (s *SSHClient) SetURLOpener(opener URLOpener) {
s.mu.Lock()
s.urlOpener = opener
s.mu.Unlock()
}
// Connect dials the SSH server through the NetBird tunnel and performs the
// SSH handshake. It auto-detects the server type via SSH banner inspection
// and selects the appropriate authentication path:
//
// - NetBird-SSH server requiring JWT: launches the OAuth 2.0 device-code
// flow, opens the verification URL through the registered URLOpener, and
// uses the resulting token as the SSH password. Host-key verification
// uses the NetBird peer registry.
// - NetBird-SSH server without JWT: authenticates with the NetBird SSH
// private key. Host-key verification uses the NetBird peer registry.
// - Regular SSH server (e.g. OpenSSH): authenticates with the NetBird key
// first (so a user-installed NetBird public key works), then falls back
// to the supplied password if non-empty. Host-key verification is
// disabled (TOFU pending).
//
// The password parameter is only consulted for regular SSH servers.
func (s *SSHClient) Connect(host string, port int, user, password string) error {
cfg, cc := s.nb.sshState()
if cc == nil {
return errors.New("netbird client not running")
}
if cfg == nil {
return errors.New("netbird config not loaded")
}
engine := cc.Engine()
if engine == nil {
return errors.New("netbird engine not available")
}
serverType := detectServerType(host, port)
log.Infof("SSH server type for %s:%d: %s", host, port, serverType)
authMethods, hostKeyCallback, err := s.buildAuth(cfg, engine, serverType, password)
if err != nil {
return err
}
clientConfig := &gossh.ClientConfig{
User: user,
Auth: authMethods,
HostKeyCallback: hostKeyCallback,
Timeout: sshDialTimeout,
}
return s.dialAndHandshake(host, port, clientConfig)
}
// StartSession requests a PTY and starts an interactive shell. Output from
// the session is forwarded to the listener via OnData.
func (s *SSHClient) StartSession(cols, rows int) error {
log.Debugf("SSH: starting session %dx%d", cols, rows)
s.mu.Lock()
sshClient := s.sshClient
s.mu.Unlock()
if sshClient == nil {
return errors.New("ssh client not connected")
}
session, err := sshClient.NewSession()
if err != nil {
return fmt.Errorf("new session: %w", err)
}
modes := gossh.TerminalModes{
gossh.ECHO: 1,
gossh.TTY_OP_ISPEED: 14400,
gossh.TTY_OP_OSPEED: 14400,
gossh.VINTR: 3,
gossh.VQUIT: 28,
gossh.VERASE: 127,
}
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
closeQuiet(session, "session after pty error")
return fmt.Errorf("request pty: %w", err)
}
stdin, err := session.StdinPipe()
if err != nil {
closeQuiet(session, "session after stdin error")
return fmt.Errorf("stdin pipe: %w", err)
}
stdout, err := session.StdoutPipe()
if err != nil {
closeQuiet(session, "session after stdout error")
return fmt.Errorf("stdout pipe: %w", err)
}
stderr, err := session.StderrPipe()
if err != nil {
closeQuiet(session, "session after stderr error")
return fmt.Errorf("stderr pipe: %w", err)
}
if err := session.Shell(); err != nil {
closeQuiet(session, "session after shell error")
return fmt.Errorf("start shell: %w", err)
}
s.mu.Lock()
s.session = session
s.stdin = stdin
s.mu.Unlock()
go s.readLoop(stdout, "stdout")
go s.readLoop(stderr, "stderr")
log.Debug("SSH: session started, shell running")
return nil
}
// Write sends data to the SSH session stdin.
func (s *SSHClient) Write(data []byte) error {
s.mu.Lock()
stdin := s.stdin
s.mu.Unlock()
if stdin == nil {
return errors.New("ssh session not started")
}
if _, err := stdin.Write(data); err != nil {
return fmt.Errorf("write stdin: %w", err)
}
return nil
}
// Resize updates the PTY window size.
func (s *SSHClient) Resize(cols, rows int) error {
s.mu.Lock()
session := s.session
s.mu.Unlock()
if session == nil {
return errors.New("ssh session not started")
}
return session.WindowChange(rows, cols)
}
// Close terminates the SSH session and underlying connection. Safe to call
// multiple times.
func (s *SSHClient) Close() error {
s.mu.Lock()
sshClient := s.sshClient
session := s.session
stdin := s.stdin
s.sshClient = nil
s.session = nil
s.stdin = nil
s.mu.Unlock()
if stdin != nil {
if err := stdin.Close(); err != nil {
log.Debugf("ssh: stdin close: %v", err)
}
}
if session != nil {
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
log.Debugf("ssh: session close: %v", err)
}
}
var firstErr error
if sshClient != nil {
if err := sshClient.Close(); err != nil {
firstErr = err
}
}
s.notifyClose("closed by client")
return firstErr
}
func (s *SSHClient) buildAuth(cfg *profilemanager.Config, engine *internal.Engine,
serverType detection.ServerType, password string) ([]gossh.AuthMethod, gossh.HostKeyCallback, error) {
switch serverType {
case detection.ServerTypeNetBirdJWT:
token, err := s.requestJWTToken(cfg)
if err != nil {
return nil, nil, fmt.Errorf("jwt: %w", err)
}
auths := []gossh.AuthMethod{gossh.Password(token)}
return auths, nbssh.CreateHostKeyCallback(&engineHostKeyVerifier{engine: engine}), nil
case detection.ServerTypeNetBirdNoJWT:
if cfg.SSHKey == "" {
return nil, nil, errors.New("no NetBird SSH key available")
}
signer, err := gossh.ParsePrivateKey([]byte(cfg.SSHKey))
if err != nil {
return nil, nil, fmt.Errorf("parse netbird ssh key: %w", err)
}
auths := []gossh.AuthMethod{gossh.PublicKeys(signer)}
return auths, nbssh.CreateHostKeyCallback(&engineHostKeyVerifier{engine: engine}), nil
default: // regular SSH
var auths []gossh.AuthMethod
if cfg.SSHKey != "" {
if signer, err := gossh.ParsePrivateKey([]byte(cfg.SSHKey)); err == nil {
auths = append(auths, gossh.PublicKeys(signer))
} else {
log.Debugf("ssh: parse netbird key for regular auth: %v", err)
}
}
if password != "" {
pw := password
auths = append(auths, gossh.Password(pw))
auths = append(auths, gossh.KeyboardInteractive(func(_, _ string, questions []string, _ []bool) ([]string, error) {
answers := make([]string, len(questions))
for i := range questions {
answers[i] = pw
}
return answers, nil
}))
}
if len(auths) == 0 {
return nil, nil, errors.New("no auth method available: provide a password or configure NetBird SSH key")
}
return auths, gossh.InsecureIgnoreHostKey(), nil // nolint:gosec // TOFU not yet implemented
}
}
func (s *SSHClient) requestJWTToken(cfg *profilemanager.Config) (string, error) {
s.mu.Lock()
urlOpener := s.urlOpener
s.mu.Unlock()
if urlOpener == nil {
return "", errors.New("URL opener not configured for JWT auth")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
flow, err := auth.NewOAuthFlow(ctx, cfg, false, true, profilemanager.GetLoginHint())
if err != nil {
return "", fmt.Errorf("create oauth flow: %w", err)
}
flowInfo, err := flow.RequestAuthInfo(ctx)
if err != nil {
return "", fmt.Errorf("request auth info: %w", err)
}
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
tokenInfo, err := flow.WaitToken(ctx, flowInfo)
if err != nil {
return "", fmt.Errorf("wait for token: %w", err)
}
token := tokenInfo.GetTokenToUse()
if token == "" {
return "", errors.New("empty token returned by IdP")
}
return token, nil
}
func (s *SSHClient) dialAndHandshake(host string, port int, clientConfig *gossh.ClientConfig) error {
addr := net.JoinHostPort(host, strconv.Itoa(port))
log.Infof("SSH: connecting to %s as %s", addr, clientConfig.User)
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
defer cancel()
var dialer net.Dialer
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return fmt.Errorf("dial %s: %w", addr, err)
}
sshConn, chans, reqs, err := gossh.NewClientConn(conn, addr, clientConfig)
if err != nil {
if cerr := conn.Close(); cerr != nil {
log.Debugf("ssh: close after handshake error: %v", cerr)
}
return fmt.Errorf("ssh handshake: %w", err)
}
s.mu.Lock()
s.sshClient = gossh.NewClient(sshConn, chans, reqs)
listener := s.listener
s.mu.Unlock()
log.Infof("SSH: connected to %s", addr)
if listener != nil {
listener.OnConnected()
}
return nil
}
func (s *SSHClient) readLoop(r io.Reader, name string) {
buf := make([]byte, 4096)
for {
n, err := r.Read(buf)
if n > 0 {
s.mu.Lock()
listener := s.listener
s.mu.Unlock()
if listener != nil {
chunk := make([]byte, n)
copy(chunk, buf[:n])
listener.OnData(chunk)
}
}
if err != nil {
if !errors.Is(err, io.EOF) {
log.Debugf("ssh %s read: %v", name, err)
}
s.notifyClose(err.Error())
return
}
}
}
func (s *SSHClient) notifyClose(reason string) {
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return
}
s.closed = true
listener := s.listener
s.mu.Unlock()
if listener != nil {
listener.OnClose(reason)
}
}
// engineHostKeyVerifier adapts *internal.Engine to nbssh.HostKeyVerifier.
type engineHostKeyVerifier struct {
engine *internal.Engine
}
func (v *engineHostKeyVerifier) VerifySSHHostKey(peerAddress string, presented []byte) error {
storedKey, found := v.engine.GetPeerSSHKey(peerAddress)
if !found {
return nbssh.ErrPeerNotFound
}
return nbssh.VerifyHostKey(storedKey, presented, peerAddress)
}
func detectServerType(host string, port int) detection.ServerType {
ctx, cancel := context.WithTimeout(context.Background(), sshDetectionTimeout)
defer cancel()
dialer := &net.Dialer{}
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
if err != nil {
log.Debugf("ssh: server detection for %s:%d failed: %v (assuming regular SSH)", host, port, err)
return detection.ServerTypeRegular
}
return serverType
}
func closeQuiet(c io.Closer, label string) {
if c == nil {
return
}
if err := c.Close(); err != nil && !errors.Is(err, io.EOF) {
log.Debugf("ssh: close %s: %v", label, err)
}
}

View File

@@ -3915,9 +3915,11 @@ func (x *GetEventsResponse) GetEvents() []*SystemEvent {
}
type SwitchProfileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"`
state protoimpl.MessageState `protogen:"open.v1"`
// profileName is treated as a handle: exact ID, unique ID prefix, or
// unique display name. The daemon resolves it server-side.
ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -3967,7 +3969,11 @@ func (x *SwitchProfileRequest) GetUsername() string {
}
type SwitchProfileResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
state protoimpl.MessageState `protogen:"open.v1"`
// id is the resolved on-disk ID of the profile that became active.
// Lets CLI clients update their local active-profile state without
// duplicating the resolution logic.
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4002,6 +4008,13 @@ func (*SwitchProfileResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{55}
}
func (x *SwitchProfileResponse) GetId() string {
if x != nil {
return x.Id
}
return ""
}
type SetConfigRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
@@ -4358,9 +4371,11 @@ func (*SetConfigResponse) Descriptor() ([]byte, []int) {
}
type AddProfileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
ProfileName string `protobuf:"bytes,2,opt,name=profileName,proto3" json:"profileName,omitempty"`
state protoimpl.MessageState `protogen:"open.v1"`
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
// profileName carries the human-readable display name for the new
// profile. The on-disk filename is a separately-generated ID.
ProfileName string `protobuf:"bytes,2,opt,name=profileName,proto3" json:"profileName,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4410,7 +4425,10 @@ func (x *AddProfileRequest) GetProfileName() string {
}
type AddProfileResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
state protoimpl.MessageState `protogen:"open.v1"`
// id is the generated on-disk ID of the new profile. CLI clients
// display a truncated form, UI clients can ignore it.
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4445,10 +4463,19 @@ func (*AddProfileResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{59}
}
func (x *AddProfileResponse) GetId() string {
if x != nil {
return x.Id
}
return ""
}
type RemoveProfileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
ProfileName string `protobuf:"bytes,2,opt,name=profileName,proto3" json:"profileName,omitempty"`
state protoimpl.MessageState `protogen:"open.v1"`
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
// profileName is treated as a handle: an exact ID, a unique ID
// prefix, or a unique display name. Resolution happens server-side.
ProfileName string `protobuf:"bytes,2,opt,name=profileName,proto3" json:"profileName,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4498,7 +4525,10 @@ func (x *RemoveProfileRequest) GetProfileName() string {
}
type RemoveProfileResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
state protoimpl.MessageState `protogen:"open.v1"`
// id is the full resolved ID of the removed profile, so callers can
// confirm exactly which profile a name/prefix handle resolved to.
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4533,6 +4563,13 @@ func (*RemoveProfileResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{61}
}
func (x *RemoveProfileResponse) GetId() string {
if x != nil {
return x.Id
}
return ""
}
type ListProfilesRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
@@ -4625,6 +4662,7 @@ type Profile struct {
state protoimpl.MessageState `protogen:"open.v1"`
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
IsActive bool `protobuf:"varint,2,opt,name=is_active,json=isActive,proto3" json:"is_active,omitempty"`
Id string `protobuf:"bytes,3,opt,name=id,proto3" json:"id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4673,6 +4711,13 @@ func (x *Profile) GetIsActive() bool {
return false
}
func (x *Profile) GetId() string {
if x != nil {
return x.Id
}
return ""
}
type GetActiveProfileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -4713,6 +4758,7 @@ type GetActiveProfileResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
ProfileName string `protobuf:"bytes,1,opt,name=profileName,proto3" json:"profileName,omitempty"`
Username string `protobuf:"bytes,2,opt,name=username,proto3" json:"username,omitempty"`
Id string `protobuf:"bytes,3,opt,name=id,proto3" json:"id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4761,6 +4807,13 @@ func (x *GetActiveProfileResponse) GetUsername() string {
return ""
}
func (x *GetActiveProfileResponse) GetId() string {
if x != nil {
return x.Id
}
return ""
}
type LogoutRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
@@ -6578,8 +6631,9 @@ const file_daemon_proto_rawDesc = "" +
"\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
"\f_profileNameB\v\n" +
"\t_username\"\x17\n" +
"\x15SwitchProfileResponse\"\x98\x11\n" +
"\t_username\"'\n" +
"\x15SwitchProfileResponse\x12\x0e\n" +
"\x02id\x18\x01 \x01(\tR\x02id\"\x98\x11\n" +
"\x10SetConfigRequest\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12 \n" +
"\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" +
@@ -6648,23 +6702,27 @@ const file_daemon_proto_rawDesc = "" +
"\x11SetConfigResponse\"Q\n" +
"\x11AddProfileRequest\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12 \n" +
"\vprofileName\x18\x02 \x01(\tR\vprofileName\"\x14\n" +
"\x12AddProfileResponse\"T\n" +
"\vprofileName\x18\x02 \x01(\tR\vprofileName\"$\n" +
"\x12AddProfileResponse\x12\x0e\n" +
"\x02id\x18\x01 \x01(\tR\x02id\"T\n" +
"\x14RemoveProfileRequest\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12 \n" +
"\vprofileName\x18\x02 \x01(\tR\vprofileName\"\x17\n" +
"\x15RemoveProfileResponse\"1\n" +
"\vprofileName\x18\x02 \x01(\tR\vprofileName\"'\n" +
"\x15RemoveProfileResponse\x12\x0e\n" +
"\x02id\x18\x01 \x01(\tR\x02id\"1\n" +
"\x13ListProfilesRequest\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\"C\n" +
"\x14ListProfilesResponse\x12+\n" +
"\bprofiles\x18\x01 \x03(\v2\x0f.daemon.ProfileR\bprofiles\":\n" +
"\bprofiles\x18\x01 \x03(\v2\x0f.daemon.ProfileR\bprofiles\"J\n" +
"\aProfile\x12\x12\n" +
"\x04name\x18\x01 \x01(\tR\x04name\x12\x1b\n" +
"\tis_active\x18\x02 \x01(\bR\bisActive\"\x19\n" +
"\x17GetActiveProfileRequest\"X\n" +
"\tis_active\x18\x02 \x01(\bR\bisActive\x12\x0e\n" +
"\x02id\x18\x03 \x01(\tR\x02id\"\x19\n" +
"\x17GetActiveProfileRequest\"h\n" +
"\x18GetActiveProfileResponse\x12 \n" +
"\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" +
"\busername\x18\x02 \x01(\tR\busername\"t\n" +
"\busername\x18\x02 \x01(\tR\busername\x12\x0e\n" +
"\x02id\x18\x03 \x01(\tR\x02id\"t\n" +
"\rLogoutRequest\x12%\n" +
"\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +

View File

@@ -613,11 +613,18 @@ message GetEventsResponse {
}
message SwitchProfileRequest {
// profileName is treated as a handle: exact ID, unique ID prefix, or
// unique display name. The daemon resolves it server-side.
optional string profileName = 1;
optional string username = 2;
}
message SwitchProfileResponse {}
message SwitchProfileResponse {
// id is the resolved on-disk ID of the profile that became active.
// Lets CLI clients update their local active-profile state without
// duplicating the resolution logic.
string id = 1;
}
message SetConfigRequest {
string username = 1;
@@ -684,17 +691,29 @@ message SetConfigResponse{}
message AddProfileRequest {
string username = 1;
// profileName carries the human-readable display name for the new
// profile. The on-disk filename is a separately-generated ID.
string profileName = 2;
}
message AddProfileResponse {}
message AddProfileResponse {
// id is the generated on-disk ID of the new profile. CLI clients
// display a truncated form, UI clients can ignore it.
string id = 1;
}
message RemoveProfileRequest {
string username = 1;
// profileName is treated as a handle: an exact ID, a unique ID
// prefix, or a unique display name. Resolution happens server-side.
string profileName = 2;
}
message RemoveProfileResponse {}
message RemoveProfileResponse {
// id is the full resolved ID of the removed profile, so callers can
// confirm exactly which profile a name/prefix handle resolved to.
string id = 1;
}
message ListProfilesRequest {
string username = 1;
@@ -707,6 +726,7 @@ message ListProfilesResponse {
message Profile {
string name = 1;
bool is_active = 2;
string id = 3;
}
message GetActiveProfileRequest {}
@@ -714,6 +734,7 @@ message GetActiveProfileRequest {}
message GetActiveProfileResponse {
string profileName = 1;
string username = 2;
string id = 3;
}
message LogoutRequest {

View File

@@ -1,17 +1,16 @@
#!/bin/bash
set -e
if ! which realpath > /dev/null 2>&1
then
echo realpath is not installed
echo run: brew install coreutils
exit 1
if ! which realpath >/dev/null 2>&1; then
echo realpath is not installed
echo run: brew install coreutils
exit 1
fi
old_pwd=$(pwd)
script_path=$(dirname $(realpath "$0"))
cd "$script_path"
go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.1
go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.6.1
protoc -I ./ ./daemon.proto --go_out=../ --go-grpc_out=../ --experimental_allow_proto3_optional
cd "$old_pwd"

View File

@@ -79,7 +79,7 @@ func TestPersistLoginOverrides(t *testing.T) {
_, err := profilemanager.UpdateOrCreateConfig(seed)
require.NoError(t, err, "seed config")
activeProf := &profilemanager.ActiveProfileState{Name: "default"}
activeProf := &profilemanager.ActiveProfileState{ID: "default"}
err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK)
require.NoError(t, err, "persistLoginOverrides")

View File

@@ -308,15 +308,14 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
}
profState := profilemanager.ActiveProfileState{
Name: msg.ProfileName,
Username: msg.Username,
}
profPath, err := profState.FilePath()
resolved, err := s.resolveProfileHandle(msg.ProfileName, msg.Username)
if err != nil {
log.Errorf("failed to get active profile file path: %v", err)
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
log.Errorf("failed to resolve profile %q: %v", msg.ProfileName, err)
return nil, err
}
profPath := resolved.Path
if profPath == "" {
profPath = profilemanager.DefaultConfigPath
}
var config profilemanager.ConfigInput
@@ -446,30 +445,9 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
}
if msg.ProfileName != nil {
if *msg.ProfileName != "default" && (msg.Username == nil || *msg.Username == "") {
log.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName)
return nil, fmt.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName)
}
var username string
if *msg.ProfileName != "default" {
username = *msg.Username
}
if *msg.ProfileName != activeProf.Name && username != activeProf.Username {
if s.checkProfilesDisabled() {
log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled")
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
log.Infof("switching to profile %s for user '%s'", *msg.ProfileName, username)
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: *msg.ProfileName,
Username: username,
}); err != nil {
log.Errorf("failed to set active profile state: %v", err)
return nil, fmt.Errorf("failed to set active profile state: %w", err)
}
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
log.Errorf("failed to switch profile: %v", err)
return nil, err
}
}
@@ -479,7 +457,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
log.Infof("active profile: %s for %s", activeProf.ID, activeProf.Username)
s.mutex.Lock()
@@ -711,10 +689,10 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
}
if msg != nil && msg.ProfileName != nil {
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
s.mutex.Unlock()
log.Errorf("failed to switch profile: %v", err)
return nil, fmt.Errorf("failed to switch profile: %w", err)
return nil, err
}
}
@@ -725,7 +703,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
log.Infof("active profile: %s for %s", activeProf.ID, activeProf.Username)
config, _, err := s.getConfig(activeProf)
if err != nil {
@@ -768,34 +746,60 @@ func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error)
}
}
func (s *Server) switchProfileIfNeeded(profileName string, userName *string, activeProf *profilemanager.ActiveProfileState) error {
if profileName != "default" && (userName == nil || *userName == "") {
log.Errorf("profile name is set to %s, but username is not provided", profileName)
return fmt.Errorf("profile name is set to %s, but username is not provided", profileName)
// resolveProfileHandle resolves a wire-level profile handle (display
// name, ID, or unique ID prefix) to a concrete profile. Returns gRPC
// status errors so handlers can return them directly.
func (s *Server) resolveProfileHandle(handle, username string) (*profilemanager.Profile, error) {
p, err := s.profileManager.ResolveProfile(handle, username)
if err == nil {
return p, nil
}
var amb *profilemanager.ErrAmbiguousHandle
if errors.As(err, &amb) {
return nil, gstatus.Errorf(codes.InvalidArgument, "%v", amb)
}
if errors.Is(err, profilemanager.ErrProfileNotFound) {
return nil, gstatus.Errorf(codes.NotFound, "profile %q not found", handle)
}
return nil, fmt.Errorf("resolve profile: %w", err)
}
// switchProfileIfNeeded resolves the user-supplied handle, updates the
// active profile state if it differs from the current one, and returns
// the resolved profile so callers can include its ID in RPC responses.
func (s *Server) switchProfileIfNeeded(handle string, userName *string, activeProf *profilemanager.ActiveProfileState) (*profilemanager.Profile, error) {
if handle != profilemanager.DefaultProfileName && (userName == nil || *userName == "") {
log.Errorf("profile name is set to %s, but username is not provided", handle)
return nil, fmt.Errorf("profile name is set to %s, but username is not provided", handle)
}
var username string
if profileName != "default" {
if handle != profilemanager.DefaultProfileName {
username = *userName
}
if profileName != activeProf.Name || username != activeProf.Username {
resolved, err := s.resolveProfileHandle(handle, username)
if err != nil {
return nil, err
}
if resolved.ID != activeProf.ID || username != activeProf.Username {
if s.checkProfilesDisabled() {
log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled")
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
log.Infof("switching to profile %s for user %s", profileName, username)
log.Infof("switching to profile %s (%s) for user %s", resolved.Name, resolved.ID, username)
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profileName,
ID: resolved.ID,
Username: username,
}); err != nil {
log.Errorf("failed to set active profile state: %v", err)
return fmt.Errorf("failed to set active profile state: %w", err)
return nil, fmt.Errorf("failed to set active profile state: %w", err)
}
}
return nil
return resolved, nil
}
// SwitchProfile switches the active profile in the daemon.
@@ -810,9 +814,9 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
}
if msg != nil && msg.ProfileName != nil {
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
log.Errorf("failed to switch profile: %v", err)
return nil, fmt.Errorf("failed to switch profile: %w", err)
return nil, err
}
}
activeProf, err = s.profileManager.GetActiveProfileState()
@@ -828,7 +832,7 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
s.config = config
return &proto.SwitchProfileResponse{}, nil
return &proto.SwitchProfileResponse{Id: activeProf.ID}, nil
}
// Down engine work in the daemon.
@@ -912,22 +916,27 @@ func (s *Server) Logout(ctx context.Context, msg *proto.LogoutRequest) (*proto.L
}
func (s *Server) handleProfileLogout(ctx context.Context, msg *proto.LogoutRequest) (*proto.LogoutResponse, error) {
if err := s.validateProfileOperation(*msg.ProfileName, true); err != nil {
return nil, err
}
if msg.Username == nil || *msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided when profile name is specified")
}
username := *msg.Username
if err := s.logoutFromProfile(ctx, *msg.ProfileName, username); err != nil {
log.Errorf("failed to logout from profile %s: %v", *msg.ProfileName, err)
resolved, err := s.resolveProfileHandle(*msg.ProfileName, username)
if err != nil {
return nil, err
}
if err := s.validateProfileOperation(resolved.ID, true); err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, resolved); err != nil {
log.Errorf("failed to logout from profile %s: %v", resolved.ID, err)
return nil, gstatus.Errorf(codes.Internal, "logout: %v", err)
}
activeProf, _ := s.profileManager.GetActiveProfileState()
if activeProf != nil && activeProf.Name == *msg.ProfileName {
if activeProf != nil && activeProf.ID == resolved.ID {
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
log.Errorf("failed to cleanup connection: %v", err)
}
@@ -989,30 +998,30 @@ func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*prof
return config, configExisted, nil
}
func (s *Server) canRemoveProfile(profileName string) error {
if profileName == profilemanager.DefaultProfileName {
func (s *Server) canRemoveProfile(id string) error {
if id == profilemanager.DefaultProfileName {
return fmt.Errorf("remove profile with reserved name: %s", profilemanager.DefaultProfileName)
}
activeProf, err := s.profileManager.GetActiveProfileState()
if err == nil && activeProf.Name == profileName {
return fmt.Errorf("remove active profile: %s", profileName)
if err == nil && activeProf.ID == id {
return fmt.Errorf("remove active profile: %s", id)
}
return nil
}
func (s *Server) validateProfileOperation(profileName string, allowActiveProfile bool) error {
func (s *Server) validateProfileOperation(id string, allowActiveProfile bool) error {
if s.checkProfilesDisabled() {
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if profileName == "" {
if id == "" {
return gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
}
if !allowActiveProfile {
if err := s.canRemoveProfile(profileName); err != nil {
if err := s.canRemoveProfile(id); err != nil {
return gstatus.Errorf(codes.InvalidArgument, "%v", err)
}
}
@@ -1020,25 +1029,15 @@ func (s *Server) validateProfileOperation(profileName string, allowActiveProfile
return nil
}
// logoutFromProfile logs out from a specific profile by loading its config and sending logout request
func (s *Server) logoutFromProfile(ctx context.Context, profileName, username string) error {
func (s *Server) logoutFromProfile(ctx context.Context, profile *profilemanager.Profile) error {
activeProf, err := s.profileManager.GetActiveProfileState()
if err == nil && activeProf.Name == profileName && s.connectClient != nil {
if err == nil && activeProf.ID == profile.ID && s.connectClient != nil {
return s.sendLogoutRequest(ctx)
}
profileState := &profilemanager.ActiveProfileState{
Name: profileName,
Username: username,
}
profilePath, err := profileState.FilePath()
config, err := profilemanager.GetConfig(profile.Path)
if err != nil {
return fmt.Errorf("get profile path: %w", err)
}
config, err := profilemanager.GetConfig(profilePath)
if err != nil {
return fmt.Errorf("profile '%s' not found", profileName)
return fmt.Errorf("profile '%s' not found", profile.ID)
}
return s.sendLogoutRequestWithConfig(ctx, config)
@@ -1452,15 +1451,14 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
return nil, ctx.Err()
}
prof := profilemanager.ActiveProfileState{
Name: req.ProfileName,
Username: req.Username,
}
cfgPath, err := prof.FilePath()
resolved, err := s.resolveProfileHandle(req.ProfileName, req.Username)
if err != nil {
log.Errorf("failed to get active profile file path: %v", err)
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
log.Errorf("failed to resolve profile %q: %v", req.ProfileName, err)
return nil, err
}
cfgPath := resolved.Path
if cfgPath == "" {
cfgPath = profilemanager.DefaultConfigPath
}
cfg, err := profilemanager.GetConfig(cfgPath)
@@ -1564,12 +1562,16 @@ func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
}
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
created, err := s.profileManager.AddProfile(msg.ProfileName, msg.Username)
if err != nil {
if errors.Is(err, profilemanager.ErrProfileAlreadyExists) {
return nil, gstatus.Errorf(codes.AlreadyExists, "profile %q already exists", msg.ProfileName)
}
log.Errorf("failed to create profile: %v", err)
return nil, fmt.Errorf("failed to create profile: %w", err)
}
return &proto.AddProfileResponse{}, nil
return &proto.AddProfileResponse{Id: created.ID}, nil
}
// RemoveProfile removes a profile from the daemon.
@@ -1577,20 +1579,29 @@ func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequ
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
if msg.ProfileName == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
}
resolved, err := s.resolveProfileHandle(msg.ProfileName, msg.Username)
if err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
if err := s.validateProfileOperation(resolved.ID, false); err != nil {
return nil, err
}
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
if err := s.logoutFromProfile(ctx, resolved); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", resolved.ID, err)
}
if err := s.profileManager.RemoveProfile(resolved.ID, msg.Username); err != nil {
log.Errorf("failed to remove profile: %v", err)
return nil, fmt.Errorf("failed to remove profile: %w", err)
}
return &proto.RemoveProfileResponse{}, nil
return &proto.RemoveProfileResponse{Id: resolved.ID}, nil
}
// ListProfiles lists all profiles in the daemon.
@@ -1613,6 +1624,7 @@ func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesReques
}
for i, profile := range profiles {
response.Profiles[i] = &proto.Profile{
Id: profile.ID,
Name: profile.Name,
IsActive: profile.IsActive,
}
@@ -1621,7 +1633,9 @@ func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesReques
return response, nil
}
// GetActiveProfile returns the active profile in the daemon.
// GetActiveProfile returns the active profile in the daemon. The
// ProfileName field carries the display name for backwards compatibility
// with UI clients, new callers should prefer Id.
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
@@ -1632,9 +1646,22 @@ func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfi
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
displayName := activeProfile.ID
if activeProfile.ID != profilemanager.DefaultProfileName {
if profiles, lerr := s.profileManager.ListProfiles(activeProfile.Username); lerr == nil {
for _, p := range profiles {
if p.ID == activeProfile.ID {
displayName = p.Name
break
}
}
}
}
return &proto.GetActiveProfileResponse{
ProfileName: activeProfile.Name,
ProfileName: displayName,
Username: activeProfile.Username,
Id: activeProfile.ID,
}, nil
}

View File

@@ -97,7 +97,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "test-profile",
ID: "test-profile",
Username: currUser.Username,
})
if err != nil {
@@ -158,7 +158,7 @@ func TestServer_Up(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profName,
ID: profName,
Username: currUser.Username,
})
if err != nil {
@@ -228,7 +228,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "default",
ID: "default",
Username: currUser.Username,
})
if err != nil {

View File

@@ -47,7 +47,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profName,
ID: profName,
Username: currUser.Username,
})
require.NoError(t, err)
@@ -112,7 +112,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.NoError(t, err)
profState := profilemanager.ActiveProfileState{
Name: profName,
ID: profName,
Username: currUser.Username,
}
cfgPath, err := profState.FilePath()

View File

@@ -622,7 +622,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
}
req := &proto.SetConfigRequest{
ProfileName: activeProf.Name,
ProfileName: activeProf.ID,
Username: currUser.Username,
}
@@ -789,11 +789,11 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe
loginReq := &proto.LoginRequest{
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ProfileName: &activeProf.Name,
ProfileName: &activeProf.ID,
Username: &currUser.Username,
}
profileState, err := s.profileManager.GetProfileState(activeProf.Name)
profileState, err := s.profileManager.GetProfileState(activeProf.ID)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
@@ -1309,7 +1309,7 @@ func (s *serviceClient) getSrvConfig() {
}
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
ProfileName: activeProf.Name,
ProfileName: activeProf.ID,
Username: currUser.Username,
})
if err != nil {
@@ -1533,7 +1533,7 @@ func (s *serviceClient) loadSettings() {
}
cfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
ProfileName: activeProf.Name,
ProfileName: activeProf.ID,
Username: currUser.Username,
})
if err != nil {
@@ -1610,7 +1610,7 @@ func (s *serviceClient) updateConfig() error {
}
req := proto.SetConfigRequest{
ProfileName: activeProf.Name,
ProfileName: activeProf.ID,
Username: currUser.Username,
DisableAutoConnect: &disableAutoStart,
ServerSSHAllowed: &sshAllowed,

View File

@@ -88,7 +88,7 @@ func (s *serviceClient) showProfilesUI() {
return
}
// switch
err = s.switchProfile(profile.Name)
err = s.switchProfile(profile.ID)
if err != nil {
log.Errorf("failed to switch profile: %v", err)
dialog.ShowError(errors.New("failed to select profile"), s.wProfiles)
@@ -130,7 +130,7 @@ func (s *serviceClient) showProfilesUI() {
logoutBtn.Show()
logoutBtn.SetText("Deregister")
logoutBtn.OnTapped = func() {
s.handleProfileLogout(profile.Name, refresh)
s.handleProfileLogout(profile, refresh)
}
// Remove profile
@@ -144,7 +144,7 @@ func (s *serviceClient) showProfilesUI() {
return
}
err = s.removeProfile(profile.Name)
err = s.removeProfile(profile.ID)
if err != nil {
log.Errorf("failed to remove profile: %v", err)
dialog.ShowError(fmt.Errorf("failed to remove profile"), s.wProfiles)
@@ -250,7 +250,7 @@ func (s *serviceClient) addProfile(profileName string) error {
return nil
}
func (s *serviceClient) switchProfile(profileName string) error {
func (s *serviceClient) switchProfile(handle string) error {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
return fmt.Errorf(getClientFMT, err)
@@ -261,15 +261,15 @@ func (s *serviceClient) switchProfile(profileName string) error {
return fmt.Errorf("get current user: %w", err)
}
if _, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
ProfileName: &profileName,
resp, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
ProfileName: &handle,
Username: &currUser.Username,
}); err != nil {
})
if err != nil {
return fmt.Errorf("switch profile failed: %w", err)
}
err = s.profileManager.SwitchProfile(profileName)
if err != nil {
if err := s.profileManager.SwitchProfile(resp.Id); err != nil {
return fmt.Errorf("switch profile: %w", err)
}
@@ -299,6 +299,7 @@ func (s *serviceClient) removeProfile(profileName string) error {
}
type Profile struct {
ID string
Name string
IsActive bool
}
@@ -324,6 +325,7 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
for _, profile := range profilesResp.Profiles {
profiles = append(profiles, Profile{
ID: profile.Id,
Name: profile.Name,
IsActive: profile.IsActive,
})
@@ -332,10 +334,10 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
return profiles, nil
}
func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) {
func (s *serviceClient) handleProfileLogout(profile Profile, refreshCallback func()) {
dialog.ShowConfirm(
"Deregister",
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profileName),
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profile.Name),
func(confirm bool) {
if !confirm {
return
@@ -356,8 +358,10 @@ func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback
}
username := currUser.Username
// ProfileName is treated as a handle; send the ID so the
// daemon resolves to exactly this profile.
_, err = conn.Logout(s.ctx, &proto.LogoutRequest{
ProfileName: &profileName,
ProfileName: &profile.ID,
Username: &username,
})
if err != nil {
@@ -368,7 +372,7 @@ func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback
dialog.ShowInformation(
"Deregistered",
fmt.Sprintf("Successfully deregistered from '%s'", profileName),
fmt.Sprintf("Successfully deregistered from '%s'", profile.Name),
s.wProfiles,
)
@@ -461,6 +465,7 @@ func (p *profileMenu) getProfiles() ([]Profile, error) {
for _, profile := range profilesResp.Profiles {
profiles = append(profiles, Profile{
ID: profile.Id,
Name: profile.Name,
IsActive: profile.IsActive,
})
@@ -501,7 +506,7 @@ func (p *profileMenu) refresh() {
}
if activeProf.ProfileName == "default" || activeProf.Username == currUser.Username {
activeProfState, err := p.profileManager.GetProfileState(activeProf.ProfileName)
activeProfState, err := p.profileManager.GetProfileState(activeProf.Id)
if err != nil {
log.Warnf("failed to get active profile state: %v", err)
p.emailMenuItem.Hide()
@@ -541,8 +546,8 @@ func (p *profileMenu) refresh() {
return
}
_, err = conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profile.Name,
switchResp, err := conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profile.ID,
Username: &currUser.Username,
})
if err != nil {
@@ -552,7 +557,7 @@ func (p *profileMenu) refresh() {
return
}
err = p.profileManager.SwitchProfile(profile.Name)
err = p.profileManager.SwitchProfile(switchResp.Id)
if err != nil {
log.Errorf("failed to switch profile '%s': %v", profile.Name, err)
return

10
go.mod
View File

@@ -3,7 +3,7 @@ module github.com/netbirdio/netbird
go 1.25.5
require (
cunicu.li/go-rosenpass v0.4.0
cunicu.li/go-rosenpass v0.5.42
github.com/cenkalti/backoff/v4 v4.3.0
github.com/cloudflare/circl v1.3.3 // indirect
github.com/golang/protobuf v1.5.4
@@ -19,8 +19,8 @@ require (
github.com/vishvananda/netlink v1.3.1
golang.org/x/crypto v0.50.0
golang.org/x/sys v0.43.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard v0.0.0-20231211153847-12269c276173
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10
golang.zx2c4.com/wireguard/windows v0.5.3
google.golang.org/grpc v1.80.0
google.golang.org/protobuf v1.36.11
@@ -38,7 +38,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/s3 v1.87.3
github.com/c-robinson/iplib v1.0.3
github.com/caddyserver/certmagic v0.21.3
github.com/cilium/ebpf v0.15.0
github.com/cilium/ebpf v0.19.0
github.com/coder/websocket v1.8.14
github.com/coreos/go-iptables v0.7.0
github.com/coreos/go-oidc/v3 v3.18.0
@@ -60,7 +60,7 @@ require (
github.com/google/go-cmp v0.7.0
github.com/google/gopacket v1.1.19
github.com/google/nftables v0.3.0
github.com/gopacket/gopacket v1.1.1
github.com/gopacket/gopacket v1.4.0
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.0.2-0.20240212192251-757544f21357
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2

22
go.sum
View File

@@ -7,8 +7,8 @@ cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdB
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 h1:b8xUw3004wk+3ipBhu0VU4RtUJsegMIiqjxSK4++lzA=
codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6/go.mod h1:+CauBF6R70Jqcyl8N2hC8pAXYbWkGIezuSbuGLtRhnw=
cunicu.li/go-rosenpass v0.4.0 h1:LtPtBgFWY/9emfgC4glKLEqS0MJTylzV6+ChRhiZERw=
cunicu.li/go-rosenpass v0.4.0/go.mod h1:MPbjH9nxV4l3vEagKVdFNwHOketqgS5/To1VYJplf/M=
cunicu.li/go-rosenpass v0.5.42 h1:fRDsGwCxd7DhDgZI1Pxeo8GtNyq8BESZJ7w2/BGGJtU=
cunicu.li/go-rosenpass v0.5.42/go.mod h1:YRBeyKOe/gWpSX2kpDUec5p9t0XOLsshTguId5gTGVg=
dario.cat/mergo v1.0.1 h1:Ra4+bf83h2ztPIQYNP99R6m+Y7KfnARDfID+a+vLl4s=
dario.cat/mergo v1.0.1/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
filippo.io/edwards25519 v1.1.1 h1:YpjwWWlNmGIDyXOn8zLzqiD+9TyIlPhGFG96P39uBpw=
@@ -111,8 +111,8 @@ github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1x
github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cilium/ebpf v0.15.0 h1:7NxJhNiBT3NG8pZJ3c+yfrVdHY8ScgKD27sScgjLMMk=
github.com/cilium/ebpf v0.15.0/go.mod h1:DHp1WyrLeiBh19Cf/tfiSMhqheEiK8fXFZ4No0P1Hso=
github.com/cilium/ebpf v0.19.0 h1:Ro/rE64RmFBeA9FGjcTc+KmCeY6jXmryu6FfnzPRIao=
github.com/cilium/ebpf v0.19.0/go.mod h1:fLCgMo3l8tZmAdM3B2XqdFzXBpwkcSTroaVqN08OWVY=
github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g=
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
@@ -225,8 +225,8 @@ github.com/go-openapi/validate v0.24.0 h1:LdfDKwNbpB6Vn40xhTdNZAnfLECL81w+VX3Bum
github.com/go-openapi/validate v0.24.0/go.mod h1:iyeX1sEufmv3nPbBdX3ieNviWnOZaJ1+zquzJEf2BAQ=
github.com/go-playground/locales v0.12.1/go.mod h1:IUMDtCfWo/w/mtMfIE/IG2K+Ey3ygWanZIBtBW0W2TM=
github.com/go-playground/universal-translator v0.16.0/go.mod h1:1AnU7NaIRDWWzGEKwgtJRd2xk99HeFyHw3yid4rvQIY=
github.com/go-quicktest/qt v1.101.0 h1:O1K29Txy5P2OK0dGo59b7b0LR6wKfIhttaAhHUyn7eI=
github.com/go-quicktest/qt v1.101.0/go.mod h1:14Bz/f7NwaXPtdYEgzsx46kqSxVwTbzVZsDC26tQJow=
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6 h1:teYtXy9B7y5lHTp8V9KPxpYRAVA7dozigQcMiBust1s=
github.com/go-quicktest/qt v1.101.1-0.20240301121107-c6c8733fa1e6/go.mod h1:p4lGIVX+8Wa6ZPNDvqcxq36XpUDLh42FLetFU7odllI=
github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-sql-driver/mysql v1.9.3 h1:U/N249h2WzJ3Ukj8SowVFjdtZKfu9vlLZxjPXV1aweo=
github.com/go-sql-driver/mysql v1.9.3/go.mod h1:qn46aNg1333BRMNU69Lq93t8du/dwxI64Gl8i5p1WMU=
@@ -307,8 +307,8 @@ github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA
github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg=
github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI=
github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4=
github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw=
github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs=
github.com/gopacket/gopacket v1.4.0 h1:cr1OlFpzksCkZHNO0eLjaSSOrMQnpPXg0j6qHIY3y2U=
github.com/gopacket/gopacket v1.4.0/go.mod h1:EpvsxINeehp5qj4YMKMLf2/dekdhKn2IIAO/ZOifS7o=
github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE=
github.com/gorilla/handlers v1.5.2/go.mod h1:dX+xVpaxdSw+q0Qek8SSsl3dfMk3jNddUkMzo0GtH0w=
github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY=
@@ -390,6 +390,8 @@ github.com/jonboulle/clockwork v0.5.0 h1:Hyh9A8u51kptdkR+cqRpT1EebBwTn1oK9YfGYbd
github.com/jonboulle/clockwork v0.5.0/go.mod h1:3mZlmanh0g2NDKO5TWZVJAfofYk64M7XN3SzBPjZF60=
github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/jsimonetti/rtnetlink/v2 v2.0.1 h1:xda7qaHDSVOsADNouv7ukSuicKZO7GgVUCXxpaIEIlM=
github.com/jsimonetti/rtnetlink/v2 v2.0.1/go.mod h1:7MoNYNbb3UaDHtF8udiJo/RH6VsTKP1pqKLUTVCvToE=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25 h1:YLvr1eE6cdCqjOe972w/cYF+FjW34v27+9Vo5106B4M=
github.com/jsummers/gobmp v0.0.0-20230614200233-a9de23ed2e25/go.mod h1:kLgvv7o6UM+0QSf0QjAse3wReFDsb9qbZJdfexWlrQw=
@@ -900,8 +902,8 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvYQH2OU3/TnxLx97WDSUDRABfT18pCOYwc2GE=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10 h1:3GDAcqdIg1ozBNLgPy4SLT84nfcBjr6rhGtXYtrkWLU=
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20241231184526-a9ab2273dd10/go.mod h1:T97yPqesLiNrOYxkwmhMI0ZIlJDm+p0PMR8eRVeR5tQ=
golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE=
golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI=
gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4=