Files
flod/__main.go
2025-06-14 11:27:07 +02:00

781 lines
20 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"bufio"
"context"
"encoding/binary"
"encoding/json"
"expvar"
"fmt"
"io"
"log"
"math/big"
"math/bits"
"net"
"net/http"
"net/netip"
"os"
"strconv"
"strings"
"sync"
"time"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/redis/go-redis/v9"
)
var (
ctx = context.Background()
redisAddr = getenv("REDIS_ADDR", "10.10.5.249:6379")
//redisAddr = getenv("REDIS_ADDR", "localhost:6379")
redisTTL = time.Hour * 24
cacheSize = 100_000
blocklistCats = []string{"generic"}
rdb *redis.Client
ipCache *lru.Cache[string, []string]
// Metrics
hits = expvar.NewInt("cache_hits")
misses = expvar.NewInt("cache_misses")
queries = expvar.NewInt("ip_queries")
)
var (
totalBlockedIPs = expvar.NewInt("total_blocked_ips")
totalWhitelistEntries = expvar.NewInt("total_whitelist_entries")
)
func updateTotalsFromRedis() {
go func() {
blockCount := 0
iter := rdb.Scan(ctx, 0, "bl:*", 0).Iterator()
for iter.Next(ctx) {
blockCount++
}
totalBlockedIPs.Set(int64(blockCount))
whiteCount := 0
iter = rdb.Scan(ctx, 0, "wl:*", 0).Iterator()
for iter.Next(ctx) {
whiteCount++
}
totalWhitelistEntries.Set(int64(whiteCount))
}()
}
func startMetricUpdater() {
ticker := time.NewTicker(10 * time.Second)
go func() {
for {
updateTotalsFromRedis()
<-ticker.C
}
}()
}
//
//
//
type Source struct {
Category string
URL []string
}
type Config struct {
RedisAddr string
Sources []Source
TTLHours int
IsWorker bool // true ⇒ lädt Blocklisten & schreibt sie nach Redis
}
func loadConfig() Config {
// default Blocklist source
srcs := []Source{{
Category: "generic",
URL: []string{
"https://raw.githubusercontent.com/firehol/blocklist-ipsets/master/firehol_level1.netset",
"https://raw.githubusercontent.com/bitwire-it/ipblocklist/refs/heads/main/ip-list.txt",
"https://ipv64.net/blocklists/countries/ipv64_blocklist_RU.txt",
"https://ipv64.net/blocklists/countries/ipv64_blocklist_CN.txt",
},
},
}
if env := os.Getenv("BLOCKLIST_SOURCES"); env != "" {
srcs = nil
for _, spec := range strings.Split(env, ",") {
spec = strings.TrimSpace(spec)
if spec == "" {
continue
}
parts := strings.SplitN(spec, ":", 2)
if len(parts) != 2 {
continue
}
cat := strings.TrimSpace(parts[0])
raw := strings.FieldsFunc(parts[1], func(r rune) bool { return r == '|' || r == ';' })
var urls []string
for _, u := range raw {
if u = strings.TrimSpace(u); u != "" {
urls = append(urls, u)
}
}
if len(urls) > 0 {
srcs = append(srcs, Source{Category: cat, URL: urls})
}
}
}
ttl := 24
if env := os.Getenv("TTL_HOURS"); env != "" {
fmt.Sscanf(env, "%d", &ttl)
}
isWorker := strings.ToLower(os.Getenv("ROLE")) == "worker"
return Config{
//RedisAddr: getenv("REDIS_ADDR", "redis:6379"),
RedisAddr: getenv("REDIS_ADDR", "10.10.5.249:6379"),
Sources: srcs,
TTLHours: ttl,
IsWorker: isWorker,
}
}
// Alle gültigen ISO 3166-1 Alpha-2 Ländercodes (abgekürzt, reale Liste ist länger)
var allCountryCodes = []string{
"AD", "AE", "AF", "AG", "AI", "AL", "AM", "AO", "AR", "AT", "AU", "AZ",
"BA", "BB", "BD", "BE", "BF", "BG", "BH", "BI", "BJ", "BN", "BO", "BR", "BS",
"BT", "BW", "BY", "BZ", "CA", "CD", "CF", "CG", "CH", "CI", "CL", "CM", "CN",
"CO", "CR", "CU", "CV", "CY", "CZ", "DE", "DJ", "DK", "DM", "DO", "DZ", "EC",
"EE", "EG", "ER", "ES", "ET", "FI", "FJ", "FM", "FR", "GA", "GB", "GD", "GE",
"GH", "GM", "GN", "GQ", "GR", "GT", "GW", "GY", "HK", "HN", "HR", "HT", "HU",
"ID", "IE", "IL", "IN", "IQ", "IR", "IS", "IT", "JM", "JO", "JP", "KE", "KG",
"KH", "KI", "KM", "KN", "KP", "KR", "KW", "KZ", "LA", "LB", "LC", "LI", "LK",
"LR", "LS", "LT", "LU", "LV", "LY", "MA", "MC", "MD", "ME", "MG", "MH", "MK",
"ML", "MM", "MN", "MR", "MT", "MU", "MV", "MW", "MX", "MY", "MZ", "NA", "NE",
"NG", "NI", "NL", "NO", "NP", "NR", "NZ", "OM", "PA", "PE", "PG", "PH", "PK",
"PL", "PT", "PW", "PY", "QA", "RO", "RS", "RU", "RW", "SA", "SB", "SC", "SD",
"SE", "SG", "SI", "SK", "SL", "SM", "SN", "SO", "SR", "ST", "SV", "SY", "SZ",
"TD", "TG", "TH", "TJ", "TL", "TM", "TN", "TO", "TR", "TT", "TV", "TZ", "UA",
"UG", "US", "UY", "UZ", "VC", "VE", "VN", "VU", "WS", "YE", "ZA", "ZM", "ZW",
}
// Hauptfunktion: gibt alle IPv4-Ranges eines Landes (CIDR) aus allen RIRs zurück
func GetIPRangesByCountry(countryCode string) ([]string, error) {
var allCIDRs []string
upperCode := strings.ToUpper(countryCode)
for _, url := range rirFiles {
resp, err := http.Get(url)
if err != nil {
return nil, fmt.Errorf("fehler beim abrufen von %s: %w", url, err)
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "2") || strings.HasPrefix(line, "#") {
continue // Kommentar oder Header
}
if strings.Contains(line, "|"+upperCode+"|ipv4|") {
fields := strings.Split(line, "|")
if len(fields) < 5 {
continue
}
ipStart := fields[3]
count, _ := strconv.Atoi(fields[4])
cidrs := summarizeCIDR(ipStart, count)
allCIDRs = append(allCIDRs, cidrs...)
}
}
}
return allCIDRs, nil
}
// Hilfsfunktion: Start-IP + Anzahl → []CIDR
func summarizeCIDR(start string, count int) []string {
var cidrs []string
ip := net.ParseIP(start).To4()
startInt := ipToInt(ip)
for count > 0 {
maxSize := 32
for maxSize > 0 {
mask := 1 << uint(32-maxSize)
if startInt%uint32(mask) == 0 && mask <= count {
break
}
maxSize--
}
cidr := fmt.Sprintf("%s/%d", intToIP(startInt), maxSize)
cidrs = append(cidrs, cidr)
count -= 1 << uint(32-maxSize)
startInt += uint32(1 << uint(32-maxSize))
}
return cidrs
}
func ipToInt(ip net.IP) uint32 {
return uint32(ip[0])<<24 + uint32(ip[1])<<16 + uint32(ip[2])<<8 + uint32(ip[3])
}
func intToIP(i uint32) net.IP {
return net.IPv4(byte(i>>24), byte(i>>16), byte(i>>8), byte(i))
}
func keyBlock(cat string, p netip.Prefix) string { return "bl:" + cat + ":" + p.String() }
func LoadAllCountryPrefixesIntoRedisAndRanger(
rdb *redis.Client,
ttlHours int,
) error {
for _, countryCode := range allCountryCodes {
expiry := time.Duration(ttlHours) * time.Hour
results := make(map[string][]netip.Prefix)
fmt.Printf("💡 Loading %s...\n", countryCode)
cidrs, err := GetIPRangesByCountry(countryCode)
if err != nil {
log.Printf("Error at %s: %v", countryCode, err)
}
fmt.Println("✅ Got " + strconv.Itoa(len(cidrs)) + " Ranges for Country " + countryCode)
var validPrefixes []netip.Prefix
for _, c := range cidrs {
prefix, err := netip.ParsePrefix(c)
if err != nil {
log.Printf("CIDR invalid [%s]: %v", c, err)
continue
}
validPrefixes = append(validPrefixes, prefix)
}
fmt.Println("✅ Got " + strconv.Itoa(len(validPrefixes)) + " valid Prefixes for Country " + countryCode)
if len(validPrefixes) > 0 {
results[countryCode] = validPrefixes
}
// Nach Verarbeitung: alles in Ranger + Redis eintragen
for code, prefixes := range results {
for _, p := range prefixes {
key := keyBlock(code, p)
if err := rdb.Set(ctx, key, "1", expiry).Err(); err != nil {
log.Printf("Redis-Error at %s: %v", key, err)
}
}
fmt.Println("✅ Import Subset " + strconv.Itoa(len(prefixes)) + " Entries")
}
fmt.Println("✅ Import done!")
fmt.Println("--------------------------------------------------")
}
return nil
}
func syncLoop(ctx context.Context, cfg Config, rdb *redis.Client) {
fmt.Println("💡 Loading Lists...")
if err := syncOnce(ctx, cfg, rdb); err != nil {
log.Println("initial sync:", err)
}
fmt.Println("✅ Loading Lists Done.")
ticker := time.NewTicker(30 * time.Minute)
for {
select {
case <-ticker.C:
fmt.Println("💡 Loading Lists Timer...")
if err := syncOnce(ctx, cfg, rdb); err != nil {
log.Println("sync loop:", err)
}
fmt.Println("✅ Loading Lists Timer Done.")
case <-ctx.Done():
ticker.Stop()
return
}
}
}
func syncOnce(ctx context.Context, cfg Config, rdb *redis.Client) error {
expiry := time.Duration(cfg.TTLHours) * time.Hour
newBlocks := make(map[string]map[netip.Prefix]struct{})
for _, src := range cfg.Sources {
for _, url := range src.URL {
fmt.Println("💡 Loading List " + src.Category + " : " + url)
if err := fetchList(ctx, url, func(p netip.Prefix) {
if _, ok := newBlocks[src.Category]; !ok {
newBlocks[src.Category] = map[netip.Prefix]struct{}{}
}
newBlocks[src.Category][p] = struct{}{}
_ = rdb.Set(ctx, keyBlock(src.Category, p), "1", expiry).Err()
}); err != nil {
fmt.Println("❌ Fail.")
return err
}
fmt.Println("✅ Done.")
}
}
return nil
}
func fetchList(ctx context.Context, url string, cb func(netip.Prefix)) error {
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("%s -> %s", url, resp.Status)
}
return parseStream(resp.Body, cb)
}
func parseStream(r io.Reader, cb func(netip.Prefix)) error {
s := bufio.NewScanner(r)
for s.Scan() {
line := strings.TrimSpace(s.Text())
if line == "" || strings.HasPrefix(line, "#") {
continue
}
if p, err := netip.ParsePrefix(line); err == nil {
cb(p)
continue
}
if addr, err := netip.ParseAddr(line); err == nil {
plen := 32
if addr.Is6() {
plen = 128
}
cb(netip.PrefixFrom(addr, plen))
}
}
return s.Err()
}
// --------------------------------------------
// INIT + MAIN
// --------------------------------------------
func main() {
if getenv("IMPORTER", "0") == "1" {
//Hier alles doof. selbe funktion wie unten. muss durch individuallisten ersetzt werden...
cfg := loadConfig()
rdb = redis.NewClient(&redis.Options{Addr: redisAddr})
/*if err := LoadAllCountryPrefixesIntoRedisAndRanger(rdb, cfg.TTLHours); err != nil {
log.Fatalf("Fehler beim Laden aller Länderranges: %v", err)
}*/
syncLoop(ctx, cfg, rdb)
log.Println("🚀 Import erfolgreich!")
} else {
var err error
// Redis client
rdb = redis.NewClient(&redis.Options{Addr: redisAddr})
if err := rdb.Ping(ctx).Err(); err != nil {
log.Fatalf("redis: %v", err)
}
// LRU cache
ipCache, err = lru.New[string, []string](cacheSize)
if err != nil {
log.Fatalf("cache init: %v", err)
}
startMetricUpdater()
// Admin load all blocklists (on demand or scheduled)
go func() {
if getenv("IMPORT_RIRS", "0") == "1" {
log.Println("Lade IP-Ranges aus RIRs...")
if err := importRIRDataToRedis(); err != nil {
log.Fatalf("import error: %v", err)
}
log.Println("✅ Import abgeschlossen.")
}
}()
// Routes
http.HandleFunc("/check/", handleCheck)
http.HandleFunc("/whitelist", handleWhitelist)
http.HandleFunc("/info", handleInfo)
http.Handle("/debug/vars", http.DefaultServeMux)
log.Println("🚀 Server läuft auf :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
}
}
func getenv(k, fallback string) string {
if v := os.Getenv(k); v != "" {
return v
}
return fallback
}
// --------------------------------------------
// IP CHECK API
// --------------------------------------------
func handleCheck(w http.ResponseWriter, r *http.Request) {
ipStr := strings.TrimPrefix(r.URL.Path, "/check/")
addr, err := netip.ParseAddr(ipStr)
if err != nil {
http.Error(w, "invalid IP", 400)
return
}
cats := blocklistCats
if q := r.URL.Query().Get("cats"); q != "" {
cats = strings.Split(q, ",")
}
queries.Add(1)
blockedCats, err := checkIP(addr, cats)
if err != nil {
http.Error(w, "lookup error", 500)
return
}
fmt.Println("----------------")
writeJSON(w, map[string]any{
"ip": ipStr,
"blocked": len(blockedCats) > 0,
"categories": blockedCats,
})
}
// liefert alle möglichen Präfixe dieser IP, beginnend beim längsten (/32 oder /128)
func supernets(ip netip.Addr) []string {
if ip.Is4() {
fmt.Println("💡 DEBUG: supernets | Is4")
a := ip.As4() // Kopie addressierbar machen
u := binary.BigEndian.Uint32(a[:]) // jetzt darf man slicen
fmt.Println("💡 DEBUG: supernets | a + u", a, u)
supers := make([]string, 33) // /32 … /0
for bits := 32; bits >= 0; bits-- {
mask := uint32(0xffffffff) << (32 - bits)
n := u & mask
addr := netip.AddrFrom4([4]byte{
byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
})
supers[32-bits] = fmt.Sprintf("%s/%d", addr, bits)
}
fmt.Println("💡 DEBUG: supernets | supers", supers)
return supers
}
a := ip.As16() // Kopie addressierbar
supers := make([]string, 129) // /128 … /0
for bits := 128; bits >= 0; bits-- {
b := a // Wert-Kopie für Modifikation
// vollständige Bytes auf 0 setzen
full := (128 - bits) / 8
for i := 0; i < full; i++ {
b[15-i] = 0
}
// Restbits maskieren
rem := (128 - bits) % 8
if rem != 0 {
b[15-full] &= 0xFF << rem
}
addr := netip.AddrFrom16(b)
supers[128-bits] = fmt.Sprintf("%s/%d", addr, bits)
}
fmt.Println("Supernets-v6", supers)
return supers
}
func checkIP(ip netip.Addr, cats []string) ([]string, error) {
// 1) Cache-Treffer?
if res, ok := ipCache.Get(ip.String()); ok {
fmt.Println("💡 DEBUG: Cache-Hit")
hits.Add(1)
return res, nil
}
// 2) alle Supernetze der IP (≤32 bzw. ≤128 Stück)
supers := supernets(ip)
fmt.Println("💡 DEBUG: checkIP | supers |", supers)
fmt.Println("💡 DEBUG: checkIP | ip |", ip)
fmt.Println("💡 DEBUG: checkIP | cats |", cats)
// 3) Pipeline jeweils *eine* EXISTS-Abfrage pro Kategorie
pipe := rdb.Pipeline()
existsCmds := make([]*redis.IntCmd, len(cats))
for i, cat := range cats {
keys := make([]string, len(supers))
for j, pfx := range supers {
keys[j] = "bl:" + cat + ":" + pfx
}
fmt.Println("💡 DEBUG: checkIP | keys |", keys)
existsCmds[i] = pipe.Exists(ctx, keys...)
}
if _, err := pipe.Exec(ctx); err != nil && err != redis.Nil {
return nil, err
}
// 4) Ergebnis auswerten
matches := make([]string, 0, len(cats))
for i, cat := range cats {
if existsCmds[i].Val() > 0 {
matches = append(matches, cat)
fmt.Println("💡 DEBUG: checkIP | matches:cat |", cat)
}
}
fmt.Println("💡 DEBUG: checkIP | matches |", matches)
// 5) Cache befüllen und zurück
misses.Add(1)
ipCache.Add(ip.String(), matches)
return matches, nil
}
// --------------------------------------------
// WHITELIST API (optional extension)
// --------------------------------------------
func handleWhitelist(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
var body struct {
IP string `json:"ip"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "bad request", 400)
return
}
addr, err := netip.ParseAddr(body.IP)
if err != nil {
http.Error(w, "invalid IP", 400)
return
}
// Add to whitelist (Redis key like wl:<ip>)
if err := rdb.Set(ctx, "wl:"+addr.String(), "1", 0).Err(); err != nil {
http.Error(w, "redis error", 500)
return
}
ipCache.Add(addr.String(), nil)
writeJSON(w, map[string]string{"status": "whitelisted"})
}
// --------------------------------------------
// ADMIN INFO
// --------------------------------------------
func handleInfo(w http.ResponseWriter, _ *http.Request) {
stats := map[string]any{
"cache_size": ipCache.Len(),
"ttl_hours": redisTTL.Hours(),
"redis": redisAddr,
}
writeJSON(w, stats)
}
// --------------------------------------------
// UTIL
// --------------------------------------------
func writeJSON(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(v)
}
// --------------------------------------------
// RIR DATA IMPORT (ALL COUNTRIES)
// --------------------------------------------
var rirFiles = []string{
"https://ftp.ripe.net/pub/stats/ripencc/delegated-ripencc-latest",
"https://ftp.apnic.net/stats/apnic/delegated-apnic-latest",
"https://ftp.arin.net/pub/stats/arin/delegated-arin-extended-latest",
"https://ftp.lacnic.net/pub/stats/lacnic/delegated-lacnic-latest",
"https://ftp.afrinic.net/pub/stats/afrinic/delegated-afrinic-extended-latest",
}
func importRIRDataToRedis() error {
wg := sync.WaitGroup{}
sem := make(chan struct{}, 5)
for _, url := range rirFiles {
wg.Add(1)
sem <- struct{}{}
go func(url string) {
defer wg.Done()
defer func() { <-sem }()
fmt.Println("Start: ", url)
if err := fetchAndStore(url); err != nil {
log.Printf("❌ Fehler bei %s: %v", url, err)
}
fmt.Println("Done: ", url)
}(url)
}
wg.Wait()
return nil
}
func fetchAndStore(url string) error {
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "#") || !strings.Contains(line, "|ipv") {
continue
}
fields := strings.Split(line, "|")
if len(fields) < 7 {
continue
}
country := strings.ToLower(fields[1])
ipType := fields[2]
start := fields[3]
count := fields[4]
if ipType != "ipv4" && ipType != "ipv6" {
continue
}
if start == "24.152.36.0" {
fmt.Printf("💡 Testing summarizeIPv4CIDRs(%s, %s)\n", start, count)
num, _ := strconv.ParseUint(count, 10, 64)
for _, cidr := range summarizeCIDRs(start, num) {
fmt.Println(" →", cidr)
}
}
//cidrList := summarizeToCIDRs(start, count, ipType)
numIPs, _ := strconv.ParseUint(count, 10, 64)
cidrList := summarizeCIDRs(start, numIPs)
//log.Printf("[%s] %s/%s (%s) → %d Netze", strings.ToUpper(country), start, count, ipType, len(cidrList))
for _, cidr := range cidrList {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
continue
}
key := "bl:" + country + ":" + prefix.String()
//fmt.Println(key)
_ = rdb.Set(ctx, key, "1", redisTTL).Err()
}
}
return scanner.Err()
}
// --------------------------------------------
// IP RANGE SUMMARIZER
// --------------------------------------------
func summarizeCIDRs(startIP string, count uint64) []string {
var result []string
if count == 0 {
return result
}
ip := net.ParseIP(startIP)
if ip == nil {
return result
}
// IPv4-Pfad ---------------------------------------------------------------
if v4 := ip.To4(); v4 != nil {
start := ip4ToUint(v4)
end := start + uint32(count) - 1
for start <= end {
prefix := 32 - uint32(bits.TrailingZeros32(start))
for (start + (1 << (32 - prefix)) - 1) > end {
prefix++
}
result = append(result,
fmt.Sprintf("%s/%d", uintToIP4(start), prefix))
start += 1 << (32 - prefix)
}
return result
}
// IPv6-Pfad ---------------------------------------------------------------
startBig := ip6ToBig(ip) // Startadresse
endBig := new(big.Int).Add(startBig, // Endadresse
new(big.Int).Sub(new(big.Int).SetUint64(count), big.NewInt(1)))
for startBig.Cmp(endBig) <= 0 {
// größter Block, der am Start ausgerichtet ist
prefix := 128 - trailingZeros128(bigToIP6(startBig))
// so lange verkleinern, bis Block in Fenster passt
for {
blockSize := new(big.Int).Lsh(big.NewInt(1), uint(128-prefix))
blockEnd := new(big.Int).Add(startBig,
new(big.Int).Sub(blockSize, big.NewInt(1)))
if blockEnd.Cmp(endBig) <= 0 {
break
}
prefix++
}
result = append(result,
fmt.Sprintf("%s/%d", bigToIP6(startBig), prefix))
// zum nächsten Subnetz springen
step := new(big.Int).Lsh(big.NewInt(1), uint(128-prefix))
startBig = new(big.Int).Add(startBig, step)
}
return result
}
/* ---------- Hilfsfunktionen IPv4 ---------- */
func ip4ToUint(ip net.IP) uint32 {
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
}
func uintToIP4(v uint32) net.IP {
return net.IPv4(byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
}
/* ---------- Hilfsfunktionen IPv6 ---------- */
func ip6ToBig(ip net.IP) *big.Int {
return new(big.Int).SetBytes(ip.To16()) // garantiert 16 Byte
}
func bigToIP6(v *big.Int) net.IP {
b := v.Bytes()
if len(b) < 16 { // von links auf 16 Byte auffüllen
pad := make([]byte, 16-len(b))
b = append(pad, b...)
}
return net.IP(b)
}
// Anzahl der Null-Bits am wenigst-signifikanten Ende (LSB) eines IPv6-Werts
func trailingZeros128(ip net.IP) int {
b := ip.To16()
tz := 0
for i := 15; i >= 0; i-- { // letzte Byte zuerst (LSB)
if b[i] == 0 {
tz += 8
} else {
tz += bits.TrailingZeros8(b[i])
break
}
}
return tz
}