package middleware import ( "errors" "fmt" "net/http" "strings" "github.com/gin-gonic/gin" "github.com/gin-gonic/gin/binding" "github.com/go-playground/validator/v10" "github.com/pocket-id/pocket-id/backend/internal/common" "gorm.io/gorm" ) type ErrorHandlerMiddleware struct{} func NewErrorHandlerMiddleware() *ErrorHandlerMiddleware { return &ErrorHandlerMiddleware{} } func (m *ErrorHandlerMiddleware) Add() gin.HandlerFunc { return func(c *gin.Context) { c.Next() for _, err := range c.Errors { // Check for record not found errors if errors.Is(err, gorm.ErrRecordNotFound) { errorResponse(c, http.StatusNotFound, "Record not found") return } // Check for validation errors var validationErrors validator.ValidationErrors if errors.As(err, &validationErrors) { message := handleValidationError(validationErrors) errorResponse(c, http.StatusBadRequest, message) return } // Check for slice validation errors svErr, ok := errors.AsType[binding.SliceValidationError](err) if ok { if errors.As(svErr[0], &validationErrors) { message := handleValidationError(validationErrors) errorResponse(c, http.StatusBadRequest, message) return } } // AppError with description appDescErr, ok := errors.AsType[common.AppErrorDescription](err) if ok { errorResponseWithDescription(c, appDescErr.HttpStatusCode(), appDescErr.Error(), appDescErr.Description()) return } // AppError (without description) appErr, ok := errors.AsType[common.AppError](err) if ok { errorResponse(c, appErr.HttpStatusCode(), appErr.Error()) return } c.JSON(http.StatusInternalServerError, errorResponseBody{ Error: "Something went wrong", }) } } } type errorResponseBody struct { Error string `json:"error"` ErrorDescription string `json:"error_description,omitempty"` } func errorResponse(c *gin.Context, statusCode int, message string) { // Capitalize the first letter of the message message = strings.ToUpper(message[:1]) + message[1:] c.JSON(statusCode, errorResponseBody{ Error: message, }) } func errorResponseWithDescription(c *gin.Context, statusCode int, message string, description string) { // Capitalize the first letter of the message message = strings.ToUpper(message[:1]) + message[1:] c.JSON(statusCode, errorResponseBody{ Error: message, ErrorDescription: description, }) } func handleValidationError(validationErrors validator.ValidationErrors) string { var errorMessages []string for _, ve := range validationErrors { fieldName := ve.Field() var errorMessage string switch ve.Tag() { case "required": errorMessage = fmt.Sprintf("%s is required", fieldName) case "email": errorMessage = fmt.Sprintf("%s must be a valid email address", fieldName) case "username": errorMessage = fmt.Sprintf("%s must only contain letters, numbers, underscores, dots, hyphens, and '@' symbols and not start or end with a special character", fieldName) case "url": errorMessage = fmt.Sprintf("%s must be a valid URL", fieldName) case "min": errorMessage = fmt.Sprintf("%s must be at least %s characters long", fieldName, ve.Param()) case "max": errorMessage = fmt.Sprintf("%s must be at most %s characters long", fieldName, ve.Param()) default: errorMessage = fmt.Sprintf("%s is invalid", fieldName) } errorMessages = append(errorMessages, errorMessage) } // Join all the error messages into a single string combinedErrors := strings.Join(errorMessages, ", ") return combinedErrors }