Compare commits

...

2 Commits

Author SHA1 Message Date
mlsmaycon
2d8b0310a4 [client, proxy] IPv6 in-place apply + accept-loop hardening on netstack listeners
Two related fixes for the embedded netbird client and the per-account
inbound listeners that ride on its gVisor netstack.

client/internal/engine.go — replace hasIPv6Changed with reconcileIPv6:

  - First v6 assignment (current had no v6, conf carries one) is applied
    in place via WGIface.UpdateAddr instead of returning ErrResetConnection.
    Pre-fix, every embedded client whose account had IPv6 enabled would
    reset on its first NetworkMap sync — boot config has no v6, the sync
    introduces one, the engine tore itself down to "apply" it. That
    teardown destroys the gVisor netstack and orphans every listener
    bound on it, which is what made the proxy's per-account :80/:443
    silently stop accepting traffic.
  - v6 removed clears in place.
  - v6 swapped to a different non-empty value still resets (gVisor
    netstack can't safely swap its address at runtime).
  - Mutates e.config.WgAddr to match the applied state so subsequent
    PeerConfig comparisons are stable.

proxy/internal/tcp/accept.go (new) + proxy/inbound.go +
proxy/internal/tcp/router.go — harden the two Accept() loops on
netstack-backed listeners:

  - IsClosedListenerErr recognises net.ErrClosed AND gVisor's
    "endpoint is in invalid state" — the latter survives gonet's
    *net.OpError wrapping in a way errors.Is(.., net.ErrClosed) does
    not. Without this the loop spins CPU-hot after the underlying
    netstack is destroyed (peer rekey, embedded-client reset, account
    churn), emitting one log line per iteration.
  - AcceptBackoff implements the exponential backoff that
    net/http.Server.Serve uses on transient Accept errors: 5ms doubling
    up to 1s. Defence-in-depth so an unknown sticky error cannot burn
    a CPU core even if IsClosedListenerErr misses its signature.

proxy/internal/roundtrip/netbird.go — emit a single structured INFO
line summarising every embed.Options flag (account_id, service_id,
public_key, management_url, wg_port, block_inbound, block_lan_access,
disable_ipv6, no_userspace, presence of credentials) when each
per-account embedded client is created. Secrets reduced to a "present"
boolean — never logged verbatim. Diagnostic-only; no behavior change,
but it makes the "why is this embedded peer misbehaving" loop a single
log read instead of a code dive.

Tests (real listeners, scripted errors, no mocks of production code):
  - engine_reconcileipv6_test.go: 8 cases for every transition (first
    assignment, no change, removed, prefix-length changed, value
    changed, invalid bytes, UpdateAddr error) plus a updateConfig
    integration check that the fix actually fires on a v6-added
    PeerConfig.
  - accept_test.go: IsClosedListenerErr matrix + AcceptBackoff
    progression / cap / reset / cancel-during-wait / cancel-before-call.
  - router_test.go, inbound_test.go: scriptedAcceptListener +
    TestRouter_Serve_ExitsOnGVisorInvalidEndpoint and
    TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint —
    regression guards that fail in 2 s if the loop ever spins.
2026-06-18 10:37:51 +02:00
Theodor Midtlien
ee360963f9 [client] Migrate profile identity from display name to ID and allow renaming of profiles (#6367)
* Migrate to profile ids

* Migrate android profile manager

* Clean up

* Fix review

* Add ID type

* Fix test and runes in ShortID()

* Fix profile switch on up and android comments

* Revert android profile to string id

* Fix feedback

* Fix UI feedback

* Fix id assignment

* Add renaming of profiles

* Fix review

* Remove ui binary
* Fix getProfileConfigPath not validating id

* Change resolve handle order and fix server merge problems

* Fix mdm test
2026-06-18 08:49:19 +02:00
35 changed files with 2625 additions and 725 deletions

View File

@@ -6,7 +6,6 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
@@ -24,6 +23,7 @@ const (
// Profile represents a profile for gomobile
type Profile struct {
ID string
Name string
IsActive bool
}
@@ -53,10 +53,10 @@ func (p *ProfileArray) Get(i int) *Profile {
├── state.json ← Default profile state
├── active_profile.json ← Active profile tracker (JSON with Name + Username)
└── profiles/ ← Subdirectory for non-default profiles
├── work.json ← Work profile config
├── work.state.json ← Work profile state
├── personal.json ← Personal profile config
── personal.state.json ← Personal profile state
├── work.json ← Legacy work profile config
├── work.state.json ← Legacy work profile state
├── 4c5f5c8198c3989cffb5b5394f5a7ae0.json ← ID profile config
── 4c5f5c8198c3989cffb5b5394f5a7ae0.state.json ← ID profile state
*/
// ProfileManager manages profiles for Android
@@ -99,6 +99,7 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
var profiles []*Profile
for _, p := range internalProfiles {
profiles = append(profiles, &Profile{
ID: p.ID.String(),
Name: p.Name,
IsActive: p.IsActive,
})
@@ -108,55 +109,65 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
}
// GetActiveProfile returns the currently active profile name
func (pm *ProfileManager) GetActiveProfile() (string, error) {
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
activeState, err := pm.serviceMgr.GetActiveProfileState()
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
return nil, fmt.Errorf("failed to get active profile: %w", err)
}
return activeState.Name, nil
// ActiveProfileState only stores the ID (and username), not the display
// name. Resolve the ID to the full profile so callers get the real Name.
prof, err := pm.serviceMgr.ResolveProfile(activeState.ID.String(), androidUsername)
if err != nil {
return nil, fmt.Errorf("failed to resolve active profile %q: %w", activeState.ID, err)
}
return &Profile{ID: prof.ID.String(), Name: prof.Name, IsActive: true}, nil
}
// SwitchProfile switches to a different profile
func (pm *ProfileManager) SwitchProfile(profileName string) error {
func (pm *ProfileManager) SwitchProfile(id string) error {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profileName,
ID: profilemanager.ID(id),
Username: androidUsername,
})
if err != nil {
return fmt.Errorf("failed to switch profile: %w", err)
}
log.Infof("switched to profile: %s", profileName)
log.Infof("switched to profile: %s", id)
return nil
}
// AddProfile creates a new profile
func (pm *ProfileManager) AddProfile(profileName string) error {
// Use ServiceManager (creates profile in profiles/ directory)
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
profile, err := pm.serviceMgr.AddProfile(profileName, androidUsername)
if err != nil {
return fmt.Errorf("failed to add profile: %w", err)
}
log.Infof("created new profile: %s", profileName)
log.Infof("created new profile: %s", profile.ID)
return nil
}
// LogoutProfile logs out from a profile (clears authentication)
func (pm *ProfileManager) LogoutProfile(profileName string) error {
profileName = sanitizeProfileName(profileName)
configPath, err := pm.getProfileConfigPath(profileName)
func (pm *ProfileManager) LogoutProfile(id string) error {
configPath, err := pm.getProfileConfigPath(id)
if err != nil {
return err
}
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
return fmt.Errorf("id '%s' is not valid", id)
}
// Check if profile exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return fmt.Errorf("profile '%s' does not exist", profileName)
return fmt.Errorf("profile '%s' does not exist", id)
}
// Read current config using internal profilemanager
@@ -174,53 +185,57 @@ func (pm *ProfileManager) LogoutProfile(profileName string) error {
return fmt.Errorf("failed to save config: %w", err)
}
log.Infof("logged out from profile: %s", profileName)
log.Infof("logged out from profile: %s", id)
return nil
}
// RemoveProfile deletes a profile
func (pm *ProfileManager) RemoveProfile(profileName string) error {
func (pm *ProfileManager) RemoveProfile(id string) error {
// Use ServiceManager (removes profile from profiles/ directory)
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
if err := pm.serviceMgr.RemoveProfile(profilemanager.ID(id), androidUsername); err != nil {
return fmt.Errorf("failed to remove profile: %w", err)
}
log.Infof("removed profile: %s", profileName)
log.Infof("removed profile: %s", id)
return nil
}
// getProfileConfigPath returns the config file path for a profile
// This is needed for Android-specific path handling (netbird.cfg for default profile)
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
func (pm *ProfileManager) getProfileConfigPath(id string) (string, error) {
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
return "", fmt.Errorf("id %q is not valid", id)
}
if id == profilemanager.DefaultProfileName {
// Android uses netbird.cfg for default profile instead of default.json
// Default profile is stored in root configDir, not in profiles/
return filepath.Join(pm.configDir, defaultConfigFilename), nil
}
// Non-default profiles are stored in profiles subdirectory
// This matches the Java Preferences.java expectation
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".json"), nil
return filepath.Join(profilesDir, id+".json"), nil
}
// GetConfigPath returns the config file path for a given profile
// GetConfigPath returns the config file path for a given profile id
// Java should call this instead of constructing paths with Preferences.configFile()
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
return pm.getProfileConfigPath(profileName)
func (pm *ProfileManager) GetConfigPath(id string) (string, error) {
return pm.getProfileConfigPath(id)
}
// GetStateFilePath returns the state file path for a given profile
// Java should call this instead of constructing paths with Preferences.stateFile()
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
func (pm *ProfileManager) GetStateFilePath(id string) (string, error) {
if id == "" || id == profilemanager.DefaultProfileName {
return filepath.Join(pm.configDir, "state.json"), nil
}
profileName = sanitizeProfileName(profileName)
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
return "", fmt.Errorf("id %q is not valid", id)
}
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, profileName+".state.json"), nil
return filepath.Join(profilesDir, id+".state.json"), nil
}
// GetActiveConfigPath returns the config file path for the currently active profile
@@ -230,7 +245,7 @@ func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetConfigPath(activeProfile)
return pm.GetConfigPath(activeProfile.ID)
}
// GetActiveStateFilePath returns the state file path for the currently active profile
@@ -240,18 +255,5 @@ func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetStateFilePath(activeProfile)
}
// sanitizeProfileName removes invalid characters from profile name
func sanitizeProfileName(name string) string {
// Keep only alphanumeric, underscore, and hyphen
var result strings.Builder
for _, r := range name {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
result.WriteRune(r)
}
}
return result.String()
return pm.GetStateFilePath(activeProfile.ID)
}

View File

