mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
Add Zitadel IdP
This commit is contained in:
35
idp/cmd/env.go
Normal file
35
idp/cmd/env.go
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/spf13/pflag"
|
||||||
|
)
|
||||||
|
|
||||||
|
// setFlagsFromEnvVars reads and updates flag values from environment variables with prefix NB_IDP_
|
||||||
|
func setFlagsFromEnvVars(cmd *cobra.Command) {
|
||||||
|
flags := cmd.PersistentFlags()
|
||||||
|
flags.VisitAll(func(f *pflag.Flag) {
|
||||||
|
newEnvVar := flagNameToEnvVar(f.Name, "NB_IDP_")
|
||||||
|
value, present := os.LookupEnv(newEnvVar)
|
||||||
|
if !present {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
err := flags.Set(f.Name, value)
|
||||||
|
if err != nil {
|
||||||
|
log.Infof("unable to configure flag %s using variable %s, err: %v", f.Name, newEnvVar, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// flagNameToEnvVar converts flag name to environment var name adding a prefix,
|
||||||
|
// replacing dashes and making all uppercase (e.g. data-dir is converted to NB_IDP_DATA_DIR)
|
||||||
|
func flagNameToEnvVar(cmdFlag string, prefix string) string {
|
||||||
|
parsed := strings.ReplaceAll(cmdFlag, "-", "_")
|
||||||
|
upper := strings.ToUpper(parsed)
|
||||||
|
return prefix + upper
|
||||||
|
}
|
||||||
148
idp/cmd/root.go
Normal file
148
idp/cmd/root.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/idp/oidcprovider"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds the IdP server configuration
|
||||||
|
type Config struct {
|
||||||
|
ListenPort int
|
||||||
|
Issuer string
|
||||||
|
DataDir string
|
||||||
|
LogLevel string
|
||||||
|
LogFile string
|
||||||
|
DevMode bool
|
||||||
|
DashboardRedirectURIs []string
|
||||||
|
CLIRedirectURIs []string
|
||||||
|
DashboardClientID string
|
||||||
|
CLIClientID string
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
config *Config
|
||||||
|
rootCmd = &cobra.Command{
|
||||||
|
Use: "idp",
|
||||||
|
Short: "NetBird Identity Provider",
|
||||||
|
Long: "Embedded OIDC Identity Provider for NetBird",
|
||||||
|
SilenceUsage: true,
|
||||||
|
SilenceErrors: true,
|
||||||
|
RunE: execute,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
_ = util.InitLog("trace", util.LogConsole)
|
||||||
|
config = &Config{}
|
||||||
|
|
||||||
|
rootCmd.PersistentFlags().IntVarP(&config.ListenPort, "port", "p", 33081, "port to listen on")
|
||||||
|
rootCmd.PersistentFlags().StringVarP(&config.Issuer, "issuer", "i", "", "OIDC issuer URL (default: http://localhost:<port>)")
|
||||||
|
rootCmd.PersistentFlags().StringVarP(&config.DataDir, "data-dir", "d", "/var/lib/netbird", "directory to store IdP data")
|
||||||
|
rootCmd.PersistentFlags().StringVar(&config.LogLevel, "log-level", "info", "log level (trace, debug, info, warn, error)")
|
||||||
|
rootCmd.PersistentFlags().StringVar(&config.LogFile, "log-file", "console", "log file path or 'console'")
|
||||||
|
rootCmd.PersistentFlags().BoolVar(&config.DevMode, "dev-mode", false, "enable development mode (allows HTTP)")
|
||||||
|
rootCmd.PersistentFlags().StringSliceVar(&config.DashboardRedirectURIs, "dashboard-redirect-uris", []string{
|
||||||
|
"http://localhost:3000/callback",
|
||||||
|
"http://localhost:3000/silent-callback",
|
||||||
|
}, "allowed redirect URIs for dashboard client")
|
||||||
|
rootCmd.PersistentFlags().StringSliceVar(&config.CLIRedirectURIs, "cli-redirect-uris", []string{
|
||||||
|
"http://localhost:53000",
|
||||||
|
"http://localhost:54000",
|
||||||
|
}, "allowed redirect URIs for CLI client")
|
||||||
|
rootCmd.PersistentFlags().StringVar(&config.DashboardClientID, "dashboard-client-id", "netbird-dashboard", "client ID for dashboard")
|
||||||
|
rootCmd.PersistentFlags().StringVar(&config.CLIClientID, "cli-client-id", "netbird-client", "client ID for CLI")
|
||||||
|
|
||||||
|
// Add subcommands
|
||||||
|
rootCmd.AddCommand(userCmd)
|
||||||
|
|
||||||
|
setFlagsFromEnvVars(rootCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute runs the root command
|
||||||
|
func Execute() error {
|
||||||
|
return rootCmd.Execute()
|
||||||
|
}
|
||||||
|
|
||||||
|
func execute(cmd *cobra.Command, args []string) error {
|
||||||
|
err := util.InitLog(config.LogLevel, config.LogFile)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to initialize log: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set default issuer if not provided
|
||||||
|
issuer := config.Issuer
|
||||||
|
if issuer == "" {
|
||||||
|
issuer = fmt.Sprintf("http://localhost:%d", config.ListenPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Starting NetBird Identity Provider")
|
||||||
|
log.Infof(" Port: %d", config.ListenPort)
|
||||||
|
log.Infof(" Issuer: %s", issuer)
|
||||||
|
log.Infof(" Data directory: %s", config.DataDir)
|
||||||
|
log.Infof(" Dev mode: %v", config.DevMode)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
// Create provider config
|
||||||
|
providerConfig := &oidcprovider.Config{
|
||||||
|
Issuer: issuer,
|
||||||
|
Port: config.ListenPort,
|
||||||
|
DataDir: config.DataDir,
|
||||||
|
DevMode: config.DevMode,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the provider
|
||||||
|
provider, err := oidcprovider.NewProvider(ctx, providerConfig)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create IdP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure default clients exist
|
||||||
|
if err := provider.EnsureDefaultClients(ctx, config.DashboardRedirectURIs, config.CLIRedirectURIs); err != nil {
|
||||||
|
return fmt.Errorf("failed to create default clients: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start the provider
|
||||||
|
if err := provider.Start(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to start IdP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("IdP is running")
|
||||||
|
log.Infof(" Discovery: %s/.well-known/openid-configuration", issuer)
|
||||||
|
log.Infof(" Authorization: %s/authorize", issuer)
|
||||||
|
log.Infof(" Token: %s/oauth/token", issuer)
|
||||||
|
log.Infof(" Device authorization: %s/device_authorization", issuer)
|
||||||
|
log.Infof(" JWKS: %s/keys", issuer)
|
||||||
|
log.Infof(" Login: %s/login", issuer)
|
||||||
|
log.Infof(" Device flow: %s/device", issuer)
|
||||||
|
|
||||||
|
// Wait for exit signal
|
||||||
|
waitForExitSignal()
|
||||||
|
|
||||||
|
log.Infof("Shutting down IdP...")
|
||||||
|
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10)
|
||||||
|
defer shutdownCancel()
|
||||||
|
|
||||||
|
if err := provider.Stop(shutdownCtx); err != nil {
|
||||||
|
return fmt.Errorf("failed to stop IdP: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("IdP stopped")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForExitSignal() {
|
||||||
|
osSigs := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
<-osSigs
|
||||||
|
}
|
||||||
249
idp/cmd/user.go
Normal file
249
idp/cmd/user.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"syscall"
|
||||||
|
"text/tabwriter"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"golang.org/x/term"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/idp/oidcprovider"
|
||||||
|
)
|
||||||
|
|
||||||
|
var userCmd = &cobra.Command{
|
||||||
|
Use: "user",
|
||||||
|
Short: "Manage IdP users",
|
||||||
|
Long: "Commands for managing users in the embedded IdP",
|
||||||
|
}
|
||||||
|
|
||||||
|
var userAddCmd = &cobra.Command{
|
||||||
|
Use: "add",
|
||||||
|
Short: "Add a new user",
|
||||||
|
Long: "Add a new user to the embedded IdP",
|
||||||
|
RunE: userAdd,
|
||||||
|
}
|
||||||
|
|
||||||
|
var userListCmd = &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Short: "List all users",
|
||||||
|
Long: "List all users in the embedded IdP",
|
||||||
|
RunE: userList,
|
||||||
|
}
|
||||||
|
|
||||||
|
var userDeleteCmd = &cobra.Command{
|
||||||
|
Use: "delete <username>",
|
||||||
|
Short: "Delete a user",
|
||||||
|
Long: "Delete a user from the embedded IdP",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
RunE: userDelete,
|
||||||
|
}
|
||||||
|
|
||||||
|
var userPasswordCmd = &cobra.Command{
|
||||||
|
Use: "password <username>",
|
||||||
|
Short: "Change user password",
|
||||||
|
Long: "Change password for a user in the embedded IdP",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
RunE: userChangePassword,
|
||||||
|
}
|
||||||
|
|
||||||
|
// User add flags
|
||||||
|
var (
|
||||||
|
userUsername string
|
||||||
|
userEmail string
|
||||||
|
userFirstName string
|
||||||
|
userLastName string
|
||||||
|
userPassword string
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
userAddCmd.Flags().StringVarP(&userUsername, "username", "u", "", "username (required)")
|
||||||
|
userAddCmd.Flags().StringVarP(&userEmail, "email", "e", "", "email address (required)")
|
||||||
|
userAddCmd.Flags().StringVarP(&userFirstName, "first-name", "f", "", "first name")
|
||||||
|
userAddCmd.Flags().StringVarP(&userLastName, "last-name", "l", "", "last name")
|
||||||
|
userAddCmd.Flags().StringVarP(&userPassword, "password", "p", "", "password (will prompt if not provided)")
|
||||||
|
_ = userAddCmd.MarkFlagRequired("username")
|
||||||
|
_ = userAddCmd.MarkFlagRequired("email")
|
||||||
|
|
||||||
|
userCmd.AddCommand(userAddCmd)
|
||||||
|
userCmd.AddCommand(userListCmd)
|
||||||
|
userCmd.AddCommand(userDeleteCmd)
|
||||||
|
userCmd.AddCommand(userPasswordCmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStore() (*oidcprovider.Store, error) {
|
||||||
|
ctx := context.Background()
|
||||||
|
store, err := oidcprovider.NewStore(ctx, config.DataDir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open store: %w", err)
|
||||||
|
}
|
||||||
|
return store, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func userAdd(cmd *cobra.Command, args []string) error {
|
||||||
|
store, err := getStore()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer store.Close()
|
||||||
|
|
||||||
|
password := userPassword
|
||||||
|
if password == "" {
|
||||||
|
// Prompt for password
|
||||||
|
fmt.Print("Enter password: ")
|
||||||
|
bytePassword, err := term.ReadPassword(int(syscall.Stdin))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read password: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
fmt.Print("Confirm password: ")
|
||||||
|
byteConfirm, err := term.ReadPassword(int(syscall.Stdin))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read password confirmation: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
if string(bytePassword) != string(byteConfirm) {
|
||||||
|
return fmt.Errorf("passwords do not match")
|
||||||
|
}
|
||||||
|
password = string(bytePassword)
|
||||||
|
}
|
||||||
|
|
||||||
|
if password == "" {
|
||||||
|
return fmt.Errorf("password cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
user := &oidcprovider.User{
|
||||||
|
Username: userUsername,
|
||||||
|
Email: userEmail,
|
||||||
|
FirstName: userFirstName,
|
||||||
|
LastName: userLastName,
|
||||||
|
Password: password,
|
||||||
|
EmailVerified: true, // Mark as verified since admin is creating the user
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
if err := store.CreateUser(ctx, user); err != nil {
|
||||||
|
return fmt.Errorf("failed to create user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("User '%s' created successfully (ID: %s)\n", userUsername, user.ID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func userList(cmd *cobra.Command, args []string) error {
|
||||||
|
store, err := getStore()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer store.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
users, err := store.ListUsers(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to list users: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(users) == 0 {
|
||||||
|
fmt.Println("No users found")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
|
||||||
|
fmt.Fprintln(w, "ID\tUSERNAME\tEMAIL\tNAME\tVERIFIED\tCREATED")
|
||||||
|
for _, user := range users {
|
||||||
|
name := fmt.Sprintf("%s %s", user.FirstName, user.LastName)
|
||||||
|
verified := "No"
|
||||||
|
if user.EmailVerified {
|
||||||
|
verified = "Yes"
|
||||||
|
}
|
||||||
|
fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||||
|
user.ID,
|
||||||
|
user.Username,
|
||||||
|
user.Email,
|
||||||
|
name,
|
||||||
|
verified,
|
||||||
|
user.CreatedAt.Format("2006-01-02 15:04"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
w.Flush()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func userDelete(cmd *cobra.Command, args []string) error {
|
||||||
|
username := args[0]
|
||||||
|
|
||||||
|
store, err := getStore()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer store.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Find user by username
|
||||||
|
user, err := store.GetUserByUsername(ctx, username)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("user '%s' not found", username)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := store.DeleteUser(ctx, user.ID); err != nil {
|
||||||
|
return fmt.Errorf("failed to delete user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("User '%s' deleted successfully\n", username)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func userChangePassword(cmd *cobra.Command, args []string) error {
|
||||||
|
username := args[0]
|
||||||
|
|
||||||
|
store, err := getStore()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer store.Close()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Find user by username
|
||||||
|
user, err := store.GetUserByUsername(ctx, username)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("user '%s' not found", username)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prompt for new password
|
||||||
|
fmt.Print("Enter new password: ")
|
||||||
|
bytePassword, err := term.ReadPassword(int(syscall.Stdin))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read password: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
fmt.Print("Confirm new password: ")
|
||||||
|
byteConfirm, err := term.ReadPassword(int(syscall.Stdin))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to read password confirmation: %w", err)
|
||||||
|
}
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
if string(bytePassword) != string(byteConfirm) {
|
||||||
|
return fmt.Errorf("passwords do not match")
|
||||||
|
}
|
||||||
|
|
||||||
|
password := string(bytePassword)
|
||||||
|
if password == "" {
|
||||||
|
return fmt.Errorf("password cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := store.UpdateUserPassword(ctx, user.ID, password); err != nil {
|
||||||
|
return fmt.Errorf("failed to update password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Password updated for user '%s'\n", username)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
13
idp/main.go
Normal file
13
idp/main.go
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/idp/cmd"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if err := cmd.Execute(); err != nil {
|
||||||
|
log.Fatalf("failed to execute command: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
249
idp/oidcprovider/client.go
Normal file
249
idp/oidcprovider/client.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
package oidcprovider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OIDCClient wraps the database Client model and implements op.Client interface
|
||||||
|
type OIDCClient struct {
|
||||||
|
client *Client
|
||||||
|
loginURL func(string) string
|
||||||
|
redirectURIs []string
|
||||||
|
grantTypes []oidc.GrantType
|
||||||
|
responseTypes []oidc.ResponseType
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOIDCClient creates an OIDCClient from a database Client
|
||||||
|
func NewOIDCClient(client *Client, loginURL func(string) string) *OIDCClient {
|
||||||
|
return &OIDCClient{
|
||||||
|
client: client,
|
||||||
|
loginURL: loginURL,
|
||||||
|
redirectURIs: ParseJSONArray(client.RedirectURIs),
|
||||||
|
grantTypes: parseGrantTypes(client.GrantTypes),
|
||||||
|
responseTypes: parseResponseTypes(client.ResponseTypes),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetID returns the client ID
|
||||||
|
func (c *OIDCClient) GetID() string {
|
||||||
|
return c.client.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedirectURIs returns the registered redirect URIs
|
||||||
|
func (c *OIDCClient) RedirectURIs() []string {
|
||||||
|
return c.redirectURIs
|
||||||
|
}
|
||||||
|
|
||||||
|
// PostLogoutRedirectURIs returns the registered post-logout redirect URIs
|
||||||
|
func (c *OIDCClient) PostLogoutRedirectURIs() []string {
|
||||||
|
return ParseJSONArray(c.client.PostLogoutURIs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplicationType returns the application type (native, web, user_agent)
|
||||||
|
func (c *OIDCClient) ApplicationType() op.ApplicationType {
|
||||||
|
switch c.client.ApplicationType {
|
||||||
|
case "native":
|
||||||
|
return op.ApplicationTypeNative
|
||||||
|
case "web":
|
||||||
|
return op.ApplicationTypeWeb
|
||||||
|
case "user_agent":
|
||||||
|
return op.ApplicationTypeUserAgent
|
||||||
|
default:
|
||||||
|
return op.ApplicationTypeWeb
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthMethod returns the authentication method
|
||||||
|
func (c *OIDCClient) AuthMethod() oidc.AuthMethod {
|
||||||
|
switch c.client.AuthMethod {
|
||||||
|
case "none":
|
||||||
|
return oidc.AuthMethodNone
|
||||||
|
case "client_secret_basic":
|
||||||
|
return oidc.AuthMethodBasic
|
||||||
|
case "client_secret_post":
|
||||||
|
return oidc.AuthMethodPost
|
||||||
|
case "private_key_jwt":
|
||||||
|
return oidc.AuthMethodPrivateKeyJWT
|
||||||
|
default:
|
||||||
|
return oidc.AuthMethodNone
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseTypes returns the allowed response types
|
||||||
|
func (c *OIDCClient) ResponseTypes() []oidc.ResponseType {
|
||||||
|
return c.responseTypes
|
||||||
|
}
|
||||||
|
|
||||||
|
// GrantTypes returns the allowed grant types
|
||||||
|
func (c *OIDCClient) GrantTypes() []oidc.GrantType {
|
||||||
|
return c.grantTypes
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoginURL returns the login URL for this client
|
||||||
|
func (c *OIDCClient) LoginURL(authRequestID string) string {
|
||||||
|
if c.loginURL != nil {
|
||||||
|
return c.loginURL(authRequestID)
|
||||||
|
}
|
||||||
|
return "/login?authRequestID=" + authRequestID
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccessTokenType returns the access token type
|
||||||
|
func (c *OIDCClient) AccessTokenType() op.AccessTokenType {
|
||||||
|
switch c.client.AccessTokenType {
|
||||||
|
case "jwt":
|
||||||
|
return op.AccessTokenTypeJWT
|
||||||
|
default:
|
||||||
|
return op.AccessTokenTypeBearer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDTokenLifetime returns the ID token lifetime
|
||||||
|
func (c *OIDCClient) IDTokenLifetime() time.Duration {
|
||||||
|
if c.client.IDTokenLifetime > 0 {
|
||||||
|
return time.Duration(c.client.IDTokenLifetime) * time.Second
|
||||||
|
}
|
||||||
|
return time.Hour // default 1 hour
|
||||||
|
}
|
||||||
|
|
||||||
|
// DevMode returns whether the client is in development mode
|
||||||
|
func (c *OIDCClient) DevMode() bool {
|
||||||
|
return c.client.DevMode
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestrictAdditionalIdTokenScopes returns any restricted scopes for ID tokens
|
||||||
|
func (c *OIDCClient) RestrictAdditionalIdTokenScopes() func(scopes []string) []string {
|
||||||
|
return func(scopes []string) []string {
|
||||||
|
return scopes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RestrictAdditionalAccessTokenScopes returns any restricted scopes for access tokens
|
||||||
|
func (c *OIDCClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string {
|
||||||
|
return func(scopes []string) []string {
|
||||||
|
return scopes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsScopeAllowed checks if a scope is allowed for this client
|
||||||
|
func (c *OIDCClient) IsScopeAllowed(scope string) bool {
|
||||||
|
// Allow all standard OIDC scopes
|
||||||
|
switch scope {
|
||||||
|
case oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone, oidc.ScopeAddress, oidc.ScopeOfflineAccess:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return true // Allow custom scopes as well
|
||||||
|
}
|
||||||
|
|
||||||
|
// IDTokenUserinfoClaimsAssertion returns whether userinfo claims should be included in ID token
|
||||||
|
func (c *OIDCClient) IDTokenUserinfoClaimsAssertion() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClockSkew returns the allowed clock skew for this client
|
||||||
|
func (c *OIDCClient) ClockSkew() time.Duration {
|
||||||
|
if c.client.ClockSkew > 0 {
|
||||||
|
return time.Duration(c.client.ClockSkew) * time.Second
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions for parsing grant types and response types
|
||||||
|
|
||||||
|
func parseGrantTypes(jsonStr string) []oidc.GrantType {
|
||||||
|
types := ParseJSONArray(jsonStr)
|
||||||
|
if len(types) == 0 {
|
||||||
|
// Default grant types
|
||||||
|
return []oidc.GrantType{
|
||||||
|
oidc.GrantTypeCode,
|
||||||
|
oidc.GrantTypeRefreshToken,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]oidc.GrantType, 0, len(types))
|
||||||
|
for _, t := range types {
|
||||||
|
switch t {
|
||||||
|
case "authorization_code":
|
||||||
|
result = append(result, oidc.GrantTypeCode)
|
||||||
|
case "refresh_token":
|
||||||
|
result = append(result, oidc.GrantTypeRefreshToken)
|
||||||
|
case "client_credentials":
|
||||||
|
result = append(result, oidc.GrantTypeClientCredentials)
|
||||||
|
case "urn:ietf:params:oauth:grant-type:device_code":
|
||||||
|
result = append(result, oidc.GrantTypeDeviceCode)
|
||||||
|
case "urn:ietf:params:oauth:grant-type:token-exchange":
|
||||||
|
result = append(result, oidc.GrantTypeTokenExchange)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseResponseTypes(jsonStr string) []oidc.ResponseType {
|
||||||
|
types := ParseJSONArray(jsonStr)
|
||||||
|
if len(types) == 0 {
|
||||||
|
// Default response types
|
||||||
|
return []oidc.ResponseType{oidc.ResponseTypeCode}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]oidc.ResponseType, 0, len(types))
|
||||||
|
for _, t := range types {
|
||||||
|
switch t {
|
||||||
|
case "code":
|
||||||
|
result = append(result, oidc.ResponseTypeCode)
|
||||||
|
case "id_token":
|
||||||
|
result = append(result, oidc.ResponseTypeIDToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateNativeClient creates a native client configuration (for CLI/mobile apps with PKCE)
|
||||||
|
func CreateNativeClient(id, name string, redirectURIs []string) *Client {
|
||||||
|
return &Client{
|
||||||
|
ID: id,
|
||||||
|
Name: name,
|
||||||
|
RedirectURIs: ToJSONArray(redirectURIs),
|
||||||
|
ApplicationType: "native",
|
||||||
|
AuthMethod: "none", // Public client
|
||||||
|
ResponseTypes: ToJSONArray([]string{"code"}),
|
||||||
|
GrantTypes: ToJSONArray([]string{"authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:device_code"}),
|
||||||
|
AccessTokenType: "bearer",
|
||||||
|
DevMode: true,
|
||||||
|
IDTokenLifetime: 3600,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateWebClient creates a web client configuration (for SPAs/web apps)
|
||||||
|
func CreateWebClient(id, secret, name string, redirectURIs []string) *Client {
|
||||||
|
return &Client{
|
||||||
|
ID: id,
|
||||||
|
Secret: secret,
|
||||||
|
Name: name,
|
||||||
|
RedirectURIs: ToJSONArray(redirectURIs),
|
||||||
|
ApplicationType: "web",
|
||||||
|
AuthMethod: "client_secret_basic",
|
||||||
|
ResponseTypes: ToJSONArray([]string{"code"}),
|
||||||
|
GrantTypes: ToJSONArray([]string{"authorization_code", "refresh_token"}),
|
||||||
|
AccessTokenType: "bearer",
|
||||||
|
DevMode: false,
|
||||||
|
IDTokenLifetime: 3600,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateSPAClient creates a Single Page Application client configuration (public client for SPAs)
|
||||||
|
func CreateSPAClient(id, name string, redirectURIs []string) *Client {
|
||||||
|
return &Client{
|
||||||
|
ID: id,
|
||||||
|
Name: name,
|
||||||
|
RedirectURIs: ToJSONArray(redirectURIs),
|
||||||
|
ApplicationType: "user_agent",
|
||||||
|
AuthMethod: "none", // Public client for SPA
|
||||||
|
ResponseTypes: ToJSONArray([]string{"code"}),
|
||||||
|
GrantTypes: ToJSONArray([]string{"authorization_code", "refresh_token"}),
|
||||||
|
AccessTokenType: "bearer",
|
||||||
|
DevMode: true,
|
||||||
|
IDTokenLifetime: 3600,
|
||||||
|
}
|
||||||
|
}
|
||||||
220
idp/oidcprovider/device.go
Normal file
220
idp/oidcprovider/device.go
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
package oidcprovider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"html/template"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/gorilla/securecookie"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeviceHandler handles the device authorization flow
|
||||||
|
type DeviceHandler struct {
|
||||||
|
storage *OIDCStorage
|
||||||
|
tmpl *template.Template
|
||||||
|
secureCookie *securecookie.SecureCookie
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDeviceHandler creates a new device handler
|
||||||
|
func NewDeviceHandler(storage *OIDCStorage) (*DeviceHandler, error) {
|
||||||
|
tmpl, err := template.ParseFS(templateFS, "templates/*.html")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate secure cookie keys
|
||||||
|
hashKey := securecookie.GenerateRandomKey(32)
|
||||||
|
blockKey := securecookie.GenerateRandomKey(32)
|
||||||
|
|
||||||
|
return &DeviceHandler{
|
||||||
|
storage: storage,
|
||||||
|
tmpl: tmpl,
|
||||||
|
secureCookie: securecookie.New(hashKey, blockKey),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Router returns the device flow router
|
||||||
|
func (h *DeviceHandler) Router() chi.Router {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Get("/", h.userCodePage)
|
||||||
|
r.Post("/login", h.handleLogin)
|
||||||
|
r.Post("/confirm", h.handleConfirm)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// userCodePage displays the user code entry form
|
||||||
|
func (h *DeviceHandler) userCodePage(w http.ResponseWriter, r *http.Request) {
|
||||||
|
userCode := r.URL.Query().Get("user_code")
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"UserCode": userCode,
|
||||||
|
"Error": "",
|
||||||
|
"Step": "code", // code, login, or confirm
|
||||||
|
}
|
||||||
|
|
||||||
|
if userCode != "" {
|
||||||
|
// Verify the user code exists
|
||||||
|
_, err := h.storage.GetDeviceAuthorizationByUserCode(r.Context(), userCode)
|
||||||
|
if err != nil {
|
||||||
|
data["Error"] = "Invalid or expired user code"
|
||||||
|
data["UserCode"] = ""
|
||||||
|
} else {
|
||||||
|
data["Step"] = "login"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.tmpl.ExecuteTemplate(w, "device.html", data); err != nil {
|
||||||
|
log.Errorf("failed to render device template: %v", err)
|
||||||
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleLogin processes the login form on the device flow
|
||||||
|
func (h *DeviceHandler) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
http.Error(w, "invalid form", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userCode := r.FormValue("user_code")
|
||||||
|
username := r.FormValue("username")
|
||||||
|
password := r.FormValue("password")
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"UserCode": userCode,
|
||||||
|
"Error": "",
|
||||||
|
"Step": "login",
|
||||||
|
}
|
||||||
|
|
||||||
|
if userCode == "" || username == "" || password == "" {
|
||||||
|
data["Error"] = "Please fill in all fields"
|
||||||
|
h.tmpl.ExecuteTemplate(w, "device.html", data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate credentials
|
||||||
|
userID, err := h.storage.CheckUsernamePasswordSimple(username, password)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("device login failed for user %s: %v", username, err)
|
||||||
|
data["Error"] = "Invalid username or password"
|
||||||
|
h.tmpl.ExecuteTemplate(w, "device.html", data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get device authorization info
|
||||||
|
authState, err := h.storage.GetDeviceAuthorizationByUserCode(r.Context(), userCode)
|
||||||
|
if err != nil {
|
||||||
|
data["Error"] = "Invalid or expired user code"
|
||||||
|
data["Step"] = "code"
|
||||||
|
data["UserCode"] = ""
|
||||||
|
h.tmpl.ExecuteTemplate(w, "device.html", data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set secure cookie with user info for confirmation step
|
||||||
|
cookieValue := map[string]string{
|
||||||
|
"user_code": userCode,
|
||||||
|
"user_id": userID,
|
||||||
|
}
|
||||||
|
|
||||||
|
encoded, err := h.secureCookie.Encode("device_auth", cookieValue)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to encode cookie: %v", err)
|
||||||
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "device_auth",
|
||||||
|
Value: encoded,
|
||||||
|
Path: "/device",
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: r.TLS != nil,
|
||||||
|
SameSite: http.SameSiteStrictMode,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Show confirmation page
|
||||||
|
data["Step"] = "confirm"
|
||||||
|
data["ClientID"] = authState.ClientID
|
||||||
|
data["Scopes"] = authState.Scopes
|
||||||
|
data["UserID"] = userID
|
||||||
|
|
||||||
|
h.tmpl.ExecuteTemplate(w, "device.html", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleConfirm processes the authorization decision
|
||||||
|
func (h *DeviceHandler) handleConfirm(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
http.Error(w, "invalid form", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get values from cookie
|
||||||
|
cookie, err := r.Cookie("device_auth")
|
||||||
|
if err != nil {
|
||||||
|
http.Redirect(w, r, "/device", http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var cookieValue map[string]string
|
||||||
|
if err := h.secureCookie.Decode("device_auth", cookie.Value, &cookieValue); err != nil {
|
||||||
|
http.Redirect(w, r, "/device", http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
userCode := cookieValue["user_code"]
|
||||||
|
userID := cookieValue["user_id"]
|
||||||
|
action := r.FormValue("action")
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"Step": "result",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clear the cookie
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: "device_auth",
|
||||||
|
Value: "",
|
||||||
|
Path: "/device",
|
||||||
|
MaxAge: -1,
|
||||||
|
HttpOnly: true,
|
||||||
|
})
|
||||||
|
|
||||||
|
if action == "allow" {
|
||||||
|
if err := h.storage.CompleteDeviceAuthorization(r.Context(), userCode, userID); err != nil {
|
||||||
|
log.Errorf("failed to complete device authorization: %v", err)
|
||||||
|
data["Error"] = "Failed to authorize device"
|
||||||
|
} else {
|
||||||
|
data["Success"] = true
|
||||||
|
data["Message"] = "Device authorized successfully! You can now close this window."
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := h.storage.DenyDeviceAuthorization(r.Context(), userCode); err != nil {
|
||||||
|
log.Errorf("failed to deny device authorization: %v", err)
|
||||||
|
}
|
||||||
|
data["Success"] = false
|
||||||
|
data["Message"] = "Authorization denied. You can close this window."
|
||||||
|
}
|
||||||
|
|
||||||
|
h.tmpl.ExecuteTemplate(w, "device.html", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateUserCode generates a user-friendly code for device flow
|
||||||
|
func GenerateUserCode() string {
|
||||||
|
// Generate a base20 code (BCDFGHJKLMNPQRSTVWXZ - no vowels to avoid words)
|
||||||
|
chars := "BCDFGHJKLMNPQRSTVWXZ"
|
||||||
|
b := securecookie.GenerateRandomKey(8)
|
||||||
|
result := make([]byte, 8)
|
||||||
|
for i := range result {
|
||||||
|
result[i] = chars[int(b[i])%len(chars)]
|
||||||
|
}
|
||||||
|
// Format as XXXX-XXXX
|
||||||
|
return string(result[:4]) + "-" + string(result[4:])
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateDeviceCode generates a secure device code
|
||||||
|
func GenerateDeviceCode() string {
|
||||||
|
b := securecookie.GenerateRandomKey(32)
|
||||||
|
return base64.RawURLEncoding.EncodeToString(b)
|
||||||
|
}
|
||||||
105
idp/oidcprovider/login.go
Normal file
105
idp/oidcprovider/login.go
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
package oidcprovider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
"html/template"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed templates/*.html
|
||||||
|
var templateFS embed.FS
|
||||||
|
|
||||||
|
// LoginHandler handles the login flow
|
||||||
|
type LoginHandler struct {
|
||||||
|
storage *OIDCStorage
|
||||||
|
callback func(string) string
|
||||||
|
tmpl *template.Template
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewLoginHandler creates a new login handler
|
||||||
|
func NewLoginHandler(storage *OIDCStorage, callback func(string) string) (*LoginHandler, error) {
|
||||||
|
tmpl, err := template.ParseFS(templateFS, "templates/*.html")
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &LoginHandler{
|
||||||
|
storage: storage,
|
||||||
|
callback: callback,
|
||||||
|
tmpl: tmpl,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Router returns the login router
|
||||||
|
func (h *LoginHandler) Router() chi.Router {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
r.Get("/", h.loginPage)
|
||||||
|
r.Post("/", h.handleLogin)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// loginPage displays the login form
|
||||||
|
func (h *LoginHandler) loginPage(w http.ResponseWriter, r *http.Request) {
|
||||||
|
authRequestID := r.URL.Query().Get("authRequestID")
|
||||||
|
if authRequestID == "" {
|
||||||
|
http.Error(w, "missing auth request ID", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"AuthRequestID": authRequestID,
|
||||||
|
"Error": "",
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.tmpl.ExecuteTemplate(w, "login.html", data); err != nil {
|
||||||
|
log.Errorf("failed to render login template: %v", err)
|
||||||
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleLogin processes the login form submission
|
||||||
|
func (h *LoginHandler) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
http.Error(w, "invalid form", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
authRequestID := r.FormValue("authRequestID")
|
||||||
|
username := r.FormValue("username")
|
||||||
|
password := r.FormValue("password")
|
||||||
|
|
||||||
|
if authRequestID == "" || username == "" || password == "" {
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"AuthRequestID": authRequestID,
|
||||||
|
"Error": "Please fill in all fields",
|
||||||
|
}
|
||||||
|
h.tmpl.ExecuteTemplate(w, "login.html", data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate credentials and get user ID
|
||||||
|
userID, err := h.storage.CheckUsernamePasswordSimple(username, password)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("login failed for user %s: %v", username, err)
|
||||||
|
data := map[string]interface{}{
|
||||||
|
"AuthRequestID": authRequestID,
|
||||||
|
"Error": "Invalid username or password",
|
||||||
|
}
|
||||||
|
h.tmpl.ExecuteTemplate(w, "login.html", data)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete the auth request
|
||||||
|
if err := h.storage.CompleteAuthRequest(r.Context(), authRequestID, userID); err != nil {
|
||||||
|
log.Errorf("failed to complete auth request: %v", err)
|
||||||
|
http.Error(w, "internal error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redirect to callback
|
||||||
|
callbackURL := h.callback(authRequestID)
|
||||||
|
http.Redirect(w, r, callbackURL, http.StatusFound)
|
||||||
|
}
|
||||||
136
idp/oidcprovider/models.go
Normal file
136
idp/oidcprovider/models.go
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
package oidcprovider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/text/language"
|
||||||
|
)
|
||||||
|
|
||||||
|
// User represents an OIDC user stored in the database
|
||||||
|
type User struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
Username string `gorm:"uniqueIndex;not null"`
|
||||||
|
Password string `gorm:"not null"` // bcrypt hashed
|
||||||
|
Email string
|
||||||
|
EmailVerified bool
|
||||||
|
FirstName string
|
||||||
|
LastName string
|
||||||
|
Phone string
|
||||||
|
PhoneVerified bool
|
||||||
|
PreferredLanguage string // language tag string
|
||||||
|
IsAdmin bool
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPreferredLanguage returns the user's preferred language as a language.Tag
|
||||||
|
func (u *User) GetPreferredLanguage() language.Tag {
|
||||||
|
if u.PreferredLanguage == "" {
|
||||||
|
return language.English
|
||||||
|
}
|
||||||
|
tag, err := language.Parse(u.PreferredLanguage)
|
||||||
|
if err != nil {
|
||||||
|
return language.English
|
||||||
|
}
|
||||||
|
return tag
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client represents an OIDC client (application) stored in the database
|
||||||
|
type Client struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
Secret string // bcrypt hashed, empty for public clients
|
||||||
|
Name string
|
||||||
|
RedirectURIs string // JSON array of redirect URIs
|
||||||
|
PostLogoutURIs string // JSON array of post-logout redirect URIs
|
||||||
|
ApplicationType string // native, web, user_agent
|
||||||
|
AuthMethod string // none, client_secret_basic, client_secret_post, private_key_jwt
|
||||||
|
ResponseTypes string // JSON array: code, id_token, token
|
||||||
|
GrantTypes string // JSON array: authorization_code, refresh_token, client_credentials, urn:ietf:params:oauth:grant-type:device_code
|
||||||
|
AccessTokenType string // bearer or jwt
|
||||||
|
DevMode bool // allows non-HTTPS redirect URIs
|
||||||
|
IDTokenLifetime int64 // in seconds, default 3600 (1 hour)
|
||||||
|
ClockSkew int64 // in seconds, allowed clock skew
|
||||||
|
CreatedAt time.Time
|
||||||
|
UpdatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthRequest represents an ongoing authorization request
|
||||||
|
type AuthRequest struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
ClientID string `gorm:"index"`
|
||||||
|
Scopes string // JSON array of scopes
|
||||||
|
RedirectURI string
|
||||||
|
State string
|
||||||
|
Nonce string
|
||||||
|
ResponseType string
|
||||||
|
ResponseMode string
|
||||||
|
CodeChallenge string
|
||||||
|
CodeMethod string // S256 or plain
|
||||||
|
UserID string // set after user authentication
|
||||||
|
Done bool // true when user has authenticated
|
||||||
|
AuthTime time.Time
|
||||||
|
CreatedAt time.Time
|
||||||
|
MaxAge int64 // max authentication age in seconds
|
||||||
|
Prompt string // none, login, consent, select_account
|
||||||
|
UILocales string // space-separated list of locales
|
||||||
|
LoginHint string
|
||||||
|
ACRValues string // space-separated list of ACR values
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthCode represents an authorization code
|
||||||
|
type AuthCode struct {
|
||||||
|
Code string `gorm:"primaryKey"`
|
||||||
|
AuthRequestID string `gorm:"index"`
|
||||||
|
CreatedAt time.Time
|
||||||
|
ExpiresAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// AccessToken represents an access token
|
||||||
|
type AccessToken struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
ApplicationID string `gorm:"index"`
|
||||||
|
Subject string `gorm:"index"`
|
||||||
|
Audience string // JSON array
|
||||||
|
Scopes string // JSON array
|
||||||
|
Expiration time.Time
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken represents a refresh token
|
||||||
|
type RefreshToken struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
Token string `gorm:"uniqueIndex"`
|
||||||
|
AuthRequestID string
|
||||||
|
ApplicationID string `gorm:"index"`
|
||||||
|
Subject string `gorm:"index"`
|
||||||
|
Audience string // JSON array
|
||||||
|
Scopes string // JSON array
|
||||||
|
AMR string // JSON array of authentication methods
|
||||||
|
AuthTime time.Time
|
||||||
|
Expiration time.Time
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceAuth represents a device authorization request
|
||||||
|
type DeviceAuth struct {
|
||||||
|
DeviceCode string `gorm:"primaryKey"`
|
||||||
|
UserCode string `gorm:"uniqueIndex"`
|
||||||
|
ClientID string `gorm:"index"`
|
||||||
|
Scopes string // JSON array
|
||||||
|
Subject string // set after user authentication
|
||||||
|
Audience string // JSON array
|
||||||
|
Done bool // true when user has authorized
|
||||||
|
Denied bool // true when user has denied
|
||||||
|
Expiration time.Time
|
||||||
|
CreatedAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// SigningKey represents a signing key for JWTs
|
||||||
|
type SigningKey struct {
|
||||||
|
ID string `gorm:"primaryKey"`
|
||||||
|
Algorithm string // RS256
|
||||||
|
PrivateKey []byte // PEM encoded
|
||||||
|
PublicKey []byte // PEM encoded
|
||||||
|
CreatedAt time.Time
|
||||||
|
Active bool
|
||||||
|
}
|
||||||
662
idp/oidcprovider/oidc_storage.go
Normal file
662
idp/oidcprovider/oidc_storage.go
Normal file
@@ -0,0 +1,662 @@
|
|||||||
|
package oidcprovider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
jose "github.com/go-jose/go-jose/v4"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/zitadel/oidc/v3/pkg/oidc"
|
||||||
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrInvalidRefreshToken is returned when a token is not a valid refresh token
|
||||||
|
var ErrInvalidRefreshToken = errors.New("invalid refresh token")
|
||||||
|
|
||||||
|
// OIDCStorage implements op.Storage interface for the OIDC provider
|
||||||
|
type OIDCStorage struct {
|
||||||
|
store *Store
|
||||||
|
issuer string
|
||||||
|
loginURL func(string) string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOIDCStorage creates a new OIDCStorage
|
||||||
|
func NewOIDCStorage(store *Store, issuer string) *OIDCStorage {
|
||||||
|
return &OIDCStorage{
|
||||||
|
store: store,
|
||||||
|
issuer: issuer,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLoginURL sets the login URL generator function
|
||||||
|
func (s *OIDCStorage) SetLoginURL(fn func(string) string) {
|
||||||
|
s.loginURL = fn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Health checks if the storage is healthy
|
||||||
|
func (s *OIDCStorage) Health(ctx context.Context) error {
|
||||||
|
sqlDB, err := s.store.db.DB()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return sqlDB.PingContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAuthRequest creates and stores a new authorization request
|
||||||
|
func (s *OIDCStorage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, userID string) (op.AuthRequest, error) {
|
||||||
|
req := &AuthRequest{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
ClientID: authReq.ClientID,
|
||||||
|
Scopes: ToJSONArray(authReq.Scopes),
|
||||||
|
RedirectURI: authReq.RedirectURI,
|
||||||
|
State: authReq.State,
|
||||||
|
Nonce: authReq.Nonce,
|
||||||
|
ResponseType: string(authReq.ResponseType),
|
||||||
|
ResponseMode: string(authReq.ResponseMode),
|
||||||
|
CodeChallenge: authReq.CodeChallenge,
|
||||||
|
CodeMethod: string(authReq.CodeChallengeMethod),
|
||||||
|
UserID: userID,
|
||||||
|
Done: userID != "",
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
Prompt: spaceSeparated(authReq.Prompt),
|
||||||
|
UILocales: authReq.UILocales.String(),
|
||||||
|
LoginHint: authReq.LoginHint,
|
||||||
|
ACRValues: spaceSeparated(authReq.ACRValues),
|
||||||
|
}
|
||||||
|
|
||||||
|
if authReq.MaxAge != nil {
|
||||||
|
req.MaxAge = int64(*authReq.MaxAge)
|
||||||
|
}
|
||||||
|
|
||||||
|
if userID != "" {
|
||||||
|
req.AuthTime = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.store.SaveAuthRequest(ctx, req); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &OIDCAuthRequest{req: req, storage: s}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthRequestByID retrieves an authorization request by ID
|
||||||
|
func (s *OIDCStorage) AuthRequestByID(ctx context.Context, id string) (op.AuthRequest, error) {
|
||||||
|
req, err := s.store.GetAuthRequestByID(ctx, id)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, fmt.Errorf("auth request not found: %s", id)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &OIDCAuthRequest{req: req, storage: s}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthRequestByCode retrieves an authorization request by code
|
||||||
|
func (s *OIDCStorage) AuthRequestByCode(ctx context.Context, code string) (op.AuthRequest, error) {
|
||||||
|
authCode, err := s.store.GetAuthCodeByCode(ctx, code)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, fmt.Errorf("auth code not found: %s", code)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().After(authCode.ExpiresAt) {
|
||||||
|
_ = s.store.DeleteAuthCode(ctx, code)
|
||||||
|
return nil, errors.New("auth code expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := s.store.GetAuthRequestByID(ctx, authCode.AuthRequestID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &OIDCAuthRequest{req: req, storage: s}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveAuthCode saves an authorization code linked to an auth request
|
||||||
|
func (s *OIDCStorage) SaveAuthCode(ctx context.Context, id, code string) error {
|
||||||
|
authCode := &AuthCode{
|
||||||
|
Code: code,
|
||||||
|
AuthRequestID: id,
|
||||||
|
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||||
|
}
|
||||||
|
return s.store.SaveAuthCode(ctx, authCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAuthRequest deletes an authorization request
|
||||||
|
func (s *OIDCStorage) DeleteAuthRequest(ctx context.Context, id string) error {
|
||||||
|
return s.store.DeleteAuthRequest(ctx, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAccessToken creates and stores an access token
|
||||||
|
func (s *OIDCStorage) CreateAccessToken(ctx context.Context, request op.TokenRequest) (string, time.Time, error) {
|
||||||
|
tokenID := uuid.New().String()
|
||||||
|
expiration := time.Now().Add(5 * time.Minute)
|
||||||
|
|
||||||
|
// Get client ID from the request if possible
|
||||||
|
var clientID string
|
||||||
|
if authReq, ok := request.(op.AuthRequest); ok {
|
||||||
|
clientID = authReq.GetClientID()
|
||||||
|
} else if refreshReq, ok := request.(op.RefreshTokenRequest); ok {
|
||||||
|
clientID = refreshReq.GetClientID()
|
||||||
|
}
|
||||||
|
|
||||||
|
token := &AccessToken{
|
||||||
|
ID: tokenID,
|
||||||
|
ApplicationID: clientID,
|
||||||
|
Subject: request.GetSubject(),
|
||||||
|
Audience: ToJSONArray(request.GetAudience()),
|
||||||
|
Scopes: ToJSONArray(request.GetScopes()),
|
||||||
|
Expiration: expiration,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.store.SaveAccessToken(ctx, token); err != nil {
|
||||||
|
return "", time.Time{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenID, expiration, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateAccessAndRefreshTokens creates both access and refresh tokens
|
||||||
|
func (s *OIDCStorage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) {
|
||||||
|
// Delete old refresh token if provided
|
||||||
|
if currentRefreshToken != "" {
|
||||||
|
_ = s.store.DeleteRefreshTokenByToken(ctx, currentRefreshToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create access token
|
||||||
|
accessTokenID, expiration, err = s.CreateAccessToken(ctx, request)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", time.Time{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get additional info from the request if possible
|
||||||
|
var clientID string
|
||||||
|
var authTime time.Time
|
||||||
|
var amr []string
|
||||||
|
|
||||||
|
if authReq, ok := request.(op.AuthRequest); ok {
|
||||||
|
clientID = authReq.GetClientID()
|
||||||
|
authTime = authReq.GetAuthTime()
|
||||||
|
amr = authReq.GetAMR()
|
||||||
|
} else if refreshReq, ok := request.(op.RefreshTokenRequest); ok {
|
||||||
|
clientID = refreshReq.GetClientID()
|
||||||
|
authTime = refreshReq.GetAuthTime()
|
||||||
|
amr = refreshReq.GetAMR()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create refresh token
|
||||||
|
refreshToken := &RefreshToken{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
Token: uuid.New().String(),
|
||||||
|
ApplicationID: clientID,
|
||||||
|
Subject: request.GetSubject(),
|
||||||
|
Audience: ToJSONArray(request.GetAudience()),
|
||||||
|
Scopes: ToJSONArray(request.GetScopes()),
|
||||||
|
AuthTime: authTime,
|
||||||
|
AMR: ToJSONArray(amr),
|
||||||
|
Expiration: time.Now().Add(5 * time.Hour), // 5 hour refresh token lifetime
|
||||||
|
}
|
||||||
|
|
||||||
|
if authReq, ok := request.(op.AuthRequest); ok {
|
||||||
|
refreshToken.AuthRequestID = authReq.GetID()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.store.SaveRefreshToken(ctx, refreshToken); err != nil {
|
||||||
|
return "", "", time.Time{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return accessTokenID, refreshToken.Token, expiration, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenRequestByRefreshToken retrieves token request info from refresh token
|
||||||
|
func (s *OIDCStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) {
|
||||||
|
token, err := s.store.GetRefreshToken(ctx, refreshToken)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, errors.New("refresh token not found")
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().After(token.Expiration) {
|
||||||
|
_ = s.store.DeleteRefreshTokenByToken(ctx, refreshToken)
|
||||||
|
return nil, errors.New("refresh token expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &OIDCRefreshToken{token: token}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TerminateSession terminates a user session
|
||||||
|
func (s *OIDCStorage) TerminateSession(ctx context.Context, userID, clientID string) error {
|
||||||
|
// For now, we don't track sessions separately
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeToken revokes a token
|
||||||
|
func (s *OIDCStorage) RevokeToken(ctx context.Context, tokenOrID string, userID string, clientID string) *oidc.Error {
|
||||||
|
// Try to delete as refresh token
|
||||||
|
if err := s.store.DeleteRefreshTokenByToken(ctx, tokenOrID); err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to delete as access token
|
||||||
|
if err := s.store.DeleteAccessToken(ctx, tokenOrID); err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil // Silently succeed even if token not found (per spec)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRefreshTokenInfo returns info about a refresh token
|
||||||
|
func (s *OIDCStorage) GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error) {
|
||||||
|
refreshToken, err := s.store.GetRefreshToken(ctx, token)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return "", "", ErrInvalidRefreshToken
|
||||||
|
}
|
||||||
|
return "", "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if refreshToken.ApplicationID != clientID {
|
||||||
|
return "", "", ErrInvalidRefreshToken
|
||||||
|
}
|
||||||
|
|
||||||
|
return refreshToken.Subject, refreshToken.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientByClientID retrieves a client by ID
|
||||||
|
func (s *OIDCStorage) GetClientByClientID(ctx context.Context, clientID string) (op.Client, error) {
|
||||||
|
client, err := s.store.GetClientByID(ctx, clientID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, fmt.Errorf("client not found: %s", clientID)
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return NewOIDCClient(client, s.loginURL), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthorizeClientIDSecret validates client credentials
|
||||||
|
func (s *OIDCStorage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error {
|
||||||
|
_, err := s.store.ValidateClientSecret(ctx, clientID, clientSecret)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUserinfoFromScopes sets userinfo claims based on scopes
|
||||||
|
func (s *OIDCStorage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error {
|
||||||
|
return s.setUserinfo(ctx, userinfo, userID, scopes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetUserinfoFromToken sets userinfo claims from an access token
|
||||||
|
func (s *OIDCStorage) SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error {
|
||||||
|
token, err := s.store.GetAccessTokenByID(ctx, tokenID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return s.setUserinfo(ctx, userinfo, token.Subject, ParseJSONArray(token.Scopes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// setUserinfo populates userinfo based on user data and scopes
|
||||||
|
func (s *OIDCStorage) setUserinfo(ctx context.Context, userinfo *oidc.UserInfo, userID string, scopes []string) error {
|
||||||
|
user, err := s.store.GetUserByID(ctx, userID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, scope := range scopes {
|
||||||
|
switch scope {
|
||||||
|
case oidc.ScopeOpenID:
|
||||||
|
userinfo.Subject = user.ID
|
||||||
|
case oidc.ScopeProfile:
|
||||||
|
userinfo.Name = fmt.Sprintf("%s %s", user.FirstName, user.LastName)
|
||||||
|
userinfo.GivenName = user.FirstName
|
||||||
|
userinfo.FamilyName = user.LastName
|
||||||
|
userinfo.PreferredUsername = user.Username
|
||||||
|
userinfo.Locale = oidc.NewLocale(user.GetPreferredLanguage())
|
||||||
|
case oidc.ScopeEmail:
|
||||||
|
userinfo.Email = user.Email
|
||||||
|
userinfo.EmailVerified = oidc.Bool(user.EmailVerified)
|
||||||
|
case oidc.ScopePhone:
|
||||||
|
userinfo.PhoneNumber = user.Phone
|
||||||
|
userinfo.PhoneNumberVerified = user.PhoneVerified
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIntrospectionFromToken sets introspection response from token
|
||||||
|
func (s *OIDCStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) error {
|
||||||
|
token, err := s.store.GetAccessTokenByID(ctx, tokenID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
introspection.Active = true
|
||||||
|
introspection.Subject = token.Subject
|
||||||
|
introspection.ClientID = token.ApplicationID
|
||||||
|
introspection.Scope = ParseJSONArray(token.Scopes)
|
||||||
|
introspection.Expiration = oidc.FromTime(token.Expiration)
|
||||||
|
introspection.IssuedAt = oidc.FromTime(token.CreatedAt)
|
||||||
|
introspection.Audience = ParseJSONArray(token.Audience)
|
||||||
|
introspection.Issuer = s.issuer
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPrivateClaimsFromScopes returns additional claims based on scopes
|
||||||
|
func (s *OIDCStorage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]any, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetKeyByIDAndClientID retrieves a key by ID for a client
|
||||||
|
func (s *OIDCStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error) {
|
||||||
|
return nil, errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateJWTProfileScopes validates scopes for JWT profile grant
|
||||||
|
func (s *OIDCStorage) ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error) {
|
||||||
|
return scopes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SigningKey returns the active signing key for token signing
|
||||||
|
func (s *OIDCStorage) SigningKey(ctx context.Context) (op.SigningKey, error) {
|
||||||
|
key, err := s.store.GetSigningKey(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
block, _ := pem.Decode(key.PrivateKey)
|
||||||
|
if block == nil {
|
||||||
|
return nil, errors.New("failed to decode private key PEM")
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &signingKey{
|
||||||
|
id: key.ID,
|
||||||
|
algorithm: jose.RS256,
|
||||||
|
privateKey: privateKey,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignatureAlgorithms returns supported signature algorithms
|
||||||
|
func (s *OIDCStorage) SignatureAlgorithms(ctx context.Context) ([]jose.SignatureAlgorithm, error) {
|
||||||
|
return []jose.SignatureAlgorithm{jose.RS256}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeySet returns the public key set for token verification
|
||||||
|
func (s *OIDCStorage) KeySet(ctx context.Context) ([]op.Key, error) {
|
||||||
|
key, err := s.store.GetSigningKey(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
block, _ := pem.Decode(key.PublicKey)
|
||||||
|
if block == nil {
|
||||||
|
return nil, errors.New("failed to decode public key PEM")
|
||||||
|
}
|
||||||
|
|
||||||
|
publicKey, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse public key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rsaKey, ok := publicKey.(*rsa.PublicKey)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("public key is not RSA")
|
||||||
|
}
|
||||||
|
|
||||||
|
return []op.Key{
|
||||||
|
&publicKeyInfo{
|
||||||
|
id: key.ID,
|
||||||
|
algorithm: jose.RS256,
|
||||||
|
publicKey: rsaKey,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Device Authorization Flow methods
|
||||||
|
|
||||||
|
// StoreDeviceAuthorization stores a device authorization request
|
||||||
|
func (s *OIDCStorage) StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) error {
|
||||||
|
auth := &DeviceAuth{
|
||||||
|
DeviceCode: deviceCode,
|
||||||
|
UserCode: userCode,
|
||||||
|
ClientID: clientID,
|
||||||
|
Scopes: ToJSONArray(scopes),
|
||||||
|
Expiration: expires,
|
||||||
|
}
|
||||||
|
return s.store.SaveDeviceAuth(ctx, auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceAuthorizationState retrieves the state of a device authorization
|
||||||
|
func (s *OIDCStorage) GetDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string) (*op.DeviceAuthorizationState, error) {
|
||||||
|
auth, err := s.store.GetDeviceAuthByDeviceCode(ctx, deviceCode)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, errors.New("device authorization not found")
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if auth.ClientID != clientID {
|
||||||
|
return nil, errors.New("client ID mismatch")
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().After(auth.Expiration) {
|
||||||
|
_ = s.store.DeleteDeviceAuth(ctx, deviceCode)
|
||||||
|
return &op.DeviceAuthorizationState{Expires: auth.Expiration}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
state := &op.DeviceAuthorizationState{
|
||||||
|
ClientID: auth.ClientID,
|
||||||
|
Scopes: ParseJSONArray(auth.Scopes),
|
||||||
|
Expires: auth.Expiration,
|
||||||
|
}
|
||||||
|
|
||||||
|
if auth.Denied {
|
||||||
|
state.Denied = true
|
||||||
|
} else if auth.Done {
|
||||||
|
state.Done = true
|
||||||
|
state.Subject = auth.Subject
|
||||||
|
}
|
||||||
|
|
||||||
|
return state, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceAuthorizationByUserCode retrieves device auth by user code
|
||||||
|
func (s *OIDCStorage) GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error) {
|
||||||
|
auth, err := s.store.GetDeviceAuthByUserCode(ctx, userCode)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, errors.New("device authorization not found")
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().After(auth.Expiration) {
|
||||||
|
return nil, errors.New("device authorization expired")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &op.DeviceAuthorizationState{
|
||||||
|
ClientID: auth.ClientID,
|
||||||
|
Scopes: ParseJSONArray(auth.Scopes),
|
||||||
|
Expires: auth.Expiration,
|
||||||
|
Done: auth.Done,
|
||||||
|
Denied: auth.Denied,
|
||||||
|
Subject: auth.Subject,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompleteDeviceAuthorization marks a device authorization as complete
|
||||||
|
func (s *OIDCStorage) CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error {
|
||||||
|
auth, err := s.store.GetDeviceAuthByUserCode(ctx, userCode)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.Done = true
|
||||||
|
auth.Subject = subject
|
||||||
|
return s.store.UpdateDeviceAuth(ctx, auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DenyDeviceAuthorization marks a device authorization as denied
|
||||||
|
func (s *OIDCStorage) DenyDeviceAuthorization(ctx context.Context, userCode string) error {
|
||||||
|
auth, err := s.store.GetDeviceAuthByUserCode(ctx, userCode)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
auth.Denied = true
|
||||||
|
return s.store.UpdateDeviceAuth(ctx, auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
// User authentication methods
|
||||||
|
|
||||||
|
// CheckUsernamePassword validates user credentials
|
||||||
|
func (s *OIDCStorage) CheckUsernamePassword(username, password, authRequestID string) error {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err := s.store.ValidateUserPassword(ctx, username, password)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckUsernamePasswordSimple validates user credentials and returns the user ID
|
||||||
|
func (s *OIDCStorage) CheckUsernamePasswordSimple(username, password string) (string, error) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
user, err := s.store.ValidateUserPassword(ctx, username, password)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user.ID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompleteAuthRequest completes an auth request after user authentication
|
||||||
|
func (s *OIDCStorage) CompleteAuthRequest(ctx context.Context, authRequestID, userID string) error {
|
||||||
|
req, err := s.store.GetAuthRequestByID(ctx, authRequestID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req.UserID = userID
|
||||||
|
req.Done = true
|
||||||
|
req.AuthTime = time.Now()
|
||||||
|
|
||||||
|
return s.store.UpdateAuthRequest(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper types
|
||||||
|
|
||||||
|
// signingKey implements op.SigningKey
|
||||||
|
type signingKey struct {
|
||||||
|
id string
|
||||||
|
algorithm jose.SignatureAlgorithm
|
||||||
|
privateKey *rsa.PrivateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *signingKey) SignatureAlgorithm() jose.SignatureAlgorithm {
|
||||||
|
return k.algorithm
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *signingKey) Key() interface{} {
|
||||||
|
return k.privateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *signingKey) ID() string {
|
||||||
|
return k.id
|
||||||
|
}
|
||||||
|
|
||||||
|
// publicKeyInfo implements op.Key
|
||||||
|
type publicKeyInfo struct {
|
||||||
|
id string
|
||||||
|
algorithm jose.SignatureAlgorithm
|
||||||
|
publicKey *rsa.PublicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *publicKeyInfo) ID() string {
|
||||||
|
return k.id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *publicKeyInfo) Algorithm() jose.SignatureAlgorithm {
|
||||||
|
return k.algorithm
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *publicKeyInfo) Use() string {
|
||||||
|
return "sig"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (k *publicKeyInfo) Key() interface{} {
|
||||||
|
return k.publicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// OIDCAuthRequest wraps AuthRequest for the op.AuthRequest interface
|
||||||
|
type OIDCAuthRequest struct {
|
||||||
|
req *AuthRequest
|
||||||
|
storage *OIDCStorage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OIDCAuthRequest) GetID() string { return r.req.ID }
|
||||||
|
func (r *OIDCAuthRequest) GetACR() string { return "" }
|
||||||
|
func (r *OIDCAuthRequest) GetAMR() []string { return []string{"pwd"} }
|
||||||
|
func (r *OIDCAuthRequest) GetAudience() []string { return []string{r.req.ClientID} }
|
||||||
|
func (r *OIDCAuthRequest) GetAuthTime() time.Time { return r.req.AuthTime }
|
||||||
|
func (r *OIDCAuthRequest) GetClientID() string { return r.req.ClientID }
|
||||||
|
func (r *OIDCAuthRequest) GetCodeChallenge() *oidc.CodeChallenge {
|
||||||
|
if r.req.CodeChallenge == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &oidc.CodeChallenge{
|
||||||
|
Challenge: r.req.CodeChallenge,
|
||||||
|
Method: oidc.CodeChallengeMethod(r.req.CodeMethod),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func (r *OIDCAuthRequest) GetNonce() string { return r.req.Nonce }
|
||||||
|
func (r *OIDCAuthRequest) GetRedirectURI() string { return r.req.RedirectURI }
|
||||||
|
func (r *OIDCAuthRequest) GetResponseType() oidc.ResponseType {
|
||||||
|
return oidc.ResponseType(r.req.ResponseType)
|
||||||
|
}
|
||||||
|
func (r *OIDCAuthRequest) GetResponseMode() oidc.ResponseMode {
|
||||||
|
return oidc.ResponseMode(r.req.ResponseMode)
|
||||||
|
}
|
||||||
|
func (r *OIDCAuthRequest) GetScopes() []string { return ParseJSONArray(r.req.Scopes) }
|
||||||
|
func (r *OIDCAuthRequest) GetState() string { return r.req.State }
|
||||||
|
func (r *OIDCAuthRequest) GetSubject() string { return r.req.UserID }
|
||||||
|
func (r *OIDCAuthRequest) Done() bool { return r.req.Done }
|
||||||
|
|
||||||
|
// OIDCRefreshToken wraps RefreshToken for the op.RefreshTokenRequest interface
|
||||||
|
type OIDCRefreshToken struct {
|
||||||
|
token *RefreshToken
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *OIDCRefreshToken) GetAMR() []string { return ParseJSONArray(r.token.AMR) }
|
||||||
|
func (r *OIDCRefreshToken) GetAudience() []string { return ParseJSONArray(r.token.Audience) }
|
||||||
|
func (r *OIDCRefreshToken) GetAuthTime() time.Time { return r.token.AuthTime }
|
||||||
|
func (r *OIDCRefreshToken) GetClientID() string { return r.token.ApplicationID }
|
||||||
|
func (r *OIDCRefreshToken) GetScopes() []string { return ParseJSONArray(r.token.Scopes) }
|
||||||
|
func (r *OIDCRefreshToken) GetSubject() string { return r.token.Subject }
|
||||||
|
func (r *OIDCRefreshToken) SetCurrentScopes(scopes []string) {}
|
||||||
|
|
||||||
|
// Helper functions
|
||||||
|
|
||||||
|
func spaceSeparated(items []string) string {
|
||||||
|
return strings.Join(items, " ")
|
||||||
|
}
|
||||||
265
idp/oidcprovider/provider.go
Normal file
265
idp/oidcprovider/provider.go
Normal file
@@ -0,0 +1,265 @@
|
|||||||
|
package oidcprovider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/zitadel/oidc/v3/pkg/op"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds the configuration for the OIDC provider
|
||||||
|
type Config struct {
|
||||||
|
// Issuer is the OIDC issuer URL (e.g., "https://idp.example.com")
|
||||||
|
Issuer string
|
||||||
|
// Port is the port to listen on
|
||||||
|
Port int
|
||||||
|
// DataDir is the directory to store OIDC data (SQLite database)
|
||||||
|
DataDir string
|
||||||
|
// DevMode enables development mode (allows HTTP, localhost)
|
||||||
|
DevMode bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider represents the embedded OIDC provider
|
||||||
|
type Provider struct {
|
||||||
|
config *Config
|
||||||
|
store *Store
|
||||||
|
storage *OIDCStorage
|
||||||
|
provider op.OpenIDProvider
|
||||||
|
router chi.Router
|
||||||
|
httpServer *http.Server
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProvider creates a new OIDC provider
|
||||||
|
func NewProvider(ctx context.Context, config *Config) (*Provider, error) {
|
||||||
|
// Create the SQLite store
|
||||||
|
store, err := NewStore(ctx, config.DataDir)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create OIDC store: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the OIDC storage adapter
|
||||||
|
storage := NewOIDCStorage(store, config.Issuer)
|
||||||
|
|
||||||
|
p := &Provider{
|
||||||
|
config: config,
|
||||||
|
store: store,
|
||||||
|
storage: storage,
|
||||||
|
}
|
||||||
|
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the OIDC provider server
|
||||||
|
func (p *Provider) Start(ctx context.Context) error {
|
||||||
|
// Create the router
|
||||||
|
router := chi.NewRouter()
|
||||||
|
router.Use(middleware.Logger)
|
||||||
|
router.Use(middleware.Recoverer)
|
||||||
|
router.Use(middleware.RequestID)
|
||||||
|
|
||||||
|
// Create the OIDC provider
|
||||||
|
key := sha256.Sum256([]byte(p.config.Issuer + "encryption-key"))
|
||||||
|
|
||||||
|
opConfig := &op.Config{
|
||||||
|
CryptoKey: key,
|
||||||
|
DefaultLogoutRedirectURI: "/logged-out",
|
||||||
|
CodeMethodS256: true,
|
||||||
|
AuthMethodPost: true,
|
||||||
|
AuthMethodPrivateKeyJWT: true,
|
||||||
|
GrantTypeRefreshToken: true,
|
||||||
|
RequestObjectSupported: true,
|
||||||
|
DeviceAuthorization: op.DeviceAuthorizationConfig{
|
||||||
|
Lifetime: 5 * time.Minute,
|
||||||
|
PollInterval: 5 * time.Second,
|
||||||
|
UserFormPath: "/device",
|
||||||
|
UserCode: op.UserCodeBase20,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the login URL generator
|
||||||
|
p.storage.SetLoginURL(func(authRequestID string) string {
|
||||||
|
return fmt.Sprintf("/login?authRequestID=%s", authRequestID)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create the provider with options
|
||||||
|
var opts []op.Option
|
||||||
|
if p.config.DevMode {
|
||||||
|
opts = append(opts, op.WithAllowInsecure())
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, err := op.NewProvider(opConfig, p.storage, op.StaticIssuer(p.config.Issuer), opts...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create OIDC provider: %w", err)
|
||||||
|
}
|
||||||
|
p.provider = provider
|
||||||
|
|
||||||
|
// Set up login handler
|
||||||
|
loginHandler, err := NewLoginHandler(p.storage, func(authRequestID string) string {
|
||||||
|
return provider.AuthorizationEndpoint().Absolute("/authorize/callback") + "?id=" + authRequestID
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create login handler: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up device handler
|
||||||
|
deviceHandler, err := NewDeviceHandler(p.storage)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create device handler: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mount routes
|
||||||
|
router.Mount("/login", loginHandler.Router())
|
||||||
|
router.Mount("/device", deviceHandler.Router())
|
||||||
|
router.Get("/logged-out", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
w.Write([]byte(`<!DOCTYPE html><html><head><title>Logged Out</title></head><body><h1>You have been logged out</h1><p>You can close this window.</p></body></html>`))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Mount the OIDC provider at root
|
||||||
|
router.Mount("/", provider)
|
||||||
|
|
||||||
|
p.router = router
|
||||||
|
|
||||||
|
// Create HTTP server
|
||||||
|
addr := fmt.Sprintf(":%d", p.config.Port)
|
||||||
|
p.httpServer = &http.Server{
|
||||||
|
Addr: addr,
|
||||||
|
Handler: router,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start server in goroutine
|
||||||
|
go func() {
|
||||||
|
log.Infof("Starting OIDC provider on %s (issuer: %s)", addr, p.config.Issuer)
|
||||||
|
if err := p.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Errorf("OIDC provider server error: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Start cleanup goroutine
|
||||||
|
go p.cleanupLoop(ctx)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the OIDC provider server
|
||||||
|
func (p *Provider) Stop(ctx context.Context) error {
|
||||||
|
if p.httpServer != nil {
|
||||||
|
if err := p.httpServer.Shutdown(ctx); err != nil {
|
||||||
|
return fmt.Errorf("failed to shutdown OIDC server: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if p.store != nil {
|
||||||
|
if err := p.store.Close(); err != nil {
|
||||||
|
return fmt.Errorf("failed to close OIDC store: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupLoop periodically cleans up expired tokens
|
||||||
|
func (p *Provider) cleanupLoop(ctx context.Context) {
|
||||||
|
ticker := time.NewTicker(15 * time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := p.store.CleanupExpired(ctx); err != nil {
|
||||||
|
log.Warnf("OIDC cleanup error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store returns the underlying store for user/client management
|
||||||
|
func (p *Provider) Store() *Store {
|
||||||
|
return p.store
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIssuer returns the issuer URL
|
||||||
|
func (p *Provider) GetIssuer() string {
|
||||||
|
return p.config.Issuer
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDiscoveryEndpoint returns the OpenID Connect discovery endpoint
|
||||||
|
func (p *Provider) GetDiscoveryEndpoint() string {
|
||||||
|
return p.config.Issuer + "/.well-known/openid-configuration"
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetTokenEndpoint returns the token endpoint
|
||||||
|
func (p *Provider) GetTokenEndpoint() string {
|
||||||
|
return p.config.Issuer + "/oauth/token"
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthorizationEndpoint returns the authorization endpoint
|
||||||
|
func (p *Provider) GetAuthorizationEndpoint() string {
|
||||||
|
return p.config.Issuer + "/authorize"
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceAuthorizationEndpoint returns the device authorization endpoint
|
||||||
|
func (p *Provider) GetDeviceAuthorizationEndpoint() string {
|
||||||
|
return p.config.Issuer + "/device_authorization"
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetJWKSEndpoint returns the JWKS endpoint
|
||||||
|
func (p *Provider) GetJWKSEndpoint() string {
|
||||||
|
return p.config.Issuer + "/keys"
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserInfoEndpoint returns the userinfo endpoint
|
||||||
|
func (p *Provider) GetUserInfoEndpoint() string {
|
||||||
|
return p.config.Issuer + "/userinfo"
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnsureDefaultClients ensures the default NetBird clients exist
|
||||||
|
func (p *Provider) EnsureDefaultClients(ctx context.Context, dashboardRedirectURIs, cliRedirectURIs []string) error {
|
||||||
|
// Check if CLI client exists
|
||||||
|
_, err := p.store.GetClientByID(ctx, "netbird-client")
|
||||||
|
if err != nil {
|
||||||
|
// Create CLI client (native, public, supports PKCE and device flow)
|
||||||
|
cliClient := CreateNativeClient("netbird-client", "NetBird CLI", cliRedirectURIs)
|
||||||
|
if err := p.store.CreateClient(ctx, cliClient); err != nil {
|
||||||
|
return fmt.Errorf("failed to create CLI client: %w", err)
|
||||||
|
}
|
||||||
|
log.Info("Created default NetBird CLI client")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if dashboard client exists
|
||||||
|
_, err = p.store.GetClientByID(ctx, "netbird-dashboard")
|
||||||
|
if err != nil {
|
||||||
|
// Create dashboard client (SPA, public, supports PKCE)
|
||||||
|
dashboardClient := CreateSPAClient("netbird-dashboard", "NetBird Dashboard", dashboardRedirectURIs)
|
||||||
|
if err := p.store.CreateClient(ctx, dashboardClient); err != nil {
|
||||||
|
return fmt.Errorf("failed to create dashboard client: %w", err)
|
||||||
|
}
|
||||||
|
log.Info("Created default NetBird Dashboard client")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateUser creates a new user (convenience method)
|
||||||
|
func (p *Provider) CreateUser(ctx context.Context, username, password, email, firstName, lastName string) (*User, error) {
|
||||||
|
user := &User{
|
||||||
|
Username: username,
|
||||||
|
Password: password, // Will be hashed by store
|
||||||
|
Email: email,
|
||||||
|
EmailVerified: false,
|
||||||
|
FirstName: firstName,
|
||||||
|
LastName: lastName,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := p.store.CreateUser(ctx, user); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
493
idp/oidcprovider/store.go
Normal file
493
idp/oidcprovider/store.go
Normal file
@@ -0,0 +1,493 @@
|
|||||||
|
package oidcprovider
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
"gorm.io/driver/sqlite"
|
||||||
|
"gorm.io/gorm"
|
||||||
|
"gorm.io/gorm/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Store handles persistence for OIDC provider data
|
||||||
|
type Store struct {
|
||||||
|
db *gorm.DB
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStore creates a new Store with SQLite backend
|
||||||
|
func NewStore(ctx context.Context, dataDir string) (*Store, error) {
|
||||||
|
dbPath := fmt.Sprintf("%s/oidc.db", dataDir)
|
||||||
|
|
||||||
|
db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{
|
||||||
|
Logger: logger.Default.LogMode(logger.Silent),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open OIDC database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable WAL mode for better concurrency
|
||||||
|
if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to enable WAL mode: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Auto-migrate tables
|
||||||
|
if err := db.AutoMigrate(
|
||||||
|
&User{},
|
||||||
|
&Client{},
|
||||||
|
&AuthRequest{},
|
||||||
|
&AuthCode{},
|
||||||
|
&AccessToken{},
|
||||||
|
&RefreshToken{},
|
||||||
|
&DeviceAuth{},
|
||||||
|
&SigningKey{},
|
||||||
|
); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to migrate OIDC database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
store := &Store{db: db}
|
||||||
|
|
||||||
|
// Ensure we have a signing key
|
||||||
|
if err := store.ensureSigningKey(ctx); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to ensure signing key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return store, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the database connection
|
||||||
|
func (s *Store) Close() error {
|
||||||
|
sqlDB, err := s.db.DB()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return sqlDB.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureSigningKey creates a signing key if one doesn't exist
|
||||||
|
func (s *Store) ensureSigningKey(ctx context.Context) error {
|
||||||
|
var key SigningKey
|
||||||
|
err := s.db.WithContext(ctx).Where("active = ?", true).First(&key).Error
|
||||||
|
if err == nil {
|
||||||
|
return nil // Key exists
|
||||||
|
}
|
||||||
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate new RSA key pair
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate RSA key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: x509.MarshalPKCS1PrivateKey(privateKey),
|
||||||
|
})
|
||||||
|
|
||||||
|
publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal public key: %w", err)
|
||||||
|
}
|
||||||
|
publicKeyPEM := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "PUBLIC KEY",
|
||||||
|
Bytes: publicKeyBytes,
|
||||||
|
})
|
||||||
|
|
||||||
|
newKey := &SigningKey{
|
||||||
|
ID: uuid.New().String(),
|
||||||
|
Algorithm: "RS256",
|
||||||
|
PrivateKey: privateKeyPEM,
|
||||||
|
PublicKey: publicKeyPEM,
|
||||||
|
CreatedAt: time.Now(),
|
||||||
|
Active: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.db.WithContext(ctx).Create(newKey).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSigningKey returns the active signing key
|
||||||
|
func (s *Store) GetSigningKey(ctx context.Context) (*SigningKey, error) {
|
||||||
|
var key SigningKey
|
||||||
|
err := s.db.WithContext(ctx).Where("active = ?", true).First(&key).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// User operations
|
||||||
|
|
||||||
|
// CreateUser creates a new user with bcrypt hashed password
|
||||||
|
func (s *Store) CreateUser(ctx context.Context, user *User) error {
|
||||||
|
if user.ID == "" {
|
||||||
|
user.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to hash password: %w", err)
|
||||||
|
}
|
||||||
|
user.Password = string(hashedPassword)
|
||||||
|
user.CreatedAt = time.Now()
|
||||||
|
user.UpdatedAt = time.Now()
|
||||||
|
|
||||||
|
return s.db.WithContext(ctx).Create(user).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByID retrieves a user by ID
|
||||||
|
func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) {
|
||||||
|
var user User
|
||||||
|
err := s.db.WithContext(ctx).Where("id = ?", id).First(&user).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserByUsername retrieves a user by username
|
||||||
|
func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) {
|
||||||
|
var user User
|
||||||
|
err := s.db.WithContext(ctx).Where("username = ?", username).First(&user).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateUserPassword validates a user's password
|
||||||
|
func (s *Store) ValidateUserPassword(ctx context.Context, username, password string) (*User, error) {
|
||||||
|
user, err := s.GetUserByUsername(ctx, username)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
|
||||||
|
return nil, errors.New("invalid password")
|
||||||
|
}
|
||||||
|
|
||||||
|
return user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListUsers returns all users
|
||||||
|
func (s *Store) ListUsers(ctx context.Context) ([]*User, error) {
|
||||||
|
var users []*User
|
||||||
|
err := s.db.WithContext(ctx).Find(&users).Error
|
||||||
|
return users, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUser updates a user
|
||||||
|
func (s *Store) UpdateUser(ctx context.Context, user *User) error {
|
||||||
|
user.UpdatedAt = time.Now()
|
||||||
|
return s.db.WithContext(ctx).Save(user).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteUser deletes a user
|
||||||
|
func (s *Store) DeleteUser(ctx context.Context, id string) error {
|
||||||
|
return s.db.WithContext(ctx).Delete(&User{}, "id = ?", id).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUserPassword updates a user's password
|
||||||
|
func (s *Store) UpdateUserPassword(ctx context.Context, id, password string) error {
|
||||||
|
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to hash password: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.db.WithContext(ctx).Model(&User{}).Where("id = ?", id).Updates(map[string]interface{}{
|
||||||
|
"password": string(hashedPassword),
|
||||||
|
"updated_at": time.Now(),
|
||||||
|
}).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Client operations
|
||||||
|
|
||||||
|
// CreateClient creates a new OIDC client
|
||||||
|
func (s *Store) CreateClient(ctx context.Context, client *Client) error {
|
||||||
|
if client.ID == "" {
|
||||||
|
client.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hash secret if provided
|
||||||
|
if client.Secret != "" {
|
||||||
|
hashedSecret, err := bcrypt.GenerateFromPassword([]byte(client.Secret), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to hash client secret: %w", err)
|
||||||
|
}
|
||||||
|
client.Secret = string(hashedSecret)
|
||||||
|
}
|
||||||
|
|
||||||
|
client.CreatedAt = time.Now()
|
||||||
|
client.UpdatedAt = time.Now()
|
||||||
|
|
||||||
|
return s.db.WithContext(ctx).Create(client).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientByID retrieves a client by ID
|
||||||
|
func (s *Store) GetClientByID(ctx context.Context, id string) (*Client, error) {
|
||||||
|
var client Client
|
||||||
|
err := s.db.WithContext(ctx).Where("id = ?", id).First(&client).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateClientSecret validates a client's secret
|
||||||
|
func (s *Store) ValidateClientSecret(ctx context.Context, clientID, secret string) (*Client, error) {
|
||||||
|
client, err := s.GetClientByID(ctx, clientID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Public clients have no secret
|
||||||
|
if client.Secret == "" && secret == "" {
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(secret)); err != nil {
|
||||||
|
return nil, errors.New("invalid client secret")
|
||||||
|
}
|
||||||
|
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListClients returns all clients
|
||||||
|
func (s *Store) ListClients(ctx context.Context) ([]*Client, error) {
|
||||||
|
var clients []*Client
|
||||||
|
err := s.db.WithContext(ctx).Find(&clients).Error
|
||||||
|
return clients, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteClient deletes a client
|
||||||
|
func (s *Store) DeleteClient(ctx context.Context, id string) error {
|
||||||
|
return s.db.WithContext(ctx).Delete(&Client{}, "id = ?", id).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthRequest operations
|
||||||
|
|
||||||
|
// SaveAuthRequest saves an authorization request
|
||||||
|
func (s *Store) SaveAuthRequest(ctx context.Context, req *AuthRequest) error {
|
||||||
|
if req.ID == "" {
|
||||||
|
req.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
req.CreatedAt = time.Now()
|
||||||
|
return s.db.WithContext(ctx).Create(req).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthRequestByID retrieves an auth request by ID
|
||||||
|
func (s *Store) GetAuthRequestByID(ctx context.Context, id string) (*AuthRequest, error) {
|
||||||
|
var req AuthRequest
|
||||||
|
err := s.db.WithContext(ctx).Where("id = ?", id).First(&req).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &req, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateAuthRequest updates an auth request
|
||||||
|
func (s *Store) UpdateAuthRequest(ctx context.Context, req *AuthRequest) error {
|
||||||
|
return s.db.WithContext(ctx).Save(req).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAuthRequest deletes an auth request
|
||||||
|
func (s *Store) DeleteAuthRequest(ctx context.Context, id string) error {
|
||||||
|
return s.db.WithContext(ctx).Delete(&AuthRequest{}, "id = ?", id).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthCode operations
|
||||||
|
|
||||||
|
// SaveAuthCode saves an authorization code
|
||||||
|
func (s *Store) SaveAuthCode(ctx context.Context, code *AuthCode) error {
|
||||||
|
code.CreatedAt = time.Now()
|
||||||
|
if code.ExpiresAt.IsZero() {
|
||||||
|
code.ExpiresAt = time.Now().Add(10 * time.Minute) // 10 minute expiry
|
||||||
|
}
|
||||||
|
return s.db.WithContext(ctx).Create(code).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAuthCodeByCode retrieves an auth code
|
||||||
|
func (s *Store) GetAuthCodeByCode(ctx context.Context, code string) (*AuthCode, error) {
|
||||||
|
var authCode AuthCode
|
||||||
|
err := s.db.WithContext(ctx).Where("code = ?", code).First(&authCode).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &authCode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAuthCode deletes an auth code
|
||||||
|
func (s *Store) DeleteAuthCode(ctx context.Context, code string) error {
|
||||||
|
return s.db.WithContext(ctx).Delete(&AuthCode{}, "code = ?", code).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token operations
|
||||||
|
|
||||||
|
// SaveAccessToken saves an access token
|
||||||
|
func (s *Store) SaveAccessToken(ctx context.Context, token *AccessToken) error {
|
||||||
|
if token.ID == "" {
|
||||||
|
token.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
token.CreatedAt = time.Now()
|
||||||
|
return s.db.WithContext(ctx).Create(token).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccessTokenByID retrieves an access token
|
||||||
|
func (s *Store) GetAccessTokenByID(ctx context.Context, id string) (*AccessToken, error) {
|
||||||
|
var token AccessToken
|
||||||
|
err := s.db.WithContext(ctx).Where("id = ?", id).First(&token).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAccessToken deletes an access token
|
||||||
|
func (s *Store) DeleteAccessToken(ctx context.Context, id string) error {
|
||||||
|
return s.db.WithContext(ctx).Delete(&AccessToken{}, "id = ?", id).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshToken operations
|
||||||
|
|
||||||
|
// SaveRefreshToken saves a refresh token
|
||||||
|
func (s *Store) SaveRefreshToken(ctx context.Context, token *RefreshToken) error {
|
||||||
|
if token.ID == "" {
|
||||||
|
token.ID = uuid.New().String()
|
||||||
|
}
|
||||||
|
if token.Token == "" {
|
||||||
|
token.Token = uuid.New().String()
|
||||||
|
}
|
||||||
|
token.CreatedAt = time.Now()
|
||||||
|
return s.db.WithContext(ctx).Create(token).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRefreshToken retrieves a refresh token by token value
|
||||||
|
func (s *Store) GetRefreshToken(ctx context.Context, token string) (*RefreshToken, error) {
|
||||||
|
var rt RefreshToken
|
||||||
|
err := s.db.WithContext(ctx).Where("token = ?", token).First(&rt).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &rt, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRefreshToken deletes a refresh token
|
||||||
|
func (s *Store) DeleteRefreshToken(ctx context.Context, id string) error {
|
||||||
|
return s.db.WithContext(ctx).Delete(&RefreshToken{}, "id = ?", id).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRefreshTokenByToken deletes a refresh token by token value
|
||||||
|
func (s *Store) DeleteRefreshTokenByToken(ctx context.Context, token string) error {
|
||||||
|
return s.db.WithContext(ctx).Delete(&RefreshToken{}, "token = ?", token).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceAuth operations
|
||||||
|
|
||||||
|
// SaveDeviceAuth saves a device authorization
|
||||||
|
func (s *Store) SaveDeviceAuth(ctx context.Context, auth *DeviceAuth) error {
|
||||||
|
auth.CreatedAt = time.Now()
|
||||||
|
return s.db.WithContext(ctx).Create(auth).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceAuthByDeviceCode retrieves device auth by device code
|
||||||
|
func (s *Store) GetDeviceAuthByDeviceCode(ctx context.Context, deviceCode string) (*DeviceAuth, error) {
|
||||||
|
var auth DeviceAuth
|
||||||
|
err := s.db.WithContext(ctx).Where("device_code = ?", deviceCode).First(&auth).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDeviceAuthByUserCode retrieves device auth by user code
|
||||||
|
func (s *Store) GetDeviceAuthByUserCode(ctx context.Context, userCode string) (*DeviceAuth, error) {
|
||||||
|
var auth DeviceAuth
|
||||||
|
err := s.db.WithContext(ctx).Where("user_code = ?", userCode).First(&auth).Error
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateDeviceAuth updates a device authorization
|
||||||
|
func (s *Store) UpdateDeviceAuth(ctx context.Context, auth *DeviceAuth) error {
|
||||||
|
return s.db.WithContext(ctx).Save(auth).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDeviceAuth deletes a device authorization
|
||||||
|
func (s *Store) DeleteDeviceAuth(ctx context.Context, deviceCode string) error {
|
||||||
|
return s.db.WithContext(ctx).Delete(&DeviceAuth{}, "device_code = ?", deviceCode).Error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup operations
|
||||||
|
|
||||||
|
// CleanupExpired removes expired tokens and auth requests
|
||||||
|
func (s *Store) CleanupExpired(ctx context.Context) error {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Delete expired auth codes
|
||||||
|
if err := s.db.WithContext(ctx).Delete(&AuthCode{}, "expires_at < ?", now).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete expired access tokens
|
||||||
|
if err := s.db.WithContext(ctx).Delete(&AccessToken{}, "expiration < ?", now).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete expired refresh tokens
|
||||||
|
if err := s.db.WithContext(ctx).Delete(&RefreshToken{}, "expiration < ?", now).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete expired device authorizations
|
||||||
|
if err := s.db.WithContext(ctx).Delete(&DeviceAuth{}, "expiration < ?", now).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete old auth requests (older than 1 hour)
|
||||||
|
oneHourAgo := now.Add(-1 * time.Hour)
|
||||||
|
if err := s.db.WithContext(ctx).Delete(&AuthRequest{}, "created_at < ?", oneHourAgo).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper functions for JSON serialization
|
||||||
|
|
||||||
|
// ParseJSONArray parses a JSON array string into a slice
|
||||||
|
func ParseJSONArray(jsonStr string) []string {
|
||||||
|
if jsonStr == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var result []string
|
||||||
|
if err := json.Unmarshal([]byte(jsonStr), &result); err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToJSONArray converts a slice to a JSON array string
|
||||||
|
func ToJSONArray(arr []string) string {
|
||||||
|
if len(arr) == 0 {
|
||||||
|
return "[]"
|
||||||
|
}
|
||||||
|
data, err := json.Marshal(arr)
|
||||||
|
if err != nil {
|
||||||
|
return "[]"
|
||||||
|
}
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
261
idp/oidcprovider/templates/device.html
Normal file
261
idp/oidcprovider/templates/device.html
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Device Authorization - NetBird</title>
|
||||||
|
<style>
|
||||||
|
* {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
min-height: 100vh;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
padding: 20px;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
background: white;
|
||||||
|
padding: 40px;
|
||||||
|
border-radius: 12px;
|
||||||
|
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
|
||||||
|
width: 100%;
|
||||||
|
max-width: 450px;
|
||||||
|
}
|
||||||
|
.logo {
|
||||||
|
text-align: center;
|
||||||
|
margin-bottom: 30px;
|
||||||
|
}
|
||||||
|
.logo h1 {
|
||||||
|
font-size: 28px;
|
||||||
|
color: #333;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
.logo p {
|
||||||
|
color: #666;
|
||||||
|
margin-top: 8px;
|
||||||
|
}
|
||||||
|
.form-group {
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
label {
|
||||||
|
display: block;
|
||||||
|
margin-bottom: 8px;
|
||||||
|
color: #333;
|
||||||
|
font-weight: 500;
|
||||||
|
}
|
||||||
|
input[type="text"],
|
||||||
|
input[type="password"] {
|
||||||
|
width: 100%;
|
||||||
|
padding: 14px 16px;
|
||||||
|
border: 2px solid #e1e5eb;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-size: 16px;
|
||||||
|
transition: border-color 0.2s, box-shadow 0.2s;
|
||||||
|
}
|
||||||
|
input.code-input {
|
||||||
|
text-align: center;
|
||||||
|
font-size: 24px;
|
||||||
|
letter-spacing: 4px;
|
||||||
|
text-transform: uppercase;
|
||||||
|
}
|
||||||
|
input[type="text"]:focus,
|
||||||
|
input[type="password"]:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #667eea;
|
||||||
|
box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.2);
|
||||||
|
}
|
||||||
|
button {
|
||||||
|
width: 100%;
|
||||||
|
padding: 14px;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-size: 16px;
|
||||||
|
font-weight: 600;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: transform 0.2s, box-shadow 0.2s;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
}
|
||||||
|
button:hover {
|
||||||
|
transform: translateY(-2px);
|
||||||
|
box-shadow: 0 10px 20px rgba(102, 126, 234, 0.3);
|
||||||
|
}
|
||||||
|
button:active {
|
||||||
|
transform: translateY(0);
|
||||||
|
}
|
||||||
|
button.secondary {
|
||||||
|
background: #e1e5eb;
|
||||||
|
color: #333;
|
||||||
|
}
|
||||||
|
button.secondary:hover {
|
||||||
|
background: #d1d5db;
|
||||||
|
box-shadow: none;
|
||||||
|
}
|
||||||
|
button.deny {
|
||||||
|
background: #dc2626;
|
||||||
|
}
|
||||||
|
button.deny:hover {
|
||||||
|
background: #b91c1c;
|
||||||
|
}
|
||||||
|
.error {
|
||||||
|
background: #fee;
|
||||||
|
color: #c00;
|
||||||
|
padding: 12px 16px;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
font-size: 14px;
|
||||||
|
border: 1px solid #fcc;
|
||||||
|
}
|
||||||
|
.success {
|
||||||
|
background: #d4edda;
|
||||||
|
color: #155724;
|
||||||
|
padding: 20px;
|
||||||
|
border-radius: 8px;
|
||||||
|
text-align: center;
|
||||||
|
font-size: 16px;
|
||||||
|
border: 1px solid #c3e6cb;
|
||||||
|
}
|
||||||
|
.info {
|
||||||
|
background: #e8f4fd;
|
||||||
|
color: #0c5460;
|
||||||
|
padding: 16px;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
font-size: 14px;
|
||||||
|
border: 1px solid #bee5eb;
|
||||||
|
}
|
||||||
|
.scopes {
|
||||||
|
background: #f8f9fa;
|
||||||
|
padding: 16px;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
.scopes h3 {
|
||||||
|
font-size: 14px;
|
||||||
|
color: #666;
|
||||||
|
margin-bottom: 12px;
|
||||||
|
}
|
||||||
|
.scopes ul {
|
||||||
|
list-style: none;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
.scopes li {
|
||||||
|
padding: 8px 0;
|
||||||
|
border-bottom: 1px solid #e1e5eb;
|
||||||
|
color: #333;
|
||||||
|
}
|
||||||
|
.scopes li:last-child {
|
||||||
|
border-bottom: none;
|
||||||
|
}
|
||||||
|
.button-group {
|
||||||
|
display: flex;
|
||||||
|
gap: 12px;
|
||||||
|
}
|
||||||
|
.button-group button {
|
||||||
|
flex: 1;
|
||||||
|
}
|
||||||
|
.footer {
|
||||||
|
text-align: center;
|
||||||
|
margin-top: 24px;
|
||||||
|
color: #888;
|
||||||
|
font-size: 13px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<div class="logo">
|
||||||
|
<h1>NetBird</h1>
|
||||||
|
<p>Device Authorization</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{{if .Error}}
|
||||||
|
<div class="error">{{.Error}}</div>
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{if eq .Step "code"}}
|
||||||
|
<!-- Step 1: Enter user code -->
|
||||||
|
<div class="info">
|
||||||
|
Enter the code shown on your device to authorize it.
|
||||||
|
</div>
|
||||||
|
<form method="GET" action="/device">
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="user_code">Device Code</label>
|
||||||
|
<input type="text" id="user_code" name="user_code" class="code-input"
|
||||||
|
placeholder="XXXX-XXXX" required autofocus
|
||||||
|
pattern="[A-Za-z]{4}-?[A-Za-z]{4}">
|
||||||
|
</div>
|
||||||
|
<button type="submit">Continue</button>
|
||||||
|
</form>
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{if eq .Step "login"}}
|
||||||
|
<!-- Step 2: Login -->
|
||||||
|
<div class="info">
|
||||||
|
Sign in to authorize the device.
|
||||||
|
</div>
|
||||||
|
<form method="POST" action="/device/login">
|
||||||
|
<input type="hidden" name="user_code" value="{{.UserCode}}">
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="username">Username</label>
|
||||||
|
<input type="text" id="username" name="username" required autofocus>
|
||||||
|
</div>
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="password">Password</label>
|
||||||
|
<input type="password" id="password" name="password" required>
|
||||||
|
</div>
|
||||||
|
<button type="submit">Sign In</button>
|
||||||
|
</form>
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{if eq .Step "confirm"}}
|
||||||
|
<!-- Step 3: Confirm authorization -->
|
||||||
|
<div class="info">
|
||||||
|
<strong>{{.ClientID}}</strong> is requesting access to your account.
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{{if .Scopes}}
|
||||||
|
<div class="scopes">
|
||||||
|
<h3>This application will have access to:</h3>
|
||||||
|
<ul>
|
||||||
|
{{range .Scopes}}
|
||||||
|
<li>{{.}}</li>
|
||||||
|
{{end}}
|
||||||
|
</ul>
|
||||||
|
</div>
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
<form method="POST" action="/device/confirm">
|
||||||
|
<div class="button-group">
|
||||||
|
<button type="submit" name="action" value="allow">Allow</button>
|
||||||
|
<button type="submit" name="action" value="deny" class="deny">Deny</button>
|
||||||
|
</div>
|
||||||
|
</form>
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
{{if eq .Step "result"}}
|
||||||
|
<!-- Result -->
|
||||||
|
{{if .Success}}
|
||||||
|
<div class="success">
|
||||||
|
{{.Message}}
|
||||||
|
</div>
|
||||||
|
{{else}}
|
||||||
|
<div class="info">
|
||||||
|
{{.Message}}
|
||||||
|
</div>
|
||||||
|
{{end}}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
|
<div class="footer">
|
||||||
|
Powered by NetBird Identity Provider
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
129
idp/oidcprovider/templates/login.html
Normal file
129
idp/oidcprovider/templates/login.html
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Login - NetBird</title>
|
||||||
|
<style>
|
||||||
|
* {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, sans-serif;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
min-height: 100vh;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
padding: 20px;
|
||||||
|
}
|
||||||
|
.login-container {
|
||||||
|
background: white;
|
||||||
|
padding: 40px;
|
||||||
|
border-radius: 12px;
|
||||||
|
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
|
||||||
|
width: 100%;
|
||||||
|
max-width: 400px;
|
||||||
|
}
|
||||||
|
.logo {
|
||||||
|
text-align: center;
|
||||||
|
margin-bottom: 30px;
|
||||||
|
}
|
||||||
|
.logo h1 {
|
||||||
|
font-size: 28px;
|
||||||
|
color: #333;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
.logo p {
|
||||||
|
color: #666;
|
||||||
|
margin-top: 8px;
|
||||||
|
}
|
||||||
|
.form-group {
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
label {
|
||||||
|
display: block;
|
||||||
|
margin-bottom: 8px;
|
||||||
|
color: #333;
|
||||||
|
font-weight: 500;
|
||||||
|
}
|
||||||
|
input[type="text"],
|
||||||
|
input[type="password"] {
|
||||||
|
width: 100%;
|
||||||
|
padding: 14px 16px;
|
||||||
|
border: 2px solid #e1e5eb;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-size: 16px;
|
||||||
|
transition: border-color 0.2s, box-shadow 0.2s;
|
||||||
|
}
|
||||||
|
input[type="text"]:focus,
|
||||||
|
input[type="password"]:focus {
|
||||||
|
outline: none;
|
||||||
|
border-color: #667eea;
|
||||||
|
box-shadow: 0 0 0 3px rgba(102, 126, 234, 0.2);
|
||||||
|
}
|
||||||
|
button {
|
||||||
|
width: 100%;
|
||||||
|
padding: 14px;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-size: 16px;
|
||||||
|
font-weight: 600;
|
||||||
|
cursor: pointer;
|
||||||
|
transition: transform 0.2s, box-shadow 0.2s;
|
||||||
|
}
|
||||||
|
button:hover {
|
||||||
|
transform: translateY(-2px);
|
||||||
|
box-shadow: 0 10px 20px rgba(102, 126, 234, 0.3);
|
||||||
|
}
|
||||||
|
button:active {
|
||||||
|
transform: translateY(0);
|
||||||
|
}
|
||||||
|
.error {
|
||||||
|
background: #fee;
|
||||||
|
color: #c00;
|
||||||
|
padding: 12px 16px;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
font-size: 14px;
|
||||||
|
border: 1px solid #fcc;
|
||||||
|
}
|
||||||
|
.footer {
|
||||||
|
text-align: center;
|
||||||
|
margin-top: 24px;
|
||||||
|
color: #888;
|
||||||
|
font-size: 13px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="login-container">
|
||||||
|
<div class="logo">
|
||||||
|
<h1>NetBird</h1>
|
||||||
|
<p>Sign in to your account</p>
|
||||||
|
</div>
|
||||||
|
{{if .Error}}
|
||||||
|
<div class="error">{{.Error}}</div>
|
||||||
|
{{end}}
|
||||||
|
<form method="POST" action="/login">
|
||||||
|
<input type="hidden" name="authRequestID" value="{{.AuthRequestID}}">
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="username">Username</label>
|
||||||
|
<input type="text" id="username" name="username" required autofocus>
|
||||||
|
</div>
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="password">Password</label>
|
||||||
|
<input type="password" id="password" name="password" required>
|
||||||
|
</div>
|
||||||
|
<button type="submit">Sign In</button>
|
||||||
|
</form>
|
||||||
|
<div class="footer">
|
||||||
|
Powered by NetBird Identity Provider
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
Reference in New Issue
Block a user