Files
ntfywui/internal/security/csrf.go
2026-01-12 13:51:52 +01:00

53 lines
1.3 KiB
Go

package security
import (
"crypto/rand"
"encoding/base64"
"net/http"
)
type CSRFFuncs struct {
// Read session for csrf value
GetCSRF func(r *http.Request) (token string, ok bool)
// Save ensures session has csrf value
EnsureCSRF func(w http.ResponseWriter, r *http.Request) (token string, err error)
}
func NewCSRFToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func CSRFMiddleware(f CSRFFuncs) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Always ensure csrf exists for HTML GET requests
if r.Method == http.MethodGet || r.Method == http.MethodHead {
_, _ = f.EnsureCSRF(w, r)
next.ServeHTTP(w, r)
return
}
// Validate for unsafe methods
token, ok := f.GetCSRF(r)
if !ok || token == "" {
http.Error(w, "csrf missing", http.StatusForbidden)
return
}
// Prefer header (JS), fallback to form value
got := r.Header.Get("X-CSRF-Token")
if got == "" {
_ = r.ParseForm()
got = r.Form.Get("csrf")
}
if got == "" || got != token {
http.Error(w, "csrf mismatch", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
}