@@ -96,17 +96,19 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
}
handle := activeProf.ID.String()
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName,
DnsLabels: dnsLabelsReq,
ProfileName: &activeProf.Name,
ProfileName: &handle,
Username: &username,
}
profileState, err := pm.GetProfileState(activeProf.Name)
profileState, err := pm.GetProfileState(activeProf.ID)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
@@ -170,14 +172,13 @@ func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, pr
return activeProf, nil
}
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) error {
err := switchProfile(context.Background(), profileName, username)
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, handle string, username string) error {
resolvedID, err := switchProfile(ctx, handle, username)
if err != nil {
return fmt.Errorf("switch profile on daemon: %v", err)
}
err = pm.SwitchProfile(profileName)
if err != nil {
if err := pm.SwitchProfile(resolvedID); err != nil {
return fmt.Errorf("switch profile: %v", err)
}
@@ -205,11 +206,15 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
return nil
}
func switchProfile(ctx context.Context, profileName string, username string) error {
// switchProfile asks the daemon to switch to the profile identified by
// handle (a name, ID, or unique ID prefix). Returns the resolved profile
// ID so the caller can update the local active-profile state without
// re-resolving the handle.
func switchProfile(ctx context.Context, handle string, username string) (profilemanager.ID, error) {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return fmt.Errorf("failed to connect to daemon error: %v\n"+
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
@@ -217,15 +222,15 @@ func switchProfile(ctx context.Context, profileName string, username string) err
client := proto.NewDaemonServiceClient(conn)
_, err = client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profileName,
resp, err := client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &handle,
Username: &username,
})
if err != nil {
return fmt.Errorf("switch profile failed: %v", err)
return "", fmt.Errorf("switch profile failed: %v", err)
}
return nil
return profilemanager.ID(resp.Id), nil
}
func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, activeProf *profilemanager.Profile) error {
@@ -249,7 +254,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
return fmt.Errorf("read config file %s: %v", configFilePath, err)
}
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.ID)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -277,7 +282,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
return nil
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string, profileID profilemanager.ID) error {
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
@@ -291,7 +296,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken := ""
if setupKey == "" && needsLogin {
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileID)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
@@ -306,10 +311,10 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
return nil
}
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileID profilemanager.ID) (*auth.TokenInfo, error) {
hint := ""
pm := profilemanager.NewProfileManager()
profileState, err := pm.GetProfileState(profileName)
profileState, err := pm.GetProfileState(profileID)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {

View File

@@ -27,7 +27,7 @@ func TestLogin(t *testing.T) {
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
sm := profilemanager.ServiceManager{}
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "default",
ID: "default",
Username: currUser.Username,
})
if err != nil {

View File

@@ -2,11 +2,16 @@ package cmd
import (
"context"
"errors"
"fmt"
"os/user"
"strings"
"text/tabwriter"
"time"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
@@ -14,6 +19,8 @@ import (
"github.com/netbirdio/netbird/util"
)
var profileListShowID bool
var profileCmd = &cobra.Command{
Use: "profile",
Short: "Manage NetBird client profiles",
@@ -31,27 +38,40 @@ var profileListCmd = &cobra.Command{
var profileAddCmd = &cobra.Command{
Use: "add <profile_name>",
Short: "Add a new profile",
Long: `Add a new profile to the NetBird client. The profile name must be unique.`,
Long: `Add a new profile. Profile name is free-form, a unique ID is generated for the on-disk config file.`,
Args: cobra.ExactArgs(1),
RunE: addProfileFunc,
}
var profileRenameCmd = &cobra.Command{
Use: "rename <profile> <new_profile_name>",
Short: "Renames an existing profile",
Long: `Renames an existing profile (by a name, ID, or unique ID prefix). Profile name is free-form.`,
Args: cobra.ExactArgs(2),
RunE: renameProfileFunc,
}
var profileRemoveCmd = &cobra.Command{
Use: "remove <profile_name>",
Short: "Remove a profile",
Long: `Remove a profile from the NetBird client. The profile must not be inactive.`,
Args: cobra.ExactArgs(1),
RunE: removeProfileFunc,
Use: "remove <profile>",
Short: "Remove a profile",
Long: `Remove a profile by name, ID, or unique ID prefix.`,
Aliases: []string{"rm"},
Args: cobra.ExactArgs(1),
RunE: removeProfileFunc,
}
var profileSelectCmd = &cobra.Command{
Use: "select <profile_name>",
Use: "select <profile>",
Short: "Select a profile",
Long: `Make the specified profile active. This will switch the client to use the selected profile's configuration.`,
Long: `Make the specified profile active. Accepts a name, ID, or unique ID prefix.`,
Args: cobra.ExactArgs(1),
RunE: selectProfileFunc,
}
func init() {
profileListCmd.Flags().BoolVar(&profileListShowID, "show-id", false, "show the profile ID column")
}
func setupCmd(cmd *cobra.Command) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
@@ -65,6 +85,7 @@ func setupCmd(cmd *cobra.Command) error {
return nil
}
func listProfilesFunc(cmd *cobra.Command, _ []string) error {
if err := setupCmd(cmd); err != nil {
return err
@@ -83,25 +104,33 @@ func listProfilesFunc(cmd *cobra.Command, _ []string) error {
daemonClient := proto.NewDaemonServiceClient(conn)
profiles, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
resp, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
Username: currUser.Username,
})
if err != nil {
return err
}
// list profiles, add a tick if the profile is active
cmd.Println("Found", len(profiles.Profiles), "profiles:")
for _, profile := range profiles.Profiles {
// use a cross to indicate the passive profiles
activeMarker := "✗"
if profile.IsActive {
activeMarker = "✓"
}
cmd.Println(activeMarker, profile.Name)
tw := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0)
if profileListShowID {
fmt.Fprintln(tw, "ID\tNAME\tACTIVE")
} else {
fmt.Fprintln(tw, "NAME\tACTIVE")
}
return nil
for _, profile := range resp.Profiles {
marker := ""
if profile.IsActive {
marker = "✓"
}
name := profilemanager.StripCtrlChars(profile.Name)
id := profilemanager.ID(profile.Id)
if profileListShowID {
fmt.Fprintf(tw, "%s\t%s\t%s\n", id.ShortID(), name, marker)
} else {
fmt.Fprintf(tw, "%s\t%s\n", name, marker)
}
}
return tw.Flush()
}
func addProfileFunc(cmd *cobra.Command, args []string) error {
@@ -121,21 +150,82 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
}
daemonClient := proto.NewDaemonServiceClient(conn)
profileName := args[0]
_, err = daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
if err != nil {
return fmt.Errorf("add profile request: %w", err)
}
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
if dupCount > 1 {
cmd.Printf("Warning: %d other profile(s) already use the name %q.\n", dupCount-1, profileName)
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
}
id := profilemanager.ID(resp.Id)
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
return nil
}
func renameProfileFunc(cmd *cobra.Command, args []string) error {
if err := setupCmd(cmd); err != nil {
return err
}
cmd.Println("Profile added successfully:", profileName)
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
if err != nil {
return fmt.Errorf("connect to service CLI interface: %w", err)
}
defer conn.Close()
currUser, err := user.Current()
if err != nil {
return fmt.Errorf("get current user: %w", err)
}
daemonClient := proto.NewDaemonServiceClient(conn)
handle := args[0]
newProfilename := args[1]
resp, err := daemonClient.RenameProfile(cmd.Context(), &proto.RenameProfileRequest{
Handle: handle,
Username: currUser.Username,
NewProfileName: newProfilename,
})
if err != nil {
return wrapAmbiguityError(err, handle)
}
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, newProfilename)
if dupCount > 1 {
cmd.Printf("Warning: %d other profile(s) already use the name %q.\n", dupCount-1, newProfilename)
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
}
cmd.Printf("Profile renamed from %s to %s\n", profilemanager.StripCtrlChars(resp.OldProfileName), profilemanager.StripCtrlChars(newProfilename))
return nil
}
func countProfilesWithName(ctx context.Context, c proto.DaemonServiceClient, username, name string) (int, error) {
resp, err := c.ListProfiles(ctx, &proto.ListProfilesRequest{Username: username})
if err != nil {
return 0, err
}
n := 0
for _, p := range resp.Profiles {
if p.Name == name {
n++
}
}
return n, nil
}
func removeProfileFunc(cmd *cobra.Command, args []string) error {
if err := setupCmd(cmd); err != nil {
return err
@@ -153,18 +243,17 @@ func removeProfileFunc(cmd *cobra.Command, args []string) error {
}
daemonClient := proto.NewDaemonServiceClient(conn)
handle := args[0]
profileName := args[0]
_, err = daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
ProfileName: profileName,
resp, err := daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
ProfileName: handle,
Username: currUser.Username,
})
if err != nil {
return err
return wrapAmbiguityError(err, handle)
}
cmd.Println("Profile removed successfully:", profileName)
cmd.Printf("Profile removed: %s\n", resp.Id)
return nil
}
@@ -174,7 +263,7 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
}
profileManager := profilemanager.NewProfileManager()
profileName := args[0]
handle := args[0]
currUser, err := user.Current()
if err != nil {
@@ -191,32 +280,15 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
daemonClient := proto.NewDaemonServiceClient(conn)
profiles, err := daemonClient.ListProfiles(ctx, &proto.ListProfilesRequest{
Username: currUser.Username,
switchResp, err := daemonClient.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &handle,
Username: &currUser.Username,
})
if err != nil {
return fmt.Errorf("list profiles: %w", err)
return wrapAmbiguityError(err, handle)
}
var profileExists bool
for _, profile := range profiles.Profiles {
if profile.Name == profileName {
profileExists = true
break
}
}
if !profileExists {
return fmt.Errorf("profile %s does not exist", profileName)
}
if err := switchProfile(cmd.Context(), profileName, currUser.Username); err != nil {
return err
}
err = profileManager.SwitchProfile(profileName)
if err != nil {
if err := profileManager.SwitchProfile(profilemanager.ID(switchResp.Id)); err != nil {
return err
}
@@ -231,6 +303,30 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
}
}
cmd.Println("Profile switched successfully to:", profileName)
id := profilemanager.ID(switchResp.Id)
cmd.Printf("Profile switched to: %s\n", id.ShortID())
return nil
}
// wrapAmbiguityError turns the daemon's gRPC InvalidArgument errors
// (which carry the resolver's message verbatim) into CLI-friendly text
// that points the user at --show-id.
func wrapAmbiguityError(err error, handle string) error {
if err == nil {
return nil
}
st, ok := gstatus.FromError(err)
if !ok {
return err
}
switch st.Code() {
case codes.InvalidArgument:
msg := st.Message()
if strings.Contains(msg, "ambiguous") {
return errors.New(msg + "\nRun `netbird profile list --show-id` to see IDs, then select by ID prefix:\n netbird profile select|remove <id-prefix>")
}
case codes.NotFound:
return fmt.Errorf("profile %q not found", handle)
}
return err
}

View File

@@ -190,6 +190,7 @@ func init() {
// profile commands
profileCmd.AddCommand(profileListCmd)
profileCmd.AddCommand(profileAddCmd)
profileCmd.AddCommand(profileRenameCmd)
profileCmd.AddCommand(profileRemoveCmd)
profileCmd.AddCommand(profileSelectCmd)

View File

@@ -128,13 +128,12 @@ func upFunc(cmd *cobra.Command, args []string) error {
var profileSwitched bool
// switch profile if provided
if profileName != "" {
err = switchProfile(cmd.Context(), profileName, username.Username)
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
err = pm.SwitchProfile(profileName)
if err != nil {
if err := pm.SwitchProfile(resolvedID); err != nil {
return fmt.Errorf("switch profile: %v", err)
}
@@ -190,7 +189,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.ID)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -261,10 +260,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
}
// set the new config
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.Name, username.Username)
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.ID.String(), username.Username)
if _, err := client.SetConfig(ctx, req); err != nil {
if st, ok := gstatus.FromError(err); ok && st.Code() == codes.Unavailable {
log.Warnf("setConfig method is not available in the daemon")
log.Warnf("setConfig method is not available in the daemon: %s", st.Message())
} else {
return fmt.Errorf("call service setConfig method: %v", err)
}
@@ -289,10 +288,11 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
return fmt.Errorf("setup login request: %v", err)
}
loginRequest.ProfileName = &activeProf.Name
profileID := activeProf.ID.String()
loginRequest.ProfileName = &profileID
loginRequest.Username = &username
profileState, err := pm.GetProfileState(activeProf.Name)
profileState, err := pm.GetProfileState(activeProf.ID)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
@@ -329,7 +329,7 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
}
if _, err := client.Up(ctx, &proto.UpRequest{
ProfileName: &activeProf.Name,
ProfileName: &profileID,
Username: &username,
}); err != nil {
return fmt.Errorf("call service up method: %v", err)

View File

@@ -29,14 +29,14 @@ func TestUpDaemon(t *testing.T) {
}
sm := profilemanager.ServiceManager{}
err = sm.AddProfile("test1", currUser.Username)
created, err := sm.AddProfile("test1", currUser.Username)
if err != nil {
t.Fatalf("failed to add profile: %v", err)
return
}
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "test1",
ID: created.ID,
Username: currUser.Username,
})
if err != nil {

View File

@@ -843,6 +843,7 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
"PreSharedKey": "sensitive: WireGuard pre-shared key",
"SSHKey": "sensitive: SSH private key",
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
"Name": "non-config: profile name is not needed for debug purposes",
"policy": "non-config: in-memory MDM policy snapshot, surfaced via Config.Policy() / GetConfigResponse.MDMManagedFields",
}

View File

@@ -64,7 +64,6 @@ import (
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
@@ -1078,11 +1077,17 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return ErrResetConnection
}
if !e.config.DisableIPv6 && e.hasIPv6Changed(conf) {
log.Infof("peer IPv6 address changed, restarting client")
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return ErrResetConnection
if !e.config.DisableIPv6 {
reset, err := e.reconcileIPv6(conf)
if err != nil {
log.Warnf("reconcile IPv6 from PeerConfig: %v", err)
}
if reset {
log.Infof("peer IPv6 address changed value, restarting client")
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return ErrResetConnection
}
}
if conf.GetSshConfig() != nil {
@@ -1104,25 +1109,58 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return nil
}
// hasIPv6Changed reports whether the IPv6 overlay address in the peer config
// differs from the configured address (added, removed, or changed).
// Compares against e.config.WgAddr (not the interface address, which may have
// been cleared by ClearIPv6 if OS assignment failed).
func (e *Engine) hasIPv6Changed(conf *mgmProto.PeerConfig) bool {
current := e.config.WgAddr
// reconcileIPv6 applies the management-supplied IPv6 overlay address to the
// engine's WireGuard interface in place when possible. Three transitions:
//
// - First v6 assignment (current had no v6, conf carries one): apply via
// WGIface.UpdateAddr, no reset. Critical for embedded clients whose
// boot config has no v6 — without this we reset on every fresh start
// once management has v6 enabled, orphaning any netstack listeners
// held outside the engine.
// - v6 removed (current had v6, conf carries none): clear in place, no
// reset.
// - v6 swapped to a different non-empty value: returns reset=true so the
// caller falls back to the engine-recreate path — the underlying
// interface address can't be safely swapped in place across all
// backends (gVisor netstack in particular fixes its address at
// CreateNetTUN time).
//
// Mutates e.config.WgAddr to match the applied state so subsequent
// PeerConfig comparisons are stable.
func (e *Engine) reconcileIPv6(conf *mgmProto.PeerConfig) (reset bool, err error) {
raw := conf.GetAddressV6()
current := e.config.WgAddr
if len(raw) == 0 {
return current.HasIPv6()
if !current.HasIPv6() {
return false, nil
}
current.ClearIPv6()
e.config.WgAddr = current
if err := e.wgInterface.UpdateAddr(current); err != nil {
return false, fmt.Errorf("clear ipv6 on wg interface: %w", err)
}
return false, nil
}
prefix, err := netiputil.DecodePrefix(raw)
if err != nil {
log.Errorf("decode v6 overlay address: %v", err)
return false
incoming := current
if err := incoming.SetIPv6FromCompact(raw); err != nil {
return false, fmt.Errorf("decode v6 overlay address: %w", err)
}
return !current.HasIPv6() || current.IPv6 != prefix.Addr() || current.IPv6Net != prefix.Masked()
if !current.HasIPv6() {
e.config.WgAddr = incoming
if err := e.wgInterface.UpdateAddr(incoming); err != nil {
return false, fmt.Errorf("apply ipv6 on wg interface: %w", err)
}
return false, nil
}
if current.IPv6 == incoming.IPv6 && current.IPv6Net == incoming.IPv6Net {
return false, nil
}
return true, nil
}
func (e *Engine) receiveJobEvents() {

View File

@@ -0,0 +1,305 @@
package internal
import (
"context"
"errors"
"net/netip"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/peer"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
)
// reconcileIPv6 / updateConfig regression suite. Locks down the behavior that
// PR #5631 (main-side IPv6 overlay support) accidentally broke for embedded
// netstack clients: any first NetworkMap update that brings an IPv6 address
// used to trigger ErrResetConnection, which destroys the netstack and orphans
// every listener bound on it (proxy-side inbound listeners in particular).
// The fix in reconcileIPv6 distinguishes "v6 first-assigned" (apply in place)
// from "v6 swapped value" (must reset).
func mustEncodeV6Prefix(t *testing.T, p netip.Prefix) []byte {
t.Helper()
b, err := netiputil.EncodePrefix(p)
require.NoError(t, err, "encode v6 prefix %s", p)
return b
}
// reconcileIPv6Fixture builds the smallest Engine the function under test
// needs: a config (with WgAddr being the load-bearing field) and a wgInterface
// whose UpdateAddr call we can observe.
func reconcileIPv6Fixture(t *testing.T, initial wgaddr.Address) (*Engine, *MockWGIface, *wgaddr.Address) {
t.Helper()
var applied wgaddr.Address
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return initial },
UpdateAddrFunc: func(a wgaddr.Address) error {
applied = a
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: initial},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
return e, mock, &applied
}
func TestReconcileIPv6_FirstAssignment_AppliesInPlace(t *testing.T) {
// Embedded clients boot v4-only; management later assigns a v6 overlay.
// The fix: apply v6 in place, return reset=false. Pre-fix this case
// fell through to the "v6 changed" branch and reset the engine.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
e, mock, applied := reconcileIPv6Fixture(t, v4)
v6Prefix := netip.MustParsePrefix("fd00::1/64")
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.False(t, reset, "first v6 assignment must NOT request an engine reset")
require.True(t, e.config.WgAddr.HasIPv6(), "engine config must record the new v6")
assert.Equal(t, v6Prefix.Addr(), e.config.WgAddr.IPv6, "engine config v6 address must match")
assert.Equal(t, v6Prefix.Masked(), e.config.WgAddr.IPv6Net, "engine config v6 prefix must match")
require.True(t, applied.HasIPv6(), "WGIface.UpdateAddr must be called with v6 populated")
assert.Equal(t, v6Prefix.Addr(), applied.IPv6, "UpdateAddr must carry the new v6")
_ = mock
}
func TestReconcileIPv6_NoChange_NoOp(t *testing.T) {
// Steady state: management redelivers the same PeerConfig. No interface
// mutation, no reset. Guards against an infinite reset loop if the
// comparison ever drifts (e.g. address-vs-prefix masking bugs).
v6Prefix := netip.MustParsePrefix("fd00::1/64")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, v6Prefix)))
updateAddrCalled := false
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return addr },
UpdateAddrFunc: func(a wgaddr.Address) error {
updateAddrCalled = true
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: addr},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.False(t, reset, "unchanged v6 must NOT trigger reset")
assert.False(t, updateAddrCalled, "unchanged v6 must NOT call UpdateAddr")
}
func TestReconcileIPv6_Removed_AppliesInPlace(t *testing.T) {
// Management withdraws v6 (e.g. account toggled off the v6 group).
// Cleared in place, no reset.
v6Prefix := netip.MustParsePrefix("fd00::1/64")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, v6Prefix)))
e, _, applied := reconcileIPv6Fixture(t, addr)
e.config.WgAddr = addr
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: nil,
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.False(t, reset, "v6 removed must NOT trigger reset")
assert.False(t, e.config.WgAddr.HasIPv6(), "engine config must reflect v6 cleared")
assert.False(t, applied.HasIPv6(), "UpdateAddr must receive cleared v6")
}
func TestReconcileIPv6_PrefixLengthChanged_RequestsReset(t *testing.T) {
// Same v6 host, different mask (e.g. /64 → /80). Treated like a value
// change because the new netmask redefines the broadcast/scope.
oldPrefix := netip.MustParsePrefix("fd00::1/64")
newPrefix := netip.MustParsePrefix("fd00::1/80")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, oldPrefix)))
updateAddrCalled := false
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return addr },
UpdateAddrFunc: func(a wgaddr.Address) error {
updateAddrCalled = true
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: addr},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: mustEncodeV6Prefix(t, newPrefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.True(t, reset, "v6 prefix length change must request a reset")
assert.False(t, updateAddrCalled, "v6 prefix length change must NOT touch the interface")
}
func TestReconcileIPv6_ValueChanged_RequestsReset(t *testing.T) {
// v6 was X, now Y. The netstack backend can't safely swap an existing
// address in place — fall back to the engine recreate path.
oldPrefix := netip.MustParsePrefix("fd00::1/64")
newPrefix := netip.MustParsePrefix("fd00::2/64")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, oldPrefix)))
updateAddrCalled := false
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return addr },
UpdateAddrFunc: func(a wgaddr.Address) error {
updateAddrCalled = true
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: addr},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: mustEncodeV6Prefix(t, newPrefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.True(t, reset, "v6 value change must request a reset")
assert.False(t, updateAddrCalled,
"v6 value change must NOT call UpdateAddr — caller will recreate the interface")
}
func TestReconcileIPv6_InvalidBytes_ReturnsError(t *testing.T) {
// Corrupt PeerConfig.AddressV6 must not crash the engine and must not
// trigger a spurious reset.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
e, _, applied := reconcileIPv6Fixture(t, v4)
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: []byte{0x00}, // truncated, definitely not a valid prefix
}
reset, err := e.reconcileIPv6(conf)
require.Error(t, err, "malformed v6 bytes must surface an error")
assert.False(t, reset, "decode error must NOT request a reset")
assert.False(t, applied.HasIPv6(), "decode error must NOT touch the interface")
}
func TestReconcileIPv6_UpdateAddrError_DoesNotPropagateReset(t *testing.T) {
// If WGIface.UpdateAddr fails (e.g. OS-side assignment error on a
// kernel device), reconcileIPv6 returns the error to the caller for
// logging — but it must NOT request a reset. The whole point of the
// fix is to AVOID the reset cascade on v6 transitions.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return v4 },
UpdateAddrFunc: func(_ wgaddr.Address) error { return errors.New("os refused address") },
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: v4},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
v6Prefix := netip.MustParsePrefix("fd00::1/64")
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
reset, err := e.reconcileIPv6(conf)
require.Error(t, err, "UpdateAddr failure must surface")
assert.False(t, reset, "UpdateAddr failure must NOT request a reset")
}
func TestUpdateConfig_V6FirstAssignment_DoesNotResetEngine(t *testing.T) {
// The integration check: updateConfig must not return ErrResetConnection
// when the only change between current state and the new PeerConfig is
// "v6 added". Pre-fix this returned ErrResetConnection, tearing down
// every listener bound on the engine's netstack.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return v4 },
UpdateAddrFunc: func(_ wgaddr.Address) error { return nil },
IsUserspaceBindFunc: func() bool { return true },
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: v4, WgPort: 51820},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
statusRecorder: peer.NewRecorder("https://mgm.test"),
}
v6Prefix := netip.MustParsePrefix("fd00::1/64")
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
err := e.updateConfig(conf)
assert.NoError(t, err,
"updateConfig MUST NOT return ErrResetConnection when v6 is added for the first time — that's the bug fix")
assert.NotErrorIs(t, err, ErrResetConnection)
require.True(t, e.config.WgAddr.HasIPv6(), "engine config must record the assigned v6 after updateConfig")
assert.Equal(t, v6Prefix.Addr(), e.config.WgAddr.IPv6)
}

