419 lines
9.7 KiB
Go
419 lines
9.7 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"encoding/json"
|
|
"expvar"
|
|
"fmt"
|
|
"log"
|
|
"math/big"
|
|
"math/bits"
|
|
"net"
|
|
"net/http"
|
|
"net/netip"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
lru "github.com/hashicorp/golang-lru/v2"
|
|
"github.com/redis/go-redis/v9"
|
|
)
|
|
|
|
var (
|
|
ctx = context.Background()
|
|
redisAddr = getenv("REDIS_ADDR", "10.10.5.249:6379")
|
|
//redisAddr = getenv("REDIS_ADDR", "localhost:6379")
|
|
redisTTL = time.Hour * 24
|
|
cacheSize = 100_000
|
|
blocklistCats = []string{"generic"}
|
|
rdb *redis.Client
|
|
ipCache *lru.Cache[string, []string]
|
|
|
|
// Metrics
|
|
hits = expvar.NewInt("cache_hits")
|
|
misses = expvar.NewInt("cache_misses")
|
|
queries = expvar.NewInt("ip_queries")
|
|
)
|
|
|
|
var (
|
|
totalBlockedIPs = expvar.NewInt("total_blocked_ips")
|
|
totalWhitelistEntries = expvar.NewInt("total_whitelist_entries")
|
|
)
|
|
|
|
func updateTotalsFromRedis() {
|
|
go func() {
|
|
blockCount := 0
|
|
iter := rdb.Scan(ctx, 0, "bl:*", 0).Iterator()
|
|
for iter.Next(ctx) {
|
|
blockCount++
|
|
}
|
|
totalBlockedIPs.Set(int64(blockCount))
|
|
|
|
whiteCount := 0
|
|
iter = rdb.Scan(ctx, 0, "wl:*", 0).Iterator()
|
|
for iter.Next(ctx) {
|
|
whiteCount++
|
|
}
|
|
totalWhitelistEntries.Set(int64(whiteCount))
|
|
}()
|
|
}
|
|
|
|
func startMetricUpdater() {
|
|
ticker := time.NewTicker(10 * time.Second)
|
|
go func() {
|
|
for {
|
|
updateTotalsFromRedis()
|
|
<-ticker.C
|
|
}
|
|
}()
|
|
}
|
|
|
|
// --------------------------------------------
|
|
// INIT + MAIN
|
|
// --------------------------------------------
|
|
|
|
func main() {
|
|
var err error
|
|
|
|
// Redis client
|
|
rdb = redis.NewClient(&redis.Options{Addr: redisAddr})
|
|
if err := rdb.Ping(ctx).Err(); err != nil {
|
|
log.Fatalf("redis: %v", err)
|
|
}
|
|
|
|
// LRU cache
|
|
ipCache, err = lru.New[string, []string](cacheSize)
|
|
if err != nil {
|
|
log.Fatalf("cache init: %v", err)
|
|
}
|
|
|
|
startMetricUpdater()
|
|
|
|
// Admin load all blocklists (on demand or scheduled)
|
|
go func() {
|
|
if getenv("IMPORT_RIRS", "1") == "1" {
|
|
log.Println("Lade IP-Ranges aus RIRs...")
|
|
if err := importRIRDataToRedis(); err != nil {
|
|
log.Fatalf("import error: %v", err)
|
|
}
|
|
log.Println("✅ Import abgeschlossen.")
|
|
}
|
|
}()
|
|
|
|
// Routes
|
|
http.HandleFunc("/check/", handleCheck)
|
|
http.HandleFunc("/whitelist", handleWhitelist)
|
|
http.HandleFunc("/info", handleInfo)
|
|
http.Handle("/debug/vars", http.DefaultServeMux)
|
|
|
|
log.Println("🚀 Server läuft auf :8080")
|
|
log.Fatal(http.ListenAndServe(":8080", nil))
|
|
}
|
|
|
|
func getenv(k, fallback string) string {
|
|
if v := os.Getenv(k); v != "" {
|
|
return v
|
|
}
|
|
return fallback
|
|
}
|
|
|
|
// --------------------------------------------
|
|
// IP CHECK API
|
|
// --------------------------------------------
|
|
|
|
func handleCheck(w http.ResponseWriter, r *http.Request) {
|
|
ipStr := strings.TrimPrefix(r.URL.Path, "/check/")
|
|
addr, err := netip.ParseAddr(ipStr)
|
|
if err != nil {
|
|
http.Error(w, "invalid IP", 400)
|
|
return
|
|
}
|
|
|
|
cats := blocklistCats
|
|
if q := r.URL.Query().Get("cats"); q != "" {
|
|
cats = strings.Split(q, ",")
|
|
}
|
|
|
|
queries.Add(1)
|
|
blockedCats, err := checkIP(addr, cats)
|
|
if err != nil {
|
|
http.Error(w, "lookup error", 500)
|
|
return
|
|
}
|
|
|
|
writeJSON(w, map[string]any{
|
|
"ip": ipStr,
|
|
"blocked": len(blockedCats) > 0,
|
|
"categories": blockedCats,
|
|
})
|
|
}
|
|
|
|
func checkIP(ip netip.Addr, cats []string) ([]string, error) {
|
|
if res, ok := ipCache.Get(ip.String()); ok {
|
|
hits.Add(1)
|
|
return res, nil
|
|
}
|
|
|
|
matches := []string{}
|
|
for _, cat := range cats {
|
|
iter := rdb.Scan(ctx, 0, "bl:"+cat+":*", 0).Iterator()
|
|
for iter.Next(ctx) {
|
|
key := iter.Val()
|
|
parts := strings.SplitN(key, ":", 3)
|
|
if len(parts) != 3 {
|
|
continue
|
|
}
|
|
pfx, err := netip.ParsePrefix(parts[2])
|
|
if err != nil {
|
|
continue
|
|
}
|
|
if pfx.Contains(ip) {
|
|
matches = append(matches, cat)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
misses.Add(1)
|
|
ipCache.Add(ip.String(), matches)
|
|
return matches, nil
|
|
}
|
|
|
|
// --------------------------------------------
|
|
// WHITELIST API (optional extension)
|
|
// --------------------------------------------
|
|
|
|
func handleWhitelist(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method != http.MethodPost {
|
|
http.Error(w, "method not allowed", 405)
|
|
return
|
|
}
|
|
var body struct {
|
|
IP string `json:"ip"`
|
|
}
|
|
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
|
|
http.Error(w, "bad request", 400)
|
|
return
|
|
}
|
|
addr, err := netip.ParseAddr(body.IP)
|
|
if err != nil {
|
|
http.Error(w, "invalid IP", 400)
|
|
return
|
|
}
|
|
// Add to whitelist (Redis key like wl:<ip>)
|
|
if err := rdb.Set(ctx, "wl:"+addr.String(), "1", 0).Err(); err != nil {
|
|
http.Error(w, "redis error", 500)
|
|
return
|
|
}
|
|
ipCache.Add(addr.String(), nil)
|
|
writeJSON(w, map[string]string{"status": "whitelisted"})
|
|
}
|
|
|
|
// --------------------------------------------
|
|
// ADMIN INFO
|
|
// --------------------------------------------
|
|
|
|
func handleInfo(w http.ResponseWriter, _ *http.Request) {
|
|
stats := map[string]any{
|
|
"cache_size": ipCache.Len(),
|
|
"ttl_hours": redisTTL.Hours(),
|
|
"redis": redisAddr,
|
|
}
|
|
writeJSON(w, stats)
|
|
}
|
|
|
|
// --------------------------------------------
|
|
// UTIL
|
|
// --------------------------------------------
|
|
|
|
func writeJSON(w http.ResponseWriter, v any) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_ = json.NewEncoder(w).Encode(v)
|
|
}
|
|
|
|
// --------------------------------------------
|
|
// RIR DATA IMPORT (ALL COUNTRIES)
|
|
// --------------------------------------------
|
|
|
|
var rirFiles = []string{
|
|
"https://ftp.ripe.net/pub/stats/ripencc/delegated-ripencc-latest",
|
|
"https://ftp.apnic.net/stats/apnic/delegated-apnic-latest",
|
|
"https://ftp.arin.net/pub/stats/arin/delegated-arin-extended-latest",
|
|
"https://ftp.lacnic.net/pub/stats/lacnic/delegated-lacnic-latest",
|
|
"https://ftp.afrinic.net/pub/stats/afrinic/delegated-afrinic-extended-latest",
|
|
}
|
|
|
|
func importRIRDataToRedis() error {
|
|
wg := sync.WaitGroup{}
|
|
sem := make(chan struct{}, 5)
|
|
|
|
for _, url := range rirFiles {
|
|
wg.Add(1)
|
|
sem <- struct{}{}
|
|
go func(url string) {
|
|
defer wg.Done()
|
|
defer func() { <-sem }()
|
|
fmt.Println("Start: ", url)
|
|
if err := fetchAndStore(url); err != nil {
|
|
log.Printf("❌ Fehler bei %s: %v", url, err)
|
|
}
|
|
fmt.Println("Done: ", url)
|
|
}(url)
|
|
}
|
|
wg.Wait()
|
|
return nil
|
|
}
|
|
|
|
func fetchAndStore(url string) error {
|
|
resp, err := http.Get(url)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
for scanner.Scan() {
|
|
line := scanner.Text()
|
|
if strings.HasPrefix(line, "#") || !strings.Contains(line, "|ipv") {
|
|
continue
|
|
}
|
|
fields := strings.Split(line, "|")
|
|
if len(fields) < 7 {
|
|
continue
|
|
}
|
|
country := strings.ToLower(fields[1])
|
|
ipType := fields[2]
|
|
start := fields[3]
|
|
count := fields[4]
|
|
|
|
if ipType != "ipv4" && ipType != "ipv6" {
|
|
continue
|
|
}
|
|
|
|
if start == "24.152.36.0" {
|
|
fmt.Printf("💡 Testing summarizeIPv4CIDRs(%s, %s)\n", start, count)
|
|
num, _ := strconv.ParseUint(count, 10, 64)
|
|
for _, cidr := range summarizeCIDRs(start, num) {
|
|
fmt.Println(" →", cidr)
|
|
}
|
|
}
|
|
|
|
//cidrList := summarizeToCIDRs(start, count, ipType)
|
|
numIPs, _ := strconv.ParseUint(count, 10, 64)
|
|
cidrList := summarizeCIDRs(start, numIPs)
|
|
//log.Printf("[%s] %s/%s (%s) → %d Netze", strings.ToUpper(country), start, count, ipType, len(cidrList))
|
|
for _, cidr := range cidrList {
|
|
prefix, err := netip.ParsePrefix(cidr)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
key := "bl:" + country + ":" + prefix.String()
|
|
//fmt.Println(key)
|
|
_ = rdb.Set(ctx, key, "1", redisTTL).Err()
|
|
}
|
|
}
|
|
return scanner.Err()
|
|
}
|
|
|
|
// --------------------------------------------
|
|
// IP RANGE SUMMARIZER
|
|
// --------------------------------------------
|
|
|
|
func summarizeCIDRs(startIP string, count uint64) []string {
|
|
var result []string
|
|
|
|
if count == 0 {
|
|
return result
|
|
}
|
|
ip := net.ParseIP(startIP)
|
|
if ip == nil {
|
|
return result
|
|
}
|
|
|
|
// IPv4-Pfad ---------------------------------------------------------------
|
|
if v4 := ip.To4(); v4 != nil {
|
|
start := ip4ToUint(v4)
|
|
end := start + uint32(count) - 1
|
|
|
|
for start <= end {
|
|
prefix := 32 - uint32(bits.TrailingZeros32(start))
|
|
for (start + (1 << (32 - prefix)) - 1) > end {
|
|
prefix++
|
|
}
|
|
result = append(result,
|
|
fmt.Sprintf("%s/%d", uintToIP4(start), prefix))
|
|
start += 1 << (32 - prefix)
|
|
}
|
|
return result
|
|
}
|
|
|
|
// IPv6-Pfad ---------------------------------------------------------------
|
|
startBig := ip6ToBig(ip) // Startadresse
|
|
endBig := new(big.Int).Add(startBig, // Endadresse
|
|
new(big.Int).Sub(new(big.Int).SetUint64(count), big.NewInt(1)))
|
|
|
|
for startBig.Cmp(endBig) <= 0 {
|
|
// größter Block, der am Start ausgerichtet ist
|
|
prefix := 128 - trailingZeros128(bigToIP6(startBig))
|
|
|
|
// so lange verkleinern, bis Block in Fenster passt
|
|
for {
|
|
blockSize := new(big.Int).Lsh(big.NewInt(1), uint(128-prefix))
|
|
blockEnd := new(big.Int).Add(startBig,
|
|
new(big.Int).Sub(blockSize, big.NewInt(1)))
|
|
if blockEnd.Cmp(endBig) <= 0 {
|
|
break
|
|
}
|
|
prefix++
|
|
}
|
|
|
|
result = append(result,
|
|
fmt.Sprintf("%s/%d", bigToIP6(startBig), prefix))
|
|
|
|
// zum nächsten Subnetz springen
|
|
step := new(big.Int).Lsh(big.NewInt(1), uint(128-prefix))
|
|
startBig = new(big.Int).Add(startBig, step)
|
|
}
|
|
return result
|
|
}
|
|
|
|
/* ---------- Hilfsfunktionen IPv4 ---------- */
|
|
|
|
func ip4ToUint(ip net.IP) uint32 {
|
|
return uint32(ip[0])<<24 | uint32(ip[1])<<16 | uint32(ip[2])<<8 | uint32(ip[3])
|
|
}
|
|
func uintToIP4(v uint32) net.IP {
|
|
return net.IPv4(byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
|
|
}
|
|
|
|
/* ---------- Hilfsfunktionen IPv6 ---------- */
|
|
|
|
func ip6ToBig(ip net.IP) *big.Int {
|
|
return new(big.Int).SetBytes(ip.To16()) // garantiert 16 Byte
|
|
}
|
|
func bigToIP6(v *big.Int) net.IP {
|
|
b := v.Bytes()
|
|
if len(b) < 16 { // von links auf 16 Byte auffüllen
|
|
pad := make([]byte, 16-len(b))
|
|
b = append(pad, b...)
|
|
}
|
|
return net.IP(b)
|
|
}
|
|
|
|
// Anzahl der Null-Bits am wenigst-signifikanten Ende (LSB) eines IPv6-Werts
|
|
func trailingZeros128(ip net.IP) int {
|
|
b := ip.To16()
|
|
tz := 0
|
|
for i := 15; i >= 0; i-- { // letzte Byte zuerst (LSB)
|
|
if b[i] == 0 {
|
|
tz += 8
|
|
} else {
|
|
tz += bits.TrailingZeros8(b[i])
|
|
break
|
|
}
|
|
}
|
|
return tz
|
|
}
|