From 614e7d5b90667b807e788dd3f0d7421dac4a8cac Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Fri, 9 Jan 2026 09:45:43 -0500 Subject: [PATCH] Validate OIDC issuer when creating or updating (#5074) --- management/server/identity_provider.go | 77 ++++++++++++- management/server/identity_provider_test.go | 110 +++++++++++++++++++ management/server/types/identity_provider.go | 14 ++- 3 files changed, 191 insertions(+), 10 deletions(-) diff --git a/management/server/identity_provider.go b/management/server/identity_provider.go index 6649c3953..8fd96c238 100644 --- a/management/server/identity_provider.go +++ b/management/server/identity_provider.go @@ -2,7 +2,13 @@ package server import ( "context" + "encoding/json" "errors" + "fmt" + "io" + "net/http" + "strings" + "time" "github.com/dexidp/dex/storage" "github.com/rs/xid" @@ -17,6 +23,69 @@ import ( "github.com/netbirdio/netbird/shared/management/status" ) +// oidcProviderJSON represents the OpenID Connect discovery document +type oidcProviderJSON struct { + Issuer string `json:"issuer"` +} + +// validateOIDCIssuer validates the OIDC issuer by fetching the OpenID configuration +// and verifying that the returned issuer matches the configured one. +func validateOIDCIssuer(ctx context.Context, issuer string) error { + wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration" + + httpClient := &http.Client{ + Timeout: 10 * time.Second, + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil) + if err != nil { + return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err) + } + + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("%w: unable to read response body: %v", types.ErrIdentityProviderIssuerUnreachable, err) + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("%w: %s: %s", types.ErrIdentityProviderIssuerUnreachable, resp.Status, body) + } + + var p oidcProviderJSON + if err := json.Unmarshal(body, &p); err != nil { + return fmt.Errorf("%w: failed to decode provider discovery object: %v", types.ErrIdentityProviderIssuerUnreachable, err) + } + + if p.Issuer != issuer { + return fmt.Errorf("%w: expected %q got %q", types.ErrIdentityProviderIssuerMismatch, issuer, p.Issuer) + } + + return nil +} + +// validateIdentityProviderConfig validates the identity provider configuration including +// basic validation and OIDC issuer verification. +func validateIdentityProviderConfig(ctx context.Context, idpConfig *types.IdentityProvider) error { + if err := idpConfig.Validate(); err != nil { + return status.Errorf(status.InvalidArgument, "%s", err.Error()) + } + + // Validate the issuer by calling the OIDC discovery endpoint + if idpConfig.Issuer != "" { + if err := validateOIDCIssuer(ctx, idpConfig.Issuer); err != nil { + return status.Errorf(status.InvalidArgument, "%s", err.Error()) + } + } + + return nil +} + // GetIdentityProviders returns all identity providers for an account func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) { ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read) @@ -82,8 +151,8 @@ func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, acc return nil, status.NewPermissionDeniedError() } - if err := idpConfig.Validate(); err != nil { - return nil, status.Errorf(status.InvalidArgument, "%s", err.Error()) + if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil { + return nil, err } embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager) @@ -119,8 +188,8 @@ func (am *DefaultAccountManager) UpdateIdentityProvider(ctx context.Context, acc return nil, status.NewPermissionDeniedError() } - if err := idpConfig.Validate(); err != nil { - return nil, status.Errorf(status.InvalidArgument, "%s", err.Error()) + if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil { + return nil, err } embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager) diff --git a/management/server/identity_provider_test.go b/management/server/identity_provider_test.go index d637c4a8f..78dcbeb74 100644 --- a/management/server/identity_provider_test.go +++ b/management/server/identity_provider_test.go @@ -2,6 +2,10 @@ package server import ( "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" "path/filepath" "testing" @@ -200,3 +204,109 @@ func TestDefaultAccountManager_UpdateIdentityProvider_Validation(t *testing.T) { require.Error(t, err) assert.Contains(t, err.Error(), "name is required") } + +func TestValidateOIDCIssuer(t *testing.T) { + tests := []struct { + name string + setupServer func() *httptest.Server + expectedErr error + expectedErrMsg string + }{ + { + name: "issuer mismatch", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := oidcProviderJSON{Issuer: "https://different-issuer.com"} + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + })) + }, + expectedErr: types.ErrIdentityProviderIssuerMismatch, + expectedErrMsg: "does not match", + }, + { + name: "server returns non-200 status", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte("not found")) + })) + }, + expectedErr: types.ErrIdentityProviderIssuerUnreachable, + expectedErrMsg: "404", + }, + { + name: "server returns invalid JSON", + setupServer: func() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte("invalid json")) + })) + }, + expectedErr: types.ErrIdentityProviderIssuerUnreachable, + expectedErrMsg: "failed to decode", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := tt.setupServer() + defer server.Close() + + err := validateOIDCIssuer(context.Background(), server.URL) + + require.Error(t, err) + assert.True(t, errors.Is(err, tt.expectedErr), "expected error %v, got %v", tt.expectedErr, err) + if tt.expectedErrMsg != "" { + assert.Contains(t, err.Error(), tt.expectedErrMsg) + } + }) + } +} + +func TestValidateOIDCIssuer_Success(t *testing.T) { + // Create a server that returns its own URL as the issuer + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/.well-known/openid-configuration" { + http.NotFound(w, r) + return + } + resp := oidcProviderJSON{Issuer: server.URL} + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + err := validateOIDCIssuer(context.Background(), server.URL) + require.NoError(t, err) +} + +func TestValidateOIDCIssuer_UnreachableServer(t *testing.T) { + // Use a URL that will definitely fail to connect + err := validateOIDCIssuer(context.Background(), "http://localhost:59999") + require.Error(t, err) + assert.True(t, errors.Is(err, types.ErrIdentityProviderIssuerUnreachable)) +} + +func TestValidateOIDCIssuer_TrailingSlash(t *testing.T) { + // Test that trailing slashes are handled correctly + var server *httptest.Server + server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/.well-known/openid-configuration" { + http.NotFound(w, r) + return + } + // Return issuer without trailing slash + resp := oidcProviderJSON{Issuer: server.URL} + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(resp) + })) + defer server.Close() + + // Pass issuer with trailing slash + err := validateOIDCIssuer(context.Background(), server.URL+"/") + // This should fail because the issuer returned doesn't have trailing slash + require.Error(t, err) + assert.True(t, errors.Is(err, types.ErrIdentityProviderIssuerMismatch)) +} diff --git a/management/server/types/identity_provider.go b/management/server/types/identity_provider.go index e809590de..c4498e4d4 100644 --- a/management/server/types/identity_provider.go +++ b/management/server/types/identity_provider.go @@ -7,12 +7,14 @@ import ( // Identity provider validation errors var ( - ErrIdentityProviderNameRequired = errors.New("identity provider name is required") - ErrIdentityProviderTypeRequired = errors.New("identity provider type is required") - ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type") - ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required") - ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL") - ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required") + ErrIdentityProviderNameRequired = errors.New("identity provider name is required") + ErrIdentityProviderTypeRequired = errors.New("identity provider type is required") + ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type") + ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required") + ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL") + ErrIdentityProviderIssuerUnreachable = errors.New("identity provider issuer is unreachable") + ErrIdentityProviderIssuerMismatch = errors.New("identity provider issuer does not match the issuer returned by the provider") + ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required") ) // IdentityProviderType is the type of identity provider