diff --git a/client/android/login.go b/client/android/login.go index 0df78dbc3..16df24ba8 100644 --- a/client/android/login.go +++ b/client/android/login.go @@ -200,7 +200,7 @@ func (a *Auth) login(urlOpener URLOpener) error { } func (a *Auth) foregroundGetTokenInfo(urlOpener URLOpener) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false) + oAuthFlow, err := auth.NewOAuthFlow(a.ctx, a.config, false, "") if err != nil { return nil, err } diff --git a/client/cmd/login.go b/client/cmd/login.go index 40b55f858..b0c877faa 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -106,6 +106,13 @@ func doDaemonLogin(ctx context.Context, cmd *cobra.Command, providedSetupKey str Username: &username, } + profileState, err := pm.GetProfileState(activeProf.Name) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + loginRequest.Hint = &profileState.Email + } + if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) { loginRequest.OptionalPreSharedKey = &preSharedKey } @@ -241,7 +248,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string, return fmt.Errorf("read config file %s: %v", configFilePath, err) } - err = foregroundLogin(ctx, cmd, config, setupKey) + err = foregroundLogin(ctx, cmd, config, setupKey, activeProf.Name) if err != nil { return fmt.Errorf("foreground login failed: %v", err) } @@ -269,7 +276,7 @@ func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.Lo return nil } -func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey string) error { +func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, setupKey, profileName string) error { needsLogin := false err := WithBackOff(func() error { @@ -286,7 +293,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman jwtToken := "" if setupKey == "" && needsLogin { - tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config) + tokenInfo, err := foregroundGetTokenInfo(ctx, cmd, config, profileName) if err != nil { return fmt.Errorf("interactive sso login failed: %v", err) } @@ -315,8 +322,17 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman return nil } -func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config) (*auth.TokenInfo, error) { - oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop()) +func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *profilemanager.Config, profileName string) (*auth.TokenInfo, error) { + hint := "" + pm := profilemanager.NewProfileManager() + profileState, err := pm.GetProfileState(profileName) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + hint = profileState.Email + } + + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop(), hint) if err != nil { return nil, err } diff --git a/client/cmd/up.go b/client/cmd/up.go index d047c041e..80175f7be 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -185,7 +185,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr _, _ = profilemanager.UpdateOldManagementURL(ctx, config, configFilePath) - err = foregroundLogin(ctx, cmd, config, providedSetupKey) + err = foregroundLogin(ctx, cmd, config, providedSetupKey, activeProf.Name) if err != nil { return fmt.Errorf("foreground login failed: %v", err) } @@ -286,6 +286,13 @@ func doDaemonUp(ctx context.Context, cmd *cobra.Command, client proto.DaemonServ loginRequest.ProfileName = &activeProf.Name loginRequest.Username = &username + profileState, err := pm.GetProfileState(activeProf.Name) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + loginRequest.Hint = &profileState.Email + } + var loginErr error var loginResp *proto.LoginResponse diff --git a/client/internal/auth/device_flow.go b/client/internal/auth/device_flow.go index da4f16c8d..8ca760742 100644 --- a/client/internal/auth/device_flow.go +++ b/client/internal/auth/device_flow.go @@ -128,9 +128,34 @@ func (d *DeviceAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlow deviceCode.VerificationURIComplete = deviceCode.VerificationURI } + if d.providerConfig.LoginHint != "" { + deviceCode.VerificationURIComplete = appendLoginHint(deviceCode.VerificationURIComplete, d.providerConfig.LoginHint) + if deviceCode.VerificationURI != "" { + deviceCode.VerificationURI = appendLoginHint(deviceCode.VerificationURI, d.providerConfig.LoginHint) + } + } + return deviceCode, err } +func appendLoginHint(uri, loginHint string) string { + if uri == "" || loginHint == "" { + return uri + } + + parsedURL, err := url.Parse(uri) + if err != nil { + log.Debugf("failed to parse verification URI for login_hint: %v", err) + return uri + } + + query := parsedURL.Query() + query.Set("login_hint", loginHint) + parsedURL.RawQuery = query.Encode() + + return parsedURL.String() +} + func (d *DeviceAuthorizationFlow) requestToken(info AuthFlowInfo) (TokenRequestResponse, error) { form := url.Values{} form.Add("client_id", d.providerConfig.ClientID) diff --git a/client/internal/auth/oauth.go b/client/internal/auth/oauth.go index 4458f600c..9fbd6cf5f 100644 --- a/client/internal/auth/oauth.go +++ b/client/internal/auth/oauth.go @@ -66,32 +66,34 @@ func (t TokenInfo) GetTokenToUse() string { // and if that also fails, the authentication process is deemed unsuccessful // // On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow -func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool) (OAuthFlow, error) { +func NewOAuthFlow(ctx context.Context, config *profilemanager.Config, isUnixDesktopClient bool, hint string) (OAuthFlow, error) { if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient { - return authenticateWithDeviceCodeFlow(ctx, config) + return authenticateWithDeviceCodeFlow(ctx, config, hint) } - pkceFlow, err := authenticateWithPKCEFlow(ctx, config) + pkceFlow, err := authenticateWithPKCEFlow(ctx, config, hint) if err != nil { - // fallback to device code flow log.Debugf("failed to initialize pkce authentication with error: %v\n", err) log.Debug("falling back to device code flow") - return authenticateWithDeviceCodeFlow(ctx, config) + return authenticateWithDeviceCodeFlow(ctx, config, hint) } return pkceFlow, nil } // authenticateWithPKCEFlow initializes the Proof Key for Code Exchange flow auth flow -func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) { +func authenticateWithPKCEFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { pkceFlowInfo, err := internal.GetPKCEAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL, config.ClientCertKeyPair) if err != nil { return nil, fmt.Errorf("getting pkce authorization flow info failed with error: %v", err) } + + pkceFlowInfo.ProviderConfig.LoginHint = hint + return NewPKCEAuthorizationFlow(pkceFlowInfo.ProviderConfig) } // authenticateWithDeviceCodeFlow initializes the Device Code auth Flow -func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config) (OAuthFlow, error) { +func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager.Config, hint string) (OAuthFlow, error) { deviceFlowInfo, err := internal.GetDeviceAuthorizationFlowInfo(ctx, config.PrivateKey, config.ManagementURL) if err != nil { switch s, ok := gstatus.FromError(err); { @@ -107,5 +109,7 @@ func authenticateWithDeviceCodeFlow(ctx context.Context, config *profilemanager. } } + deviceFlowInfo.ProviderConfig.LoginHint = hint + return NewDeviceAuthorizationFlow(deviceFlowInfo.ProviderConfig) } diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 8741e8636..738d3e34f 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -109,6 +109,9 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn params = append(params, oauth2.SetAuthURLParam("max_age", "0")) } } + if p.providerConfig.LoginHint != "" { + params = append(params, oauth2.SetAuthURLParam("login_hint", p.providerConfig.LoginHint)) + } authURL := p.oAuthConfig.AuthCodeURL(state, params...) diff --git a/client/internal/device_auth.go b/client/internal/device_auth.go index 6bd29801d..7f7d06130 100644 --- a/client/internal/device_auth.go +++ b/client/internal/device_auth.go @@ -38,6 +38,8 @@ type DeviceAuthProviderConfig struct { Scope string // UseIDToken indicates if the id token should be used for authentication UseIDToken bool + // LoginHint is used to pre-fill the email/username field during authentication + LoginHint string } // GetDeviceAuthorizationFlowInfo initialize a DeviceAuthorizationFlow instance and return with it diff --git a/client/internal/pkce_auth.go b/client/internal/pkce_auth.go index a713bb342..23c92e8af 100644 --- a/client/internal/pkce_auth.go +++ b/client/internal/pkce_auth.go @@ -44,6 +44,8 @@ type PKCEAuthProviderConfig struct { DisablePromptLogin bool // LoginFlag is used to configure the PKCE flow login behavior LoginFlag common.LoginFlag + // LoginHint is used to pre-fill the email/username field during authentication + LoginHint string } // GetPKCEAuthorizationFlowInfo initialize a PKCEAuthorizationFlow instance and return with it diff --git a/client/ios/NetBirdSDK/client.go b/client/ios/NetBirdSDK/client.go index 2109d4b15..fa1c89aab 100644 --- a/client/ios/NetBirdSDK/client.go +++ b/client/ios/NetBirdSDK/client.go @@ -228,7 +228,7 @@ func (c *Client) LoginForMobile() string { ConfigPath: c.cfgFile, }) - oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false) + oAuthFlow, err := auth.NewOAuthFlow(ctx, cfg, false, "") if err != nil { return err.Error() } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 841e3c0f7..02f09b08a 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -279,8 +279,10 @@ type LoginRequest struct { ProfileName *string `protobuf:"bytes,30,opt,name=profileName,proto3,oneof" json:"profileName,omitempty"` Username *string `protobuf:"bytes,31,opt,name=username,proto3,oneof" json:"username,omitempty"` Mtu *int64 `protobuf:"varint,32,opt,name=mtu,proto3,oneof" json:"mtu,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // hint is used to pre-fill the email/username field during SSO authentication + Hint *string `protobuf:"bytes,33,opt,name=hint,proto3,oneof" json:"hint,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *LoginRequest) Reset() { @@ -538,6 +540,13 @@ func (x *LoginRequest) GetMtu() int64 { return 0 } +func (x *LoginRequest) GetHint() string { + if x != nil && x.Hint != nil { + return *x.Hint + } + return "" +} + type LoginResponse struct { state protoimpl.MessageState `protogen:"open.v1"` NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"` @@ -4608,7 +4617,7 @@ var File_daemon_proto protoreflect.FileDescriptor const file_daemon_proto_rawDesc = "" + "\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + - "\fEmptyRequest\"\xc3\x0e\n" + + "\fEmptyRequest\"\xe5\x0e\n" + "\fLoginRequest\x12\x1a\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + @@ -4645,7 +4654,8 @@ const file_daemon_proto_rawDesc = "" + "\rblock_inbound\x18\x1d \x01(\bH\x10R\fblockInbound\x88\x01\x01\x12%\n" + "\vprofileName\x18\x1e \x01(\tH\x11R\vprofileName\x88\x01\x01\x12\x1f\n" + "\busername\x18\x1f \x01(\tH\x12R\busername\x88\x01\x01\x12\x15\n" + - "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01B\x13\n" + + "\x03mtu\x18 \x01(\x03H\x13R\x03mtu\x88\x01\x01\x12\x17\n" + + "\x04hint\x18! \x01(\tH\x14R\x04hint\x88\x01\x01B\x13\n" + "\x11_rosenpassEnabledB\x10\n" + "\x0e_interfaceNameB\x10\n" + "\x0e_wireguardPortB\x17\n" + @@ -4665,7 +4675,8 @@ const file_daemon_proto_rawDesc = "" + "\x0e_block_inboundB\x0e\n" + "\f_profileNameB\v\n" + "\t_usernameB\x06\n" + - "\x04_mtu\"\xb5\x01\n" + + "\x04_mtuB\a\n" + + "\x05_hint\"\xb5\x01\n" + "\rLoginResponse\x12$\n" + "\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" + "\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" + diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index 5b27b4d98..8d1080051 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -158,6 +158,9 @@ message LoginRequest { optional string username = 31; optional int64 mtu = 32; + + // hint is used to pre-fill the email/username field during SSO authentication + optional string hint = 33; } message LoginResponse { diff --git a/client/server/server.go b/client/server/server.go index 3641e6f92..6699cdadc 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -483,7 +483,11 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro state.Set(internal.StatusConnecting) if msg.SetupKey == "" { - oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient) + hint := "" + if msg.Hint != nil { + hint = *msg.Hint + } + oAuthFlow, err := auth.NewOAuthFlow(ctx, config, msg.IsUnixDesktopClient, hint) if err != nil { state.Set(internal.StatusLoginFailed) return nil, err diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index f4350b251..e580be56d 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -610,11 +610,20 @@ func (s *serviceClient) login(ctx context.Context, openURL bool) (*proto.LoginRe return nil, fmt.Errorf("get current user: %w", err) } - loginResp, err := conn.Login(ctx, &proto.LoginRequest{ + loginReq := &proto.LoginRequest{ IsUnixDesktopClient: runtime.GOOS == "linux" || runtime.GOOS == "freebsd", ProfileName: &activeProf.Name, Username: &currUser.Username, - }) + } + + profileState, err := s.profileManager.GetProfileState(activeProf.Name) + if err != nil { + log.Debugf("failed to get profile state for login hint: %v", err) + } else if profileState.Email != "" { + loginReq.Hint = &profileState.Email + } + + loginResp, err := conn.Login(ctx, loginReq) if err != nil { return nil, fmt.Errorf("login to management: %w", err) }