Files
flod/main.go
jbergner 13bfc5500a
All checks were successful
release-tag / release-image (push) Successful in 1m32s
fixes
2025-06-10 00:00:43 +02:00

476 lines
12 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/netip"
"os"
"sort"
"strings"
"sync"
"time"
"github.com/redis/go-redis/v9"
)
// -----------------------------------------------------------------------------
// CONFIGURATION & ENV
// -----------------------------------------------------------------------------
type Source struct {
Category string
URL []string
}
type Config struct {
RedisAddr string
Sources []Source
TTLHours int
IsWorker bool // true ⇒ lädt Blocklisten & schreibt sie nach Redis
}
func loadConfig() Config {
// default Blocklist source
srcs := []Source{{
Category: "generic",
URL: []string{"https://raw.githubusercontent.com/firehol/blocklist-ipsets/master/firehol_level1.netset"},
}}
if env := os.Getenv("BLOCKLIST_SOURCES"); env != "" {
srcs = nil
for _, spec := range strings.Split(env, ",") {
spec = strings.TrimSpace(spec)
if spec == "" {
continue
}
parts := strings.SplitN(spec, ":", 2)
if len(parts) != 2 {
continue
}
cat := strings.TrimSpace(parts[0])
raw := strings.FieldsFunc(parts[1], func(r rune) bool { return r == '|' || r == ';' })
var urls []string
for _, u := range raw {
if u = strings.TrimSpace(u); u != "" {
urls = append(urls, u)
}
}
if len(urls) > 0 {
srcs = append(srcs, Source{Category: cat, URL: urls})
}
}
}
ttl := 720
if env := os.Getenv("TTL_HOURS"); env != "" {
fmt.Sscanf(env, "%d", &ttl)
}
isWorker := strings.ToLower(os.Getenv("ROLE")) == "worker"
return Config{
RedisAddr: getenv("REDIS_ADDR", "redis:6379"),
Sources: srcs,
TTLHours: ttl,
IsWorker: isWorker,
}
}
func getenv(k, def string) string {
if v := os.Getenv(k); v != "" {
return v
}
return def
}
// -----------------------------------------------------------------------------
// REDIS KEYS
// -----------------------------------------------------------------------------
func keyBlock(cat string, p netip.Prefix) string { return "bl:" + cat + ":" + p.String() }
func keyWhite(a netip.Addr) string { return "wl:" + a.String() }
// -----------------------------------------------------------------------------
// RANGER threadsafe inmemory index
// -----------------------------------------------------------------------------
type Ranger struct {
mu sync.RWMutex
blocks map[string]map[netip.Prefix]struct{}
whites map[netip.Addr]struct{}
}
func newRanger() *Ranger {
return &Ranger{
blocks: make(map[string]map[netip.Prefix]struct{}),
whites: make(map[netip.Addr]struct{}),
}
}
func (r *Ranger) resetBlocks(m map[string]map[netip.Prefix]struct{}) {
r.mu.Lock()
r.blocks = m
r.mu.Unlock()
}
func (r *Ranger) resetWhites(set map[netip.Addr]struct{}) {
r.mu.Lock()
r.whites = set
r.mu.Unlock()
}
func (r *Ranger) addBlock(cat string, p netip.Prefix) {
r.mu.Lock()
if _, ok := r.blocks[cat]; !ok {
r.blocks[cat] = make(map[netip.Prefix]struct{})
}
r.blocks[cat][p] = struct{}{}
r.mu.Unlock()
}
func (r *Ranger) removeBlock(cat string, p netip.Prefix) {
r.mu.Lock()
if m, ok := r.blocks[cat]; ok {
delete(m, p)
}
r.mu.Unlock()
}
func (r *Ranger) addWhite(a netip.Addr) {
r.mu.Lock()
r.whites[a] = struct{}{}
r.mu.Unlock()
}
func (r *Ranger) removeWhite(a netip.Addr) {
r.mu.Lock()
delete(r.whites, a)
r.mu.Unlock()
}
func (r *Ranger) blockedInCats(a netip.Addr, cats []string) []string {
r.mu.RLock()
defer r.mu.RUnlock()
if _, ok := r.whites[a]; ok {
return nil
}
if len(cats) == 0 {
for c := range r.blocks {
cats = append(cats, c)
}
}
var res []string
for _, cat := range cats {
if m, ok := r.blocks[cat]; ok {
for p := range m {
if p.Contains(a) {
res = append(res, cat)
break
}
}
}
}
sort.Strings(res)
return res
}
// -----------------------------------------------------------------------------
// INITIAL LOAD FROM REDIS (baseline before keyspace events)
// -----------------------------------------------------------------------------
func loadFromRedis(ctx context.Context, rdb *redis.Client, r *Ranger) error {
// 1) Blocks
blocks := make(map[string]map[netip.Prefix]struct{})
iter := rdb.Scan(ctx, 0, "bl:*", 0).Iterator()
for iter.Next(ctx) {
key := iter.Val() // bl:<cat>:<cidr>
parts := strings.SplitN(key, ":", 3)
if len(parts) != 3 {
continue
}
cat, cidr := parts[1], parts[2]
p, err := netip.ParsePrefix(cidr)
if err != nil {
continue
}
if _, ok := blocks[cat]; !ok {
blocks[cat] = map[netip.Prefix]struct{}{}
}
blocks[cat][p] = struct{}{}
}
if err := iter.Err(); err != nil {
return err
}
r.resetBlocks(blocks)
// 2) Whites
whites := make(map[netip.Addr]struct{})
wIter := rdb.Scan(ctx, 0, "wl:*", 0).Iterator()
for wIter.Next(ctx) {
ip := strings.TrimPrefix(wIter.Val(), "wl:")
if a, err := netip.ParseAddr(ip); err == nil {
whites[a] = struct{}{}
}
}
if err := wIter.Err(); err != nil {
return err
}
r.resetWhites(whites)
return nil
}
// -----------------------------------------------------------------------------
// SYNC WORKER (only on ROLE=worker)
// -----------------------------------------------------------------------------
func syncLoop(ctx context.Context, cfg Config, rdb *redis.Client, ranger *Ranger) {
if err := syncOnce(ctx, cfg, rdb, ranger); err != nil {
log.Println("initial sync:", err)
}
ticker := time.NewTicker(6 * time.Hour)
for {
select {
case <-ticker.C:
if err := syncOnce(ctx, cfg, rdb, ranger); err != nil {
log.Println("sync loop:", err)
}
case <-ctx.Done():
ticker.Stop()
return
}
}
}
func syncOnce(ctx context.Context, cfg Config, rdb *redis.Client, ranger *Ranger) error {
expiry := time.Duration(cfg.TTLHours) * time.Hour
newBlocks := make(map[string]map[netip.Prefix]struct{})
for _, src := range cfg.Sources {
for _, url := range src.URL {
if err := fetchList(ctx, url, func(p netip.Prefix) {
if _, ok := newBlocks[src.Category]; !ok {
newBlocks[src.Category] = map[netip.Prefix]struct{}{}
}
newBlocks[src.Category][p] = struct{}{}
_ = rdb.Set(ctx, keyBlock(src.Category, p), "1", expiry).Err()
}); err != nil {
return err
}
}
}
ranger.resetBlocks(newBlocks)
return nil
}
func fetchList(ctx context.Context, url string, cb func(netip.Prefix)) error {
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("%s -> %s", url, resp.Status)
}
return parseStream(resp.Body, cb)
}
func parseStream(r io.Reader, cb func(netip.Prefix)) error {
s := bufio.NewScanner(r)
for s.Scan() {
line := strings.TrimSpace(s.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if p, err := netip.ParsePrefix(line); err == nil {
cb(p)
continue
}
if addr, err := netip.ParseAddr(line); err == nil {
plen := 32
if addr.Is6() {
plen = 128
}
cb(netip.PrefixFrom(addr, plen))
}
}
return s.Err()
}
// -----------------------------------------------------------------------------
// KEYSPACE SUBSCRIBER instant propagation
// -----------------------------------------------------------------------------
func subscribeKeyspace(ctx context.Context, rdb *redis.Client, ranger *Ranger) {
// listen to keyevent channels (not keyspace) so msg.Payload == actual key
patterns := []string{
"__keyevent@0__:set",
"__keyevent@0__:del",
"__keyevent@0__:expired",
}
pubsub := rdb.PSubscribe(ctx, patterns...)
go func() {
for msg := range pubsub.Channel() {
key := msg.Payload // full redis key e.g. "wl:1.2.3.4" or "bl:spam:10.0.0.0/8"
switch {
case strings.HasPrefix(key, "wl:"):
ipStr := strings.TrimPrefix(key, "wl:")
addr, err := netip.ParseAddr(ipStr)
if err != nil {
continue
}
switch msg.Channel {
case "__keyevent@0__:set":
ranger.addWhite(addr)
case "__keyevent@0__:del", "__keyevent@0__:expired":
ranger.removeWhite(addr)
}
case strings.HasPrefix(key, "bl:"):
parts := strings.SplitN(key, ":", 3)
if len(parts) != 3 {
continue
}
cat, cidr := parts[1], parts[2]
p, err := netip.ParsePrefix(cidr)
if err != nil {
continue
}
switch msg.Channel {
case "__keyevent@0__:set":
ranger.addBlock(cat, p)
case "__keyevent@0__:del", "__keyevent@0__:expired":
ranger.removeBlock(cat, p)
}
}
}
}()
}
// -----------------------------------------------------------------------------
// HTTP SERVER
// -----------------------------------------------------------------------------
type Server struct {
ranger *Ranger
rdb *redis.Client
}
func (s *Server) routes() http.Handler {
mux := http.NewServeMux()
mux.HandleFunc("/check/", s.handleCheck)
mux.HandleFunc("/whitelist", s.handleAddWhite)
mux.HandleFunc("/categories", s.handleCats)
return mux
}
func (s *Server) handleCheck(w http.ResponseWriter, r *http.Request) {
ipStr := strings.TrimPrefix(r.URL.Path, "/check/")
if ipStr == "" {
http.Error(w, "missing IP", http.StatusBadRequest)
return
}
addr, err := netip.ParseAddr(ipStr)
if err != nil {
http.Error(w, "invalid IP", http.StatusBadRequest)
return
}
catsParam := strings.TrimSpace(r.URL.Query().Get("cats"))
var cats []string
if catsParam != "" {
cats = strings.Split(catsParam, ",")
}
blocked := s.ranger.blockedInCats(addr, cats)
writeJSON(w, map[string]any{
"ip": ipStr,
"blocked": len(blocked) > 0,
"categories": blocked,
})
}
// POST {"ip":"1.2.3.4"}
func (s *Server) handleAddWhite(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var body struct {
IP string `json:"ip"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "bad json", http.StatusBadRequest)
return
}
addr, err := netip.ParseAddr(strings.TrimSpace(body.IP))
if err != nil {
http.Error(w, "invalid IP", http.StatusBadRequest)
return
}
if err := s.rdb.Set(r.Context(), keyWhite(addr), "1", 0).Err(); err != nil {
http.Error(w, "redis", http.StatusInternalServerError)
return
}
s.ranger.addWhite(addr) // immediate local effect
writeJSON(w, map[string]string{"status": "whitelisted"})
}
func (s *Server) handleCats(w http.ResponseWriter, _ *http.Request) {
s.ranger.mu.RLock()
cats := make([]string, 0, len(s.ranger.blocks))
for c := range s.ranger.blocks {
cats = append(cats, c)
}
s.ranger.mu.RUnlock()
sort.Strings(cats)
writeJSON(w, map[string]any{"categories": cats})
}
func writeJSON(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(v)
}
// -----------------------------------------------------------------------------
// MAIN
// -----------------------------------------------------------------------------
func main() {
cfg := loadConfig()
ctx := context.Background()
rdb := redis.NewClient(&redis.Options{Addr: cfg.RedisAddr})
if err := rdb.Ping(ctx).Err(); err != nil {
log.Fatalf("redis: %v", err)
}
// enable keyspace events (if not already set in redis.conf)
_ = rdb.ConfigSet(ctx, "notify-keyspace-events", "KEx").Err()
ranger := newRanger()
if err := loadFromRedis(ctx, rdb, ranger); err != nil {
log.Println("initial load error:", err)
}
subscribeKeyspace(ctx, rdb, ranger)
if cfg.IsWorker {
go syncLoop(ctx, cfg, rdb, ranger)
}
srv := &Server{ranger: ranger, rdb: rdb}
log.Println("listening on :8080 (worker:", cfg.IsWorker, ")")
if err := http.ListenAndServe(":8080", srv.routes()); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.Fatal(err)
}
}