View File

@@ -66,7 +66,6 @@ import (
"github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
@@ -1707,82 +1706,12 @@ func getPeers(e *Engine) int {
return len(e.peerStore.PeersPubKey())
}
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
t.Helper()
b, err := netiputil.EncodePrefix(p)
require.NoError(t, err)
return b
}
func TestEngine_hasIPv6Changed(t *testing.T) {
v4Only := wgaddr.MustParseWGAddress("100.64.0.1/16")
v4v6 := wgaddr.MustParseWGAddress("100.64.0.1/16")
v4v6.IPv6 = netip.MustParseAddr("fd00::1")
v4v6.IPv6Net = netip.MustParsePrefix("fd00::1/64").Masked()
tests := []struct {
name string
current wgaddr.Address
confV6 []byte
expected bool
}{
{
name: "no v6 before, no v6 now",
current: v4Only,
confV6: nil,
expected: false,
},
{
name: "no v6 before, v6 added",
current: v4Only,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")),
expected: true,
},
{
name: "had v6, now removed",
current: v4v6,
confV6: nil,
expected: true,
},
{
name: "had v6, same v6",
current: v4v6,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")),
expected: false,
},
{
name: "had v6, different v6",
current: v4v6,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::2/64")),
expected: true,
},
{
name: "same v6 addr, different prefix length",
current: v4v6,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/80")),
expected: true,
},
{
name: "decode error keeps status quo",
current: v4Only,
confV6: []byte{1, 2, 3},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine := &Engine{
config: &EngineConfig{WgAddr: tt.current},
}
conf := &mgmtProto.PeerConfig{
AddressV6: tt.confV6,
}
assert.Equal(t, tt.expected, engine.hasIPv6Changed(conf))
})
}
}
// The former TestEngine_hasIPv6Changed has been superseded by
// engine_reconcileipv6_test.go — the underlying function (hasIPv6Changed)
// was replaced by reconcileIPv6, which applies "v6 added" / "v6 removed"
// in place instead of demanding a full engine reset. The behavioral
// matrix the old test enforced is now covered, with corrected expectations,
// by TestReconcileIPv6_* in that sibling file.
func TestFilterAllowedIPs(t *testing.T) {
v4v6Addr := wgaddr.MustParseWGAddress("100.64.0.1/16")

View File

@@ -108,6 +108,10 @@ type ConfigInput struct {
// Config Configuration type
type Config struct {
// Name is the human-readable profile name shown in CLI/UI listings.
// It is independent of the profile's on-disk filename (which is the ID).
Name string
// Wireguard private key of local peer
PrivateKey string
PreSharedKey string
@@ -270,6 +274,16 @@ func createNewConfig(input ConfigInput) (*Config, error) {
}
func (config *Config) apply(input ConfigInput) (updated bool, err error) {
if config.Name != "" {
sanitized, err := sanitizeDisplayName(config.Name)
if err != nil {
return false, fmt.Errorf("invalid profile name: %w", err)
}
if sanitized != config.Name {
config.Name = sanitized
updated = true
}
}
if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)

View File

@@ -0,0 +1,118 @@
package profilemanager
import (
"crypto/rand"
"encoding/hex"
"fmt"
"path/filepath"
"strings"
"unicode"
"unicode/utf8"
)
const (
// profileIDByteLen is the number of random bytes generated for a new
// profile ID. The resulting hex string is twice this length.
profileIDByteLen = 16
// shortIDLen is the number of leading characters of an ID we render in
// list output. Profiles per device are few, so 8 chars is collision-safe
// in practice and easy to type as a prefix.
shortIDLen = 8
// maxProfileNameLen caps the human-readable profile name to keep table
// output legible and prevent denial-of-service via huge JSON fields.
maxProfileNameLen = 128
// maxProfileIDLen bounds the on-disk filename we'll accept. New
// IDs are 32 hex chars, legacy stems are sanitized profile names. The
// cap is generous enough to cover both without permitting absurdly
// long filenames.
maxProfileIDLen = 64
)
type ID string
// generateProfileID returns a new random hex ID for a profile file.
func generateProfileID() (ID, error) {
buf := make([]byte, profileIDByteLen)
if _, err := rand.Read(buf); err != nil {
return "", fmt.Errorf("read random bytes: %w", err)
}
return ID(hex.EncodeToString(buf)), nil
}
// IsValidProfileFilenameStem reports whether id is safe to use as the stem
// of a profile JSON filename.
func IsValidProfileFilenameStem(id ID) bool {
s := id.String()
if s == "" || len(s) > maxProfileIDLen {
return false
}
if s == defaultProfileName {
return true
}
if strings.ContainsAny(s, `/\`) || strings.Contains(s, "..") {
return false
}
// filepath.Base catches any leftover separators on platforms with
// exotic path conventions.
if filepath.Base(s) != s {
return false
}
for _, r := range s {
if !(unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '-') {
return false
}
}
return true
}
// sanitizeDisplayName normalizes a user-supplied profile display name for
// storage. It strips ASCII control characters, rejects invalid UTF-8, and
// caps the length. Emojis, spaces, punctuation, and non-ASCII letters are
// preserved. Returns an error if nothing usable remains.
func sanitizeDisplayName(name string) (string, error) {
if !utf8.ValidString(name) {
return "", fmt.Errorf("name is not valid UTF-8")
}
name = StripCtrlChars(name)
name = strings.TrimSpace(name)
if name == "" {
return "", fmt.Errorf("name is empty after sanitization")
}
if utf8.RuneCountInString(name) > maxProfileNameLen {
return "", fmt.Errorf("name exceeds %d characters", maxProfileNameLen)
}
return name, nil
}
// StripCtrlChars control characters from a name before printing it.
func StripCtrlChars(name string) string {
var b strings.Builder
b.Grow(len(name))
for _, r := range name {
// Skip C0 controls and DEL, plus C1 controls (0x800x9F).
if r < 0x20 || r == 0x7F || (r >= 0x80 && r <= 0x9F) {
continue
}
b.WriteRune(r)
}
return b.String()
}
// ShortID truncates an ID for display.
func (id ID) ShortID() string {
if id == DefaultProfileName {
return DefaultProfileName
}
runes := []rune(id)
if len(runes) <= shortIDLen {
return id.String()
}
return string(runes[:shortIDLen])
}
func (id ID) String() string {
return string(id)
}

View File

@@ -19,19 +19,41 @@ const (
)
type Profile struct {
Name string
// ID is the on-disk filename stem (without .json). For new profiles
// it is a 32-char hex string; legacy profiles created before the
// ID-keyed layout keep their original name as their ID. The reserved
// value "default" identifies the special default profile.
ID ID
// Name is the human-readable display name. Falls back to ID when the
// underlying JSON has no "name" field set.
Name string
// Path is the absolute path to the profile JSON. Populated by the
// loader so callers do not have to reconstruct it from ID + dir.
Path string
IsActive bool
}
func (p *Profile) FilePath() (string, error) {
if p.Name == "" {
return "", fmt.Errorf("active profile name is empty")
if p.Path != "" {
return p.Path, nil
}
if p.Name == defaultProfileName {
id := p.ID
if id == "" {
id = ID(p.Name)
}
if id == "" {
return "", fmt.Errorf("profile ID is empty")
}
if id == defaultProfileName {
return DefaultConfigPath, nil
}
if !IsValidProfileFilenameStem(id) {
return "", fmt.Errorf("invalid profile ID: %q", id)
}
username, err := user.Current()
if err != nil {
return "", fmt.Errorf("failed to get current user: %w", err)
@@ -42,10 +64,13 @@ func (p *Profile) FilePath() (string, error) {
return "", fmt.Errorf("failed to get config directory for user %s: %w", username.Username, err)
}
return filepath.Join(configDir, p.Name+".json"), nil
return filepath.Join(configDir, id.String()+".json"), nil
}
func (p *Profile) IsDefault() bool {
if p.ID != "" {
return p.ID == defaultProfileName
}
return p.Name == defaultProfileName
}
@@ -57,18 +82,24 @@ func NewProfileManager() *ProfileManager {
return &ProfileManager{}
}
// GetActiveProfile returns the active profile as recorded in the local
// user state file. Only ID is populated.
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
prof := pm.getActiveProfileState()
return &Profile{Name: prof}, nil
id := pm.getActiveProfileState()
return &Profile{ID: id}, nil
}
func (pm *ProfileManager) SwitchProfile(profileName string) error {
profileName = sanitizeProfileName(profileName)
// SwitchProfile records the given profile ID as active in the local user
// state file.
func (pm *ProfileManager) SwitchProfile(id ID) error {
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid profile ID: %q", id)
}
if err := pm.setActiveProfileState(profileName); err != nil {
if err := pm.setActiveProfileState(id); err != nil {
return fmt.Errorf("failed to switch profile: %w", err)
}
return nil
@@ -85,7 +116,7 @@ func sanitizeProfileName(name string) string {
}, name)
}
func (pm *ProfileManager) getActiveProfileState() string {
func (pm *ProfileManager) getActiveProfileState() ID {
configDir, err := getConfigDir()
if err != nil {
@@ -113,10 +144,10 @@ func (pm *ProfileManager) getActiveProfileState() string {
return defaultProfileName
}
return profileName
return ID(profileName)
}
func (pm *ProfileManager) setActiveProfileState(profileName string) error {
func (pm *ProfileManager) setActiveProfileState(id ID) error {
configDir, err := getConfigDir()
if err != nil {
@@ -125,7 +156,7 @@ func (pm *ProfileManager) setActiveProfileState(profileName string) error {
statePath := filepath.Join(configDir, activeProfileStateFilename)
err = os.WriteFile(statePath, []byte(profileName), 0600)
err = os.WriteFile(statePath, []byte(id), 0600)
if err != nil {
return fmt.Errorf("failed to write active profile state: %w", err)
}
@@ -142,7 +173,7 @@ func GetLoginHint() string {
return ""
}
profileState, err := pm.GetProfileState(activeProf.Name)
profileState, err := pm.GetProfileState(activeProf.ID)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
return ""

View File

@@ -50,14 +50,14 @@ func TestServiceManager_CreateAndGetDefaultProfile(t *testing.T) {
state, err := sm.GetActiveProfileState()
assert.NoError(t, err)
assert.Equal(t, state.Name, defaultProfileName) // No active profile state yet
assert.Equal(t, defaultProfileName, state.ID.String()) // No active profile state yet
err = sm.SetActiveProfileStateToDefault()
assert.NoError(t, err)
active, err := sm.GetActiveProfileState()
assert.NoError(t, err)
assert.Equal(t, "default", active.Name)
assert.Equal(t, "default", active.ID.String())
})
})
}
@@ -92,14 +92,14 @@ func TestServiceManager_SetActiveProfileState(t *testing.T) {
currUser, err := user.Current()
assert.NoError(t, err)
sm := &ServiceManager{}
state := &ActiveProfileState{Name: "foo", Username: currUser.Username}
state := &ActiveProfileState{ID: "foo", Username: currUser.Username}
err = sm.SetActiveProfileState(state)
assert.NoError(t, err)
// Should error on nil or incomplete state
err = sm.SetActiveProfileState(nil)
assert.Error(t, err)
err = sm.SetActiveProfileState(&ActiveProfileState{Name: "", Username: ""})
err = sm.SetActiveProfileState(&ActiveProfileState{ID: "", Username: ""})
assert.Error(t, err)
})
})

View File

@@ -2,6 +2,7 @@ package profilemanager
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
@@ -23,12 +24,43 @@ var (
DefaultConfigPathDir = ""
DefaultConfigPath = ""
ActiveProfileStatePath = ""
)
var (
ErrorOldDefaultConfigNotFound = errors.New("old default config not found")
)
// ErrAmbiguousHandle is returned when a profile handle (ID prefix or name)
// matches more than one profile. Callers can render Candidates to help the
// user disambiguate.
type ErrAmbiguousHandle struct {
Handle string
Candidates []Profile
Kind AmbiguityKind
}
// AmbiguityKind describes which matcher produced the ambiguity, so callers
// can tailor the error message.
type AmbiguityKind int
const (
AmbiguityKindIDPrefix AmbiguityKind = iota
AmbiguityKindName
)
// profileMeta is the minimal slice of a profile JSON we need, so we avoid
// reading all fields
type profileMeta struct {
Name string
}
func (e *ErrAmbiguousHandle) Error() string {
switch e.Kind {
case AmbiguityKindIDPrefix:
return fmt.Sprintf("ID prefix %q is ambiguous (matches %d profiles)", e.Handle, len(e.Candidates))
default:
return fmt.Sprintf("name %q is ambiguous (%d profiles share this name)", e.Handle, len(e.Candidates))
}
}
func init() {
DefaultConfigPathDir = "/var/lib/netbird/"
@@ -54,25 +86,34 @@ func init() {
}
type ActiveProfileState struct {
Name string `json:"name"`
// ID is the on-disk filename stem of the active profile. The JSON tag stays
// as "name" for backwards compatibility with active state files written
// before the ID-based config files. Legacy values were profile names, which
// were also the legacy filename stems, so they still resolve to the correct
// file on disk.
ID ID `json:"name"`
Username string `json:"username"`
}
func (a *ActiveProfileState) FilePath() (string, error) {
if a.Name == "" {
return "", fmt.Errorf("active profile name is empty")
if a.ID == "" {
return "", fmt.Errorf("active profile ID is empty")
}
if a.Name == defaultProfileName {
if a.ID == defaultProfileName {
return DefaultConfigPath, nil
}
if !IsValidProfileFilenameStem(a.ID) {
return "", fmt.Errorf("invalid profile ID: %q", a.ID)
}
configDir, err := getConfigDirForUser(a.Username)
if err != nil {
return "", fmt.Errorf("failed to get config directory for user %s: %w", a.Username, err)
}
return filepath.Join(configDir, a.Name+".json"), nil
return filepath.Join(configDir, a.ID.String()+".json"), nil
}
type ServiceManager struct {
@@ -178,7 +219,7 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
return nil, fmt.Errorf("failed to set active profile to default: %w", err)
}
return &ActiveProfileState{
Name: "default",
ID: defaultProfileName,
Username: "",
}, nil
} else {
@@ -186,12 +227,12 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
}
}
if activeProfile.Name == "" {
if activeProfile.ID == "" {
if err := s.SetActiveProfileStateToDefault(); err != nil {
return nil, fmt.Errorf("failed to set active profile to default: %w", err)
}
return &ActiveProfileState{
Name: "default",
ID: defaultProfileName,
Username: "",
}, nil
}
@@ -216,25 +257,29 @@ func (s *ServiceManager) setDefaultActiveState() error {
}
func (s *ServiceManager) SetActiveProfileState(a *ActiveProfileState) error {
if a == nil || a.Name == "" {
if a == nil || a.ID == "" {
return errors.New("invalid active profile state")
}
if a.Name != defaultProfileName && a.Username == "" {
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.Name)
if a.ID != defaultProfileName && a.Username == "" {
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.ID)
}
if a.ID != defaultProfileName && !IsValidProfileFilenameStem(a.ID) {
return fmt.Errorf("invalid profile ID: %q", a.ID)
}
if err := util.WriteJsonWithRestrictedPermission(context.Background(), ActiveProfileStatePath, a); err != nil {
return fmt.Errorf("failed to write active profile state: %w", err)
}
log.Infof("active profile set to %s for %s", a.Name, a.Username)
log.Infof("active profile set to %s for %s", a.ID, a.Username)
return nil
}
func (s *ServiceManager) SetActiveProfileStateToDefault() error {
return s.SetActiveProfileState(&ActiveProfileState{
Name: "default",
ID: defaultProfileName,
Username: "",
})
}
@@ -243,57 +288,117 @@ func (s *ServiceManager) DefaultProfilePath() string {
return DefaultConfigPath
}
func (s *ServiceManager) AddProfile(profileName, username string) error {
// AddProfile creates a new profile with a generated ID. The user-supplied
// displayName is stored inside the JSON's name field, the on-disk filename
// uses the generated ID.
//
// The returned Profile carries the freshly-generated ID so callers can
// show it to the user (and so the gRPC AddProfileResponse can include
// it).
func (s *ServiceManager) AddProfile(displayName, username string) (*Profile, error) {
configDir, err := s.getConfigDir(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
return nil, fmt.Errorf("failed to get config directory: %w", err)
}
profileName = sanitizeProfileName(profileName)
if profileName == defaultProfileName {
return fmt.Errorf("cannot create profile with reserved name: %s", defaultProfileName)
}
profPath := filepath.Join(configDir, profileName+".json")
profileExists, err := fileExists(profPath)
displayName, err = sanitizeDisplayName(displayName)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if profileExists {
return ErrProfileAlreadyExists
return nil, fmt.Errorf("invalid profile name: %w", err)
}
id, err := generateProfileID()
if err != nil {
return nil, fmt.Errorf("generate profile id: %w", err)
}
profPath := filepath.Join(configDir, id.String()+".json")
cfg, err := createNewConfig(ConfigInput{ConfigPath: profPath})
if err != nil {
return fmt.Errorf("failed to create new config: %w", err)
return nil, fmt.Errorf("failed to create new config: %w", err)
}
cfg.Name = displayName
if err := util.WriteJson(context.Background(), profPath, cfg); err != nil {
return nil, fmt.Errorf("failed to write profile config: %w", err)
}
err = util.WriteJson(context.Background(), profPath, cfg)
return &Profile{
ID: id,
Name: displayName,
Path: profPath,
}, nil
}
func (s *ServiceManager) RenameProfile(id ID, username string, newName string) error {
displayName, err := sanitizeDisplayName(newName)
if err != nil {
return fmt.Errorf("failed to write profile config: %w", err)
return fmt.Errorf("invalid profile name: %w", err)
}
if !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid profile ID: %q", id)
}
profiles, err := s.loadAllProfiles(username)
if err != nil {
return fmt.Errorf("load profiles: %w", err)
}
var target *Profile
for i := range profiles {
if profiles[i].ID == id {
target = &profiles[i]
break
}
}
if target == nil {
return ErrProfileNotFound
}
data, err := os.ReadFile(target.Path)
if err != nil {
return err
}
var cfg Config
if err := json.Unmarshal(data, &cfg); err != nil {
return err
}
cfg.Name = displayName
if err := util.WriteJson(context.Background(), target.Path, cfg); err != nil {
return fmt.Errorf("failed to write profile name: %w", err)
}
return nil
}
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
configDir, err := s.getConfigDir(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
// RemoveProfile deletes the profile identified by id. Callers must have
// already resolved any user-supplied handle to a concrete ID via
// ResolveProfile.
func (s *ServiceManager) RemoveProfile(id ID, username string) error {
if id == defaultProfileName {
defaultName := readProfileName(DefaultConfigPath)
if defaultName == "" {
defaultName = defaultProfileName
}
return fmt.Errorf("cannot remove default profile with name: %s", defaultName)
}
if !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid profile ID: %q", id)
}
profileName = sanitizeProfileName(profileName)
if profileName == defaultProfileName {
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
}
profPath := filepath.Join(configDir, profileName+".json")
profileExists, err := fileExists(profPath)
profiles, err := s.loadAllProfiles(username)
if err != nil {
return fmt.Errorf("failed to check if profile exists: %w", err)
return fmt.Errorf("load profiles: %w", err)
}
if !profileExists {
var target *Profile
for i := range profiles {
if profiles[i].ID == id {
target = &profiles[i]
break
}
}
if target == nil {
return ErrProfileNotFound
}
@@ -301,57 +406,26 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
if err != nil && !errors.Is(err, ErrNoActiveProfile) {
return fmt.Errorf("failed to get active profile: %w", err)
}
if activeProf != nil && activeProf.Name == profileName {
return fmt.Errorf("cannot remove active profile: %s", profileName)
if activeProf != nil && activeProf.ID == id {
return fmt.Errorf("cannot remove active profile: %s", id)
}
err = util.RemoveJson(profPath)
if err != nil {
if err := util.RemoveJson(target.Path); err != nil {
return fmt.Errorf("failed to remove profile config: %w", err)
}
stateFile := filepath.Join(filepath.Dir(target.Path), id.String()+".state.json")
if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
log.Warnf("failed to remove profile state file %s: %v", stateFile, err)
}
return nil
}
// ListProfiles returns every profile for the given user, including the
// default profile, with IsActive flags set.
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("failed to get config directory: %w", err)
}
files, err := util.ListFiles(configDir, "*.json")
if err != nil {
return nil, fmt.Errorf("failed to list profile files: %w", err)
}
var filtered []string
for _, file := range files {
if strings.HasSuffix(file, "state.json") {
continue // skip state files
}
filtered = append(filtered, file)
}
sort.Strings(filtered)
var activeProfName string
activeProf, err := s.GetActiveProfileState()
if err == nil {
activeProfName = activeProf.Name
}
var profiles []Profile
// add default profile always
profiles = append(profiles, Profile{Name: defaultProfileName, IsActive: activeProfName == "" || activeProfName == defaultProfileName})
for _, file := range filtered {
profileName := strings.TrimSuffix(filepath.Base(file), ".json")
var isActive bool
if activeProfName != "" && activeProfName == profileName {
isActive = true
}
profiles = append(profiles, Profile{Name: profileName, IsActive: isActive})
}
return profiles, nil
return s.loadAllProfiles(username)
}
// GetStatePath returns the path to the state file based on the operating system
@@ -369,7 +443,12 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath
}
if activeProf.Name == defaultProfileName {
if activeProf.ID == defaultProfileName {
return defaultStatePath
}
if !IsValidProfileFilenameStem(activeProf.ID) {
log.Warnf("invalid active profile ID %q, using default state path", activeProf.ID)
return defaultStatePath
}
@@ -379,7 +458,7 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath
}
return filepath.Join(configDir, activeProf.Name+".state.json")
return filepath.Join(configDir, activeProf.ID.String()+".state.json")
}
// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
@@ -390,3 +469,169 @@ func (s *ServiceManager) getConfigDir(username string) (string, error) {
return getConfigDirForUser(username)
}
// loadAllProfiles returns every profile visible to the daemon for the
// given user, including the default profile. The returned slice is sorted
// by ID for a stable display order.
//
// Each Profile is fully populated: ID is the filename stem, Name comes
// from the JSON's "name" field (falling back to the filename stem when absent)
// and Path is built from a basename read off disk.
func (s *ServiceManager) loadAllProfiles(username string) ([]Profile, error) {
activeID, activeIsDefault := s.activeProfileID()
defaultName := readProfileName(DefaultConfigPath)
if defaultName == "" {
defaultName = defaultProfileName
}
profiles := []Profile{{
ID: defaultProfileName,
Name: defaultName,
Path: DefaultConfigPath,
IsActive: activeIsDefault,
}}
configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("get config directory: %w", err)
}
entries, err := os.ReadDir(configDir)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return profiles, nil
}
return nil, fmt.Errorf("read profile directory: %w", err)
}
var fileProfiles []Profile
for _, entry := range entries {
if entry.IsDir() {
continue
}
base := entry.Name()
if !strings.HasSuffix(base, ".json") {
continue
}
if strings.HasSuffix(base, ".state.json") {
continue
}
stem := ID(strings.TrimSuffix(base, ".json"))
if stem == defaultProfileName {
// default lives at the top-level config dir, not under /<user>
continue
}
if !IsValidProfileFilenameStem(ID(stem)) {
continue
}
path := filepath.Join(configDir, base)
name := readProfileName(path)
if name == "" {
name = stem.String()
}
fileProfiles = append(fileProfiles, Profile{
ID: stem,
Name: name,
Path: path,
IsActive: stem == ID(activeID),
})
}
sort.Slice(fileProfiles, func(i, j int) bool {
if fileProfiles[i].Name != fileProfiles[j].Name {
return fileProfiles[i].Name < fileProfiles[j].Name
}
// Sort tie-break on ID so duplicate names always render in the same order.
return fileProfiles[i].ID < fileProfiles[j].ID
})
profiles = append(profiles, fileProfiles...)
return profiles, nil
}
// readProfileName parses just the "name" field from the profile Json.
func readProfileName(path string) string {
data, err := os.ReadFile(path)
if err != nil {
return ""
}
var meta profileMeta
if err := json.Unmarshal(data, &meta); err != nil {
return ""
}
return meta.Name
}
// activeProfileID returns the currently-active profile's ID. The second
// return value is true when the active profile is the default one.
func (s *ServiceManager) activeProfileID() (ID, bool) {
state, err := s.GetActiveProfileState()
if err != nil || state == nil {
return defaultProfileName, true
}
if state.ID == "" || state.ID == defaultProfileName {
return defaultProfileName, true
}
return state.ID, false
}
// ResolveProfile turns a user-supplied handle into a Profile. Resolution
// precedence is: exact ID match, then unique exact name, then unique ID
// prefix. Ambiguous matches return *ErrAmbiguousHandle so callers can
// surface the candidates.
func (s *ServiceManager) ResolveProfile(handle, username string) (*Profile, error) {
if handle == "" {
return nil, fmt.Errorf("profile handle is empty")
}
profiles, err := s.loadAllProfiles(username)
if err != nil {
return nil, err
}
for i := range profiles {
if profiles[i].ID == ID(handle) {
return &profiles[i], nil
}
}
var nameMatches []Profile
for i := range profiles {
if profiles[i].Name == handle {
nameMatches = append(nameMatches, profiles[i])
}
}
if len(nameMatches) == 1 {
return &nameMatches[0], nil
}
if len(nameMatches) > 1 {
return nil, &ErrAmbiguousHandle{
Handle: handle,
Candidates: nameMatches,
Kind: AmbiguityKindName,
}
}
// ID prefix match. Skip the default profile so `select d` does not
// accidentally pick it via prefix.
var prefixMatches []Profile
for i := range profiles {
if profiles[i].ID == defaultProfileName {
continue
}
if strings.HasPrefix(profiles[i].ID.String(), handle) {
prefixMatches = append(prefixMatches, profiles[i])
}
}
if len(prefixMatches) == 1 {
return &prefixMatches[0], nil
}
if len(prefixMatches) > 1 {
return nil, &ErrAmbiguousHandle{
Handle: handle,
Candidates: prefixMatches,
Kind: AmbiguityKindIDPrefix,
}
}
return nil, ErrProfileNotFound
}

View File

@@ -0,0 +1,230 @@
package profilemanager
import (
"context"
"errors"
"os"
"os/user"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/util"
)
// withTestSM wires up patched globals + a clean config dir and returns a
// fully initialized ServiceManager plus the username we are scoped to.
func withTestSM(t *testing.T, fn func(sm *ServiceManager, username string)) {
t.Helper()
withTempConfigDir(t, func(configDir string) {
withPatchedGlobals(t, configDir, func() {
u, err := user.Current()
require.NoError(t, err)
sm := &ServiceManager{}
require.NoError(t, sm.CreateDefaultProfile())
fn(sm, u.Username)
})
})
}
func TestServiceProfile_ExactID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
got, err := sm.ResolveProfile(created.ID.String(), username)
require.NoError(t, err)
assert.Equal(t, created.ID, got.ID)
assert.Equal(t, "work", got.Name)
})
}
func TestServiceProfile_IDPrefix(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
prefix := created.ID[:4]
got, err := sm.ResolveProfile(prefix.String(), username)
require.NoError(t, err)
assert.Equal(t, created.ID, got.ID)
})
}
func TestServiceProfile_AmbiguousPrefix(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
// Plant two profiles whose IDs share a known prefix by writing
// the files directly, since generated IDs are random.
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
for _, id := range []string{"abcd1111aaaa", "abcd2222bbbb"} {
path := filepath.Join(configDir, id+".json")
require.NoError(t, util.WriteJson(context.Background(), path, &Config{Name: id}))
}
_, err = sm.ResolveProfile("abcd", username)
var amb *ErrAmbiguousHandle
require.ErrorAs(t, err, &amb)
assert.Equal(t, AmbiguityKindIDPrefix, amb.Kind)
assert.Len(t, amb.Candidates, 2)
})
}
func TestServiceProfile_ExactNameUnique(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.AddProfile("work", username)
require.NoError(t, err)
got, err := sm.ResolveProfile("work", username)
require.NoError(t, err)
assert.Equal(t, "work", got.Name)
})
}
func TestServiceProfile_AmbiguousName(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.AddProfile("work", username)
require.NoError(t, err)
_, err = sm.AddProfile("work", username)
require.NoError(t, err)
_, err = sm.ResolveProfile("work", username)
var amb *ErrAmbiguousHandle
require.ErrorAs(t, err, &amb)
assert.Equal(t, AmbiguityKindName, amb.Kind)
assert.Len(t, amb.Candidates, 2)
})
}
func TestServiceProfile_NotFound(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.ResolveProfile("nope", username)
assert.ErrorIs(t, err, ErrProfileNotFound)
})
}
func TestServiceProfile_DefaultByExactID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
got, err := sm.ResolveProfile(defaultProfileName, username)
require.NoError(t, err)
assert.Equal(t, defaultProfileName, got.ID.String())
})
}
func TestServiceProfile_LegacyFilenameCoexists(t *testing.T) {
// Legacy profiles stored as <name>.json with no "name" JSON field
// should still be discoverable by name and removable by name.
withTestSM(t, func(sm *ServiceManager, username string) {
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
path := filepath.Join(configDir, "legacy.json")
require.NoError(t, util.WriteJson(context.Background(), path, &Config{}))
got, err := sm.ResolveProfile("legacy", username)
require.NoError(t, err)
assert.Equal(t, "legacy", got.ID.String())
// Name falls back to the filename stem when JSON omits it.
assert.Equal(t, "legacy", got.Name)
})
}
func TestAddProfile_AllowsDuplicateWithFlag(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
first, err := sm.AddProfile("work", username)
require.NoError(t, err)
second, err := sm.AddProfile("work", username)
require.NoError(t, err)
assert.NotEqual(t, first.ID, second.ID)
assert.Equal(t, "work", second.Name)
})
}
func TestAddProfile_RejectsInvalidNames(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
cases := []string{
"", // empty
"\x00\x01", // only control chars (becomes empty)
strings.Repeat("a", maxProfileNameLen+1), // too long
}
for _, name := range cases {
_, err := sm.AddProfile(name, username)
assert.Error(t, err, "expected error for %q", name)
}
})
}
func TestRemoveProfile_RejectsInvalidID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
err := sm.RemoveProfile("../escape", username)
assert.Error(t, err)
})
}
func TestSanitizeDisplayName(t *testing.T) {
cases := []struct {
in string
want string
wantErr bool
}{
{"work", "work", false},
{"My Work Account", "My Work Account", false},
{"emoji 🚀 ok", "emoji 🚀 ok", false},
{"漢字テスト", "漢字テスト", false},
{"with\x00null", "withnull", false},
{"\x01\x02\x03", "", true},
{"", "", true},
}
for _, tc := range cases {
got, err := sanitizeDisplayName(tc.in)
if tc.wantErr {
assert.Error(t, err, "case %q", tc.in)
continue
}
assert.NoError(t, err, "case %q", tc.in)
assert.Equal(t, tc.want, got, "case %q", tc.in)
}
}
func TestIsValidProfileFilenameStem(t *testing.T) {
cases := []struct {
in string
want bool
}{
{"default", true},
{"abc123def456", true},
{"legacy-name", true},
{"legacy_name", true},
{"", false},
{"..", false},
{"../etc", false},
{"foo/bar", false},
{`foo\bar`, false},
{"with space", false},
{"with.dot", false},
{strings.Repeat("a", maxProfileIDLen+1), false},
}
for _, tc := range cases {
got := IsValidProfileFilenameStem(ID(tc.in))
assert.Equal(t, tc.want, got, "case %q", tc.in)
}
}
func TestRemoveProfile_DeletesStateFile(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
statePath := filepath.Join(configDir, created.ID.String()+".state.json")
require.NoError(t, os.WriteFile(statePath, []byte(`{"email":"a@b"}`), 0600))
require.NoError(t, sm.RemoveProfile(created.ID, username))
_, err = os.Stat(statePath)
assert.True(t, errors.Is(err, os.ErrNotExist), "state file should be removed")
})
}

View File

@@ -13,13 +13,20 @@ type ProfileState struct {
Email string `json:"email"`
}
func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, error) {
// GetProfileState reads the per-profile state file keyed by profile ID.
// The state file lives in the user's config directory. Legacy state files
// keyed by the old profile name remain readable.
func (pm *ProfileManager) GetProfileState(id ID) (*ProfileState, error) {
configDir, err := getConfigDir()
if err != nil {
return nil, fmt.Errorf("get config directory: %w", err)
}
stateFile := filepath.Join(configDir, profileName+".state.json")
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return nil, fmt.Errorf("invalid profile ID: %q", id)
}
stateFile := filepath.Join(configDir, id.String()+".state.json")
stateFileExists, err := fileExists(stateFile)
if err != nil {
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
@@ -51,7 +58,12 @@ func (pm *ProfileManager) SetActiveProfileState(state *ProfileState) error {
return fmt.Errorf("get active profile: %w", err)
}
stateFile := filepath.Join(configDir, activeProf.Name+".state.json")
id := activeProf.ID
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid active profile ID: %q", id)
}
stateFile := filepath.Join(configDir, id.String()+".state.json")
err = util.WriteJsonWithRestrictedPermission(context.Background(), stateFile, state)
if err != nil {
return fmt.Errorf("write profile state: %w", err)

File diff suppressed because it is too large Load Diff

View File

@@ -85,6 +85,8 @@ service DaemonService {
rpc AddProfile(AddProfileRequest) returns (AddProfileResponse) {}
rpc RenameProfile(RenameProfileRequest) returns (RenameProfileResponse) {}
rpc RemoveProfile(RemoveProfileRequest) returns (RemoveProfileResponse) {}
rpc ListProfiles(ListProfilesRequest) returns (ListProfilesResponse) {}
@@ -625,11 +627,18 @@ message GetEventsResponse {
}
message SwitchProfileRequest {
// profileName is treated as a handle: exact ID, unique ID prefix, or
// unique display name. The daemon resolves it server-side.
optional string profileName = 1;
optional string username = 2;
}
message SwitchProfileResponse {}
message SwitchProfileResponse {
// id is the resolved on-disk ID of the profile that became active.
// Lets CLI clients update their local active-profile state without
// duplicating the resolution logic.
string id = 1;
}
message SetConfigRequest {
string username = 1;
@@ -696,17 +705,42 @@ message SetConfigResponse{}
message AddProfileRequest {
string username = 1;
// profileName carries the human-readable display name for the new
// profile. The on-disk filename is a separately-generated ID.
string profileName = 2;
}
message AddProfileResponse {}
message AddProfileResponse {
// id is the generated on-disk ID of the new profile. CLI clients
// display a truncated form, UI clients can ignore it.
string id = 1;
}
message RenameProfileRequest {
string username = 1;
// handle: an exact ID, a unique ID prefix, or a unique display name.
string handle = 2;
// newProfileName is the new human-readable display name for the profile.
string newProfileName = 3;
}
message RenameProfileResponse {
// confirm the old profile name after resolving handle.
string oldProfileName = 1;
}
message RemoveProfileRequest {
string username = 1;
// profileName is treated as a handle: an exact ID, a unique ID
// prefix, or a unique display name. Resolution happens server-side.
string profileName = 2;
}
message RemoveProfileResponse {}
message RemoveProfileResponse {
// id is the full resolved ID of the removed profile, so callers can
// confirm exactly which profile a name/prefix handle resolved to.
string id = 1;
}
message ListProfilesRequest {
string username = 1;
@@ -719,6 +753,7 @@ message ListProfilesResponse {
message Profile {
string name = 1;
bool is_active = 2;
string id = 3;
}
message GetActiveProfileRequest {}
@@ -726,6 +761,7 @@ message GetActiveProfileRequest {}
message GetActiveProfileResponse {
string profileName = 1;
string username = 2;
string id = 3;
}
message LogoutRequest {

View File

@@ -45,6 +45,7 @@ const (
DaemonService_SwitchProfile_FullMethodName = "/daemon.DaemonService/SwitchProfile"
DaemonService_SetConfig_FullMethodName = "/daemon.DaemonService/SetConfig"
DaemonService_AddProfile_FullMethodName = "/daemon.DaemonService/AddProfile"
DaemonService_RenameProfile_FullMethodName = "/daemon.DaemonService/RenameProfile"
DaemonService_RemoveProfile_FullMethodName = "/daemon.DaemonService/RemoveProfile"
DaemonService_ListProfiles_FullMethodName = "/daemon.DaemonService/ListProfiles"
DaemonService_GetActiveProfile_FullMethodName = "/daemon.DaemonService/GetActiveProfile"
@@ -112,6 +113,7 @@ type DaemonServiceClient interface {
SwitchProfile(ctx context.Context, in *SwitchProfileRequest, opts ...grpc.CallOption) (*SwitchProfileResponse, error)
SetConfig(ctx context.Context, in *SetConfigRequest, opts ...grpc.CallOption) (*SetConfigResponse, error)
AddProfile(ctx context.Context, in *AddProfileRequest, opts ...grpc.CallOption) (*AddProfileResponse, error)
RenameProfile(ctx context.Context, in *RenameProfileRequest, opts ...grpc.CallOption) (*RenameProfileResponse, error)
RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error)
ListProfiles(ctx context.Context, in *ListProfilesRequest, opts ...grpc.CallOption) (*ListProfilesResponse, error)
GetActiveProfile(ctx context.Context, in *GetActiveProfileRequest, opts ...grpc.CallOption) (*GetActiveProfileResponse, error)
@@ -422,6 +424,16 @@ func (c *daemonServiceClient) AddProfile(ctx context.Context, in *AddProfileRequ
return out, nil
}
func (c *daemonServiceClient) RenameProfile(ctx context.Context, in *RenameProfileRequest, opts ...grpc.CallOption) (*RenameProfileResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RenameProfileResponse)
err := c.cc.Invoke(ctx, DaemonService_RenameProfile_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) RemoveProfile(ctx context.Context, in *RemoveProfileRequest, opts ...grpc.CallOption) (*RemoveProfileResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RemoveProfileResponse)
@@ -613,6 +625,7 @@ type DaemonServiceServer interface {
SwitchProfile(context.Context, *SwitchProfileRequest) (*SwitchProfileResponse, error)
SetConfig(context.Context, *SetConfigRequest) (*SetConfigResponse, error)
AddProfile(context.Context, *AddProfileRequest) (*AddProfileResponse, error)
RenameProfile(context.Context, *RenameProfileRequest) (*RenameProfileResponse, error)
RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error)
ListProfiles(context.Context, *ListProfilesRequest) (*ListProfilesResponse, error)
GetActiveProfile(context.Context, *GetActiveProfileRequest) (*GetActiveProfileResponse, error)
@@ -723,6 +736,9 @@ func (UnimplementedDaemonServiceServer) SetConfig(context.Context, *SetConfigReq
func (UnimplementedDaemonServiceServer) AddProfile(context.Context, *AddProfileRequest) (*AddProfileResponse, error) {
return nil, status.Error(codes.Unimplemented, "method AddProfile not implemented")
}
func (UnimplementedDaemonServiceServer) RenameProfile(context.Context, *RenameProfileRequest) (*RenameProfileResponse, error) {
return nil, status.Error(codes.Unimplemented, "method RenameProfile not implemented")
}
func (UnimplementedDaemonServiceServer) RemoveProfile(context.Context, *RemoveProfileRequest) (*RemoveProfileResponse, error) {
return nil, status.Error(codes.Unimplemented, "method RemoveProfile not implemented")
}
@@ -1237,6 +1253,24 @@ func _DaemonService_AddProfile_Handler(srv interface{}, ctx context.Context, dec
return interceptor(ctx, in, info, handler)
}
func _DaemonService_RenameProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RenameProfileRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).RenameProfile(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: DaemonService_RenameProfile_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).RenameProfile(ctx, req.(*RenameProfileRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_RemoveProfile_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RemoveProfileRequest)
if err := dec(in); err != nil {
@@ -1567,6 +1601,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "AddProfile",
Handler: _DaemonService_AddProfile_Handler,
},
{
MethodName: "RenameProfile",
Handler: _DaemonService_RenameProfile_Handler,
},
{
MethodName: "RemoveProfile",
Handler: _DaemonService_RemoveProfile_Handler,

View File

@@ -79,7 +79,7 @@ func TestPersistLoginOverrides(t *testing.T) {
_, err := profilemanager.UpdateOrCreateConfig(seed)
require.NoError(t, err, "seed config")
activeProf := &profilemanager.ActiveProfileState{Name: "default"}
activeProf := &profilemanager.ActiveProfileState{ID: "default"}
err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK)
require.NoError(t, err, "persistLoginOverrides")

View File

@@ -78,7 +78,7 @@ type Server struct {
// changed by connectWithRetryRuns goroutine exit — for that
// (goroutine-still-alive) check, see connectionGoroutineRunning() which
// derives from clientGiveUpChan close state. Protected by s.mutex.
clientRunning bool
clientRunning bool
clientRunningChan chan struct{}
clientGiveUpChan chan struct{} // closed when connectWithRetryRuns goroutine exits
@@ -375,7 +375,7 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
return nil, err
}
config, err := setConfigInputFromRequest(msg)
config, err := s.setConfigInputFromRequest(msg)
if err != nil {
return nil, err
}
@@ -398,17 +398,17 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
// field is its own optional case. Returns the resolved ConfigInput
// and a non-nil error only when the active profile file path cannot
// be determined.
func setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.ConfigInput, error) {
func (s *Server) setConfigInputFromRequest(msg *proto.SetConfigRequest) (profilemanager.ConfigInput, error) {
var config profilemanager.ConfigInput
profState := profilemanager.ActiveProfileState{
Name: msg.ProfileName,
Username: msg.Username,
}
profPath, err := profState.FilePath()
resolved, err := s.resolveProfileHandle(msg.ProfileName, msg.Username)
if err != nil {
log.Errorf("failed to get active profile file path: %v", err)
return config, fmt.Errorf("failed to get active profile file path: %w", err)
log.Errorf("failed to resolve profile %q: %v", msg.ProfileName, err)
return config, err
}
profPath := resolved.Path
if profPath == "" {
profPath = profilemanager.DefaultConfigPath
}
config.ConfigPath = profPath
@@ -535,30 +535,9 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
}
if msg.ProfileName != nil {
if *msg.ProfileName != "default" && (msg.Username == nil || *msg.Username == "") {
log.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName)
return nil, fmt.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName)
}
var username string
if *msg.ProfileName != "default" {
username = *msg.Username
}
if *msg.ProfileName != activeProf.Name && username != activeProf.Username {
if s.checkProfilesDisabled() {
log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled")
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
log.Infof("switching to profile %s for user '%s'", *msg.ProfileName, username)
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: *msg.ProfileName,
Username: username,
}); err != nil {
log.Errorf("failed to set active profile state: %v", err)
return nil, fmt.Errorf("failed to set active profile state: %w", err)
}
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
log.Errorf("failed to switch profile: %v", err)
return nil, err
}
}
@@ -568,7 +547,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
log.Infof("active profile: %s for %s", activeProf.ID, activeProf.Username)
s.mutex.Lock()
@@ -806,10 +785,10 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
}
if msg != nil && msg.ProfileName != nil {
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
s.mutex.Unlock()
log.Errorf("failed to switch profile: %v", err)
return nil, fmt.Errorf("failed to switch profile: %w", err)
return nil, err
}
}
@@ -820,7 +799,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
log.Infof("active profile: %s for %s", activeProf.ID, activeProf.Username)
config, _, err := s.getConfig(activeProf)
if err != nil {
@@ -864,34 +843,60 @@ func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error)
}
}
func (s *Server) switchProfileIfNeeded(profileName string, userName *string, activeProf *profilemanager.ActiveProfileState) error {
if profileName != "default" && (userName == nil || *userName == "") {
log.Errorf("profile name is set to %s, but username is not provided", profileName)
return fmt.Errorf("profile name is set to %s, but username is not provided", profileName)
// resolveProfileHandle resolves a wire-level profile handle (display
// name, ID, or unique ID prefix) to a concrete profile. Returns gRPC
// status errors so handlers can return them directly.
func (s *Server) resolveProfileHandle(handle, username string) (*profilemanager.Profile, error) {
p, err := s.profileManager.ResolveProfile(handle, username)
if err == nil {
return p, nil
}
var amb *profilemanager.ErrAmbiguousHandle
if errors.As(err, &amb) {
return nil, gstatus.Errorf(codes.InvalidArgument, "%v", amb)
}
if errors.Is(err, profilemanager.ErrProfileNotFound) {
return nil, gstatus.Errorf(codes.NotFound, "profile %q not found", handle)
}
return nil, fmt.Errorf("resolve profile: %w", err)
}
// switchProfileIfNeeded resolves the user-supplied handle, updates the
// active profile state if it differs from the current one, and returns
// the resolved profile so callers can include its ID in RPC responses.
func (s *Server) switchProfileIfNeeded(handle string, userName *string, activeProf *profilemanager.ActiveProfileState) (*profilemanager.Profile, error) {
if handle != profilemanager.DefaultProfileName && (userName == nil || *userName == "") {
log.Errorf("profile name is set to %s, but username is not provided", handle)
return nil, fmt.Errorf("profile name is set to %s, but username is not provided", handle)
}
var username string
if profileName != "default" {
if handle != profilemanager.DefaultProfileName {
username = *userName
}
if profileName != activeProf.Name || username != activeProf.Username {
resolved, err := s.resolveProfileHandle(handle, username)
if err != nil {
return nil, err
}
if resolved.ID != activeProf.ID || username != activeProf.Username {
if s.checkProfilesDisabled() {
log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled")
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
log.Infof("switching to profile %s for user %s", profileName, username)
log.Infof("switching to profile %s (%s) for user %s", resolved.Name, resolved.ID, username)
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profileName,
ID: resolved.ID,
Username: username,
}); err != nil {
log.Errorf("failed to set active profile state: %v", err)
return fmt.Errorf("failed to set active profile state: %w", err)
return nil, fmt.Errorf("failed to set active profile state: %w", err)
}
}
return nil
return resolved, nil
}
// SwitchProfile switches the active profile in the daemon.
@@ -906,9 +911,9 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
}
if msg != nil && msg.ProfileName != nil {
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
log.Errorf("failed to switch profile: %v", err)
return nil, fmt.Errorf("failed to switch profile: %w", err)
return nil, err
}
}
activeProf, err = s.profileManager.GetActiveProfileState()
@@ -924,7 +929,7 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
s.config = config
return &proto.SwitchProfileResponse{}, nil
return &proto.SwitchProfileResponse{Id: activeProf.ID.String()}, nil
}
// Down engine work in the daemon.
@@ -1014,22 +1019,27 @@ func (s *Server) Logout(ctx context.Context, msg *proto.LogoutRequest) (*proto.L
}
func (s *Server) handleProfileLogout(ctx context.Context, msg *proto.LogoutRequest) (*proto.LogoutResponse, error) {
if err := s.validateProfileOperation(*msg.ProfileName, true); err != nil {
return nil, err
}
if msg.Username == nil || *msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided when profile name is specified")
}
username := *msg.Username
if err := s.logoutFromProfile(ctx, *msg.ProfileName, username); err != nil {
log.Errorf("failed to logout from profile %s: %v", *msg.ProfileName, err)
resolved, err := s.resolveProfileHandle(*msg.ProfileName, username)
if err != nil {
return nil, err
}
if err := s.validateProfileOperation(resolved.ID, true); err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, resolved); err != nil {
log.Errorf("failed to logout from profile %s: %v", resolved.ID, err)
return nil, gstatus.Errorf(codes.Internal, "logout: %v", err)
}
activeProf, _ := s.profileManager.GetActiveProfileState()
if activeProf != nil && activeProf.Name == *msg.ProfileName {
if activeProf != nil && activeProf.ID == resolved.ID {
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
log.Errorf("failed to cleanup connection: %v", err)
}
@@ -1091,30 +1101,30 @@ func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*prof
return config, configExisted, nil
}
func (s *Server) canRemoveProfile(profileName string) error {
if profileName == profilemanager.DefaultProfileName {
func (s *Server) canRemoveProfile(id profilemanager.ID) error {
if id == profilemanager.DefaultProfileName {
return fmt.Errorf("remove profile with reserved name: %s", profilemanager.DefaultProfileName)
}
activeProf, err := s.profileManager.GetActiveProfileState()
if err == nil && activeProf.Name == profileName {
return fmt.Errorf("remove active profile: %s", profileName)
if err == nil && activeProf.ID == id {
return fmt.Errorf("remove active profile: %s", id)
}
return nil
}
func (s *Server) validateProfileOperation(profileName string, allowActiveProfile bool) error {
func (s *Server) validateProfileOperation(id profilemanager.ID, allowActiveProfile bool) error {
if s.checkProfilesDisabled() {
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if profileName == "" {
if id == "" {
return gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
}
if !allowActiveProfile {
if err := s.canRemoveProfile(profileName); err != nil {
if err := s.canRemoveProfile(id); err != nil {
return gstatus.Errorf(codes.InvalidArgument, "%v", err)
}
}
@@ -1122,25 +1132,20 @@ func (s *Server) validateProfileOperation(profileName string, allowActiveProfile
return nil
}
// logoutFromProfile logs out from a specific profile by loading its config and sending logout request
func (s *Server) logoutFromProfile(ctx context.Context, profileName, username string) error {
func (s *Server) logoutFromProfile(ctx context.Context, profile *profilemanager.Profile) error {
activeProf, err := s.profileManager.GetActiveProfileState()
if err == nil && activeProf.Name == profileName && s.connectClient != nil {
if err == nil && activeProf.ID == profile.ID && s.connectClient != nil {
return s.sendLogoutRequest(ctx)
}
profileState := &profilemanager.ActiveProfileState{
Name: profileName,
Username: username,
}
profilePath, err := profileState.FilePath()
if err != nil {
return fmt.Errorf("get profile path: %w", err)
cfgPath := profile.Path
if cfgPath == "" {
cfgPath = profilemanager.DefaultConfigPath
}
config, err := profilemanager.GetConfig(profilePath)
config, err := profilemanager.GetConfig(cfgPath)
if err != nil {
return fmt.Errorf("profile '%s' not found", profileName)
return fmt.Errorf("profile '%s' not found", profile.ID)
}
return s.sendLogoutRequestWithConfig(ctx, config)
@@ -1558,15 +1563,14 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
return nil, ctx.Err()
}
prof := profilemanager.ActiveProfileState{
Name: req.ProfileName,
Username: req.Username,
}
cfgPath, err := prof.FilePath()
resolved, err := s.resolveProfileHandle(req.ProfileName, req.Username)
if err != nil {
log.Errorf("failed to get active profile file path: %v", err)
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
log.Errorf("failed to resolve profile %q: %v", req.ProfileName, err)
return nil, err
}
cfgPath := resolved.Path
if cfgPath == "" {
cfgPath = profilemanager.DefaultConfigPath
}
cfg, err := profilemanager.GetConfig(cfgPath)
@@ -1671,12 +1675,39 @@ func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
}
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
created, err := s.profileManager.AddProfile(msg.ProfileName, msg.Username)
if err != nil {
log.Errorf("failed to create profile: %v", err)
return nil, fmt.Errorf("failed to create profile: %w", err)
}
return &proto.AddProfileResponse{}, nil
return &proto.AddProfileResponse{Id: created.ID.String()}, nil
}
func (s *Server) RenameProfile(ctx context.Context, msg *proto.RenameProfileRequest) (*proto.RenameProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.checkProfilesDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if msg.Handle == "" || msg.Username == "" || msg.NewProfileName == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name, username and new profile name must be provided")
}
resolved, err := s.resolveProfileHandle(msg.Handle, msg.Username)
if err != nil {
return nil, err
}
err = s.profileManager.RenameProfile(resolved.ID, msg.Username, msg.NewProfileName)
if err != nil {
log.Errorf("failed to rename profile: %v", err)
return nil, fmt.Errorf("failed to rename profile: %w", err)
}
return &proto.RenameProfileResponse{OldProfileName: resolved.Name}, nil
}
// RemoveProfile removes a profile from the daemon.
@@ -1684,20 +1715,29 @@ func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequ
s.mutex.Lock()
defer s.mutex.Unlock()
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
if s.checkProfilesDisabled() {
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if msg.ProfileName == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
}
resolved, err := s.resolveProfileHandle(msg.ProfileName, msg.Username)
if err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
if err := s.logoutFromProfile(ctx, resolved); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", resolved.ID, err)
}
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
if err := s.profileManager.RemoveProfile(resolved.ID, msg.Username); err != nil {
log.Errorf("failed to remove profile: %v", err)
return nil, fmt.Errorf("failed to remove profile: %w", err)
}
return &proto.RemoveProfileResponse{}, nil
return &proto.RemoveProfileResponse{Id: resolved.ID.String()}, nil
}
// ListProfiles lists all profiles in the daemon.
@@ -1720,6 +1760,7 @@ func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesReques
}
for i, profile := range profiles {
response.Profiles[i] = &proto.Profile{
Id: profile.ID.String(),
Name: profile.Name,
IsActive: profile.IsActive,
}
@@ -1728,7 +1769,9 @@ func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesReques
return response, nil
}
// GetActiveProfile returns the active profile in the daemon.
// GetActiveProfile returns the active profile in the daemon. The ProfileName
// field carries the display name for backwards compatibility with UI clients,
// new callers should prefer Id.
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
@@ -1739,9 +1782,23 @@ func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfi
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
// Fallback to legacy name == ID
displayName := activeProfile.ID.String()
if activeProfile.ID != profilemanager.DefaultProfileName {
if profiles, lerr := s.profileManager.ListProfiles(activeProfile.Username); lerr == nil {
for _, p := range profiles {
if p.ID == activeProfile.ID {
displayName = p.Name
break
}
}
}
}
return &proto.GetActiveProfileResponse{
ProfileName: activeProfile.Name,
ProfileName: displayName,
Username: activeProfile.Username,
Id: activeProfile.ID.String(),
}, nil
}

View File

@@ -97,7 +97,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "test-profile",
ID: "test-profile",
Username: currUser.Username,
})
if err != nil {
@@ -158,7 +158,7 @@ func TestServer_Up(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profName,
ID: profilemanager.ID(profName),
Username: currUser.Username,
})
if err != nil {
@@ -228,7 +228,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: "default",
ID: "default",
Username: currUser.Username,
})
if err != nil {

View File

@@ -62,7 +62,7 @@ func setupServerWithProfile(t *testing.T) (s *Server, ctx context.Context, profN
pm := profilemanager.ServiceManager{}
require.NoError(t, pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profName,
ID: profilemanager.ID(profName),
Username: currUser.Username,
}))
@@ -107,9 +107,9 @@ func TestSetConfig_MDMReject_SingleField(t *testing.T) {
func TestSetConfig_MDMReject_MultipleFields(t *testing.T) {
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
mdm.KeyManagementURL: "https://mdm.example.com:443",
mdm.KeyBlockInbound: true,
mdm.KeyRosenpassEnabled: true,
mdm.KeyManagementURL: "https://mdm.example.com:443",
mdm.KeyBlockInbound: true,
mdm.KeyRosenpassEnabled: true,
}))
s, ctx, profName, username, _ := setupServerWithProfile(t)

View File

@@ -47,7 +47,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: profName,
ID: profilemanager.ID(profName),
Username: currUser.Username,
})
require.NoError(t, err)
@@ -96,7 +96,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
DisableNotifications: &disableNotifications,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,
DisableIpv6: &disableIPv6,
DisableIpv6: &disableIPv6,
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
CleanNATExternalIPs: false,
CustomDNSAddress: []byte("1.1.1.1:53"),
@@ -112,7 +112,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.NoError(t, err)
profState := profilemanager.ActiveProfileState{
Name: profName,
ID: profilemanager.ID(profName),
Username: currUser.Username,
}
cfgPath, err := profState.FilePath()

View File

@@ -645,7 +645,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
}
req := &proto.SetConfigRequest{
ProfileName: activeProf.Name,
ProfileName: activeProf.ID.String(),
Username: currUser.Username,
}
@@ -818,13 +818,15 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe
return nil, fmt.Errorf("get current user: %w", err)
}
handle := activeProf.ID.String()
loginReq := &proto.LoginRequest{
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ProfileName: &activeProf.Name,
ProfileName: &handle,
Username: &currUser.Username,
}
profileState, err := s.profileManager.GetProfileState(activeProf.Name)
profileState, err := s.profileManager.GetProfileState(activeProf.ID)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
@@ -1367,7 +1369,7 @@ func (s *serviceClient) getSrvConfig() {
}
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
ProfileName: activeProf.Name,
ProfileName: activeProf.ID.String(),
Username: currUser.Username,
})
if err != nil {
@@ -1613,7 +1615,7 @@ func (s *serviceClient) loadSettings() {
}
cfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
ProfileName: activeProf.Name,
ProfileName: activeProf.ID.String(),
Username: currUser.Username,
})
if err != nil {
@@ -1813,7 +1815,7 @@ func (s *serviceClient) updateConfig() error {
}
req := proto.SetConfigRequest{
ProfileName: activeProf.Name,
ProfileName: activeProf.ID.String(),
Username: currUser.Username,
DisableAutoConnect: &disableAutoStart,
ServerSSHAllowed: &sshAllowed,

View File

@@ -66,7 +66,7 @@ func (s *serviceClient) showProfilesUI() {
} else {
indicator.SetText("")
}
nameLabel.SetText(profile.Name)
nameLabel.SetText(formatProfileLabel(profile, profiles))
// Configure Select/Active button
selectBtn.SetText(func() string {
@@ -88,7 +88,7 @@ func (s *serviceClient) showProfilesUI() {
return
}
// switch
err = s.switchProfile(profile.Name)
err = s.switchProfile(profile.ID)
if err != nil {
log.Errorf("failed to switch profile: %v", err)
dialog.ShowError(errors.New("failed to select profile"), s.wProfiles)
@@ -130,7 +130,7 @@ func (s *serviceClient) showProfilesUI() {
logoutBtn.Show()
logoutBtn.SetText("Deregister")
logoutBtn.OnTapped = func() {
s.handleProfileLogout(profile.Name, refresh)
s.handleProfileLogout(profile, refresh)
}
// Remove profile
@@ -144,7 +144,7 @@ func (s *serviceClient) showProfilesUI() {
return
}
err = s.removeProfile(profile.Name)
err = s.removeProfile(profile.ID)
if err != nil {
log.Errorf("failed to remove profile: %v", err)
dialog.ShowError(fmt.Errorf("failed to remove profile"), s.wProfiles)
@@ -250,7 +250,7 @@ func (s *serviceClient) addProfile(profileName string) error {
return nil
}
func (s *serviceClient) switchProfile(profileName string) error {
func (s *serviceClient) switchProfile(handle string) error {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
return fmt.Errorf(getClientFMT, err)
@@ -261,15 +261,15 @@ func (s *serviceClient) switchProfile(profileName string) error {
return fmt.Errorf("get current user: %w", err)
}
if _, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
ProfileName: &profileName,
resp, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
ProfileName: &handle,
Username: &currUser.Username,
}); err != nil {
})
if err != nil {
return fmt.Errorf("switch profile failed: %w", err)
}
err = s.profileManager.SwitchProfile(profileName)
if err != nil {
if err := s.profileManager.SwitchProfile(profilemanager.ID(resp.Id)); err != nil {
return fmt.Errorf("switch profile: %w", err)
}
@@ -299,10 +299,27 @@ func (s *serviceClient) removeProfile(profileName string) error {
}
type Profile struct {
ID string
Name string
IsActive bool
}
// formatProfileLabel returns the display label for a profile. Profiles can
// share the same Name, so when more than one profile in profiles carries this
// Name, a short form of the ID is appended to disambiguate the entries.
func formatProfileLabel(profile Profile, profiles []Profile) string {
count := 0
for _, p := range profiles {
if p.Name == profile.Name {
count++
}
}
if count <= 1 {
return profile.Name
}
return fmt.Sprintf("%s (%s)", profile.Name, profilemanager.ID(profile.ID).ShortID())
}
func (s *serviceClient) getProfiles() ([]Profile, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
@@ -324,6 +341,7 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
for _, profile := range profilesResp.Profiles {
profiles = append(profiles, Profile{
ID: profile.Id,
Name: profile.Name,
IsActive: profile.IsActive,
})
@@ -332,10 +350,10 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
return profiles, nil
}
func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) {
func (s *serviceClient) handleProfileLogout(profile Profile, refreshCallback func()) {
dialog.ShowConfirm(
"Deregister",
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profileName),
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profile.Name),
func(confirm bool) {
if !confirm {
return
@@ -356,8 +374,10 @@ func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback
}
username := currUser.Username
// ProfileName is treated as a handle; send the ID so the
// daemon resolves to exactly this profile.
_, err = conn.Logout(s.ctx, &proto.LogoutRequest{
ProfileName: &profileName,
ProfileName: &profile.ID,
Username: &username,
})
if err != nil {
@@ -368,7 +388,7 @@ func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback
dialog.ShowInformation(
"Deregistered",
fmt.Sprintf("Successfully deregistered from '%s'", profileName),
fmt.Sprintf("Successfully deregistered from '%s'", profile.Name),
s.wProfiles,
)
@@ -461,6 +481,7 @@ func (p *profileMenu) getProfiles() ([]Profile, error) {
for _, profile := range profilesResp.Profiles {
profiles = append(profiles, Profile{
ID: profile.Id,
Name: profile.Name,
IsActive: profile.IsActive,
})
@@ -501,7 +522,7 @@ func (p *profileMenu) refresh() {
}
if activeProf.ProfileName == "default" || activeProf.Username == currUser.Username {
activeProfState, err := p.profileManager.GetProfileState(activeProf.ProfileName)
activeProfState, err := p.profileManager.GetProfileState(profilemanager.ID(activeProf.Id))
if err != nil {
log.Warnf("failed to get active profile state: %v", err)
p.emailMenuItem.Hide()
@@ -512,7 +533,7 @@ func (p *profileMenu) refresh() {
}
for _, profile := range profiles {
item := p.profileMenuItem.AddSubMenuItem(profile.Name, "")
item := p.profileMenuItem.AddSubMenuItem(formatProfileLabel(profile, profiles), "")
if profile.IsActive {
item.Check()
}
@@ -541,8 +562,8 @@ func (p *profileMenu) refresh() {
return
}
_, err = conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profile.Name,
switchResp, err := conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profile.ID,
Username: &currUser.Username,
})
if err != nil {
@@ -552,7 +573,7 @@ func (p *profileMenu) refresh() {
return
}
err = p.profileManager.SwitchProfile(profile.Name)
err = p.profileManager.SwitchProfile(profilemanager.ID(switchResp.Id))
if err != nil {
log.Errorf("failed to switch profile '%s': %v", profile.Name, err)
return
@@ -727,7 +748,10 @@ func (p *profileMenu) updateMenu() {
}
sort.Slice(profiles, func(i, j int) bool {
return profiles[i].Name < profiles[j].Name
if profiles[i].Name != profiles[j].Name {
return profiles[i].Name < profiles[j].Name
}
return profiles[i].ID < profiles[j].ID
})
p.mu.Lock()

View File

@@ -466,15 +466,20 @@ func feedRouterFromListener(ctx context.Context, ln net.Listener, router *nbtcp.
_ = ln.Close()
}()
var backoff nbtcp.AcceptBackoff
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
if ctx.Err() != nil || nbtcp.IsClosedListenerErr(err) {
return
}
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v; backing off", err)
if !backoff.Backoff(ctx) {
return
}
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v", err)
continue
}
backoff.Reset()
router.HandleConn(ctx, conn)
}
}

View File

@@ -533,3 +533,125 @@ MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`)
// scriptedAcceptListener returns pre-scripted errors from Accept(). Used
// to drive the feedRouterFromListener tests without binding a real
// socket — the production code path is a netstack-backed listener that
// returns gVisor's "endpoint is in invalid state" forever after its
// endpoint is destroyed.
type scriptedAcceptListener struct {
errs chan error
closed chan struct{}
}
func newScriptedAcceptListener(errs ...error) *scriptedAcceptListener {
s := &scriptedAcceptListener{
errs: make(chan error, len(errs)+1),
closed: make(chan struct{}),
}
for _, e := range errs {
s.errs <- e
}
return s
}
func (s *scriptedAcceptListener) Accept() (net.Conn, error) {
select {
case <-s.closed:
return nil, net.ErrClosed
case err := <-s.errs:
return nil, err
}
}
func (s *scriptedAcceptListener) Close() error {
select {
case <-s.closed:
default:
close(s.closed)
}
return nil
}
func (s *scriptedAcceptListener) Addr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
}
// errSentinel carries a literal error message so tests can synthesise
// the exact gVisor text without importing the netstack package.
type errSentinel string
func (e errSentinel) Error() string { return string(e) }
// TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint is the
// regression guard for the inbound side of the tight-loop bug. The
// per-account plain-HTTP feeder must recognise gVisor's "endpoint is in
// invalid state" and exit, otherwise it pegs a CPU core and floods the
// account-scoped log with the same accept error every iteration.
func TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80}
router := nbtcp.NewRouter(logger, nil, addr)
gvisorErr := &net.OpError{
Op: "accept",
Net: "tcp",
Addr: addr,
Err: errSentinel("endpoint is in invalid state"),
}
ln := newScriptedAcceptListener(gvisorErr)
defer ln.Close()
done := make(chan struct{})
go func() {
defer close(done)
feedRouterFromListener(context.Background(), ln, router, logger, "acct-1")
}()
select {
case <-done:
// Expected: loop recognised the gVisor error and returned.
case <-time.After(2 * time.Second):
t.Fatal("feedRouterFromListener did not exit on gVisor 'endpoint is in invalid state' — accept loop is spinning")
}
}
// TestFeedRouterFromListener_BacksOffOnTransientError asserts the
// defence-in-depth path: an unknown sticky Accept error must NOT cause
// CPU spin. The loop backs off and exits cleanly when ctx is cancelled.
func TestFeedRouterFromListener_BacksOffOnTransientError(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80}
router := nbtcp.NewRouter(logger, nil, addr)
const transientCount = 5
errs := make([]error, transientCount)
for i := range errs {
errs[i] = errSentinel("transient: temporary network error")
}
ln := newScriptedAcceptListener(errs...)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
start := time.Now()
done := make(chan struct{})
go func() {
defer close(done)
feedRouterFromListener(ctx, ln, router, logger, "acct-1")
}()
time.AfterFunc(150*time.Millisecond, cancel)
select {
case <-done:
// Expected.
case <-time.After(2 * time.Second):
t.Fatal("feedRouterFromListener did not exit on ctx cancellation — backoff or exit path broken")
}
// Without backoff the 5 scripted errors would burn in microseconds.
// With backoff the first delay alone is 5ms, so the loop must take
// at least that long even though ctx fires at 150ms.
elapsed := time.Since(start)
assert.GreaterOrEqual(t, elapsed, 5*time.Millisecond,
"loop ran without backing off — would burn CPU in production")
}

