[client] Add login_hint to oidc flows (#4724)

This commit is contained in:
Viktor Liu
2025-11-05 17:00:20 +01:00
committed by GitHub
parent c92e6c1b5f
commit 75327d9519
13 changed files with 109 additions and 23 deletions

View File

@@ -200,7 +200,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
} }
func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) { func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false) oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -106,6 +106,13 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str
Username: &username, Username: &username,
} }
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginRequest.Hint = &profileState.Email
}
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
loginRequest.OptionalPreSharedKey = &preSharedKey loginRequest.OptionalPreSharedKey = &preSharedKey
} }
@@ -241,7 +248,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
return fmt.Errorf("read config file %s: %v", configFilePath, err) return fmt.Errorf("read config file %s: %v", configFilePath, err)
} }
err = foregroundLogin(ctx, cmd, config, setupKey) err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@@ -269,7 +276,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo
return nil return nil
} }
func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error { func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error {
needsLogin := false needsLogin := false
err := WithBackOff(func() error { err := WithBackOff(func() error {
@@ -286,7 +293,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
jwtToken := "" jwtToken := ""
if setupKey == "" && needsLogin { if setupKey == "" && needsLogin {
tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config) tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName)
if err != nil { if err != nil {
return fmt.Errorf("interactive sso login failed: %v", err) return fmt.Errorf("interactive sso login failed: %v", err)
} }
@@ -315,8 +322,17 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
return nil return nil
} }
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) { func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop()) hint := ""
pm := profilemanager.NewProfileManager()
profileState, err := pm.GetProfileState(profileName)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
hint = profileState.Email
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
_, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath) _, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath)
err = foregroundLogin(ctx, cmd, config, providedSetupKey) err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name)
if err != nil { if err != nil {
return fmt.Errorf("foreground login failed: %v", err) return fmt.Errorf("foreground login failed: %v", err)
} }
@@ -286,6 +286,13 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ
loginRequest.ProfileName = &activeProf.Name loginRequest.ProfileName = &activeProf.Name
loginRequest.Username = &username loginRequest.Username = &username
profileState, err := pm.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginRequest.Hint = &profileState.Email
}
var loginErr error var loginErr error
var loginResp *proto.LoginResponse var loginResp *proto.LoginResponse

View File

@@ -128,9 +128,34 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow
deviceCode.VerificationURIComplete = deviceCode.VerificationURI deviceCode.VerificationURIComplete = deviceCode.VerificationURI
} }
if d.providerConfig.LoginHint != "" {
deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint)
if deviceCode.VerificationURI != "" {
deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint)
}
}
return deviceCode, err return deviceCode, err
} }
func appendLoginHint(uri, loginHint string) string {
if uri == "" || loginHint == "" {
return uri
}
parsedURL, err := url.Parse(uri)
if err != nil {
log.Debugf("failed to parse verification URI for login_hint: %v", err)
return uri
}
query := parsedURL.Query()
query.Set("login_hint", loginHint)
parsedURL.RawQuery = query.Encode()
return parsedURL.String()
}
func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) { func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) {
form := url.Values{} form := url.Values{}
form.Add("client_id", d.providerConfig.ClientID) form.Add("client_id", d.providerConfig.ClientID)

View File

@@ -66,32 +66,34 @@ func (t TokenInfo) GetTokenToUse() string {
// and if that also fails, the authentication process is deemed unsuccessful // and if that also fails, the authentication process is deemed unsuccessful
// //
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) { func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) {
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient { if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
return authenticateWithDeviceCodeFlow(ctx, config) return authenticateWithDeviceCodeFlow(ctx, config, hint)
} }
pkceFlow, err := authenticateWithPKCEFlow(ctx, config) pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint)
if err != nil { if err != nil {
// fallback to device code flow
log.Debugf("failed to initialize pkce authentication with error: %v\n", err) log.Debugf("failed to initialize pkce authentication with error: %v\n", err)
log.Debug("falling back to device code flow") log.Debug("falling back to device code flow")
return authenticateWithDeviceCodeFlow(ctx, config) return authenticateWithDeviceCodeFlow(ctx, config, hint)
} }
return pkceFlow, nil return pkceFlow, nil
} }
// authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow
func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) { func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair) pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair)
if err != nil { if err != nil {
return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err)
} }
pkceFlowInfo.ProviderConfig.LoginHint = hint
return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig)
} }
// authenticateWithDeviceCodeFlow initializes the Device Code auth Flow // authenticateWithDeviceCodeFlow initializes the Device Code auth Flow
func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) { func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) {
deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL)
if err != nil { if err != nil {
switch s, ok := gstatus.FromError(err); { switch s, ok := gstatus.FromError(err); {
@@ -107,5 +109,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.
} }
} }
deviceFlowInfo.ProviderConfig.LoginHint = hint
return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig) return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig)
} }

View File

@@ -109,6 +109,9 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
params = append(params, oauth2.SetAuthURLParam("max_age", "0")) params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
} }
} }
if p.providerConfig.LoginHint != "" {
params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint))
}
authURL := p.oAuthConfig.AuthCodeURL(state, params...) authURL := p.oAuthConfig.AuthCodeURL(state, params...)

