From e33a9b8c888988f7b2e6c500e72d53b57db01f34 Mon Sep 17 00:00:00 2001 From: Elias Schneider Date: Sun, 26 Apr 2026 15:42:34 +0200 Subject: [PATCH] chore: post dependency upgrade fixes --- backend/internal/service/oidc_service.go | 18 ++++--- backend/internal/service/oidc_service_test.go | 34 +++++++------- .../internal/utils/testing/round_tripper.go | 47 +++++++++++++++++-- 3 files changed, 72 insertions(+), 27 deletions(-) diff --git a/backend/internal/service/oidc_service.go b/backend/internal/service/oidc_service.go index 53824588..845c24c8 100644 --- a/backend/internal/service/oidc_service.go +++ b/backend/internal/service/oidc_service.go @@ -1800,14 +1800,18 @@ func (s *OidcService) jwkSetForURL(ctx context.Context, url string) (set jwk.Set // We set a timeout because otherwise Register will keep trying in case of errors registerCtx, registerCancel := context.WithTimeout(ctx, 15*time.Second) defer registerCancel() - // We need to register the URL - err = s.jwkCache.Register( - registerCtx, - url, - jwk.WithMaxInterval(24*time.Hour), - jwk.WithMinInterval(15*time.Minute), + + registerOptions := []jwk.RegisterOption{ + jwk.WithMaxInterval(24 * time.Hour), + jwk.WithMinInterval(15 * time.Minute), jwk.WithWaitReady(true), - ) + } + if s.httpClient != nil { + registerOptions = append(registerOptions, jwk.WithHTTPClient(s.httpClient)) + } + + // We need to register the URL + err = s.jwkCache.Register(registerCtx, url, registerOptions...) // In case of race conditions (two goroutines calling jwkCache.Register at the same time), it's possible we can get a conflict anyways, so we ignore that error if err != nil && !errors.Is(err, httprc.ErrResourceAlreadyExists()) { return nil, fmt.Errorf("failed to register JWK set: %w", err) diff --git a/backend/internal/service/oidc_service_test.go b/backend/internal/service/oidc_service_test.go index 0f807432..6d514578 100644 --- a/backend/internal/service/oidc_service_test.go +++ b/backend/internal/service/oidc_service_test.go @@ -895,6 +895,8 @@ func TestOidcService_updateClientLogoType(t *testing.T) { } func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { + const publicLogoHost = "https://8.8.8.8" + // Create a test database db := testutils.NewDatabaseForTest(t) @@ -940,7 +942,7 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { // Create a mock HTTP client with responses mockResponses := map[string]*http.Response{ //nolint:bodyclose - "https://example.com/logo.png": pngResponse, + publicLogoHost + "/logo.png": pngResponse, } httpClient := &http.Client{ Transport: &testutils.MockRoundTripper{ @@ -956,7 +958,7 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { } // Download and save the logo - err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/logo.png", true) + err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, publicLogoHost+"/logo.png", true) require.NoError(t, err) // Verify the file was saved @@ -985,7 +987,7 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { mockResponses := map[string]*http.Response{ //nolint:bodyclose - "https://example.com/dark-logo.webp": webpResponse, + publicLogoHost + "/dark-logo.webp": webpResponse, } httpClient := &http.Client{ Transport: &testutils.MockRoundTripper{ @@ -1000,7 +1002,7 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { } // Download and save the dark logo - err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/dark-logo.webp", false) + err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, publicLogoHost+"/dark-logo.webp", false) require.NoError(t, err) // Verify the dark logo file was saved @@ -1024,7 +1026,7 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { mockResponses := map[string]*http.Response{ //nolint:bodyclose - "https://example.com/icon.svg": testutils.NewMockResponse(http.StatusOK, string(svgContent)), + publicLogoHost + "/icon.svg": testutils.NewMockResponse(http.StatusOK, string(svgContent)), } httpClient := &http.Client{ Transport: &testutils.MockRoundTripper{ @@ -1038,7 +1040,7 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { httpClient: httpClient, } - err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/icon.svg", true) + err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, publicLogoHost+"/icon.svg", true) require.NoError(t, err) // Verify SVG file was saved @@ -1055,7 +1057,7 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { mockResponses := map[string]*http.Response{ //nolint:bodyclose - "https://example.com/logo": jpgResponse, + publicLogoHost + "/logo": jpgResponse, } httpClient := &http.Client{ Transport: &testutils.MockRoundTripper{ @@ -1069,7 +1071,7 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { httpClient: httpClient, } - err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/logo", true) + err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, publicLogoHost+"/logo", true) require.NoError(t, err) // Verify JPG file was saved (jpeg extension is normalized to jpg) @@ -1091,7 +1093,7 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { t.Run("Returns error for non-200 status code", func(t *testing.T) { mockResponses := map[string]*http.Response{ //nolint:bodyclose - "https://example.com/not-found.png": testutils.NewMockResponse(http.StatusNotFound, "Not Found"), + publicLogoHost + "/not-found.png": testutils.NewMockResponse(http.StatusNotFound, "Not Found"), } httpClient := &http.Client{ Transport: &testutils.MockRoundTripper{ @@ -1105,7 +1107,7 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { httpClient: httpClient, } - err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/not-found.png", true) + err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, publicLogoHost+"/not-found.png", true) require.Error(t, err) require.ErrorContains(t, err, "failed to fetch logo") }) @@ -1121,7 +1123,7 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { mockResponses := map[string]*http.Response{ //nolint:bodyclose - "https://example.com/large.png": largeResponse, + publicLogoHost + "/large.png": largeResponse, } httpClient := &http.Client{ Transport: &testutils.MockRoundTripper{ @@ -1135,7 +1137,7 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { httpClient: httpClient, } - err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/large.png", true) + err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, publicLogoHost+"/large.png", true) require.Error(t, err) require.ErrorIs(t, err, errLogoTooLarge) }) @@ -1147,7 +1149,7 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { mockResponses := map[string]*http.Response{ //nolint:bodyclose - "https://example.com/file.txt": textResponse, + publicLogoHost + "/file.txt": textResponse, } httpClient := &http.Client{ Transport: &testutils.MockRoundTripper{ @@ -1161,7 +1163,7 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { httpClient: httpClient, } - err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, "https://example.com/file.txt", true) + err := s.downloadAndSaveLogoFromURL(t.Context(), client.ID, publicLogoHost+"/file.txt", true) require.Error(t, err) var fileTypeErr *common.FileTypeNotSupportedError require.ErrorAs(t, err, &fileTypeErr) @@ -1174,7 +1176,7 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { mockResponses := map[string]*http.Response{ //nolint:bodyclose - "https://example.com/logo.png": pngResponse, + publicLogoHost + "/logo.png": pngResponse, } httpClient := &http.Client{ Transport: &testutils.MockRoundTripper{ @@ -1188,7 +1190,7 @@ func TestOidcService_downloadAndSaveLogoFromURL(t *testing.T) { httpClient: httpClient, } - err := s.downloadAndSaveLogoFromURL(t.Context(), "non-existent-client-id", "https://example.com/logo.png", true) + err := s.downloadAndSaveLogoFromURL(t.Context(), "non-existent-client-id", publicLogoHost+"/logo.png", true) require.Error(t, err) require.ErrorContains(t, err, "failed to look up client") }) diff --git a/backend/internal/utils/testing/round_tripper.go b/backend/internal/utils/testing/round_tripper.go index 806e2eaf..d8b758e4 100644 --- a/backend/internal/utils/testing/round_tripper.go +++ b/backend/internal/utils/testing/round_tripper.go @@ -3,9 +3,11 @@ package testing import ( + "bytes" "io" "net/http" "strings" + "sync" _ "github.com/golang-migrate/migrate/v4/source/file" ) @@ -14,25 +16,62 @@ import ( type MockRoundTripper struct { Err error Responses map[string]*http.Response + + mu sync.Mutex + responseBodies map[string][]byte } // RoundTrip implements the http.RoundTripper interface func (m *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if m.Err != nil { + return nil, m.Err + } + // Check if we have a specific response for this URL for url, resp := range m.Responses { if req.URL.String() == url { - return resp, nil + return m.cloneResponse(url, resp) } } return NewMockResponse(http.StatusNotFound, ""), nil } +func (m *MockRoundTripper) cloneResponse(url string, resp *http.Response) (*http.Response, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if m.responseBodies == nil { + m.responseBodies = make(map[string][]byte, len(m.Responses)) + } + + body, ok := m.responseBodies[url] + if !ok { + var err error + body, err = io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + m.responseBodies[url] = body + resp.Body = io.NopCloser(bytes.NewReader(body)) + } + + cloned := new(http.Response) + *cloned = *resp + cloned.Header = resp.Header.Clone() + cloned.Body = io.NopCloser(bytes.NewReader(body)) + cloned.ContentLength = int64(len(body)) + + return cloned, nil +} + // NewMockResponse creates an http.Response with the given status code and body func NewMockResponse(statusCode int, body string) *http.Response { return &http.Response{ - StatusCode: statusCode, - Body: io.NopCloser(strings.NewReader(body)), - Header: make(http.Header), + StatusCode: statusCode, + Body: io.NopCloser(strings.NewReader(body)), + Header: make(http.Header), + ContentLength: int64(len(body)), } }