View File

@@ -356,7 +356,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
// Create embedded NetBird client with the generated private key.
// The peer has already been created via CreateProxyPeer RPC with the public key.
wgPort := int(n.clientCfg.WGPort)
client, err := embed.New(embed.Options{
embedOpts := embed.Options{
DeviceName: deviceNamePrefix + n.proxyID,
ManagementURL: n.clientCfg.MgmtAddr,
PrivateKey: privateKey.String(),
@@ -371,7 +371,9 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
WireguardPort: &wgPort,
PreSharedKey: n.clientCfg.PreSharedKey,
Performance: n.clientCfg.Performance,
})
}
logEmbedOptions(n.logger, accountID, serviceID, publicKey.String(), embedOpts)
client, err := embed.New(embedOpts)
if err != nil {
return nil, fmt.Errorf("create netbird client: %w", err)
}
@@ -847,3 +849,53 @@ func DirectUpstreamFromContext(ctx context.Context) bool {
v, _ := ctx.Value(directUpstreamContextKey{}).(bool)
return v
}
// logEmbedOptions emits a single structured INFO line summarising every
// operationally meaningful flag handed to embed.New for this per-account
// client. Secrets (PrivateKey, PreSharedKey) are reduced to a "present"
// boolean — never logged verbatim. Use this when an embedded peer
// silently misbehaves: most failure modes (inbound drops, wrong
// management URL, v6 unexpectedly on, userspace flipped, port clash)
// are obvious from these flags before any traffic flows.
func logEmbedOptions(logger *log.Logger, accountID types.AccountID, serviceID types.ServiceID, publicKey string, opts embed.Options) {
wgPort := 0
if opts.WireguardPort != nil {
wgPort = *opts.WireguardPort
}
mtu := uint16(0)
if opts.MTU != nil {
mtu = *opts.MTU
}
perfBuffers := uint32(0)
if opts.Performance.PreallocatedBuffersPerPool != nil {
perfBuffers = *opts.Performance.PreallocatedBuffersPerPool
}
perfBatch := uint32(0)
if opts.Performance.MaxBatchSize != nil {
perfBatch = *opts.Performance.MaxBatchSize
}
logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": serviceID,
"public_key": publicKey,
"device_name": opts.DeviceName,
"management_url": opts.ManagementURL,
"log_level": opts.LogLevel,
"wg_port": wgPort,
"mtu": mtu,
"block_inbound": opts.BlockInbound,
"block_lan_access": opts.BlockLANAccess,
"disable_ipv6": opts.DisableIPv6,
"disable_client_routes": opts.DisableClientRoutes,
"no_userspace": opts.NoUserspace,
"config_path_set": opts.ConfigPath != "",
"state_path_set": opts.StatePath != "",
"private_key_present": opts.PrivateKey != "",
"presharedkey_present": opts.PreSharedKey != "",
"setup_key_present": opts.SetupKey != "",
"jwt_token_present": opts.JWTToken != "",
"dns_labels": opts.DNSLabels,
"perf_buffers_per_pool": perfBuffers,
"perf_max_batch_size": perfBatch,
}).Info("starting embedded netbird client for account")
}

