mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[proxy, management] Add header auth, access restrictions, and session idle timeout (#5587)
This commit is contained in:
264
proxy/internal/geolocation/download.go
Normal file
264
proxy/internal/geolocation/download.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package geolocation
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bufio"
|
||||
"compress/gzip"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
mmdbTarGZURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz"
|
||||
mmdbSha256URL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256"
|
||||
mmdbInnerName = "GeoLite2-City.mmdb"
|
||||
|
||||
downloadTimeout = 2 * time.Minute
|
||||
maxMMDBSize = 256 << 20 // 256 MB
|
||||
)
|
||||
|
||||
// ensureMMDB checks for an existing MMDB file in dataDir. If none is found,
|
||||
// it downloads from pkgs.netbird.io with SHA256 verification.
|
||||
func ensureMMDB(logger *log.Logger, dataDir string) (string, error) {
|
||||
if err := os.MkdirAll(dataDir, 0o755); err != nil {
|
||||
return "", fmt.Errorf("create geo data directory %s: %w", dataDir, err)
|
||||
}
|
||||
|
||||
pattern := filepath.Join(dataDir, mmdbGlob)
|
||||
if files, _ := filepath.Glob(pattern); len(files) > 0 {
|
||||
mmdbPath := files[len(files)-1]
|
||||
logger.Debugf("using existing geolocation database: %s", mmdbPath)
|
||||
return mmdbPath, nil
|
||||
}
|
||||
|
||||
logger.Info("geolocation database not found, downloading from pkgs.netbird.io")
|
||||
return downloadMMDB(logger, dataDir)
|
||||
}
|
||||
|
||||
func downloadMMDB(logger *log.Logger, dataDir string) (string, error) {
|
||||
client := &http.Client{Timeout: downloadTimeout}
|
||||
|
||||
datedName, err := fetchRemoteFilename(client, mmdbTarGZURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get remote filename: %w", err)
|
||||
}
|
||||
|
||||
mmdbFilename := deriveMMDBFilename(datedName)
|
||||
mmdbPath := filepath.Join(dataDir, mmdbFilename)
|
||||
|
||||
tmp, err := os.MkdirTemp("", "geolite-proxy-*")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create temp directory: %w", err)
|
||||
}
|
||||
defer os.RemoveAll(tmp)
|
||||
|
||||
checksumFile := filepath.Join(tmp, "checksum.sha256")
|
||||
if err := downloadToFile(client, mmdbSha256URL, checksumFile); err != nil {
|
||||
return "", fmt.Errorf("download checksum: %w", err)
|
||||
}
|
||||
|
||||
expectedHash, err := readChecksumFile(checksumFile)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read checksum: %w", err)
|
||||
}
|
||||
|
||||
tarFile := filepath.Join(tmp, datedName)
|
||||
logger.Debugf("downloading geolocation database (%s)", datedName)
|
||||
if err := downloadToFile(client, mmdbTarGZURL, tarFile); err != nil {
|
||||
return "", fmt.Errorf("download database: %w", err)
|
||||
}
|
||||
|
||||
if err := verifySHA256(tarFile, expectedHash); err != nil {
|
||||
return "", fmt.Errorf("verify database checksum: %w", err)
|
||||
}
|
||||
|
||||
if err := extractMMDBFromTarGZ(tarFile, mmdbPath); err != nil {
|
||||
return "", fmt.Errorf("extract database: %w", err)
|
||||
}
|
||||
|
||||
logger.Infof("geolocation database downloaded: %s", mmdbPath)
|
||||
return mmdbPath, nil
|
||||
}
|
||||
|
||||
// deriveMMDBFilename converts a tar.gz filename to an MMDB filename.
|
||||
// Example: GeoLite2-City_20240101.tar.gz -> GeoLite2-City_20240101.mmdb
|
||||
func deriveMMDBFilename(tarName string) string {
|
||||
base, _, _ := strings.Cut(tarName, ".")
|
||||
if !strings.Contains(base, "_") {
|
||||
return "GeoLite2-City.mmdb"
|
||||
}
|
||||
return base + ".mmdb"
|
||||
}
|
||||
|
||||
func fetchRemoteFilename(client *http.Client, url string) (string, error) {
|
||||
resp, err := client.Head(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("HEAD request: HTTP %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
cd := resp.Header.Get("Content-Disposition")
|
||||
if cd == "" {
|
||||
return "", errors.New("no Content-Disposition header")
|
||||
}
|
||||
|
||||
_, params, err := mime.ParseMediaType(cd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("parse Content-Disposition: %w", err)
|
||||
}
|
||||
|
||||
name := filepath.Base(params["filename"])
|
||||
if name == "" || name == "." {
|
||||
return "", errors.New("no filename in Content-Disposition")
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
func downloadToFile(client *http.Client, url, dest string) error {
|
||||
resp, err := client.Get(url) //nolint:gosec
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
||||
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
|
||||
}
|
||||
|
||||
f, err := os.Create(dest) //nolint:gosec
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Cap download at 256 MB to prevent unbounded reads from a compromised server.
|
||||
if _, err := io.Copy(f, io.LimitReader(resp.Body, maxMMDBSize)); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func readChecksumFile(path string) (string, error) {
|
||||
f, err := os.Open(path) //nolint:gosec
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
scanner := bufio.NewScanner(f)
|
||||
if scanner.Scan() {
|
||||
parts := strings.Fields(scanner.Text())
|
||||
if len(parts) > 0 {
|
||||
return parts[0], nil
|
||||
}
|
||||
}
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "", errors.New("empty checksum file")
|
||||
}
|
||||
|
||||
func verifySHA256(path, expected string) error {
|
||||
f, err := os.Open(path) //nolint:gosec
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, f); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
actual := fmt.Sprintf("%x", h.Sum(nil))
|
||||
if actual != expected {
|
||||
return fmt.Errorf("SHA256 mismatch: expected %s, got %s", expected, actual)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractMMDBFromTarGZ(tarGZPath, destPath string) error {
|
||||
f, err := os.Open(tarGZPath) //nolint:gosec
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
gz, err := gzip.NewReader(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer gz.Close()
|
||||
|
||||
tr := tar.NewReader(gz)
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if hdr.Typeflag == tar.TypeReg && filepath.Base(hdr.Name) == mmdbInnerName {
|
||||
if hdr.Size < 0 || hdr.Size > maxMMDBSize {
|
||||
return fmt.Errorf("mmdb entry size %d exceeds limit %d", hdr.Size, maxMMDBSize)
|
||||
}
|
||||
if err := extractToFileAtomic(io.LimitReader(tr, hdr.Size), destPath); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("%s not found in archive", mmdbInnerName)
|
||||
}
|
||||
|
||||
// extractToFileAtomic writes r to a temporary file in the same directory as
|
||||
// destPath, then renames it into place so a crash never leaves a truncated file.
|
||||
func extractToFileAtomic(r io.Reader, destPath string) error {
|
||||
dir := filepath.Dir(destPath)
|
||||
tmp, err := os.CreateTemp(dir, ".mmdb-*.tmp")
|
||||
if err != nil {
|
||||
return fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmp.Name()
|
||||
|
||||
if _, err := io.Copy(tmp, r); err != nil { //nolint:gosec // G110: caller bounds with LimitReader
|
||||
if closeErr := tmp.Close(); closeErr != nil {
|
||||
log.Debugf("failed to close temp file %s: %v", tmpPath, closeErr)
|
||||
}
|
||||
if removeErr := os.Remove(tmpPath); removeErr != nil {
|
||||
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
|
||||
}
|
||||
return fmt.Errorf("write mmdb: %w", err)
|
||||
}
|
||||
if err := tmp.Close(); err != nil {
|
||||
if removeErr := os.Remove(tmpPath); removeErr != nil {
|
||||
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
|
||||
}
|
||||
return fmt.Errorf("close temp file: %w", err)
|
||||
}
|
||||
if err := os.Rename(tmpPath, destPath); err != nil {
|
||||
if removeErr := os.Remove(tmpPath); removeErr != nil {
|
||||
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
|
||||
}
|
||||
return fmt.Errorf("rename to %s: %w", destPath, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
152
proxy/internal/geolocation/geolocation.go
Normal file
152
proxy/internal/geolocation/geolocation.go
Normal file
@@ -0,0 +1,152 @@
|
||||
// Package geolocation provides IP-to-country lookups using MaxMind GeoLite2 databases.
|
||||
package geolocation
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// EnvDisable disables geolocation lookups entirely when set to a truthy value.
|
||||
EnvDisable = "NB_PROXY_DISABLE_GEOLOCATION"
|
||||
|
||||
mmdbGlob = "GeoLite2-City_*.mmdb"
|
||||
)
|
||||
|
||||
type record struct {
|
||||
Country struct {
|
||||
ISOCode string `maxminddb:"iso_code"`
|
||||
} `maxminddb:"country"`
|
||||
City struct {
|
||||
Names struct {
|
||||
En string `maxminddb:"en"`
|
||||
} `maxminddb:"names"`
|
||||
} `maxminddb:"city"`
|
||||
Subdivisions []struct {
|
||||
ISOCode string `maxminddb:"iso_code"`
|
||||
Names struct {
|
||||
En string `maxminddb:"en"`
|
||||
} `maxminddb:"names"`
|
||||
} `maxminddb:"subdivisions"`
|
||||
}
|
||||
|
||||
// Result holds the outcome of a geo lookup.
|
||||
type Result struct {
|
||||
CountryCode string
|
||||
CityName string
|
||||
SubdivisionCode string
|
||||
SubdivisionName string
|
||||
}
|
||||
|
||||
// Lookup provides IP geolocation lookups.
|
||||
type Lookup struct {
|
||||
mu sync.RWMutex
|
||||
db *maxminddb.Reader
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewLookup opens or downloads the GeoLite2-City MMDB in dataDir.
|
||||
// Returns nil without error if geolocation is disabled via environment
|
||||
// variable, no data directory is configured, or the download fails
|
||||
// (graceful degradation: country restrictions will deny all requests).
|
||||
func NewLookup(logger *log.Logger, dataDir string) (*Lookup, error) {
|
||||
if isDisabledByEnv(logger) {
|
||||
logger.Info("geolocation disabled via environment variable")
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
if dataDir == "" {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
mmdbPath, err := ensureMMDB(logger, dataDir)
|
||||
if err != nil {
|
||||
logger.Warnf("geolocation database unavailable: %v", err)
|
||||
logger.Warn("country-based access restrictions will deny all requests until a database is available")
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
db, err := maxminddb.Open(mmdbPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open GeoLite2 database %s: %w", mmdbPath, err)
|
||||
}
|
||||
|
||||
logger.Infof("geolocation database loaded from %s", mmdbPath)
|
||||
return &Lookup{db: db, logger: logger}, nil
|
||||
}
|
||||
|
||||
// LookupAddr returns the country ISO code and city name for the given IP.
|
||||
// Returns an empty Result if the database is nil or the lookup fails.
|
||||
func (l *Lookup) LookupAddr(addr netip.Addr) Result {
|
||||
if l == nil {
|
||||
return Result{}
|
||||
}
|
||||
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
|
||||
if l.db == nil {
|
||||
return Result{}
|
||||
}
|
||||
|
||||
addr = addr.Unmap()
|
||||
|
||||
var rec record
|
||||
if err := l.db.Lookup(addr.AsSlice(), &rec); err != nil {
|
||||
l.logger.Debugf("geolocation lookup %s: %v", addr, err)
|
||||
return Result{}
|
||||
}
|
||||
r := Result{
|
||||
CountryCode: rec.Country.ISOCode,
|
||||
CityName: rec.City.Names.En,
|
||||
}
|
||||
if len(rec.Subdivisions) > 0 {
|
||||
r.SubdivisionCode = rec.Subdivisions[0].ISOCode
|
||||
r.SubdivisionName = rec.Subdivisions[0].Names.En
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Available reports whether the lookup has a loaded database.
|
||||
func (l *Lookup) Available() bool {
|
||||
if l == nil {
|
||||
return false
|
||||
}
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
return l.db != nil
|
||||
}
|
||||
|
||||
// Close releases the database resources.
|
||||
func (l *Lookup) Close() error {
|
||||
if l == nil {
|
||||
return nil
|
||||
}
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if l.db != nil {
|
||||
err := l.db.Close()
|
||||
l.db = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isDisabledByEnv(logger *log.Logger) bool {
|
||||
val := os.Getenv(EnvDisable)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
disabled, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
logger.Warnf("parse %s=%q: %v", EnvDisable, val, err)
|
||||
return false
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
Reference in New Issue
Block a user