[proxy, management] Add header auth, access restrictions, and session idle timeout (#5587)

This commit is contained in:
Viktor Liu
2026-03-16 22:22:00 +08:00
committed by GitHub
parent 3e6baea405
commit 387e374e4b
34 changed files with 3509 additions and 1380 deletions

View File

@@ -0,0 +1,264 @@
package geolocation
import (
"archive/tar"
"bufio"
"compress/gzip"
"crypto/sha256"
"errors"
"fmt"
"io"
"mime"
"net/http"
"os"
"path/filepath"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
const (
mmdbTarGZURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz"
mmdbSha256URL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256"
mmdbInnerName = "GeoLite2-City.mmdb"
downloadTimeout = 2 * time.Minute
maxMMDBSize = 256 << 20 // 256 MB
)
// ensureMMDB checks for an existing MMDB file in dataDir. If none is found,
// it downloads from pkgs.netbird.io with SHA256 verification.
func ensureMMDB(logger *log.Logger, dataDir string) (string, error) {
if err := os.MkdirAll(dataDir, 0o755); err != nil {
return "", fmt.Errorf("create geo data directory %s: %w", dataDir, err)
}
pattern := filepath.Join(dataDir, mmdbGlob)
if files, _ := filepath.Glob(pattern); len(files) > 0 {
mmdbPath := files[len(files)-1]
logger.Debugf("using existing geolocation database: %s", mmdbPath)
return mmdbPath, nil
}
logger.Info("geolocation database not found, downloading from pkgs.netbird.io")
return downloadMMDB(logger, dataDir)
}
func downloadMMDB(logger *log.Logger, dataDir string) (string, error) {
client := &http.Client{Timeout: downloadTimeout}
datedName, err := fetchRemoteFilename(client, mmdbTarGZURL)
if err != nil {
return "", fmt.Errorf("get remote filename: %w", err)
}
mmdbFilename := deriveMMDBFilename(datedName)
mmdbPath := filepath.Join(dataDir, mmdbFilename)
tmp, err := os.MkdirTemp("", "geolite-proxy-*")
if err != nil {
return "", fmt.Errorf("create temp directory: %w", err)
}
defer os.RemoveAll(tmp)
checksumFile := filepath.Join(tmp, "checksum.sha256")
if err := downloadToFile(client, mmdbSha256URL, checksumFile); err != nil {
return "", fmt.Errorf("download checksum: %w", err)
}
expectedHash, err := readChecksumFile(checksumFile)
if err != nil {
return "", fmt.Errorf("read checksum: %w", err)
}
tarFile := filepath.Join(tmp, datedName)
logger.Debugf("downloading geolocation database (%s)", datedName)
if err := downloadToFile(client, mmdbTarGZURL, tarFile); err != nil {
return "", fmt.Errorf("download database: %w", err)
}
if err := verifySHA256(tarFile, expectedHash); err != nil {
return "", fmt.Errorf("verify database checksum: %w", err)
}
if err := extractMMDBFromTarGZ(tarFile, mmdbPath); err != nil {
return "", fmt.Errorf("extract database: %w", err)
}
logger.Infof("geolocation database downloaded: %s", mmdbPath)
return mmdbPath, nil
}
// deriveMMDBFilename converts a tar.gz filename to an MMDB filename.
// Example: GeoLite2-City_20240101.tar.gz -> GeoLite2-City_20240101.mmdb
func deriveMMDBFilename(tarName string) string {
base, _, _ := strings.Cut(tarName, ".")
if !strings.Contains(base, "_") {
return "GeoLite2-City.mmdb"
}
return base + ".mmdb"
}
func fetchRemoteFilename(client *http.Client, url string) (string, error) {
resp, err := client.Head(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HEAD request: HTTP %d", resp.StatusCode)
}
cd := resp.Header.Get("Content-Disposition")
if cd == "" {
return "", errors.New("no Content-Disposition header")
}
_, params, err := mime.ParseMediaType(cd)
if err != nil {
return "", fmt.Errorf("parse Content-Disposition: %w", err)
}
name := filepath.Base(params["filename"])
if name == "" || name == "." {
return "", errors.New("no filename in Content-Disposition")
}
return name, nil
}
func downloadToFile(client *http.Client, url, dest string) error {
resp, err := client.Get(url) //nolint:gosec
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
f, err := os.Create(dest) //nolint:gosec
if err != nil {
return err
}
defer f.Close()
// Cap download at 256 MB to prevent unbounded reads from a compromised server.
if _, err := io.Copy(f, io.LimitReader(resp.Body, maxMMDBSize)); err != nil {
return err
}
return nil
}
func readChecksumFile(path string) (string, error) {
f, err := os.Open(path) //nolint:gosec
if err != nil {
return "", err
}
defer f.Close()
scanner := bufio.NewScanner(f)
if scanner.Scan() {
parts := strings.Fields(scanner.Text())
if len(parts) > 0 {
return parts[0], nil
}
}
if err := scanner.Err(); err != nil {
return "", err
}
return "", errors.New("empty checksum file")
}
func verifySHA256(path, expected string) error {
f, err := os.Open(path) //nolint:gosec
if err != nil {
return err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return err
}
actual := fmt.Sprintf("%x", h.Sum(nil))
if actual != expected {
return fmt.Errorf("SHA256 mismatch: expected %s, got %s", expected, actual)
}
return nil
}
func extractMMDBFromTarGZ(tarGZPath, destPath string) error {
f, err := os.Open(tarGZPath) //nolint:gosec
if err != nil {
return err
}
defer f.Close()
gz, err := gzip.NewReader(f)
if err != nil {
return err
}
defer gz.Close()
tr := tar.NewReader(gz)
for {
hdr, err := tr.Next()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return err
}
if hdr.Typeflag == tar.TypeReg && filepath.Base(hdr.Name) == mmdbInnerName {
if hdr.Size < 0 || hdr.Size > maxMMDBSize {
return fmt.Errorf("mmdb entry size %d exceeds limit %d", hdr.Size, maxMMDBSize)
}
if err := extractToFileAtomic(io.LimitReader(tr, hdr.Size), destPath); err != nil {
return err
}
return nil
}
}
return fmt.Errorf("%s not found in archive", mmdbInnerName)
}
// extractToFileAtomic writes r to a temporary file in the same directory as
// destPath, then renames it into place so a crash never leaves a truncated file.
func extractToFileAtomic(r io.Reader, destPath string) error {
dir := filepath.Dir(destPath)
tmp, err := os.CreateTemp(dir, ".mmdb-*.tmp")
if err != nil {
return fmt.Errorf("create temp file: %w", err)
}
tmpPath := tmp.Name()
if _, err := io.Copy(tmp, r); err != nil { //nolint:gosec // G110: caller bounds with LimitReader
if closeErr := tmp.Close(); closeErr != nil {
log.Debugf("failed to close temp file %s: %v", tmpPath, closeErr)
}
if removeErr := os.Remove(tmpPath); removeErr != nil {
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
}
return fmt.Errorf("write mmdb: %w", err)
}
if err := tmp.Close(); err != nil {
if removeErr := os.Remove(tmpPath); removeErr != nil {
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
}
return fmt.Errorf("close temp file: %w", err)
}
if err := os.Rename(tmpPath, destPath); err != nil {
if removeErr := os.Remove(tmpPath); removeErr != nil {
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
}
return fmt.Errorf("rename to %s: %w", destPath, err)
}
return nil
}

View File

@@ -0,0 +1,152 @@
// Package geolocation provides IP-to-country lookups using MaxMind GeoLite2 databases.
package geolocation
import (
"fmt"
"net/netip"
"os"
"strconv"
"sync"
"github.com/oschwald/maxminddb-golang"
log "github.com/sirupsen/logrus"
)
const (
// EnvDisable disables geolocation lookups entirely when set to a truthy value.
EnvDisable = "NB_PROXY_DISABLE_GEOLOCATION"
mmdbGlob = "GeoLite2-City_*.mmdb"
)
type record struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
City struct {
Names struct {
En string `maxminddb:"en"`
} `maxminddb:"names"`
} `maxminddb:"city"`
Subdivisions []struct {
ISOCode string `maxminddb:"iso_code"`
Names struct {
En string `maxminddb:"en"`
} `maxminddb:"names"`
} `maxminddb:"subdivisions"`
}
// Result holds the outcome of a geo lookup.
type Result struct {
CountryCode string
CityName string
SubdivisionCode string
SubdivisionName string
}
// Lookup provides IP geolocation lookups.
type Lookup struct {
mu sync.RWMutex
db *maxminddb.Reader
logger *log.Logger
}
// NewLookup opens or downloads the GeoLite2-City MMDB in dataDir.
// Returns nil without error if geolocation is disabled via environment
// variable, no data directory is configured, or the download fails
// (graceful degradation: country restrictions will deny all requests).
func NewLookup(logger *log.Logger, dataDir string) (*Lookup, error) {
if isDisabledByEnv(logger) {
logger.Info("geolocation disabled via environment variable")
return nil, nil //nolint:nilnil
}
if dataDir == "" {
return nil, nil //nolint:nilnil
}
mmdbPath, err := ensureMMDB(logger, dataDir)
if err != nil {
logger.Warnf("geolocation database unavailable: %v", err)
logger.Warn("country-based access restrictions will deny all requests until a database is available")
return nil, nil //nolint:nilnil
}
db, err := maxminddb.Open(mmdbPath)
if err != nil {
return nil, fmt.Errorf("open GeoLite2 database %s: %w", mmdbPath, err)
}
logger.Infof("geolocation database loaded from %s", mmdbPath)
return &Lookup{db: db, logger: logger}, nil
}
// LookupAddr returns the country ISO code and city name for the given IP.
// Returns an empty Result if the database is nil or the lookup fails.
func (l *Lookup) LookupAddr(addr netip.Addr) Result {
if l == nil {
return Result{}
}
l.mu.RLock()
defer l.mu.RUnlock()
if l.db == nil {
return Result{}
}
addr = addr.Unmap()
var rec record
if err := l.db.Lookup(addr.AsSlice(), &rec); err != nil {
l.logger.Debugf("geolocation lookup %s: %v", addr, err)
return Result{}
}
r := Result{
CountryCode: rec.Country.ISOCode,
CityName: rec.City.Names.En,
}
if len(rec.Subdivisions) > 0 {
r.SubdivisionCode = rec.Subdivisions[0].ISOCode
r.SubdivisionName = rec.Subdivisions[0].Names.En
}
return r
}
// Available reports whether the lookup has a loaded database.
func (l *Lookup) Available() bool {
if l == nil {
return false
}
l.mu.RLock()
defer l.mu.RUnlock()
return l.db != nil
}
// Close releases the database resources.
func (l *Lookup) Close() error {
if l == nil {
return nil
}
l.mu.Lock()
defer l.mu.Unlock()
if l.db != nil {
err := l.db.Close()
l.db = nil
return err
}
return nil
}
func isDisabledByEnv(logger *log.Logger) bool {
val := os.Getenv(EnvDisable)
if val == "" {
return false
}
disabled, err := strconv.ParseBool(val)
if err != nil {
logger.Warnf("parse %s=%q: %v", EnvDisable, val, err)
return false
}
return disabled
}