View File

@@ -0,0 +1,85 @@
package tcp
import (
"context"
"errors"
"net"
"strings"
"time"
)
// gvisorInvalidEndpointMsg is the canonical text gVisor netstack returns
// when Accept() is called on a listener whose underlying endpoint has
// been destroyed (peer rekey, embedded-client reset, account churn).
// There is no exported sentinel from gvisor.dev/gvisor/pkg/tcpip that
// survives gonet's *net.OpError wrapping in a way errors.Is can match,
// so we fall back to a string check. Stable across the gVisor versions
// netbird pins.
const gvisorInvalidEndpointMsg = "endpoint is in invalid state"
// IsClosedListenerErr reports whether err signals that an accept loop
// should exit because the underlying listener can no longer serve
// connections. It recognises:
//
// - net.ErrClosed for stdlib listeners (Listener.Close was called).
// - gVisor's "endpoint is in invalid state" for netstack-backed
// listeners whose endpoint was destroyed out from under them
// (typically when a per-account WireGuard netstack is reset without
// also tearing the listener entry down).
//
// Without the gVisor branch an accept loop on a netstack listener spins
// CPU-hot forever after the endpoint dies, because Accept never blocks
// again and the error neither matches net.ErrClosed nor cancels ctx.
func IsClosedListenerErr(err error) bool {
if err == nil {
return false
}
if errors.Is(err, net.ErrClosed) {
return true
}
return strings.Contains(err.Error(), gvisorInvalidEndpointMsg)
}
// AcceptBackoff implements the exponential backoff used by
// net/http.Server.Serve for transient Accept errors. Without it a loop
// hitting a sticky unknown error burns a full CPU core. The zero value
// is ready to use; call Reset after a successful Accept.
type AcceptBackoff struct {
delay time.Duration
}
// minAcceptDelay / maxAcceptDelay mirror the stdlib defaults
// (net/http.Server.Serve) and keep us well below 1 log line per second
// per orphaned listener.
const (
minAcceptDelay = 5 * time.Millisecond
maxAcceptDelay = time.Second
)
// Backoff waits the next exponential delay (5ms doubling up to 1s) and
// returns true when the wait completed. Returns false if ctx fired
// during the wait — callers should treat that as "exit the loop".
func (b *AcceptBackoff) Backoff(ctx context.Context) bool {
b.advance()
select {
case <-ctx.Done():
return false
case <-time.After(b.delay):
return true
}
}
// Reset clears the accumulated delay so the next failure starts at the
// minimum delay again. Call after a successful Accept.
func (b *AcceptBackoff) Reset() { b.delay = 0 }
func (b *AcceptBackoff) advance() {
if b.delay == 0 {
b.delay = minAcceptDelay
} else {
b.delay *= 2
}
if b.delay > maxAcceptDelay {
b.delay = maxAcceptDelay
}
}

