init
This commit is contained in:
52
internal/security/csrf.go
Normal file
52
internal/security/csrf.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type CSRFFuncs struct {
|
||||
// Read session for csrf value
|
||||
GetCSRF func(r *http.Request) (token string, ok bool)
|
||||
// Save ensures session has csrf value
|
||||
EnsureCSRF func(w http.ResponseWriter, r *http.Request) (token string, err error)
|
||||
}
|
||||
|
||||
func NewCSRFToken() (string, error) {
|
||||
b := make([]byte, 32)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
func CSRFMiddleware(f CSRFFuncs) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// Always ensure csrf exists for HTML GET requests
|
||||
if r.Method == http.MethodGet || r.Method == http.MethodHead {
|
||||
_, _ = f.EnsureCSRF(w, r)
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
// Validate for unsafe methods
|
||||
token, ok := f.GetCSRF(r)
|
||||
if !ok || token == "" {
|
||||
http.Error(w, "csrf missing", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
// Prefer header (JS), fallback to form value
|
||||
got := r.Header.Get("X-CSRF-Token")
|
||||
if got == "" {
|
||||
_ = r.ParseForm()
|
||||
got = r.Form.Get("csrf")
|
||||
}
|
||||
if got == "" || got != token {
|
||||
http.Error(w, "csrf mismatch", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
27
internal/security/headers.go
Normal file
27
internal/security/headers.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package security
|
||||
|
||||
import "net/http"
|
||||
|
||||
// SecureHeaders adds a baseline of security headers.
|
||||
// CSP is intentionally conservative; adjust if you add external assets.
|
||||
func SecureHeaders(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-Frame-Options", "DENY")
|
||||
w.Header().Set("Referrer-Policy", "no-referrer")
|
||||
w.Header().Set("Permissions-Policy", "geolocation=(), microphone=(), camera=()")
|
||||
w.Header().Set("Cross-Origin-Opener-Policy", "same-origin")
|
||||
w.Header().Set("Cross-Origin-Resource-Policy", "same-origin")
|
||||
w.Header().Set("Cross-Origin-Embedder-Policy", "require-corp")
|
||||
w.Header().Set("Content-Security-Policy",
|
||||
"default-src 'self'; "+
|
||||
"script-src 'self'; "+
|
||||
"style-src 'self'; "+
|
||||
"img-src 'self' data:; "+
|
||||
"object-src 'none'; "+
|
||||
"base-uri 'none'; "+
|
||||
"frame-ancestors 'none'; "+
|
||||
"form-action 'self'")
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
96
internal/security/pbkdf2.go
Normal file
96
internal/security/pbkdf2.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func pbkdf2SHA256(password, salt []byte, iter, keyLen int) []byte {
|
||||
// PBKDF2 per RFC2898
|
||||
hLen := 32 // sha256
|
||||
numBlocks := (keyLen + hLen - 1) / hLen
|
||||
var out []byte
|
||||
for block := 1; block <= numBlocks; block++ {
|
||||
t := pbkdf2F(password, salt, iter, block)
|
||||
out = append(out, t...)
|
||||
}
|
||||
return out[:keyLen]
|
||||
}
|
||||
|
||||
func pbkdf2F(password, salt []byte, iter, blockNum int) []byte {
|
||||
// U1 = PRF(P, S || INT(blockNum))
|
||||
// Uc = PRF(P, Uc-1)
|
||||
// T = U1 XOR U2 XOR ... XOR Uiter
|
||||
b := make([]byte, len(salt)+4)
|
||||
copy(b, salt)
|
||||
b[len(salt)+0] = byte(blockNum >> 24)
|
||||
b[len(salt)+1] = byte(blockNum >> 16)
|
||||
b[len(salt)+2] = byte(blockNum >> 8)
|
||||
b[len(salt)+3] = byte(blockNum)
|
||||
|
||||
u := hmacSHA256(password, b)
|
||||
t := make([]byte, len(u))
|
||||
copy(t, u)
|
||||
for i := 2; i <= iter; i++ {
|
||||
u = hmacSHA256(password, u)
|
||||
for j := range t {
|
||||
t[j] ^= u[j]
|
||||
}
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
func hmacSHA256(key, msg []byte) []byte {
|
||||
m := hmac.New(sha256.New, key)
|
||||
m.Write(msg)
|
||||
return m.Sum(nil)
|
||||
}
|
||||
|
||||
func HashPasswordPBKDF2(password string, salt []byte, iter int) string {
|
||||
key := pbkdf2SHA256([]byte(password), salt, iter, 32)
|
||||
return fmt.Sprintf("pbkdf2_sha256$%d$%s$%s",
|
||||
iter,
|
||||
base64.RawURLEncoding.EncodeToString(salt),
|
||||
base64.RawURLEncoding.EncodeToString(key),
|
||||
)
|
||||
}
|
||||
|
||||
func VerifyPasswordPBKDF2(password, encoded string) (bool, error) {
|
||||
// Go's fmt scanning does not support "scanset" verbs like %[^$]. Parse explicitly.
|
||||
parts := strings.Split(encoded, "$")
|
||||
if len(parts) != 4 {
|
||||
return false, fmt.Errorf("parse hash: expected 4 parts, got %d", len(parts))
|
||||
}
|
||||
algo := parts[0]
|
||||
iter, err := strconv.Atoi(parts[1])
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("parse hash iter: %w", err)
|
||||
}
|
||||
saltB64 := parts[2]
|
||||
keyB64 := parts[3]
|
||||
if algo != "pbkdf2_sha256" {
|
||||
return false, fmt.Errorf("unsupported algo %q", algo)
|
||||
}
|
||||
salt, err := base64.RawURLEncoding.DecodeString(saltB64)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("salt decode: %w", err)
|
||||
}
|
||||
want, err := base64.RawURLEncoding.DecodeString(keyB64)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("key decode: %w", err)
|
||||
}
|
||||
got := pbkdf2SHA256([]byte(password), salt, iter, len(want))
|
||||
// constant-time compare
|
||||
if len(got) != len(want) {
|
||||
return false, nil
|
||||
}
|
||||
var diff byte
|
||||
for i := range got {
|
||||
diff |= got[i] ^ want[i]
|
||||
}
|
||||
return diff == 0, nil
|
||||
}
|
||||
80
internal/security/ratelimit.go
Normal file
80
internal/security/ratelimit.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type bucket struct {
|
||||
tokens float64
|
||||
last time.Time
|
||||
blocked time.Time
|
||||
}
|
||||
|
||||
type RateLimiter struct {
|
||||
mu sync.Mutex
|
||||
capacity float64
|
||||
refillPer float64 // tokens/sec
|
||||
ttl time.Duration
|
||||
buckets map[string]*bucket
|
||||
}
|
||||
|
||||
func NewRateLimiter(capacity int, refillPerSec float64, ttl time.Duration) *RateLimiter {
|
||||
return &RateLimiter{
|
||||
capacity: float64(capacity),
|
||||
refillPer: refillPerSec,
|
||||
ttl: ttl,
|
||||
buckets: map[string]*bucket{},
|
||||
}
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Allow(key string) bool {
|
||||
now := time.Now()
|
||||
rl.mu.Lock()
|
||||
defer rl.mu.Unlock()
|
||||
b := rl.buckets[key]
|
||||
if b == nil {
|
||||
b = &bucket{tokens: rl.capacity, last: now}
|
||||
rl.buckets[key] = b
|
||||
}
|
||||
// cleanup occasionally
|
||||
if len(rl.buckets) > 10000 {
|
||||
for k, v := range rl.buckets {
|
||||
if now.Sub(v.last) > rl.ttl {
|
||||
delete(rl.buckets, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
elapsed := now.Sub(b.last).Seconds()
|
||||
if elapsed > 0 {
|
||||
b.tokens += elapsed * rl.refillPer
|
||||
if b.tokens > rl.capacity {
|
||||
b.tokens = rl.capacity
|
||||
}
|
||||
b.last = now
|
||||
}
|
||||
if b.tokens >= 1 {
|
||||
b.tokens -= 1
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (rl *RateLimiter) Middleware(keyFn func(r *http.Request) string) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
key := keyFn(r)
|
||||
if key == "" {
|
||||
key = "anon"
|
||||
}
|
||||
if !rl.Allow(key) {
|
||||
w.Header().Set("Retry-After", "2")
|
||||
http.Error(w, "rate limited", http.StatusTooManyRequests)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
53
internal/security/realip.go
Normal file
53
internal/security/realip.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type RealIPConfig struct {
|
||||
TrustedProxies []*net.IPNet
|
||||
}
|
||||
|
||||
func (c RealIPConfig) IsTrusted(remoteAddr string) bool {
|
||||
host, _, err := net.SplitHostPort(remoteAddr)
|
||||
if err != nil {
|
||||
host = remoteAddr
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
for _, n := range c.TrustedProxies {
|
||||
if n.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RealIP returns the best-effort client IP.
|
||||
// It only honors X-Forwarded-For when the direct peer is in TrustedProxies.
|
||||
func RealIP(r *http.Request, cfg RealIPConfig) string {
|
||||
if cfg.IsTrusted(r.RemoteAddr) {
|
||||
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
||||
// First IP is original client
|
||||
parts := strings.Split(xff, ",")
|
||||
if len(parts) > 0 {
|
||||
ip := strings.TrimSpace(parts[0])
|
||||
if net.ParseIP(ip) != nil {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
if xrip := strings.TrimSpace(r.Header.Get("X-Real-IP")); xrip != "" && net.ParseIP(xrip) != nil {
|
||||
return xrip
|
||||
}
|
||||
}
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err == nil && net.ParseIP(host) != nil {
|
||||
return host
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
122
internal/security/sessions.go
Normal file
122
internal/security/sessions.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SessionManager struct {
|
||||
cookieName string
|
||||
secure bool
|
||||
sameSite http.SameSite
|
||||
maxAge time.Duration
|
||||
aead cipher.AEAD
|
||||
}
|
||||
|
||||
func NewSessionManager(secret []byte, cookieName string, secure bool) (*SessionManager, error) {
|
||||
// Derive 32-byte key for AES-256-GCM
|
||||
key := hmacSHA256(secret, []byte("ntfywui session v1"))
|
||||
if len(key) != 32 {
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
aead, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &SessionManager{
|
||||
cookieName: cookieName,
|
||||
secure: secure,
|
||||
sameSite: http.SameSiteLaxMode,
|
||||
maxAge: 12 * time.Hour,
|
||||
aead: aead,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Session contents are encrypted+authenticated.
|
||||
type Session struct {
|
||||
User string `json:"user"`
|
||||
Role string `json:"role"`
|
||||
CSRF string `json:"csrf"`
|
||||
Flash string `json:"flash,omitempty"`
|
||||
IssuedAt int64 `json:"iat"`
|
||||
ExpiresAt int64 `json:"exp"`
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Get(r *http.Request) (*Session, bool) {
|
||||
c, err := r.Cookie(sm.cookieName)
|
||||
if err != nil || c.Value == "" {
|
||||
return &Session{}, false
|
||||
}
|
||||
raw, err := base64.RawURLEncoding.DecodeString(c.Value)
|
||||
if err != nil || len(raw) < sm.aead.NonceSize() {
|
||||
return &Session{}, false
|
||||
}
|
||||
nonce := raw[:sm.aead.NonceSize()]
|
||||
ct := raw[sm.aead.NonceSize():]
|
||||
pt, err := sm.aead.Open(nil, nonce, ct, nil)
|
||||
if err != nil {
|
||||
return &Session{}, false
|
||||
}
|
||||
var s Session
|
||||
if err := json.Unmarshal(pt, &s); err != nil {
|
||||
return &Session{}, false
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
if s.ExpiresAt != 0 && now > s.ExpiresAt {
|
||||
return &Session{}, false
|
||||
}
|
||||
return &s, true
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Save(w http.ResponseWriter, s *Session) error {
|
||||
now := time.Now()
|
||||
if s.IssuedAt == 0 {
|
||||
s.IssuedAt = now.Unix()
|
||||
}
|
||||
if s.ExpiresAt == 0 {
|
||||
s.ExpiresAt = now.Add(sm.maxAge).Unix()
|
||||
}
|
||||
pt, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nonce := make([]byte, sm.aead.NonceSize())
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return err
|
||||
}
|
||||
ct := sm.aead.Seal(nil, nonce, pt, nil)
|
||||
raw := append(nonce, ct...)
|
||||
val := base64.RawURLEncoding.EncodeToString(raw)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sm.cookieName,
|
||||
Value: val,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: sm.secure,
|
||||
SameSite: sm.sameSite,
|
||||
MaxAge: int(sm.maxAge.Seconds()),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
func (sm *SessionManager) Clear(w http.ResponseWriter) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: sm.cookieName,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: sm.secure,
|
||||
SameSite: sm.sameSite,
|
||||
MaxAge: -1,
|
||||
})
|
||||
}
|
||||
63
internal/security/totp.go
Normal file
63
internal/security/totp.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package security
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha1"
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GenerateTOTPSecret returns a base32 secret without padding.
|
||||
func GenerateTOTPSecret() (string, error) {
|
||||
b := make([]byte, 20)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
enc := base32.StdEncoding.WithPadding(base32.NoPadding)
|
||||
return enc.EncodeToString(b), nil
|
||||
}
|
||||
|
||||
// VerifyTOTP verifies a 6-digit token with ±1 step skew (30s step).
|
||||
func VerifyTOTP(secretBase32, code string, now time.Time) bool {
|
||||
code = strings.ReplaceAll(code, " ", "")
|
||||
if len(code) != 6 {
|
||||
return false
|
||||
}
|
||||
sec, err := decodeBase32NoPad(secretBase32)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
t := now.Unix() / 30
|
||||
for _, drift := range []int64{-1, 0, 1} {
|
||||
if hotp(sec, uint64(t+drift)) == code {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hotp(secret []byte, counter uint64) string {
|
||||
var buf [8]byte
|
||||
binary.BigEndian.PutUint64(buf[:], counter)
|
||||
mac := hmac.New(sha1.New, secret)
|
||||
mac.Write(buf[:])
|
||||
sum := mac.Sum(nil)
|
||||
|
||||
off := sum[len(sum)-1] & 0x0f
|
||||
bin := (int(sum[off])&0x7f)<<24 |
|
||||
(int(sum[off+1])&0xff)<<16 |
|
||||
(int(sum[off+2])&0xff)<<8 |
|
||||
(int(sum[off+3]) & 0xff)
|
||||
otp := bin % 1000000
|
||||
return fmt.Sprintf("%06d", otp)
|
||||
}
|
||||
|
||||
func decodeBase32NoPad(s string) ([]byte, error) {
|
||||
s = strings.ToUpper(strings.ReplaceAll(s, " ", ""))
|
||||
enc := base32.StdEncoding.WithPadding(base32.NoPadding)
|
||||
return enc.DecodeString(s)
|
||||
}
|
||||
4
internal/security/version.go
Normal file
4
internal/security/version.go
Normal file
@@ -0,0 +1,4 @@
|
||||
package security
|
||||
|
||||
// Version is a dummy constant used to avoid unused imports in main.
|
||||
const Version = "v0"
|
||||
Reference in New Issue
Block a user