Compare commits

..

3 Commits

Author SHA1 Message Date
bcmmbaga
37ce40ede6 build client config in sync response 2026-06-05 23:53:48 +03:00
bcmmbaga
318014552f Merge branch 'main' into refactor/mgmt-bootstrap 2026-06-05 22:02:25 +03:00
bcmmbaga
78ed62535e Clean up middleware and move auth middlewa into module 2026-05-26 20:45:57 +03:00
40 changed files with 535 additions and 1744 deletions

View File

@@ -29,10 +29,10 @@ jobs:
persist-credentials: false
- name: Generate FreeBSD port diff
run: bash -x release_files/freebsd-port-diff.sh
run: bash release_files/freebsd-port-diff.sh
- name: Generate FreeBSD port issue body
run: bash -x release_files/freebsd-port-issue-body.sh
run: bash release_files/freebsd-port-issue-body.sh
- name: Check if diff was generated
id: check_diff

View File

@@ -65,7 +65,7 @@ jobs:
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
if [ ${SIZE} -gt 62914560 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 60MB limit!"
if [ ${SIZE} -gt 58720256 ]; then
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
exit 1
fi

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"os"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
@@ -23,7 +24,6 @@ const (
// Profile represents a profile for gomobile
type Profile struct {
ID string
Name string
IsActive bool
}
@@ -53,10 +53,10 @@ func (p *ProfileArray) Get(i int) *Profile {
├── state.json ← Default profile state
├── active_profile.json ← Active profile tracker (JSON with Name + Username)
└── profiles/ ← Subdirectory for non-default profiles
├── work.json ← Legacy work profile config
├── work.state.json ← Legacy work profile state
├── 4c5f5c8198c3989cffb5b5394f5a7ae0.json ← ID profile config
── 4c5f5c8198c3989cffb5b5394f5a7ae0.state.json ← ID profile state
├── work.json ← Work profile config
├── work.state.json ← Work profile state
├── personal.json ← Personal profile config
── personal.state.json ← Personal profile state
*/
// ProfileManager manages profiles for Android
@@ -99,7 +99,6 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
var profiles []*Profile
for _, p := range internalProfiles {
profiles = append(profiles, &Profile{
ID: p.ID.String(),
Name: p.Name,
IsActive: p.IsActive,
})
@@ -109,65 +108,55 @@ func (pm *ProfileManager) ListProfiles() (*ProfileArray, error) {
}
// GetActiveProfile returns the currently active profile name
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
func (pm *ProfileManager) GetActiveProfile() (string, error) {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
activeState, err := pm.serviceMgr.GetActiveProfileState()
if err != nil {
return nil, fmt.Errorf("failed to get active profile: %w", err)
return "", fmt.Errorf("failed to get active profile: %w", err)
}
// ActiveProfileState only stores the ID (and username), not the display
// name. Resolve the ID to the full profile so callers get the real Name.
prof, err := pm.serviceMgr.ResolveProfile(activeState.ID.String(), androidUsername)
if err != nil {
return nil, fmt.Errorf("failed to resolve active profile %q: %w", activeState.ID, err)
}
return &Profile{ID: prof.ID.String(), Name: prof.Name, IsActive: true}, nil
return activeState.Name, nil
}
// SwitchProfile switches to a different profile
func (pm *ProfileManager) SwitchProfile(id string) error {
func (pm *ProfileManager) SwitchProfile(profileName string) error {
// Use ServiceManager to stay consistent with ListProfiles
// ServiceManager uses active_profile.json
err := pm.serviceMgr.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: profilemanager.ID(id),
Name: profileName,
Username: androidUsername,
})
if err != nil {
return fmt.Errorf("failed to switch profile: %w", err)
}
log.Infof("switched to profile: %s", id)
log.Infof("switched to profile: %s", profileName)
return nil
}
// AddProfile creates a new profile
func (pm *ProfileManager) AddProfile(profileName string) error {
// Use ServiceManager (creates profile in profiles/ directory)
profile, err := pm.serviceMgr.AddProfile(profileName, androidUsername)
if err != nil {
if err := pm.serviceMgr.AddProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to add profile: %w", err)
}
log.Infof("created new profile: %s", profile.ID)
log.Infof("created new profile: %s", profileName)
return nil
}
// LogoutProfile logs out from a profile (clears authentication)
func (pm *ProfileManager) LogoutProfile(id string) error {
configPath, err := pm.getProfileConfigPath(id)
func (pm *ProfileManager) LogoutProfile(profileName string) error {
profileName = sanitizeProfileName(profileName)
configPath, err := pm.getProfileConfigPath(profileName)
if err != nil {
return err
}
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
return fmt.Errorf("id '%s' is not valid", id)
}
// Check if profile exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return fmt.Errorf("profile '%s' does not exist", id)
return fmt.Errorf("profile '%s' does not exist", profileName)
}
// Read current config using internal profilemanager
@@ -185,56 +174,53 @@ func (pm *ProfileManager) LogoutProfile(id string) error {
return fmt.Errorf("failed to save config: %w", err)
}
log.Infof("logged out from profile: %s", id)
log.Infof("logged out from profile: %s", profileName)
return nil
}
// RemoveProfile deletes a profile
func (pm *ProfileManager) RemoveProfile(id string) error {
func (pm *ProfileManager) RemoveProfile(profileName string) error {
// Use ServiceManager (removes profile from profiles/ directory)
if err := pm.serviceMgr.RemoveProfile(profilemanager.ID(id), androidUsername); err != nil {
if err := pm.serviceMgr.RemoveProfile(profileName, androidUsername); err != nil {
return fmt.Errorf("failed to remove profile: %w", err)
}
log.Infof("removed profile: %s", id)
log.Infof("removed profile: %s", profileName)
return nil
}
// getProfileConfigPath returns the config file path for a profile
// This is needed for Android-specific path handling (netbird.cfg for default profile)
func (pm *ProfileManager) getProfileConfigPath(id string) (string, error) {
if id == "" || id == profilemanager.DefaultProfileName {
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
return "", fmt.Errorf("id %q is not valid", id)
}
func (pm *ProfileManager) getProfileConfigPath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
// Android uses netbird.cfg for default profile instead of default.json
// Default profile is stored in root configDir, not in profiles/
return filepath.Join(pm.configDir, defaultConfigFilename), nil
}
// Non-default profiles are stored in profiles subdirectory
// This matches the Java Preferences.java expectation
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, id+".json"), nil
return filepath.Join(profilesDir, profileName+".json"), nil
}
// GetConfigPath returns the config file path for a given profile id
// GetConfigPath returns the config file path for a given profile
// Java should call this instead of constructing paths with Preferences.configFile()
func (pm *ProfileManager) GetConfigPath(id string) (string, error) {
return pm.getProfileConfigPath(id)
func (pm *ProfileManager) GetConfigPath(profileName string) (string, error) {
return pm.getProfileConfigPath(profileName)
}
// GetStateFilePath returns the state file path for a given profile
// Java should call this instead of constructing paths with Preferences.stateFile()
func (pm *ProfileManager) GetStateFilePath(id string) (string, error) {
if id == "" || id == profilemanager.DefaultProfileName {
func (pm *ProfileManager) GetStateFilePath(profileName string) (string, error) {
if profileName == "" || profileName == profilemanager.DefaultProfileName {
return filepath.Join(pm.configDir, "state.json"), nil
}
if !profilemanager.IsValidProfileFilenameStem(profilemanager.ID(id)) {
return "", fmt.Errorf("id %q is not valid", id)
}
profileName = sanitizeProfileName(profileName)
profilesDir := filepath.Join(pm.configDir, profilesSubdir)
return filepath.Join(profilesDir, id+".state.json"), nil
return filepath.Join(profilesDir, profileName+".state.json"), nil
}
// GetActiveConfigPath returns the config file path for the currently active profile
@@ -244,7 +230,7 @@ func (pm *ProfileManager) GetActiveConfigPath() (string, error) {
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetConfigPath(activeProfile.ID)
return pm.GetConfigPath(activeProfile)
}
// GetActiveStateFilePath returns the state file path for the currently active profile
@@ -254,5 +240,18 @@ func (pm *ProfileManager) GetActiveStateFilePath() (string, error) {
if err != nil {
return "", fmt.Errorf("failed to get active profile: %w", err)
}
return pm.GetStateFilePath(activeProfile.ID)
return pm.GetStateFilePath(activeProfile)
}
// sanitizeProfileName removes invalid characters from profile name
func sanitizeProfileName(name string) string {
// Keep only alphanumeric, underscore, and hyphen
var result strings.Builder
for _, r := range name {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') ||
(r >= '0' && r <= '9') || r == '_' || r == '-' {
result.WriteRune(r)
}
}
return result.String()
}

View File

@@ -96,19 +96,17 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
}
handle := activeProf.ID.String()
loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
IsUnixDesktopClient: isUnixRunningDesktop(),
Hostname: hostName,
DnsLabels: dnsLabelsReq,
ProfileName: &handle,
ProfileName: &activeProf.Name,
Username: &username,
}
profileState, err := pm.GetProfileState(activeProf.ID)
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
@@ -172,13 +170,14 @@ func getActiveProfile(ctx context.Context, pm *profilemanager.ProfileManager, pr
return activeProf, nil
}
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, handle string, username string) error {
resolvedID, err := switchProfile(ctx, handle, username)
func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManager, profileName string, username string) error {
err := switchProfile(context.Background(), profileName, username)
if err != nil {
return fmt.Errorf("switch profile on daemon: %v", err)
}
if err := pm.SwitchProfile(resolvedID); err != nil {
err = pm.SwitchProfile(profileName)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
@@ -206,15 +205,11 @@ func switchProfileOnDaemon(ctx context.Context, pm *profilemanager.ProfileManage
return nil
}
// switchProfile asks the daemon to switch to the profile identified by
// handle (a name, ID, or unique ID prefix). Returns the resolved profile
// ID so the caller can update the local active-profile state without
// re-resolving the handle.
func switchProfile(ctx context.Context, handle string, username string) (profilemanager.ID, error) {
func switchProfile(ctx context.Context, profileName string, username string) error {
conn, err := DialClientGRPCServer(ctx, daemonAddr)
if err != nil {
//nolint
return "", fmt.Errorf("failed to connect to daemon error: %v\n"+
return fmt.Errorf("failed to connect to daemon error: %v\n"+
"If the daemon is not running please run: "+
"\nnetbird service install \nnetbird service start\n", err)
}
@@ -222,15 +217,15 @@ func switchProfile(ctx context.Context, handle string, username string) (profile
client := proto.NewDaemonServiceClient(conn)
resp, err := client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &handle,
_, err = client.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profileName,
Username: &username,
})
if err != nil {
return "", fmt.Errorf("switch profile failed: %v", err)
return fmt.Errorf("switch profile failed: %v", err)
}
return profilemanager.ID(resp.Id), nil
return nil
}
func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, activeProf *profilemanager.Profile) error {
@@ -254,7 +249,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
return fmt.Errorf("read config file %s: %v", configFilePath, err)
}
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.ID)
err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -282,7 +277,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
return nil
}
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string, profileID profilemanager.ID) error {
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
authClient, err := auth.NewAuth(ctx, config.PrivateKey, config.ManagementURL, config)
if err != nil {
return fmt.Errorf("failed to create auth client: %v", err)
@@ -296,7 +291,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken := ""
if setupKey == "" && needsLogin {
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileID)
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err)
}
@@ -311,10 +306,10 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
return nil
}
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileID profilemanager.ID) (*auth.TokenInfo, error) {
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
hint := ""
pm := profilemanager.NewProfileManager()
profileState, err := pm.GetProfileState(profileID)
profileState, err := pm.GetProfileState(profileName)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {

View File

@@ -27,7 +27,7 @@ func TestLogin(t *testing.T) {
profilemanager.ActiveProfileStatePath = tempDir + "/active_profile.json"
sm := profilemanager.ServiceManager{}
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: "default",
Name: "default",
Username: currUser.Username,
})
if err != nil {

View File

@@ -2,16 +2,11 @@ package cmd
import (
"context"
"errors"
"fmt"
"os/user"
"strings"
"text/tabwriter"
"time"
"github.com/spf13/cobra"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
@@ -19,8 +14,6 @@ import (
"github.com/netbirdio/netbird/util"
)
var profileListShowID bool
var profileCmd = &cobra.Command{
Use: "profile",
Short: "Manage NetBird client profiles",
@@ -38,32 +31,27 @@ var profileListCmd = &cobra.Command{
var profileAddCmd = &cobra.Command{
Use: "add <profile_name>",
Short: "Add a new profile",
Long: `Add a new profile. Profile name is free-form, a unique ID is generated for the on-disk config file.`,
Long: `Add a new profile to the NetBird client. The profile name must be unique.`,
Args: cobra.ExactArgs(1),
RunE: addProfileFunc,
}
var profileRemoveCmd = &cobra.Command{
Use: "remove <profile>",
Short: "Remove a profile",
Long: `Remove a profile by name, ID, or unique ID prefix.`,
Aliases: []string{"rm"},
Args: cobra.ExactArgs(1),
RunE: removeProfileFunc,
Use: "remove <profile_name>",
Short: "Remove a profile",
Long: `Remove a profile from the NetBird client. The profile must not be inactive.`,
Args: cobra.ExactArgs(1),
RunE: removeProfileFunc,
}
var profileSelectCmd = &cobra.Command{
Use: "select <profile>",
Use: "select <profile_name>",
Short: "Select a profile",
Long: `Make the specified profile active. Accepts a name, ID, or unique ID prefix.`,
Long: `Make the specified profile active. This will switch the client to use the selected profile's configuration.`,
Args: cobra.ExactArgs(1),
RunE: selectProfileFunc,
}
func init() {
profileListCmd.Flags().BoolVar(&profileListShowID, "show-id", false, "show the profile ID column")
}
func setupCmd(cmd *cobra.Command) error {
SetFlagsFromEnvVars(rootCmd)
SetFlagsFromEnvVars(cmd)
@@ -77,7 +65,6 @@ func setupCmd(cmd *cobra.Command) error {
return nil
}
func listProfilesFunc(cmd *cobra.Command, _ []string) error {
if err := setupCmd(cmd); err != nil {
return err
@@ -96,33 +83,25 @@ func listProfilesFunc(cmd *cobra.Command, _ []string) error {
daemonClient := proto.NewDaemonServiceClient(conn)
resp, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
profiles, err := daemonClient.ListProfiles(cmd.Context(), &proto.ListProfilesRequest{
Username: currUser.Username,
})
if err != nil {
return err
}
tw := tabwriter.NewWriter(cmd.OutOrStdout(), 0, 0, 2, ' ', 0)
if profileListShowID {
fmt.Fprintln(tw, "ID\tNAME\tACTIVE")
} else {
fmt.Fprintln(tw, "NAME\tACTIVE")
}
for _, profile := range resp.Profiles {
marker := ""
// list profiles, add a tick if the profile is active
cmd.Println("Found", len(profiles.Profiles), "profiles:")
for _, profile := range profiles.Profiles {
// use a cross to indicate the passive profiles
activeMarker := "✗"
if profile.IsActive {
marker = "✓"
}
name := profilemanager.StripCtrlChars(profile.Name)
id := profilemanager.ID(profile.Id)
if profileListShowID {
fmt.Fprintf(tw, "%s\t%s\t%s\n", id.ShortID(), name, marker)
} else {
fmt.Fprintf(tw, "%s\t%s\n", name, marker)
activeMarker = "✓"
}
cmd.Println(activeMarker, profile.Name)
}
return tw.Flush()
return nil
}
func addProfileFunc(cmd *cobra.Command, args []string) error {
@@ -142,51 +121,19 @@ func addProfileFunc(cmd *cobra.Command, args []string) error {
}
daemonClient := proto.NewDaemonServiceClient(conn)
profileName := args[0]
resp, err := daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
_, err = daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
if err == nil {
id := profilemanager.ID(resp.Id)
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
return nil
}
if st, ok := gstatus.FromError(err); ok && st.Code() == codes.AlreadyExists {
dupCount, _ := countProfilesWithName(cmd.Context(), daemonClient, currUser.Username, profileName)
if dupCount > 0 {
cmd.Printf("Warning: %d other profile(s) already use the name %q.\n", dupCount, profileName)
cmd.Println("Use `netbird profile list --show-id` to disambiguate later.")
}
resp, err = daemonClient.AddProfile(cmd.Context(), &proto.AddProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
if err != nil {
return err
}
id := profilemanager.ID(resp.Id)
cmd.Printf("Profile added: %s %s\n", id.ShortID(), profilemanager.StripCtrlChars(profileName))
return nil
}
return err
}
func countProfilesWithName(ctx context.Context, c proto.DaemonServiceClient, username, name string) (int, error) {
resp, err := c.ListProfiles(ctx, &proto.ListProfilesRequest{Username: username})
if err != nil {
return 0, err
return err
}
n := 0
for _, p := range resp.Profiles {
if p.Name == name {
n++
}
}
return n, nil
cmd.Println("Profile added successfully:", profileName)
return nil
}
func removeProfileFunc(cmd *cobra.Command, args []string) error {
@@ -206,17 +153,18 @@ func removeProfileFunc(cmd *cobra.Command, args []string) error {
}
daemonClient := proto.NewDaemonServiceClient(conn)
handle := args[0]
resp, err := daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
ProfileName: handle,
profileName := args[0]
_, err = daemonClient.RemoveProfile(cmd.Context(), &proto.RemoveProfileRequest{
ProfileName: profileName,
Username: currUser.Username,
})
if err != nil {
return wrapAmbiguityError(err, handle)
return err
}
cmd.Printf("Profile removed: %s\n", resp.Id)
cmd.Println("Profile removed successfully:", profileName)
return nil
}
@@ -226,7 +174,7 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
}
profileManager := profilemanager.NewProfileManager()
handle := args[0]
profileName := args[0]
currUser, err := user.Current()
if err != nil {
@@ -243,15 +191,32 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
daemonClient := proto.NewDaemonServiceClient(conn)
switchResp, err := daemonClient.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &handle,
Username: &currUser.Username,
profiles, err := daemonClient.ListProfiles(ctx, &proto.ListProfilesRequest{
Username: currUser.Username,
})
if err != nil {
return wrapAmbiguityError(err, handle)
return fmt.Errorf("list profiles: %w", err)
}
if err := profileManager.SwitchProfile(profilemanager.ID(switchResp.Id)); err != nil {
var profileExists bool
for _, profile := range profiles.Profiles {
if profile.Name == profileName {
profileExists = true
break
}
}
if !profileExists {
return fmt.Errorf("profile %s does not exist", profileName)
}
if err := switchProfile(cmd.Context(), profileName, currUser.Username); err != nil {
return err
}
err = profileManager.SwitchProfile(profileName)
if err != nil {
return err
}
@@ -266,30 +231,6 @@ func selectProfileFunc(cmd *cobra.Command, args []string) error {
}
}
id := profilemanager.ID(switchResp.Id)
cmd.Printf("Profile switched to: %s\n", id.ShortID())
cmd.Println("Profile switched successfully to:", profileName)
return nil
}
// wrapAmbiguityError turns the daemon's gRPC InvalidArgument errors
// (which carry the resolver's message verbatim) into CLI-friendly text
// that points the user at --show-id.
func wrapAmbiguityError(err error, handle string) error {
if err == nil {
return nil
}
st, ok := gstatus.FromError(err)
if !ok {
return err
}
switch st.Code() {
case codes.InvalidArgument:
msg := st.Message()
if strings.Contains(msg, "ambiguous") {
return errors.New(msg + "\nRun `netbird profile list --show-id` to see IDs, then select by ID prefix:\n netbird profile select|remove <id-prefix>")
}
case codes.NotFound:
return fmt.Errorf("profile %q not found", handle)
}
return err
}

View File

@@ -128,12 +128,13 @@ func upFunc(cmd *cobra.Command, args []string) error {
var profileSwitched bool
// switch profile if provided
if profileName != "" {
resolvedID, err := switchProfile(cmd.Context(), profileName, username.Username)
err = switchProfile(cmd.Context(), profileName, username.Username)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
if err := pm.SwitchProfile(resolvedID); err != nil {
err = pm.SwitchProfile(profileName)
if err != nil {
return fmt.Errorf("switch profile: %v", err)
}
@@ -189,7 +190,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.ID)
err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
if err != nil {
return fmt.Errorf("foreground login failed: %v", err)
}
@@ -260,10 +261,10 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command, pm *profilemanager
}
// set the new config
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.ID.String(), username.Username)
req := setupSetConfigReq(customDNSAddressConverted, cmd, activeProf.Name, username.Username)
if _, err := client.SetConfig(ctx, req); err != nil {
if st, ok := gstatus.FromError(err); ok && st.Code() == codes.Unavailable {
log.Warnf("setConfig method is not available in the daemon: %s", st.Message())
log.Warnf("setConfig method is not available in the daemon")
} else {
return fmt.Errorf("call service setConfig method: %v", err)
}
@@ -288,11 +289,10 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
return fmt.Errorf("setup login request: %v", err)
}
profileID := activeProf.ID.String()
loginRequest.ProfileName = &profileID
loginRequest.ProfileName = &activeProf.Name
loginRequest.Username = &username
profileState, err := pm.GetProfileState(activeProf.ID)
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
@@ -329,7 +329,7 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
}
if _, err := client.Up(ctx, &proto.UpRequest{
ProfileName: &profileID,
ProfileName: &activeProf.Name,
Username: &username,
}); err != nil {
return fmt.Errorf("call service up method: %v", err)

View File

@@ -29,14 +29,14 @@ func TestUpDaemon(t *testing.T) {
}
sm := profilemanager.ServiceManager{}
created, err := sm.AddProfile("test1", currUser.Username)
err = sm.AddProfile("test1", currUser.Username)
if err != nil {
t.Fatalf("failed to add profile: %v", err)
return
}
err = sm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: created.ID,
Name: "test1",
Username: currUser.Username,
})
if err != nil {

View File

@@ -806,8 +806,6 @@ func (g *BundleGenerator) addSyncResponse() error {
AllowPartial: true,
}
g.maskSecrets()
jsonBytes, err := options.Marshal(g.syncResponse)
if err != nil {
return fmt.Errorf("generate json: %w", err)
@@ -820,27 +818,6 @@ func (g *BundleGenerator) addSyncResponse() error {
return nil
}
func (g *BundleGenerator) maskSecrets() {
if g.syncResponse == nil || g.syncResponse.NetbirdConfig == nil {
return
}
if g.syncResponse.NetbirdConfig.Flow != nil {
g.syncResponse.NetbirdConfig.Flow.TokenPayload = maskedValue
}
if g.syncResponse.NetbirdConfig.Relay != nil {
g.syncResponse.NetbirdConfig.Relay.TokenPayload = maskedValue
}
for i := range g.syncResponse.NetbirdConfig.Turns {
if g.syncResponse.NetbirdConfig.Turns[i] != nil {
g.syncResponse.NetbirdConfig.Turns[i].Password = maskedValue
}
}
}
func (g *BundleGenerator) addStateFile() error {
sm := profilemanager.NewServiceManager("")
path := sm.GetStatePath()

View File

@@ -843,7 +843,6 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
"PreSharedKey": "sensitive: WireGuard pre-shared key",
"SSHKey": "sensitive: SSH private key",
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
"Name": "non-config: profile name is not needed for debug purposes",
}
mURL, _ := url.Parse("https://api.example.com:443")

View File

@@ -103,10 +103,6 @@ type ConfigInput struct {
// Config Configuration type
type Config struct {
// Name is the human-readable profile name shown in CLI/UI listings.
// It is independent of the profile's on-disk filename (which is the ID).
Name string
// Wireguard private key of local peer
PrivateKey string
PreSharedKey string
@@ -252,16 +248,6 @@ func createNewConfig(input ConfigInput) (*Config, error) {
}
func (config *Config) apply(input ConfigInput) (updated bool, err error) {
if config.Name != "" {
sanitized, err := sanitizeDisplayName(config.Name)
if err != nil {
return false, fmt.Errorf("invalid profile name: %w", err)
}
if sanitized != config.Name {
config.Name = sanitized
updated = true
}
}
if config.ManagementURL == nil {
log.Infof("using default Management URL %s", DefaultManagementURL)
config.ManagementURL, err = parseURL("Management URL", DefaultManagementURL)

View File

@@ -1,118 +0,0 @@
package profilemanager
import (
"crypto/rand"
"encoding/hex"
"fmt"
"path/filepath"
"strings"
"unicode"
"unicode/utf8"
)
const (
// profileIDByteLen is the number of random bytes generated for a new
// profile ID. The resulting hex string is twice this length.
profileIDByteLen = 16
// shortIDLen is the number of leading characters of an ID we render in
// list output. Profiles per device are few, so 8 chars is collision-safe
// in practice and easy to type as a prefix.
shortIDLen = 8
// maxProfileNameLen caps the human-readable profile name to keep table
// output legible and prevent denial-of-service via huge JSON fields.
maxProfileNameLen = 128
// maxProfileIDLen bounds the on-disk filename we'll accept. New
// IDs are 32 hex chars, legacy stems are sanitized profile names. The
// cap is generous enough to cover both without permitting absurdly
// long filenames.
maxProfileIDLen = 64
)
type ID string
// generateProfileID returns a new random hex ID for a profile file.
func generateProfileID() (ID, error) {
buf := make([]byte, profileIDByteLen)
if _, err := rand.Read(buf); err != nil {
return "", fmt.Errorf("read random bytes: %w", err)
}
return ID(hex.EncodeToString(buf)), nil
}
// IsValidProfileFilenameStem reports whether id is safe to use as the stem
// of a profile JSON filename.
func IsValidProfileFilenameStem(id ID) bool {
s := id.String()
if s == "" || len(s) > maxProfileIDLen {
return false
}
if s == defaultProfileName {
return true
}
if strings.ContainsAny(s, `/\`) || strings.Contains(s, "..") {
return false
}
// filepath.Base catches any leftover separators on platforms with
// exotic path conventions.
if filepath.Base(s) != s {
return false
}
for _, r := range s {
if !(unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '-') {
return false
}
}
return true
}
// sanitizeDisplayName normalizes a user-supplied profile display name for
// storage. It strips ASCII control characters, rejects invalid UTF-8, and
// caps the length. Emojis, spaces, punctuation, and non-ASCII letters are
// preserved. Returns an error if nothing usable remains.
func sanitizeDisplayName(name string) (string, error) {
if !utf8.ValidString(name) {
return "", fmt.Errorf("name is not valid UTF-8")
}
name = StripCtrlChars(name)
name = strings.TrimSpace(name)
if name == "" {
return "", fmt.Errorf("name is empty after sanitization")
}
if utf8.RuneCountInString(name) > maxProfileNameLen {
return "", fmt.Errorf("name exceeds %d characters", maxProfileNameLen)
}
return name, nil
}
// StripCtrlChars control characters from a name before printing it.
func StripCtrlChars(name string) string {
var b strings.Builder
b.Grow(len(name))
for _, r := range name {
// Skip C0 controls and DEL, plus C1 controls (0x800x9F).
if r < 0x20 || r == 0x7F || (r >= 0x80 && r <= 0x9F) {
continue
}
b.WriteRune(r)
}
return b.String()
}
// ShortID truncates an ID for display.
func (id ID) ShortID() string {
if id == DefaultProfileName {
return DefaultProfileName
}
runes := []rune(id)
if len(runes) <= shortIDLen {
return id.String()
}
return string(runes[:shortIDLen])
}
func (id ID) String() string {
return string(id)
}

View File

@@ -19,41 +19,19 @@ const (
)
type Profile struct {
// ID is the on-disk filename stem (without .json). For new profiles
// it is a 32-char hex string; legacy profiles created before the
// ID-keyed layout keep their original name as their ID. The reserved
// value "default" identifies the special default profile.
ID ID
// Name is the human-readable display name. Falls back to ID when the
// underlying JSON has no "name" field set.
Name string
// Path is the absolute path to the profile JSON. Populated by the
// loader so callers do not have to reconstruct it from ID + dir.
Path string
Name string
IsActive bool
}
func (p *Profile) FilePath() (string, error) {
if p.Path != "" {
return p.Path, nil
if p.Name == "" {
return "", fmt.Errorf("active profile name is empty")
}
id := p.ID
if id == "" {
id = ID(p.Name)
}
if id == "" {
return "", fmt.Errorf("profile ID is empty")
}
if id == defaultProfileName {
if p.Name == defaultProfileName {
return DefaultConfigPath, nil
}
if !IsValidProfileFilenameStem(id) {
return "", fmt.Errorf("invalid profile ID: %q", id)
}
username, err := user.Current()
if err != nil {
return "", fmt.Errorf("failed to get current user: %w", err)
@@ -64,13 +42,10 @@ func (p *Profile) FilePath() (string, error) {
return "", fmt.Errorf("failed to get config directory for user %s: %w", username.Username, err)
}
return filepath.Join(configDir, id.String()+".json"), nil
return filepath.Join(configDir, p.Name+".json"), nil
}
func (p *Profile) IsDefault() bool {
if p.ID != "" {
return p.ID == defaultProfileName
}
return p.Name == defaultProfileName
}
@@ -82,24 +57,18 @@ func NewProfileManager() *ProfileManager {
return &ProfileManager{}
}
// GetActiveProfile returns the active profile as recorded in the local
// user state file. Only ID is populated.
func (pm *ProfileManager) GetActiveProfile() (*Profile, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
id := pm.getActiveProfileState()
return &Profile{ID: id}, nil
prof := pm.getActiveProfileState()
return &Profile{Name: prof}, nil
}
// SwitchProfile records the given profile ID as active in the local user
// state file.
func (pm *ProfileManager) SwitchProfile(id ID) error {
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid profile ID: %q", id)
}
func (pm *ProfileManager) SwitchProfile(profileName string) error {
profileName = sanitizeProfileName(profileName)
if err := pm.setActiveProfileState(id); err != nil {
if err := pm.setActiveProfileState(profileName); err != nil {
return fmt.Errorf("failed to switch profile: %w", err)
}
return nil
@@ -116,7 +85,7 @@ func sanitizeProfileName(name string) string {
}, name)
}
func (pm *ProfileManager) getActiveProfileState() ID {
func (pm *ProfileManager) getActiveProfileState() string {
configDir, err := getConfigDir()
if err != nil {
@@ -144,10 +113,10 @@ func (pm *ProfileManager) getActiveProfileState() ID {
return defaultProfileName
}
return ID(profileName)
return profileName
}
func (pm *ProfileManager) setActiveProfileState(id ID) error {
func (pm *ProfileManager) setActiveProfileState(profileName string) error {
configDir, err := getConfigDir()
if err != nil {
@@ -156,7 +125,7 @@ func (pm *ProfileManager) setActiveProfileState(id ID) error {
statePath := filepath.Join(configDir, activeProfileStateFilename)
err = os.WriteFile(statePath, []byte(id), 0600)
err = os.WriteFile(statePath, []byte(profileName), 0600)
if err != nil {
return fmt.Errorf("failed to write active profile state: %w", err)
}
@@ -173,7 +142,7 @@ func GetLoginHint() string {
return ""
}
profileState, err := pm.GetProfileState(activeProf.ID)
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
return ""

View File

@@ -50,14 +50,14 @@ func TestServiceManager_CreateAndGetDefaultProfile(t *testing.T) {
state, err := sm.GetActiveProfileState()
assert.NoError(t, err)
assert.Equal(t, defaultProfileName, state.ID.String()) // No active profile state yet
assert.Equal(t, state.Name, defaultProfileName) // No active profile state yet
err = sm.SetActiveProfileStateToDefault()
assert.NoError(t, err)
active, err := sm.GetActiveProfileState()
assert.NoError(t, err)
assert.Equal(t, "default", active.ID.String())
assert.Equal(t, "default", active.Name)
})
})
}
@@ -92,14 +92,14 @@ func TestServiceManager_SetActiveProfileState(t *testing.T) {
currUser, err := user.Current()
assert.NoError(t, err)
sm := &ServiceManager{}
state := &ActiveProfileState{ID: "foo", Username: currUser.Username}
state := &ActiveProfileState{Name: "foo", Username: currUser.Username}
err = sm.SetActiveProfileState(state)
assert.NoError(t, err)
// Should error on nil or incomplete state
err = sm.SetActiveProfileState(nil)
assert.Error(t, err)
err = sm.SetActiveProfileState(&ActiveProfileState{ID: "", Username: ""})
err = sm.SetActiveProfileState(&ActiveProfileState{Name: "", Username: ""})
assert.Error(t, err)
})
})

View File

@@ -2,7 +2,6 @@ package profilemanager
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
@@ -24,43 +23,12 @@ var (
DefaultConfigPathDir = ""
DefaultConfigPath = ""
ActiveProfileStatePath = ""
)
var (
ErrorOldDefaultConfigNotFound = errors.New("old default config not found")
)
// ErrAmbiguousHandle is returned when a profile handle (ID prefix or name)
// matches more than one profile. Callers can render Candidates to help the
// user disambiguate.
type ErrAmbiguousHandle struct {
Handle string
Candidates []Profile
Kind AmbiguityKind
}
// AmbiguityKind describes which matcher produced the ambiguity, so callers
// can tailor the error message.
type AmbiguityKind int
const (
AmbiguityKindIDPrefix AmbiguityKind = iota
AmbiguityKindName
)
// profileMeta is the minimal slice of a profile JSON we need, so we avoid
// reading all fields
type profileMeta struct {
Name string
}
func (e *ErrAmbiguousHandle) Error() string {
switch e.Kind {
case AmbiguityKindIDPrefix:
return fmt.Sprintf("ID prefix %q is ambiguous (matches %d profiles)", e.Handle, len(e.Candidates))
default:
return fmt.Sprintf("name %q is ambiguous (%d profiles share this name)", e.Handle, len(e.Candidates))
}
}
func init() {
DefaultConfigPathDir = "/var/lib/netbird/"
@@ -86,34 +54,25 @@ func init() {
}
type ActiveProfileState struct {
// ID is the on-disk filename stem of the active profile. The JSON tag stays
// as "name" for backwards compatibility with active state files written
// before the ID-based config files. Legacy values were profile names, which
// were also the legacy filename stems, so they still resolve to the correct
// file on disk.
ID ID `json:"name"`
Name string `json:"name"`
Username string `json:"username"`
}
func (a *ActiveProfileState) FilePath() (string, error) {
if a.ID == "" {
return "", fmt.Errorf("active profile ID is empty")
if a.Name == "" {
return "", fmt.Errorf("active profile name is empty")
}
if a.ID == defaultProfileName {
if a.Name == defaultProfileName {
return DefaultConfigPath, nil
}
if !IsValidProfileFilenameStem(a.ID) {
return "", fmt.Errorf("invalid profile ID: %q", a.ID)
}
configDir, err := getConfigDirForUser(a.Username)
if err != nil {
return "", fmt.Errorf("failed to get config directory for user %s: %w", a.Username, err)
}
return filepath.Join(configDir, a.ID.String()+".json"), nil
return filepath.Join(configDir, a.Name+".json"), nil
}
type ServiceManager struct {
@@ -219,7 +178,7 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
return nil, fmt.Errorf("failed to set active profile to default: %w", err)
}
return &ActiveProfileState{
ID: defaultProfileName,
Name: "default",
Username: "",
}, nil
} else {
@@ -227,12 +186,12 @@ func (s *ServiceManager) GetActiveProfileState() (*ActiveProfileState, error) {
}
}
if activeProfile.ID == "" {
if activeProfile.Name == "" {
if err := s.SetActiveProfileStateToDefault(); err != nil {
return nil, fmt.Errorf("failed to set active profile to default: %w", err)
}
return &ActiveProfileState{
ID: defaultProfileName,
Name: "default",
Username: "",
}, nil
}
@@ -257,29 +216,25 @@ func (s *ServiceManager) setDefaultActiveState() error {
}
func (s *ServiceManager) SetActiveProfileState(a *ActiveProfileState) error {
if a == nil || a.ID == "" {
if a == nil || a.Name == "" {
return errors.New("invalid active profile state")
}
if a.ID != defaultProfileName && a.Username == "" {
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.ID)
}
if a.ID != defaultProfileName && !IsValidProfileFilenameStem(a.ID) {
return fmt.Errorf("invalid profile ID: %q", a.ID)
if a.Name != defaultProfileName && a.Username == "" {
return fmt.Errorf("username must be set for non-default profiles, got: %s", a.Name)
}
if err := util.WriteJsonWithRestrictedPermission(context.Background(), ActiveProfileStatePath, a); err != nil {
return fmt.Errorf("failed to write active profile state: %w", err)
}
log.Infof("active profile set to %s for %s", a.ID, a.Username)
log.Infof("active profile set to %s for %s", a.Name, a.Username)
return nil
}
func (s *ServiceManager) SetActiveProfileStateToDefault() error {
return s.SetActiveProfileState(&ActiveProfileState{
ID: defaultProfileName,
Name: "default",
Username: "",
})
}
@@ -288,75 +243,57 @@ func (s *ServiceManager) DefaultProfilePath() string {
return DefaultConfigPath
}
// AddProfile creates a new profile with a generated ID. The user-supplied
// displayName is stored inside the JSON's name field, the on-disk filename
// uses the generated ID.
//
// The returned Profile carries the freshly-generated ID so callers can
// show it to the user (and so the gRPC AddProfileResponse can include
// it).
func (s *ServiceManager) AddProfile(displayName, username string) (*Profile, error) {
func (s *ServiceManager) AddProfile(profileName, username string) error {
configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("failed to get config directory: %w", err)
return fmt.Errorf("failed to get config directory: %w", err)
}
displayName, err = sanitizeDisplayName(displayName)
profileName = sanitizeProfileName(profileName)
if profileName == defaultProfileName {
return fmt.Errorf("cannot create profile with reserved name: %s", defaultProfileName)
}
profPath := filepath.Join(configDir, profileName+".json")
profileExists, err := fileExists(profPath)
if err != nil {
return nil, fmt.Errorf("invalid profile name: %w", err)
return fmt.Errorf("failed to check if profile exists: %w", err)
}
if profileExists {
return ErrProfileAlreadyExists
}
if displayName == defaultProfileName {
return nil, fmt.Errorf("cannot create profile with reserved name: %s", defaultProfileName)
}
id, err := generateProfileID()
if err != nil {
return nil, fmt.Errorf("generate profile id: %w", err)
}
profPath := filepath.Join(configDir, id.String()+".json")
cfg, err := createNewConfig(ConfigInput{ConfigPath: profPath})
if err != nil {
return nil, fmt.Errorf("failed to create new config: %w", err)
}
cfg.Name = displayName
if err := util.WriteJson(context.Background(), profPath, cfg); err != nil {
return nil, fmt.Errorf("failed to write profile config: %w", err)
return fmt.Errorf("failed to create new config: %w", err)
}
return &Profile{
ID: id,
Name: displayName,
Path: profPath,
}, nil
err = util.WriteJson(context.Background(), profPath, cfg)
if err != nil {
return fmt.Errorf("failed to write profile config: %w", err)
}
return nil
}
// RemoveProfile deletes the profile identified by id. Callers must have
// already resolved any user-supplied handle to a concrete ID via
// ResolveProfile.
func (s *ServiceManager) RemoveProfile(id ID, username string) error {
if id == defaultProfileName {
func (s *ServiceManager) RemoveProfile(profileName, username string) error {
configDir, err := s.getConfigDir(username)
if err != nil {
return fmt.Errorf("failed to get config directory: %w", err)
}
profileName = sanitizeProfileName(profileName)
if profileName == defaultProfileName {
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
}
if !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid profile ID: %q", id)
}
profiles, err := s.loadAllProfiles(username)
profPath := filepath.Join(configDir, profileName+".json")
profileExists, err := fileExists(profPath)
if err != nil {
return fmt.Errorf("load profiles: %w", err)
return fmt.Errorf("failed to check if profile exists: %w", err)
}
var target *Profile
for i := range profiles {
if profiles[i].ID == id {
target = &profiles[i]
break
}
}
if target == nil {
if !profileExists {
return ErrProfileNotFound
}
@@ -364,26 +301,57 @@ func (s *ServiceManager) RemoveProfile(id ID, username string) error {
if err != nil && !errors.Is(err, ErrNoActiveProfile) {
return fmt.Errorf("failed to get active profile: %w", err)
}
if activeProf != nil && activeProf.ID == id {
return fmt.Errorf("cannot remove active profile: %s", id)
if activeProf != nil && activeProf.Name == profileName {
return fmt.Errorf("cannot remove active profile: %s", profileName)
}
if err := util.RemoveJson(target.Path); err != nil {
err = util.RemoveJson(profPath)
if err != nil {
return fmt.Errorf("failed to remove profile config: %w", err)
}
stateFile := filepath.Join(filepath.Dir(target.Path), id.String()+".state.json")
if err := os.Remove(stateFile); err != nil && !os.IsNotExist(err) {
log.Warnf("failed to remove profile state file %s: %v", stateFile, err)
}
return nil
}
// ListProfiles returns every profile for the given user, including the
// default profile, with IsActive flags set.
func (s *ServiceManager) ListProfiles(username string) ([]Profile, error) {
return s.loadAllProfiles(username)
configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("failed to get config directory: %w", err)
}
files, err := util.ListFiles(configDir, "*.json")
if err != nil {
return nil, fmt.Errorf("failed to list profile files: %w", err)
}
var filtered []string
for _, file := range files {
if strings.HasSuffix(file, "state.json") {
continue // skip state files
}
filtered = append(filtered, file)
}
sort.Strings(filtered)
var activeProfName string
activeProf, err := s.GetActiveProfileState()
if err == nil {
activeProfName = activeProf.Name
}
var profiles []Profile
// add default profile always
profiles = append(profiles, Profile{Name: defaultProfileName, IsActive: activeProfName == "" || activeProfName == defaultProfileName})
for _, file := range filtered {
profileName := strings.TrimSuffix(filepath.Base(file), ".json")
var isActive bool
if activeProfName != "" && activeProfName == profileName {
isActive = true
}
profiles = append(profiles, Profile{Name: profileName, IsActive: isActive})
}
return profiles, nil
}
// GetStatePath returns the path to the state file based on the operating system
@@ -401,12 +369,7 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath
}
if activeProf.ID == defaultProfileName {
return defaultStatePath
}
if !IsValidProfileFilenameStem(activeProf.ID) {
log.Warnf("invalid active profile ID %q, using default state path", activeProf.ID)
if activeProf.Name == defaultProfileName {
return defaultStatePath
}
@@ -416,7 +379,7 @@ func (s *ServiceManager) GetStatePath() string {
return defaultStatePath
}
return filepath.Join(configDir, activeProf.ID.String()+".state.json")
return filepath.Join(configDir, activeProf.Name+".state.json")
}
// getConfigDir returns the profiles directory, using profilesDir if set, otherwise getConfigDirForUser
@@ -427,165 +390,3 @@ func (s *ServiceManager) getConfigDir(username string) (string, error) {
return getConfigDirForUser(username)
}
// loadAllProfiles returns every profile visible to the daemon for the
// given user, including the default profile. The returned slice is sorted
// by ID for a stable display order.
//
// Each Profile is fully populated: ID is the filename stem, Name comes
// from the JSON's "name" field (falling back to the filename stem when absent)
// and Path is built from a basename read off disk.
func (s *ServiceManager) loadAllProfiles(username string) ([]Profile, error) {
activeID, activeIsDefault := s.activeProfileID()
profiles := []Profile{{
ID: defaultProfileName,
Name: defaultProfileName,
Path: DefaultConfigPath,
IsActive: activeIsDefault,
}}
configDir, err := s.getConfigDir(username)
if err != nil {
return nil, fmt.Errorf("get config directory: %w", err)
}
entries, err := os.ReadDir(configDir)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return profiles, nil
}
return nil, fmt.Errorf("read profile directory: %w", err)
}
var fileProfiles []Profile
for _, entry := range entries {
if entry.IsDir() {
continue
}
base := entry.Name()
if !strings.HasSuffix(base, ".json") {
continue
}
if strings.HasSuffix(base, ".state.json") {
continue
}
stem := ID(strings.TrimSuffix(base, ".json"))
if stem == defaultProfileName {
// default lives at the top-level config dir, not under /<user>
continue
}
if !IsValidProfileFilenameStem(ID(stem)) {
continue
}
path := filepath.Join(configDir, base)
name := readProfileName(path)
if name == "" {
name = stem.String()
}
fileProfiles = append(fileProfiles, Profile{
ID: stem,
Name: name,
Path: path,
IsActive: stem == ID(activeID),
})
}
sort.Slice(fileProfiles, func(i, j int) bool {
if fileProfiles[i].Name != fileProfiles[j].Name {
return fileProfiles[i].Name < fileProfiles[j].Name
}
// Sort tie-break on ID so duplicate names always render in the same order.
return fileProfiles[i].ID < fileProfiles[j].ID
})
profiles = append(profiles, fileProfiles...)
return profiles, nil
}
// readProfileName parses just the "name" field from the profile Json.
func readProfileName(path string) string {
data, err := os.ReadFile(path)
if err != nil {
return ""
}
var meta profileMeta
if err := json.Unmarshal(data, &meta); err != nil {
return ""
}
return meta.Name
}
// activeProfileID returns the currently-active profile's ID. The second
// return value is true when the active profile is the default one.
func (s *ServiceManager) activeProfileID() (ID, bool) {
state, err := s.GetActiveProfileState()
if err != nil || state == nil {
return defaultProfileName, true
}
if state.ID == "" || state.ID == defaultProfileName {
return defaultProfileName, true
}
return state.ID, false
}
// ResolveProfile turns a user-supplied handle into a Profile. Resolution
// precedence is: exact ID match, then unique ID prefix, then unique exact
// name. Ambiguous matches return *ErrAmbiguousHandle so callers can
// surface the candidates.
func (s *ServiceManager) ResolveProfile(handle, username string) (*Profile, error) {
if handle == "" {
return nil, fmt.Errorf("profile handle is empty")
}
profiles, err := s.loadAllProfiles(username)
if err != nil {
return nil, err
}
for i := range profiles {
if profiles[i].ID == ID(handle) {
return &profiles[i], nil
}
}
// ID prefix match. Skip the default profile so `select d` does not
// accidentally pick it via prefix.
var prefixMatches []Profile
for i := range profiles {
if profiles[i].ID == defaultProfileName {
continue
}
if strings.HasPrefix(profiles[i].ID.String(), handle) {
prefixMatches = append(prefixMatches, profiles[i])
}
}
if len(prefixMatches) == 1 {
return &prefixMatches[0], nil
}
if len(prefixMatches) > 1 {
return nil, &ErrAmbiguousHandle{
Handle: handle,
Candidates: prefixMatches,
Kind: AmbiguityKindIDPrefix,
}
}
var nameMatches []Profile
for i := range profiles {
if profiles[i].Name == handle {
nameMatches = append(nameMatches, profiles[i])
}
}
if len(nameMatches) == 1 {
return &nameMatches[0], nil
}
if len(nameMatches) > 1 {
return nil, &ErrAmbiguousHandle{
Handle: handle,
Candidates: nameMatches,
Kind: AmbiguityKindName,
}
}
return nil, ErrProfileNotFound
}

View File

@@ -1,230 +0,0 @@
package profilemanager
import (
"context"
"errors"
"os"
"os/user"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/util"
)
// withTestSM wires up patched globals + a clean config dir and returns a
// fully initialized ServiceManager plus the username we are scoped to.
func withTestSM(t *testing.T, fn func(sm *ServiceManager, username string)) {
t.Helper()
withTempConfigDir(t, func(configDir string) {
withPatchedGlobals(t, configDir, func() {
u, err := user.Current()
require.NoError(t, err)
sm := &ServiceManager{}
require.NoError(t, sm.CreateDefaultProfile())
fn(sm, u.Username)
})
})
}
func TestServiceProfile_ExactID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
got, err := sm.ResolveProfile(created.ID.String(), username)
require.NoError(t, err)
assert.Equal(t, created.ID, got.ID)
assert.Equal(t, "work", got.Name)
})
}
func TestServiceProfile_IDPrefix(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
prefix := created.ID[:4]
got, err := sm.ResolveProfile(prefix.String(), username)
require.NoError(t, err)
assert.Equal(t, created.ID, got.ID)
})
}
func TestServiceProfile_AmbiguousPrefix(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
// Plant two profiles whose IDs share a known prefix by writing
// the files directly, since generated IDs are random.
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
for _, id := range []string{"abcd1111aaaa", "abcd2222bbbb"} {
path := filepath.Join(configDir, id+".json")
require.NoError(t, util.WriteJson(context.Background(), path, &Config{Name: id}))
}
_, err = sm.ResolveProfile("abcd", username)
var amb *ErrAmbiguousHandle
require.ErrorAs(t, err, &amb)
assert.Equal(t, AmbiguityKindIDPrefix, amb.Kind)
assert.Len(t, amb.Candidates, 2)
})
}
func TestServiceProfile_ExactNameUnique(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.AddProfile("work", username)
require.NoError(t, err)
got, err := sm.ResolveProfile("work", username)
require.NoError(t, err)
assert.Equal(t, "work", got.Name)
})
}
func TestServiceProfile_AmbiguousName(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.AddProfile("work", username)
require.NoError(t, err)
_, err = sm.AddProfile("work", username)
require.NoError(t, err)
_, err = sm.ResolveProfile("work", username)
var amb *ErrAmbiguousHandle
require.ErrorAs(t, err, &amb)
assert.Equal(t, AmbiguityKindName, amb.Kind)
assert.Len(t, amb.Candidates, 2)
})
}
func TestServiceProfile_NotFound(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
_, err := sm.ResolveProfile("nope", username)
assert.ErrorIs(t, err, ErrProfileNotFound)
})
}
func TestServiceProfile_DefaultByExactID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
got, err := sm.ResolveProfile(defaultProfileName, username)
require.NoError(t, err)
assert.Equal(t, defaultProfileName, got.ID.String())
})
}
func TestServiceProfile_LegacyFilenameCoexists(t *testing.T) {
// Legacy profiles stored as <name>.json with no "name" JSON field
// should still be discoverable by name and removable by name.
withTestSM(t, func(sm *ServiceManager, username string) {
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
path := filepath.Join(configDir, "legacy.json")
require.NoError(t, util.WriteJson(context.Background(), path, &Config{}))
got, err := sm.ResolveProfile("legacy", username)
require.NoError(t, err)
assert.Equal(t, "legacy", got.ID.String())
// Name falls back to the filename stem when JSON omits it.
assert.Equal(t, "legacy", got.Name)
})
}
func TestAddProfile_AllowsDuplicateWithFlag(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
first, err := sm.AddProfile("work", username)
require.NoError(t, err)
second, err := sm.AddProfile("work", username)
require.NoError(t, err)
assert.NotEqual(t, first.ID, second.ID)
assert.Equal(t, "work", second.Name)
})
}
func TestAddProfile_RejectsInvalidNames(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
cases := []string{
"", // empty
"\x00\x01", // only control chars (becomes empty)
strings.Repeat("a", maxProfileNameLen+1), // too long
}
for _, name := range cases {
_, err := sm.AddProfile(name, username)
assert.Error(t, err, "expected error for %q", name)
}
})
}
func TestRemoveProfile_RejectsInvalidID(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
err := sm.RemoveProfile("../escape", username)
assert.Error(t, err)
})
}
func TestSanitizeDisplayName(t *testing.T) {
cases := []struct {
in string
want string
wantErr bool
}{
{"work", "work", false},
{"My Work Account", "My Work Account", false},
{"emoji 🚀 ok", "emoji 🚀 ok", false},
{"漢字テスト", "漢字テスト", false},
{"with\x00null", "withnull", false},
{"\x01\x02\x03", "", true},
{"", "", true},
}
for _, tc := range cases {
got, err := sanitizeDisplayName(tc.in)
if tc.wantErr {
assert.Error(t, err, "case %q", tc.in)
continue
}
assert.NoError(t, err, "case %q", tc.in)
assert.Equal(t, tc.want, got, "case %q", tc.in)
}
}
func TestIsValidProfileFilenameStem(t *testing.T) {
cases := []struct {
in string
want bool
}{
{"default", true},
{"abc123def456", true},
{"legacy-name", true},
{"legacy_name", true},
{"", false},
{"..", false},
{"../etc", false},
{"foo/bar", false},
{`foo\bar`, false},
{"with space", false},
{"with.dot", false},
{strings.Repeat("a", maxProfileIDLen+1), false},
}
for _, tc := range cases {
got := IsValidProfileFilenameStem(ID(tc.in))
assert.Equal(t, tc.want, got, "case %q", tc.in)
}
}
func TestRemoveProfile_DeletesStateFile(t *testing.T) {
withTestSM(t, func(sm *ServiceManager, username string) {
created, err := sm.AddProfile("work", username)
require.NoError(t, err)
configDir, err := sm.getConfigDir(username)
require.NoError(t, err)
statePath := filepath.Join(configDir, created.ID.String()+".state.json")
require.NoError(t, os.WriteFile(statePath, []byte(`{"email":"a@b"}`), 0600))
require.NoError(t, sm.RemoveProfile(created.ID, username))
_, err = os.Stat(statePath)
assert.True(t, errors.Is(err, os.ErrNotExist), "state file should be removed")
})
}

View File

@@ -13,20 +13,13 @@ type ProfileState struct {
Email string `json:"email"`
}
// GetProfileState reads the per-profile state file keyed by profile ID.
// The state file lives in the user's config directory. Legacy state files
// keyed by the old profile name remain readable.
func (pm *ProfileManager) GetProfileState(id ID) (*ProfileState, error) {
func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, error) {
configDir, err := getConfigDir()
if err != nil {
return nil, fmt.Errorf("get config directory: %w", err)
}
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return nil, fmt.Errorf("invalid profile ID: %q", id)
}
stateFile := filepath.Join(configDir, id.String()+".state.json")
stateFile := filepath.Join(configDir, profileName+".state.json")
stateFileExists, err := fileExists(stateFile)
if err != nil {
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
@@ -58,12 +51,7 @@ func (pm *ProfileManager) SetActiveProfileState(state *ProfileState) error {
return fmt.Errorf("get active profile: %w", err)
}
id := activeProf.ID
if id != defaultProfileName && !IsValidProfileFilenameStem(id) {
return fmt.Errorf("invalid active profile ID: %q", id)
}
stateFile := filepath.Join(configDir, id.String()+".state.json")
stateFile := filepath.Join(configDir, activeProf.Name+".state.json")
err = util.WriteJsonWithRestrictedPermission(context.Background(), stateFile, state)
if err != nil {
return fmt.Errorf("write profile state: %w", err)

View File

@@ -3931,11 +3931,9 @@ func (x *GetEventsResponse) GetEvents() []*SystemEvent {
}
type SwitchProfileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
// profileName is treated as a handle: exact ID, unique ID prefix, or
// unique display name. The daemon resolves it server-side.
ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"`
state protoimpl.MessageState `protogen:"open.v1"`
ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,2,opt,name=username,proto3,oneof" json:"username,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -3985,11 +3983,7 @@ func (x *SwitchProfileRequest) GetUsername() string {
}
type SwitchProfileResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
// id is the resolved on-disk ID of the profile that became active.
// Lets CLI clients update their local active-profile state without
// duplicating the resolution logic.
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4024,13 +4018,6 @@ func (*SwitchProfileResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{55}
}
func (x *SwitchProfileResponse) GetId() string {
if x != nil {
return x.Id
}
return ""
}
type SetConfigRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
@@ -4387,11 +4374,9 @@ func (*SetConfigResponse) Descriptor() ([]byte, []int) {
}
type AddProfileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
// profileName carries the human-readable display name for the new
// profile. The on-disk filename is a separately-generated ID.
ProfileName string `protobuf:"bytes,2,opt,name=profileName,proto3" json:"profileName,omitempty"`
state protoimpl.MessageState `protogen:"open.v1"`
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
ProfileName string `protobuf:"bytes,2,opt,name=profileName,proto3" json:"profileName,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4441,10 +4426,7 @@ func (x *AddProfileRequest) GetProfileName() string {
}
type AddProfileResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
// id is the generated on-disk ID of the new profile. CLI clients
// display a truncated form, UI clients can ignore it.
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4479,19 +4461,10 @@ func (*AddProfileResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{59}
}
func (x *AddProfileResponse) GetId() string {
if x != nil {
return x.Id
}
return ""
}
type RemoveProfileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
// profileName is treated as a handle: an exact ID, a unique ID
// prefix, or a unique display name. Resolution happens server-side.
ProfileName string `protobuf:"bytes,2,opt,name=profileName,proto3" json:"profileName,omitempty"`
state protoimpl.MessageState `protogen:"open.v1"`
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
ProfileName string `protobuf:"bytes,2,opt,name=profileName,proto3" json:"profileName,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4541,10 +4514,7 @@ func (x *RemoveProfileRequest) GetProfileName() string {
}
type RemoveProfileResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
// id is the full resolved ID of the removed profile, so callers can
// confirm exactly which profile a name/prefix handle resolved to.
Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"`
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4579,13 +4549,6 @@ func (*RemoveProfileResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{61}
}
func (x *RemoveProfileResponse) GetId() string {
if x != nil {
return x.Id
}
return ""
}
type ListProfilesRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
Username string `protobuf:"bytes,1,opt,name=username,proto3" json:"username,omitempty"`
@@ -4678,7 +4641,6 @@ type Profile struct {
state protoimpl.MessageState `protogen:"open.v1"`
Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"`
IsActive bool `protobuf:"varint,2,opt,name=is_active,json=isActive,proto3" json:"is_active,omitempty"`
Id string `protobuf:"bytes,3,opt,name=id,proto3" json:"id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4727,13 +4689,6 @@ func (x *Profile) GetIsActive() bool {
return false
}
func (x *Profile) GetId() string {
if x != nil {
return x.Id
}
return ""
}
type GetActiveProfileRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -4774,7 +4729,6 @@ type GetActiveProfileResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
ProfileName string `protobuf:"bytes,1,opt,name=profileName,proto3" json:"profileName,omitempty"`
Username string `protobuf:"bytes,2,opt,name=username,proto3" json:"username,omitempty"`
Id string `protobuf:"bytes,3,opt,name=id,proto3" json:"id,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4823,13 +4777,6 @@ func (x *GetActiveProfileResponse) GetUsername() string {
return ""
}
func (x *GetActiveProfileResponse) GetId() string {
if x != nil {
return x.Id
}
return ""
}
type LogoutRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
ProfileName *string `protobuf:"bytes,1,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
@@ -6651,9 +6598,8 @@ const file_daemon_proto_rawDesc = "" +
"\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
"\f_profileNameB\v\n" +
"\t_username\"'\n" +
"\x15SwitchProfileResponse\x12\x0e\n" +
"\x02id\x18\x01 \x01(\tR\x02id\"\x98\x11\n" +
"\t_username\"\x17\n" +
"\x15SwitchProfileResponse\"\x98\x11\n" +
"\x10SetConfigRequest\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12 \n" +
"\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" +
@@ -6722,27 +6668,23 @@ const file_daemon_proto_rawDesc = "" +
"\x11SetConfigResponse\"Q\n" +
"\x11AddProfileRequest\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12 \n" +
"\vprofileName\x18\x02 \x01(\tR\vprofileName\"$\n" +
"\x12AddProfileResponse\x12\x0e\n" +
"\x02id\x18\x01 \x01(\tR\x02id\"T\n" +
"\vprofileName\x18\x02 \x01(\tR\vprofileName\"\x14\n" +
"\x12AddProfileResponse\"T\n" +
"\x14RemoveProfileRequest\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12 \n" +
"\vprofileName\x18\x02 \x01(\tR\vprofileName\"'\n" +
"\x15RemoveProfileResponse\x12\x0e\n" +
"\x02id\x18\x01 \x01(\tR\x02id\"1\n" +
"\vprofileName\x18\x02 \x01(\tR\vprofileName\"\x17\n" +
"\x15RemoveProfileResponse\"1\n" +
"\x13ListProfilesRequest\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\"C\n" +
"\x14ListProfilesResponse\x12+\n" +
"\bprofiles\x18\x01 \x03(\v2\x0f.daemon.ProfileR\bprofiles\"J\n" +
"\bprofiles\x18\x01 \x03(\v2\x0f.daemon.ProfileR\bprofiles\":\n" +
"\aProfile\x12\x12\n" +
"\x04name\x18\x01 \x01(\tR\x04name\x12\x1b\n" +
"\tis_active\x18\x02 \x01(\bR\bisActive\x12\x0e\n" +
"\x02id\x18\x03 \x01(\tR\x02id\"\x19\n" +
"\x17GetActiveProfileRequest\"h\n" +
"\tis_active\x18\x02 \x01(\bR\bisActive\"\x19\n" +
"\x17GetActiveProfileRequest\"X\n" +
"\x18GetActiveProfileResponse\x12 \n" +
"\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" +
"\busername\x18\x02 \x01(\tR\busername\x12\x0e\n" +
"\x02id\x18\x03 \x01(\tR\x02id\"t\n" +
"\busername\x18\x02 \x01(\tR\busername\"t\n" +
"\rLogoutRequest\x12%\n" +
"\vprofileName\x18\x01 \x01(\tH\x00R\vprofileName\x88\x01\x01\x12\x1f\n" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +

View File

@@ -615,18 +615,11 @@ message GetEventsResponse {
}
message SwitchProfileRequest {
// profileName is treated as a handle: exact ID, unique ID prefix, or
// unique display name. The daemon resolves it server-side.
optional string profileName = 1;
optional string username = 2;
}
message SwitchProfileResponse {
// id is the resolved on-disk ID of the profile that became active.
// Lets CLI clients update their local active-profile state without
// duplicating the resolution logic.
string id = 1;
}
message SwitchProfileResponse {}
message SetConfigRequest {
string username = 1;
@@ -693,29 +686,17 @@ message SetConfigResponse{}
message AddProfileRequest {
string username = 1;
// profileName carries the human-readable display name for the new
// profile. The on-disk filename is a separately-generated ID.
string profileName = 2;
}
message AddProfileResponse {
// id is the generated on-disk ID of the new profile. CLI clients
// display a truncated form, UI clients can ignore it.
string id = 1;
}
message AddProfileResponse {}
message RemoveProfileRequest {
string username = 1;
// profileName is treated as a handle: an exact ID, a unique ID
// prefix, or a unique display name. Resolution happens server-side.
string profileName = 2;
}
message RemoveProfileResponse {
// id is the full resolved ID of the removed profile, so callers can
// confirm exactly which profile a name/prefix handle resolved to.
string id = 1;
}
message RemoveProfileResponse {}
message ListProfilesRequest {
string username = 1;
@@ -728,7 +709,6 @@ message ListProfilesResponse {
message Profile {
string name = 1;
bool is_active = 2;
string id = 3;
}
message GetActiveProfileRequest {}
@@ -736,7 +716,6 @@ message GetActiveProfileRequest {}
message GetActiveProfileResponse {
string profileName = 1;
string username = 2;
string id = 3;
}
message LogoutRequest {

View File

@@ -79,7 +79,7 @@ func TestPersistLoginOverrides(t *testing.T) {
_, err := profilemanager.UpdateOrCreateConfig(seed)
require.NoError(t, err, "seed config")
activeProf := &profilemanager.ActiveProfileState{ID: "default"}
activeProf := &profilemanager.ActiveProfileState{Name: "default"}
err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK)
require.NoError(t, err, "persistLoginOverrides")

View File

@@ -308,14 +308,15 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
return nil, gstatus.Errorf(codes.Unavailable, errUpdateSettingsDisabled)
}
resolved, err := s.resolveProfileHandle(msg.ProfileName, msg.Username)
if err != nil {
log.Errorf("failed to resolve profile %q: %v", msg.ProfileName, err)
return nil, err
profState := profilemanager.ActiveProfileState{
Name: msg.ProfileName,
Username: msg.Username,
}
profPath := resolved.Path
if profPath == "" {
profPath = profilemanager.DefaultConfigPath
profPath, err := profState.FilePath()
if err != nil {
log.Errorf("failed to get active profile file path: %v", err)
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
}
var config profilemanager.ConfigInput
@@ -445,9 +446,30 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
}
if msg.ProfileName != nil {
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
log.Errorf("failed to switch profile: %v", err)
return nil, err
if *msg.ProfileName != "default" && (msg.Username == nil || *msg.Username == "") {
log.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName)
return nil, fmt.Errorf("profile name is set to %s, but username is not provided", *msg.ProfileName)
}
var username string
if *msg.ProfileName != "default" {
username = *msg.Username
}
if *msg.ProfileName != activeProf.Name && username != activeProf.Username {
if s.checkProfilesDisabled() {
log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled")
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
log.Infof("switching to profile %s for user '%s'", *msg.ProfileName, username)
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
Name: *msg.ProfileName,
Username: username,
}); err != nil {
log.Errorf("failed to set active profile state: %v", err)
return nil, fmt.Errorf("failed to set active profile state: %w", err)
}
}
}
@@ -457,7 +479,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
log.Infof("active profile: %s for %s", activeProf.ID, activeProf.Username)
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
s.mutex.Lock()
@@ -689,10 +711,10 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
}
if msg != nil && msg.ProfileName != nil {
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
s.mutex.Unlock()
log.Errorf("failed to switch profile: %v", err)
return nil, err
return nil, fmt.Errorf("failed to switch profile: %w", err)
}
}
@@ -703,7 +725,7 @@ func (s *Server) Up(callerCtx context.Context, msg *proto.UpRequest) (*proto.UpR
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
log.Infof("active profile: %s for %s", activeProf.ID, activeProf.Username)
log.Infof("active profile: %s for %s", activeProf.Name, activeProf.Username)
config, _, err := s.getConfig(activeProf)
if err != nil {
@@ -746,60 +768,34 @@ func (s *Server) waitForUp(callerCtx context.Context) (*proto.UpResponse, error)
}
}
// resolveProfileHandle resolves a wire-level profile handle (display
// name, ID, or unique ID prefix) to a concrete profile. Returns gRPC
// status errors so handlers can return them directly.
func (s *Server) resolveProfileHandle(handle, username string) (*profilemanager.Profile, error) {
p, err := s.profileManager.ResolveProfile(handle, username)
if err == nil {
return p, nil
}
var amb *profilemanager.ErrAmbiguousHandle
if errors.As(err, &amb) {
return nil, gstatus.Errorf(codes.InvalidArgument, "%v", amb)
}
if errors.Is(err, profilemanager.ErrProfileNotFound) {
return nil, gstatus.Errorf(codes.NotFound, "profile %q not found", handle)
}
return nil, fmt.Errorf("resolve profile: %w", err)
}
// switchProfileIfNeeded resolves the user-supplied handle, updates the
// active profile state if it differs from the current one, and returns
// the resolved profile so callers can include its ID in RPC responses.
func (s *Server) switchProfileIfNeeded(handle string, userName *string, activeProf *profilemanager.ActiveProfileState) (*profilemanager.Profile, error) {
if handle != profilemanager.DefaultProfileName && (userName == nil || *userName == "") {
log.Errorf("profile name is set to %s, but username is not provided", handle)
return nil, fmt.Errorf("profile name is set to %s, but username is not provided", handle)
func (s *Server) switchProfileIfNeeded(profileName string, userName *string, activeProf *profilemanager.ActiveProfileState) error {
if profileName != "default" && (userName == nil || *userName == "") {
log.Errorf("profile name is set to %s, but username is not provided", profileName)
return fmt.Errorf("profile name is set to %s, but username is not provided", profileName)
}
var username string
if handle != profilemanager.DefaultProfileName {
if profileName != "default" {
username = *userName
}
resolved, err := s.resolveProfileHandle(handle, username)
if err != nil {
return nil, err
}
if resolved.ID != activeProf.ID || username != activeProf.Username {
if profileName != activeProf.Name || username != activeProf.Username {
if s.checkProfilesDisabled() {
log.Errorf("profiles are disabled, you cannot use this feature without profiles enabled")
return nil, gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
log.Infof("switching to profile %s (%s) for user %s", resolved.Name, resolved.ID, username)
log.Infof("switching to profile %s for user %s", profileName, username)
if err := s.profileManager.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: resolved.ID,
Name: profileName,
Username: username,
}); err != nil {
log.Errorf("failed to set active profile state: %v", err)
return nil, fmt.Errorf("failed to set active profile state: %w", err)
return fmt.Errorf("failed to set active profile state: %w", err)
}
}
return resolved, nil
return nil
}
// SwitchProfile switches the active profile in the daemon.
@@ -814,9 +810,9 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
}
if msg != nil && msg.ProfileName != nil {
if _, err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
if err := s.switchProfileIfNeeded(*msg.ProfileName, msg.Username, activeProf); err != nil {
log.Errorf("failed to switch profile: %v", err)
return nil, err
return nil, fmt.Errorf("failed to switch profile: %w", err)
}
}
activeProf, err = s.profileManager.GetActiveProfileState()
@@ -832,7 +828,7 @@ func (s *Server) SwitchProfile(callerCtx context.Context, msg *proto.SwitchProfi
s.config = config
return &proto.SwitchProfileResponse{Id: activeProf.ID.String()}, nil
return &proto.SwitchProfileResponse{}, nil
}
// Down engine work in the daemon.
@@ -916,27 +912,22 @@ func (s *Server) Logout(ctx context.Context, msg *proto.LogoutRequest) (*proto.L
}
func (s *Server) handleProfileLogout(ctx context.Context, msg *proto.LogoutRequest) (*proto.LogoutResponse, error) {
if err := s.validateProfileOperation(*msg.ProfileName, true); err != nil {
return nil, err
}
if msg.Username == nil || *msg.Username == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "username must be provided when profile name is specified")
}
username := *msg.Username
resolved, err := s.resolveProfileHandle(*msg.ProfileName, username)
if err != nil {
return nil, err
}
if err := s.validateProfileOperation(resolved.ID, true); err != nil {
return nil, err
}
if err := s.logoutFromProfile(ctx, resolved); err != nil {
log.Errorf("failed to logout from profile %s: %v", resolved.ID, err)
if err := s.logoutFromProfile(ctx, *msg.ProfileName, username); err != nil {
log.Errorf("failed to logout from profile %s: %v", *msg.ProfileName, err)
return nil, gstatus.Errorf(codes.Internal, "logout: %v", err)
}
activeProf, _ := s.profileManager.GetActiveProfileState()
if activeProf != nil && activeProf.ID == resolved.ID {
if activeProf != nil && activeProf.Name == *msg.ProfileName {
if err := s.cleanupConnection(); err != nil && !errors.Is(err, ErrServiceNotUp) {
log.Errorf("failed to cleanup connection: %v", err)
}
@@ -998,30 +989,30 @@ func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*prof
return config, configExisted, nil
}
func (s *Server) canRemoveProfile(id profilemanager.ID) error {
if id == profilemanager.DefaultProfileName {
func (s *Server) canRemoveProfile(profileName string) error {
if profileName == profilemanager.DefaultProfileName {
return fmt.Errorf("remove profile with reserved name: %s", profilemanager.DefaultProfileName)
}
activeProf, err := s.profileManager.GetActiveProfileState()
if err == nil && activeProf.ID == id {
return fmt.Errorf("remove active profile: %s", id)
if err == nil && activeProf.Name == profileName {
return fmt.Errorf("remove active profile: %s", profileName)
}
return nil
}
func (s *Server) validateProfileOperation(id profilemanager.ID, allowActiveProfile bool) error {
func (s *Server) validateProfileOperation(profileName string, allowActiveProfile bool) error {
if s.checkProfilesDisabled() {
return gstatus.Errorf(codes.Unavailable, errProfilesDisabled)
}
if id == "" {
if profileName == "" {
return gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
}
if !allowActiveProfile {
if err := s.canRemoveProfile(id); err != nil {
if err := s.canRemoveProfile(profileName); err != nil {
return gstatus.Errorf(codes.InvalidArgument, "%v", err)
}
}
@@ -1029,20 +1020,25 @@ func (s *Server) validateProfileOperation(id profilemanager.ID, allowActiveProfi
return nil
}
func (s *Server) logoutFromProfile(ctx context.Context, profile *profilemanager.Profile) error {
// logoutFromProfile logs out from a specific profile by loading its config and sending logout request
func (s *Server) logoutFromProfile(ctx context.Context, profileName, username string) error {
activeProf, err := s.profileManager.GetActiveProfileState()
if err == nil && activeProf.ID == profile.ID && s.connectClient != nil {
if err == nil && activeProf.Name == profileName && s.connectClient != nil {
return s.sendLogoutRequest(ctx)
}
cfgPath := profile.Path
if cfgPath == "" {
cfgPath = profilemanager.DefaultConfigPath
profileState := &profilemanager.ActiveProfileState{
Name: profileName,
Username: username,
}
profilePath, err := profileState.FilePath()
if err != nil {
return fmt.Errorf("get profile path: %w", err)
}
config, err := profilemanager.GetConfig(cfgPath)
config, err := profilemanager.GetConfig(profilePath)
if err != nil {
return fmt.Errorf("profile '%s' not found", profile.ID)
return fmt.Errorf("profile '%s' not found", profileName)
}
return s.sendLogoutRequestWithConfig(ctx, config)
@@ -1456,14 +1452,15 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
return nil, ctx.Err()
}
resolved, err := s.resolveProfileHandle(req.ProfileName, req.Username)
if err != nil {
log.Errorf("failed to resolve profile %q: %v", req.ProfileName, err)
return nil, err
prof := profilemanager.ActiveProfileState{
Name: req.ProfileName,
Username: req.Username,
}
cfgPath := resolved.Path
if cfgPath == "" {
cfgPath = profilemanager.DefaultConfigPath
cfgPath, err := prof.FilePath()
if err != nil {
log.Errorf("failed to get active profile file path: %v", err)
return nil, fmt.Errorf("failed to get active profile file path: %w", err)
}
cfg, err := profilemanager.GetConfig(cfgPath)
@@ -1567,16 +1564,12 @@ func (s *Server) AddProfile(ctx context.Context, msg *proto.AddProfileRequest) (
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name and username must be provided")
}
created, err := s.profileManager.AddProfile(msg.ProfileName, msg.Username)
if err != nil {
if errors.Is(err, profilemanager.ErrProfileAlreadyExists) {
return nil, gstatus.Errorf(codes.AlreadyExists, "profile %q already exists", msg.ProfileName)
}
if err := s.profileManager.AddProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to create profile: %v", err)
return nil, fmt.Errorf("failed to create profile: %w", err)
}
return &proto.AddProfileResponse{Id: created.ID.String()}, nil
return &proto.AddProfileResponse{}, nil
}
// RemoveProfile removes a profile from the daemon.
@@ -1584,29 +1577,20 @@ func (s *Server) RemoveProfile(ctx context.Context, msg *proto.RemoveProfileRequ
s.mutex.Lock()
defer s.mutex.Unlock()
if msg.ProfileName == "" {
return nil, gstatus.Errorf(codes.InvalidArgument, "profile name must be provided")
}
resolved, err := s.resolveProfileHandle(msg.ProfileName, msg.Username)
if err != nil {
if err := s.validateProfileOperation(msg.ProfileName, false); err != nil {
return nil, err
}
if err := s.validateProfileOperation(resolved.ID, false); err != nil {
return nil, err
if err := s.logoutFromProfile(ctx, msg.ProfileName, msg.Username); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", msg.ProfileName, err)
}
if err := s.logoutFromProfile(ctx, resolved); err != nil {
log.Warnf("failed to logout from profile %s before removal: %v", resolved.ID, err)
}
if err := s.profileManager.RemoveProfile(resolved.ID, msg.Username); err != nil {
if err := s.profileManager.RemoveProfile(msg.ProfileName, msg.Username); err != nil {
log.Errorf("failed to remove profile: %v", err)
return nil, fmt.Errorf("failed to remove profile: %w", err)
}
return &proto.RemoveProfileResponse{Id: resolved.ID.String()}, nil
return &proto.RemoveProfileResponse{}, nil
}
// ListProfiles lists all profiles in the daemon.
@@ -1629,7 +1613,6 @@ func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesReques
}
for i, profile := range profiles {
response.Profiles[i] = &proto.Profile{
Id: profile.ID.String(),
Name: profile.Name,
IsActive: profile.IsActive,
}
@@ -1638,9 +1621,7 @@ func (s *Server) ListProfiles(ctx context.Context, msg *proto.ListProfilesReques
return response, nil
}
// GetActiveProfile returns the active profile in the daemon. The ProfileName
// field carries the display name for backwards compatibility with UI clients,
// new callers should prefer Id.
// GetActiveProfile returns the active profile in the daemon.
func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfileRequest) (*proto.GetActiveProfileResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
@@ -1651,23 +1632,9 @@ func (s *Server) GetActiveProfile(ctx context.Context, msg *proto.GetActiveProfi
return nil, fmt.Errorf("failed to get active profile state: %w", err)
}
// Fallback to legacy name == ID
displayName := activeProfile.ID.String()
if activeProfile.ID != profilemanager.DefaultProfileName {
if profiles, lerr := s.profileManager.ListProfiles(activeProfile.Username); lerr == nil {
for _, p := range profiles {
if p.ID == activeProfile.ID {
displayName = p.Name
break
}
}
}
}
return &proto.GetActiveProfileResponse{
ProfileName: displayName,
ProfileName: activeProfile.Name,
Username: activeProfile.Username,
Id: activeProfile.ID.String(),
}, nil
}

View File

@@ -97,7 +97,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: "test-profile",
Name: "test-profile",
Username: currUser.Username,
})
if err != nil {
@@ -158,7 +158,7 @@ func TestServer_Up(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: profilemanager.ID(profName),
Name: profName,
Username: currUser.Username,
})
if err != nil {
@@ -228,7 +228,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: "default",
Name: "default",
Username: currUser.Username,
})
if err != nil {

View File

@@ -47,7 +47,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
pm := profilemanager.ServiceManager{}
err = pm.SetActiveProfileState(&profilemanager.ActiveProfileState{
ID: profilemanager.ID(profName),
Name: profName,
Username: currUser.Username,
})
require.NoError(t, err)
@@ -96,7 +96,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
DisableNotifications: &disableNotifications,
LazyConnectionEnabled: &lazyConnectionEnabled,
BlockInbound: &blockInbound,
DisableIpv6: &disableIPv6,
DisableIpv6: &disableIPv6,
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
CleanNATExternalIPs: false,
CustomDNSAddress: []byte("1.1.1.1:53"),
@@ -112,7 +112,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.NoError(t, err)
profState := profilemanager.ActiveProfileState{
ID: profilemanager.ID(profName),
Name: profName,
Username: currUser.Username,
}
cfgPath, err := profState.FilePath()

View File

@@ -622,7 +622,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
}
req := &proto.SetConfigRequest{
ProfileName: activeProf.ID.String(),
ProfileName: activeProf.Name,
Username: currUser.Username,
}
@@ -787,15 +787,13 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe
return nil, fmt.Errorf("get current user: %w", err)
}
handle := activeProf.ID.String()
loginReq := &proto.LoginRequest{
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ProfileName: &handle,
ProfileName: &activeProf.Name,
Username: &currUser.Username,
}
profileState, err := s.profileManager.GetProfileState(activeProf.ID)
profileState, err := s.profileManager.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
@@ -1311,7 +1309,7 @@ func (s *serviceClient) getSrvConfig() {
}
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
ProfileName: activeProf.ID.String(),
ProfileName: activeProf.Name,
Username: currUser.Username,
})
if err != nil {
@@ -1535,7 +1533,7 @@ func (s *serviceClient) loadSettings() {
}
cfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
ProfileName: activeProf.ID.String(),
ProfileName: activeProf.Name,
Username: currUser.Username,
})
if err != nil {
@@ -1612,7 +1610,7 @@ func (s *serviceClient) updateConfig() error {
}
req := proto.SetConfigRequest{
ProfileName: activeProf.ID.String(),
ProfileName: activeProf.Name,
Username: currUser.Username,
DisableAutoConnect: &disableAutoStart,
ServerSSHAllowed: &sshAllowed,

View File

@@ -66,7 +66,7 @@ func (s *serviceClient) showProfilesUI() {
} else {
indicator.SetText("")
}
nameLabel.SetText(formatProfileLabel(profile, profiles))
nameLabel.SetText(profile.Name)
// Configure Select/Active button
selectBtn.SetText(func() string {
@@ -88,7 +88,7 @@ func (s *serviceClient) showProfilesUI() {
return
}
// switch
err = s.switchProfile(profile.ID)
err = s.switchProfile(profile.Name)
if err != nil {
log.Errorf("failed to switch profile: %v", err)
dialog.ShowError(errors.New("failed to select profile"), s.wProfiles)
@@ -130,7 +130,7 @@ func (s *serviceClient) showProfilesUI() {
logoutBtn.Show()
logoutBtn.SetText("Deregister")
logoutBtn.OnTapped = func() {
s.handleProfileLogout(profile, refresh)
s.handleProfileLogout(profile.Name, refresh)
}
// Remove profile
@@ -144,7 +144,7 @@ func (s *serviceClient) showProfilesUI() {
return
}
err = s.removeProfile(profile.ID)
err = s.removeProfile(profile.Name)
if err != nil {
log.Errorf("failed to remove profile: %v", err)
dialog.ShowError(fmt.Errorf("failed to remove profile"), s.wProfiles)
@@ -250,7 +250,7 @@ func (s *serviceClient) addProfile(profileName string) error {
return nil
}
func (s *serviceClient) switchProfile(handle string) error {
func (s *serviceClient) switchProfile(profileName string) error {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
return fmt.Errorf(getClientFMT, err)
@@ -261,15 +261,15 @@ func (s *serviceClient) switchProfile(handle string) error {
return fmt.Errorf("get current user: %w", err)
}
resp, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
ProfileName: &handle,
if _, err := conn.SwitchProfile(s.ctx, &proto.SwitchProfileRequest{
ProfileName: &profileName,
Username: &currUser.Username,
})
if err != nil {
}); err != nil {
return fmt.Errorf("switch profile failed: %w", err)
}
if err := s.profileManager.SwitchProfile(profilemanager.ID(resp.Id)); err != nil {
err = s.profileManager.SwitchProfile(profileName)
if err != nil {
return fmt.Errorf("switch profile: %w", err)
}
@@ -299,27 +299,10 @@ func (s *serviceClient) removeProfile(profileName string) error {
}
type Profile struct {
ID string
Name string
IsActive bool
}
// formatProfileLabel returns the display label for a profile. Profiles can
// share the same Name, so when more than one profile in profiles carries this
// Name, a short form of the ID is appended to disambiguate the entries.
func formatProfileLabel(profile Profile, profiles []Profile) string {
count := 0
for _, p := range profiles {
if p.Name == profile.Name {
count++
}
}
if count <= 1 {
return profile.Name
}
return fmt.Sprintf("%s (%s)", profile.Name, profilemanager.ID(profile.ID).ShortID())
}
func (s *serviceClient) getProfiles() ([]Profile, error) {
conn, err := s.getSrvClient(defaultFailTimeout)
if err != nil {
@@ -341,7 +324,6 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
for _, profile := range profilesResp.Profiles {
profiles = append(profiles, Profile{
ID: profile.Id,
Name: profile.Name,
IsActive: profile.IsActive,
})
@@ -350,10 +332,10 @@ func (s *serviceClient) getProfiles() ([]Profile, error) {
return profiles, nil
}
func (s *serviceClient) handleProfileLogout(profile Profile, refreshCallback func()) {
func (s *serviceClient) handleProfileLogout(profileName string, refreshCallback func()) {
dialog.ShowConfirm(
"Deregister",
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profile.Name),
fmt.Sprintf("Are you sure you want to deregister from '%s'?", profileName),
func(confirm bool) {
if !confirm {
return
@@ -374,10 +356,8 @@ func (s *serviceClient) handleProfileLogout(profile Profile, refreshCallback fun
}
username := currUser.Username
// ProfileName is treated as a handle; send the ID so the
// daemon resolves to exactly this profile.
_, err = conn.Logout(s.ctx, &proto.LogoutRequest{
ProfileName: &profile.ID,
ProfileName: &profileName,
Username: &username,
})
if err != nil {
@@ -388,7 +368,7 @@ func (s *serviceClient) handleProfileLogout(profile Profile, refreshCallback fun
dialog.ShowInformation(
"Deregistered",
fmt.Sprintf("Successfully deregistered from '%s'", profile.Name),
fmt.Sprintf("Successfully deregistered from '%s'", profileName),
s.wProfiles,
)
@@ -481,7 +461,6 @@ func (p *profileMenu) getProfiles() ([]Profile, error) {
for _, profile := range profilesResp.Profiles {
profiles = append(profiles, Profile{
ID: profile.Id,
Name: profile.Name,
IsActive: profile.IsActive,
})
@@ -522,7 +501,7 @@ func (p *profileMenu) refresh() {
}
if activeProf.ProfileName == "default" || activeProf.Username == currUser.Username {
activeProfState, err := p.profileManager.GetProfileState(profilemanager.ID(activeProf.Id))
activeProfState, err := p.profileManager.GetProfileState(activeProf.ProfileName)
if err != nil {
log.Warnf("failed to get active profile state: %v", err)
p.emailMenuItem.Hide()
@@ -533,7 +512,7 @@ func (p *profileMenu) refresh() {
}
for _, profile := range profiles {
item := p.profileMenuItem.AddSubMenuItem(formatProfileLabel(profile, profiles), "")
item := p.profileMenuItem.AddSubMenuItem(profile.Name, "")
if profile.IsActive {
item.Check()
}
@@ -562,8 +541,8 @@ func (p *profileMenu) refresh() {
return
}
switchResp, err := conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profile.ID,
_, err = conn.SwitchProfile(ctx, &proto.SwitchProfileRequest{
ProfileName: &profile.Name,
Username: &currUser.Username,
})
if err != nil {
@@ -573,7 +552,7 @@ func (p *profileMenu) refresh() {
return
}
err = p.profileManager.SwitchProfile(profilemanager.ID(switchResp.Id))
err = p.profileManager.SwitchProfile(profile.Name)
if err != nil {
log.Errorf("failed to switch profile '%s': %v", profile.Name, err)
return
@@ -716,10 +695,7 @@ func (p *profileMenu) updateMenu() {
}
sort.Slice(profiles, func(i, j int) bool {
if profiles[i].Name != profiles[j].Name {
return profiles[i].Name < profiles[j].Name
}
return profiles[i].ID < profiles[j].ID
return profiles[i].Name < profiles[j].Name
})
p.mu.Lock()

View File

@@ -120,7 +120,7 @@ func (s *BaseServer) EventStore() activity.Store {
func (s *BaseServer) APIHandler() http.Handler {
return Create(s, func() http.Handler {
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.InstanceManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.AuthMiddleware())
if err != nil {
log.Fatalf("failed to create API handler: %v", err)
}
@@ -153,6 +153,20 @@ func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
})
}
func (s *BaseServer) AuthMiddleware() mux.MiddlewareFunc {
return Create(s, func() mux.MiddlewareFunc {
m := middleware.NewAuthMiddleware(
s.AuthManager(),
s.AccountManager().GetAccountIDFromUserAuth,
s.AccountManager().SyncUserJWTGroups,
s.AccountManager().GetUserFromUserAuth,
s.RateLimiter(),
s.Metrics().GetMeter(),
)
return m.Handler
})
}
func (s *BaseServer) GRPCServer() *grpc.Server {
return Create(s, func() *grpc.Server {
trustedPeers := s.Config.ReverseProxy.TrustedPeers

View File

@@ -23,6 +23,7 @@ import (
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
@@ -151,6 +152,16 @@ func (s *BaseServer) IdpManager() idp.Manager {
})
}
func (s *BaseServer) InstanceManager() instance.Manager {
return Create(s, func() instance.Manager {
m, err := instance.NewManager(context.Background(), s.Store(), s.IdpManager())
if err != nil {
log.Fatalf("failed to create instance manager: %v", err)
}
return m
})
}
// OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
@@ -229,6 +240,3 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
})
}
func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool {
return false
}

View File

@@ -12,8 +12,6 @@ import (
goproto "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
@@ -147,9 +145,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
Checks: toProtocolChecks(ctx, checks),
}
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
response.NetbirdConfig = extendedConfig
response.NetbirdConfig = toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
response.NetworkMap.PeerConfig = response.PeerConfig
@@ -363,7 +359,6 @@ func toProtocolFirewallRules(rules []*types.FirewallRule, includeIPv6, useSource
return result
}
// populateSourcePrefixes sets SourcePrefixes on fwRule and returns any
// additional rules needed (e.g. a v6 wildcard clone when the peer IP is unspecified).
func populateSourcePrefixes(fwRule *proto.FirewallRule, rule *types.FirewallRule, includeIPv6 bool) []*proto.FirewallRule {

View File

@@ -978,7 +978,6 @@ func shallowCloneMapping(m *proto.ProxyMapping) *proto.ProxyMapping {
Mode: m.Mode,
ListenPort: m.ListenPort,
AccessRestrictions: m.AccessRestrictions,
Private: m.Private,
}
}

View File

@@ -1,88 +0,0 @@
package grpc
import (
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/proto"
)
// authTokenField is the only per-proxy field that shallowCloneMapping must NOT
// copy from the source, since callers assign it individually after cloning.
const authTokenField = "AuthToken"
// TestShallowCloneMapping_ClonesAllFields populates every exported field of
// ProxyMapping with a non-zero value and verifies the clone carries each one
// (except AuthToken). It uses reflection so adding a new field to ProxyMapping
// without updating shallowCloneMapping fails this test.
func TestShallowCloneMapping_ClonesAllFields(t *testing.T) {
src := &proto.ProxyMapping{}
populated := populateExportedFields(t, reflect.ValueOf(src).Elem())
require.NotEmpty(t, populated, "ProxyMapping should expose fields to populate")
clone := shallowCloneMapping(src)
require.NotNil(t, clone, "clone must not be nil")
srcVal := reflect.ValueOf(src).Elem()
cloneVal := reflect.ValueOf(clone).Elem()
for _, name := range populated {
srcField := srcVal.FieldByName(name).Interface()
cloneField := cloneVal.FieldByName(name).Interface()
if name == authTokenField {
assert.Zero(t, cloneField, "AuthToken must not be cloned; it is set per proxy after cloning")
continue
}
assert.Equal(t, srcField, cloneField, "field %s must be carried over by shallowCloneMapping", name)
}
}
// populateExportedFields sets a non-zero value on every settable exported field
// of the struct and returns their names.
func populateExportedFields(t *testing.T, v reflect.Value) []string {
t.Helper()
var names []string
typ := v.Type()
for i := 0; i < v.NumField(); i++ {
field := v.Field(i)
structField := typ.Field(i)
if structField.PkgPath != "" || !field.CanSet() {
continue
}
setNonZero(t, field, structField.Name)
names = append(names, structField.Name)
}
return names
}
// setNonZero assigns a deterministic non-zero value based on the field kind.
func setNonZero(t *testing.T, field reflect.Value, name string) {
t.Helper()
switch field.Kind() {
case reflect.String:
field.SetString("non-zero-" + name)
case reflect.Bool:
field.SetBool(true)
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
field.SetInt(7)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
field.SetUint(7)
case reflect.Ptr:
field.Set(reflect.New(field.Type().Elem()))
case reflect.Slice:
field.Set(reflect.MakeSlice(field.Type(), 1, 1))
case reflect.Map:
field.Set(reflect.MakeMapWithSize(field.Type(), 0))
default:
t.Fatalf("unhandled field kind %s for field %s; extend setNonZero", field.Kind(), name)
}
}

View File

@@ -12,7 +12,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/groups"
@@ -41,8 +40,6 @@ type TimeBasedAuthSecretsManager struct {
turnHmacToken *auth.TimedHMAC
relayHmacToken *authv2.Generator
updateManager network_map.PeersUpdateManager
settingsManager settings.Manager
groupsManager groups.Manager
turnCancelMap map[string]chan struct{}
relayCancelMap map[string]chan struct{}
wgKey wgtypes.Key
@@ -62,8 +59,6 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager
relayCfg: relayCfg,
turnCancelMap: make(map[string]chan struct{}),
relayCancelMap: make(map[string]chan struct{}),
settingsManager: settingsManager,
groupsManager: groupsManager,
wgKey: key,
}
@@ -239,8 +234,6 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont
}
}
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
Update: update,
@@ -266,26 +259,9 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac
},
}
m.extendNetbirdConfig(ctx, peerID, accountID, update)
log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID)
m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeControlConfig,
})
}
func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) {
extraSettings, err := m.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get extra settings: %v", err)
}
peerGroups, err := m.groupsManager.GetPeerGroupIDs(ctx, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peer groups: %v", err)
}
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peerID, peerGroups, update.NetbirdConfig, extraSettings)
update.NetbirdConfig = extendedConfig
}

View File

@@ -8,7 +8,6 @@ import (
"github.com/gorilla/mux"
"github.com/rs/cors"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
@@ -20,7 +19,6 @@ import (
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
idpmanager "github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/modules/zones"
@@ -34,7 +32,6 @@ import (
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
"github.com/netbirdio/netbird/management/server/auth"
"github.com/netbirdio/netbird/management/server/geolocation"
nbgroups "github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/http/handlers/accounts"
@@ -49,7 +46,6 @@ import (
"github.com/netbirdio/netbird/management/server/http/handlers/routes"
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
"github.com/netbirdio/netbird/management/server/http/handlers/users"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
@@ -59,7 +55,7 @@ import (
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) {
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, instanceManager nbinstance.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, authMiddleware mux.MiddlewareFunc) (http.Handler, error) {
// Register bypass paths for unauthenticated endpoints
if err := bypass.AddBypassPath("/api/instance"); err != nil {
@@ -80,32 +76,11 @@ func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager accou
return nil, fmt.Errorf("failed to add bypass path: %w", err)
}
if rateLimiter == nil {
log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
rateLimiter = middleware.NewAPIRateLimiter(nil)
rateLimiter.SetEnabled(false)
}
authMiddleware := middleware.NewAuthMiddleware(
authManager,
accountManager.GetAccountIDFromUserAuth,
accountManager.SyncUserJWTGroups,
accountManager.GetUserFromUserAuth,
rateLimiter,
appMetrics.GetMeter(),
isValidChildAccount,
)
corsMiddleware := cors.AllowAll()
metricsMiddleware := appMetrics.HTTPMiddleware()
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler)
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), idpManager)
if err != nil {
return nil, fmt.Errorf("failed to create instance manager: %w", err)
}
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware)
accounts.AddEndpoints(accountManager, settingsManager, router)
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)

View File

@@ -8,6 +8,7 @@ import (
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
"go.opentelemetry.io/otel/metric"
@@ -25,7 +26,8 @@ type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) err
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
type IsValidChildAccountFunc func(ctx context.Context, userID, accountID, childAccountID string) bool
// jwtTokenCtxKey carries the parsed JWT token.
type jwtTokenCtxKey struct{}
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
@@ -35,7 +37,6 @@ type AuthMiddleware struct {
syncUserJWTGroups SyncUserJWTGroupsFunc
rateLimiter *APIRateLimiter
patUsageTracker *PATUsageTracker
isValidChildAccount IsValidChildAccountFunc
}
// NewAuthMiddleware instance constructor
@@ -46,7 +47,6 @@ func NewAuthMiddleware(
getUserFromUserAuth GetUserFromUserAuthFunc,
rateLimiter *APIRateLimiter,
meter metric.Meter,
isValidChildAccount IsValidChildAccountFunc,
) *AuthMiddleware {
var patUsageTracker *PATUsageTracker
if meter != nil {
@@ -64,12 +64,18 @@ func NewAuthMiddleware(
getUserFromUserAuth: getUserFromUserAuth,
rateLimiter: rateLimiter,
patUsageTracker: patUsageTracker,
isValidChildAccount: isValidChildAccount,
}
}
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
// Handler composes the full authentication chain by wrapping the given
// handler with ValidationHandler followed by AccountAccessHandler.
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
return m.ValidationHandler(m.AccountAccessHandler(h))
}
// ValidationHandler authenticates the caller via JWT or PAT and stores the
// resulting UserAuth in the request context. It performs no account-level work.
func (m *AuthMiddleware) ValidationHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
return
@@ -86,14 +92,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
switch authType {
case "bearer":
if err := m.checkJWTFromRequest(r, authHeader); err != nil {
if err := m.validateJWT(r, authHeader); err != nil {
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
case "token":
if err := m.checkPATFromRequest(r, authHeader); err != nil {
if err := m.validatePAT(r, authHeader); err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
// Check if it's a status error, otherwise default to Unauthorized
if _, ok := status.FromError(err); !ok {
@@ -110,66 +116,55 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
})
}
// CheckJWTFromRequest checks if the JWT is valid
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
token, err := getTokenFromJWTRequest(authHeaderParts)
// AccountAccessHandler runs post-validation access checks for JWT-authenticated
// requests. PAT requests pass through unchanged.
func (m *AuthMiddleware) AccountAccessHandler(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
return
}
// If an error occurs, call the error handler and return an error
userAuth, err := nbcontext.GetUserAuthFromRequest(r)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
return
}
if userAuth.IsPAT {
h.ServeHTTP(w, r)
return
}
validatedToken, _ := r.Context().Value(jwtTokenCtxKey{}).(*jwt.Token)
if err := m.applyAccountAccess(r, userAuth, validatedToken); err != nil {
log.WithContext(r.Context()).Errorf("Error applying JWT account access: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
return
}
h.ServeHTTP(w, r)
})
}
func (m *AuthMiddleware) validateJWT(r *http.Request, authHeaderParts []string) error {
token, err := getTokenFromJWTRequest(authHeaderParts)
if err != nil {
return fmt.Errorf("error extracting token: %w", err)
}
ctx := r.Context()
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(r.Context(), token)
if err != nil {
return err
}
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
if m.isValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
userAuth.AccountId = impersonate[0]
userAuth.IsChild = true
}
}
// Email is now extracted in ToUserAuth (from claims or userinfo endpoint)
// Available as userAuth.Email
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
accountId, _, err := m.ensureAccount(ctx, userAuth)
if err != nil {
return err
}
if userAuth.AccountId != accountId {
log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
userAuth.AccountId = accountId
}
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
if err != nil {
return err
}
err = m.syncUserJWTGroups(ctx, userAuth)
if err != nil {
log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err)
}
_, err = m.getUserFromUserAuth(ctx, userAuth)
if err != nil {
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
return err
}
// propagates ctx change to upstream middleware
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
*r = *r.WithContext(context.WithValue(r.Context(), jwtTokenCtxKey{}, validatedToken))
return nil
}
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
func (m *AuthMiddleware) validatePAT(r *http.Request, authHeaderParts []string) error {
token, err := getTokenFromPATRequest(authHeaderParts)
if err != nil {
return fmt.Errorf("error extracting token: %w", err)
@@ -192,8 +187,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
return fmt.Errorf("token expired")
}
err = m.authManager.MarkPATUsed(ctx, pat.ID)
if err != nil {
if err := m.authManager.MarkPATUsed(ctx, pat.ID); err != nil {
return err
}
@@ -205,11 +199,40 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
IsPAT: true,
}
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
if m.isValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
userAuth.AccountId = impersonate[0]
userAuth.IsChild = true
}
// propagates ctx change to upstream middleware
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
return nil
}
// applyAccountAccess executes account-level checks for an authenticated JWT
// user: ensures the account exists, verifies access via JWT groups, syncs
// groups, and fetches the user record.
func (m *AuthMiddleware) applyAccountAccess(r *http.Request, userAuth auth.UserAuth, validatedToken *jwt.Token) error {
ctx := r.Context()
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
accountId, _, err := m.ensureAccount(ctx, userAuth)
if err != nil {
return err
}
if userAuth.AccountId != accountId {
log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
userAuth.AccountId = accountId
}
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
if err != nil {
return err
}
if err := m.syncUserJWTGroups(ctx, userAuth); err != nil {
log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err)
}
if _, err := m.getUserFromUserAuth(ctx, userAuth); err != nil {
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
return err
}
// propagates ctx change to upstream middleware

View File

@@ -211,7 +211,6 @@ func TestAuthMiddleware_Handler(t *testing.T) {
},
disabledLimiter,
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handlerToTest := authMiddleware.Handler(nextHandler)
@@ -271,7 +270,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -324,7 +322,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -368,7 +365,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -413,7 +409,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -478,7 +473,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -538,7 +532,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -594,7 +587,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
},
NewAPIRateLimiter(rateLimitConfig),
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -695,7 +687,6 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
},
disabledLimiter,
nil,
func(_ context.Context, _, _, _ string) bool { return false },
)
for _, tc := range tt {

View File

@@ -40,6 +40,7 @@ import (
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/groups"
http2 "github.com/netbirdio/netbird/management/server/http"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
@@ -136,8 +137,12 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
rateLimiter := middleware.NewAPIRateLimiter(nil)
rateLimiter.SetEnabled(false)
authMiddleware := middleware.NewAuthMiddleware(authManagerMock, am.GetAccountIDFromUserAuth, am.SyncUserJWTGroups, am.GetUserFromUserAuth, rateLimiter, metrics.GetMeter())
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, authMiddleware.Handler)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}
@@ -266,8 +271,12 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
rateLimiter := middleware.NewAPIRateLimiter(nil)
rateLimiter.SetEnabled(false)
authMiddleware := middleware.NewAuthMiddleware(authManagerMock, am.GetAccountIDFromUserAuth, am.SyncUserJWTGroups, am.GetUserFromUserAuth, rateLimiter, metrics.GetMeter())
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, authMiddleware.Handler)
if err != nil {
t.Fatalf("Failed to create API handler: %v", err)
}

View File

@@ -1216,7 +1216,6 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types
Preload("NetworkResources").
Preload("Onboarding").
Preload("Services.Targets").
Preload("Domains").
Take(&account, idQueryCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
@@ -1303,7 +1302,7 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
}
var wg sync.WaitGroup
errChan := make(chan error, 16)
errChan := make(chan error, 12)
wg.Add(1)
go func() {
@@ -1404,17 +1403,6 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.
account.Services = services
}()
wg.Add(1)
go func() {
defer wg.Done()
domains, err := s.ListCustomDomains(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.Domains = domains
}()
wg.Add(1)
go func() {
defer wg.Done()

View File

@@ -4,8 +4,6 @@ import (
"context"
"net"
"net/netip"
"os"
"runtime"
"testing"
"time"
@@ -23,63 +21,6 @@ import (
"github.com/netbirdio/netbird/route"
)
// TestGetAccount_LoadsCustomDomains verifies GetAccount populates account.Domains.
// SynthesizePrivateServiceZones depends on this relation to anchor a custom-domain
// private service's DNS zone; without the preload the relation is empty and the
// service is silently skipped, so a custom domain never resolves on clients.
func TestGetAccount_LoadsCustomDomains(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err)
defer cleanup()
assertGetAccountLoadsCustomDomains(t, store)
}
func TestPostgresql_GetAccount_LoadsCustomDomains(t *testing.T) {
if (os.Getenv("CI") == "true" && runtime.GOOS == "darwin") || runtime.GOOS == "windows" {
t.Skip("skip CI tests on darwin and windows")
}
t.Setenv("NETBIRD_STORE_ENGINE", string(types.PostgresStoreEngine))
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "", t.TempDir())
require.NoError(t, err)
t.Cleanup(cleanup)
assertGetAccountLoadsCustomDomains(t, store)
}
// assertGetAccountLoadsCustomDomains exercises both the gorm and pgx GetAccount
// paths: it persists two custom domains and asserts the relation comes back
// populated, which SynthesizePrivateServiceZones relies on.
func assertGetAccountLoadsCustomDomains(t *testing.T, store Store) {
t.Helper()
ctx := context.Background()
accountID := "acct-custom-domains"
require.NoError(t, store.SaveAccount(ctx, newAccountWithId(ctx, accountID, "user-1", "")))
_, err := store.CreateCustomDomain(ctx, accountID, "example.com", "eu.proxy.netbird.io", true)
require.NoError(t, err, "creating the first custom domain must succeed")
_, err = store.CreateCustomDomain(ctx, accountID, "apps.acme.io", "us.proxy.netbird.io", false)
require.NoError(t, err, "creating the second custom domain must succeed")
account, err := store.GetAccount(ctx, accountID)
require.NoError(t, err)
require.Len(t, account.Domains, 2, "GetAccount must preload the account's custom domains")
byDomain := map[string]string{}
for _, d := range account.Domains {
require.NotNil(t, d)
byDomain[d.Domain] = d.TargetCluster
}
assert.Equal(t, "eu.proxy.netbird.io", byDomain["example.com"], "custom domain must carry its target cluster")
assert.Equal(t, "us.proxy.netbird.io", byDomain["apps.acme.io"], "custom domain must carry its target cluster")
}
// TestGetAccount_ComprehensiveFieldValidation validates that GetAccount properly loads
// all fields and nested objects from the database, including deeply nested structures.
func TestGetAccount_ComprehensiveFieldValidation(t *testing.T) {

View File

@@ -273,7 +273,7 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
}
peerGroups := a.GetPeerGroups(peerID)
zonesByApex := map[string]*nbdns.CustomZone{}
zonesByCluster := map[string]*nbdns.CustomZone{}
for _, svc := range a.Services {
if svc == nil || !svc.Enabled || !svc.Private {
@@ -290,24 +290,19 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
continue
}
serviceDomainZone := a.privateServiceDomainZone(svc)
if serviceDomainZone == "" {
continue
}
zone, exists := zonesByApex[serviceDomainZone]
zone, exists := zonesByCluster[svc.ProxyCluster]
if !exists {
// NonAuthoritative makes this a match-only zone: queries for
// names without an explicit record fall through to the
// upstream resolver instead of returning NXDOMAIN. Without
// it, adding a single private service would black-hole every
// other name under the zone apex.
// other name under the cluster apex.
zone = &nbdns.CustomZone{
Domain: dns.Fqdn(serviceDomainZone),
Domain: dns.Fqdn(svc.ProxyCluster),
Records: []nbdns.SimpleRecord{},
NonAuthoritative: true,
}
zonesByApex[serviceDomainZone] = zone
zonesByCluster[svc.ProxyCluster] = zone
}
emitted := 0
@@ -345,8 +340,8 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
}
}
out := make([]nbdns.CustomZone, 0, len(zonesByApex))
for _, zone := range zonesByApex {
out := make([]nbdns.CustomZone, 0, len(zonesByCluster))
for _, zone := range zonesByCluster {
if len(zone.Records) == 0 {
continue
}
@@ -362,33 +357,6 @@ func (a *Account) SynthesizePrivateServiceZones(peerID string) []nbdns.CustomZon
return out
}
// privateServiceDomainZone returns the DNS zone name for the given private service domain by
// looking at the proxy cluster domain then the custom domains.
func (a *Account) privateServiceDomainZone(svc *service.Service) string {
if domainFromSuffix(svc.Domain, svc.ProxyCluster) {
return svc.ProxyCluster
}
// Longest matching custom domain wins
zoneName := ""
for _, d := range a.Domains {
if d == nil || d.TargetCluster != svc.ProxyCluster {
continue
}
if domainFromSuffix(svc.Domain, d.Domain) && len(d.Domain) > len(zoneName) {
zoneName = d.Domain
}
}
return zoneName
}
func domainFromSuffix(domain, suffix string) bool {
if suffix == "" {
return false
}
return domain == suffix || strings.HasSuffix(domain, "."+suffix)
}
// peerInDistributionGroups reports whether any of the peer's groups
// matches the service's bearer-auth distribution_groups.
func peerInDistributionGroups(peerGroups LookupMap, distributionGroups []string) bool {

View File

@@ -11,7 +11,6 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
@@ -235,113 +234,6 @@ func TestPrivateZone_GetPeerNetworkMap_PeerOutsideGroups_OmitsSynthZone(t *testi
assert.False(t, ok, "peer outside the distribution_groups must not see the synth zone")
}
func TestSynthesizePrivateServiceZones_CustomDomain_ZoneApexIsRegisteredDomain(t *testing.T) {
account := privateZoneTestAccount(t)
// A custom-domain service: Domain is the custom FQDN, ProxyCluster
// is the cluster serving it, and account.Domains holds the registered
// custom domain. The synth zone apex must be the registered domain,
// not the cluster, or the client's match-only zone never intercepts
// the query.
account.Services[0].Domain = "app.example.com"
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
}
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 1, "custom-domain service must still produce one zone")
zone := zones[0]
assert.Equal(t, "example.com.", zone.Domain, "zone apex must be the registered custom domain, not the cluster or the service FQDN")
assert.True(t, zone.NonAuthoritative, "synth zone must remain match-only")
require.Len(t, zone.Records, 1, "custom-domain service yields one A record")
rec := zone.Records[0]
assert.Equal(t, "app.example.com.", rec.Name, "record name is the custom service FQDN")
assert.Equal(t, "100.64.0.99", rec.RData, "record points at the embedded proxy peer's tunnel IP")
}
func TestSynthesizePrivateServiceZones_CustomAndFreeDomain_SeparateZones(t *testing.T) {
account := privateZoneTestAccount(t)
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
}
account.Services = append(account.Services, &service.Service{
ID: "svc-2",
AccountID: "acct-1",
Name: "custom",
Domain: "app.example.com",
ProxyCluster: "eu.proxy.netbird.io",
Enabled: true,
Private: true,
Mode: service.ModeHTTP,
AccessGroups: []string{"grp-admins"},
})
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 2, "a free-domain and a custom-domain service must not collapse into one zone")
free, ok := findCustomZone(zones, "eu.proxy.netbird.io")
require.True(t, ok, "free-domain service keeps the shared cluster-apex zone")
require.Len(t, free.Records, 1, "cluster zone carries only the free-domain record")
assert.Equal(t, "myapp.eu.proxy.netbird.io.", free.Records[0].Name, "cluster zone record is the free-domain FQDN")
custom, ok := findCustomZone(zones, "example.com")
require.True(t, ok, "custom-domain service gets its own zone at the registered custom domain apex")
require.Len(t, custom.Records, 1, "custom zone carries only the custom-domain record")
assert.Equal(t, "app.example.com.", custom.Records[0].Name, "custom zone record is the custom-domain FQDN")
}
func TestSynthesizePrivateServiceZones_TwoServicesSameCustomDomain_OneZone(t *testing.T) {
account := privateZoneTestAccount(t)
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
}
account.Services[0].Domain = "a.example.com"
account.Services = append(account.Services, &service.Service{
ID: "svc-2",
AccountID: "acct-1",
Name: "bapp",
Domain: "b.example.com",
ProxyCluster: "eu.proxy.netbird.io",
Enabled: true,
Private: true,
Mode: service.ModeHTTP,
AccessGroups: []string{"grp-admins"},
})
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 1, "two services under the same registered custom domain must share one zone")
assert.Equal(t, "example.com.", zones[0].Domain, "shared zone apex is the registered custom domain")
require.Len(t, zones[0].Records, 2, "both services surface as records in the shared custom-domain zone")
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
assert.ElementsMatch(t, []string{"a.example.com.", "b.example.com."}, names, "both custom-domain service FQDNs must surface")
}
func TestSynthesizePrivateServiceZones_CustomDomainNotRegistered_NoZone(t *testing.T) {
account := privateZoneTestAccount(t)
// Service domain is outside the cluster and no account.Domains entry
// covers it: there is no apex that would intercept the query, so the
// service must be skipped rather than emit an unmatchable record.
account.Services[0].Domain = "app.example.com"
zones := account.SynthesizePrivateServiceZones("user-peer")
assert.Empty(t, zones, "a custom-domain service with no registered domain apex must not produce a zone")
}
func TestSynthesizePrivateServiceZones_CustomDomainClusterMismatch_NoZone(t *testing.T) {
account := privateZoneTestAccount(t)
// The registered custom domain matches the service FQDN by suffix but
// targets a different cluster than the service's ProxyCluster. It must
// be ignored, leaving no apex to intercept the query — otherwise the
// zone would point at this cluster's proxy peers under a domain owned
// by a different cluster.
account.Services[0].Domain = "app.example.com"
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "us.proxy.netbird.io", Validated: true},
}
zones := account.SynthesizePrivateServiceZones("user-peer")
assert.Empty(t, zones, "a custom domain targeting a different cluster must not anchor the service zone")
}
func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing.T) {
account := privateZoneTestAccount(t)
account.Services = append(account.Services, &service.Service{
@@ -362,72 +254,3 @@ func TestSynthesizePrivateServiceZones_TwoServicesSameCluster_OneZone(t *testing
names := []string{zones[0].Records[0].Name, zones[0].Records[1].Name}
assert.ElementsMatch(t, []string{"myapp.eu.proxy.netbird.io.", "anotherapp.eu.proxy.netbird.io."}, names, "both service domains must surface")
}
func TestSynthesizePrivateServiceZones_MixedClusterCustomAndPublic(t *testing.T) {
account := privateZoneTestAccount(t)
account.Domains = []*proxydomain.Domain{
{Domain: "example.com", AccountID: "acct-1", TargetCluster: "eu.proxy.netbird.io", Validated: true},
}
privateService := func(id, domain string) *service.Service {
return &service.Service{
ID: id,
AccountID: "acct-1",
Name: id,
Domain: domain,
ProxyCluster: "eu.proxy.netbird.io",
Enabled: true,
Private: true,
Mode: service.ModeHTTP,
AccessGroups: []string{"grp-admins"},
}
}
publicService := func(id, domain string) *service.Service {
s := privateService(id, domain)
s.Private = false
return s
}
account.Services = []*service.Service{
// 3 private services under the cluster suffix.
privateService("cluster-1", "cluster1.eu.proxy.netbird.io"),
privateService("cluster-2", "cluster2.eu.proxy.netbird.io"),
privateService("cluster-3", "cluster3.eu.proxy.netbird.io"),
// 4 private services under the custom domain suffix.
privateService("custom-1", "custom1.example.com"),
privateService("custom-2", "custom2.example.com"),
privateService("custom-3", "custom3.example.com"),
privateService("custom-4", "custom4.example.com"),
// 2 public services, one per suffix, must not surface.
publicService("public-cluster", "public.eu.proxy.netbird.io"),
publicService("public-custom", "public.example.com"),
}
zones := account.SynthesizePrivateServiceZones("user-peer")
require.Len(t, zones, 2, "one zone per apex: the cluster apex and the custom domain apex")
cluster, ok := findCustomZone(zones, "eu.proxy.netbird.io")
require.True(t, ok, "cluster-suffix services collapse into the cluster-apex zone")
clusterNames := recordNames(cluster)
assert.ElementsMatch(t,
[]string{"cluster1.eu.proxy.netbird.io.", "cluster2.eu.proxy.netbird.io.", "cluster3.eu.proxy.netbird.io."},
clusterNames,
"only the 3 private cluster services surface in the cluster zone (public one excluded)")
custom, ok := findCustomZone(zones, "example.com")
require.True(t, ok, "custom-suffix services collapse into the custom-domain-apex zone")
customNames := recordNames(custom)
assert.ElementsMatch(t,
[]string{"custom1.example.com.", "custom2.example.com.", "custom3.example.com.", "custom4.example.com."},
customNames,
"only the 4 private custom services surface in the custom zone (public one excluded)")
}
// recordNames returns the record names of a zone for order-independent assertions.
func recordNames(zone nbdns.CustomZone) []string {
names := make([]string, 0, len(zone.Records))
for _, r := range zone.Records {
names = append(names, r.Name)
}
return names
}

BIN
ui

Binary file not shown.