mirror of
https://github.com/pocket-id/pocket-id.git
synced 2026-04-10 16:46:37 +00:00
Compare commits
5 Commits
v2.5.0
...
snyk-fix-d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
971c746b50 | ||
|
|
626adbf14c | ||
|
|
2b94535ade | ||
|
|
e825a58b39 | ||
|
|
b85a81f9b1 |
@@ -89,12 +89,19 @@ type OidcController struct {
|
|||||||
// @Router /api/oidc/authorize [post]
|
// @Router /api/oidc/authorize [post]
|
||||||
func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
func (oc *OidcController) authorizeHandler(c *gin.Context) {
|
||||||
var input dto.AuthorizeOidcClientRequestDto
|
var input dto.AuthorizeOidcClientRequestDto
|
||||||
if err := c.ShouldBindJSON(&input); err != nil {
|
err := c.ShouldBindJSON(&input)
|
||||||
|
if err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
code, callbackURL, err := oc.oidcService.Authorize(c.Request.Context(), input, c.GetString("userID"), c.ClientIP(), c.Request.UserAgent())
|
code, callbackURL, err := oc.oidcService.Authorize(
|
||||||
|
c.Request.Context(),
|
||||||
|
input,
|
||||||
|
c.GetString("userID"),
|
||||||
|
c.ClientIP(),
|
||||||
|
c.Request.UserAgent(),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -33,8 +33,8 @@ type OidcClientWithAllowedGroupsCountDto struct {
|
|||||||
|
|
||||||
type OidcClientUpdateDto struct {
|
type OidcClientUpdateDto struct {
|
||||||
Name string `json:"name" binding:"required,max=50" unorm:"nfc"`
|
Name string `json:"name" binding:"required,max=50" unorm:"nfc"`
|
||||||
CallbackURLs []string `json:"callbackURLs" binding:"omitempty,dive,callback_url"`
|
CallbackURLs []string `json:"callbackURLs" binding:"omitempty,dive,callback_url_pattern"`
|
||||||
LogoutCallbackURLs []string `json:"logoutCallbackURLs" binding:"omitempty,dive,callback_url"`
|
LogoutCallbackURLs []string `json:"logoutCallbackURLs" binding:"omitempty,dive,callback_url_pattern"`
|
||||||
IsPublic bool `json:"isPublic"`
|
IsPublic bool `json:"isPublic"`
|
||||||
PkceEnabled bool `json:"pkceEnabled"`
|
PkceEnabled bool `json:"pkceEnabled"`
|
||||||
RequiresReauthentication bool `json:"requiresReauthentication"`
|
RequiresReauthentication bool `json:"requiresReauthentication"`
|
||||||
@@ -66,7 +66,7 @@ type OidcClientFederatedIdentityDto struct {
|
|||||||
type AuthorizeOidcClientRequestDto struct {
|
type AuthorizeOidcClientRequestDto struct {
|
||||||
ClientID string `json:"clientID" binding:"required"`
|
ClientID string `json:"clientID" binding:"required"`
|
||||||
Scope string `json:"scope" binding:"required"`
|
Scope string `json:"scope" binding:"required"`
|
||||||
CallbackURL string `json:"callbackURL"`
|
CallbackURL string `json:"callbackURL" binding:"omitempty,callback_url"`
|
||||||
Nonce string `json:"nonce"`
|
Nonce string `json:"nonce"`
|
||||||
CodeChallenge string `json:"codeChallenge"`
|
CodeChallenge string `json:"codeChallenge"`
|
||||||
CodeChallengeMethod string `json:"codeChallengeMethod"`
|
CodeChallengeMethod string `json:"codeChallengeMethod"`
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
"github.com/pocket-id/pocket-id/backend/internal/utils"
|
||||||
@@ -19,38 +21,38 @@ var validateUsernameRegex = regexp.MustCompile("^[a-zA-Z0-9]([a-zA-Z0-9_.@-]*[a-
|
|||||||
var validateClientIDRegex = regexp.MustCompile("^[a-zA-Z0-9._-]+$")
|
var validateClientIDRegex = regexp.MustCompile("^[a-zA-Z0-9._-]+$")
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
v := binding.Validator.Engine().(*validator.Validate)
|
engine := binding.Validator.Engine().(*validator.Validate)
|
||||||
|
|
||||||
// Maximum allowed value for TTLs
|
// Maximum allowed value for TTLs
|
||||||
const maxTTL = 31 * 24 * time.Hour
|
const maxTTL = 31 * 24 * time.Hour
|
||||||
|
|
||||||
if err := v.RegisterValidation("username", func(fl validator.FieldLevel) bool {
|
validators := map[string]validator.Func{
|
||||||
return ValidateUsername(fl.Field().String())
|
"username": func(fl validator.FieldLevel) bool {
|
||||||
}); err != nil {
|
return ValidateUsername(fl.Field().String())
|
||||||
panic("Failed to register custom validation for username: " + err.Error())
|
},
|
||||||
|
"client_id": func(fl validator.FieldLevel) bool {
|
||||||
|
return ValidateClientID(fl.Field().String())
|
||||||
|
},
|
||||||
|
"ttl": func(fl validator.FieldLevel) bool {
|
||||||
|
ttl, ok := fl.Field().Interface().(utils.JSONDuration)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Allow zero, which means the field wasn't set
|
||||||
|
return ttl.Duration == 0 || (ttl.Duration > time.Second && ttl.Duration <= maxTTL)
|
||||||
|
},
|
||||||
|
"callback_url": func(fl validator.FieldLevel) bool {
|
||||||
|
return ValidateCallbackURL(fl.Field().String())
|
||||||
|
},
|
||||||
|
"callback_url_pattern": func(fl validator.FieldLevel) bool {
|
||||||
|
return ValidateCallbackURLPattern(fl.Field().String())
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
for k, v := range validators {
|
||||||
if err := v.RegisterValidation("client_id", func(fl validator.FieldLevel) bool {
|
err := engine.RegisterValidation(k, v)
|
||||||
return ValidateClientID(fl.Field().String())
|
if err != nil {
|
||||||
}); err != nil {
|
panic("Failed to register custom validation for " + k + ": " + err.Error())
|
||||||
panic("Failed to register custom validation for client_id: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := v.RegisterValidation("ttl", func(fl validator.FieldLevel) bool {
|
|
||||||
ttl, ok := fl.Field().Interface().(utils.JSONDuration)
|
|
||||||
if !ok {
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
// Allow zero, which means the field wasn't set
|
|
||||||
return ttl.Duration == 0 || (ttl.Duration > time.Second && ttl.Duration <= maxTTL)
|
|
||||||
}); err != nil {
|
|
||||||
panic("Failed to register custom validation for ttl: " + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := v.RegisterValidation("callback_url", func(fl validator.FieldLevel) bool {
|
|
||||||
return ValidateCallbackURL(fl.Field().String())
|
|
||||||
}); err != nil {
|
|
||||||
panic("Failed to register custom validation for callback_url: " + err.Error())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -64,8 +66,24 @@ func ValidateClientID(clientID string) bool {
|
|||||||
return validateClientIDRegex.MatchString(clientID)
|
return validateClientIDRegex.MatchString(clientID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateCallbackURL validates callback URLs with support for wildcards
|
// ValidateCallbackURL validates the input callback URL
|
||||||
func ValidateCallbackURL(raw string) bool {
|
func ValidateCallbackURL(str string) bool {
|
||||||
|
// Ensure the URL is a valid one and that the protocol is not "javascript:" or "data:"
|
||||||
|
u, err := url.Parse(str)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch strings.ToLower(u.Scheme) {
|
||||||
|
case "javascript", "data":
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateCallbackURLPattern validates callback URL patterns, with support for wildcards
|
||||||
|
func ValidateCallbackURLPattern(raw string) bool {
|
||||||
err := utils.ValidateCallbackURLPattern(raw)
|
err := utils.ValidateCallbackURLPattern(raw)
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,3 +57,28 @@ func TestValidateClientID(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateCallbackURL(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{"valid https URL", "https://example.com/callback", true},
|
||||||
|
{"valid loopback URL", "http://127.0.0.1:49813/callback", true},
|
||||||
|
{"empty scheme", "//127.0.0.1:49813/callback", true},
|
||||||
|
{"valid custom scheme", "pocketid://callback", true},
|
||||||
|
{"invalid malformed URL", "http://[::1", false},
|
||||||
|
{"invalid missing scheme separator", "://example.com/callback", false},
|
||||||
|
{"rejects javascript scheme", "javascript:alert(1)", false},
|
||||||
|
{"rejects mixed case javascript scheme", "JavaScript:alert(1)", false},
|
||||||
|
{"rejects data scheme", "data:text/html;base64,PGgxPkhlbGxvPC9oMT4=", false},
|
||||||
|
{"rejects mixed case data scheme", "DaTa:text/html;base64,PGgxPkhlbGxvPC9oMT4=", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equal(t, tt.expected, ValidateCallbackURL(tt.input))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -125,9 +125,7 @@ func (s *OidcService) getJWKCache(ctx context.Context) (*jwk.Cache, error) {
|
|||||||
|
|
||||||
func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
|
func (s *OidcService) Authorize(ctx context.Context, input dto.AuthorizeOidcClientRequestDto, userID, ipAddress, userAgent string) (string, string, error) {
|
||||||
tx := s.db.Begin()
|
tx := s.db.Begin()
|
||||||
defer func() {
|
defer tx.Rollback()
|
||||||
tx.Rollback()
|
|
||||||
}()
|
|
||||||
|
|
||||||
var client model.OidcClient
|
var client model.OidcClient
|
||||||
err := tx.
|
err := tx.
|
||||||
|
|||||||
@@ -138,7 +138,7 @@ func (s *s3Storage) List(ctx context.Context, path string) ([]ObjectInfo, error)
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
objects = append(objects, ObjectInfo{
|
objects = append(objects, ObjectInfo{
|
||||||
Path: aws.ToString(obj.Key),
|
Path: s.pathFromKey(aws.ToString(obj.Key)),
|
||||||
Size: aws.ToInt64(obj.Size),
|
Size: aws.ToInt64(obj.Size),
|
||||||
ModTime: aws.ToTime(obj.LastModified),
|
ModTime: aws.ToTime(obj.LastModified),
|
||||||
})
|
})
|
||||||
@@ -147,6 +147,13 @@ func (s *s3Storage) List(ctx context.Context, path string) ([]ObjectInfo, error)
|
|||||||
return objects, nil
|
return objects, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *s3Storage) pathFromKey(key string) string {
|
||||||
|
if s.prefix == "" {
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
return strings.TrimPrefix(key, s.prefix+"/")
|
||||||
|
}
|
||||||
|
|
||||||
func (s *s3Storage) Walk(ctx context.Context, root string, fn func(ObjectInfo) error) error {
|
func (s *s3Storage) Walk(ctx context.Context, root string, fn func(ObjectInfo) error) error {
|
||||||
objects, err := s.List(ctx, root)
|
objects, err := s.List(ctx, root)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -35,6 +35,50 @@ func TestS3Helpers(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
t.Run("pathFromKey strips prefix to honor relative-path contract", func(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
prefix string
|
||||||
|
key string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{name: "no prefix returns key unchanged", prefix: "", key: "images/logo.png", expected: "images/logo.png"},
|
||||||
|
{name: "no prefix empty key", prefix: "", key: "", expected: ""},
|
||||||
|
{name: "prefix matches and is stripped", prefix: "data/uploads", key: "data/uploads/application-images/logo.svg", expected: "application-images/logo.svg"},
|
||||||
|
{name: "single-segment prefix stripped", prefix: "root", key: "root/foo/bar.txt", expected: "foo/bar.txt"},
|
||||||
|
{name: "prefix equal to key without trailing slash is unchanged", prefix: "root", key: "root", expected: "root"},
|
||||||
|
{name: "key without expected prefix returned unchanged", prefix: "data/uploads", key: "other/path.txt", expected: "other/path.txt"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
s := &s3Storage{
|
||||||
|
bucket: "bucket",
|
||||||
|
prefix: tc.prefix,
|
||||||
|
}
|
||||||
|
assert.Equal(t, tc.expected, s.pathFromKey(tc.key))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("pathFromKey is the inverse of buildObjectKey for clean paths", func(t *testing.T) {
|
||||||
|
paths := []string{
|
||||||
|
"images/logo.png",
|
||||||
|
"application-images/logo.svg",
|
||||||
|
"oidc-client-images/abc.png",
|
||||||
|
"deeply/nested/file.bin",
|
||||||
|
}
|
||||||
|
prefixes := []string{"", "root", "data/uploads"}
|
||||||
|
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
for _, p := range paths {
|
||||||
|
s := &s3Storage{bucket: "bucket", prefix: prefix}
|
||||||
|
assert.Equal(t, p, s.pathFromKey(s.buildObjectKey(p)),
|
||||||
|
"round-trip failed for prefix=%q path=%q", prefix, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
t.Run("isS3NotFound detects expected errors", func(t *testing.T) {
|
t.Run("isS3NotFound detects expected errors", func(t *testing.T) {
|
||||||
assert.True(t, isS3NotFound(&smithy.GenericAPIError{Code: "NoSuchKey"}))
|
assert.True(t, isS3NotFound(&smithy.GenericAPIError{Code: "NoSuchKey"}))
|
||||||
assert.True(t, isS3NotFound(&smithy.GenericAPIError{Code: "NotFound"}))
|
assert.True(t, isS3NotFound(&smithy.GenericAPIError{Code: "NotFound"}))
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -490,7 +490,7 @@
|
|||||||
"scim_provisioning_description": "SCIM 프로비저닝을 통해 OIDC 클라이언트에서 사용자 및 그룹을 자동으로 프로비저닝 및 디프로비저닝할 수 있습니다. 자세한 내용은 <link href='https://pocket-id.org/docs/configuration/scim'>문서를</link> 참조하세요.",
|
"scim_provisioning_description": "SCIM 프로비저닝을 통해 OIDC 클라이언트에서 사용자 및 그룹을 자동으로 프로비저닝 및 디프로비저닝할 수 있습니다. 자세한 내용은 <link href='https://pocket-id.org/docs/configuration/scim'>문서를</link> 참조하세요.",
|
||||||
"scim_endpoint": "SCIM 엔드포인트",
|
"scim_endpoint": "SCIM 엔드포인트",
|
||||||
"scim_token": "SCIM 토큰",
|
"scim_token": "SCIM 토큰",
|
||||||
"last_successful_sync_at": "마지막 성공적인 동기화: {time}",
|
"last_successful_sync_at": "마지막 동기화 성공: {time}",
|
||||||
"scim_configuration_updated_successfully": "SCIM 구성이 성공적으로 업데이트되었습니다.",
|
"scim_configuration_updated_successfully": "SCIM 구성이 성공적으로 업데이트되었습니다.",
|
||||||
"scim_enabled_successfully": "SCIM이 성공적으로 활성화되었습니다.",
|
"scim_enabled_successfully": "SCIM이 성공적으로 활성화되었습니다.",
|
||||||
"scim_disabled_successfully": "SCIM이 성공적으로 비활성화되었습니다.",
|
"scim_disabled_successfully": "SCIM이 성공적으로 비활성화되었습니다.",
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@simplewebauthn/browser": "^13.2.2",
|
"@simplewebauthn/browser": "^13.2.2",
|
||||||
"@tailwindcss/vite": "^4.2.0",
|
"@tailwindcss/vite": "^4.2.0",
|
||||||
"axios": "^1.13.5",
|
"axios": "^1.15.0",
|
||||||
"clsx": "^2.1.1",
|
"clsx": "^2.1.1",
|
||||||
"date-fns": "^4.1.0",
|
"date-fns": "^4.1.0",
|
||||||
"jose": "^6.1.3",
|
"jose": "^6.1.3",
|
||||||
|
|||||||
@@ -71,19 +71,16 @@
|
|||||||
reauthToken = await webauthnService.reauthenticate(authResponse);
|
reauthToken = await webauthnService.reauthenticate(authResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
await oidService
|
const authResult = await oidService.authorize(
|
||||||
.authorize(
|
client!.id,
|
||||||
client!.id,
|
scope,
|
||||||
scope,
|
callbackURL,
|
||||||
callbackURL,
|
nonce,
|
||||||
nonce,
|
codeChallenge,
|
||||||
codeChallenge,
|
codeChallengeMethod,
|
||||||
codeChallengeMethod,
|
reauthToken
|
||||||
reauthToken
|
);
|
||||||
)
|
onSuccess(authResult.code, authResult.callbackURL, authResult.issuer);
|
||||||
.then(async ({ code, callbackURL, issuer }) => {
|
|
||||||
onSuccess(code, callbackURL, issuer);
|
|
||||||
});
|
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
errorMessage = getWebauthnErrorMessage(e);
|
errorMessage = getWebauthnErrorMessage(e);
|
||||||
isLoading = false;
|
isLoading = false;
|
||||||
@@ -91,13 +88,17 @@
|
|||||||
}
|
}
|
||||||
|
|
||||||
function onSuccess(code: string, callbackURL: string, issuer: string) {
|
function onSuccess(code: string, callbackURL: string, issuer: string) {
|
||||||
|
const redirectURL = new URL(callbackURL);
|
||||||
|
if (redirectURL.protocol == 'javascript:' || redirectURL.protocol == 'data:') {
|
||||||
|
throw new Error('Invalid redirect URL protocol');
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURL.searchParams.append('code', code);
|
||||||
|
redirectURL.searchParams.append('state', authorizeState);
|
||||||
|
redirectURL.searchParams.append('iss', issuer);
|
||||||
|
|
||||||
success = true;
|
success = true;
|
||||||
setTimeout(() => {
|
setTimeout(() => {
|
||||||
const redirectURL = new URL(callbackURL);
|
|
||||||
redirectURL.searchParams.append('code', code);
|
|
||||||
redirectURL.searchParams.append('state', authorizeState);
|
|
||||||
redirectURL.searchParams.append('iss', issuer);
|
|
||||||
|
|
||||||
window.location.href = redirectURL.toString();
|
window.location.href = redirectURL.toString();
|
||||||
}, 1000);
|
}, 1000);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user