diff --git a/idp/cmd/env.go b/idp/cmd/env.go new file mode 100644 index 000000000..e952fd085 --- /dev/null +++ b/idp/cmd/env.go @@ -0,0 +1,35 @@ +package cmd + +import ( + "os" + "strings" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +// setFlagsFromEnvVars reads and updates flag values from environment variables with prefix NB_IDP_ +func setFlagsFromEnvVars(cmd *cobra.Command) { + flags := cmd.PersistentFlags() + flags.VisitAll(func(f *pflag.Flag) { + newEnvVar := flagNameToEnvVar(f.Name, "NB_IDP_") + value, present := os.LookupEnv(newEnvVar) + if !present { + return + } + + err := flags.Set(f.Name, value) + if err != nil { + log.Infof("unable to configure flag %s using variable %s, err: %v", f.Name, newEnvVar, err) + } + }) +} + +// flagNameToEnvVar converts flag name to environment var name adding a prefix, +// replacing dashes and making all uppercase (e.g. data-dir is converted to NB_IDP_DATA_DIR) +func flagNameToEnvVar(cmdFlag string, prefix string) string { + parsed := strings.ReplaceAll(cmdFlag, "-", "_") + upper := strings.ToUpper(parsed) + return prefix + upper +} diff --git a/idp/cmd/root.go b/idp/cmd/root.go new file mode 100644 index 000000000..0c06477f2 --- /dev/null +++ b/idp/cmd/root.go @@ -0,0 +1,148 @@ +package cmd + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + + log "github.com/sirupsen/logrus" + "github.com/spf13/cobra" + + "github.com/netbirdio/netbird/idp/oidcprovider" + "github.com/netbirdio/netbird/util" +) + +// Config holds the IdP server configuration +type Config struct { + ListenPort int + Issuer string + DataDir string + LogLevel string + LogFile string + DevMode bool + DashboardRedirectURIs []string + CLIRedirectURIs []string + DashboardClientID string + CLIClientID string +} + +var ( + config *Config + rootCmd = &cobra.Command{ + Use: "idp", + Short: "NetBird Identity Provider", + Long: "Embedded OIDC Identity Provider for NetBird", + SilenceUsage: true, + SilenceErrors: true, + RunE: execute, + } +) + +func init() { + _ = util.InitLog("trace", util.LogConsole) + config = &Config{} + + rootCmd.PersistentFlags().IntVarP(&config.ListenPort, "port", "p", 33081, "port to listen on") + rootCmd.PersistentFlags().StringVarP(&config.Issuer, "issuer", "i", "", "OIDC issuer URL (default: http://localhost:)") + rootCmd.PersistentFlags().StringVarP(&config.DataDir, "data-dir", "d", "/var/lib/netbird", "directory to store IdP data") + rootCmd.PersistentFlags().StringVar(&config.LogLevel, "log-level", "info", "log level (trace, debug, info, warn, error)") + rootCmd.PersistentFlags().StringVar(&config.LogFile, "log-file", "console", "log file path or 'console'") + rootCmd.PersistentFlags().BoolVar(&config.DevMode, "dev-mode", false, "enable development mode (allows HTTP)") + rootCmd.PersistentFlags().StringSliceVar(&config.DashboardRedirectURIs, "dashboard-redirect-uris", []string{ + "http://localhost:3000/callback", + "http://localhost:3000/silent-callback", + }, "allowed redirect URIs for dashboard client") + rootCmd.PersistentFlags().StringSliceVar(&config.CLIRedirectURIs, "cli-redirect-uris", []string{ + "http://localhost:53000", + "http://localhost:54000", + }, "allowed redirect URIs for CLI client") + rootCmd.PersistentFlags().StringVar(&config.DashboardClientID, "dashboard-client-id", "netbird-dashboard", "client ID for dashboard") + rootCmd.PersistentFlags().StringVar(&config.CLIClientID, "cli-client-id", "netbird-client", "client ID for CLI") + + // Add subcommands + rootCmd.AddCommand(userCmd) + + setFlagsFromEnvVars(rootCmd) +} + +// Execute runs the root command +func Execute() error { + return rootCmd.Execute() +} + +func execute(cmd *cobra.Command, args []string) error { + err := util.InitLog(config.LogLevel, config.LogFile) + if err != nil { + return fmt.Errorf("failed to initialize log: %s", err) + } + + // Set default issuer if not provided + issuer := config.Issuer + if issuer == "" { + issuer = fmt.Sprintf("http://localhost:%d", config.ListenPort) + } + + log.Infof("Starting NetBird Identity Provider") + log.Infof(" Port: %d", config.ListenPort) + log.Infof(" Issuer: %s", issuer) + log.Infof(" Data directory: %s", config.DataDir) + log.Infof(" Dev mode: %v", config.DevMode) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Create provider config + providerConfig := &oidcprovider.Config{ + Issuer: issuer, + Port: config.ListenPort, + DataDir: config.DataDir, + DevMode: config.DevMode, + } + + // Create the provider + provider, err := oidcprovider.NewProvider(ctx, providerConfig) + if err != nil { + return fmt.Errorf("failed to create IdP: %w", err) + } + + // Ensure default clients exist + if err := provider.EnsureDefaultClients(ctx, config.DashboardRedirectURIs, config.CLIRedirectURIs); err != nil { + return fmt.Errorf("failed to create default clients: %w", err) + } + + // Start the provider + if err := provider.Start(ctx); err != nil { + return fmt.Errorf("failed to start IdP: %w", err) + } + + log.Infof("IdP is running") + log.Infof(" Discovery: %s/.well-known/openid-configuration", issuer) + log.Infof(" Authorization: %s/authorize", issuer) + log.Infof(" Token: %s/oauth/token", issuer) + log.Infof(" Device authorization: %s/device_authorization", issuer) + log.Infof(" JWKS: %s/keys", issuer) + log.Infof(" Login: %s/login", issuer) + log.Infof(" Device flow: %s/device", issuer) + + // Wait for exit signal + waitForExitSignal() + + log.Infof("Shutting down IdP...") + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 10) + defer shutdownCancel() + + if err := provider.Stop(shutdownCtx); err != nil { + return fmt.Errorf("failed to stop IdP: %w", err) + } + + log.Infof("IdP stopped") + return nil +} + +func waitForExitSignal() { + osSigs := make(chan os.Signal, 1) + signal.Notify(osSigs, syscall.SIGINT, syscall.SIGTERM) + <-osSigs +} diff --git a/idp/cmd/user.go b/idp/cmd/user.go new file mode 100644 index 000000000..1e78595ef --- /dev/null +++ b/idp/cmd/user.go @@ -0,0 +1,249 @@ +package cmd + +import ( + "context" + "fmt" + "os" + "syscall" + "text/tabwriter" + + "github.com/spf13/cobra" + "golang.org/x/term" + + "github.com/netbirdio/netbird/idp/oidcprovider" +) + +var userCmd = &cobra.Command{ + Use: "user", + Short: "Manage IdP users", + Long: "Commands for managing users in the embedded IdP", +} + +var userAddCmd = &cobra.Command{ + Use: "add", + Short: "Add a new user", + Long: "Add a new user to the embedded IdP", + RunE: userAdd, +} + +var userListCmd = &cobra.Command{ + Use: "list", + Short: "List all users", + Long: "List all users in the embedded IdP", + RunE: userList, +} + +var userDeleteCmd = &cobra.Command{ + Use: "delete ", + Short: "Delete a user", + Long: "Delete a user from the embedded IdP", + Args: cobra.ExactArgs(1), + RunE: userDelete, +} + +var userPasswordCmd = &cobra.Command{ + Use: "password ", + Short: "Change user password", + Long: "Change password for a user in the embedded IdP", + Args: cobra.ExactArgs(1), + RunE: userChangePassword, +} + +// User add flags +var ( + userUsername string + userEmail string + userFirstName string + userLastName string + userPassword string +) + +func init() { + userAddCmd.Flags().StringVarP(&userUsername, "username", "u", "", "username (required)") + userAddCmd.Flags().StringVarP(&userEmail, "email", "e", "", "email address (required)") + userAddCmd.Flags().StringVarP(&userFirstName, "first-name", "f", "", "first name") + userAddCmd.Flags().StringVarP(&userLastName, "last-name", "l", "", "last name") + userAddCmd.Flags().StringVarP(&userPassword, "password", "p", "", "password (will prompt if not provided)") + _ = userAddCmd.MarkFlagRequired("username") + _ = userAddCmd.MarkFlagRequired("email") + + userCmd.AddCommand(userAddCmd) + userCmd.AddCommand(userListCmd) + userCmd.AddCommand(userDeleteCmd) + userCmd.AddCommand(userPasswordCmd) +} + +func getStore() (*oidcprovider.Store, error) { + ctx := context.Background() + store, err := oidcprovider.NewStore(ctx, config.DataDir) + if err != nil { + return nil, fmt.Errorf("failed to open store: %w", err) + } + return store, nil +} + +func userAdd(cmd *cobra.Command, args []string) error { + store, err := getStore() + if err != nil { + return err + } + defer store.Close() + + password := userPassword + if password == "" { + // Prompt for password + fmt.Print("Enter password: ") + bytePassword, err := term.ReadPassword(int(syscall.Stdin)) + if err != nil { + return fmt.Errorf("failed to read password: %w", err) + } + fmt.Println() + + fmt.Print("Confirm password: ") + byteConfirm, err := term.ReadPassword(int(syscall.Stdin)) + if err != nil { + return fmt.Errorf("failed to read password confirmation: %w", err) + } + fmt.Println() + + if string(bytePassword) != string(byteConfirm) { + return fmt.Errorf("passwords do not match") + } + password = string(bytePassword) + } + + if password == "" { + return fmt.Errorf("password cannot be empty") + } + + user := &oidcprovider.User{ + Username: userUsername, + Email: userEmail, + FirstName: userFirstName, + LastName: userLastName, + Password: password, + EmailVerified: true, // Mark as verified since admin is creating the user + } + + ctx := context.Background() + if err := store.CreateUser(ctx, user); err != nil { + return fmt.Errorf("failed to create user: %w", err) + } + + fmt.Printf("User '%s' created successfully (ID: %s)\n", userUsername, user.ID) + return nil +} + +func userList(cmd *cobra.Command, args []string) error { + store, err := getStore() + if err != nil { + return err + } + defer store.Close() + + ctx := context.Background() + users, err := store.ListUsers(ctx) + if err != nil { + return fmt.Errorf("failed to list users: %w", err) + } + + if len(users) == 0 { + fmt.Println("No users found") + return nil + } + + w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0) + fmt.Fprintln(w, "ID\tUSERNAME\tEMAIL\tNAME\tVERIFIED\tCREATED") + for _, user := range users { + name := fmt.Sprintf("%s %s", user.FirstName, user.LastName) + verified := "No" + if user.EmailVerified { + verified = "Yes" + } + fmt.Fprintf(w, "%s\t%s\t%s\t%s\t%s\t%s\n", + user.ID, + user.Username, + user.Email, + name, + verified, + user.CreatedAt.Format("2006-01-02 15:04"), + ) + } + w.Flush() + + return nil +} + +func userDelete(cmd *cobra.Command, args []string) error { + username := args[0] + + store, err := getStore() + if err != nil { + return err + } + defer store.Close() + + ctx := context.Background() + + // Find user by username + user, err := store.GetUserByUsername(ctx, username) + if err != nil { + return fmt.Errorf("user '%s' not found", username) + } + + if err := store.DeleteUser(ctx, user.ID); err != nil { + return fmt.Errorf("failed to delete user: %w", err) + } + + fmt.Printf("User '%s' deleted successfully\n", username) + return nil +} + +func userChangePassword(cmd *cobra.Command, args []string) error { + username := args[0] + + store, err := getStore() + if err != nil { + return err + } + defer store.Close() + + ctx := context.Background() + + // Find user by username + user, err := store.GetUserByUsername(ctx, username) + if err != nil { + return fmt.Errorf("user '%s' not found", username) + } + + // Prompt for new password + fmt.Print("Enter new password: ") + bytePassword, err := term.ReadPassword(int(syscall.Stdin)) + if err != nil { + return fmt.Errorf("failed to read password: %w", err) + } + fmt.Println() + + fmt.Print("Confirm new password: ") + byteConfirm, err := term.ReadPassword(int(syscall.Stdin)) + if err != nil { + return fmt.Errorf("failed to read password confirmation: %w", err) + } + fmt.Println() + + if string(bytePassword) != string(byteConfirm) { + return fmt.Errorf("passwords do not match") + } + + password := string(bytePassword) + if password == "" { + return fmt.Errorf("password cannot be empty") + } + + if err := store.UpdateUserPassword(ctx, user.ID, password); err != nil { + return fmt.Errorf("failed to update password: %w", err) + } + + fmt.Printf("Password updated for user '%s'\n", username) + return nil +} diff --git a/idp/main.go b/idp/main.go new file mode 100644 index 000000000..b344d798f --- /dev/null +++ b/idp/main.go @@ -0,0 +1,13 @@ +package main + +import ( + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/idp/cmd" +) + +func main() { + if err := cmd.Execute(); err != nil { + log.Fatalf("failed to execute command: %v", err) + } +} diff --git a/idp/oidcprovider/client.go b/idp/oidcprovider/client.go new file mode 100644 index 000000000..a538024f6 --- /dev/null +++ b/idp/oidcprovider/client.go @@ -0,0 +1,249 @@ +package oidcprovider + +import ( + "time" + + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" +) + +// OIDCClient wraps the database Client model and implements op.Client interface +type OIDCClient struct { + client *Client + loginURL func(string) string + redirectURIs []string + grantTypes []oidc.GrantType + responseTypes []oidc.ResponseType +} + +// NewOIDCClient creates an OIDCClient from a database Client +func NewOIDCClient(client *Client, loginURL func(string) string) *OIDCClient { + return &OIDCClient{ + client: client, + loginURL: loginURL, + redirectURIs: ParseJSONArray(client.RedirectURIs), + grantTypes: parseGrantTypes(client.GrantTypes), + responseTypes: parseResponseTypes(client.ResponseTypes), + } +} + +// GetID returns the client ID +func (c *OIDCClient) GetID() string { + return c.client.ID +} + +// RedirectURIs returns the registered redirect URIs +func (c *OIDCClient) RedirectURIs() []string { + return c.redirectURIs +} + +// PostLogoutRedirectURIs returns the registered post-logout redirect URIs +func (c *OIDCClient) PostLogoutRedirectURIs() []string { + return ParseJSONArray(c.client.PostLogoutURIs) +} + +// ApplicationType returns the application type (native, web, user_agent) +func (c *OIDCClient) ApplicationType() op.ApplicationType { + switch c.client.ApplicationType { + case "native": + return op.ApplicationTypeNative + case "web": + return op.ApplicationTypeWeb + case "user_agent": + return op.ApplicationTypeUserAgent + default: + return op.ApplicationTypeWeb + } +} + +// AuthMethod returns the authentication method +func (c *OIDCClient) AuthMethod() oidc.AuthMethod { + switch c.client.AuthMethod { + case "none": + return oidc.AuthMethodNone + case "client_secret_basic": + return oidc.AuthMethodBasic + case "client_secret_post": + return oidc.AuthMethodPost + case "private_key_jwt": + return oidc.AuthMethodPrivateKeyJWT + default: + return oidc.AuthMethodNone + } +} + +// ResponseTypes returns the allowed response types +func (c *OIDCClient) ResponseTypes() []oidc.ResponseType { + return c.responseTypes +} + +// GrantTypes returns the allowed grant types +func (c *OIDCClient) GrantTypes() []oidc.GrantType { + return c.grantTypes +} + +// LoginURL returns the login URL for this client +func (c *OIDCClient) LoginURL(authRequestID string) string { + if c.loginURL != nil { + return c.loginURL(authRequestID) + } + return "/login?authRequestID=" + authRequestID +} + +// AccessTokenType returns the access token type +func (c *OIDCClient) AccessTokenType() op.AccessTokenType { + switch c.client.AccessTokenType { + case "jwt": + return op.AccessTokenTypeJWT + default: + return op.AccessTokenTypeBearer + } +} + +// IDTokenLifetime returns the ID token lifetime +func (c *OIDCClient) IDTokenLifetime() time.Duration { + if c.client.IDTokenLifetime > 0 { + return time.Duration(c.client.IDTokenLifetime) * time.Second + } + return time.Hour // default 1 hour +} + +// DevMode returns whether the client is in development mode +func (c *OIDCClient) DevMode() bool { + return c.client.DevMode +} + +// RestrictAdditionalIdTokenScopes returns any restricted scopes for ID tokens +func (c *OIDCClient) RestrictAdditionalIdTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } +} + +// RestrictAdditionalAccessTokenScopes returns any restricted scopes for access tokens +func (c *OIDCClient) RestrictAdditionalAccessTokenScopes() func(scopes []string) []string { + return func(scopes []string) []string { + return scopes + } +} + +// IsScopeAllowed checks if a scope is allowed for this client +func (c *OIDCClient) IsScopeAllowed(scope string) bool { + // Allow all standard OIDC scopes + switch scope { + case oidc.ScopeOpenID, oidc.ScopeProfile, oidc.ScopeEmail, oidc.ScopePhone, oidc.ScopeAddress, oidc.ScopeOfflineAccess: + return true + } + return true // Allow custom scopes as well +} + +// IDTokenUserinfoClaimsAssertion returns whether userinfo claims should be included in ID token +func (c *OIDCClient) IDTokenUserinfoClaimsAssertion() bool { + return false +} + +// ClockSkew returns the allowed clock skew for this client +func (c *OIDCClient) ClockSkew() time.Duration { + if c.client.ClockSkew > 0 { + return time.Duration(c.client.ClockSkew) * time.Second + } + return 0 +} + +// Helper functions for parsing grant types and response types + +func parseGrantTypes(jsonStr string) []oidc.GrantType { + types := ParseJSONArray(jsonStr) + if len(types) == 0 { + // Default grant types + return []oidc.GrantType{ + oidc.GrantTypeCode, + oidc.GrantTypeRefreshToken, + } + } + + result := make([]oidc.GrantType, 0, len(types)) + for _, t := range types { + switch t { + case "authorization_code": + result = append(result, oidc.GrantTypeCode) + case "refresh_token": + result = append(result, oidc.GrantTypeRefreshToken) + case "client_credentials": + result = append(result, oidc.GrantTypeClientCredentials) + case "urn:ietf:params:oauth:grant-type:device_code": + result = append(result, oidc.GrantTypeDeviceCode) + case "urn:ietf:params:oauth:grant-type:token-exchange": + result = append(result, oidc.GrantTypeTokenExchange) + } + } + return result +} + +func parseResponseTypes(jsonStr string) []oidc.ResponseType { + types := ParseJSONArray(jsonStr) + if len(types) == 0 { + // Default response types + return []oidc.ResponseType{oidc.ResponseTypeCode} + } + + result := make([]oidc.ResponseType, 0, len(types)) + for _, t := range types { + switch t { + case "code": + result = append(result, oidc.ResponseTypeCode) + case "id_token": + result = append(result, oidc.ResponseTypeIDToken) + } + } + return result +} + +// CreateNativeClient creates a native client configuration (for CLI/mobile apps with PKCE) +func CreateNativeClient(id, name string, redirectURIs []string) *Client { + return &Client{ + ID: id, + Name: name, + RedirectURIs: ToJSONArray(redirectURIs), + ApplicationType: "native", + AuthMethod: "none", // Public client + ResponseTypes: ToJSONArray([]string{"code"}), + GrantTypes: ToJSONArray([]string{"authorization_code", "refresh_token", "urn:ietf:params:oauth:grant-type:device_code"}), + AccessTokenType: "bearer", + DevMode: true, + IDTokenLifetime: 3600, + } +} + +// CreateWebClient creates a web client configuration (for SPAs/web apps) +func CreateWebClient(id, secret, name string, redirectURIs []string) *Client { + return &Client{ + ID: id, + Secret: secret, + Name: name, + RedirectURIs: ToJSONArray(redirectURIs), + ApplicationType: "web", + AuthMethod: "client_secret_basic", + ResponseTypes: ToJSONArray([]string{"code"}), + GrantTypes: ToJSONArray([]string{"authorization_code", "refresh_token"}), + AccessTokenType: "bearer", + DevMode: false, + IDTokenLifetime: 3600, + } +} + +// CreateSPAClient creates a Single Page Application client configuration (public client for SPAs) +func CreateSPAClient(id, name string, redirectURIs []string) *Client { + return &Client{ + ID: id, + Name: name, + RedirectURIs: ToJSONArray(redirectURIs), + ApplicationType: "user_agent", + AuthMethod: "none", // Public client for SPA + ResponseTypes: ToJSONArray([]string{"code"}), + GrantTypes: ToJSONArray([]string{"authorization_code", "refresh_token"}), + AccessTokenType: "bearer", + DevMode: true, + IDTokenLifetime: 3600, + } +} diff --git a/idp/oidcprovider/device.go b/idp/oidcprovider/device.go new file mode 100644 index 000000000..ffc242a9c --- /dev/null +++ b/idp/oidcprovider/device.go @@ -0,0 +1,220 @@ +package oidcprovider + +import ( + "encoding/base64" + "html/template" + "net/http" + + "github.com/go-chi/chi/v5" + "github.com/gorilla/securecookie" + log "github.com/sirupsen/logrus" +) + +// DeviceHandler handles the device authorization flow +type DeviceHandler struct { + storage *OIDCStorage + tmpl *template.Template + secureCookie *securecookie.SecureCookie +} + +// NewDeviceHandler creates a new device handler +func NewDeviceHandler(storage *OIDCStorage) (*DeviceHandler, error) { + tmpl, err := template.ParseFS(templateFS, "templates/*.html") + if err != nil { + return nil, err + } + + // Generate secure cookie keys + hashKey := securecookie.GenerateRandomKey(32) + blockKey := securecookie.GenerateRandomKey(32) + + return &DeviceHandler{ + storage: storage, + tmpl: tmpl, + secureCookie: securecookie.New(hashKey, blockKey), + }, nil +} + +// Router returns the device flow router +func (h *DeviceHandler) Router() chi.Router { + r := chi.NewRouter() + r.Get("/", h.userCodePage) + r.Post("/login", h.handleLogin) + r.Post("/confirm", h.handleConfirm) + return r +} + +// userCodePage displays the user code entry form +func (h *DeviceHandler) userCodePage(w http.ResponseWriter, r *http.Request) { + userCode := r.URL.Query().Get("user_code") + + data := map[string]interface{}{ + "UserCode": userCode, + "Error": "", + "Step": "code", // code, login, or confirm + } + + if userCode != "" { + // Verify the user code exists + _, err := h.storage.GetDeviceAuthorizationByUserCode(r.Context(), userCode) + if err != nil { + data["Error"] = "Invalid or expired user code" + data["UserCode"] = "" + } else { + data["Step"] = "login" + } + } + + if err := h.tmpl.ExecuteTemplate(w, "device.html", data); err != nil { + log.Errorf("failed to render device template: %v", err) + http.Error(w, "internal error", http.StatusInternalServerError) + } +} + +// handleLogin processes the login form on the device flow +func (h *DeviceHandler) handleLogin(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "invalid form", http.StatusBadRequest) + return + } + + userCode := r.FormValue("user_code") + username := r.FormValue("username") + password := r.FormValue("password") + + data := map[string]interface{}{ + "UserCode": userCode, + "Error": "", + "Step": "login", + } + + if userCode == "" || username == "" || password == "" { + data["Error"] = "Please fill in all fields" + h.tmpl.ExecuteTemplate(w, "device.html", data) + return + } + + // Validate credentials + userID, err := h.storage.CheckUsernamePasswordSimple(username, password) + if err != nil { + log.Warnf("device login failed for user %s: %v", username, err) + data["Error"] = "Invalid username or password" + h.tmpl.ExecuteTemplate(w, "device.html", data) + return + } + + // Get device authorization info + authState, err := h.storage.GetDeviceAuthorizationByUserCode(r.Context(), userCode) + if err != nil { + data["Error"] = "Invalid or expired user code" + data["Step"] = "code" + data["UserCode"] = "" + h.tmpl.ExecuteTemplate(w, "device.html", data) + return + } + + // Set secure cookie with user info for confirmation step + cookieValue := map[string]string{ + "user_code": userCode, + "user_id": userID, + } + + encoded, err := h.secureCookie.Encode("device_auth", cookieValue) + if err != nil { + log.Errorf("failed to encode cookie: %v", err) + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + http.SetCookie(w, &http.Cookie{ + Name: "device_auth", + Value: encoded, + Path: "/device", + HttpOnly: true, + Secure: r.TLS != nil, + SameSite: http.SameSiteStrictMode, + }) + + // Show confirmation page + data["Step"] = "confirm" + data["ClientID"] = authState.ClientID + data["Scopes"] = authState.Scopes + data["UserID"] = userID + + h.tmpl.ExecuteTemplate(w, "device.html", data) +} + +// handleConfirm processes the authorization decision +func (h *DeviceHandler) handleConfirm(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "invalid form", http.StatusBadRequest) + return + } + + // Get values from cookie + cookie, err := r.Cookie("device_auth") + if err != nil { + http.Redirect(w, r, "/device", http.StatusFound) + return + } + + var cookieValue map[string]string + if err := h.secureCookie.Decode("device_auth", cookie.Value, &cookieValue); err != nil { + http.Redirect(w, r, "/device", http.StatusFound) + return + } + + userCode := cookieValue["user_code"] + userID := cookieValue["user_id"] + action := r.FormValue("action") + + data := map[string]interface{}{ + "Step": "result", + } + + // Clear the cookie + http.SetCookie(w, &http.Cookie{ + Name: "device_auth", + Value: "", + Path: "/device", + MaxAge: -1, + HttpOnly: true, + }) + + if action == "allow" { + if err := h.storage.CompleteDeviceAuthorization(r.Context(), userCode, userID); err != nil { + log.Errorf("failed to complete device authorization: %v", err) + data["Error"] = "Failed to authorize device" + } else { + data["Success"] = true + data["Message"] = "Device authorized successfully! You can now close this window." + } + } else { + if err := h.storage.DenyDeviceAuthorization(r.Context(), userCode); err != nil { + log.Errorf("failed to deny device authorization: %v", err) + } + data["Success"] = false + data["Message"] = "Authorization denied. You can close this window." + } + + h.tmpl.ExecuteTemplate(w, "device.html", data) +} + +// GenerateUserCode generates a user-friendly code for device flow +func GenerateUserCode() string { + // Generate a base20 code (BCDFGHJKLMNPQRSTVWXZ - no vowels to avoid words) + chars := "BCDFGHJKLMNPQRSTVWXZ" + b := securecookie.GenerateRandomKey(8) + result := make([]byte, 8) + for i := range result { + result[i] = chars[int(b[i])%len(chars)] + } + // Format as XXXX-XXXX + return string(result[:4]) + "-" + string(result[4:]) +} + +// GenerateDeviceCode generates a secure device code +func GenerateDeviceCode() string { + b := securecookie.GenerateRandomKey(32) + return base64.RawURLEncoding.EncodeToString(b) +} diff --git a/idp/oidcprovider/login.go b/idp/oidcprovider/login.go new file mode 100644 index 000000000..725f6fb09 --- /dev/null +++ b/idp/oidcprovider/login.go @@ -0,0 +1,105 @@ +package oidcprovider + +import ( + "embed" + "html/template" + "net/http" + + "github.com/go-chi/chi/v5" + log "github.com/sirupsen/logrus" +) + +//go:embed templates/*.html +var templateFS embed.FS + +// LoginHandler handles the login flow +type LoginHandler struct { + storage *OIDCStorage + callback func(string) string + tmpl *template.Template +} + +// NewLoginHandler creates a new login handler +func NewLoginHandler(storage *OIDCStorage, callback func(string) string) (*LoginHandler, error) { + tmpl, err := template.ParseFS(templateFS, "templates/*.html") + if err != nil { + return nil, err + } + + return &LoginHandler{ + storage: storage, + callback: callback, + tmpl: tmpl, + }, nil +} + +// Router returns the login router +func (h *LoginHandler) Router() chi.Router { + r := chi.NewRouter() + r.Get("/", h.loginPage) + r.Post("/", h.handleLogin) + return r +} + +// loginPage displays the login form +func (h *LoginHandler) loginPage(w http.ResponseWriter, r *http.Request) { + authRequestID := r.URL.Query().Get("authRequestID") + if authRequestID == "" { + http.Error(w, "missing auth request ID", http.StatusBadRequest) + return + } + + data := map[string]interface{}{ + "AuthRequestID": authRequestID, + "Error": "", + } + + if err := h.tmpl.ExecuteTemplate(w, "login.html", data); err != nil { + log.Errorf("failed to render login template: %v", err) + http.Error(w, "internal error", http.StatusInternalServerError) + } +} + +// handleLogin processes the login form submission +func (h *LoginHandler) handleLogin(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + http.Error(w, "invalid form", http.StatusBadRequest) + return + } + + authRequestID := r.FormValue("authRequestID") + username := r.FormValue("username") + password := r.FormValue("password") + + if authRequestID == "" || username == "" || password == "" { + data := map[string]interface{}{ + "AuthRequestID": authRequestID, + "Error": "Please fill in all fields", + } + h.tmpl.ExecuteTemplate(w, "login.html", data) + return + } + + // Validate credentials and get user ID + userID, err := h.storage.CheckUsernamePasswordSimple(username, password) + if err != nil { + log.Warnf("login failed for user %s: %v", username, err) + data := map[string]interface{}{ + "AuthRequestID": authRequestID, + "Error": "Invalid username or password", + } + h.tmpl.ExecuteTemplate(w, "login.html", data) + return + } + + // Complete the auth request + if err := h.storage.CompleteAuthRequest(r.Context(), authRequestID, userID); err != nil { + log.Errorf("failed to complete auth request: %v", err) + http.Error(w, "internal error", http.StatusInternalServerError) + return + } + + // Redirect to callback + callbackURL := h.callback(authRequestID) + http.Redirect(w, r, callbackURL, http.StatusFound) +} diff --git a/idp/oidcprovider/models.go b/idp/oidcprovider/models.go new file mode 100644 index 000000000..3616ec60d --- /dev/null +++ b/idp/oidcprovider/models.go @@ -0,0 +1,136 @@ +package oidcprovider + +import ( + "time" + + "golang.org/x/text/language" +) + +// User represents an OIDC user stored in the database +type User struct { + ID string `gorm:"primaryKey"` + Username string `gorm:"uniqueIndex;not null"` + Password string `gorm:"not null"` // bcrypt hashed + Email string + EmailVerified bool + FirstName string + LastName string + Phone string + PhoneVerified bool + PreferredLanguage string // language tag string + IsAdmin bool + CreatedAt time.Time + UpdatedAt time.Time +} + +// GetPreferredLanguage returns the user's preferred language as a language.Tag +func (u *User) GetPreferredLanguage() language.Tag { + if u.PreferredLanguage == "" { + return language.English + } + tag, err := language.Parse(u.PreferredLanguage) + if err != nil { + return language.English + } + return tag +} + +// Client represents an OIDC client (application) stored in the database +type Client struct { + ID string `gorm:"primaryKey"` + Secret string // bcrypt hashed, empty for public clients + Name string + RedirectURIs string // JSON array of redirect URIs + PostLogoutURIs string // JSON array of post-logout redirect URIs + ApplicationType string // native, web, user_agent + AuthMethod string // none, client_secret_basic, client_secret_post, private_key_jwt + ResponseTypes string // JSON array: code, id_token, token + GrantTypes string // JSON array: authorization_code, refresh_token, client_credentials, urn:ietf:params:oauth:grant-type:device_code + AccessTokenType string // bearer or jwt + DevMode bool // allows non-HTTPS redirect URIs + IDTokenLifetime int64 // in seconds, default 3600 (1 hour) + ClockSkew int64 // in seconds, allowed clock skew + CreatedAt time.Time + UpdatedAt time.Time +} + +// AuthRequest represents an ongoing authorization request +type AuthRequest struct { + ID string `gorm:"primaryKey"` + ClientID string `gorm:"index"` + Scopes string // JSON array of scopes + RedirectURI string + State string + Nonce string + ResponseType string + ResponseMode string + CodeChallenge string + CodeMethod string // S256 or plain + UserID string // set after user authentication + Done bool // true when user has authenticated + AuthTime time.Time + CreatedAt time.Time + MaxAge int64 // max authentication age in seconds + Prompt string // none, login, consent, select_account + UILocales string // space-separated list of locales + LoginHint string + ACRValues string // space-separated list of ACR values +} + +// AuthCode represents an authorization code +type AuthCode struct { + Code string `gorm:"primaryKey"` + AuthRequestID string `gorm:"index"` + CreatedAt time.Time + ExpiresAt time.Time +} + +// AccessToken represents an access token +type AccessToken struct { + ID string `gorm:"primaryKey"` + ApplicationID string `gorm:"index"` + Subject string `gorm:"index"` + Audience string // JSON array + Scopes string // JSON array + Expiration time.Time + CreatedAt time.Time +} + +// RefreshToken represents a refresh token +type RefreshToken struct { + ID string `gorm:"primaryKey"` + Token string `gorm:"uniqueIndex"` + AuthRequestID string + ApplicationID string `gorm:"index"` + Subject string `gorm:"index"` + Audience string // JSON array + Scopes string // JSON array + AMR string // JSON array of authentication methods + AuthTime time.Time + Expiration time.Time + CreatedAt time.Time +} + +// DeviceAuth represents a device authorization request +type DeviceAuth struct { + DeviceCode string `gorm:"primaryKey"` + UserCode string `gorm:"uniqueIndex"` + ClientID string `gorm:"index"` + Scopes string // JSON array + Subject string // set after user authentication + Audience string // JSON array + Done bool // true when user has authorized + Denied bool // true when user has denied + Expiration time.Time + CreatedAt time.Time +} + +// SigningKey represents a signing key for JWTs +type SigningKey struct { + ID string `gorm:"primaryKey"` + Algorithm string // RS256 + PrivateKey []byte // PEM encoded + PublicKey []byte // PEM encoded + CreatedAt time.Time + Active bool +} diff --git a/idp/oidcprovider/oidc_storage.go b/idp/oidcprovider/oidc_storage.go new file mode 100644 index 000000000..628f20e97 --- /dev/null +++ b/idp/oidcprovider/oidc_storage.go @@ -0,0 +1,662 @@ +package oidcprovider + +import ( + "context" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "strings" + "time" + + jose "github.com/go-jose/go-jose/v4" + "github.com/google/uuid" + "github.com/zitadel/oidc/v3/pkg/oidc" + "github.com/zitadel/oidc/v3/pkg/op" + "gorm.io/gorm" +) + +// ErrInvalidRefreshToken is returned when a token is not a valid refresh token +var ErrInvalidRefreshToken = errors.New("invalid refresh token") + +// OIDCStorage implements op.Storage interface for the OIDC provider +type OIDCStorage struct { + store *Store + issuer string + loginURL func(string) string +} + +// NewOIDCStorage creates a new OIDCStorage +func NewOIDCStorage(store *Store, issuer string) *OIDCStorage { + return &OIDCStorage{ + store: store, + issuer: issuer, + } +} + +// SetLoginURL sets the login URL generator function +func (s *OIDCStorage) SetLoginURL(fn func(string) string) { + s.loginURL = fn +} + +// Health checks if the storage is healthy +func (s *OIDCStorage) Health(ctx context.Context) error { + sqlDB, err := s.store.db.DB() + if err != nil { + return err + } + return sqlDB.PingContext(ctx) +} + +// CreateAuthRequest creates and stores a new authorization request +func (s *OIDCStorage) CreateAuthRequest(ctx context.Context, authReq *oidc.AuthRequest, userID string) (op.AuthRequest, error) { + req := &AuthRequest{ + ID: uuid.New().String(), + ClientID: authReq.ClientID, + Scopes: ToJSONArray(authReq.Scopes), + RedirectURI: authReq.RedirectURI, + State: authReq.State, + Nonce: authReq.Nonce, + ResponseType: string(authReq.ResponseType), + ResponseMode: string(authReq.ResponseMode), + CodeChallenge: authReq.CodeChallenge, + CodeMethod: string(authReq.CodeChallengeMethod), + UserID: userID, + Done: userID != "", + CreatedAt: time.Now(), + Prompt: spaceSeparated(authReq.Prompt), + UILocales: authReq.UILocales.String(), + LoginHint: authReq.LoginHint, + ACRValues: spaceSeparated(authReq.ACRValues), + } + + if authReq.MaxAge != nil { + req.MaxAge = int64(*authReq.MaxAge) + } + + if userID != "" { + req.AuthTime = time.Now() + } + + if err := s.store.SaveAuthRequest(ctx, req); err != nil { + return nil, err + } + + return &OIDCAuthRequest{req: req, storage: s}, nil +} + +// AuthRequestByID retrieves an authorization request by ID +func (s *OIDCStorage) AuthRequestByID(ctx context.Context, id string) (op.AuthRequest, error) { + req, err := s.store.GetAuthRequestByID(ctx, id) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fmt.Errorf("auth request not found: %s", id) + } + return nil, err + } + return &OIDCAuthRequest{req: req, storage: s}, nil +} + +// AuthRequestByCode retrieves an authorization request by code +func (s *OIDCStorage) AuthRequestByCode(ctx context.Context, code string) (op.AuthRequest, error) { + authCode, err := s.store.GetAuthCodeByCode(ctx, code) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fmt.Errorf("auth code not found: %s", code) + } + return nil, err + } + + if time.Now().After(authCode.ExpiresAt) { + _ = s.store.DeleteAuthCode(ctx, code) + return nil, errors.New("auth code expired") + } + + req, err := s.store.GetAuthRequestByID(ctx, authCode.AuthRequestID) + if err != nil { + return nil, err + } + + return &OIDCAuthRequest{req: req, storage: s}, nil +} + +// SaveAuthCode saves an authorization code linked to an auth request +func (s *OIDCStorage) SaveAuthCode(ctx context.Context, id, code string) error { + authCode := &AuthCode{ + Code: code, + AuthRequestID: id, + ExpiresAt: time.Now().Add(10 * time.Minute), + } + return s.store.SaveAuthCode(ctx, authCode) +} + +// DeleteAuthRequest deletes an authorization request +func (s *OIDCStorage) DeleteAuthRequest(ctx context.Context, id string) error { + return s.store.DeleteAuthRequest(ctx, id) +} + +// CreateAccessToken creates and stores an access token +func (s *OIDCStorage) CreateAccessToken(ctx context.Context, request op.TokenRequest) (string, time.Time, error) { + tokenID := uuid.New().String() + expiration := time.Now().Add(5 * time.Minute) + + // Get client ID from the request if possible + var clientID string + if authReq, ok := request.(op.AuthRequest); ok { + clientID = authReq.GetClientID() + } else if refreshReq, ok := request.(op.RefreshTokenRequest); ok { + clientID = refreshReq.GetClientID() + } + + token := &AccessToken{ + ID: tokenID, + ApplicationID: clientID, + Subject: request.GetSubject(), + Audience: ToJSONArray(request.GetAudience()), + Scopes: ToJSONArray(request.GetScopes()), + Expiration: expiration, + } + + if err := s.store.SaveAccessToken(ctx, token); err != nil { + return "", time.Time{}, err + } + + return tokenID, expiration, nil +} + +// CreateAccessAndRefreshTokens creates both access and refresh tokens +func (s *OIDCStorage) CreateAccessAndRefreshTokens(ctx context.Context, request op.TokenRequest, currentRefreshToken string) (accessTokenID string, newRefreshToken string, expiration time.Time, err error) { + // Delete old refresh token if provided + if currentRefreshToken != "" { + _ = s.store.DeleteRefreshTokenByToken(ctx, currentRefreshToken) + } + + // Create access token + accessTokenID, expiration, err = s.CreateAccessToken(ctx, request) + if err != nil { + return "", "", time.Time{}, err + } + + // Get additional info from the request if possible + var clientID string + var authTime time.Time + var amr []string + + if authReq, ok := request.(op.AuthRequest); ok { + clientID = authReq.GetClientID() + authTime = authReq.GetAuthTime() + amr = authReq.GetAMR() + } else if refreshReq, ok := request.(op.RefreshTokenRequest); ok { + clientID = refreshReq.GetClientID() + authTime = refreshReq.GetAuthTime() + amr = refreshReq.GetAMR() + } + + // Create refresh token + refreshToken := &RefreshToken{ + ID: uuid.New().String(), + Token: uuid.New().String(), + ApplicationID: clientID, + Subject: request.GetSubject(), + Audience: ToJSONArray(request.GetAudience()), + Scopes: ToJSONArray(request.GetScopes()), + AuthTime: authTime, + AMR: ToJSONArray(amr), + Expiration: time.Now().Add(5 * time.Hour), // 5 hour refresh token lifetime + } + + if authReq, ok := request.(op.AuthRequest); ok { + refreshToken.AuthRequestID = authReq.GetID() + } + + if err := s.store.SaveRefreshToken(ctx, refreshToken); err != nil { + return "", "", time.Time{}, err + } + + return accessTokenID, refreshToken.Token, expiration, nil +} + +// TokenRequestByRefreshToken retrieves token request info from refresh token +func (s *OIDCStorage) TokenRequestByRefreshToken(ctx context.Context, refreshToken string) (op.RefreshTokenRequest, error) { + token, err := s.store.GetRefreshToken(ctx, refreshToken) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.New("refresh token not found") + } + return nil, err + } + + if time.Now().After(token.Expiration) { + _ = s.store.DeleteRefreshTokenByToken(ctx, refreshToken) + return nil, errors.New("refresh token expired") + } + + return &OIDCRefreshToken{token: token}, nil +} + +// TerminateSession terminates a user session +func (s *OIDCStorage) TerminateSession(ctx context.Context, userID, clientID string) error { + // For now, we don't track sessions separately + return nil +} + +// RevokeToken revokes a token +func (s *OIDCStorage) RevokeToken(ctx context.Context, tokenOrID string, userID string, clientID string) *oidc.Error { + // Try to delete as refresh token + if err := s.store.DeleteRefreshTokenByToken(ctx, tokenOrID); err == nil { + return nil + } + + // Try to delete as access token + if err := s.store.DeleteAccessToken(ctx, tokenOrID); err == nil { + return nil + } + + return nil // Silently succeed even if token not found (per spec) +} + +// GetRefreshTokenInfo returns info about a refresh token +func (s *OIDCStorage) GetRefreshTokenInfo(ctx context.Context, clientID string, token string) (userID string, tokenID string, err error) { + refreshToken, err := s.store.GetRefreshToken(ctx, token) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return "", "", ErrInvalidRefreshToken + } + return "", "", err + } + + if refreshToken.ApplicationID != clientID { + return "", "", ErrInvalidRefreshToken + } + + return refreshToken.Subject, refreshToken.ID, nil +} + +// GetClientByClientID retrieves a client by ID +func (s *OIDCStorage) GetClientByClientID(ctx context.Context, clientID string) (op.Client, error) { + client, err := s.store.GetClientByID(ctx, clientID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, fmt.Errorf("client not found: %s", clientID) + } + return nil, err + } + return NewOIDCClient(client, s.loginURL), nil +} + +// AuthorizeClientIDSecret validates client credentials +func (s *OIDCStorage) AuthorizeClientIDSecret(ctx context.Context, clientID, clientSecret string) error { + _, err := s.store.ValidateClientSecret(ctx, clientID, clientSecret) + return err +} + +// SetUserinfoFromScopes sets userinfo claims based on scopes +func (s *OIDCStorage) SetUserinfoFromScopes(ctx context.Context, userinfo *oidc.UserInfo, userID, clientID string, scopes []string) error { + return s.setUserinfo(ctx, userinfo, userID, scopes) +} + +// SetUserinfoFromToken sets userinfo claims from an access token +func (s *OIDCStorage) SetUserinfoFromToken(ctx context.Context, userinfo *oidc.UserInfo, tokenID, subject, origin string) error { + token, err := s.store.GetAccessTokenByID(ctx, tokenID) + if err != nil { + return err + } + return s.setUserinfo(ctx, userinfo, token.Subject, ParseJSONArray(token.Scopes)) +} + +// setUserinfo populates userinfo based on user data and scopes +func (s *OIDCStorage) setUserinfo(ctx context.Context, userinfo *oidc.UserInfo, userID string, scopes []string) error { + user, err := s.store.GetUserByID(ctx, userID) + if err != nil { + return err + } + + for _, scope := range scopes { + switch scope { + case oidc.ScopeOpenID: + userinfo.Subject = user.ID + case oidc.ScopeProfile: + userinfo.Name = fmt.Sprintf("%s %s", user.FirstName, user.LastName) + userinfo.GivenName = user.FirstName + userinfo.FamilyName = user.LastName + userinfo.PreferredUsername = user.Username + userinfo.Locale = oidc.NewLocale(user.GetPreferredLanguage()) + case oidc.ScopeEmail: + userinfo.Email = user.Email + userinfo.EmailVerified = oidc.Bool(user.EmailVerified) + case oidc.ScopePhone: + userinfo.PhoneNumber = user.Phone + userinfo.PhoneNumberVerified = user.PhoneVerified + } + } + + return nil +} + +// SetIntrospectionFromToken sets introspection response from token +func (s *OIDCStorage) SetIntrospectionFromToken(ctx context.Context, introspection *oidc.IntrospectionResponse, tokenID, subject, clientID string) error { + token, err := s.store.GetAccessTokenByID(ctx, tokenID) + if err != nil { + return err + } + + introspection.Active = true + introspection.Subject = token.Subject + introspection.ClientID = token.ApplicationID + introspection.Scope = ParseJSONArray(token.Scopes) + introspection.Expiration = oidc.FromTime(token.Expiration) + introspection.IssuedAt = oidc.FromTime(token.CreatedAt) + introspection.Audience = ParseJSONArray(token.Audience) + introspection.Issuer = s.issuer + + return nil +} + +// GetPrivateClaimsFromScopes returns additional claims based on scopes +func (s *OIDCStorage) GetPrivateClaimsFromScopes(ctx context.Context, userID, clientID string, scopes []string) (map[string]any, error) { + return nil, nil +} + +// GetKeyByIDAndClientID retrieves a key by ID for a client +func (s *OIDCStorage) GetKeyByIDAndClientID(ctx context.Context, keyID, clientID string) (*jose.JSONWebKey, error) { + return nil, errors.New("not implemented") +} + +// ValidateJWTProfileScopes validates scopes for JWT profile grant +func (s *OIDCStorage) ValidateJWTProfileScopes(ctx context.Context, userID string, scopes []string) ([]string, error) { + return scopes, nil +} + +// SigningKey returns the active signing key for token signing +func (s *OIDCStorage) SigningKey(ctx context.Context) (op.SigningKey, error) { + key, err := s.store.GetSigningKey(ctx) + if err != nil { + return nil, err + } + + block, _ := pem.Decode(key.PrivateKey) + if block == nil { + return nil, errors.New("failed to decode private key PEM") + } + + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + + return &signingKey{ + id: key.ID, + algorithm: jose.RS256, + privateKey: privateKey, + }, nil +} + +// SignatureAlgorithms returns supported signature algorithms +func (s *OIDCStorage) SignatureAlgorithms(ctx context.Context) ([]jose.SignatureAlgorithm, error) { + return []jose.SignatureAlgorithm{jose.RS256}, nil +} + +// KeySet returns the public key set for token verification +func (s *OIDCStorage) KeySet(ctx context.Context) ([]op.Key, error) { + key, err := s.store.GetSigningKey(ctx) + if err != nil { + return nil, err + } + + block, _ := pem.Decode(key.PublicKey) + if block == nil { + return nil, errors.New("failed to decode public key PEM") + } + + publicKey, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("failed to parse public key: %w", err) + } + + rsaKey, ok := publicKey.(*rsa.PublicKey) + if !ok { + return nil, errors.New("public key is not RSA") + } + + return []op.Key{ + &publicKeyInfo{ + id: key.ID, + algorithm: jose.RS256, + publicKey: rsaKey, + }, + }, nil +} + +// Device Authorization Flow methods + +// StoreDeviceAuthorization stores a device authorization request +func (s *OIDCStorage) StoreDeviceAuthorization(ctx context.Context, clientID, deviceCode, userCode string, expires time.Time, scopes []string) error { + auth := &DeviceAuth{ + DeviceCode: deviceCode, + UserCode: userCode, + ClientID: clientID, + Scopes: ToJSONArray(scopes), + Expiration: expires, + } + return s.store.SaveDeviceAuth(ctx, auth) +} + +// GetDeviceAuthorizationState retrieves the state of a device authorization +func (s *OIDCStorage) GetDeviceAuthorizationState(ctx context.Context, clientID, deviceCode string) (*op.DeviceAuthorizationState, error) { + auth, err := s.store.GetDeviceAuthByDeviceCode(ctx, deviceCode) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.New("device authorization not found") + } + return nil, err + } + + if auth.ClientID != clientID { + return nil, errors.New("client ID mismatch") + } + + if time.Now().After(auth.Expiration) { + _ = s.store.DeleteDeviceAuth(ctx, deviceCode) + return &op.DeviceAuthorizationState{Expires: auth.Expiration}, nil + } + + state := &op.DeviceAuthorizationState{ + ClientID: auth.ClientID, + Scopes: ParseJSONArray(auth.Scopes), + Expires: auth.Expiration, + } + + if auth.Denied { + state.Denied = true + } else if auth.Done { + state.Done = true + state.Subject = auth.Subject + } + + return state, nil +} + +// GetDeviceAuthorizationByUserCode retrieves device auth by user code +func (s *OIDCStorage) GetDeviceAuthorizationByUserCode(ctx context.Context, userCode string) (*op.DeviceAuthorizationState, error) { + auth, err := s.store.GetDeviceAuthByUserCode(ctx, userCode) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errors.New("device authorization not found") + } + return nil, err + } + + if time.Now().After(auth.Expiration) { + return nil, errors.New("device authorization expired") + } + + return &op.DeviceAuthorizationState{ + ClientID: auth.ClientID, + Scopes: ParseJSONArray(auth.Scopes), + Expires: auth.Expiration, + Done: auth.Done, + Denied: auth.Denied, + Subject: auth.Subject, + }, nil +} + +// CompleteDeviceAuthorization marks a device authorization as complete +func (s *OIDCStorage) CompleteDeviceAuthorization(ctx context.Context, userCode, subject string) error { + auth, err := s.store.GetDeviceAuthByUserCode(ctx, userCode) + if err != nil { + return err + } + + auth.Done = true + auth.Subject = subject + return s.store.UpdateDeviceAuth(ctx, auth) +} + +// DenyDeviceAuthorization marks a device authorization as denied +func (s *OIDCStorage) DenyDeviceAuthorization(ctx context.Context, userCode string) error { + auth, err := s.store.GetDeviceAuthByUserCode(ctx, userCode) + if err != nil { + return err + } + + auth.Denied = true + return s.store.UpdateDeviceAuth(ctx, auth) +} + +// User authentication methods + +// CheckUsernamePassword validates user credentials +func (s *OIDCStorage) CheckUsernamePassword(username, password, authRequestID string) error { + ctx := context.Background() + + _, err := s.store.ValidateUserPassword(ctx, username, password) + if err != nil { + return err + } + + return nil +} + +// CheckUsernamePasswordSimple validates user credentials and returns the user ID +func (s *OIDCStorage) CheckUsernamePasswordSimple(username, password string) (string, error) { + ctx := context.Background() + + user, err := s.store.ValidateUserPassword(ctx, username, password) + if err != nil { + return "", err + } + + return user.ID, nil +} + +// CompleteAuthRequest completes an auth request after user authentication +func (s *OIDCStorage) CompleteAuthRequest(ctx context.Context, authRequestID, userID string) error { + req, err := s.store.GetAuthRequestByID(ctx, authRequestID) + if err != nil { + return err + } + + req.UserID = userID + req.Done = true + req.AuthTime = time.Now() + + return s.store.UpdateAuthRequest(ctx, req) +} + +// Helper types + +// signingKey implements op.SigningKey +type signingKey struct { + id string + algorithm jose.SignatureAlgorithm + privateKey *rsa.PrivateKey +} + +func (k *signingKey) SignatureAlgorithm() jose.SignatureAlgorithm { + return k.algorithm +} + +func (k *signingKey) Key() interface{} { + return k.privateKey +} + +func (k *signingKey) ID() string { + return k.id +} + +// publicKeyInfo implements op.Key +type publicKeyInfo struct { + id string + algorithm jose.SignatureAlgorithm + publicKey *rsa.PublicKey +} + +func (k *publicKeyInfo) ID() string { + return k.id +} + +func (k *publicKeyInfo) Algorithm() jose.SignatureAlgorithm { + return k.algorithm +} + +func (k *publicKeyInfo) Use() string { + return "sig" +} + +func (k *publicKeyInfo) Key() interface{} { + return k.publicKey +} + +// OIDCAuthRequest wraps AuthRequest for the op.AuthRequest interface +type OIDCAuthRequest struct { + req *AuthRequest + storage *OIDCStorage +} + +func (r *OIDCAuthRequest) GetID() string { return r.req.ID } +func (r *OIDCAuthRequest) GetACR() string { return "" } +func (r *OIDCAuthRequest) GetAMR() []string { return []string{"pwd"} } +func (r *OIDCAuthRequest) GetAudience() []string { return []string{r.req.ClientID} } +func (r *OIDCAuthRequest) GetAuthTime() time.Time { return r.req.AuthTime } +func (r *OIDCAuthRequest) GetClientID() string { return r.req.ClientID } +func (r *OIDCAuthRequest) GetCodeChallenge() *oidc.CodeChallenge { + if r.req.CodeChallenge == "" { + return nil + } + return &oidc.CodeChallenge{ + Challenge: r.req.CodeChallenge, + Method: oidc.CodeChallengeMethod(r.req.CodeMethod), + } +} +func (r *OIDCAuthRequest) GetNonce() string { return r.req.Nonce } +func (r *OIDCAuthRequest) GetRedirectURI() string { return r.req.RedirectURI } +func (r *OIDCAuthRequest) GetResponseType() oidc.ResponseType { + return oidc.ResponseType(r.req.ResponseType) +} +func (r *OIDCAuthRequest) GetResponseMode() oidc.ResponseMode { + return oidc.ResponseMode(r.req.ResponseMode) +} +func (r *OIDCAuthRequest) GetScopes() []string { return ParseJSONArray(r.req.Scopes) } +func (r *OIDCAuthRequest) GetState() string { return r.req.State } +func (r *OIDCAuthRequest) GetSubject() string { return r.req.UserID } +func (r *OIDCAuthRequest) Done() bool { return r.req.Done } + +// OIDCRefreshToken wraps RefreshToken for the op.RefreshTokenRequest interface +type OIDCRefreshToken struct { + token *RefreshToken +} + +func (r *OIDCRefreshToken) GetAMR() []string { return ParseJSONArray(r.token.AMR) } +func (r *OIDCRefreshToken) GetAudience() []string { return ParseJSONArray(r.token.Audience) } +func (r *OIDCRefreshToken) GetAuthTime() time.Time { return r.token.AuthTime } +func (r *OIDCRefreshToken) GetClientID() string { return r.token.ApplicationID } +func (r *OIDCRefreshToken) GetScopes() []string { return ParseJSONArray(r.token.Scopes) } +func (r *OIDCRefreshToken) GetSubject() string { return r.token.Subject } +func (r *OIDCRefreshToken) SetCurrentScopes(scopes []string) {} + +// Helper functions + +func spaceSeparated(items []string) string { + return strings.Join(items, " ") +} diff --git a/idp/oidcprovider/provider.go b/idp/oidcprovider/provider.go new file mode 100644 index 000000000..847c338bb --- /dev/null +++ b/idp/oidcprovider/provider.go @@ -0,0 +1,265 @@ +package oidcprovider + +import ( + "context" + "crypto/sha256" + "fmt" + "net/http" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + log "github.com/sirupsen/logrus" + "github.com/zitadel/oidc/v3/pkg/op" +) + +// Config holds the configuration for the OIDC provider +type Config struct { + // Issuer is the OIDC issuer URL (e.g., "https://idp.example.com") + Issuer string + // Port is the port to listen on + Port int + // DataDir is the directory to store OIDC data (SQLite database) + DataDir string + // DevMode enables development mode (allows HTTP, localhost) + DevMode bool +} + +// Provider represents the embedded OIDC provider +type Provider struct { + config *Config + store *Store + storage *OIDCStorage + provider op.OpenIDProvider + router chi.Router + httpServer *http.Server +} + +// NewProvider creates a new OIDC provider +func NewProvider(ctx context.Context, config *Config) (*Provider, error) { + // Create the SQLite store + store, err := NewStore(ctx, config.DataDir) + if err != nil { + return nil, fmt.Errorf("failed to create OIDC store: %w", err) + } + + // Create the OIDC storage adapter + storage := NewOIDCStorage(store, config.Issuer) + + p := &Provider{ + config: config, + store: store, + storage: storage, + } + + return p, nil +} + +// Start starts the OIDC provider server +func (p *Provider) Start(ctx context.Context) error { + // Create the router + router := chi.NewRouter() + router.Use(middleware.Logger) + router.Use(middleware.Recoverer) + router.Use(middleware.RequestID) + + // Create the OIDC provider + key := sha256.Sum256([]byte(p.config.Issuer + "encryption-key")) + + opConfig := &op.Config{ + CryptoKey: key, + DefaultLogoutRedirectURI: "/logged-out", + CodeMethodS256: true, + AuthMethodPost: true, + AuthMethodPrivateKeyJWT: true, + GrantTypeRefreshToken: true, + RequestObjectSupported: true, + DeviceAuthorization: op.DeviceAuthorizationConfig{ + Lifetime: 5 * time.Minute, + PollInterval: 5 * time.Second, + UserFormPath: "/device", + UserCode: op.UserCodeBase20, + }, + } + + // Set the login URL generator + p.storage.SetLoginURL(func(authRequestID string) string { + return fmt.Sprintf("/login?authRequestID=%s", authRequestID) + }) + + // Create the provider with options + var opts []op.Option + if p.config.DevMode { + opts = append(opts, op.WithAllowInsecure()) + } + + provider, err := op.NewProvider(opConfig, p.storage, op.StaticIssuer(p.config.Issuer), opts...) + if err != nil { + return fmt.Errorf("failed to create OIDC provider: %w", err) + } + p.provider = provider + + // Set up login handler + loginHandler, err := NewLoginHandler(p.storage, func(authRequestID string) string { + return provider.AuthorizationEndpoint().Absolute("/authorize/callback") + "?id=" + authRequestID + }) + if err != nil { + return fmt.Errorf("failed to create login handler: %w", err) + } + + // Set up device handler + deviceHandler, err := NewDeviceHandler(p.storage) + if err != nil { + return fmt.Errorf("failed to create device handler: %w", err) + } + + // Mount routes + router.Mount("/login", loginHandler.Router()) + router.Mount("/device", deviceHandler.Router()) + router.Get("/logged-out", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Write([]byte(`Logged Out

You have been logged out

You can close this window.

`)) + }) + + // Mount the OIDC provider at root + router.Mount("/", provider) + + p.router = router + + // Create HTTP server + addr := fmt.Sprintf(":%d", p.config.Port) + p.httpServer = &http.Server{ + Addr: addr, + Handler: router, + } + + // Start server in goroutine + go func() { + log.Infof("Starting OIDC provider on %s (issuer: %s)", addr, p.config.Issuer) + if err := p.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Errorf("OIDC provider server error: %v", err) + } + }() + + // Start cleanup goroutine + go p.cleanupLoop(ctx) + + return nil +} + +// Stop stops the OIDC provider server +func (p *Provider) Stop(ctx context.Context) error { + if p.httpServer != nil { + if err := p.httpServer.Shutdown(ctx); err != nil { + return fmt.Errorf("failed to shutdown OIDC server: %w", err) + } + } + if p.store != nil { + if err := p.store.Close(); err != nil { + return fmt.Errorf("failed to close OIDC store: %w", err) + } + } + return nil +} + +// cleanupLoop periodically cleans up expired tokens +func (p *Provider) cleanupLoop(ctx context.Context) { + ticker := time.NewTicker(15 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := p.store.CleanupExpired(ctx); err != nil { + log.Warnf("OIDC cleanup error: %v", err) + } + } + } +} + +// Store returns the underlying store for user/client management +func (p *Provider) Store() *Store { + return p.store +} + +// GetIssuer returns the issuer URL +func (p *Provider) GetIssuer() string { + return p.config.Issuer +} + +// GetDiscoveryEndpoint returns the OpenID Connect discovery endpoint +func (p *Provider) GetDiscoveryEndpoint() string { + return p.config.Issuer + "/.well-known/openid-configuration" +} + +// GetTokenEndpoint returns the token endpoint +func (p *Provider) GetTokenEndpoint() string { + return p.config.Issuer + "/oauth/token" +} + +// GetAuthorizationEndpoint returns the authorization endpoint +func (p *Provider) GetAuthorizationEndpoint() string { + return p.config.Issuer + "/authorize" +} + +// GetDeviceAuthorizationEndpoint returns the device authorization endpoint +func (p *Provider) GetDeviceAuthorizationEndpoint() string { + return p.config.Issuer + "/device_authorization" +} + +// GetJWKSEndpoint returns the JWKS endpoint +func (p *Provider) GetJWKSEndpoint() string { + return p.config.Issuer + "/keys" +} + +// GetUserInfoEndpoint returns the userinfo endpoint +func (p *Provider) GetUserInfoEndpoint() string { + return p.config.Issuer + "/userinfo" +} + +// EnsureDefaultClients ensures the default NetBird clients exist +func (p *Provider) EnsureDefaultClients(ctx context.Context, dashboardRedirectURIs, cliRedirectURIs []string) error { + // Check if CLI client exists + _, err := p.store.GetClientByID(ctx, "netbird-client") + if err != nil { + // Create CLI client (native, public, supports PKCE and device flow) + cliClient := CreateNativeClient("netbird-client", "NetBird CLI", cliRedirectURIs) + if err := p.store.CreateClient(ctx, cliClient); err != nil { + return fmt.Errorf("failed to create CLI client: %w", err) + } + log.Info("Created default NetBird CLI client") + } + + // Check if dashboard client exists + _, err = p.store.GetClientByID(ctx, "netbird-dashboard") + if err != nil { + // Create dashboard client (SPA, public, supports PKCE) + dashboardClient := CreateSPAClient("netbird-dashboard", "NetBird Dashboard", dashboardRedirectURIs) + if err := p.store.CreateClient(ctx, dashboardClient); err != nil { + return fmt.Errorf("failed to create dashboard client: %w", err) + } + log.Info("Created default NetBird Dashboard client") + } + + return nil +} + +// CreateUser creates a new user (convenience method) +func (p *Provider) CreateUser(ctx context.Context, username, password, email, firstName, lastName string) (*User, error) { + user := &User{ + Username: username, + Password: password, // Will be hashed by store + Email: email, + EmailVerified: false, + FirstName: firstName, + LastName: lastName, + } + + if err := p.store.CreateUser(ctx, user); err != nil { + return nil, err + } + + return user, nil +} diff --git a/idp/oidcprovider/store.go b/idp/oidcprovider/store.go new file mode 100644 index 000000000..c116507a5 --- /dev/null +++ b/idp/oidcprovider/store.go @@ -0,0 +1,493 @@ +package oidcprovider + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "errors" + "fmt" + "sync" + "time" + + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + "golang.org/x/crypto/bcrypt" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +// Store handles persistence for OIDC provider data +type Store struct { + db *gorm.DB + mu sync.RWMutex +} + +// NewStore creates a new Store with SQLite backend +func NewStore(ctx context.Context, dataDir string) (*Store, error) { + dbPath := fmt.Sprintf("%s/oidc.db", dataDir) + + db, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + }) + if err != nil { + return nil, fmt.Errorf("failed to open OIDC database: %w", err) + } + + // Enable WAL mode for better concurrency + if err := db.Exec("PRAGMA journal_mode=WAL").Error; err != nil { + log.WithContext(ctx).Warnf("failed to enable WAL mode: %v", err) + } + + // Auto-migrate tables + if err := db.AutoMigrate( + &User{}, + &Client{}, + &AuthRequest{}, + &AuthCode{}, + &AccessToken{}, + &RefreshToken{}, + &DeviceAuth{}, + &SigningKey{}, + ); err != nil { + return nil, fmt.Errorf("failed to migrate OIDC database: %w", err) + } + + store := &Store{db: db} + + // Ensure we have a signing key + if err := store.ensureSigningKey(ctx); err != nil { + return nil, fmt.Errorf("failed to ensure signing key: %w", err) + } + + return store, nil +} + +// Close closes the database connection +func (s *Store) Close() error { + sqlDB, err := s.db.DB() + if err != nil { + return err + } + return sqlDB.Close() +} + +// ensureSigningKey creates a signing key if one doesn't exist +func (s *Store) ensureSigningKey(ctx context.Context) error { + var key SigningKey + err := s.db.WithContext(ctx).Where("active = ?", true).First(&key).Error + if err == nil { + return nil // Key exists + } + if !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + // Generate new RSA key pair + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return fmt.Errorf("failed to generate RSA key: %w", err) + } + + privateKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + + publicKeyBytes, err := x509.MarshalPKIXPublicKey(&privateKey.PublicKey) + if err != nil { + return fmt.Errorf("failed to marshal public key: %w", err) + } + publicKeyPEM := pem.EncodeToMemory(&pem.Block{ + Type: "PUBLIC KEY", + Bytes: publicKeyBytes, + }) + + newKey := &SigningKey{ + ID: uuid.New().String(), + Algorithm: "RS256", + PrivateKey: privateKeyPEM, + PublicKey: publicKeyPEM, + CreatedAt: time.Now(), + Active: true, + } + + return s.db.WithContext(ctx).Create(newKey).Error +} + +// GetSigningKey returns the active signing key +func (s *Store) GetSigningKey(ctx context.Context) (*SigningKey, error) { + var key SigningKey + err := s.db.WithContext(ctx).Where("active = ?", true).First(&key).Error + if err != nil { + return nil, err + } + return &key, nil +} + +// User operations + +// CreateUser creates a new user with bcrypt hashed password +func (s *Store) CreateUser(ctx context.Context, user *User) error { + if user.ID == "" { + user.ID = uuid.New().String() + } + + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("failed to hash password: %w", err) + } + user.Password = string(hashedPassword) + user.CreatedAt = time.Now() + user.UpdatedAt = time.Now() + + return s.db.WithContext(ctx).Create(user).Error +} + +// GetUserByID retrieves a user by ID +func (s *Store) GetUserByID(ctx context.Context, id string) (*User, error) { + var user User + err := s.db.WithContext(ctx).Where("id = ?", id).First(&user).Error + if err != nil { + return nil, err + } + return &user, nil +} + +// GetUserByUsername retrieves a user by username +func (s *Store) GetUserByUsername(ctx context.Context, username string) (*User, error) { + var user User + err := s.db.WithContext(ctx).Where("username = ?", username).First(&user).Error + if err != nil { + return nil, err + } + return &user, nil +} + +// ValidateUserPassword validates a user's password +func (s *Store) ValidateUserPassword(ctx context.Context, username, password string) (*User, error) { + user, err := s.GetUserByUsername(ctx, username) + if err != nil { + return nil, err + } + + if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil { + return nil, errors.New("invalid password") + } + + return user, nil +} + +// ListUsers returns all users +func (s *Store) ListUsers(ctx context.Context) ([]*User, error) { + var users []*User + err := s.db.WithContext(ctx).Find(&users).Error + return users, err +} + +// UpdateUser updates a user +func (s *Store) UpdateUser(ctx context.Context, user *User) error { + user.UpdatedAt = time.Now() + return s.db.WithContext(ctx).Save(user).Error +} + +// DeleteUser deletes a user +func (s *Store) DeleteUser(ctx context.Context, id string) error { + return s.db.WithContext(ctx).Delete(&User{}, "id = ?", id).Error +} + +// UpdateUserPassword updates a user's password +func (s *Store) UpdateUserPassword(ctx context.Context, id, password string) error { + hashedPassword, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("failed to hash password: %w", err) + } + + return s.db.WithContext(ctx).Model(&User{}).Where("id = ?", id).Updates(map[string]interface{}{ + "password": string(hashedPassword), + "updated_at": time.Now(), + }).Error +} + +// Client operations + +// CreateClient creates a new OIDC client +func (s *Store) CreateClient(ctx context.Context, client *Client) error { + if client.ID == "" { + client.ID = uuid.New().String() + } + + // Hash secret if provided + if client.Secret != "" { + hashedSecret, err := bcrypt.GenerateFromPassword([]byte(client.Secret), bcrypt.DefaultCost) + if err != nil { + return fmt.Errorf("failed to hash client secret: %w", err) + } + client.Secret = string(hashedSecret) + } + + client.CreatedAt = time.Now() + client.UpdatedAt = time.Now() + + return s.db.WithContext(ctx).Create(client).Error +} + +// GetClientByID retrieves a client by ID +func (s *Store) GetClientByID(ctx context.Context, id string) (*Client, error) { + var client Client + err := s.db.WithContext(ctx).Where("id = ?", id).First(&client).Error + if err != nil { + return nil, err + } + return &client, nil +} + +// ValidateClientSecret validates a client's secret +func (s *Store) ValidateClientSecret(ctx context.Context, clientID, secret string) (*Client, error) { + client, err := s.GetClientByID(ctx, clientID) + if err != nil { + return nil, err + } + + // Public clients have no secret + if client.Secret == "" && secret == "" { + return client, nil + } + + if err := bcrypt.CompareHashAndPassword([]byte(client.Secret), []byte(secret)); err != nil { + return nil, errors.New("invalid client secret") + } + + return client, nil +} + +// ListClients returns all clients +func (s *Store) ListClients(ctx context.Context) ([]*Client, error) { + var clients []*Client + err := s.db.WithContext(ctx).Find(&clients).Error + return clients, err +} + +// DeleteClient deletes a client +func (s *Store) DeleteClient(ctx context.Context, id string) error { + return s.db.WithContext(ctx).Delete(&Client{}, "id = ?", id).Error +} + +// AuthRequest operations + +// SaveAuthRequest saves an authorization request +func (s *Store) SaveAuthRequest(ctx context.Context, req *AuthRequest) error { + if req.ID == "" { + req.ID = uuid.New().String() + } + req.CreatedAt = time.Now() + return s.db.WithContext(ctx).Create(req).Error +} + +// GetAuthRequestByID retrieves an auth request by ID +func (s *Store) GetAuthRequestByID(ctx context.Context, id string) (*AuthRequest, error) { + var req AuthRequest + err := s.db.WithContext(ctx).Where("id = ?", id).First(&req).Error + if err != nil { + return nil, err + } + return &req, nil +} + +// UpdateAuthRequest updates an auth request +func (s *Store) UpdateAuthRequest(ctx context.Context, req *AuthRequest) error { + return s.db.WithContext(ctx).Save(req).Error +} + +// DeleteAuthRequest deletes an auth request +func (s *Store) DeleteAuthRequest(ctx context.Context, id string) error { + return s.db.WithContext(ctx).Delete(&AuthRequest{}, "id = ?", id).Error +} + +// AuthCode operations + +// SaveAuthCode saves an authorization code +func (s *Store) SaveAuthCode(ctx context.Context, code *AuthCode) error { + code.CreatedAt = time.Now() + if code.ExpiresAt.IsZero() { + code.ExpiresAt = time.Now().Add(10 * time.Minute) // 10 minute expiry + } + return s.db.WithContext(ctx).Create(code).Error +} + +// GetAuthCodeByCode retrieves an auth code +func (s *Store) GetAuthCodeByCode(ctx context.Context, code string) (*AuthCode, error) { + var authCode AuthCode + err := s.db.WithContext(ctx).Where("code = ?", code).First(&authCode).Error + if err != nil { + return nil, err + } + return &authCode, nil +} + +// DeleteAuthCode deletes an auth code +func (s *Store) DeleteAuthCode(ctx context.Context, code string) error { + return s.db.WithContext(ctx).Delete(&AuthCode{}, "code = ?", code).Error +} + +// Token operations + +// SaveAccessToken saves an access token +func (s *Store) SaveAccessToken(ctx context.Context, token *AccessToken) error { + if token.ID == "" { + token.ID = uuid.New().String() + } + token.CreatedAt = time.Now() + return s.db.WithContext(ctx).Create(token).Error +} + +// GetAccessTokenByID retrieves an access token +func (s *Store) GetAccessTokenByID(ctx context.Context, id string) (*AccessToken, error) { + var token AccessToken + err := s.db.WithContext(ctx).Where("id = ?", id).First(&token).Error + if err != nil { + return nil, err + } + return &token, nil +} + +// DeleteAccessToken deletes an access token +func (s *Store) DeleteAccessToken(ctx context.Context, id string) error { + return s.db.WithContext(ctx).Delete(&AccessToken{}, "id = ?", id).Error +} + +// RefreshToken operations + +// SaveRefreshToken saves a refresh token +func (s *Store) SaveRefreshToken(ctx context.Context, token *RefreshToken) error { + if token.ID == "" { + token.ID = uuid.New().String() + } + if token.Token == "" { + token.Token = uuid.New().String() + } + token.CreatedAt = time.Now() + return s.db.WithContext(ctx).Create(token).Error +} + +// GetRefreshToken retrieves a refresh token by token value +func (s *Store) GetRefreshToken(ctx context.Context, token string) (*RefreshToken, error) { + var rt RefreshToken + err := s.db.WithContext(ctx).Where("token = ?", token).First(&rt).Error + if err != nil { + return nil, err + } + return &rt, nil +} + +// DeleteRefreshToken deletes a refresh token +func (s *Store) DeleteRefreshToken(ctx context.Context, id string) error { + return s.db.WithContext(ctx).Delete(&RefreshToken{}, "id = ?", id).Error +} + +// DeleteRefreshTokenByToken deletes a refresh token by token value +func (s *Store) DeleteRefreshTokenByToken(ctx context.Context, token string) error { + return s.db.WithContext(ctx).Delete(&RefreshToken{}, "token = ?", token).Error +} + +// DeviceAuth operations + +// SaveDeviceAuth saves a device authorization +func (s *Store) SaveDeviceAuth(ctx context.Context, auth *DeviceAuth) error { + auth.CreatedAt = time.Now() + return s.db.WithContext(ctx).Create(auth).Error +} + +// GetDeviceAuthByDeviceCode retrieves device auth by device code +func (s *Store) GetDeviceAuthByDeviceCode(ctx context.Context, deviceCode string) (*DeviceAuth, error) { + var auth DeviceAuth + err := s.db.WithContext(ctx).Where("device_code = ?", deviceCode).First(&auth).Error + if err != nil { + return nil, err + } + return &auth, nil +} + +// GetDeviceAuthByUserCode retrieves device auth by user code +func (s *Store) GetDeviceAuthByUserCode(ctx context.Context, userCode string) (*DeviceAuth, error) { + var auth DeviceAuth + err := s.db.WithContext(ctx).Where("user_code = ?", userCode).First(&auth).Error + if err != nil { + return nil, err + } + return &auth, nil +} + +// UpdateDeviceAuth updates a device authorization +func (s *Store) UpdateDeviceAuth(ctx context.Context, auth *DeviceAuth) error { + return s.db.WithContext(ctx).Save(auth).Error +} + +// DeleteDeviceAuth deletes a device authorization +func (s *Store) DeleteDeviceAuth(ctx context.Context, deviceCode string) error { + return s.db.WithContext(ctx).Delete(&DeviceAuth{}, "device_code = ?", deviceCode).Error +} + +// Cleanup operations + +// CleanupExpired removes expired tokens and auth requests +func (s *Store) CleanupExpired(ctx context.Context) error { + now := time.Now() + + // Delete expired auth codes + if err := s.db.WithContext(ctx).Delete(&AuthCode{}, "expires_at < ?", now).Error; err != nil { + return err + } + + // Delete expired access tokens + if err := s.db.WithContext(ctx).Delete(&AccessToken{}, "expiration < ?", now).Error; err != nil { + return err + } + + // Delete expired refresh tokens + if err := s.db.WithContext(ctx).Delete(&RefreshToken{}, "expiration < ?", now).Error; err != nil { + return err + } + + // Delete expired device authorizations + if err := s.db.WithContext(ctx).Delete(&DeviceAuth{}, "expiration < ?", now).Error; err != nil { + return err + } + + // Delete old auth requests (older than 1 hour) + oneHourAgo := now.Add(-1 * time.Hour) + if err := s.db.WithContext(ctx).Delete(&AuthRequest{}, "created_at < ?", oneHourAgo).Error; err != nil { + return err + } + + return nil +} + +// Helper functions for JSON serialization + +// ParseJSONArray parses a JSON array string into a slice +func ParseJSONArray(jsonStr string) []string { + if jsonStr == "" { + return nil + } + var result []string + if err := json.Unmarshal([]byte(jsonStr), &result); err != nil { + return nil + } + return result +} + +// ToJSONArray converts a slice to a JSON array string +func ToJSONArray(arr []string) string { + if len(arr) == 0 { + return "[]" + } + data, err := json.Marshal(arr) + if err != nil { + return "[]" + } + return string(data) +} diff --git a/idp/oidcprovider/templates/device.html b/idp/oidcprovider/templates/device.html new file mode 100644 index 000000000..2d2b775fb --- /dev/null +++ b/idp/oidcprovider/templates/device.html @@ -0,0 +1,261 @@ + + + + + + Device Authorization - NetBird + + + +
+ + + {{if .Error}} +
{{.Error}}
+ {{end}} + + {{if eq .Step "code"}} + +
+ Enter the code shown on your device to authorize it. +
+
+
+ + +
+ +
+ {{end}} + + {{if eq .Step "login"}} + +
+ Sign in to authorize the device. +
+
+ +
+ + +
+
+ + +
+ +
+ {{end}} + + {{if eq .Step "confirm"}} + +
+ {{.ClientID}} is requesting access to your account. +
+ + {{if .Scopes}} +
+

This application will have access to:

+
    + {{range .Scopes}} +
  • {{.}}
  • + {{end}} +
+
+ {{end}} + +
+
+ + +
+
+ {{end}} + + {{if eq .Step "result"}} + + {{if .Success}} +
+ {{.Message}} +
+ {{else}} +
+ {{.Message}} +
+ {{end}} + {{end}} + + +
+ + \ No newline at end of file diff --git a/idp/oidcprovider/templates/login.html b/idp/oidcprovider/templates/login.html new file mode 100644 index 000000000..25ab59960 --- /dev/null +++ b/idp/oidcprovider/templates/login.html @@ -0,0 +1,129 @@ + + + + + + Login - NetBird + + + + + + \ No newline at end of file