View File

@@ -0,0 +1,142 @@
package tcp
import (
"context"
"errors"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestIsClosedListenerErr_NetErrClosed verifies the stdlib path: a
// closed *net.Listener returns net.ErrClosed wrapped in *net.OpError,
// and IsClosedListenerErr must unwrap it.
func TestIsClosedListenerErr_NetErrClosed(t *testing.T) {
wrapped := &net.OpError{Op: "accept", Net: "tcp", Err: net.ErrClosed}
assert.True(t, IsClosedListenerErr(wrapped),
"net.OpError wrapping net.ErrClosed must be recognised as closed")
}
// TestIsClosedListenerErr_GVisorInvalidEndpoint is the load-bearing
// regression guard. A gVisor netstack listener whose endpoint has been
// destroyed returns this exact text. Without recognising it the accept
// loop spins forever and burns a CPU core.
func TestIsClosedListenerErr_GVisorInvalidEndpoint(t *testing.T) {
err := fmt.Errorf("accept tcp 10.10.1.254:80: endpoint is in invalid state")
assert.True(t, IsClosedListenerErr(err),
"gVisor 'endpoint is in invalid state' must be recognised as closed")
}
// TestIsClosedListenerErr_OtherError confirms we don't over-match —
// transient errors must keep returning false so the backoff path runs.
func TestIsClosedListenerErr_OtherError(t *testing.T) {
cases := []error{
errors.New("temporary failure"),
errors.New("accept tcp 10.10.1.254:80: too many open files"),
nil,
}
for _, c := range cases {
assert.False(t, IsClosedListenerErr(c),
"unexpected match on %v — must not be treated as closed", c)
}
}
// TestAcceptBackoff_ProgressionAndCap asserts the doubling schedule:
// 5ms, 10ms, 20ms, 40ms, ... capped at 1s. The test runs against a
// real timer but uses tight bounds so a slow CI machine still passes.
func TestAcceptBackoff_ProgressionAndCap(t *testing.T) {
var b AcceptBackoff
expected := []time.Duration{
5 * time.Millisecond,
10 * time.Millisecond,
20 * time.Millisecond,
40 * time.Millisecond,
}
for i, want := range expected {
start := time.Now()
ok := b.Backoff(context.Background())
elapsed := time.Since(start)
require.True(t, ok, "Backoff %d must complete; ctx is alive", i)
assert.GreaterOrEqual(t, elapsed, want,
"backoff %d (%v) must wait at least the configured delay", i, want)
assert.Less(t, elapsed, want*4,
"backoff %d (%v) must not overshoot by more than 4x — caps misbehaving", i, want)
}
// Burn enough rounds to reach the cap, then assert subsequent
// rounds stay at exactly maxAcceptDelay (1s) — the timer should
// never exceed it.
for range 6 {
b.Backoff(context.Background())
}
assert.Equal(t, maxAcceptDelay, b.delay,
"after enough doublings the delay must clamp to maxAcceptDelay")
}
// TestAcceptBackoff_Reset confirms that a successful Accept resets the
// schedule — a busy-then-quiet listener mustn't stay on a 1s timer
// after recovery.
func TestAcceptBackoff_Reset(t *testing.T) {
var b AcceptBackoff
for range 5 {
b.Backoff(context.Background())
}
require.NotEqual(t, time.Duration(0), b.delay, "precondition: delay must have accumulated")
b.Reset()
assert.Equal(t, time.Duration(0), b.delay, "Reset must zero the delay")
start := time.Now()
ok := b.Backoff(context.Background())
elapsed := time.Since(start)
require.True(t, ok, "Backoff after Reset must complete")
assert.GreaterOrEqual(t, elapsed, minAcceptDelay,
"after Reset the next backoff must restart at minAcceptDelay")
assert.Less(t, elapsed, 50*time.Millisecond,
"after Reset the next backoff must NOT carry over the prior delay")
}
// TestAcceptBackoff_CancelDuringWait proves the loop exits promptly
// when ctx fires mid-wait. Without this, a tear-down would still take
// up to 1 second per orphaned listener.
func TestAcceptBackoff_CancelDuringWait(t *testing.T) {
var b AcceptBackoff
// Drive the backoff up so the next call will wait ~1s — long
// enough that we can detect early cancellation.
for range 10 {
b.Backoff(context.Background())
}
require.Equal(t, maxAcceptDelay, b.delay)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(20 * time.Millisecond)
cancel()
}()
start := time.Now()
ok := b.Backoff(ctx)
elapsed := time.Since(start)
assert.False(t, ok, "Backoff must return false when ctx is cancelled mid-wait")
assert.Less(t, elapsed, 200*time.Millisecond,
"cancellation must short-circuit the timer; took %v", elapsed)
}
// TestAcceptBackoff_CancelBeforeCall — when ctx is already done the
// loop exits without sleeping at all.
func TestAcceptBackoff_CancelBeforeCall(t *testing.T) {
var b AcceptBackoff
ctx, cancel := context.WithCancel(context.Background())
cancel()
start := time.Now()
ok := b.Backoff(ctx)
elapsed := time.Since(start)
assert.False(t, ok, "Backoff must return false when ctx is already cancelled")
assert.Less(t, elapsed, 50*time.Millisecond,
"already-cancelled ctx must return immediately; took %v", elapsed)
}

