This commit is contained in:
2026-01-12 13:51:52 +01:00
parent 90191c50d8
commit 06e55c441e
44 changed files with 3066 additions and 1 deletions

52
internal/security/csrf.go Normal file
View File

@@ -0,0 +1,52 @@
package security
import (
"crypto/rand"
"encoding/base64"
"net/http"
)
type CSRFFuncs struct {
// Read session for csrf value
GetCSRF func(r *http.Request) (token string, ok bool)
// Save ensures session has csrf value
EnsureCSRF func(w http.ResponseWriter, r *http.Request) (token string, err error)
}
func NewCSRFToken() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
func CSRFMiddleware(f CSRFFuncs) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Always ensure csrf exists for HTML GET requests
if r.Method == http.MethodGet || r.Method == http.MethodHead {
_, _ = f.EnsureCSRF(w, r)
next.ServeHTTP(w, r)
return
}
// Validate for unsafe methods
token, ok := f.GetCSRF(r)
if !ok || token == "" {
http.Error(w, "csrf missing", http.StatusForbidden)
return
}
// Prefer header (JS), fallback to form value
got := r.Header.Get("X-CSRF-Token")
if got == "" {
_ = r.ParseForm()
got = r.Form.Get("csrf")
}
if got == "" || got != token {
http.Error(w, "csrf mismatch", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}
}

View File

@@ -0,0 +1,27 @@
package security
import "net/http"
// SecureHeaders adds a baseline of security headers.
// CSP is intentionally conservative; adjust if you add external assets.
func SecureHeaders(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("Referrer-Policy", "no-referrer")
w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
w.Header().Set("Cross-Origin-Opener-Policy", "same-origin")
w.Header().Set("Cross-Origin-Resource-Policy", "same-origin")
w.Header().Set("Cross-Origin-Embedder-Policy", "require-corp")
w.Header().Set("Content-Security-Policy",
"default-src 'self'; "+
"script-src 'self'; "+
"style-src 'self'; "+
"img-src 'self' data:; "+
"object-src 'none'; "+
"base-uri 'none'; "+
"frame-ancestors 'none'; "+
"form-action 'self'")
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,96 @@
package security
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"fmt"
"strconv"
"strings"
)
func pbkdf2SHA256(password, salt []byte, iter, keyLen int) []byte {
// PBKDF2 per RFC2898
hLen := 32 // sha256
numBlocks := (keyLen + hLen - 1) / hLen
var out []byte
for block := 1; block <= numBlocks; block++ {
t := pbkdf2F(password, salt, iter, block)
out = append(out, t...)
}
return out[:keyLen]
}
func pbkdf2F(password, salt []byte, iter, blockNum int) []byte {
// U1 = PRF(P, S || INT(blockNum))
// Uc = PRF(P, Uc-1)
// T = U1 XOR U2 XOR ... XOR Uiter
b := make([]byte, len(salt)+4)
copy(b, salt)
b[len(salt)+0] = byte(blockNum >> 24)
b[len(salt)+1] = byte(blockNum >> 16)
b[len(salt)+2] = byte(blockNum >> 8)
b[len(salt)+3] = byte(blockNum)
u := hmacSHA256(password, b)
t := make([]byte, len(u))
copy(t, u)
for i := 2; i <= iter; i++ {
u = hmacSHA256(password, u)
for j := range t {
t[j] ^= u[j]
}
}
return t
}
func hmacSHA256(key, msg []byte) []byte {
m := hmac.New(sha256.New, key)
m.Write(msg)
return m.Sum(nil)
}
func HashPasswordPBKDF2(password string, salt []byte, iter int) string {
key := pbkdf2SHA256([]byte(password), salt, iter, 32)
return fmt.Sprintf("pbkdf2_sha256$%d$%s$%s",
iter,
base64.RawURLEncoding.EncodeToString(salt),
base64.RawURLEncoding.EncodeToString(key),
)
}
func VerifyPasswordPBKDF2(password, encoded string) (bool, error) {
// Go's fmt scanning does not support "scanset" verbs like %[^$]. Parse explicitly.
parts := strings.Split(encoded, "$")
if len(parts) != 4 {
return false, fmt.Errorf("parse hash: expected 4 parts, got %d", len(parts))
}
algo := parts[0]
iter, err := strconv.Atoi(parts[1])
if err != nil {
return false, fmt.Errorf("parse hash iter: %w", err)
}
saltB64 := parts[2]
keyB64 := parts[3]
if algo != "pbkdf2_sha256" {
return false, fmt.Errorf("unsupported algo %q", algo)
}
salt, err := base64.RawURLEncoding.DecodeString(saltB64)
if err != nil {
return false, fmt.Errorf("salt decode: %w", err)
}
want, err := base64.RawURLEncoding.DecodeString(keyB64)
if err != nil {
return false, fmt.Errorf("key decode: %w", err)
}
got := pbkdf2SHA256([]byte(password), salt, iter, len(want))
// constant-time compare
if len(got) != len(want) {
return false, nil
}
var diff byte
for i := range got {
diff |= got[i] ^ want[i]
}
return diff == 0, nil
}

