Files
flod/main.go
jbergner 159ddfff76
All checks were successful
release-tag / release-image (push) Successful in 1m40s
debug
2025-06-09 17:32:27 +02:00

332 lines
8.2 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/netip"
"os"
"sort"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
)
// -----------------------------------------------------------------------------
// CONFIGURATION
// -----------------------------------------------------------------------------
type Source struct {
Category string // e.g. "spam", "tor", "malware"
URL []string // one or many URLs that belong to the same category
}
type Config struct {
RedisAddr string
Sources []Source // each Source now groups URLs per category
TTLHours int // Redis TTL for block entries
}
func loadConfig() Config {
// --- default onecategory fallback --------------------------------------
srcs := []Source{{
Category: "generic",
URL: []string{"https://ipv64.net/blocklists/ipv64_blocklist_firehole_l1.txt"},
}}
/*
ENV syntax supporting multiple URLs per category:
BLOCKLIST_SOURCES="spam:https://a.net|https://b.net,tor:https://c.net;https://d.net"
categories separated by comma
URLs inside category separated by | or ;
*/
if env := os.Getenv("BLOCKLIST_SOURCES"); env != "" {
srcs = nil // override default
for _, spec := range strings.Split(env, ",") {
spec = strings.TrimSpace(spec)
if spec == "" {
continue
}
parts := strings.SplitN(spec, ":", 2)
if len(parts) != 2 {
continue // malformed fragment
}
cat := strings.TrimSpace(parts[0])
rawURLs := strings.FieldsFunc(parts[1], func(r rune) bool { return r == '|' || r == ';' })
var urls []string
for _, u := range rawURLs {
if u = strings.TrimSpace(u); u != "" {
urls = append(urls, u)
}
}
if len(urls) > 0 {
srcs = append(srcs, Source{Category: cat, URL: urls})
}
}
}
ttl := 720 // 30 days
if env := os.Getenv("TTL_HOURS"); env != "" {
fmt.Sscanf(env, "%d", &ttl)
}
fmt.Println(getenv("REDIS_ADDR", "localhost:6379"), srcs, ttl)
return Config{
RedisAddr: getenv("REDIS_ADDR", "localhost:6379"),
Sources: srcs,
TTLHours: ttl,
}
}
func getenv(k, def string) string {
if v := os.Getenv(k); v != "" {
return v
}
return def
}
// -----------------------------------------------------------------------------
// REDIS KEY HELPERS
// -----------------------------------------------------------------------------
func keyBlock(cat string, p netip.Prefix) string { return "bl:" + cat + ":" + p.String() }
func keyWhite(a netip.Addr) string { return "wl:" + a.String() }
// -----------------------------------------------------------------------------
// INMEMORY RANGER percategory CIDR map
// -----------------------------------------------------------------------------
type Ranger struct {
mu sync.RWMutex
blocks map[string]map[netip.Prefix]struct{} // cat → set(prefix)
whites map[netip.Addr]struct{}
}
func newRanger() *Ranger {
return &Ranger{
blocks: make(map[string]map[netip.Prefix]struct{}),
whites: make(map[netip.Addr]struct{}),
}
}
func (r *Ranger) resetBlocks(m map[string]map[netip.Prefix]struct{}) {
r.mu.Lock()
r.blocks = m
r.mu.Unlock()
}
func (r *Ranger) addWhite(a netip.Addr) {
r.mu.Lock()
r.whites[a] = struct{}{}
r.mu.Unlock()
}
// blockedInCats returns slice of categories in which IP is blocked
func (r *Ranger) blockedInCats(a netip.Addr, cats []string) []string {
r.mu.RLock()
defer r.mu.RUnlock()
if _, ok := r.whites[a]; ok {
return nil
}
if len(cats) == 0 {
for c := range r.blocks {
cats = append(cats, c)
}
}
var res []string
for _, cat := range cats {
if m, ok := r.blocks[cat]; ok {
for p := range m {
if p.Contains(a) {
res = append(res, cat)
break
}
}
}
}
sort.Strings(res)
return res
}
// -----------------------------------------------------------------------------
// SYNC WORKER fetch lists → Redis + Ranger
// -----------------------------------------------------------------------------
func syncOnce(ctx context.Context, cfg Config, rdb *redis.Client, ranger *Ranger) error {
expiry := time.Duration(cfg.TTLHours) * time.Hour
newBlocks := make(map[string]map[netip.Prefix]struct{})
for _, src := range cfg.Sources {
for _, url := range src.URL {
if err := processURL(ctx, url, func(p netip.Prefix) {
if _, ok := newBlocks[src.Category]; !ok {
newBlocks[src.Category] = make(map[netip.Prefix]struct{})
}
newBlocks[src.Category][p] = struct{}{}
_ = rdb.Set(ctx, keyBlock(src.Category, p), "1", expiry).Err()
}); err != nil {
return err
}
}
}
ranger.resetBlocks(newBlocks)
return nil
}
func processURL(ctx context.Context, url string, cb func(netip.Prefix)) error {
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("%s -> %s", url, resp.Status)
}
return parseStream(resp.Body, cb)
}
func parseStream(r io.Reader, cb func(netip.Prefix)) error {
s := bufio.NewScanner(r)
for s.Scan() {
line := strings.TrimSpace(s.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if p, err := netip.ParsePrefix(line); err == nil {
cb(p)
}
}
return s.Err()
}
// -----------------------------------------------------------------------------
// HTTP SERVER
// -----------------------------------------------------------------------------
type Server struct {
cfg Config
ranger *Ranger
rdb *redis.Client
}
func (s *Server) routes() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/check/", s.handleCheck) // GET /check/<ip>?cats=spam,tor
mux.HandleFunc("/whitelist", s.handleAddWhite) // POST {"ip":"1.2.3.4"}
mux.HandleFunc("/categories", s.handleCats) // GET all categories
return mux
}
func (s *Server) handleCheck(w http.ResponseWriter, r *http.Request) {
ipStr := strings.TrimPrefix(r.URL.Path, "/check/")
addr, err := netip.ParseAddr(ipStr)
if err != nil {
http.Error(w, "bad ip", http.StatusBadRequest)
return
}
var cats []string
if q := strings.TrimSpace(r.URL.Query().Get("cats")); q != "" {
cats = strings.Split(q, ",")
}
blocked := s.ranger.blockedInCats(addr, cats)
writeJSON(w, map[string]any{
"ip": ipStr,
"blocked": len(blocked) > 0,
"categories": blocked,
})
}
func (s *Server) handleAddWhite(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "POST only", http.StatusMethodNotAllowed)
return
}
var body struct {
IP string `json:"ip"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "bad json", 400)
return
}
addr, err := netip.ParseAddr(strings.TrimSpace(body.IP))
if err != nil {
http.Error(w, "bad ip", 400)
return
}
if err := s.rdb.Set(r.Context(), keyWhite(addr), "1", 0).Err(); err != nil {
http.Error(w, "redis", 500)
return
}
s.ranger.addWhite(addr)
writeJSON(w, map[string]string{"status": "whitelisted"})
}
func (s *Server) handleCats(w http.ResponseWriter, _ *http.Request) {
s.ranger.mu.RLock()
cats := make([]string, 0, len(s.ranger.blocks))
for c := range s.ranger.blocks {
cats = append(cats, c)
}
s.ranger.mu.RUnlock()
sort.Strings(cats)
writeJSON(w, map[string]any{"categories": cats})
}
func writeJSON(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(v)
}
// -----------------------------------------------------------------------------
// MAIN
// -----------------------------------------------------------------------------
func main() {
cfg := loadConfig()
rdb := redis.NewClient(&redis.Options{Addr: cfg.RedisAddr})
ctx := context.Background()
if err := rdb.Ping(ctx).Err(); err != nil {
log.Fatalf("redis: %v", err)
}
ranger := newRanger()
if err := syncOnce(ctx, cfg, rdb, ranger); err != nil {
log.Println("initial sync:", err)
}
go func() {
ticker := time.NewTicker(6 * time.Hour)
defer ticker.Stop()
for range ticker.C {
if err := syncOnce(ctx, cfg, rdb, ranger); err != nil {
log.Println("sync:", err)
}
}
}()
srv := &Server{cfg: cfg, ranger: ranger, rdb: rdb}
log.Println("listening on :8080")
if err := http.ListenAndServe(":8080", srv.routes()); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatal(err)
}
}