Files
flod/main.go
jbergner bbccee0754
All checks were successful
release-tag / release-image (push) Successful in 1m34s
das sollte klappen!
2025-06-12 06:48:21 +02:00

470 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package main
import (
"bufio"
"context"
"encoding/binary"
"encoding/json"
"expvar"
"fmt"
"log"
"math/big"
"math/bits"
"net"
"net/http"
"net/netip"
"os"
"strconv"
"strings"
"sync"
"time"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/redis/go-redis/v9"
)
var (
ctx = context.Background()
redisAddr = getenv("REDIS_ADDR", "10.10.5.249:6379")
//redisAddr = getenv("REDIS_ADDR", "localhost:6379")
redisTTL = time.Hour * 24
cacheSize = 100_000
blocklistCats = []string{"generic"}
rdb *redis.Client
ipCache *lru.Cache[string, []string]
// Metrics
hits = expvar.NewInt("cache_hits")
misses = expvar.NewInt("cache_misses")
queries = expvar.NewInt("ip_queries")
)
var (
totalBlockedIPs = expvar.NewInt("total_blocked_ips")
totalWhitelistEntries = expvar.NewInt("total_whitelist_entries")
)
func updateTotalsFromRedis() {
go func() {
blockCount := 0
iter := rdb.Scan(ctx, 0, "bl:*", 0).Iterator()
for iter.Next(ctx) {
blockCount++
}
totalBlockedIPs.Set(int64(blockCount))
whiteCount := 0
iter = rdb.Scan(ctx, 0, "wl:*", 0).Iterator()
for iter.Next(ctx) {
whiteCount++
}
totalWhitelistEntries.Set(int64(whiteCount))
}()
}
func startMetricUpdater() {
ticker := time.NewTicker(10 * time.Second)
go func() {
for {
updateTotalsFromRedis()
<-ticker.C
}
}()
}
// --------------------------------------------
// INIT + MAIN
// --------------------------------------------
func main() {
var err error
// Redis client
rdb = redis.NewClient(&redis.Options{Addr: redisAddr})
if err := rdb.Ping(ctx).Err(); err != nil {
log.Fatalf("redis: %v", err)
}
// LRU cache
ipCache, err = lru.New[string, []string](cacheSize)
if err != nil {
log.Fatalf("cache init: %v", err)
}
startMetricUpdater()
// Admin load all blocklists (on demand or scheduled)
go func() {
if getenv("IMPORT_RIRS", "0") == "1" {
log.Println("Lade IP-Ranges aus RIRs...")
if err := importRIRDataToRedis(); err != nil {
log.Fatalf("import error: %v", err)
}
log.Println("✅ Import abgeschlossen.")
}
}()
// Routes
http.HandleFunc("/check/", handleCheck)
http.HandleFunc("/whitelist", handleWhitelist)
http.HandleFunc("/info", handleInfo)
http.Handle("/debug/vars", http.DefaultServeMux)
log.Println("🚀 Server läuft auf :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
}
func getenv(k, fallback string) string {
if v := os.Getenv(k); v != "" {
return v
}
return fallback
}
// --------------------------------------------
// IP CHECK API
// --------------------------------------------
func handleCheck(w http.ResponseWriter, r *http.Request) {
ipStr := strings.TrimPrefix(r.URL.Path, "/check/")
addr, err := netip.ParseAddr(ipStr)
if err != nil {
http.Error(w, "invalid IP", 400)
return
}
cats := blocklistCats
if q := r.URL.Query().Get("cats"); q != "" {
cats = strings.Split(q, ",")
}
queries.Add(1)
blockedCats, err := checkIP(addr, cats)
if err != nil {
http.Error(w, "lookup error", 500)
return
}
writeJSON(w, map[string]any{
"ip": ipStr,
"blocked": len(blockedCats) > 0,
"categories": blockedCats,
})
}
// liefert alle möglichen Präfixe dieser IP, beginnend beim längsten (/32 oder /128)
func supernets(ip netip.Addr) []string {
if ip.Is4() {
a := ip.As4() // Kopie addressierbar machen
u := binary.BigEndian.Uint32(a[:]) // jetzt darf man slicen
supers := make([]string, 33) // /32 … /0
for bits := 32; bits >= 0; bits-- {
mask := uint32(0xffffffff) << (32 - bits)
n := u & mask
addr := netip.AddrFrom4([4]byte{
byte(n >> 24), byte(n >> 16), byte(n >> 8), byte(n),
})
supers[32-bits] = fmt.Sprintf("%s/%d", addr, bits)
}
return supers
}
a := ip.As16() // Kopie addressierbar
supers := make([]string, 129) // /128 … /0
for bits := 128; bits >= 0; bits-- {
b := a // Wert-Kopie für Modifikation
// vollständige Bytes auf 0 setzen
full := (128 - bits) / 8
for i := 0; i < full; i++ {
b[15-i] = 0
}
// Restbits maskieren
rem := (128 - bits) % 8
if rem != 0 {
b[15-full] &= 0xFF << rem
}
addr := netip.AddrFrom16(b)
supers[128-bits] = fmt.Sprintf("%s/%d", addr, bits)
}
return supers
}
func checkIP(ip netip.Addr, cats []string) ([]string, error) {
// 1) Cache-Treffer?
if res, ok := ipCache.Get(ip.String()); ok {
hits.Add(1)
return res, nil
}
// 2) alle Supernetze der IP (≤32 bzw. ≤128 Stück)
supers := supernets(ip)
// 3) Pipeline jeweils *eine* EXISTS-Abfrage pro Kategorie
pipe := rdb.Pipeline()
existsCmds := make([]*redis.IntCmd, len(cats))
for i, cat := range cats {
keys := make([]string, len(supers))
for j, pfx := range supers {
keys[j] = "bl:" + cat + ":" + pfx
}
existsCmds[i] = pipe.Exists(ctx, keys...)
}
if _, err := pipe.Exec(ctx); err != nil && err != redis.Nil {
return nil, err
}
// 4) Ergebnis auswerten
matches := make([]string, 0, len(cats))
for i, cat := range cats {
if existsCmds[i].Val() > 0 {
matches = append(matches, cat)
}
}
// 5) Cache befüllen und zurück
misses.Add(1)
ipCache.Add(ip.String(), matches)
return matches, nil
}
// --------------------------------------------
// WHITELIST API (optional extension)
// --------------------------------------------
func handleWhitelist(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", 405)
return
}
var body struct {
IP string `json:"ip"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "bad request", 400)
return
}
addr, err := netip.ParseAddr(body.IP)
if err != nil {
http.Error(w, "invalid IP", 400)
return
}
// Add to whitelist (Redis key like wl:<ip>)
if err := rdb.Set(ctx, "wl:"+addr.String(), "1", 0).Err(); err != nil {
http.Error(w, "redis error", 500)
return
}
ipCache.Add(addr.String(), nil)
writeJSON(w, map[string]string{"status": "whitelisted"})
}
// --------------------------------------------
// ADMIN INFO
// --------------------------------------------
func handleInfo(w http.ResponseWriter, _ *http.Request) {
stats := map[string]any{
"cache_size": ipCache.Len(),
"ttl_hours": redisTTL.Hours(),
"redis": redisAddr,
}
writeJSON(w, stats)
}
// --------------------------------------------
// UTIL
// --------------------------------------------
func writeJSON(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(v)
}
// --------------------------------------------
// RIR DATA IMPORT (ALL COUNTRIES)
// --------------------------------------------
var rirFiles = []string{
"https://ftp.ripe.net/pub/stats/ripencc/delegated-ripencc-latest",
"https://ftp.apnic.net/stats/apnic/delegated-apnic-latest",
"https://ftp.arin.net/pub/stats/arin/delegated-arin-extended-latest",
"https://ftp.lacnic.net/pub/stats/lacnic/delegated-lacnic-latest",
"https://ftp.afrinic.net/pub/stats/afrinic/delegated-afrinic-extended-latest",
}
func importRIRDataToRedis() error {
wg := sync.WaitGroup{}
sem := make(chan struct{}, 5)
for _, url := range rirFiles {
wg.Add(1)
sem <- struct{}{}
go func(url string) {
defer wg.Done()
defer func() { <-sem }()
fmt.Println("Start: ", url)
if err := fetchAndStore(url); err != nil {
log.Printf("❌ Fehler bei %s: %v", url, err)
}
fmt.Println("Done: ", url)
}(url)
}
wg.Wait()
return nil
}
func fetchAndStore(url string) error {
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "#") || !strings.Contains(line, "|ipv") {
continue
}
fields := strings.Split(line, "|")
if len(fields) < 7 {
continue
}
country := strings.ToLower(fields[1])
ipType := fields[2]
start := fields[3]
count := fields[4]
if ipType != "ipv4" && ipType != "ipv6" {
continue
}
if start == "24.152.36.0" {
fmt.Printf("💡 Testing summarizeIPv4CIDRs(%s, %s)\n", start, count)
num, _ := strconv.ParseUint(count, 10, 64)
for _, cidr := range summarizeCIDRs(start, num) {
fmt.Println(" →", cidr)
}
}
//cidrList := summarizeToCIDRs(start, count, ipType)
numIPs, _ := strconv.ParseUint(count, 10, 64)
cidrList := summarizeCIDRs(start, numIPs)
//log.Printf("[%s] %s/%s (%s) → %d Netze", strings.ToUpper(country), start, count, ipType, len(cidrList))
for _, cidr := range cidrList {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
continue
}
key := "bl:" + country + ":" + prefix.String()
//fmt.Println(key)
_ = rdb.Set(ctx, key, "1", redisTTL).Err()
}
}
return scanner.Err()
}
// --------------------------------------------
// IP RANGE SUMMARIZER
// --------------------------------------------
func summarizeCIDRs(startIP string, count uint64) []string {
var result []string
if count == 0 {
return result
}
ip := net.ParseIP(startIP)
if ip == nil {
return result
}
// IPv4-Pfad ---------------------------------------------------------------
if v4 := ip.To4(); v4 != nil {
start := ip4ToUint(v4)
end := start + uint32(count) - 1
for start <= end {
prefix := 32 - uint32(bits.TrailingZeros32(start))
for (start + (1 << (32 - prefix)) - 1) > end {
prefix++
}
result = append(result,
fmt.Sprintf("%s/%d", uintToIP4(start), prefix))
start += 1 << (32 - prefix)
}
return result
}
// IPv6-Pfad ---------------------------------------------------------------
startBig := ip6ToBig(ip) // Startadresse
endBig := new(big.Int).Add(startBig, // Endadresse
new(big.Int).Sub(new(big.Int).SetUint64(count), big.NewInt(1)))
for startBig.Cmp(endBig) <= 0 {
// größter Block, der am Start ausgerichtet ist
prefix := 128 - trailingZeros128(bigToIP6(startBig))
// so lange verkleinern, bis Block in Fenster passt
for {
blockSize := new(big.Int).Lsh(big.NewInt(1), uint(128-prefix))
blockEnd := new(big.Int).Add(startBig,
new(big.Int).Sub(blockSize, big.NewInt(1)))
if blockEnd.Cmp(endBig) <= 0 {
break
}
prefix++
}
result = append(result,
fmt.Sprintf("%s/%d", bigToIP6(startBig), prefix))
// zum nächsten Subnetz springen
step := new(big.Int).Lsh(big.NewInt(1), uint(128-prefix))
startBig = new(big.Int).Add(startBig, step)
}
return result
}
/* ---------- Hilfsfunktionen IPv4 ---------- */
func ip4ToUint(ip net.IP) uint32 {
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
}
func uintToIP4(v uint32) net.IP {
return net.IPv4(byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
}
/* ---------- Hilfsfunktionen IPv6 ---------- */
func ip6ToBig(ip net.IP) *big.Int {
return new(big.Int).SetBytes(ip.To16()) // garantiert 16 Byte
}
func bigToIP6(v *big.Int) net.IP {
b := v.Bytes()
if len(b) < 16 { // von links auf 16 Byte auffüllen
pad := make([]byte, 16-len(b))
b = append(pad, b...)
}
return net.IP(b)
}
// Anzahl der Null-Bits am wenigst-signifikanten Ende (LSB) eines IPv6-Werts
func trailingZeros128(ip net.IP) int {
b := ip.To16()
tz := 0
for i := 15; i >= 0; i-- { // letzte Byte zuerst (LSB)
if b[i] == 0 {
tz += 8
} else {
tz += bits.TrailingZeros8(b[i])
break
}
}
return tz
}