package main import ( "crypto/subtle" "encoding/json" "errors" "fmt" "io" "log" "net" "net/http" "net/netip" "os" "path/filepath" "strings" "sync" "time" ) // --- IP matcher ------------------------------------------------------------- type ipSet struct { singles map[netip.Addr]struct{} prefixes []netip.Prefix } func newIPSet() *ipSet { return &ipSet{singles: make(map[netip.Addr]struct{})} } func parseIPOrCIDR(s string) (addr *netip.Addr, pfx *netip.Prefix, err error) { s = strings.TrimSpace(s) if s == "" { return nil, nil, errors.New("empty entry") } if strings.Contains(s, "/") { pr, err := netip.ParsePrefix(s) if err != nil { return nil, nil, fmt.Errorf("invalid CIDR: %w", err) } pr = pr.Masked() return nil, &pr, nil } a, err := netip.ParseAddr(s) if err != nil { return nil, nil, fmt.Errorf("invalid IP: %w", err) } a = a.Unmap() return &a, nil, nil } func (s *ipSet) add(entry string) (string, error) { if addr, pfx, err := parseIPOrCIDR(entry); err != nil { return "", err } else if addr != nil { s.singles[*addr] = struct{}{} return addr.String(), nil } else { s.prefixes = append(s.prefixes, *pfx) return pfx.String(), nil } } func (s *ipSet) remove(entry string) bool { if addr, pfx, err := parseIPOrCIDR(entry); err == nil { if addr != nil { if _, ok := s.singles[*addr]; ok { delete(s.singles, *addr) return true } return false } // remove matching prefix norm := pfx.String() for i, pr := range s.prefixes { if pr.String() == norm { s.prefixes = append(s.prefixes[:i], s.prefixes[i+1:]...) return true } } } return false } func (s *ipSet) contains(a netip.Addr) bool { a = a.Unmap() if _, ok := s.singles[a]; ok { return true } for _, p := range s.prefixes { if p.Contains(a) { return true } } return false } // --- State & persistence ---------------------------------------------------- type Mode string const ( ModeBlock Mode = "block" // allow unless in block list ModeAllow Mode = "allow" // deny unless in allow list ) type stateFile struct { Mode Mode `json:"mode"` Block []string `json:"block"` Allow []string `json:"allow"` } type state struct { mu sync.RWMutex mode Mode block *ipSet allow *ipSet path string } func newState(path string) *state { return &state{mode: ModeBlock, block: newIPSet(), allow: newIPSet(), path: path} } func (s *state) load() error { b, err := os.ReadFile(s.path) if err != nil { if errors.Is(err, os.ErrNotExist) { return nil } return err } var sf stateFile if err := json.Unmarshal(b, &sf); err != nil { return err } if sf.Mode != "" { s.mode = sf.Mode } for _, e := range sf.Block { _, _ = s.block.add(e) } for _, e := range sf.Allow { _, _ = s.allow.add(e) } return nil } func (s *state) save() error { s.mu.RLock() defer s.mu.RUnlock() // rebuild normalized lists block := make([]string, 0, len(s.block.singles)+len(s.block.prefixes)) for a := range s.block.singles { block = append(block, a.String()) } for _, p := range s.block.prefixes { block = append(block, p.String()) } allow := make([]string, 0, len(s.allow.singles)+len(s.allow.prefixes)) for a := range s.allow.singles { allow = append(allow, a.String()) } for _, p := range s.allow.prefixes { allow = append(allow, p.String()) } out := stateFile{Mode: s.mode, Block: block, Allow: allow} data, _ := json.MarshalIndent(out, "", " ") tmp := s.path + ".tmp" if err := os.WriteFile(tmp, data, 0o600); err != nil { return err } return os.Rename(tmp, s.path) } // --- Auth logic ------------------------------------------------------------- func firstIPFromXFF(xff string) (netip.Addr, bool) { // XFF: client, proxy1, proxy2, ... -> take the FIRST as original client parts := strings.Split(xff, ",") if len(parts) == 0 { return netip.Addr{}, false } s := strings.TrimSpace(parts[0]) if s == "" { return netip.Addr{}, false } a, err := netip.ParseAddr(s) if err != nil { return netip.Addr{}, false } return a.Unmap(), true } func extractClientIP(r *http.Request) (netip.Addr, error) { if xff := r.Header.Get("X-Forwarded-For"); xff != "" { if a, ok := firstIPFromXFF(xff); ok { return a, nil } } if xr := r.Header.Get("X-Real-IP"); xr != "" { if a, err := netip.ParseAddr(strings.TrimSpace(xr)); err == nil { return a.Unmap(), nil } } host, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { host = r.RemoteAddr } a, err := netip.ParseAddr(host) if err != nil { return netip.Addr{}, fmt.Errorf("cannot parse client IP: %w", err) } return a.Unmap(), nil } func decide(st *state, ip netip.Addr) (allowed bool) { st.mu.RLock() defer st.mu.RUnlock() switch st.mode { case ModeBlock: return !st.block.contains(ip) case ModeAllow: return st.allow.contains(ip) default: return true } } // --- Admin auth ------------------------------------------------------------- func basicAuthOK(r *http.Request, user, pass string) bool { u, p, ok := r.BasicAuth() if !ok { return false } // constant-time compare if subtle.ConstantTimeCompare([]byte(u), []byte(user)) != 1 { return false } if subtle.ConstantTimeCompare([]byte(p), []byte(pass)) != 1 { return false } return true } func requireAdmin(next http.Handler, enabled bool, user, pass string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if !enabled { http.Error(w, "admin disabled: set ADMIN_USER and ADMIN_PASS", http.StatusServiceUnavailable) return } if !basicAuthOK(r, user, pass) { w.Header().Set("WWW-Authenticate", `Basic realm="ipfilter-admin"`) http.Error(w, "auth required", http.StatusUnauthorized) return } next.ServeHTTP(w, r) }) } // --- HTTP handlers ---------------------------------------------------------- func authHandler(st *state) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip, err := extractClientIP(r) if err != nil { http.Error(w, "cannot determine client IP", http.StatusForbidden) return } if decide(st, ip) { //log.Printf("Allowed: %s", ip) w.Header().Set("X-IPFilter-Client", ip.String()) w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte("OK")) return } log.Printf("Blocked: %s", ip) w.Header().Set("X-IPFilter-Client", ip.String()) http.Error(w, "forbidden by ip filter", http.StatusForbidden) }) } func adminPage() string { return `
Manage block/allow lists used by Traefik ForwardAuth. The auth endpoint is /auth. This UI requires HTTP Basic Auth.