From 224999bf65e26d4518aa4f85d75e04c4c5c36f0d Mon Sep 17 00:00:00 2001 From: jbergner Date: Sat, 14 Jun 2025 11:27:07 +0200 Subject: [PATCH] test1 --- __main.go | 780 ++++++++++++++++++++++++++++++++++++++++++++ go.mod | 13 +- go.sum | 16 + main.go | 940 ++++++++++++++---------------------------------------- 4 files changed, 1056 insertions(+), 693 deletions(-) create mode 100644 __main.go diff --git a/__main.go b/__main.go new file mode 100644 index 0000000..27ead99 --- /dev/null +++ b/__main.go @@ -0,0 +1,780 @@ +package main + +import ( + "bufio" + "context" + "encoding/binary" + "encoding/json" + "expvar" + "fmt" + "io" + "log" + "math/big" + "math/bits" + "net" + "net/http" + "net/netip" + "os" + "strconv" + "strings" + "sync" + "time" + + lru "github.com/hashicorp/golang-lru/v2" + "github.com/redis/go-redis/v9" +) + +var ( + ctx = context.Background() + redisAddr = getenv("REDIS_ADDR", "10.10.5.249:6379") + //redisAddr = getenv("REDIS_ADDR", "localhost:6379") + redisTTL = time.Hour * 24 + cacheSize = 100_000 + blocklistCats = []string{"generic"} + rdb *redis.Client + ipCache *lru.Cache[string, []string] + + // Metrics + hits = expvar.NewInt("cache_hits") + misses = expvar.NewInt("cache_misses") + queries = expvar.NewInt("ip_queries") +) + +var ( + totalBlockedIPs = expvar.NewInt("total_blocked_ips") + totalWhitelistEntries = expvar.NewInt("total_whitelist_entries") +) + +func updateTotalsFromRedis() { + go func() { + blockCount := 0 + iter := rdb.Scan(ctx, 0, "bl:*", 0).Iterator() + for iter.Next(ctx) { + blockCount++ + } + totalBlockedIPs.Set(int64(blockCount)) + + whiteCount := 0 + iter = rdb.Scan(ctx, 0, "wl:*", 0).Iterator() + for iter.Next(ctx) { + whiteCount++ + } + totalWhitelistEntries.Set(int64(whiteCount)) + }() +} + +func startMetricUpdater() { + ticker := time.NewTicker(10 * time.Second) + go func() { + for { + updateTotalsFromRedis() + <-ticker.C + } + }() +} + +// +// +// + +type Source struct { + Category string + URL []string +} + +type Config struct { + RedisAddr string + Sources []Source + TTLHours int + IsWorker bool // true ⇒ lädt Blocklisten & schreibt sie nach Redis +} + +func loadConfig() Config { + // default Blocklist source + srcs := []Source{{ + Category: "generic", + URL: []string{ + "https://raw.githubusercontent.com/firehol/blocklist-ipsets/master/firehol_level1.netset", + "https://raw.githubusercontent.com/bitwire-it/ipblocklist/refs/heads/main/ip-list.txt", + "https://ipv64.net/blocklists/countries/ipv64_blocklist_RU.txt", + "https://ipv64.net/blocklists/countries/ipv64_blocklist_CN.txt", + }, + }, + } + + if env := os.Getenv("BLOCKLIST_SOURCES"); env != "" { + srcs = nil + for _, spec := range strings.Split(env, ",") { + spec = strings.TrimSpace(spec) + if spec == "" { + continue + } + parts := strings.SplitN(spec, ":", 2) + if len(parts) != 2 { + continue + } + cat := strings.TrimSpace(parts[0]) + raw := strings.FieldsFunc(parts[1], func(r rune) bool { return r == '|' || r == ';' }) + var urls []string + for _, u := range raw { + if u = strings.TrimSpace(u); u != "" { + urls = append(urls, u) + } + } + if len(urls) > 0 { + srcs = append(srcs, Source{Category: cat, URL: urls}) + } + } + } + + ttl := 24 + if env := os.Getenv("TTL_HOURS"); env != "" { + fmt.Sscanf(env, "%d", &ttl) + } + + isWorker := strings.ToLower(os.Getenv("ROLE")) == "worker" + + return Config{ + //RedisAddr: getenv("REDIS_ADDR", "redis:6379"), + RedisAddr: getenv("REDIS_ADDR", "10.10.5.249:6379"), + Sources: srcs, + TTLHours: ttl, + IsWorker: isWorker, + } +} + +// Alle gültigen ISO 3166-1 Alpha-2 Ländercodes (abgekürzt, reale Liste ist länger) +var allCountryCodes = []string{ + "AD", "AE", "AF", "AG", "AI", "AL", "AM", "AO", "AR", "AT", "AU", "AZ", + "BA", "BB", "BD", "BE", "BF", "BG", "BH", "BI", "BJ", "BN", "BO", "BR", "BS", + "BT", "BW", "BY", "BZ", "CA", "CD", "CF", "CG", "CH", "CI", "CL", "CM", "CN", + "CO", "CR", "CU", "CV", "CY", "CZ", "DE", "DJ", "DK", "DM", "DO", "DZ", "EC", + "EE", "EG", "ER", "ES", "ET", "FI", "FJ", "FM", "FR", "GA", "GB", "GD", "GE", + "GH", "GM", "GN", "GQ", "GR", "GT", "GW", "GY", "HK", "HN", "HR", "HT", "HU", + "ID", "IE", "IL", "IN", "IQ", "IR", "IS", "IT", "JM", "JO", "JP", "KE", "KG", + "KH", "KI", "KM", "KN", "KP", "KR", "KW", "KZ", "LA", "LB", "LC", "LI", "LK", + "LR", "LS", "LT", "LU", "LV", "LY", "MA", "MC", "MD", "ME", "MG", "MH", "MK", + "ML", "MM", "MN", "MR", "MT", "MU", "MV", "MW", "MX", "MY", "MZ", "NA", "NE", + "NG", "NI", "NL", "NO", "NP", "NR", "NZ", "OM", "PA", "PE", "PG", "PH", "PK", + "PL", "PT", "PW", "PY", "QA", "RO", "RS", "RU", "RW", "SA", "SB", "SC", "SD", + "SE", "SG", "SI", "SK", "SL", "SM", "SN", "SO", "SR", "ST", "SV", "SY", "SZ", + "TD", "TG", "TH", "TJ", "TL", "TM", "TN", "TO", "TR", "TT", "TV", "TZ", "UA", + "UG", "US", "UY", "UZ", "VC", "VE", "VN", "VU", "WS", "YE", "ZA", "ZM", "ZW", +} + +// Hauptfunktion: gibt alle IPv4-Ranges eines Landes (CIDR) aus allen RIRs zurück +func GetIPRangesByCountry(countryCode string) ([]string, error) { + var allCIDRs []string + upperCode := strings.ToUpper(countryCode) + + for _, url := range rirFiles { + resp, err := http.Get(url) + if err != nil { + return nil, fmt.Errorf("fehler beim abrufen von %s: %w", url, err) + } + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "2") || strings.HasPrefix(line, "#") { + continue // Kommentar oder Header + } + if strings.Contains(line, "|"+upperCode+"|ipv4|") { + fields := strings.Split(line, "|") + if len(fields) < 5 { + continue + } + ipStart := fields[3] + count, _ := strconv.Atoi(fields[4]) + cidrs := summarizeCIDR(ipStart, count) + allCIDRs = append(allCIDRs, cidrs...) + } + } + } + return allCIDRs, nil +} + +// Hilfsfunktion: Start-IP + Anzahl → []CIDR +func summarizeCIDR(start string, count int) []string { + var cidrs []string + ip := net.ParseIP(start).To4() + startInt := ipToInt(ip) + + for count > 0 { + maxSize := 32 + for maxSize > 0 { + mask := 1 << uint(32-maxSize) + if startInt%uint32(mask) == 0 && mask <= count { + break + } + maxSize-- + } + cidr := fmt.Sprintf("%s/%d", intToIP(startInt), maxSize) + cidrs = append(cidrs, cidr) + count -= 1 << uint(32-maxSize) + startInt += uint32(1 << uint(32-maxSize)) + } + return cidrs +} + +func ipToInt(ip net.IP) uint32 { + return uint32(ip[0])<<24 + uint32(ip[1])<<16 + uint32(ip[2])<<8 + uint32(ip[3]) +} + +func intToIP(i uint32) net.IP { + return net.IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)) +} + +func keyBlock(cat string, p netip.Prefix) string { return "bl:" + cat + ":" + p.String() } + +func LoadAllCountryPrefixesIntoRedisAndRanger( + rdb *redis.Client, + ttlHours int, +) error { + for _, countryCode := range allCountryCodes { + + expiry := time.Duration(ttlHours) * time.Hour + results := make(map[string][]netip.Prefix) + + fmt.Printf("💡 Loading %s...\n", countryCode) + cidrs, err := GetIPRangesByCountry(countryCode) + if err != nil { + log.Printf("Error at %s: %v", countryCode, err) + } + fmt.Println("✅ Got " + strconv.Itoa(len(cidrs)) + " Ranges for Country " + countryCode) + var validPrefixes []netip.Prefix + for _, c := range cidrs { + prefix, err := netip.ParsePrefix(c) + if err != nil { + log.Printf("CIDR invalid [%s]: %v", c, err) + continue + } + validPrefixes = append(validPrefixes, prefix) + } + fmt.Println("✅ Got " + strconv.Itoa(len(validPrefixes)) + " valid Prefixes for Country " + countryCode) + + if len(validPrefixes) > 0 { + results[countryCode] = validPrefixes + } + + // Nach Verarbeitung: alles in Ranger + Redis eintragen + for code, prefixes := range results { + for _, p := range prefixes { + key := keyBlock(code, p) + if err := rdb.Set(ctx, key, "1", expiry).Err(); err != nil { + log.Printf("Redis-Error at %s: %v", key, err) + } + } + fmt.Println("✅ Import Subset " + strconv.Itoa(len(prefixes)) + " Entries") + } + fmt.Println("✅ Import done!") + fmt.Println("--------------------------------------------------") + } + + return nil +} + +func syncLoop(ctx context.Context, cfg Config, rdb *redis.Client) { + + fmt.Println("💡 Loading Lists...") + if err := syncOnce(ctx, cfg, rdb); err != nil { + log.Println("initial sync:", err) + } + fmt.Println("✅ Loading Lists Done.") + ticker := time.NewTicker(30 * time.Minute) + for { + select { + case <-ticker.C: + fmt.Println("💡 Loading Lists Timer...") + if err := syncOnce(ctx, cfg, rdb); err != nil { + log.Println("sync loop:", err) + } + fmt.Println("✅ Loading Lists Timer Done.") + case <-ctx.Done(): + ticker.Stop() + return + } + } +} + +func syncOnce(ctx context.Context, cfg Config, rdb *redis.Client) error { + expiry := time.Duration(cfg.TTLHours) * time.Hour + newBlocks := make(map[string]map[netip.Prefix]struct{}) + + for _, src := range cfg.Sources { + for _, url := range src.URL { + fmt.Println("💡 Loading List " + src.Category + " : " + url) + if err := fetchList(ctx, url, func(p netip.Prefix) { + if _, ok := newBlocks[src.Category]; !ok { + newBlocks[src.Category] = map[netip.Prefix]struct{}{} + } + newBlocks[src.Category][p] = struct{}{} + _ = rdb.Set(ctx, keyBlock(src.Category, p), "1", expiry).Err() + + }); err != nil { + fmt.Println("❌ Fail.") + return err + } + fmt.Println("✅ Done.") + } + } + return nil +} + +func fetchList(ctx context.Context, url string, cb func(netip.Prefix)) error { + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("%s -> %s", url, resp.Status) + } + return parseStream(resp.Body, cb) +} + +func parseStream(r io.Reader, cb func(netip.Prefix)) error { + s := bufio.NewScanner(r) + for s.Scan() { + line := strings.TrimSpace(s.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + if p, err := netip.ParsePrefix(line); err == nil { + cb(p) + continue + } + if addr, err := netip.ParseAddr(line); err == nil { + plen := 32 + if addr.Is6() { + plen = 128 + } + cb(netip.PrefixFrom(addr, plen)) + } + } + return s.Err() +} + +// -------------------------------------------- +// INIT + MAIN +// -------------------------------------------- + +func main() { + + if getenv("IMPORTER", "0") == "1" { + //Hier alles doof. selbe funktion wie unten. muss durch individuallisten ersetzt werden... + cfg := loadConfig() + rdb = redis.NewClient(&redis.Options{Addr: redisAddr}) + /*if err := LoadAllCountryPrefixesIntoRedisAndRanger(rdb, cfg.TTLHours); err != nil { + log.Fatalf("Fehler beim Laden aller Länderranges: %v", err) + }*/ + syncLoop(ctx, cfg, rdb) + log.Println("🚀 Import erfolgreich!") + } else { + var err error + + // Redis client + rdb = redis.NewClient(&redis.Options{Addr: redisAddr}) + if err := rdb.Ping(ctx).Err(); err != nil { + log.Fatalf("redis: %v", err) + } + + // LRU cache + ipCache, err = lru.New[string, []string](cacheSize) + if err != nil { + log.Fatalf("cache init: %v", err) + } + + startMetricUpdater() + + // Admin load all blocklists (on demand or scheduled) + go func() { + if getenv("IMPORT_RIRS", "0") == "1" { + log.Println("Lade IP-Ranges aus RIRs...") + if err := importRIRDataToRedis(); err != nil { + log.Fatalf("import error: %v", err) + } + log.Println("✅ Import abgeschlossen.") + } + }() + + // Routes + http.HandleFunc("/check/", handleCheck) + http.HandleFunc("/whitelist", handleWhitelist) + http.HandleFunc("/info", handleInfo) + http.Handle("/debug/vars", http.DefaultServeMux) + + log.Println("🚀 Server läuft auf :8080") + log.Fatal(http.ListenAndServe(":8080", nil)) + } + +} + +func getenv(k, fallback string) string { + if v := os.Getenv(k); v != "" { + return v + } + return fallback +} + +// -------------------------------------------- +// IP CHECK API +// -------------------------------------------- + +func handleCheck(w http.ResponseWriter, r *http.Request) { + ipStr := strings.TrimPrefix(r.URL.Path, "/check/") + addr, err := netip.ParseAddr(ipStr) + if err != nil { + http.Error(w, "invalid IP", 400) + return + } + + cats := blocklistCats + if q := r.URL.Query().Get("cats"); q != "" { + cats = strings.Split(q, ",") + } + + queries.Add(1) + blockedCats, err := checkIP(addr, cats) + if err != nil { + http.Error(w, "lookup error", 500) + return + } + + fmt.Println("----------------") + + writeJSON(w, map[string]any{ + "ip": ipStr, + "blocked": len(blockedCats) > 0, + "categories": blockedCats, + }) +} + +// liefert alle möglichen Präfixe dieser IP, beginnend beim längsten (/32 oder /128) +func supernets(ip netip.Addr) []string { + if ip.Is4() { + fmt.Println("💡 DEBUG: supernets | Is4") + a := ip.As4() // Kopie addressierbar machen + u := binary.BigEndian.Uint32(a[:]) // jetzt darf man slicen + fmt.Println("💡 DEBUG: supernets | a + u", a, u) + + supers := make([]string, 33) // /32 … /0 + for bits := 32; bits >= 0; bits-- { + mask := uint32(0xffffffff) << (32 - bits) + n := u & mask + addr := netip.AddrFrom4([4]byte{ + byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n), + }) + supers[32-bits] = fmt.Sprintf("%s/%d", addr, bits) + } + fmt.Println("💡 DEBUG: supernets | supers", supers) + return supers + } + + a := ip.As16() // Kopie addressierbar + supers := make([]string, 129) // /128 … /0 + for bits := 128; bits >= 0; bits-- { + b := a // Wert-Kopie für Modifikation + + // vollständige Bytes auf 0 setzen + full := (128 - bits) / 8 + for i := 0; i < full; i++ { + b[15-i] = 0 + } + // Restbits maskieren + rem := (128 - bits) % 8 + if rem != 0 { + b[15-full] &= 0xFF << rem + } + + addr := netip.AddrFrom16(b) + supers[128-bits] = fmt.Sprintf("%s/%d", addr, bits) + } + fmt.Println("Supernets-v6", supers) + return supers +} + +func checkIP(ip netip.Addr, cats []string) ([]string, error) { + // 1) Cache-Treffer? + if res, ok := ipCache.Get(ip.String()); ok { + fmt.Println("💡 DEBUG: Cache-Hit") + hits.Add(1) + return res, nil + } + + // 2) alle Supernetze der IP (≤32 bzw. ≤128 Stück) + supers := supernets(ip) + fmt.Println("💡 DEBUG: checkIP | supers |", supers) + fmt.Println("💡 DEBUG: checkIP | ip |", ip) + fmt.Println("💡 DEBUG: checkIP | cats |", cats) + + // 3) Pipeline – jeweils *eine* EXISTS-Abfrage pro Kategorie + pipe := rdb.Pipeline() + existsCmds := make([]*redis.IntCmd, len(cats)) + + for i, cat := range cats { + keys := make([]string, len(supers)) + for j, pfx := range supers { + keys[j] = "bl:" + cat + ":" + pfx + } + fmt.Println("💡 DEBUG: checkIP | keys |", keys) + existsCmds[i] = pipe.Exists(ctx, keys...) + } + + if _, err := pipe.Exec(ctx); err != nil && err != redis.Nil { + return nil, err + } + + // 4) Ergebnis auswerten + matches := make([]string, 0, len(cats)) + for i, cat := range cats { + if existsCmds[i].Val() > 0 { + matches = append(matches, cat) + fmt.Println("💡 DEBUG: checkIP | matches:cat |", cat) + } + } + fmt.Println("💡 DEBUG: checkIP | matches |", matches) + + // 5) Cache befüllen und zurück + misses.Add(1) + ipCache.Add(ip.String(), matches) + return matches, nil +} + +// -------------------------------------------- +// WHITELIST API (optional extension) +// -------------------------------------------- + +func handleWhitelist(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + var body struct { + IP string `json:"ip"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, "bad request", 400) + return + } + addr, err := netip.ParseAddr(body.IP) + if err != nil { + http.Error(w, "invalid IP", 400) + return + } + // Add to whitelist (Redis key like wl:) + if err := rdb.Set(ctx, "wl:"+addr.String(), "1", 0).Err(); err != nil { + http.Error(w, "redis error", 500) + return + } + ipCache.Add(addr.String(), nil) + writeJSON(w, map[string]string{"status": "whitelisted"}) +} + +// -------------------------------------------- +// ADMIN INFO +// -------------------------------------------- + +func handleInfo(w http.ResponseWriter, _ *http.Request) { + stats := map[string]any{ + "cache_size": ipCache.Len(), + "ttl_hours": redisTTL.Hours(), + "redis": redisAddr, + } + writeJSON(w, stats) +} + +// -------------------------------------------- +// UTIL +// -------------------------------------------- + +func writeJSON(w http.ResponseWriter, v any) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(v) +} + +// -------------------------------------------- +// RIR DATA IMPORT (ALL COUNTRIES) +// -------------------------------------------- + +var rirFiles = []string{ + "https://ftp.ripe.net/pub/stats/ripencc/delegated-ripencc-latest", + "https://ftp.apnic.net/stats/apnic/delegated-apnic-latest", + "https://ftp.arin.net/pub/stats/arin/delegated-arin-extended-latest", + "https://ftp.lacnic.net/pub/stats/lacnic/delegated-lacnic-latest", + "https://ftp.afrinic.net/pub/stats/afrinic/delegated-afrinic-extended-latest", +} + +func importRIRDataToRedis() error { + wg := sync.WaitGroup{} + sem := make(chan struct{}, 5) + + for _, url := range rirFiles { + wg.Add(1) + sem <- struct{}{} + go func(url string) { + defer wg.Done() + defer func() { <-sem }() + fmt.Println("Start: ", url) + if err := fetchAndStore(url); err != nil { + log.Printf("❌ Fehler bei %s: %v", url, err) + } + fmt.Println("Done: ", url) + }(url) + } + wg.Wait() + return nil +} + +func fetchAndStore(url string) error { + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "#") || !strings.Contains(line, "|ipv") { + continue + } + fields := strings.Split(line, "|") + if len(fields) < 7 { + continue + } + country := strings.ToLower(fields[1]) + ipType := fields[2] + start := fields[3] + count := fields[4] + + if ipType != "ipv4" && ipType != "ipv6" { + continue + } + + if start == "24.152.36.0" { + fmt.Printf("💡 Testing summarizeIPv4CIDRs(%s, %s)\n", start, count) + num, _ := strconv.ParseUint(count, 10, 64) + for _, cidr := range summarizeCIDRs(start, num) { + fmt.Println(" →", cidr) + } + } + + //cidrList := summarizeToCIDRs(start, count, ipType) + numIPs, _ := strconv.ParseUint(count, 10, 64) + cidrList := summarizeCIDRs(start, numIPs) + //log.Printf("[%s] %s/%s (%s) → %d Netze", strings.ToUpper(country), start, count, ipType, len(cidrList)) + for _, cidr := range cidrList { + prefix, err := netip.ParsePrefix(cidr) + if err != nil { + continue + } + key := "bl:" + country + ":" + prefix.String() + //fmt.Println(key) + _ = rdb.Set(ctx, key, "1", redisTTL).Err() + } + } + return scanner.Err() +} + +// -------------------------------------------- +// IP RANGE SUMMARIZER +// -------------------------------------------- + +func summarizeCIDRs(startIP string, count uint64) []string { + var result []string + + if count == 0 { + return result + } + ip := net.ParseIP(startIP) + if ip == nil { + return result + } + + // IPv4-Pfad --------------------------------------------------------------- + if v4 := ip.To4(); v4 != nil { + start := ip4ToUint(v4) + end := start + uint32(count) - 1 + + for start <= end { + prefix := 32 - uint32(bits.TrailingZeros32(start)) + for (start + (1 << (32 - prefix)) - 1) > end { + prefix++ + } + result = append(result, + fmt.Sprintf("%s/%d", uintToIP4(start), prefix)) + start += 1 << (32 - prefix) + } + return result + } + + // IPv6-Pfad --------------------------------------------------------------- + startBig := ip6ToBig(ip) // Startadresse + endBig := new(big.Int).Add(startBig, // Endadresse + new(big.Int).Sub(new(big.Int).SetUint64(count), big.NewInt(1))) + + for startBig.Cmp(endBig) <= 0 { + // größter Block, der am Start ausgerichtet ist + prefix := 128 - trailingZeros128(bigToIP6(startBig)) + + // so lange verkleinern, bis Block in Fenster passt + for { + blockSize := new(big.Int).Lsh(big.NewInt(1), uint(128-prefix)) + blockEnd := new(big.Int).Add(startBig, + new(big.Int).Sub(blockSize, big.NewInt(1))) + if blockEnd.Cmp(endBig) <= 0 { + break + } + prefix++ + } + + result = append(result, + fmt.Sprintf("%s/%d", bigToIP6(startBig), prefix)) + + // zum nächsten Subnetz springen + step := new(big.Int).Lsh(big.NewInt(1), uint(128-prefix)) + startBig = new(big.Int).Add(startBig, step) + } + return result +} + +/* ---------- Hilfsfunktionen IPv4 ---------- */ + +func ip4ToUint(ip net.IP) uint32 { + return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3]) +} +func uintToIP4(v uint32) net.IP { + return net.IPv4(byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) +} + +/* ---------- Hilfsfunktionen IPv6 ---------- */ + +func ip6ToBig(ip net.IP) *big.Int { + return new(big.Int).SetBytes(ip.To16()) // garantiert 16 Byte +} +func bigToIP6(v *big.Int) net.IP { + b := v.Bytes() + if len(b) < 16 { // von links auf 16 Byte auffüllen + pad := make([]byte, 16-len(b)) + b = append(pad, b...) + } + return net.IP(b) +} + +// Anzahl der Null-Bits am wenigst-signifikanten Ende (LSB) eines IPv6-Werts +func trailingZeros128(ip net.IP) int { + b := ip.To16() + tz := 0 + for i := 15; i >= 0; i-- { // letzte Byte zuerst (LSB) + if b[i] == 0 { + tz += 8 + } else { + tz += bits.TrailingZeros8(b[i]) + break + } + } + return tz +} diff --git a/go.mod b/go.mod index 289ccd5..a2b04da 100644 --- a/go.mod +++ b/go.mod @@ -3,8 +3,19 @@ module git.send.nrw/sendnrw/flod go 1.24.3 require ( + github.com/hashicorp/golang-lru/v2 v2.0.7 + github.com/prometheus/client_golang v1.22.0 +) + +require ( + github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/hashicorp/golang-lru/v2 v2.0.7 + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/prometheus/client_model v0.6.1 // indirect + github.com/prometheus/common v0.62.0 // indirect + github.com/prometheus/procfs v0.15.1 // indirect github.com/redis/go-redis/v9 v9.10.0 // indirect + golang.org/x/sys v0.30.0 // indirect + google.golang.org/protobuf v1.36.5 // indirect ) diff --git a/go.sum b/go.sum index 6a1b628..3d0d41e 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,24 @@ +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= +github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= +github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= +github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= +github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/redis/go-redis/v9 v9.10.0 h1:FxwK3eV8p/CQa0Ch276C7u2d0eNC9kCmAYQ7mCXCzVs= github.com/redis/go-redis/v9 v9.10.0/go.mod h1:huWgSWd8mW6+m0VPhJjSSQ+d6Nh1VICQ6Q5lHuCH/Iw= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +google.golang.org/protobuf v1.36.5 h1:tPhr+woSbjfYvY6/GPufUoYizxw1cF/yFoxJ2fmpwlM= +google.golang.org/protobuf v1.36.5/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= diff --git a/main.go b/main.go index ffdb42e..0daed75 100644 --- a/main.go +++ b/main.go @@ -3,764 +3,320 @@ package main import ( "bufio" "context" - "encoding/binary" "encoding/json" - "expvar" "fmt" - "io" - "log" - "math/big" - "math/bits" "net" "net/http" "net/netip" - "os" - "strconv" "strings" "sync" + "testing" "time" - lru "github.com/hashicorp/golang-lru/v2" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/redis/go-redis/v9" ) var ( - ctx = context.Background() - redisAddr = getenv("REDIS_ADDR", "10.10.5.249:6379") - //redisAddr = getenv("REDIS_ADDR", "localhost:6379") - redisTTL = time.Hour * 24 - cacheSize = 100_000 - blocklistCats = []string{"generic"} - rdb *redis.Client - ipCache *lru.Cache[string, []string] - - // Metrics - hits = expvar.NewInt("cache_hits") - misses = expvar.NewInt("cache_misses") - queries = expvar.NewInt("ip_queries") -) - -var ( - totalBlockedIPs = expvar.NewInt("total_blocked_ips") - totalWhitelistEntries = expvar.NewInt("total_whitelist_entries") -) - -func updateTotalsFromRedis() { - go func() { - blockCount := 0 - iter := rdb.Scan(ctx, 0, "bl:*", 0).Iterator() - for iter.Next(ctx) { - blockCount++ - } - totalBlockedIPs.Set(int64(blockCount)) - - whiteCount := 0 - iter = rdb.Scan(ctx, 0, "wl:*", 0).Iterator() - for iter.Next(ctx) { - whiteCount++ - } - totalWhitelistEntries.Set(int64(whiteCount)) - }() -} - -func startMetricUpdater() { - ticker := time.NewTicker(10 * time.Second) - go func() { - for { - updateTotalsFromRedis() - <-ticker.C - } - }() -} - -// -// -// - -type Source struct { - Category string - URL []string -} - -type Config struct { - RedisAddr string - Sources []Source - TTLHours int - IsWorker bool // true ⇒ lädt Blocklisten & schreibt sie nach Redis -} - -func loadConfig() Config { - // default Blocklist source - srcs := []Source{{ - Category: "generic", - URL: []string{ - "https://raw.githubusercontent.com/firehol/blocklist-ipsets/master/firehol_level1.netset", - "https://raw.githubusercontent.com/bitwire-it/ipblocklist/refs/heads/main/ip-list.txt", - "", + checkRequests = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "ipcheck_requests_total", + Help: "Total number of IP check requests", }, - }, - } + ) + checkBlocked = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "ipcheck_blocked_total", + Help: "Total number of blocked IPs", + }, + ) + checkWhitelist = prometheus.NewCounter( + prometheus.CounterOpts{ + Name: "ipcheck_whitelisted_total", + Help: "Total number of whitelisted IPs", + }, + ) +) - if env := os.Getenv("BLOCKLIST_SOURCES"); env != "" { - srcs = nil - for _, spec := range strings.Split(env, ",") { - spec = strings.TrimSpace(spec) - if spec == "" { - continue - } - parts := strings.SplitN(spec, ":", 2) - if len(parts) != 2 { - continue - } - cat := strings.TrimSpace(parts[0]) - raw := strings.FieldsFunc(parts[1], func(r rune) bool { return r == '|' || r == ';' }) - var urls []string - for _, u := range raw { - if u = strings.TrimSpace(u); u != "" { - urls = append(urls, u) - } - } - if len(urls) > 0 { - srcs = append(srcs, Source{Category: cat, URL: urls}) - } - } - } - - ttl := 24 - if env := os.Getenv("TTL_HOURS"); env != "" { - fmt.Sscanf(env, "%d", &ttl) - } - - isWorker := strings.ToLower(os.Getenv("ROLE")) == "worker" - - return Config{ - //RedisAddr: getenv("REDIS_ADDR", "redis:6379"), - RedisAddr: getenv("REDIS_ADDR", "10.10.5.249:6379"), - Sources: srcs, - TTLHours: ttl, - IsWorker: isWorker, - } +func init() { + prometheus.MustRegister(checkRequests, checkBlocked, checkWhitelist) } -// Alle gültigen ISO 3166-1 Alpha-2 Ländercodes (abgekürzt, reale Liste ist länger) -var allCountryCodes = []string{ - "AD", "AE", "AF", "AG", "AI", "AL", "AM", "AO", "AR", "AT", "AU", "AZ", - "BA", "BB", "BD", "BE", "BF", "BG", "BH", "BI", "BJ", "BN", "BO", "BR", "BS", - "BT", "BW", "BY", "BZ", "CA", "CD", "CF", "CG", "CH", "CI", "CL", "CM", "CN", - "CO", "CR", "CU", "CV", "CY", "CZ", "DE", "DJ", "DK", "DM", "DO", "DZ", "EC", - "EE", "EG", "ER", "ES", "ET", "FI", "FJ", "FM", "FR", "GA", "GB", "GD", "GE", - "GH", "GM", "GN", "GQ", "GR", "GT", "GW", "GY", "HK", "HN", "HR", "HT", "HU", - "ID", "IE", "IL", "IN", "IQ", "IR", "IS", "IT", "JM", "JO", "JP", "KE", "KG", - "KH", "KI", "KM", "KN", "KP", "KR", "KW", "KZ", "LA", "LB", "LC", "LI", "LK", - "LR", "LS", "LT", "LU", "LV", "LY", "MA", "MC", "MD", "ME", "MG", "MH", "MK", - "ML", "MM", "MN", "MR", "MT", "MU", "MV", "MW", "MX", "MY", "MZ", "NA", "NE", - "NG", "NI", "NL", "NO", "NP", "NR", "NZ", "OM", "PA", "PE", "PG", "PH", "PK", - "PL", "PT", "PW", "PY", "QA", "RO", "RS", "RU", "RW", "SA", "SB", "SC", "SD", - "SE", "SG", "SI", "SK", "SL", "SM", "SN", "SO", "SR", "ST", "SV", "SY", "SZ", - "TD", "TG", "TH", "TJ", "TL", "TM", "TN", "TO", "TR", "TT", "TV", "TZ", "UA", - "UG", "US", "UY", "UZ", "VC", "VE", "VN", "VU", "WS", "YE", "ZA", "ZM", "ZW", -} - -// Hauptfunktion: gibt alle IPv4-Ranges eines Landes (CIDR) aus allen RIRs zurück -func GetIPRangesByCountry(countryCode string) ([]string, error) { - var allCIDRs []string - upperCode := strings.ToUpper(countryCode) - - for _, url := range rirFiles { - resp, err := http.Get(url) - if err != nil { - return nil, fmt.Errorf("fehler beim abrufen von %s: %w", url, err) - } - defer resp.Body.Close() - - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "2") || strings.HasPrefix(line, "#") { - continue // Kommentar oder Header - } - if strings.Contains(line, "|"+upperCode+"|ipv4|") { - fields := strings.Split(line, "|") - if len(fields) < 5 { - continue - } - ipStart := fields[3] - count, _ := strconv.Atoi(fields[4]) - cidrs := summarizeCIDR(ipStart, count) - allCIDRs = append(allCIDRs, cidrs...) - } - } - } - return allCIDRs, nil -} - -// Hilfsfunktion: Start-IP + Anzahl → []CIDR -func summarizeCIDR(start string, count int) []string { - var cidrs []string - ip := net.ParseIP(start).To4() - startInt := ipToInt(ip) - - for count > 0 { - maxSize := 32 - for maxSize > 0 { - mask := 1 << uint(32-maxSize) - if startInt%uint32(mask) == 0 && mask <= count { - break - } - maxSize-- - } - cidr := fmt.Sprintf("%s/%d", intToIP(startInt), maxSize) - cidrs = append(cidrs, cidr) - count -= 1 << uint(32-maxSize) - startInt += uint32(1 << uint(32-maxSize)) - } - return cidrs -} - -func ipToInt(ip net.IP) uint32 { - return uint32(ip[0])<<24 + uint32(ip[1])<<16 + uint32(ip[2])<<8 + uint32(ip[3]) -} - -func intToIP(i uint32) net.IP { - return net.IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i)) -} - -func keyBlock(cat string, p netip.Prefix) string { return "bl:" + cat + ":" + p.String() } - -func LoadAllCountryPrefixesIntoRedisAndRanger( - rdb *redis.Client, - ttlHours int, -) error { - for _, countryCode := range allCountryCodes { - - expiry := time.Duration(ttlHours) * time.Hour - results := make(map[string][]netip.Prefix) - - fmt.Printf("💡 Loading %s...\n", countryCode) - cidrs, err := GetIPRangesByCountry(countryCode) - if err != nil { - log.Printf("Error at %s: %v", countryCode, err) - } - fmt.Println("✅ Got " + strconv.Itoa(len(cidrs)) + " Ranges for Country " + countryCode) - var validPrefixes []netip.Prefix - for _, c := range cidrs { - prefix, err := netip.ParsePrefix(c) - if err != nil { - log.Printf("CIDR invalid [%s]: %v", c, err) - continue - } - validPrefixes = append(validPrefixes, prefix) - } - fmt.Println("✅ Got " + strconv.Itoa(len(validPrefixes)) + " valid Prefixes for Country " + countryCode) - - if len(validPrefixes) > 0 { - results[countryCode] = validPrefixes - } - - // Nach Verarbeitung: alles in Ranger + Redis eintragen - for code, prefixes := range results { - for _, p := range prefixes { - key := keyBlock(code, p) - if err := rdb.Set(ctx, key, "1", expiry).Err(); err != nil { - log.Printf("Redis-Error at %s: %v", key, err) - } - } - fmt.Println("✅ Import Subset " + strconv.Itoa(len(prefixes)) + " Entries") - } - fmt.Println("✅ Import done!") - fmt.Println("--------------------------------------------------") - } - - return nil -} - -func syncLoop(ctx context.Context, cfg Config, rdb *redis.Client) { - - fmt.Println("💡 Loading Lists...") - if err := syncOnce(ctx, cfg, rdb); err != nil { - log.Println("initial sync:", err) - } - fmt.Println("✅ Loading Lists Done.") - ticker := time.NewTicker(30 * time.Minute) - for { - select { - case <-ticker.C: - fmt.Println("💡 Loading Lists Timer...") - if err := syncOnce(ctx, cfg, rdb); err != nil { - log.Println("sync loop:", err) - } - fmt.Println("✅ Loading Lists Timer Done.") - case <-ctx.Done(): - ticker.Stop() - return - } - } -} - -func syncOnce(ctx context.Context, cfg Config, rdb *redis.Client) error { - expiry := time.Duration(cfg.TTLHours) * time.Hour - newBlocks := make(map[string]map[netip.Prefix]struct{}) - - for _, src := range cfg.Sources { - for _, url := range src.URL { - fmt.Println("💡 Loading List " + src.Category + " : " + url) - if err := fetchList(ctx, url, func(p netip.Prefix) { - if _, ok := newBlocks[src.Category]; !ok { - newBlocks[src.Category] = map[netip.Prefix]struct{}{} - } - newBlocks[src.Category][p] = struct{}{} - _ = rdb.Set(ctx, keyBlock(src.Category, p), "1", expiry).Err() - - }); err != nil { - fmt.Println("❌ Fail.") - return err - } - fmt.Println("✅ Done.") - } - } - return nil -} - -func fetchList(ctx context.Context, url string, cb func(netip.Prefix)) error { - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - resp, err := http.DefaultClient.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("%s -> %s", url, resp.Status) - } - return parseStream(resp.Body, cb) -} - -func parseStream(r io.Reader, cb func(netip.Prefix)) error { - s := bufio.NewScanner(r) - for s.Scan() { - line := strings.TrimSpace(s.Text()) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - if p, err := netip.ParsePrefix(line); err == nil { - cb(p) - continue - } - if addr, err := netip.ParseAddr(line); err == nil { - plen := 32 - if addr.Is6() { - plen = 128 - } - cb(netip.PrefixFrom(addr, plen)) - } - } - return s.Err() -} - -// -------------------------------------------- -// INIT + MAIN -// -------------------------------------------- - -func main() { - - if getenv("IMPORTER", "1") == "1" { - //Hier alles doof. selbe funktion wie unten. muss durch individuallisten ersetzt werden... - cfg := loadConfig() - rdb = redis.NewClient(&redis.Options{Addr: redisAddr}) - /*if err := LoadAllCountryPrefixesIntoRedisAndRanger(rdb, cfg.TTLHours); err != nil { - log.Fatalf("Fehler beim Laden aller Länderranges: %v", err) - }*/ - syncLoop(ctx, cfg, rdb) - log.Println("🚀 Import erfolgreich!") - } else { - var err error - - // Redis client - rdb = redis.NewClient(&redis.Options{Addr: redisAddr}) - if err := rdb.Ping(ctx).Err(); err != nil { - log.Fatalf("redis: %v", err) - } - - // LRU cache - ipCache, err = lru.New[string, []string](cacheSize) - if err != nil { - log.Fatalf("cache init: %v", err) - } - - startMetricUpdater() - - // Admin load all blocklists (on demand or scheduled) - go func() { - if getenv("IMPORT_RIRS", "0") == "1" { - log.Println("Lade IP-Ranges aus RIRs...") - if err := importRIRDataToRedis(); err != nil { - log.Fatalf("import error: %v", err) - } - log.Println("✅ Import abgeschlossen.") - } - }() - - // Routes - http.HandleFunc("/check/", handleCheck) - http.HandleFunc("/whitelist", handleWhitelist) - http.HandleFunc("/info", handleInfo) - http.Handle("/debug/vars", http.DefaultServeMux) - - log.Println("🚀 Server läuft auf :8080") - log.Fatal(http.ListenAndServe(":8080", nil)) - } - -} - -func getenv(k, fallback string) string { - if v := os.Getenv(k); v != "" { - return v - } - return fallback -} - -// -------------------------------------------- -// IP CHECK API -// -------------------------------------------- - func handleCheck(w http.ResponseWriter, r *http.Request) { + checkRequests.Inc() + ipStr := strings.TrimPrefix(r.URL.Path, "/check/") - addr, err := netip.ParseAddr(ipStr) + ip, err := netip.ParseAddr(ipStr) if err != nil { - http.Error(w, "invalid IP", 400) + http.Error(w, "invalid IP", http.StatusBadRequest) return } - cats := blocklistCats + cats := []string{"generic", "test"} if q := r.URL.Query().Get("cats"); q != "" { cats = strings.Split(q, ",") } - queries.Add(1) - blockedCats, err := checkIP(addr, cats) + matches, err := checkIP(ip, cats) if err != nil { - http.Error(w, "lookup error", 500) + http.Error(w, "internal error", http.StatusInternalServerError) return } + if len(matches) == 0 { + wl, _ := rdb.Exists(ctx, "wl:"+ip.String()).Result() + if wl > 0 { + checkWhitelist.Inc() + } + } else { + checkBlocked.Inc() + } + writeJSON(w, map[string]any{ "ip": ipStr, - "blocked": len(blockedCats) > 0, - "categories": blockedCats, + "blocked": len(matches) > 0, + "categories": matches, }) } -// liefert alle möglichen Präfixe dieser IP, beginnend beim längsten (/32 oder /128) -func supernets(ip netip.Addr) []string { - if ip.Is4() { - a := ip.As4() // Kopie addressierbar machen - u := binary.BigEndian.Uint32(a[:]) // jetzt darf man slicen +// Redis Client und Context +var ctx = context.Background() +var rdb = redis.NewClient(&redis.Options{ + Addr: "10.10.5.249:6379", +}) - supers := make([]string, 33) // /32 … /0 - for bits := 32; bits >= 0; bits-- { - mask := uint32(0xffffffff) << (32 - bits) - n := u & mask - addr := netip.AddrFrom4([4]byte{ - byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n), - }) - supers[32-bits] = fmt.Sprintf("%s/%d", addr, bits) - } - return supers - } - - a := ip.As16() // Kopie addressierbar - supers := make([]string, 129) // /128 … /0 - for bits := 128; bits >= 0; bits-- { - b := a // Wert-Kopie für Modifikation - - // vollständige Bytes auf 0 setzen - full := (128 - bits) / 8 - for i := 0; i < full; i++ { - b[15-i] = 0 - } - // Restbits maskieren - rem := (128 - bits) % 8 - if rem != 0 { - b[15-full] &= 0xFF << rem - } - - addr := netip.AddrFrom16(b) - supers[128-bits] = fmt.Sprintf("%s/%d", addr, bits) - } - return supers +// Präfix-Cache (pro Kategorie) +type prefixCacheEntry struct { + prefixes []netip.Prefix + expireAt time.Time } +var ( + prefixCache = map[string]prefixCacheEntry{} + prefixCacheMu sync.Mutex +) + +// Prüfen der IP func checkIP(ip netip.Addr, cats []string) ([]string, error) { - // 1) Cache-Treffer? - if res, ok := ipCache.Get(ip.String()); ok { - hits.Add(1) - return res, nil - } - - // 2) alle Supernetze der IP (≤32 bzw. ≤128 Stück) - supers := supernets(ip) - - // 3) Pipeline – jeweils *eine* EXISTS-Abfrage pro Kategorie - pipe := rdb.Pipeline() - existsCmds := make([]*redis.IntCmd, len(cats)) - - for i, cat := range cats { - keys := make([]string, len(supers)) - for j, pfx := range supers { - keys[j] = "bl:" + cat + ":" + pfx - } - existsCmds[i] = pipe.Exists(ctx, keys...) - } - - if _, err := pipe.Exec(ctx); err != nil && err != redis.Nil { + // Whitelist + wl, err := rdb.Exists(ctx, "wl:"+ip.String()).Result() + if err != nil { return nil, err } - - // 4) Ergebnis auswerten - matches := make([]string, 0, len(cats)) - for i, cat := range cats { - if existsCmds[i].Val() > 0 { - matches = append(matches, cat) - } + if wl > 0 { + return []string{}, nil } - // 5) Cache befüllen und zurück - misses.Add(1) - ipCache.Add(ip.String(), matches) + matches := []string{} + for _, cat := range cats { + prefixes, err := loadCategoryPrefixes(cat) + if err != nil { + return nil, err + } + + for _, pfx := range prefixes { + if pfx.Contains(ip) { + fmt.Printf("💡 MATCH: %s in %s (%s)\n", ip, cat, pfx) + matches = append(matches, cat) + break + } + } + } return matches, nil } -// -------------------------------------------- -// WHITELIST API (optional extension) -// -------------------------------------------- +func loadCategoryPrefixes(cat string) ([]netip.Prefix, error) { + prefixCacheMu.Lock() + defer prefixCacheMu.Unlock() -func handleWhitelist(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "method not allowed", http.StatusMethodNotAllowed) - return + entry, ok := prefixCache[cat] + if ok && time.Now().Before(entry.expireAt) { + return entry.prefixes, nil } - var body struct { - IP string `json:"ip"` - } - if err := json.NewDecoder(r.Body).Decode(&body); err != nil { - http.Error(w, "bad request", 400) - return - } - addr, err := netip.ParseAddr(body.IP) + + // Redis HKEYS holen + keys, err := rdb.HKeys(ctx, "bl:"+cat).Result() if err != nil { - http.Error(w, "invalid IP", 400) - return + return nil, err } - // Add to whitelist (Redis key like wl:) - if err := rdb.Set(ctx, "wl:"+addr.String(), "1", 0).Err(); err != nil { - http.Error(w, "redis error", 500) - return + + var prefixes []netip.Prefix + for _, k := range keys { + k = strings.TrimSpace(k) // spaces entfernen! + pfx, err := netip.ParsePrefix(k) + if err == nil { + prefixes = append(prefixes, pfx) + } else { + fmt.Printf("⚠️ Ungültiger Prefix in Redis (%s): %s\n", cat, k) + } } - ipCache.Add(addr.String(), nil) - writeJSON(w, map[string]string{"status": "whitelisted"}) + + prefixCache[cat] = prefixCacheEntry{ + prefixes: prefixes, + expireAt: time.Now().Add(1 * time.Second), + } + + return prefixes, nil } -// -------------------------------------------- -// ADMIN INFO -// -------------------------------------------- - -func handleInfo(w http.ResponseWriter, _ *http.Request) { - stats := map[string]any{ - "cache_size": ipCache.Len(), - "ttl_hours": redisTTL.Hours(), - "redis": redisAddr, - } - writeJSON(w, stats) -} - -// -------------------------------------------- -// UTIL -// -------------------------------------------- - +// JSON-Antwort func writeJSON(w http.ResponseWriter, v any) { w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(v) + enc := json.NewEncoder(w) + enc.SetIndent("", " ") + enc.Encode(v) } -// -------------------------------------------- -// RIR DATA IMPORT (ALL COUNTRIES) -// -------------------------------------------- +// Main + Beispiel-Setup +func main() { -var rirFiles = []string{ - "https://ftp.ripe.net/pub/stats/ripencc/delegated-ripencc-latest", - "https://ftp.apnic.net/stats/apnic/delegated-apnic-latest", - "https://ftp.arin.net/pub/stats/arin/delegated-arin-extended-latest", - "https://ftp.lacnic.net/pub/stats/lacnic/delegated-lacnic-latest", - "https://ftp.afrinic.net/pub/stats/afrinic/delegated-afrinic-extended-latest", -} + if 1 == 1 { + err := importBlocklists() + if err != nil { + fmt.Println("Import-Fehler:", err) + return + } -func importRIRDataToRedis() error { - wg := sync.WaitGroup{} - sem := make(chan struct{}, 5) - - for _, url := range rirFiles { - wg.Add(1) - sem <- struct{}{} - go func(url string) { - defer wg.Done() - defer func() { <-sem }() - fmt.Println("Start: ", url) - if err := fetchAndStore(url); err != nil { - log.Printf("❌ Fehler bei %s: %v", url, err) - } - fmt.Println("Done: ", url) - }(url) + fmt.Println("Blocklisten-Import abgeschlossen.") } - wg.Wait() + + http.HandleFunc("/check/", handleCheck) + http.Handle("/metrics", promhttp.Handler()) + + fmt.Println("Server läuft auf :8080") + http.ListenAndServe(":8080", nil) +} + +// Tests + +func TestCheckIP(t *testing.T) { + // Setup Redis-Daten + rdb.FlushDB(ctx) + rdb.HSet(ctx, "bl:generic", map[string]string{ + "81.232.51.35/32": "1", + "150.242.0.0/20": "1", + }) + rdb.HSet(ctx, "bl:test", map[string]string{ + "203.9.56.0/24": "1", + }) + rdb.Set(ctx, "wl:8.8.8.8", "1", 0) + + tests := []struct { + ip string + expected []string + }{ + {"81.232.51.35", []string{"generic"}}, + {"150.242.5.10", []string{"generic"}}, + {"203.9.56.5", []string{"test"}}, + {"8.8.8.8", []string{}}, // Whitelisted + {"1.1.1.1", []string{}}, + } + + for _, tc := range tests { + addr := netip.MustParseAddr(tc.ip) + cats := []string{"generic", "test"} + res, err := checkIP(addr, cats) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if !equalStringSlices(res, tc.expected) { + t.Errorf("for IP %s: expected %v, got %v", tc.ip, tc.expected, res) + } + } +} + +func equalStringSlices(a, b []string) bool { + if len(a) != len(b) { + return false + } + m := make(map[string]int) + for _, v := range a { + m[v]++ + } + for _, v := range b { + if m[v] == 0 { + return false + } + m[v]-- + } + return true +} + +// Import + +// URL-Liste +var blocklistURLs = map[string]string{ + "firehol": "https://raw.githubusercontent.com/firehol/blocklist-ipsets/master/firehol_level1.netset", + "bitwire": "https://raw.githubusercontent.com/bitwire-it/ipblocklist/refs/heads/main/ip-list.txt", + "RU": "https://ipv64.net/blocklists/countries/ipv64_blocklist_RU.txt", + "CN": "https://ipv64.net/blocklists/countries/ipv64_blocklist_CN.txt", +} + +// Importer +func importBlocklists() error { + for category, url := range blocklistURLs { + fmt.Printf("Lade %s (%s)...\n", category, url) + resp, err := http.Get(url) + if err != nil { + return fmt.Errorf("Fehler beim Laden %s: %v", url, err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return fmt.Errorf("Fehler beim Laden %s: HTTP %d", url, resp.StatusCode) + } + + scanner := bufio.NewScanner(resp.Body) + count := 0 + + pipe := rdb.Pipeline() + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + // Prüfen ob IP oder Prefix + if isValidPrefix(line) { + pipe.HSet(ctx, "bl:"+category, line, 1) + count++ + } else { + fmt.Printf("⚠️ Ungültige Zeile (%s): %s\n", category, line) + } + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("Lesefehler %s: %v", url, err) + } + + // Commit Pipeline + _, err = pipe.Exec(ctx) + if err != nil { + return fmt.Errorf("Redis-Fehler %s: %v", category, err) + } + + fmt.Printf("✅ %d Einträge in Kategorie %s importiert.\n", count, category) + } + return nil } -func fetchAndStore(url string) error { - resp, err := http.Get(url) - if err != nil { - return err - } - defer resp.Body.Close() - - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - line := scanner.Text() - if strings.HasPrefix(line, "#") || !strings.Contains(line, "|ipv") { - continue - } - fields := strings.Split(line, "|") - if len(fields) < 7 { - continue - } - country := strings.ToLower(fields[1]) - ipType := fields[2] - start := fields[3] - count := fields[4] - - if ipType != "ipv4" && ipType != "ipv6" { - continue - } - - if start == "24.152.36.0" { - fmt.Printf("💡 Testing summarizeIPv4CIDRs(%s, %s)\n", start, count) - num, _ := strconv.ParseUint(count, 10, 64) - for _, cidr := range summarizeCIDRs(start, num) { - fmt.Println(" →", cidr) +// Prüft ob eine Zeile ein valides IP-Präfix ist +func isValidPrefix(s string) bool { + // Wenn es kein / enthält → vermutlich /32 oder /128 annehmen + if !strings.Contains(s, "/") { + if ip := net.ParseIP(s); ip != nil { + if ip.To4() != nil { + s = s + "/32" + } else { + s = s + "/128" } - } - - //cidrList := summarizeToCIDRs(start, count, ipType) - numIPs, _ := strconv.ParseUint(count, 10, 64) - cidrList := summarizeCIDRs(start, numIPs) - //log.Printf("[%s] %s/%s (%s) → %d Netze", strings.ToUpper(country), start, count, ipType, len(cidrList)) - for _, cidr := range cidrList { - prefix, err := netip.ParsePrefix(cidr) - if err != nil { - continue - } - key := "bl:" + country + ":" + prefix.String() - //fmt.Println(key) - _ = rdb.Set(ctx, key, "1", redisTTL).Err() - } - } - return scanner.Err() -} - -// -------------------------------------------- -// IP RANGE SUMMARIZER -// -------------------------------------------- - -func summarizeCIDRs(startIP string, count uint64) []string { - var result []string - - if count == 0 { - return result - } - ip := net.ParseIP(startIP) - if ip == nil { - return result - } - - // IPv4-Pfad --------------------------------------------------------------- - if v4 := ip.To4(); v4 != nil { - start := ip4ToUint(v4) - end := start + uint32(count) - 1 - - for start <= end { - prefix := 32 - uint32(bits.TrailingZeros32(start)) - for (start + (1 << (32 - prefix)) - 1) > end { - prefix++ - } - result = append(result, - fmt.Sprintf("%s/%d", uintToIP4(start), prefix)) - start += 1 << (32 - prefix) - } - return result - } - - // IPv6-Pfad --------------------------------------------------------------- - startBig := ip6ToBig(ip) // Startadresse - endBig := new(big.Int).Add(startBig, // Endadresse - new(big.Int).Sub(new(big.Int).SetUint64(count), big.NewInt(1))) - - for startBig.Cmp(endBig) <= 0 { - // größter Block, der am Start ausgerichtet ist - prefix := 128 - trailingZeros128(bigToIP6(startBig)) - - // so lange verkleinern, bis Block in Fenster passt - for { - blockSize := new(big.Int).Lsh(big.NewInt(1), uint(128-prefix)) - blockEnd := new(big.Int).Add(startBig, - new(big.Int).Sub(blockSize, big.NewInt(1))) - if blockEnd.Cmp(endBig) <= 0 { - break - } - prefix++ - } - - result = append(result, - fmt.Sprintf("%s/%d", bigToIP6(startBig), prefix)) - - // zum nächsten Subnetz springen - step := new(big.Int).Lsh(big.NewInt(1), uint(128-prefix)) - startBig = new(big.Int).Add(startBig, step) - } - return result -} - -/* ---------- Hilfsfunktionen IPv4 ---------- */ - -func ip4ToUint(ip net.IP) uint32 { - return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3]) -} -func uintToIP4(v uint32) net.IP { - return net.IPv4(byte(v>>24), byte(v>>16), byte(v>>8), byte(v)) -} - -/* ---------- Hilfsfunktionen IPv6 ---------- */ - -func ip6ToBig(ip net.IP) *big.Int { - return new(big.Int).SetBytes(ip.To16()) // garantiert 16 Byte -} -func bigToIP6(v *big.Int) net.IP { - b := v.Bytes() - if len(b) < 16 { // von links auf 16 Byte auffüllen - pad := make([]byte, 16-len(b)) - b = append(pad, b...) - } - return net.IP(b) -} - -// Anzahl der Null-Bits am wenigst-signifikanten Ende (LSB) eines IPv6-Werts -func trailingZeros128(ip net.IP) int { - b := ip.To16() - tz := 0 - for i := 15; i >= 0; i-- { // letzte Byte zuerst (LSB) - if b[i] == 0 { - tz += 8 } else { - tz += bits.TrailingZeros8(b[i]) - break + return false } } - return tz + + _, err := netip.ParsePrefix(s) + return err == nil }