mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
Validate OIDC issuer when creating or updating (#5074)
This commit is contained in:
@@ -2,7 +2,13 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/dexidp/dex/storage"
|
"github.com/dexidp/dex/storage"
|
||||||
"github.com/rs/xid"
|
"github.com/rs/xid"
|
||||||
@@ -17,6 +23,69 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// oidcProviderJSON represents the OpenID Connect discovery document
|
||||||
|
type oidcProviderJSON struct {
|
||||||
|
Issuer string `json:"issuer"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateOIDCIssuer validates the OIDC issuer by fetching the OpenID configuration
|
||||||
|
// and verifying that the returned issuer matches the configured one.
|
||||||
|
func validateOIDCIssuer(ctx context.Context, issuer string) error {
|
||||||
|
wellKnown := strings.TrimSuffix(issuer, "/") + "/.well-known/openid-configuration"
|
||||||
|
|
||||||
|
httpClient := &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnown, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: unable to read response body: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("%w: %s: %s", types.ErrIdentityProviderIssuerUnreachable, resp.Status, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
var p oidcProviderJSON
|
||||||
|
if err := json.Unmarshal(body, &p); err != nil {
|
||||||
|
return fmt.Errorf("%w: failed to decode provider discovery object: %v", types.ErrIdentityProviderIssuerUnreachable, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Issuer != issuer {
|
||||||
|
return fmt.Errorf("%w: expected %q got %q", types.ErrIdentityProviderIssuerMismatch, issuer, p.Issuer)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// validateIdentityProviderConfig validates the identity provider configuration including
|
||||||
|
// basic validation and OIDC issuer verification.
|
||||||
|
func validateIdentityProviderConfig(ctx context.Context, idpConfig *types.IdentityProvider) error {
|
||||||
|
if err := idpConfig.Validate(); err != nil {
|
||||||
|
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the issuer by calling the OIDC discovery endpoint
|
||||||
|
if idpConfig.Issuer != "" {
|
||||||
|
if err := validateOIDCIssuer(ctx, idpConfig.Issuer); err != nil {
|
||||||
|
return status.Errorf(status.InvalidArgument, "%s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetIdentityProviders returns all identity providers for an account
|
// GetIdentityProviders returns all identity providers for an account
|
||||||
func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
|
func (am *DefaultAccountManager) GetIdentityProviders(ctx context.Context, accountID, userID string) ([]*types.IdentityProvider, error) {
|
||||||
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
|
ok, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.IdentityProviders, operations.Read)
|
||||||
@@ -82,8 +151,8 @@ func (am *DefaultAccountManager) CreateIdentityProvider(ctx context.Context, acc
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := idpConfig.Validate(); err != nil {
|
if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||||
@@ -119,8 +188,8 @@ func (am *DefaultAccountManager) UpdateIdentityProvider(ctx context.Context, acc
|
|||||||
return nil, status.NewPermissionDeniedError()
|
return nil, status.NewPermissionDeniedError()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := idpConfig.Validate(); err != nil {
|
if err := validateIdentityProviderConfig(ctx, idpConfig); err != nil {
|
||||||
return nil, status.Errorf(status.InvalidArgument, "%s", err.Error())
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
embeddedManager, ok := am.idpManager.(*idp.EmbeddedIdPManager)
|
||||||
|
|||||||
@@ -2,6 +2,10 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
@@ -200,3 +204,109 @@ func TestDefaultAccountManager_UpdateIdentityProvider_Validation(t *testing.T) {
|
|||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
assert.Contains(t, err.Error(), "name is required")
|
assert.Contains(t, err.Error(), "name is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateOIDCIssuer(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupServer func() *httptest.Server
|
||||||
|
expectedErr error
|
||||||
|
expectedErrMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "issuer mismatch",
|
||||||
|
setupServer: func() *httptest.Server {
|
||||||
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
resp := oidcProviderJSON{Issuer: "https://different-issuer.com"}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
},
|
||||||
|
expectedErr: types.ErrIdentityProviderIssuerMismatch,
|
||||||
|
expectedErrMsg: "does not match",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "server returns non-200 status",
|
||||||
|
setupServer: func() *httptest.Server {
|
||||||
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
_, _ = w.Write([]byte("not found"))
|
||||||
|
}))
|
||||||
|
},
|
||||||
|
expectedErr: types.ErrIdentityProviderIssuerUnreachable,
|
||||||
|
expectedErrMsg: "404",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "server returns invalid JSON",
|
||||||
|
setupServer: func() *httptest.Server {
|
||||||
|
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_, _ = w.Write([]byte("invalid json"))
|
||||||
|
}))
|
||||||
|
},
|
||||||
|
expectedErr: types.ErrIdentityProviderIssuerUnreachable,
|
||||||
|
expectedErrMsg: "failed to decode",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
server := tt.setupServer()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
err := validateOIDCIssuer(context.Background(), server.URL)
|
||||||
|
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, errors.Is(err, tt.expectedErr), "expected error %v, got %v", tt.expectedErr, err)
|
||||||
|
if tt.expectedErrMsg != "" {
|
||||||
|
assert.Contains(t, err.Error(), tt.expectedErrMsg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateOIDCIssuer_Success(t *testing.T) {
|
||||||
|
// Create a server that returns its own URL as the issuer
|
||||||
|
var server *httptest.Server
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp := oidcProviderJSON{Issuer: server.URL}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
err := validateOIDCIssuer(context.Background(), server.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateOIDCIssuer_UnreachableServer(t *testing.T) {
|
||||||
|
// Use a URL that will definitely fail to connect
|
||||||
|
err := validateOIDCIssuer(context.Background(), "http://localhost:59999")
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, errors.Is(err, types.ErrIdentityProviderIssuerUnreachable))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateOIDCIssuer_TrailingSlash(t *testing.T) {
|
||||||
|
// Test that trailing slashes are handled correctly
|
||||||
|
var server *httptest.Server
|
||||||
|
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.URL.Path != "/.well-known/openid-configuration" {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Return issuer without trailing slash
|
||||||
|
resp := oidcProviderJSON{Issuer: server.URL}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(resp)
|
||||||
|
}))
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
// Pass issuer with trailing slash
|
||||||
|
err := validateOIDCIssuer(context.Background(), server.URL+"/")
|
||||||
|
// This should fail because the issuer returned doesn't have trailing slash
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.True(t, errors.Is(err, types.ErrIdentityProviderIssuerMismatch))
|
||||||
|
}
|
||||||
|
|||||||
@@ -7,12 +7,14 @@ import (
|
|||||||
|
|
||||||
// Identity provider validation errors
|
// Identity provider validation errors
|
||||||
var (
|
var (
|
||||||
ErrIdentityProviderNameRequired = errors.New("identity provider name is required")
|
ErrIdentityProviderNameRequired = errors.New("identity provider name is required")
|
||||||
ErrIdentityProviderTypeRequired = errors.New("identity provider type is required")
|
ErrIdentityProviderTypeRequired = errors.New("identity provider type is required")
|
||||||
ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type")
|
ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type")
|
||||||
ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required")
|
ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required")
|
||||||
ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL")
|
ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL")
|
||||||
ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required")
|
ErrIdentityProviderIssuerUnreachable = errors.New("identity provider issuer is unreachable")
|
||||||
|
ErrIdentityProviderIssuerMismatch = errors.New("identity provider issuer does not match the issuer returned by the provider")
|
||||||
|
ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required")
|
||||||
)
|
)
|
||||||
|
|
||||||
// IdentityProviderType is the type of identity provider
|
// IdentityProviderType is the type of identity provider
|
||||||
|
|||||||
Reference in New Issue
Block a user