View File

@@ -0,0 +1,80 @@
package security
import (
"net/http"
"sync"
"time"
)
type bucket struct {
tokens float64
last time.Time
blocked time.Time
}
type RateLimiter struct {
mu sync.Mutex
capacity float64
refillPer float64 // tokens/sec
ttl time.Duration
buckets map[string]*bucket
}
func NewRateLimiter(capacity int, refillPerSec float64, ttl time.Duration) *RateLimiter {
return &RateLimiter{
capacity: float64(capacity),
refillPer: refillPerSec,
ttl: ttl,
buckets: map[string]*bucket{},
}
}
func (rl *RateLimiter) Allow(key string) bool {
now := time.Now()
rl.mu.Lock()
defer rl.mu.Unlock()
b := rl.buckets[key]
if b == nil {
b = &bucket{tokens: rl.capacity, last: now}
rl.buckets[key] = b
}
// cleanup occasionally
if len(rl.buckets) > 10000 {
for k, v := range rl.buckets {
if now.Sub(v.last) > rl.ttl {
delete(rl.buckets, k)
}
}
}
elapsed := now.Sub(b.last).Seconds()
if elapsed > 0 {
b.tokens += elapsed * rl.refillPer
if b.tokens > rl.capacity {
b.tokens = rl.capacity
}
b.last = now
}
if b.tokens >= 1 {
b.tokens -= 1
return true
}
return false
}
func (rl *RateLimiter) Middleware(keyFn func(r *http.Request) string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
key := keyFn(r)
if key == "" {
key = "anon"
}
if !rl.Allow(key) {
w.Header().Set("Retry-After", "2")
http.Error(w, "rate limited", http.StatusTooManyRequests)
return
}
next.ServeHTTP(w, r)
})
}
}

View File

@@ -0,0 +1,53 @@
package security
import (
"net"
"net/http"
"strings"
)
type RealIPConfig struct {
TrustedProxies []*net.IPNet
}
func (c RealIPConfig) IsTrusted(remoteAddr string) bool {
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
host = remoteAddr
}
ip := net.ParseIP(host)
if ip == nil {
return false
}
for _, n := range c.TrustedProxies {
if n.Contains(ip) {
return true
}
}
return false
}
// RealIP returns the best-effort client IP.
// It only honors X-Forwarded-For when the direct peer is in TrustedProxies.
func RealIP(r *http.Request, cfg RealIPConfig) string {
if cfg.IsTrusted(r.RemoteAddr) {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
// First IP is original client
parts := strings.Split(xff, ",")
if len(parts) > 0 {
ip := strings.TrimSpace(parts[0])
if net.ParseIP(ip) != nil {
return ip
}
}
}
if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" && net.ParseIP(xrip) != nil {
return xrip
}
}
host, _, err := net.SplitHostPort(r.RemoteAddr)
if err == nil && net.ParseIP(host) != nil {
return host
}
return r.RemoteAddr
}

View File

