Files
flod/main.go
2025-06-11 23:02:40 +02:00

419 lines
9.7 KiB
Go

package main
import (
"bufio"
"context"
"encoding/json"
"expvar"
"fmt"
"log"
"math/big"
"math/bits"
"net"
"net/http"
"net/netip"
"os"
"strconv"
"strings"
"sync"
"time"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/redis/go-redis/v9"
)
var (
ctx = context.Background()
redisAddr = getenv("REDIS_ADDR", "10.10.5.249:6379")
//redisAddr = getenv("REDIS_ADDR", "localhost:6379")
redisTTL = time.Hour * 24
cacheSize = 100_000
blocklistCats = []string{"generic"}
rdb *redis.Client
ipCache *lru.Cache[string, []string]
// Metrics
hits = expvar.NewInt("cache_hits")
misses = expvar.NewInt("cache_misses")
queries = expvar.NewInt("ip_queries")
)
var (
totalBlockedIPs = expvar.NewInt("total_blocked_ips")
totalWhitelistEntries = expvar.NewInt("total_whitelist_entries")
)
func updateTotalsFromRedis() {
go func() {
blockCount := 0
iter := rdb.Scan(ctx, 0, "bl:*", 0).Iterator()
for iter.Next(ctx) {
blockCount++
}
totalBlockedIPs.Set(int64(blockCount))
whiteCount := 0
iter = rdb.Scan(ctx, 0, "wl:*", 0).Iterator()
for iter.Next(ctx) {
whiteCount++
}
totalWhitelistEntries.Set(int64(whiteCount))
}()
}
func startMetricUpdater() {
ticker := time.NewTicker(10 * time.Second)
go func() {
for {
updateTotalsFromRedis()
<-ticker.C
}
}()
}
// --------------------------------------------
// INIT + MAIN
// --------------------------------------------
func main() {
var err error
// Redis client
rdb = redis.NewClient(&redis.Options{Addr: redisAddr})
if err := rdb.Ping(ctx).Err(); err != nil {
log.Fatalf("redis: %v", err)
}
// LRU cache
ipCache, err = lru.New[string, []string](cacheSize)
if err != nil {
log.Fatalf("cache init: %v", err)
}
startMetricUpdater()
// Admin load all blocklists (on demand or scheduled)
go func() {
if getenv("IMPORT_RIRS", "1") == "1" {
log.Println("Lade IP-Ranges aus RIRs...")
if err := importRIRDataToRedis(); err != nil {
log.Fatalf("import error: %v", err)
}
log.Println("✅ Import abgeschlossen.")
}
}()
// Routes
http.HandleFunc("/check/", handleCheck)
http.HandleFunc("/whitelist", handleWhitelist)
http.HandleFunc("/info", handleInfo)
http.Handle("/debug/vars", http.DefaultServeMux)
log.Println("🚀 Server läuft auf :8080")
log.Fatal(http.ListenAndServe(":8080", nil))
}
func getenv(k, fallback string) string {
if v := os.Getenv(k); v != "" {
return v
}
return fallback
}
// --------------------------------------------
// IP CHECK API
// --------------------------------------------
func handleCheck(w http.ResponseWriter, r *http.Request) {
ipStr := strings.TrimPrefix(r.URL.Path, "/check/")
addr, err := netip.ParseAddr(ipStr)
if err != nil {
http.Error(w, "invalid IP", 400)
return
}
cats := blocklistCats
if q := r.URL.Query().Get("cats"); q != "" {
cats = strings.Split(q, ",")
}
queries.Add(1)
blockedCats, err := checkIP(addr, cats)
if err != nil {
http.Error(w, "lookup error", 500)
return
}
writeJSON(w, map[string]any{
"ip": ipStr,
"blocked": len(blockedCats) > 0,
"categories": blockedCats,
})
}
func checkIP(ip netip.Addr, cats []string) ([]string, error) {
if res, ok := ipCache.Get(ip.String()); ok {
hits.Add(1)
return res, nil
}
matches := []string{}
for _, cat := range cats {
iter := rdb.Scan(ctx, 0, "bl:"+cat+":*", 0).Iterator()
for iter.Next(ctx) {
key := iter.Val()
parts := strings.SplitN(key, ":", 3)
if len(parts) != 3 {
continue
}
pfx, err := netip.ParsePrefix(parts[2])
if err != nil {
continue
}
if pfx.Contains(ip) {
matches = append(matches, cat)
break
}
}
}
misses.Add(1)
ipCache.Add(ip.String(), matches)
return matches, nil
}
// --------------------------------------------
// WHITELIST API (optional extension)
// --------------------------------------------
func handleWhitelist(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", 405)
return
}
var body struct {
IP string `json:"ip"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
http.Error(w, "bad request", 400)
return
}
addr, err := netip.ParseAddr(body.IP)
if err != nil {
http.Error(w, "invalid IP", 400)
return
}
// Add to whitelist (Redis key like wl:<ip>)
if err := rdb.Set(ctx, "wl:"+addr.String(), "1", 0).Err(); err != nil {
http.Error(w, "redis error", 500)
return
}
ipCache.Add(addr.String(), nil)
writeJSON(w, map[string]string{"status": "whitelisted"})
}
// --------------------------------------------
// ADMIN INFO
// --------------------------------------------
func handleInfo(w http.ResponseWriter, _ *http.Request) {
stats := map[string]any{
"cache_size": ipCache.Len(),
"ttl_hours": redisTTL.Hours(),
"redis": redisAddr,
}
writeJSON(w, stats)
}
// --------------------------------------------
// UTIL
// --------------------------------------------
func writeJSON(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(v)
}
// --------------------------------------------
// RIR DATA IMPORT (ALL COUNTRIES)
// --------------------------------------------
var rirFiles = []string{
"https://ftp.ripe.net/pub/stats/ripencc/delegated-ripencc-latest",
"https://ftp.apnic.net/stats/apnic/delegated-apnic-latest",
"https://ftp.arin.net/pub/stats/arin/delegated-arin-extended-latest",
"https://ftp.lacnic.net/pub/stats/lacnic/delegated-lacnic-latest",
"https://ftp.afrinic.net/pub/stats/afrinic/delegated-afrinic-extended-latest",
}
func importRIRDataToRedis() error {
wg := sync.WaitGroup{}
sem := make(chan struct{}, 5)
for _, url := range rirFiles {
wg.Add(1)
sem <- struct{}{}
go func(url string) {
defer wg.Done()
defer func() { <-sem }()
fmt.Println("Start: ", url)
if err := fetchAndStore(url); err != nil {
log.Printf("❌ Fehler bei %s: %v", url, err)
}
fmt.Println("Done: ", url)
}(url)
}
wg.Wait()
return nil
}
func fetchAndStore(url string) error {
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
scanner := bufio.NewScanner(resp.Body)
for scanner.Scan() {
line := scanner.Text()
if strings.HasPrefix(line, "#") || !strings.Contains(line, "|ipv") {
continue
}
fields := strings.Split(line, "|")
if len(fields) < 7 {
continue
}
country := strings.ToLower(fields[1])
ipType := fields[2]
start := fields[3]
count := fields[4]
if ipType != "ipv4" && ipType != "ipv6" {
continue
}
if start == "24.152.36.0" {
fmt.Printf("💡 Testing summarizeIPv4CIDRs(%s, %s)\n", start, count)
num, _ := strconv.ParseUint(count, 10, 64)
for _, cidr := range summarizeCIDRs(start, num) {
fmt.Println(" →", cidr)
}
}
//cidrList := summarizeToCIDRs(start, count, ipType)
numIPs, _ := strconv.ParseUint(count, 10, 64)
cidrList := summarizeCIDRs(start, numIPs)
//log.Printf("[%s] %s/%s (%s) → %d Netze", strings.ToUpper(country), start, count, ipType, len(cidrList))
for _, cidr := range cidrList {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
continue
}
key := "bl:" + country + ":" + prefix.String()
//fmt.Println(key)
_ = rdb.Set(ctx, key, "1", redisTTL).Err()
}
}
return scanner.Err()
}
// --------------------------------------------
// IP RANGE SUMMARIZER
// --------------------------------------------
func summarizeCIDRs(startIP string, count uint64) []string {
var result []string
if count == 0 {
return result
}
ip := net.ParseIP(startIP)
if ip == nil {
return result
}
// IPv4-Pfad ---------------------------------------------------------------
if v4 := ip.To4(); v4 != nil {
start := ip4ToUint(v4)
end := start + uint32(count) - 1
for start <= end {
prefix := 32 - uint32(bits.TrailingZeros32(start))
for (start + (1 << (32 - prefix)) - 1) > end {
prefix++
}
result = append(result,
fmt.Sprintf("%s/%d", uintToIP4(start), prefix))
start += 1 << (32 - prefix)
}
return result
}
// IPv6-Pfad ---------------------------------------------------------------
startBig := ip6ToBig(ip) // Startadresse
endBig := new(big.Int).Add(startBig, // Endadresse
new(big.Int).Sub(new(big.Int).SetUint64(count), big.NewInt(1)))
for startBig.Cmp(endBig) <= 0 {
// größter Block, der am Start ausgerichtet ist
prefix := 128 - trailingZeros128(bigToIP6(startBig))
// so lange verkleinern, bis Block in Fenster passt
for {
blockSize := new(big.Int).Lsh(big.NewInt(1), uint(128-prefix))
blockEnd := new(big.Int).Add(startBig,
new(big.Int).Sub(blockSize, big.NewInt(1)))
if blockEnd.Cmp(endBig) <= 0 {
break
}
prefix++
}
result = append(result,
fmt.Sprintf("%s/%d", bigToIP6(startBig), prefix))
// zum nächsten Subnetz springen
step := new(big.Int).Lsh(big.NewInt(1), uint(128-prefix))
startBig = new(big.Int).Add(startBig, step)
}
return result
}
/* ---------- Hilfsfunktionen IPv4 ---------- */
func ip4ToUint(ip net.IP) uint32 {
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
}
func uintToIP4(v uint32) net.IP {
return net.IPv4(byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
}
/* ---------- Hilfsfunktionen IPv6 ---------- */
func ip6ToBig(ip net.IP) *big.Int {
return new(big.Int).SetBytes(ip.To16()) // garantiert 16 Byte
}
func bigToIP6(v *big.Int) net.IP {
b := v.Bytes()
if len(b) < 16 { // von links auf 16 Byte auffüllen
pad := make([]byte, 16-len(b))
b = append(pad, b...)
}
return net.IP(b)
}
// Anzahl der Null-Bits am wenigst-signifikanten Ende (LSB) eines IPv6-Werts
func trailingZeros128(ip net.IP) int {
b := ip.To16()
tz := 0
for i := 15; i >= 0; i-- { // letzte Byte zuerst (LSB)
if b[i] == 0 {
tz += 8
} else {
tz += bits.TrailingZeros8(b[i])
break
}
}
return tz
}