mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
769 lines
22 KiB
Go
769 lines
22 KiB
Go
// Package dex provides an embedded Dex OIDC identity provider.
|
|
package dex
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
dexapi "github.com/dexidp/dex/api/v2"
|
|
"github.com/dexidp/dex/server"
|
|
"github.com/dexidp/dex/server/signer"
|
|
"github.com/dexidp/dex/storage"
|
|
"github.com/dexidp/dex/storage/sql"
|
|
jose "github.com/go-jose/go-jose/v4"
|
|
"github.com/google/uuid"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"golang.org/x/crypto/bcrypt"
|
|
"google.golang.org/grpc"
|
|
|
|
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
|
)
|
|
|
|
// Config matches what management/internals/server/server.go expects
|
|
type Config struct {
|
|
Issuer string
|
|
Port int
|
|
DataDir string
|
|
DevMode bool
|
|
|
|
// GRPCAddr is the address for the gRPC API (e.g., ":5557"). Empty disables gRPC.
|
|
GRPCAddr string
|
|
}
|
|
|
|
// Provider wraps a Dex server
|
|
type Provider struct {
|
|
config *Config
|
|
yamlConfig *YAMLConfig
|
|
dexServer *server.Server
|
|
httpServer *http.Server
|
|
listener net.Listener
|
|
grpcServer *grpc.Server
|
|
grpcListener net.Listener
|
|
storage storage.Storage
|
|
logger *slog.Logger
|
|
mu sync.Mutex
|
|
running bool
|
|
}
|
|
|
|
// NewProvider creates and initializes the Dex server
|
|
func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
|
if config.Issuer == "" {
|
|
return nil, fmt.Errorf("issuer is required")
|
|
}
|
|
if config.Port <= 0 {
|
|
return nil, fmt.Errorf("invalid port")
|
|
}
|
|
if config.DataDir == "" {
|
|
return nil, fmt.Errorf("data directory is required")
|
|
}
|
|
|
|
logger := slog.New(slog.NewTextHandler(os.Stderr, nil))
|
|
|
|
// Ensure data directory exists
|
|
if err := os.MkdirAll(config.DataDir, 0o700); err != nil {
|
|
return nil, fmt.Errorf("failed to create data directory: %w", err)
|
|
}
|
|
|
|
// Initialize SQLite storage
|
|
dbPath := filepath.Join(config.DataDir, "oidc.db")
|
|
sqliteConfig := &sql.SQLite3{File: dbPath}
|
|
stor, err := sqliteConfig.Open(logger)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open storage: %w", err)
|
|
}
|
|
|
|
// Ensure a local connector exists (for password authentication)
|
|
if err := ensureLocalConnector(ctx, stor); err != nil {
|
|
stor.Close()
|
|
return nil, fmt.Errorf("failed to ensure local connector: %w", err)
|
|
}
|
|
|
|
// Ensure issuer ends with /oauth2 for proper path mounting
|
|
issuer := strings.TrimSuffix(config.Issuer, "/")
|
|
if !strings.HasSuffix(issuer, "/oauth2") {
|
|
issuer += "/oauth2"
|
|
}
|
|
|
|
// Build refresh token policy (required to avoid nil pointer panics)
|
|
refreshPolicy, err := server.NewRefreshTokenPolicy(logger, false, "", "", "")
|
|
if err != nil {
|
|
stor.Close()
|
|
return nil, fmt.Errorf("failed to create refresh token policy: %w", err)
|
|
}
|
|
|
|
localSignerConfig := signer.LocalConfig{
|
|
KeysRotationPeriod: "6h",
|
|
}
|
|
|
|
localSigner, err := localSignerConfig.Open(ctx, stor, 24*time.Hour, time.Now, logger)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create local signer: %w", err)
|
|
}
|
|
|
|
// Build Dex server config - use Dex's types directly
|
|
dexConfig := server.Config{
|
|
Issuer: issuer,
|
|
Storage: stor,
|
|
SkipApprovalScreen: true,
|
|
SupportedResponseTypes: []string{"code"},
|
|
ContinueOnConnectorFailure: true,
|
|
Logger: logger,
|
|
PrometheusRegistry: prometheus.NewRegistry(),
|
|
IDTokensValidFor: 24 * time.Hour,
|
|
RefreshTokenPolicy: refreshPolicy,
|
|
Web: server.WebConfig{
|
|
Issuer: "NetBird",
|
|
},
|
|
Signer: localSigner,
|
|
}
|
|
|
|
dexSrv, err := server.NewServer(ctx, dexConfig)
|
|
if err != nil {
|
|
stor.Close()
|
|
return nil, fmt.Errorf("failed to create dex server: %w", err)
|
|
}
|
|
|
|
return &Provider{
|
|
config: config,
|
|
dexServer: dexSrv,
|
|
storage: stor,
|
|
logger: logger,
|
|
}, nil
|
|
}
|
|
|
|
// NewProviderFromYAML creates and initializes the Dex server from a YAMLConfig
|
|
func NewProviderFromYAML(ctx context.Context, yamlConfig *YAMLConfig) (*Provider, error) {
|
|
// Configure log level from config, default to WARN to avoid logging sensitive data (emails)
|
|
logLevel := slog.LevelWarn
|
|
if yamlConfig.Logger.Level != "" {
|
|
switch strings.ToLower(yamlConfig.Logger.Level) {
|
|
case "debug":
|
|
logLevel = slog.LevelDebug
|
|
case "info":
|
|
logLevel = slog.LevelInfo
|
|
case "warn", "warning":
|
|
logLevel = slog.LevelWarn
|
|
case "error":
|
|
logLevel = slog.LevelError
|
|
}
|
|
}
|
|
logger := slog.New(NewLogrusHandler(logLevel))
|
|
|
|
stor, err := yamlConfig.Storage.OpenStorage(logger)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open storage: %w", err)
|
|
}
|
|
|
|
if err := initializeStorage(ctx, stor, yamlConfig); err != nil {
|
|
stor.Close()
|
|
return nil, err
|
|
}
|
|
|
|
dexConfig := buildDexConfig(yamlConfig, stor, logger)
|
|
dexConfig.RefreshTokenPolicy, err = yamlConfig.GetRefreshTokenPolicy(logger)
|
|
if err != nil {
|
|
stor.Close()
|
|
return nil, fmt.Errorf("failed to create refresh token policy: %w", err)
|
|
}
|
|
|
|
localSigner, err := getSigner(ctx, stor, yamlConfig, logger)
|
|
if err != nil {
|
|
stor.Close()
|
|
return nil, fmt.Errorf("failed to create local signer: %w", err)
|
|
}
|
|
|
|
dexConfig.Signer = localSigner
|
|
|
|
dexSrv, err := server.NewServer(ctx, dexConfig)
|
|
if err != nil {
|
|
stor.Close()
|
|
return nil, fmt.Errorf("failed to create dex server: %w", err)
|
|
}
|
|
|
|
return &Provider{
|
|
config: &Config{Issuer: yamlConfig.Issuer, GRPCAddr: yamlConfig.GRPC.Addr},
|
|
yamlConfig: yamlConfig,
|
|
dexServer: dexSrv,
|
|
storage: stor,
|
|
logger: logger,
|
|
}, nil
|
|
}
|
|
|
|
func getSigner(ctx context.Context, stor storage.Storage, yamlConfig *YAMLConfig, logger *slog.Logger) (signer.Signer, error) {
|
|
// Parse expiry durations
|
|
idTokensValidFor := 24 * time.Hour // default
|
|
if yamlConfig.Expiry.IDTokens != "" {
|
|
var err error
|
|
idTokensValidFor, err = parseDuration(yamlConfig.Expiry.IDTokens)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid config value %q for id token expiry: %v", yamlConfig.Expiry.IDTokens, err)
|
|
}
|
|
}
|
|
|
|
localSignerConfig := &signer.LocalConfig{
|
|
KeysRotationPeriod: "720h", // 30 Days
|
|
}
|
|
|
|
if yamlConfig.Expiry.SigningKeys != "" {
|
|
if _, err := parseDuration(yamlConfig.Expiry.SigningKeys); err != nil {
|
|
return nil, fmt.Errorf("invalid config value %q for signing key expiry: %v", yamlConfig.Expiry.SigningKeys, err)
|
|
}
|
|
|
|
localSignerConfig.KeysRotationPeriod = yamlConfig.Expiry.SigningKeys
|
|
}
|
|
|
|
return localSignerConfig.Open(ctx, stor, idTokensValidFor, time.Now, logger)
|
|
}
|
|
|
|
// initializeStorage sets up connectors, passwords, and clients in storage
|
|
func initializeStorage(ctx context.Context, stor storage.Storage, cfg *YAMLConfig) error {
|
|
if cfg.EnablePasswordDB {
|
|
if err := ensureLocalConnector(ctx, stor); err != nil {
|
|
return fmt.Errorf("failed to ensure local connector: %w", err)
|
|
}
|
|
}
|
|
if err := ensureStaticPasswords(ctx, stor, cfg.StaticPasswords); err != nil {
|
|
return err
|
|
}
|
|
if err := ensureStaticClients(ctx, stor, cfg.StaticClients); err != nil {
|
|
return err
|
|
}
|
|
return ensureStaticConnectors(ctx, stor, cfg.StaticConnectors)
|
|
}
|
|
|
|
// ensureStaticPasswords creates or updates static passwords in storage
|
|
func ensureStaticPasswords(ctx context.Context, stor storage.Storage, passwords []Password) error {
|
|
for _, pw := range passwords {
|
|
existing, err := stor.GetPassword(ctx, pw.Email)
|
|
if errors.Is(err, storage.ErrNotFound) {
|
|
if err := stor.CreatePassword(ctx, storage.Password(pw)); err != nil {
|
|
return fmt.Errorf("failed to create password for %s: %w", pw.Email, err)
|
|
}
|
|
continue
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get password for %s: %w", pw.Email, err)
|
|
}
|
|
if string(existing.Hash) != string(pw.Hash) {
|
|
if err := stor.UpdatePassword(ctx, pw.Email, func(old storage.Password) (storage.Password, error) {
|
|
old.Hash = pw.Hash
|
|
old.Username = pw.Username
|
|
return old, nil
|
|
}); err != nil {
|
|
return fmt.Errorf("failed to update password for %s: %w", pw.Email, err)
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ensureStaticClients creates or updates static clients in storage
|
|
func ensureStaticClients(ctx context.Context, stor storage.Storage, clients []storage.Client) error {
|
|
for _, client := range clients {
|
|
_, err := stor.GetClient(ctx, client.ID)
|
|
if errors.Is(err, storage.ErrNotFound) {
|
|
if err := stor.CreateClient(ctx, client); err != nil {
|
|
return fmt.Errorf("failed to create client %s: %w", client.ID, err)
|
|
}
|
|
continue
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get client %s: %w", client.ID, err)
|
|
}
|
|
if err := stor.UpdateClient(ctx, client.ID, func(old storage.Client) (storage.Client, error) {
|
|
old.RedirectURIs = client.RedirectURIs
|
|
old.Name = client.Name
|
|
old.Public = client.Public
|
|
old.PostLogoutRedirectURIs = client.PostLogoutRedirectURIs
|
|
old.MFAChain = client.MFAChain
|
|
return old, nil
|
|
}); err != nil {
|
|
return fmt.Errorf("failed to update client %s: %w", client.ID, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// buildDexConfig creates a server.Config with defaults applied
|
|
func buildDexConfig(yamlConfig *YAMLConfig, stor storage.Storage, logger *slog.Logger) server.Config {
|
|
cfg := yamlConfig.ToServerConfig(stor, logger)
|
|
cfg.PrometheusRegistry = prometheus.NewRegistry()
|
|
if cfg.IDTokensValidFor == 0 {
|
|
cfg.IDTokensValidFor = 24 * time.Hour
|
|
}
|
|
if cfg.Web.Issuer == "" {
|
|
cfg.Web.Issuer = "NetBird"
|
|
}
|
|
if len(cfg.SupportedResponseTypes) == 0 {
|
|
cfg.SupportedResponseTypes = []string{"code"}
|
|
}
|
|
cfg.ContinueOnConnectorFailure = true
|
|
return cfg
|
|
}
|
|
|
|
// Start starts the HTTP server and optionally the gRPC API server
|
|
func (p *Provider) Start(_ context.Context) error {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
if p.running {
|
|
return fmt.Errorf("already running")
|
|
}
|
|
|
|
// Determine listen address from config
|
|
var addr string
|
|
if p.yamlConfig != nil {
|
|
addr = p.yamlConfig.Web.HTTP
|
|
if addr == "" {
|
|
addr = p.yamlConfig.Web.HTTPS
|
|
}
|
|
} else if p.config != nil && p.config.Port > 0 {
|
|
addr = fmt.Sprintf(":%d", p.config.Port)
|
|
}
|
|
if addr == "" {
|
|
return fmt.Errorf("no listen address configured")
|
|
}
|
|
|
|
listener, err := net.Listen("tcp", addr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to listen on %s: %w", addr, err)
|
|
}
|
|
p.listener = listener
|
|
|
|
// Mount Dex at /oauth2/ path for reverse proxy compatibility
|
|
// Don't strip the prefix - Dex's issuer includes /oauth2 so it expects the full path
|
|
mux := http.NewServeMux()
|
|
mux.Handle("/oauth2/", p.dexServer)
|
|
|
|
p.httpServer = &http.Server{Handler: mux}
|
|
p.running = true
|
|
|
|
go func() {
|
|
if err := p.httpServer.Serve(listener); err != nil && err != http.ErrServerClosed {
|
|
p.logger.Error("http server error", "error", err)
|
|
}
|
|
}()
|
|
|
|
// Start gRPC API server if configured
|
|
if p.config.GRPCAddr != "" {
|
|
if err := p.startGRPCServer(); err != nil {
|
|
// Clean up HTTP server on failure
|
|
_ = p.httpServer.Close()
|
|
_ = p.listener.Close()
|
|
return fmt.Errorf("failed to start gRPC server: %w", err)
|
|
}
|
|
}
|
|
|
|
p.logger.Info("HTTP server started", "addr", addr)
|
|
return nil
|
|
}
|
|
|
|
// startGRPCServer starts the gRPC API server using Dex's built-in API
|
|
func (p *Provider) startGRPCServer() error {
|
|
grpcListener, err := net.Listen("tcp", p.config.GRPCAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to listen on %s: %w", p.config.GRPCAddr, err)
|
|
}
|
|
p.grpcListener = grpcListener
|
|
|
|
p.grpcServer = grpc.NewServer()
|
|
// Use Dex's built-in API server implementation
|
|
// server.NewAPI(storage, logger, version, dexServer)
|
|
dexapi.RegisterDexServer(p.grpcServer, server.NewAPI(p.storage, p.logger, "netbird-dex", p.dexServer))
|
|
|
|
go func() {
|
|
if err := p.grpcServer.Serve(grpcListener); err != nil {
|
|
p.logger.Error("grpc server error", "error", err)
|
|
}
|
|
}()
|
|
|
|
p.logger.Info("gRPC API server started", "addr", p.config.GRPCAddr)
|
|
return nil
|
|
}
|
|
|
|
// Stop gracefully shuts down
|
|
func (p *Provider) Stop(ctx context.Context) error {
|
|
p.mu.Lock()
|
|
defer p.mu.Unlock()
|
|
|
|
if !p.running {
|
|
return nil
|
|
}
|
|
|
|
var errs []error
|
|
|
|
// Stop gRPC server first
|
|
if p.grpcServer != nil {
|
|
p.grpcServer.GracefulStop()
|
|
p.grpcServer = nil
|
|
}
|
|
if p.grpcListener != nil {
|
|
p.grpcListener.Close()
|
|
p.grpcListener = nil
|
|
}
|
|
|
|
if p.httpServer != nil {
|
|
if err := p.httpServer.Shutdown(ctx); err != nil {
|
|
errs = append(errs, err)
|
|
}
|
|
}
|
|
|
|
// Explicitly close listener as fallback (Shutdown should do this, but be safe)
|
|
if p.listener != nil {
|
|
if err := p.listener.Close(); err != nil {
|
|
// Ignore "use of closed network connection" - expected after Shutdown
|
|
if !strings.Contains(err.Error(), "use of closed") {
|
|
errs = append(errs, err)
|
|
}
|
|
}
|
|
p.listener = nil
|
|
}
|
|
|
|
if p.storage != nil {
|
|
if err := p.storage.Close(); err != nil {
|
|
errs = append(errs, err)
|
|
}
|
|
}
|
|
|
|
p.httpServer = nil
|
|
p.running = false
|
|
|
|
if len(errs) > 0 {
|
|
return fmt.Errorf("shutdown errors: %v", errs)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// EnsureDefaultClients creates dashboard and CLI OAuth clients
|
|
// Uses Dex's storage.Client directly - no custom wrappers
|
|
func (p *Provider) EnsureDefaultClients(ctx context.Context, dashboardURIs, cliURIs []string) error {
|
|
clients := []storage.Client{
|
|
{
|
|
ID: "netbird-dashboard",
|
|
Name: "NetBird Dashboard",
|
|
RedirectURIs: dashboardURIs,
|
|
Public: true,
|
|
},
|
|
{
|
|
ID: "netbird-cli",
|
|
Name: "NetBird CLI",
|
|
RedirectURIs: cliURIs,
|
|
Public: true,
|
|
},
|
|
}
|
|
|
|
for _, client := range clients {
|
|
_, err := p.storage.GetClient(ctx, client.ID)
|
|
if err == storage.ErrNotFound {
|
|
if err := p.storage.CreateClient(ctx, client); err != nil {
|
|
return fmt.Errorf("failed to create client %s: %w", client.ID, err)
|
|
}
|
|
continue
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get client %s: %w", client.ID, err)
|
|
}
|
|
// Update if exists
|
|
if err := p.storage.UpdateClient(ctx, client.ID, func(old storage.Client) (storage.Client, error) {
|
|
old.RedirectURIs = client.RedirectURIs
|
|
return old, nil
|
|
}); err != nil {
|
|
return fmt.Errorf("failed to update client %s: %w", client.ID, err)
|
|
}
|
|
}
|
|
|
|
p.logger.Info("default OIDC clients ensured")
|
|
return nil
|
|
}
|
|
|
|
// Storage returns the underlying Dex storage for direct access
|
|
// Users can use storage.Client, storage.Password, storage.Connector directly
|
|
func (p *Provider) Storage() storage.Storage {
|
|
return p.storage
|
|
}
|
|
|
|
// Handler returns the Dex server as an http.Handler for embedding in another server.
|
|
// The handler expects requests with path prefix "/oauth2/".
|
|
func (p *Provider) Handler() http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// NOTE: by default Dex will use the /logout route to only logout sessions, doesn't invalidate jwt tokens,
|
|
// to avoid confusion on users, we're not allowing for this, and only enable OIDC logout triggered through
|
|
// the dashboard which will invalidate both the session and the jwt token
|
|
//if strings.HasSuffix(r.URL.Path, "/logout") && r.FormValue("id_token_hint") == "" {
|
|
//http.Redirect(w, r, "/", http.StatusSeeOther)
|
|
//return
|
|
//}
|
|
|
|
p.dexServer.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// CreateUser creates a new user with the given email, username, and password.
|
|
// Returns the encoded user ID in Dex's format (base64-encoded protobuf with connector ID).
|
|
func (p *Provider) CreateUser(ctx context.Context, email, username, password string) (string, error) {
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return "", fmt.Errorf("failed to hash password: %w", err)
|
|
}
|
|
|
|
userID := uuid.New().String()
|
|
err = p.storage.CreatePassword(ctx, storage.Password{
|
|
Email: email,
|
|
Username: username,
|
|
UserID: userID,
|
|
Hash: hash,
|
|
})
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// Encode the user ID in Dex's format: base64(protobuf{user_id, connector_id})
|
|
// This matches the format Dex uses in JWT tokens
|
|
encodedID := EncodeDexUserID(userID, "local")
|
|
return encodedID, nil
|
|
}
|
|
|
|
// EncodeDexUserID encodes user ID and connector ID into Dex's base64-encoded protobuf format.
|
|
// Dex uses this format for the 'sub' claim in JWT tokens.
|
|
// Format: base64(protobuf message with field 1 = user_id, field 2 = connector_id)
|
|
func EncodeDexUserID(userID, connectorID string) string {
|
|
// Manually encode protobuf: field 1 (user_id) and field 2 (connector_id)
|
|
// Wire type 2 (length-delimited) for strings
|
|
var buf []byte
|
|
|
|
// Field 1: user_id (tag = 0x0a = field 1, wire type 2)
|
|
buf = append(buf, 0x0a)
|
|
buf = append(buf, byte(len(userID)))
|
|
buf = append(buf, []byte(userID)...)
|
|
|
|
// Field 2: connector_id (tag = 0x12 = field 2, wire type 2)
|
|
buf = append(buf, 0x12)
|
|
buf = append(buf, byte(len(connectorID)))
|
|
buf = append(buf, []byte(connectorID)...)
|
|
|
|
return base64.RawStdEncoding.EncodeToString(buf)
|
|
}
|
|
|
|
// DecodeDexUserID decodes Dex's base64-encoded user ID back to the raw user ID and connector ID.
|
|
func DecodeDexUserID(encodedID string) (userID, connectorID string, err error) {
|
|
// Try RawStdEncoding first, then StdEncoding (with padding)
|
|
buf, err := base64.RawStdEncoding.DecodeString(encodedID)
|
|
if err != nil {
|
|
buf, err = base64.StdEncoding.DecodeString(encodedID)
|
|
if err != nil {
|
|
return "", "", fmt.Errorf("failed to decode base64: %w", err)
|
|
}
|
|
}
|
|
|
|
// Parse protobuf manually
|
|
i := 0
|
|
for i < len(buf) {
|
|
if i >= len(buf) {
|
|
break
|
|
}
|
|
tag := buf[i]
|
|
i++
|
|
|
|
fieldNum := tag >> 3
|
|
wireType := tag & 0x07
|
|
|
|
if wireType != 2 { // We only expect length-delimited strings
|
|
return "", "", fmt.Errorf("unexpected wire type %d", wireType)
|
|
}
|
|
|
|
if i >= len(buf) {
|
|
return "", "", fmt.Errorf("truncated message")
|
|
}
|
|
length := int(buf[i])
|
|
i++
|
|
|
|
if i+length > len(buf) {
|
|
return "", "", fmt.Errorf("truncated string field")
|
|
}
|
|
value := string(buf[i : i+length])
|
|
i += length
|
|
|
|
switch fieldNum {
|
|
case 1:
|
|
userID = value
|
|
case 2:
|
|
connectorID = value
|
|
}
|
|
}
|
|
|
|
return userID, connectorID, nil
|
|
}
|
|
|
|
// GetUser returns a user by email
|
|
func (p *Provider) GetUser(ctx context.Context, email string) (storage.Password, error) {
|
|
return p.storage.GetPassword(ctx, email)
|
|
}
|
|
|
|
// GetUserByID returns a user by user ID.
|
|
// The userID can be either an encoded Dex ID (base64 protobuf) or a raw UUID.
|
|
// Note: This requires iterating through all users since dex storage doesn't index by userID.
|
|
func (p *Provider) GetUserByID(ctx context.Context, userID string) (storage.Password, error) {
|
|
// Try to decode the user ID in case it's encoded
|
|
rawUserID, _, err := DecodeDexUserID(userID)
|
|
if err != nil {
|
|
// If decoding fails, assume it's already a raw UUID
|
|
rawUserID = userID
|
|
}
|
|
|
|
users, err := p.storage.ListPasswords(ctx)
|
|
if err != nil {
|
|
return storage.Password{}, fmt.Errorf("failed to list users: %w", err)
|
|
}
|
|
for _, user := range users {
|
|
if user.UserID == rawUserID {
|
|
return user, nil
|
|
}
|
|
}
|
|
return storage.Password{}, storage.ErrNotFound
|
|
}
|
|
|
|
// DeleteUser removes a user by email
|
|
func (p *Provider) DeleteUser(ctx context.Context, email string) error {
|
|
return p.storage.DeletePassword(ctx, email)
|
|
}
|
|
|
|
// ListUsers returns all users
|
|
func (p *Provider) ListUsers(ctx context.Context) ([]storage.Password, error) {
|
|
return p.storage.ListPasswords(ctx)
|
|
}
|
|
|
|
// UpdateUserPassword updates the password for a user identified by userID.
|
|
// The userID can be either an encoded Dex ID (base64 protobuf) or a raw UUID.
|
|
// It verifies the current password before updating.
|
|
func (p *Provider) UpdateUserPassword(ctx context.Context, userID string, oldPassword, newPassword string) error {
|
|
// Get the user by ID to find their email
|
|
user, err := p.GetUserByID(ctx, userID)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to get user: %w", err)
|
|
}
|
|
|
|
// Verify old password
|
|
if err := bcrypt.CompareHashAndPassword(user.Hash, []byte(oldPassword)); err != nil {
|
|
return fmt.Errorf("current password is incorrect")
|
|
}
|
|
|
|
// Hash the new password
|
|
newHash, err := bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to hash new password: %w", err)
|
|
}
|
|
|
|
// Update the password in storage
|
|
err = p.storage.UpdatePassword(ctx, user.Email, func(old storage.Password) (storage.Password, error) {
|
|
old.Hash = newHash
|
|
return old, nil
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to update password: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetIssuer returns the OIDC issuer URL.
|
|
func (p *Provider) GetIssuer() string {
|
|
if p.config == nil {
|
|
return ""
|
|
}
|
|
issuer := strings.TrimSuffix(p.config.Issuer, "/")
|
|
if !strings.HasSuffix(issuer, "/oauth2") {
|
|
issuer += "/oauth2"
|
|
}
|
|
return issuer
|
|
}
|
|
|
|
// GetKeysLocation returns the JWKS endpoint URL for token validation.
|
|
func (p *Provider) GetKeysLocation() string {
|
|
issuer := p.GetIssuer()
|
|
if issuer == "" {
|
|
return ""
|
|
}
|
|
return issuer + "/keys"
|
|
}
|
|
|
|
// GetTokenEndpoint returns the OAuth2 token endpoint URL.
|
|
func (p *Provider) GetTokenEndpoint() string {
|
|
issuer := p.GetIssuer()
|
|
if issuer == "" {
|
|
return ""
|
|
}
|
|
return issuer + "/token"
|
|
}
|
|
|
|
// GetDeviceAuthEndpoint returns the OAuth2 device authorization endpoint URL.
|
|
func (p *Provider) GetDeviceAuthEndpoint() string {
|
|
issuer := p.GetIssuer()
|
|
if issuer == "" {
|
|
return ""
|
|
}
|
|
return issuer + "/device/code"
|
|
}
|
|
|
|
// GetAuthorizationEndpoint returns the OAuth2 authorization endpoint URL.
|
|
func (p *Provider) GetAuthorizationEndpoint() string {
|
|
issuer := p.GetIssuer()
|
|
if issuer == "" {
|
|
return ""
|
|
}
|
|
return issuer + "/auth"
|
|
}
|
|
|
|
// GetJWKS reads signing keys directly from Dex storage and returns them as Jwks.
|
|
// This avoids HTTP round-trips when the embedded IDP is co-located with the management server.
|
|
// The key retrieval mirrors Dex's own handlePublicKeys/ValidationKeys logic:
|
|
// SigningKeyPub first, then all VerificationKeys, serialized via go-jose.
|
|
func (p *Provider) GetJWKS(ctx context.Context) (*nbjwt.Jwks, error) {
|
|
keys, err := p.storage.GetKeys(ctx)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to get keys from storage: %w", err)
|
|
}
|
|
|
|
if keys.SigningKeyPub == nil {
|
|
return nil, fmt.Errorf("no public keys found in storage")
|
|
}
|
|
|
|
// Build the key set exactly as Dex's localSigner.ValidationKeys does:
|
|
// signing key first, then all verification (rotated) keys.
|
|
joseKeys := make([]jose.JSONWebKey, 0, len(keys.VerificationKeys)+1)
|
|
joseKeys = append(joseKeys, *keys.SigningKeyPub)
|
|
for _, vk := range keys.VerificationKeys {
|
|
if vk.PublicKey != nil {
|
|
joseKeys = append(joseKeys, *vk.PublicKey)
|
|
}
|
|
}
|
|
|
|
// Serialize through go-jose (same as Dex's handlePublicKeys handler)
|
|
// then deserialize into our Jwks type, so the JSON field mapping is identical
|
|
// to what the /keys HTTP endpoint would return.
|
|
joseSet := jose.JSONWebKeySet{Keys: joseKeys}
|
|
data, err := json.Marshal(joseSet)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to marshal JWKS: %w", err)
|
|
}
|
|
|
|
jwks := &nbjwt.Jwks{}
|
|
if err := json.Unmarshal(data, jwks); err != nil {
|
|
return nil, fmt.Errorf("failed to unmarshal JWKS: %w", err)
|
|
}
|
|
|
|
jwks.ExpiresInTime = keys.NextRotation
|
|
|
|
return jwks, nil
|
|
}
|