package main import ( "bufio" "context" "encoding/json" "expvar" "fmt" "log" "math/big" "math/bits" "net" "net/http" "net/netip" "os" "strconv" "strings" "sync" "time" lru "github.com/hashicorp/golang-lru/v2" "github.com/redis/go-redis/v9" ) var ( ctx = context.Background() redisAddr = getenv("REDIS_ADDR", "10.10.5.249:6379") //redisAddr = getenv("REDIS_ADDR", "localhost:6379") redisTTL = time.Hour * 24 cacheSize = 100_000 blocklistCats = []string{"generic"} rdb *redis.Client ipCache *lru.Cache[string, []string] // Metrics hits = expvar.NewInt("cache_hits") misses = expvar.NewInt("cache_misses") queries = expvar.NewInt("ip_queries") ) var ( totalBlockedIPs = expvar.NewInt("total_blocked_ips") totalWhitelistEntries = expvar.NewInt("total_whitelist_entries") ) func updateTotalsFromRedis() { go func() { blockCount := 0 iter := rdb.Scan(ctx, 0, "bl:*", 0).Iterator() for iter.Next(ctx) { blockCount++ } totalBlockedIPs.Set(int64(blockCount)) whiteCount := 0 iter = rdb.Scan(ctx, 0, "wl:*", 0).Iterator() for iter.Next(ctx) { whiteCount++ } totalWhitelistEntries.Set(int64(whiteCount)) }() } func startMetricUpdater() { ticker := time.NewTicker(10 * time.Second) go func() { for { updateTotalsFromRedis() <-ticker.C } }() } // -------------------------------------------- // INIT + MAIN // -------------------------------------------- func main() { var err error // Redis client rdb = redis.NewClient(&redis.Options{Addr: redisAddr}) if err := rdb.Ping(ctx).Err(); err != nil { log.Fatalf("redis: %v", err) } // LRU cache ipCache, err = lru.New[string, []string](cacheSize) if err != nil { log.Fatalf("cache init: %v", err) } startMetricUpdater() // Admin load all blocklists (on demand or scheduled) go func() { if getenv("IMPORT_RIRS", "1") == "1" { log.Println("Lade IP-Ranges aus RIRs...") if err := importRIRDataToRedis(); err != nil { log.Fatalf("import error: %v", err) } log.Println("✅ Import abgeschlossen.") } }() // Routes http.HandleFunc("/check/", handleCheck) http.HandleFunc("/whitelist", handleWhitelist) http.HandleFunc("/info", handleInfo) http.Handle("/debug/vars", http.DefaultServeMux) log.Println("🚀 Server läuft auf :8080") log.Fatal(http.ListenAndServe(":8080", nil)) } func getenv(k, fallback string) string { if v := os.Getenv(k); v != "" { return v } return fallback } // -------------------------------------------- // IP CHECK API // -------------------------------------------- func handleCheck(w http.ResponseWriter, r *http.Request) { ipStr := strings.TrimPrefix(r.URL.Path, "/check/") addr, err := netip.ParseAddr(ipStr) if err != nil { http.Error(w, "invalid IP", 400) return } cats := blocklistCats if q := r.URL.Query().Get("cats"); q != "" { cats = strings.Split(q, ",") } queries.Add(1) blockedCats, err := checkIP(addr, cats) if err != nil { http.Error(w, "lookup error", 500) return } writeJSON(w, map[string]any{ "ip": ipStr, "blocked": len(blockedCats) > 0, "categories": blockedCats, }) } func checkIP(ip netip.Addr, cats []string) ([]string, error) { if res, ok := ipCache.Get(ip.String()); ok { hits.Add(1) return res, nil } matches := []string{} for _, cat := range cats { iter := rdb.Scan(ctx, 0, "bl:"+cat+":*", 0).Iterator() for iter.Next(ctx) { key := iter.Val() parts := strings.SplitN(key, ":", 3) if len(parts) != 3 { continue } pfx, err := netip.ParsePrefix(parts[2]) if err != nil { continue } if pfx.Contains(ip) { matches = append(matches, cat) break } } } misses.Add(1) ipCache.Add(ip.String(), matches) return matches, nil } // -------------------------------------------- // WHITELIST API (optional extension) // -------------------------------------------- func handleWhitelist(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { http.Error(w, "method not allowed", 405) return } var body struct { IP string `json:"ip"` } if err := json.NewDecoder(r.Body).Decode(&body); err != nil { http.Error(w, "bad request", 400) return } addr, err := netip.ParseAddr(body.IP) if err != nil { http.Error(w, "invalid IP", 400) return } // Add to whitelist (Redis key like wl:) if err := rdb.Set(ctx, "wl:"+addr.String(), "1", 0).Err(); err != nil { http.Error(w, "redis error", 500) return } ipCache.Add(addr.String(), nil) writeJSON(w, map[string]string{"status": "whitelisted"}) } // -------------------------------------------- // ADMIN INFO // -------------------------------------------- func handleInfo(w http.ResponseWriter, _ *http.Request) { stats := map[string]any{ "cache_size": ipCache.Len(), "ttl_hours": redisTTL.Hours(), "redis": redisAddr, } writeJSON(w, stats) } // -------------------------------------------- // UTIL // -------------------------------------------- func writeJSON(w http.ResponseWriter, v any) { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(v) } // -------------------------------------------- // RIR DATA IMPORT (ALL COUNTRIES) // -------------------------------------------- var rirFiles = []string{ "https://ftp.ripe.net/pub/stats/ripencc/delegated-ripencc-latest", "https://ftp.apnic.net/stats/apnic/delegated-apnic-latest", "https://ftp.arin.net/pub/stats/arin/delegated-arin-extended-latest", "https://ftp.lacnic.net/pub/stats/lacnic/delegated-lacnic-latest", "https://ftp.afrinic.net/pub/stats/afrinic/delegated-afrinic-extended-latest", } func importRIRDataToRedis() error { wg := sync.WaitGroup{} sem := make(chan struct{}, 5) for _, url := range rirFiles { wg.Add(1) sem <- struct{}{} go func(url string) { defer wg.Done() defer func() { <-sem }() fmt.Println("Start: ", url) if err := fetchAndStore(url); err != nil { log.Printf("❌ Fehler bei %s: %v", url, err) } fmt.Println("Done: ", url) }(url) } wg.Wait() return nil } func fetchAndStore(url string) error { resp, err := http.Get(url) if err != nil { return err } defer resp.Body.Close() scanner := bufio.NewScanner(resp.Body) for scanner.Scan() { line := scanner.Text() if strings.HasPrefix(line, "#") || !strings.Contains(line, "|ipv") { continue } fields := strings.Split(line, "|") if len(fields) < 7 { continue } country := strings.ToLower(fields[1]) ipType := fields[2] start := fields[3] count := fields[4] if ipType != "ipv4" && ipType != "ipv6" { continue } if start == "24.152.36.0" { fmt.Printf("💡 Testing summarizeIPv4CIDRs(%s, %s)\n", start, count) num, _ := strconv.ParseUint(count, 10, 64) for _, cidr := range summarizeCIDRs(start, num) { fmt.Println(" →", cidr) } } //cidrList := summarizeToCIDRs(start, count, ipType) numIPs, _ := strconv.ParseUint(count, 10, 64) cidrList := summarizeCIDRs(start, numIPs) //log.Printf("[%s] %s/%s (%s) → %d Netze", strings.ToUpper(country), start, count, ipType, len(cidrList)) for _, cidr := range cidrList { prefix, err := netip.ParsePrefix(cidr) if err != nil { continue } key := "bl:" + country + ":" + prefix.String() //fmt.Println(key) _ = rdb.Set(ctx, key, "1", redisTTL).Err() } } return scanner.Err() } // -------------------------------------------- // IP RANGE SUMMARIZER // -------------------------------------------- func summarizeCIDRs(startIP string, count uint64) []string { var result []string if count == 0 { return result } ip := net.ParseIP(startIP) if ip == nil { return result } // IPv4-Pfad --------------------------------------------------------------- if v4 := ip.To4(); v4 != nil { start := ip4ToUint(v4) end := start + uint32(count) - 1 for start <= end { prefix := 32 - uint32(bits.TrailingZeros32(start)) for (start + (1 << (32 - prefix)) - 1) > end { prefix++ } result = append(result, fmt.Sprintf("%s/%d", uintToIP4(start), prefix)) start += 1 << (32 - prefix) } return result } // IPv6-Pfad --------------------------------------------------------------- startBig := ip6ToBig(ip) // Startadresse endBig := new(big.Int).Add(startBig, // Endadresse new(big.Int).Sub(new(big.Int).SetUint64(count), big.NewInt(1))) for startBig.Cmp(endBig) <= 0 { // größter Block, der am Start ausgerichtet ist prefix := 128 - trailingZeros128(bigToIP6(startBig)) // so lange verkleinern, bis Block in Fenster passt for { blockSize := new(big.Int).Lsh(big.NewInt(1), uint(128-prefix)) blockEnd := new(big.Int).Add(startBig, new(big.Int).Sub(blockSize, big.NewInt(1))) if blockEnd.Cmp(endBig) <= 0 { break } prefix++ } result = append(result, fmt.Sprintf("%s/%d", bigToIP6(startBig), prefix)) // zum nächsten Subnetz springen step := new(big.Int).Lsh(big.NewInt(1), uint(128-prefix)) startBig = new(big.Int).Add(startBig, step) } return result } /* ---------- Hilfsfunktionen IPv4 ---------- */ func ip4ToUint(ip net.IP) uint32 { return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3]) } func uintToIP4(v uint32) net.IP { return net.IPv4(byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) } /* ---------- Hilfsfunktionen IPv6 ---------- */ func ip6ToBig(ip net.IP) *big.Int { return new(big.Int).SetBytes(ip.To16()) // garantiert 16 Byte } func bigToIP6(v *big.Int) net.IP { b := v.Bytes() if len(b) < 16 { // von links auf 16 Byte auffüllen pad := make([]byte, 16-len(b)) b = append(pad, b...) } return net.IP(b) } // Anzahl der Null-Bits am wenigst-signifikanten Ende (LSB) eines IPv6-Werts func trailingZeros128(ip net.IP) int { b := ip.To16() tz := 0 for i := 15; i >= 0; i-- { // letzte Byte zuerst (LSB) if b[i] == 0 { tz += 8 } else { tz += bits.TrailingZeros8(b[i]) break } } return tz }