diff --git a/main.go b/main.go index b075e38..1cb2913 100644 --- a/main.go +++ b/main.go @@ -20,31 +20,28 @@ import ( ) // ----------------------------------------------------------------------------- -// CONFIGURATION +// CONFIGURATION & ENV // ----------------------------------------------------------------------------- type Source struct { - Category string // e.g. "spam", "tor", "malware" - URL []string // one or many URLs belonging to this category + Category string + URL []string } type Config struct { RedisAddr string - Sources []Source // grouped by category - TTLHours int // TTL for block entries in Redis + Sources []Source + TTLHours int + IsWorker bool // true ⇒ lädt Blocklisten & schreibt sie nach Redis } func loadConfig() Config { - // default single source + // default Blocklist source srcs := []Source{{ Category: "generic", URL: []string{"https://raw.githubusercontent.com/firehol/blocklist-ipsets/master/firehol_level1.netset"}, }} - /* - ENV format supporting many URLs per category: - BLOCKLIST_SOURCES="spam:https://a.net|https://b.net,tor:https://c.net;https://d.net" - */ if env := os.Getenv("BLOCKLIST_SOURCES"); env != "" { srcs = nil for _, spec := range strings.Split(env, ",") { @@ -70,15 +67,18 @@ func loadConfig() Config { } } - ttl := 720 // 30 days + ttl := 720 if env := os.Getenv("TTL_HOURS"); env != "" { fmt.Sscanf(env, "%d", &ttl) } + isWorker := strings.ToLower(os.Getenv("ROLE")) == "worker" + return Config{ RedisAddr: getenv("REDIS_ADDR", "redis:6379"), Sources: srcs, TTLHours: ttl, + IsWorker: isWorker, } } @@ -90,14 +90,14 @@ func getenv(k, def string) string { } // ----------------------------------------------------------------------------- -// REDIS KEY HELPERS +// REDIS KEYS // ----------------------------------------------------------------------------- func keyBlock(cat string, p netip.Prefix) string { return "bl:" + cat + ":" + p.String() } func keyWhite(a netip.Addr) string { return "wl:" + a.String() } // ----------------------------------------------------------------------------- -// IN-MEMORY RANGER +// RANGER – thread‑safe in‑memory index // ----------------------------------------------------------------------------- type Ranger struct { @@ -119,12 +119,41 @@ func (r *Ranger) resetBlocks(m map[string]map[netip.Prefix]struct{}) { r.mu.Unlock() } +func (r *Ranger) resetWhites(set map[netip.Addr]struct{}) { + r.mu.Lock() + r.whites = set + r.mu.Unlock() +} + +func (r *Ranger) addBlock(cat string, p netip.Prefix) { + r.mu.Lock() + if _, ok := r.blocks[cat]; !ok { + r.blocks[cat] = make(map[netip.Prefix]struct{}) + } + r.blocks[cat][p] = struct{}{} + r.mu.Unlock() +} + +func (r *Ranger) removeBlock(cat string, p netip.Prefix) { + r.mu.Lock() + if m, ok := r.blocks[cat]; ok { + delete(m, p) + } + r.mu.Unlock() +} + func (r *Ranger) addWhite(a netip.Addr) { r.mu.Lock() r.whites[a] = struct{}{} r.mu.Unlock() } +func (r *Ranger) removeWhite(a netip.Addr) { + r.mu.Lock() + delete(r.whites, a) + r.mu.Unlock() +} + func (r *Ranger) blockedInCats(a netip.Addr, cats []string) []string { r.mu.RLock() defer r.mu.RUnlock() @@ -155,18 +184,81 @@ func (r *Ranger) blockedInCats(a netip.Addr, cats []string) []string { } // ----------------------------------------------------------------------------- -// SYNC WORKER +// INITIAL LOAD FROM REDIS (baseline before keyspace events) // ----------------------------------------------------------------------------- +func loadFromRedis(ctx context.Context, rdb *redis.Client, r *Ranger) error { + // 1) Blocks + blocks := make(map[string]map[netip.Prefix]struct{}) + iter := rdb.Scan(ctx, 0, "bl:*", 0).Iterator() + for iter.Next(ctx) { + key := iter.Val() // bl:: + parts := strings.SplitN(key, ":", 3) + if len(parts) != 3 { + continue + } + cat, cidr := parts[1], parts[2] + p, err := netip.ParsePrefix(cidr) + if err != nil { + continue + } + if _, ok := blocks[cat]; !ok { + blocks[cat] = map[netip.Prefix]struct{}{} + } + blocks[cat][p] = struct{}{} + } + if err := iter.Err(); err != nil { + return err + } + r.resetBlocks(blocks) + + // 2) Whites + whites := make(map[netip.Addr]struct{}) + wIter := rdb.Scan(ctx, 0, "wl:*", 0).Iterator() + for wIter.Next(ctx) { + ip := strings.TrimPrefix(wIter.Val(), "wl:") + if a, err := netip.ParseAddr(ip); err == nil { + whites[a] = struct{}{} + } + } + if err := wIter.Err(); err != nil { + return err + } + r.resetWhites(whites) + return nil +} + +// ----------------------------------------------------------------------------- +// SYNC WORKER (only on ROLE=worker) +// ----------------------------------------------------------------------------- + +func syncLoop(ctx context.Context, cfg Config, rdb *redis.Client, ranger *Ranger) { + if err := syncOnce(ctx, cfg, rdb, ranger); err != nil { + log.Println("initial sync:", err) + } + ticker := time.NewTicker(6 * time.Hour) + for { + select { + case <-ticker.C: + if err := syncOnce(ctx, cfg, rdb, ranger); err != nil { + log.Println("sync loop:", err) + } + case <-ctx.Done(): + ticker.Stop() + return + } + } +} + func syncOnce(ctx context.Context, cfg Config, rdb *redis.Client, ranger *Ranger) error { expiry := time.Duration(cfg.TTLHours) * time.Hour newBlocks := make(map[string]map[netip.Prefix]struct{}) for _, src := range cfg.Sources { for _, url := range src.URL { - if err := processURL(ctx, url, func(p netip.Prefix) { + if err := fetchList(ctx, url, func(p netip.Prefix) { if _, ok := newBlocks[src.Category]; !ok { - newBlocks[src.Category] = make(map[netip.Prefix]struct{}) + newBlocks[src.Category] = map[netip.Prefix]struct{}{} } newBlocks[src.Category][p] = struct{}{} _ = rdb.Set(ctx, keyBlock(src.Category, p), "1", expiry).Err() @@ -175,25 +267,11 @@ func syncOnce(ctx context.Context, cfg Config, rdb *redis.Client, ranger *Ranger } } } - ranger.resetBlocks(newBlocks) return nil } -func loadWhites(ctx context.Context, rdb *redis.Client, ranger *Ranger) error { - iter := rdb.Scan(ctx, 0, "wl:*", 0).Iterator() - for iter.Next(ctx) { - key := iter.Val() // "wl:1.2.3.4" - ipStr := strings.TrimPrefix(key, "wl:") - if ip, err := netip.ParseAddr(ipStr); err == nil { - ranger.addWhite(ip) - } - } - return iter.Err() -} - -func processURL(ctx context.Context, url string, cb func(netip.Prefix)) error { - fmt.Println("Process URL:", url) +func fetchList(ctx context.Context, url string, cb func(netip.Prefix)) error { req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) resp, err := http.DefaultClient.Do(req) if err != nil { @@ -203,7 +281,6 @@ func processURL(ctx context.Context, url string, cb func(netip.Prefix)) error { if resp.StatusCode != http.StatusOK { return fmt.Errorf("%s -> %s", url, resp.Status) } - fmt.Println("Done.") return parseStream(resp.Body, cb) } @@ -229,12 +306,56 @@ func parseStream(r io.Reader, cb func(netip.Prefix)) error { return s.Err() } +// ----------------------------------------------------------------------------- +// KEYSPACE SUBSCRIBER – instant propagation +// ----------------------------------------------------------------------------- + +func subscribeKeyspace(ctx context.Context, rdb *redis.Client, ranger *Ranger) { + pubsub := rdb.PSubscribe(ctx, "__keyspace@0__:bl:*", "__keyspace@0__:wl:*") + go func() { + for msg := range pubsub.Channel() { + key := strings.TrimPrefix(msg.Channel, "__keyspace@0__:") + payload := msg.Payload + if strings.HasPrefix(key, "wl:") { + ipStr := strings.TrimPrefix(key, "wl:") + addr, err := netip.ParseAddr(ipStr) + if err != nil { + continue + } + switch payload { + case "set": + ranger.addWhite(addr) + case "del", "expired": + ranger.removeWhite(addr) + } + continue + } + if strings.HasPrefix(key, "bl:") { + parts := strings.SplitN(key, ":", 3) + if len(parts) != 3 { + continue + } + cat, cidr := parts[1], parts[2] + p, err := netip.ParsePrefix(cidr) + if err != nil { + continue + } + switch payload { + case "set": + ranger.addBlock(cat, p) + case "del", "expired": + ranger.removeBlock(cat, p) + } + } + } + }() +} + // ----------------------------------------------------------------------------- // HTTP SERVER // ----------------------------------------------------------------------------- type Server struct { - cfg Config ranger *Ranger rdb *redis.Client } @@ -249,22 +370,34 @@ func (s *Server) routes() http.Handler { func (s *Server) handleCheck(w http.ResponseWriter, r *http.Request) { ipStr := strings.TrimPrefix(r.URL.Path, "/check/") - addr, err := netip.ParseAddr(ipStr) - if err != nil { - http.Error(w, "bad ip", http.StatusBadRequest) + if ipStr == "" { + http.Error(w, "missing IP", http.StatusBadRequest) return } - var cats []string - if q := strings.TrimSpace(r.URL.Query().Get("cats")); q != "" { - cats = strings.Split(q, ",") + addr, err := netip.ParseAddr(ipStr) + if err != nil { + http.Error(w, "invalid IP", http.StatusBadRequest) + return } + + catsParam := strings.TrimSpace(r.URL.Query().Get("cats")) + var cats []string + if catsParam != "" { + cats = strings.Split(catsParam, ",") + } + blocked := s.ranger.blockedInCats(addr, cats) - writeJSON(w, map[string]any{"ip": ipStr, "blocked": len(blocked) > 0, "categories": blocked}) + writeJSON(w, map[string]any{ + "ip": ipStr, + "blocked": len(blocked) > 0, + "categories": blocked, + }) } +// POST {"ip":"1.2.3.4"} func (s *Server) handleAddWhite(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { - http.Error(w, "POST only", http.StatusMethodNotAllowed) + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } var body struct { @@ -276,14 +409,14 @@ func (s *Server) handleAddWhite(w http.ResponseWriter, r *http.Request) { } addr, err := netip.ParseAddr(strings.TrimSpace(body.IP)) if err != nil { - http.Error(w, "bad ip", http.StatusBadRequest) + http.Error(w, "invalid IP", http.StatusBadRequest) return } if err := s.rdb.Set(r.Context(), keyWhite(addr), "1", 0).Err(); err != nil { http.Error(w, "redis", http.StatusInternalServerError) return } - s.ranger.addWhite(addr) + s.ranger.addWhite(addr) // immediate local effect writeJSON(w, map[string]string{"status": "whitelisted"}) } @@ -309,34 +442,29 @@ func writeJSON(w http.ResponseWriter, v any) { func main() { cfg := loadConfig() + ctx := context.Background() rdb := redis.NewClient(&redis.Options{Addr: cfg.RedisAddr}) - ctx := context.Background() if err := rdb.Ping(ctx).Err(); err != nil { log.Fatalf("redis: %v", err) } + // enable keyspace events (if not already set in redis.conf) + _ = rdb.ConfigSet(ctx, "notify-keyspace-events", "KEx").Err() + ranger := newRanger() - if err := syncOnce(ctx, cfg, rdb, ranger); err != nil { - log.Println("initial sync:", err) + if err := loadFromRedis(ctx, rdb, ranger); err != nil { + log.Println("initial load error:", err) } - if err := loadWhites(ctx, rdb, ranger); err != nil { - log.Println("loadWhites:", err) + subscribeKeyspace(ctx, rdb, ranger) + + if cfg.IsWorker { + go syncLoop(ctx, cfg, rdb, ranger) } - go func() { - ticker := time.NewTicker(2 * time.Hour) - defer ticker.Stop() - for range ticker.C { - if err := syncOnce(ctx, cfg, rdb, ranger); err != nil { - log.Println("sync:", err) - } - } - }() - - srv := &Server{cfg: cfg, ranger: ranger, rdb: rdb} - log.Println("listening on :8080") + srv := &Server{ranger: ranger, rdb: rdb} + log.Println("listening on :8080 (worker:", cfg.IsWorker, ")") if err := http.ListenAndServe(":8080", srv.routes()); err != nil && !errors.Is(err, http.ErrServerClosed) { log.Fatal(err) }