package main import ( "bufio" "context" "encoding/binary" "encoding/json" "expvar" "fmt" "io" "log" "math/big" "math/bits" "net" "net/http" "net/netip" "os" "strconv" "strings" "sync" "time" lru "github.com/hashicorp/golang-lru/v2" "github.com/redis/go-redis/v9" ) var ( ctx = context.Background() redisAddr = getenv("REDIS_ADDR", "10.10.5.249:6379") //redisAddr = getenv("REDIS_ADDR", "localhost:6379") redisTTL = time.Hour * 24 cacheSize = 100_000 blocklistCats = []string{"generic"} rdb *redis.Client ipCache *lru.Cache[string, []string] // Metrics hits = expvar.NewInt("cache_hits") misses = expvar.NewInt("cache_misses") queries = expvar.NewInt("ip_queries") ) var ( totalBlockedIPs = expvar.NewInt("total_blocked_ips") totalWhitelistEntries = expvar.NewInt("total_whitelist_entries") ) func updateTotalsFromRedis() { go func() { blockCount := 0 iter := rdb.Scan(ctx, 0, "bl:*", 0).Iterator() for iter.Next(ctx) { blockCount++ } totalBlockedIPs.Set(int64(blockCount)) whiteCount := 0 iter = rdb.Scan(ctx, 0, "wl:*", 0).Iterator() for iter.Next(ctx) { whiteCount++ } totalWhitelistEntries.Set(int64(whiteCount)) }() } func startMetricUpdater() { ticker := time.NewTicker(10 * time.Second) go func() { for { updateTotalsFromRedis() <-ticker.C } }() } // // // type Source struct { Category string URL []string } type Config struct { RedisAddr string Sources []Source TTLHours int IsWorker bool // true ⇒ lädt Blocklisten & schreibt sie nach Redis } func loadConfig() Config { // default Blocklist source srcs := []Source{{ Category: "generic", URL: []string{ "https://raw.githubusercontent.com/firehol/blocklist-ipsets/master/firehol_level1.netset", "https://raw.githubusercontent.com/bitwire-it/ipblocklist/refs/heads/main/ip-list.txt", "https://ipv64.net/blocklists/countries/ipv64_blocklist_RU.txt", "https://ipv64.net/blocklists/countries/ipv64_blocklist_CN.txt", }, }, } if env := os.Getenv("BLOCKLIST_SOURCES"); env != "" { srcs = nil for _, spec := range strings.Split(env, ",") { spec = strings.TrimSpace(spec) if spec == "" { continue } parts := strings.SplitN(spec, ":", 2) if len(parts) != 2 { continue } cat := strings.TrimSpace(parts[0]) raw := strings.FieldsFunc(parts[1], func(r rune) bool { return r == '|' || r == ';' }) var urls []string for _, u := range raw { if u = strings.TrimSpace(u); u != "" { urls = append(urls, u) } } if len(urls) > 0 { srcs = append(srcs, Source{Category: cat, URL: urls}) } } } ttl := 24 if env := os.Getenv("TTL_HOURS"); env != "" { fmt.Sscanf(env, "%d", &ttl) } isWorker := strings.ToLower(os.Getenv("ROLE")) == "worker" return Config{ //RedisAddr: getenv("REDIS_ADDR", "redis:6379"), RedisAddr: getenv("REDIS_ADDR", "10.10.5.249:6379"), Sources: srcs, TTLHours: ttl, IsWorker: isWorker, } } // Alle gültigen ISO 3166-1 Alpha-2 Ländercodes (abgekürzt, reale Liste ist länger) var allCountryCodes = []string{ "AD", "AE", "AF", "AG", "AI", "AL", "AM", "AO", "AR", "AT", "AU", "AZ", "BA", "BB", "BD", "BE", "BF", "BG", "BH", "BI", "BJ", "BN", "BO", "BR", "BS", "BT", "BW", "BY", "BZ", "CA", "CD", "CF", "CG", "CH", "CI", "CL", "CM", "CN", "CO", "CR", "CU", "CV", "CY", "CZ", "DE", "DJ", "DK", "DM", "DO", "DZ", "EC", "EE", "EG", "ER", "ES", "ET", "FI", "FJ", "FM", "FR", "GA", "GB", "GD", "GE", "GH", "GM", "GN", "GQ", "GR", "GT", "GW", "GY", "HK", "HN", "HR", "HT", "HU", "ID", "IE", "IL", "IN", "IQ", "IR", "IS", "IT", "JM", "JO", "JP", "KE", "KG", "KH", "KI", "KM", "KN", "KP", "KR", "KW", "KZ", "LA", "LB", "LC", "LI", "LK", "LR", "LS", "LT", "LU", "LV", "LY", "MA", "MC", "MD", "ME", "MG", "MH", "MK", "ML", "MM", "MN", "MR", "MT", "MU", "MV", "MW", "MX", "MY", "MZ", "NA", "NE", "NG", "NI", "NL", "NO", "NP", "NR", "NZ", "OM", "PA", "PE", "PG", "PH", "PK", "PL", "PT", "PW", "PY", "QA", "RO", "RS", "RU", "RW", "SA", "SB", "SC", "SD", "SE", "SG", "SI", "SK", "SL", "SM", "SN", "SO", "SR", "ST", "SV", "SY", "SZ", "TD", "TG", "TH", "TJ", "TL", "TM", "TN", "TO", "TR", "TT", "TV", "TZ", "UA", "UG", "US", "UY", "UZ", "VC", "VE", "VN", "VU", "WS", "YE", "ZA", "ZM", "ZW", } // Hauptfunktion: gibt alle IPv4-Ranges eines Landes (CIDR) aus allen RIRs zurück func GetIPRangesByCountry(countryCode string) ([]string, error) { var allCIDRs []string upperCode := strings.ToUpper(countryCode) for _, url := range rirFiles { resp, err := http.Get(url) if err != nil { return nil, fmt.Errorf("fehler beim abrufen von %s: %w", url, err) } defer resp.Body.Close() scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { line := scanner.Text() if strings.HasPrefix(line, "2") || strings.HasPrefix(line, "#") { continue // Kommentar oder Header } if strings.Contains(line, "|"+upperCode+"|ipv4|") { fields := strings.Split(line, "|") if len(fields) < 5 { continue } ipStart := fields[3] count, _ := strconv.Atoi(fields[4]) cidrs := summarizeCIDR(ipStart, count) allCIDRs = append(allCIDRs, cidrs...) } } } return allCIDRs, nil } // Hilfsfunktion: Start-IP + Anzahl → []CIDR func summarizeCIDR(start string, count int) []string { var cidrs []string ip := net.ParseIP(start).To4() startInt := ipToInt(ip) for count > 0 { maxSize := 32 for maxSize > 0 { mask := 1 << uint(32-maxSize) if startInt%uint32(mask) == 0 && mask <= count { break } maxSize-- } cidr := fmt.Sprintf("%s/%d", intToIP(startInt), maxSize) cidrs = append(cidrs, cidr) count -= 1 << uint(32-maxSize) startInt += uint32(1 << uint(32-maxSize)) } return cidrs } func ipToInt(ip net.IP) uint32 { return uint32(ip[0])<<24 + uint32(ip[1])<<16 + uint32(ip[2])<<8 + uint32(ip[3]) } func intToIP(i uint32) net.IP { return net.IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)) } func keyBlock(cat string, p netip.Prefix) string { return "bl:" + cat + ":" + p.String() } func LoadAllCountryPrefixesIntoRedisAndRanger( rdb *redis.Client, ttlHours int, ) error { for _, countryCode := range allCountryCodes { expiry := time.Duration(ttlHours) * time.Hour results := make(map[string][]netip.Prefix) fmt.Printf("💡 Loading %s...\n", countryCode) cidrs, err := GetIPRangesByCountry(countryCode) if err != nil { log.Printf("Error at %s: %v", countryCode, err) } fmt.Println("✅ Got " + strconv.Itoa(len(cidrs)) + " Ranges for Country " + countryCode) var validPrefixes []netip.Prefix for _, c := range cidrs { prefix, err := netip.ParsePrefix(c) if err != nil { log.Printf("CIDR invalid [%s]: %v", c, err) continue } validPrefixes = append(validPrefixes, prefix) } fmt.Println("✅ Got " + strconv.Itoa(len(validPrefixes)) + " valid Prefixes for Country " + countryCode) if len(validPrefixes) > 0 { results[countryCode] = validPrefixes } // Nach Verarbeitung: alles in Ranger + Redis eintragen for code, prefixes := range results { for _, p := range prefixes { key := keyBlock(code, p) if err := rdb.Set(ctx, key, "1", expiry).Err(); err != nil { log.Printf("Redis-Error at %s: %v", key, err) } } fmt.Println("✅ Import Subset " + strconv.Itoa(len(prefixes)) + " Entries") } fmt.Println("✅ Import done!") fmt.Println("--------------------------------------------------") } return nil } func syncLoop(ctx context.Context, cfg Config, rdb *redis.Client) { fmt.Println("💡 Loading Lists...") if err := syncOnce(ctx, cfg, rdb); err != nil { log.Println("initial sync:", err) } fmt.Println("✅ Loading Lists Done.") ticker := time.NewTicker(30 * time.Minute) for { select { case <-ticker.C: fmt.Println("💡 Loading Lists Timer...") if err := syncOnce(ctx, cfg, rdb); err != nil { log.Println("sync loop:", err) } fmt.Println("✅ Loading Lists Timer Done.") case <-ctx.Done(): ticker.Stop() return } } } func syncOnce(ctx context.Context, cfg Config, rdb *redis.Client) error { expiry := time.Duration(cfg.TTLHours) * time.Hour newBlocks := make(map[string]map[netip.Prefix]struct{}) for _, src := range cfg.Sources { for _, url := range src.URL { fmt.Println("💡 Loading List " + src.Category + " : " + url) if err := fetchList(ctx, url, func(p netip.Prefix) { if _, ok := newBlocks[src.Category]; !ok { newBlocks[src.Category] = map[netip.Prefix]struct{}{} } newBlocks[src.Category][p] = struct{}{} _ = rdb.Set(ctx, keyBlock(src.Category, p), "1", expiry).Err() }); err != nil { fmt.Println("❌ Fail.") return err } fmt.Println("✅ Done.") } } return nil } func fetchList(ctx context.Context, url string, cb func(netip.Prefix)) error { req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) resp, err := http.DefaultClient.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("%s -> %s", url, resp.Status) } return parseStream(resp.Body, cb) } func parseStream(r io.Reader, cb func(netip.Prefix)) error { s := bufio.NewScanner(r) for s.Scan() { line := strings.TrimSpace(s.Text()) if line == "" || strings.HasPrefix(line, "#") { continue } if p, err := netip.ParsePrefix(line); err == nil { cb(p) continue } if addr, err := netip.ParseAddr(line); err == nil { plen := 32 if addr.Is6() { plen = 128 } cb(netip.PrefixFrom(addr, plen)) } } return s.Err() } // -------------------------------------------- // INIT + MAIN // -------------------------------------------- func main() { if getenv("IMPORTER", "0") == "1" { //Hier alles doof. selbe funktion wie unten. muss durch individuallisten ersetzt werden... cfg := loadConfig() rdb = redis.NewClient(&redis.Options{Addr: redisAddr}) /*if err := LoadAllCountryPrefixesIntoRedisAndRanger(rdb, cfg.TTLHours); err != nil { log.Fatalf("Fehler beim Laden aller Länderranges: %v", err) }*/ syncLoop(ctx, cfg, rdb) log.Println("🚀 Import erfolgreich!") } else { var err error // Redis client rdb = redis.NewClient(&redis.Options{Addr: redisAddr}) if err := rdb.Ping(ctx).Err(); err != nil { log.Fatalf("redis: %v", err) } // LRU cache ipCache, err = lru.New[string, []string](cacheSize) if err != nil { log.Fatalf("cache init: %v", err) } startMetricUpdater() // Admin load all blocklists (on demand or scheduled) go func() { if getenv("IMPORT_RIRS", "0") == "1" { log.Println("Lade IP-Ranges aus RIRs...") if err := importRIRDataToRedis(); err != nil { log.Fatalf("import error: %v", err) } log.Println("✅ Import abgeschlossen.") } }() // Routes http.HandleFunc("/check/", handleCheck) http.HandleFunc("/whitelist", handleWhitelist) http.HandleFunc("/info", handleInfo) http.Handle("/debug/vars", http.DefaultServeMux) log.Println("🚀 Server läuft auf :8080") log.Fatal(http.ListenAndServe(":8080", nil)) } } func getenv(k, fallback string) string { if v := os.Getenv(k); v != "" { return v } return fallback } // -------------------------------------------- // IP CHECK API // -------------------------------------------- func handleCheck(w http.ResponseWriter, r *http.Request) { ipStr := strings.TrimPrefix(r.URL.Path, "/check/") addr, err := netip.ParseAddr(ipStr) if err != nil { http.Error(w, "invalid IP", 400) return } cats := blocklistCats if q := r.URL.Query().Get("cats"); q != "" { cats = strings.Split(q, ",") } queries.Add(1) blockedCats, err := checkIP(addr, cats) if err != nil { http.Error(w, "lookup error", 500) return } fmt.Println("----------------") writeJSON(w, map[string]any{ "ip": ipStr, "blocked": len(blockedCats) > 0, "categories": blockedCats, }) } // liefert alle möglichen Präfixe dieser IP, beginnend beim längsten (/32 oder /128) func supernets(ip netip.Addr) []string { if ip.Is4() { fmt.Println("💡 DEBUG: supernets | Is4") a := ip.As4() // Kopie addressierbar machen u := binary.BigEndian.Uint32(a[:]) // jetzt darf man slicen fmt.Println("💡 DEBUG: supernets | a + u", a, u) supers := make([]string, 33) // /32 … /0 for bits := 32; bits >= 0; bits-- { mask := uint32(0xffffffff) << (32 - bits) n := u & mask addr := netip.AddrFrom4([4]byte{ byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n), }) supers[32-bits] = fmt.Sprintf("%s/%d", addr, bits) } fmt.Println("💡 DEBUG: supernets | supers", supers) return supers } a := ip.As16() // Kopie addressierbar supers := make([]string, 129) // /128 … /0 for bits := 128; bits >= 0; bits-- { b := a // Wert-Kopie für Modifikation // vollständige Bytes auf 0 setzen full := (128 - bits) / 8 for i := 0; i < full; i++ { b[15-i] = 0 } // Restbits maskieren rem := (128 - bits) % 8 if rem != 0 { b[15-full] &= 0xFF << rem } addr := netip.AddrFrom16(b) supers[128-bits] = fmt.Sprintf("%s/%d", addr, bits) } fmt.Println("Supernets-v6", supers) return supers } func checkIP(ip netip.Addr, cats []string) ([]string, error) { // 1) Cache-Treffer? if res, ok := ipCache.Get(ip.String()); ok { fmt.Println("💡 DEBUG: Cache-Hit") hits.Add(1) return res, nil } // 2) alle Supernetze der IP (≤32 bzw. ≤128 Stück) supers := supernets(ip) fmt.Println("💡 DEBUG: checkIP | supers |", supers) fmt.Println("💡 DEBUG: checkIP | ip |", ip) fmt.Println("💡 DEBUG: checkIP | cats |", cats) // 3) Pipeline – jeweils *eine* EXISTS-Abfrage pro Kategorie pipe := rdb.Pipeline() existsCmds := make([]*redis.IntCmd, len(cats)) for i, cat := range cats { keys := make([]string, len(supers)) for j, pfx := range supers { keys[j] = "bl:" + cat + ":" + pfx } fmt.Println("💡 DEBUG: checkIP | keys |", keys) existsCmds[i] = pipe.Exists(ctx, keys...) } if _, err := pipe.Exec(ctx); err != nil && err != redis.Nil { return nil, err } // 4) Ergebnis auswerten matches := make([]string, 0, len(cats)) for i, cat := range cats { if existsCmds[i].Val() > 0 { matches = append(matches, cat) fmt.Println("💡 DEBUG: checkIP | matches:cat |", cat) } } fmt.Println("💡 DEBUG: checkIP | matches |", matches) // 5) Cache befüllen und zurück misses.Add(1) ipCache.Add(ip.String(), matches) return matches, nil } // -------------------------------------------- // WHITELIST API (optional extension) // -------------------------------------------- func handleWhitelist(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "method not allowed", http.StatusMethodNotAllowed) return } var body struct { IP string `json:"ip"` } if err := json.NewDecoder(r.Body).Decode(&body); err != nil { http.Error(w, "bad request", 400) return } addr, err := netip.ParseAddr(body.IP) if err != nil { http.Error(w, "invalid IP", 400) return } // Add to whitelist (Redis key like wl:) if err := rdb.Set(ctx, "wl:"+addr.String(), "1", 0).Err(); err != nil { http.Error(w, "redis error", 500) return } ipCache.Add(addr.String(), nil) writeJSON(w, map[string]string{"status": "whitelisted"}) } // -------------------------------------------- // ADMIN INFO // -------------------------------------------- func handleInfo(w http.ResponseWriter, _ *http.Request) { stats := map[string]any{ "cache_size": ipCache.Len(), "ttl_hours": redisTTL.Hours(), "redis": redisAddr, } writeJSON(w, stats) } // -------------------------------------------- // UTIL // -------------------------------------------- func writeJSON(w http.ResponseWriter, v any) { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(v) } // -------------------------------------------- // RIR DATA IMPORT (ALL COUNTRIES) // -------------------------------------------- var rirFiles = []string{ "https://ftp.ripe.net/pub/stats/ripencc/delegated-ripencc-latest", "https://ftp.apnic.net/stats/apnic/delegated-apnic-latest", "https://ftp.arin.net/pub/stats/arin/delegated-arin-extended-latest", "https://ftp.lacnic.net/pub/stats/lacnic/delegated-lacnic-latest", "https://ftp.afrinic.net/pub/stats/afrinic/delegated-afrinic-extended-latest", } func importRIRDataToRedis() error { wg := sync.WaitGroup{} sem := make(chan struct{}, 5) for _, url := range rirFiles { wg.Add(1) sem <- struct{}{} go func(url string) { defer wg.Done() defer func() { <-sem }() fmt.Println("Start: ", url) if err := fetchAndStore(url); err != nil { log.Printf("❌ Fehler bei %s: %v", url, err) } fmt.Println("Done: ", url) }(url) } wg.Wait() return nil } func fetchAndStore(url string) error { resp, err := http.Get(url) if err != nil { return err } defer resp.Body.Close() scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { line := scanner.Text() if strings.HasPrefix(line, "#") || !strings.Contains(line, "|ipv") { continue } fields := strings.Split(line, "|") if len(fields) < 7 { continue } country := strings.ToLower(fields[1]) ipType := fields[2] start := fields[3] count := fields[4] if ipType != "ipv4" && ipType != "ipv6" { continue } if start == "24.152.36.0" { fmt.Printf("💡 Testing summarizeIPv4CIDRs(%s, %s)\n", start, count) num, _ := strconv.ParseUint(count, 10, 64) for _, cidr := range summarizeCIDRs(start, num) { fmt.Println(" →", cidr) } } //cidrList := summarizeToCIDRs(start, count, ipType) numIPs, _ := strconv.ParseUint(count, 10, 64) cidrList := summarizeCIDRs(start, numIPs) //log.Printf("[%s] %s/%s (%s) → %d Netze", strings.ToUpper(country), start, count, ipType, len(cidrList)) for _, cidr := range cidrList { prefix, err := netip.ParsePrefix(cidr) if err != nil { continue } key := "bl:" + country + ":" + prefix.String() //fmt.Println(key) _ = rdb.Set(ctx, key, "1", redisTTL).Err() } } return scanner.Err() } // -------------------------------------------- // IP RANGE SUMMARIZER // -------------------------------------------- func summarizeCIDRs(startIP string, count uint64) []string { var result []string if count == 0 { return result } ip := net.ParseIP(startIP) if ip == nil { return result } // IPv4-Pfad --------------------------------------------------------------- if v4 := ip.To4(); v4 != nil { start := ip4ToUint(v4) end := start + uint32(count) - 1 for start <= end { prefix := 32 - uint32(bits.TrailingZeros32(start)) for (start + (1 << (32 - prefix)) - 1) > end { prefix++ } result = append(result, fmt.Sprintf("%s/%d", uintToIP4(start), prefix)) start += 1 << (32 - prefix) } return result } // IPv6-Pfad --------------------------------------------------------------- startBig := ip6ToBig(ip) // Startadresse endBig := new(big.Int).Add(startBig, // Endadresse new(big.Int).Sub(new(big.Int).SetUint64(count), big.NewInt(1))) for startBig.Cmp(endBig) <= 0 { // größter Block, der am Start ausgerichtet ist prefix := 128 - trailingZeros128(bigToIP6(startBig)) // so lange verkleinern, bis Block in Fenster passt for { blockSize := new(big.Int).Lsh(big.NewInt(1), uint(128-prefix)) blockEnd := new(big.Int).Add(startBig, new(big.Int).Sub(blockSize, big.NewInt(1))) if blockEnd.Cmp(endBig) <= 0 { break } prefix++ } result = append(result, fmt.Sprintf("%s/%d", bigToIP6(startBig), prefix)) // zum nächsten Subnetz springen step := new(big.Int).Lsh(big.NewInt(1), uint(128-prefix)) startBig = new(big.Int).Add(startBig, step) } return result } /* ---------- Hilfsfunktionen IPv4 ---------- */ func ip4ToUint(ip net.IP) uint32 { return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3]) } func uintToIP4(v uint32) net.IP { return net.IPv4(byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) } /* ---------- Hilfsfunktionen IPv6 ---------- */ func ip6ToBig(ip net.IP) *big.Int { return new(big.Int).SetBytes(ip.To16()) // garantiert 16 Byte } func bigToIP6(v *big.Int) net.IP { b := v.Bytes() if len(b) < 16 { // von links auf 16 Byte auffüllen pad := make([]byte, 16-len(b)) b = append(pad, b...) } return net.IP(b) } // Anzahl der Null-Bits am wenigst-signifikanten Ende (LSB) eines IPv6-Werts func trailingZeros128(ip net.IP) int { b := ip.To16() tz := 0 for i := 15; i >= 0; i-- { // letzte Byte zuerst (LSB) if b[i] == 0 { tz += 8 } else { tz += bits.TrailingZeros8(b[i]) break } } return tz }