View File

@@ -38,6 +38,8 @@ type DeviceAuthProviderConfig struct {
Scope string Scope string
// UseIDToken indicates if the id token should be used for authentication // UseIDToken indicates if the id token should be used for authentication
UseIDToken bool UseIDToken bool
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
} }
// GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it // GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it

View File

@@ -44,6 +44,8 @@ type PKCEAuthProviderConfig struct {
DisablePromptLogin bool DisablePromptLogin bool
// LoginFlag is used to configure the PKCE flow login behavior // LoginFlag is used to configure the PKCE flow login behavior
LoginFlag common.LoginFlag LoginFlag common.LoginFlag
// LoginHint is used to pre-fill the email/username field during authentication
LoginHint string
} }
// GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it // GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it

View File

@@ -228,7 +228,7 @@ func (c *Client) LoginForMobile() string {
ConfigPath: c.cfgFile, ConfigPath: c.cfgFile,
}) })
oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false) oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "")
if err != nil { if err != nil {
return err.Error() return err.Error()
} }

View File

@@ -279,8 +279,10 @@ type LoginRequest struct {
ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"`
Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"` Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"`
Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"`
unknownFields protoimpl.UnknownFields // hint is used to pre-fill the email/username field during SSO authentication
sizeCache protoimpl.SizeCache Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
} }
func (x *LoginRequest) Reset() { func (x *LoginRequest) Reset() {
@@ -538,6 +540,13 @@ func (x *LoginRequest) GetMtu() int64 {
return 0 return 0
} }
func (x *LoginRequest) GetHint() string {
if x != nil && x.Hint != nil {
return *x.Hint
}
return ""
}
type LoginResponse struct { type LoginResponse struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"` NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
@@ -4608,7 +4617,7 @@ var File_daemon_proto protoreflect.FileDescriptor
const file_daemon_proto_rawDesc = "" + const file_daemon_proto_rawDesc = "" +
"\n" + "\n" +
"\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" +
"\fEmptyRequest\"\xc3\x0e\n" + "\fEmptyRequest\"\xe5\x0e\n" +
"\fLoginRequest\x12\x1a\n" + "\fLoginRequest\x12\x1a\n" +
"\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" +
"\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" +
@@ -4645,7 +4654,8 @@ const file_daemon_proto_rawDesc = "" +
"\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" + "\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" +
"\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" + "\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" +
"\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" + "\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" +
"\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01B\x13\n" + "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01\x12\x17\n" +
"\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01B\x13\n" +
"\x11_rosenpassEnabledB\x10\n" + "\x11_rosenpassEnabledB\x10\n" +
"\x0e_interfaceNameB\x10\n" + "\x0e_interfaceNameB\x10\n" +
"\x0e_wireguardPortB\x17\n" + "\x0e_wireguardPortB\x17\n" +
@@ -4665,7 +4675,8 @@ const file_daemon_proto_rawDesc = "" +
"\x0e_block_inboundB\x0e\n" + "\x0e_block_inboundB\x0e\n" +
"\f_profileNameB\v\n" + "\f_profileNameB\v\n" +
"\t_usernameB\x06\n" + "\t_usernameB\x06\n" +
"\x04_mtu\"\xb5\x01\n" + "\x04_mtuB\a\n" +
"\x05_hint\"\xb5\x01\n" +
"\rLoginResponse\x12$\n" + "\rLoginResponse\x12$\n" +
"\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" + "\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" +
"\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" + "\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" +

View File

@@ -158,6 +158,9 @@ message LoginRequest {
optional string username = 31; optional string username = 31;
optional int64 mtu = 32; optional int64 mtu = 32;
// hint is used to pre-fill the email/username field during SSO authentication
optional string hint = 33;
} }
message LoginResponse { message LoginResponse {

View File

@@ -483,7 +483,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
state.Set(internal.StatusConnecting) state.Set(internal.StatusConnecting)
if msg.SetupKey == "" { if msg.SetupKey == "" {
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient) hint := ""
if msg.Hint != nil {
hint = *msg.Hint
}
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint)
if err != nil { if err != nil {
state.Set(internal.StatusLoginFailed) state.Set(internal.StatusLoginFailed)
return nil, err return nil, err

View File

@@ -610,11 +610,20 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe
return nil, fmt.Errorf("get current user: %w", err) return nil, fmt.Errorf("get current user: %w", err)
} }
loginResp, err := conn.Login(ctx, &proto.LoginRequest{ loginReq := &proto.LoginRequest{
IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd",
ProfileName: &activeProf.Name, ProfileName: &activeProf.Name,
Username: &currUser.Username, Username: &currUser.Username,
}) }
profileState, err := s.profileManager.GetProfileState(activeProf.Name)
if err != nil {
log.Debugf("failed to get profile state for login hint: %v", err)
} else if profileState.Email != "" {
loginReq.Hint = &profileState.Email
}
loginResp, err := conn.Login(ctx, loginReq)
if err != nil { if err != nil {
return nil, fmt.Errorf("login to management: %w", err) return nil, fmt.Errorf("login to management: %w", err)
} }