View File

@@ -297,18 +297,23 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
}
}()
var backoff AcceptBackoff
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
if ctx.Err() != nil || IsClosedListenerErr(err) {
if ok := r.Drain(DefaultDrainTimeout); !ok {
r.logger.Warn("timed out waiting for connections to drain")
}
return nil
}
r.logger.Debugf("SNI router accept: %v", err)
r.logger.Debugf("SNI router accept: %v; backing off", err)
if !backoff.Backoff(ctx) {
return nil
}
continue
}
backoff.Reset()
r.logger.Debugf("SNI router accepted conn from %s on %s", conn.RemoteAddr(), conn.LocalAddr())
r.activeConns.Add(1)
go func() {

View File

@@ -1836,3 +1836,132 @@ func TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled(t *testing.T) {
t.Fatal("TLS conn never reached the TLS channel")
}
}
// scriptedAcceptListener is a net.Listener whose Accept() returns
// pre-scripted errors. Used by the accept-loop exit tests to simulate
// the failure mode that triggers the tight-loop bug: a netstack
// listener whose endpoint has been destroyed and now returns the gVisor
// "endpoint is in invalid state" error from every Accept call.
type scriptedAcceptListener struct {
errs chan error
closed chan struct{}
}
func newScriptedAcceptListener(errs ...error) *scriptedAcceptListener {
s := &scriptedAcceptListener{
errs: make(chan error, len(errs)+1),
closed: make(chan struct{}),
}
for _, e := range errs {
s.errs <- e
}
return s
}
func (s *scriptedAcceptListener) Accept() (net.Conn, error) {
select {
case <-s.closed:
return nil, net.ErrClosed
case err := <-s.errs:
return nil, err
}
}
func (s *scriptedAcceptListener) Close() error {
select {
case <-s.closed:
default:
close(s.closed)
}
return nil
}
func (s *scriptedAcceptListener) Addr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
}
// TestRouter_Serve_ExitsOnGVisorInvalidEndpoint is the regression guard
// for the tight-loop bug: when the underlying netstack endpoint is
// destroyed, Accept returns "endpoint is in invalid state" forever. The
// loop must recognise that signal and return, otherwise it pegs a CPU
// core and floods logs.
func TestRouter_Serve_ExitsOnGVisorInvalidEndpoint(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, nil, addr)
gvisorErr := &net.OpError{
Op: "accept",
Net: "tcp",
Addr: addr,
Err: errSentinel("endpoint is in invalid state"),
}
ln := newScriptedAcceptListener(gvisorErr)
defer ln.Close()
done := make(chan error, 1)
go func() {
done <- router.Serve(context.Background(), ln)
}()
select {
case err := <-done:
assert.NoError(t, err, "Serve must return cleanly on a recognised closed-listener error")
case <-time.After(2 * time.Second):
t.Fatal("Serve did not exit on gVisor 'endpoint is in invalid state' — accept loop is spinning")
}
}
// TestRouter_Serve_BacksOffOnTransientError verifies the defence-in-
// depth path: when Accept returns an unknown transient error, the loop
// MUST not spin. It backs off, then exits cleanly once ctx is cancelled.
// "Bounded call count" stands in for "no CPU spin" — without backoff
// the goroutine would issue thousands of Accept calls in this window.
func TestRouter_Serve_BacksOffOnTransientError(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, nil, addr)
const transientErrCount = 5
errs := make([]error, transientErrCount)
for i := range errs {
errs[i] = errSentinel("transient: too many open files")
}
ln := newScriptedAcceptListener(errs...)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
start := time.Now()
go func() {
done <- router.Serve(ctx, ln)
}()
// Cancel after enough time for the backoff to climb (5ms + 10ms +
// 20ms + 40ms = 75ms minimum), but short enough that a spinning
// loop would have made thousands of calls by now.
time.AfterFunc(150*time.Millisecond, cancel)
select {
case err := <-done:
assert.NoError(t, err, "Serve must return cleanly on ctx cancellation")
case <-time.After(2 * time.Second):
t.Fatal("Serve did not exit on ctx cancellation — backoff or exit path broken")
}
// Without backoff the loop would burn through all 5 scripted errors
// in microseconds and then block on the channel. With backoff the
// total wall time should be at least 5ms (the first backoff).
elapsed := time.Since(start)
assert.GreaterOrEqual(t, elapsed, minAcceptDelay,
"loop ran without backing off — would burn CPU in production")
}
// errSentinel mirrors gVisor's tcpip error message exactly. We can't
// import the gVisor package without dragging in the whole netstack, so
// the test uses the canonical string the production error formatter
// emits — same shape IsClosedListenerErr matches in production.
type errSentinel string
func (e errSentinel) Error() string { return string(e) }