@@ -0,0 +1,122 @@
package security
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"time"
)
type SessionManager struct {
cookieName string
secure bool
sameSite http.SameSite
maxAge time.Duration
aead cipher.AEAD
}
func NewSessionManager(secret []byte, cookieName string, secure bool) (*SessionManager, error) {
// Derive 32-byte key for AES-256-GCM
key := hmacSHA256(secret, []byte("ntfywui session v1"))
if len(key) != 32 {
return nil, io.ErrUnexpectedEOF
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
aead, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
return &SessionManager{
cookieName: cookieName,
secure: secure,
sameSite: http.SameSiteLaxMode,
maxAge: 12 * time.Hour,
aead: aead,
}, nil
}
// Session contents are encrypted+authenticated.
type Session struct {
User string `json:"user"`
Role string `json:"role"`
CSRF string `json:"csrf"`
Flash string `json:"flash,omitempty"`
IssuedAt int64 `json:"iat"`
ExpiresAt int64 `json:"exp"`
}
func (sm *SessionManager) Get(r *http.Request) (*Session, bool) {
c, err := r.Cookie(sm.cookieName)
if err != nil || c.Value == "" {
return &Session{}, false
}
raw, err := base64.RawURLEncoding.DecodeString(c.Value)
if err != nil || len(raw) < sm.aead.NonceSize() {
return &Session{}, false
}
nonce := raw[:sm.aead.NonceSize()]
ct := raw[sm.aead.NonceSize():]
pt, err := sm.aead.Open(nil, nonce, ct, nil)
if err != nil {
return &Session{}, false
}
var s Session
if err := json.Unmarshal(pt, &s); err != nil {
return &Session{}, false
}
now := time.Now().Unix()
if s.ExpiresAt != 0 && now > s.ExpiresAt {
return &Session{}, false
}
return &s, true
}
func (sm *SessionManager) Save(w http.ResponseWriter, s *Session) error {
now := time.Now()
if s.IssuedAt == 0 {
s.IssuedAt = now.Unix()
}
if s.ExpiresAt == 0 {
s.ExpiresAt = now.Add(sm.maxAge).Unix()
}
pt, err := json.Marshal(s)
if err != nil {
return err
}
nonce := make([]byte, sm.aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return err
}
ct := sm.aead.Seal(nil, nonce, pt, nil)
raw := append(nonce, ct...)
val := base64.RawURLEncoding.EncodeToString(raw)
http.SetCookie(w, &http.Cookie{
Name: sm.cookieName,
Value: val,
Path: "/",
HttpOnly: true,
Secure: sm.secure,
SameSite: sm.sameSite,
MaxAge: int(sm.maxAge.Seconds()),
})
return nil
}
func (sm *SessionManager) Clear(w http.ResponseWriter) {
http.SetCookie(w, &http.Cookie{
Name: sm.cookieName,
Value: "",
Path: "/",
HttpOnly: true,
Secure: sm.secure,
SameSite: sm.sameSite,
MaxAge: -1,
})
}

63
internal/security/totp.go Normal file
View File

@@ -0,0 +1,63 @@
package security
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha1"
"encoding/base32"
"encoding/binary"
"fmt"
"strings"
"time"
)
// GenerateTOTPSecret returns a base32 secret without padding.
func GenerateTOTPSecret() (string, error) {
b := make([]byte, 20)
if _, err := rand.Read(b); err != nil {
return "", err
}
enc := base32.StdEncoding.WithPadding(base32.NoPadding)
return enc.EncodeToString(b), nil
}
// VerifyTOTP verifies a 6-digit token with ±1 step skew (30s step).
func VerifyTOTP(secretBase32, code string, now time.Time) bool {
code = strings.ReplaceAll(code, " ", "")
if len(code) != 6 {
return false
}
sec, err := decodeBase32NoPad(secretBase32)
if err != nil {
return false
}
t := now.Unix() / 30
for _, drift := range []int64{-1, 0, 1} {
if hotp(sec, uint64(t+drift)) == code {
return true
}
}
return false
}
func hotp(secret []byte, counter uint64) string {
var buf [8]byte
binary.BigEndian.PutUint64(buf[:], counter)
mac := hmac.New(sha1.New, secret)
mac.Write(buf[:])
sum := mac.Sum(nil)
off := sum[len(sum)-1] & 0x0f
bin := (int(sum[off])&0x7f)<<24 |
(int(sum[off+1])&0xff)<<16 |
(int(sum[off+2])&0xff)<<8 |
(int(sum[off+3]) & 0xff)
otp := bin % 1000000
return fmt.Sprintf("%06d", otp)
}
func decodeBase32NoPad(s string) ([]byte, error) {
s = strings.ToUpper(strings.ReplaceAll(s, " ", ""))
enc := base32.StdEncoding.WithPadding(base32.NoPadding)
return enc.DecodeString(s)
}

View File

@@ -0,0 +1,4 @@
package security
// Version is a dummy constant used to avoid unused imports in main.
const Version = "v0"