53 lines
1.3 KiB
Go
53 lines
1.3 KiB
Go
package security
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"net/http"
|
|
)
|
|
|
|
type CSRFFuncs struct {
|
|
// Read session for csrf value
|
|
GetCSRF func(r *http.Request) (token string, ok bool)
|
|
// Save ensures session has csrf value
|
|
EnsureCSRF func(w http.ResponseWriter, r *http.Request) (token string, err error)
|
|
}
|
|
|
|
func NewCSRFToken() (string, error) {
|
|
b := make([]byte, 32)
|
|
if _, err := rand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
return base64.RawURLEncoding.EncodeToString(b), nil
|
|
}
|
|
|
|
func CSRFMiddleware(f CSRFFuncs) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Always ensure csrf exists for HTML GET requests
|
|
if r.Method == http.MethodGet || r.Method == http.MethodHead {
|
|
_, _ = f.EnsureCSRF(w, r)
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
// Validate for unsafe methods
|
|
token, ok := f.GetCSRF(r)
|
|
if !ok || token == "" {
|
|
http.Error(w, "csrf missing", http.StatusForbidden)
|
|
return
|
|
}
|
|
// Prefer header (JS), fallback to form value
|
|
got := r.Header.Get("X-CSRF-Token")
|
|
if got == "" {
|
|
_ = r.ParseForm()
|
|
got = r.Form.Get("csrf")
|
|
}
|
|
if got == "" || got != token {
|
|
http.Error(w, "csrf mismatch", http.StatusForbidden)
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|