diff --git a/client/server/login_overrides_test.go b/client/server/login_overrides_test.go new file mode 100644 index 000000000..c45557c59 --- /dev/null +++ b/client/server/login_overrides_test.go @@ -0,0 +1,93 @@ +package server + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/profilemanager" +) + +func TestPersistLoginOverrides(t *testing.T) { + strPtr := func(s string) *string { return &s } + + tests := []struct { + name string + initialMgmtURL string + initialPSK string + newMgmtURL string + newPSK *string + wantMgmtURL string + wantPSK string + }{ + { + name: "persist new management URL", + initialMgmtURL: "https://old.example.com:33073", + newMgmtURL: "https://new.example.com:33073", + wantMgmtURL: "https://new.example.com:33073", + }, + { + name: "persist new pre-shared key", + initialMgmtURL: "https://existing.example.com:33073", + initialPSK: "old-key", + newPSK: strPtr("new-key"), + wantMgmtURL: "https://existing.example.com:33073", + wantPSK: "new-key", + }, + { + name: "persist both", + initialMgmtURL: "https://old.example.com:33073", + initialPSK: "old-key", + newMgmtURL: "https://new.example.com:33073", + newPSK: strPtr("new-key"), + wantMgmtURL: "https://new.example.com:33073", + wantPSK: "new-key", + }, + { + name: "no inputs preserves existing", + initialMgmtURL: "https://existing.example.com:33073", + initialPSK: "existing-key", + wantMgmtURL: "https://existing.example.com:33073", + wantPSK: "existing-key", + }, + { + name: "empty PSK pointer is ignored", + initialMgmtURL: "https://existing.example.com:33073", + initialPSK: "existing-key", + newPSK: strPtr(""), + wantMgmtURL: "https://existing.example.com:33073", + wantPSK: "existing-key", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + origDefault := profilemanager.DefaultConfigPath + t.Cleanup(func() { profilemanager.DefaultConfigPath = origDefault }) + + dir := t.TempDir() + profilemanager.DefaultConfigPath = filepath.Join(dir, "default.json") + + seed := profilemanager.ConfigInput{ + ConfigPath: profilemanager.DefaultConfigPath, + ManagementURL: tt.initialMgmtURL, + } + if tt.initialPSK != "" { + seed.PreSharedKey = strPtr(tt.initialPSK) + } + _, err := profilemanager.UpdateOrCreateConfig(seed) + require.NoError(t, err, "seed config") + + activeProf := &profilemanager.ActiveProfileState{Name: "default"} + err = persistLoginOverrides(activeProf, tt.newMgmtURL, tt.newPSK) + require.NoError(t, err, "persistLoginOverrides") + + cfg, err := profilemanager.ReadConfig(profilemanager.DefaultConfigPath) + require.NoError(t, err, "read back config") + + require.Equal(t, tt.wantMgmtURL, cfg.ManagementURL.String(), "management URL") + require.Equal(t, tt.wantPSK, cfg.PreSharedKey, "pre-shared key") + }) + } +} diff --git a/client/server/server.go b/client/server/server.go index 648ffa8ce..a793ca46f 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -489,6 +489,10 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro s.mutex.Unlock() + if err := persistLoginOverrides(activeProf, msg.ManagementUrl, msg.OptionalPreSharedKey); err != nil { + return nil, err + } + config, _, err := s.getConfig(activeProf) if err != nil { log.Errorf("failed to get active profile config: %v", err) @@ -963,7 +967,33 @@ func (s *Server) handleActiveProfileLogout(ctx context.Context) (*proto.LogoutRe return &proto.LogoutResponse{}, nil } -// GetConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist +// persistLoginOverrides writes management URL and pre-shared key from a LoginRequest to the +// active profile config so that subsequent reads pick them up. Empty/nil values are ignored. +func persistLoginOverrides(activeProf *profilemanager.ActiveProfileState, managementURL string, preSharedKey *string) error { + if preSharedKey != nil && *preSharedKey == "" { + preSharedKey = nil + } + if managementURL == "" && preSharedKey == nil { + return nil + } + + cfgPath, err := activeProf.FilePath() + if err != nil { + return fmt.Errorf("get active profile file path: %w", err) + } + + input := profilemanager.ConfigInput{ + ConfigPath: cfgPath, + ManagementURL: managementURL, + PreSharedKey: preSharedKey, + } + if _, err := profilemanager.UpdateOrCreateConfig(input); err != nil { + return fmt.Errorf("persist login overrides: %w", err) + } + return nil +} + +// getConfig reads config file and returns Config and whether the config file already existed. Errors out if it does not exist func (s *Server) getConfig(activeProf *profilemanager.ActiveProfileState) (*profilemanager.Config, bool, error) { cfgPath, err := activeProf.FilePath() if err != nil {