package security import ( "crypto/rand" "encoding/base64" "net/http" ) type CSRFFuncs struct { // Read session for csrf value GetCSRF func(r *http.Request) (token string, ok bool) // Save ensures session has csrf value EnsureCSRF func(w http.ResponseWriter, r *http.Request) (token string, err error) } func NewCSRFToken() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(b), nil } func CSRFMiddleware(f CSRFFuncs) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Always ensure csrf exists for HTML GET requests if r.Method == http.MethodGet || r.Method == http.MethodHead { _, _ = f.EnsureCSRF(w, r) next.ServeHTTP(w, r) return } // Validate for unsafe methods token, ok := f.GetCSRF(r) if !ok || token == "" { http.Error(w, "csrf missing", http.StatusForbidden) return } // Prefer header (JS), fallback to form value got := r.Header.Get("X-CSRF-Token") if got == "" { _ = r.ParseForm() got = r.Form.Get("csrf") } if got == "" || got != token { http.Error(w, "csrf mismatch", http.StatusForbidden) return } next.ServeHTTP(w, r) }) } }