From 0d79301141d9b856d05a000e1e5a21517f5aad5e Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Mon, 17 Nov 2025 15:28:20 +0100 Subject: [PATCH] Update client login success page (#4797) --- client/internal/auth/pkce_flow.go | 11 +- client/internal/templates/pkce-auth-msg.html | 155 ++++----- .../internal/templates/pkce_auth_msg_test.go | 299 ++++++++++++++++++ 3 files changed, 386 insertions(+), 79 deletions(-) create mode 100644 client/internal/templates/pkce_auth_msg_test.go diff --git a/client/internal/auth/pkce_flow.go b/client/internal/auth/pkce_flow.go index 738d3e34f..48873f640 100644 --- a/client/internal/auth/pkce_flow.go +++ b/client/internal/auth/pkce_flow.go @@ -192,17 +192,20 @@ func (p *PKCEAuthorizationFlow) handleRequest(req *http.Request) (*oauth2.Token, if authError := query.Get(queryError); authError != "" { authErrorDesc := query.Get(queryErrorDesc) - return nil, fmt.Errorf("%s.%s", authError, authErrorDesc) + if authErrorDesc != "" { + return nil, fmt.Errorf("authentication failed: %s", authErrorDesc) + } + return nil, fmt.Errorf("authentication failed: %s", authError) } // Prevent timing attacks on the state if state := query.Get(queryState); subtle.ConstantTimeCompare([]byte(p.state), []byte(state)) == 0 { - return nil, fmt.Errorf("invalid state") + return nil, fmt.Errorf("authentication failed: Invalid state") } code := query.Get(queryCode) if code == "" { - return nil, fmt.Errorf("missing code") + return nil, fmt.Errorf("authentication failed: missing code") } return p.oAuthConfig.Exchange( @@ -231,7 +234,7 @@ func (p *PKCEAuthorizationFlow) parseOAuthToken(token *oauth2.Token) (TokenInfo, } if err := isValidAccessToken(tokenInfo.GetTokenToUse(), audience); err != nil { - return TokenInfo{}, fmt.Errorf("validate access token failed with error: %v", err) + return TokenInfo{}, fmt.Errorf("authentication failed: invalid access token - %w", err) } email, err := parseEmailFromIDToken(tokenInfo.IDToken) diff --git a/client/internal/templates/pkce-auth-msg.html b/client/internal/templates/pkce-auth-msg.html index 4825c48e7..175a6f05c 100644 --- a/client/internal/templates/pkce-auth-msg.html +++ b/client/internal/templates/pkce-auth-msg.html @@ -1,88 +1,93 @@ + - + + + + NetBird Login + + + - NetBird Login Successful + -
- -
- {{ if .Error }} - - - - -
-
- Login failed +
+
+ + +
+ + + + + + + + + + + + + + + + + + +
- {{ .Error }}. -
- {{ else }} - - - - -
-
- Login successful + +
+ +
+ + {{ if .Error }} + +
+ + + + +
+ {{ else }} + +
+ + + + +
+ {{ end }} + + +
+ {{ if .Error }} +

Login Failed

+ {{ else }} +

Login Successful

+ {{ end }} +
+ + + {{ if .Error }} +
+ {{ .Error }} +
+ {{ else }} + +
+ Your device is now registered and logged in to NetBird. You can now close this window. +
+ {{ end }} + +
- Your device is now registered and logged in to NetBird. -
- You can now close this window.
- {{ end }}
+ diff --git a/client/internal/templates/pkce_auth_msg_test.go b/client/internal/templates/pkce_auth_msg_test.go new file mode 100644 index 000000000..75b1c9e76 --- /dev/null +++ b/client/internal/templates/pkce_auth_msg_test.go @@ -0,0 +1,299 @@ +package templates + +import ( + "html/template" + "os" + "path/filepath" + "testing" +) + +func TestPKCEAuthMsgTemplate(t *testing.T) { + tests := []struct { + name string + data map[string]string + outputFile string + expectedTitle string + expectedInContent []string + notExpectedInContent []string + }{ + { + name: "error_state", + data: map[string]string{ + "Error": "authentication failed: invalid state", + }, + outputFile: "pkce-auth-error.html", + expectedTitle: "Login Failed", + expectedInContent: []string{ + "authentication failed: invalid state", + "Login Failed", + }, + notExpectedInContent: []string{ + "Login Successful", + "Your device is now registered and logged in to NetBird", + }, + }, + { + name: "success_state", + data: map[string]string{ + // No error field means success + }, + outputFile: "pkce-auth-success.html", + expectedTitle: "Login Successful", + expectedInContent: []string{ + "Login Successful", + "Your device is now registered and logged in to NetBird. You can now close this window.", + }, + notExpectedInContent: []string{ + "Login Failed", + }, + }, + { + name: "error_state_timeout", + data: map[string]string{ + "Error": "authentication timeout: request expired after 5 minutes", + }, + outputFile: "pkce-auth-timeout.html", + expectedTitle: "Login Failed", + expectedInContent: []string{ + "authentication timeout: request expired after 5 minutes", + "Login Failed", + }, + notExpectedInContent: []string{ + "Login Successful", + "Your device is now registered and logged in to NetBird", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Parse the template + tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl) + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + // Create temp directory for this test + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, tt.outputFile) + + // Create output file + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + + // Execute the template + if err := tmpl.Execute(file, tt.data); err != nil { + file.Close() + t.Fatalf("Failed to execute template: %v", err) + } + file.Close() + + t.Logf("Generated test output: %s", outputPath) + + // Read the generated file + content, err := os.ReadFile(outputPath) + if err != nil { + t.Fatalf("Failed to read output file: %v", err) + } + + contentStr := string(content) + + // Verify file has content + if len(contentStr) == 0 { + t.Error("Output file is empty") + } + + // Verify basic HTML structure + basicElements := []string{ + "", + "", + "", + "NetBird", + } + + for _, elem := range basicElements { + if !contains(contentStr, elem) { + t.Errorf("Expected HTML to contain '%s', but it was not found", elem) + } + } + + // Verify expected title + if !contains(contentStr, tt.expectedTitle) { + t.Errorf("Expected HTML to contain title '%s', but it was not found", tt.expectedTitle) + } + + // Verify expected content is present + for _, expected := range tt.expectedInContent { + if !contains(contentStr, expected) { + t.Errorf("Expected HTML to contain '%s', but it was not found", expected) + } + } + + // Verify unexpected content is not present + for _, notExpected := range tt.notExpectedInContent { + if contains(contentStr, notExpected) { + t.Errorf("Expected HTML to NOT contain '%s', but it was found", notExpected) + } + } + }) + } +} + +func TestPKCEAuthMsgTemplateValidation(t *testing.T) { + // Test that the template can be parsed without errors + tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl) + if err != nil { + t.Fatalf("Template parsing failed: %v", err) + } + + // Test with empty data + t.Run("empty_data", func(t *testing.T) { + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, "empty-data.html") + + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + defer file.Close() + + if err := tmpl.Execute(file, nil); err != nil { + t.Errorf("Template execution with nil data failed: %v", err) + } + }) + + // Test with error data + t.Run("with_error", func(t *testing.T) { + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, "with-error.html") + + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + defer file.Close() + + data := map[string]string{ + "Error": "test error message", + } + if err := tmpl.Execute(file, data); err != nil { + t.Errorf("Template execution with error data failed: %v", err) + } + }) +} + +func TestPKCEAuthMsgTemplateContent(t *testing.T) { + // Test that the template contains expected elements + tmpl, err := template.New("pkce-auth-msg").Parse(PKCEAuthMsgTmpl) + if err != nil { + t.Fatalf("Template parsing failed: %v", err) + } + + t.Run("success_content", func(t *testing.T) { + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, "success.html") + + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + defer file.Close() + + data := map[string]string{} + if err := tmpl.Execute(file, data); err != nil { + t.Fatalf("Template execution failed: %v", err) + } + + // Read the file and verify it contains expected content + content, err := os.ReadFile(outputPath) + if err != nil { + t.Fatalf("Failed to read output file: %v", err) + } + + // Check for success indicators + contentStr := string(content) + if len(contentStr) == 0 { + t.Error("Generated HTML is empty") + } + + // Basic HTML structure checks + requiredElements := []string{ + "", + "", + "", + "Login Successful", + "NetBird", + } + + for _, elem := range requiredElements { + if !contains(contentStr, elem) { + t.Errorf("Expected HTML to contain '%s', but it was not found", elem) + } + } + }) + + t.Run("error_content", func(t *testing.T) { + tempDir := t.TempDir() + outputPath := filepath.Join(tempDir, "error.html") + + file, err := os.Create(outputPath) + if err != nil { + t.Fatalf("Failed to create output file: %v", err) + } + defer file.Close() + + errorMsg := "test error message" + data := map[string]string{ + "Error": errorMsg, + } + if err := tmpl.Execute(file, data); err != nil { + t.Fatalf("Template execution failed: %v", err) + } + + // Read the file and verify it contains expected content + content, err := os.ReadFile(outputPath) + if err != nil { + t.Fatalf("Failed to read output file: %v", err) + } + + // Check for error indicators + contentStr := string(content) + if len(contentStr) == 0 { + t.Error("Generated HTML is empty") + } + + // Basic HTML structure checks + requiredElements := []string{ + "", + "", + "", + "Login Failed", + errorMsg, + } + + for _, elem := range requiredElements { + if !contains(contentStr, elem) { + t.Errorf("Expected HTML to contain '%s', but it was not found", elem) + } + } + }) +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(substr) == 0 || + (len(s) > 0 && len(substr) > 0 && containsHelper(s, substr))) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +}