mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
265 lines
6.9 KiB
Go
265 lines
6.9 KiB
Go
package geolocation
|
|
|
|
import (
|
|
"archive/tar"
|
|
"bufio"
|
|
"compress/gzip"
|
|
"crypto/sha256"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"mime"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
const (
|
|
mmdbTarGZURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz"
|
|
mmdbSha256URL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256"
|
|
mmdbInnerName = "GeoLite2-City.mmdb"
|
|
|
|
downloadTimeout = 2 * time.Minute
|
|
maxMMDBSize = 256 << 20 // 256 MB
|
|
)
|
|
|
|
// ensureMMDB checks for an existing MMDB file in dataDir. If none is found,
|
|
// it downloads from pkgs.netbird.io with SHA256 verification.
|
|
func ensureMMDB(logger *log.Logger, dataDir string) (string, error) {
|
|
if err := os.MkdirAll(dataDir, 0o755); err != nil {
|
|
return "", fmt.Errorf("create geo data directory %s: %w", dataDir, err)
|
|
}
|
|
|
|
pattern := filepath.Join(dataDir, mmdbGlob)
|
|
if files, _ := filepath.Glob(pattern); len(files) > 0 {
|
|
mmdbPath := files[len(files)-1]
|
|
logger.Debugf("using existing geolocation database: %s", mmdbPath)
|
|
return mmdbPath, nil
|
|
}
|
|
|
|
logger.Info("geolocation database not found, downloading from pkgs.netbird.io")
|
|
return downloadMMDB(logger, dataDir)
|
|
}
|
|
|
|
func downloadMMDB(logger *log.Logger, dataDir string) (string, error) {
|
|
client := &http.Client{Timeout: downloadTimeout}
|
|
|
|
datedName, err := fetchRemoteFilename(client, mmdbTarGZURL)
|
|
if err != nil {
|
|
return "", fmt.Errorf("get remote filename: %w", err)
|
|
}
|
|
|
|
mmdbFilename := deriveMMDBFilename(datedName)
|
|
mmdbPath := filepath.Join(dataDir, mmdbFilename)
|
|
|
|
tmp, err := os.MkdirTemp("", "geolite-proxy-*")
|
|
if err != nil {
|
|
return "", fmt.Errorf("create temp directory: %w", err)
|
|
}
|
|
defer os.RemoveAll(tmp)
|
|
|
|
checksumFile := filepath.Join(tmp, "checksum.sha256")
|
|
if err := downloadToFile(client, mmdbSha256URL, checksumFile); err != nil {
|
|
return "", fmt.Errorf("download checksum: %w", err)
|
|
}
|
|
|
|
expectedHash, err := readChecksumFile(checksumFile)
|
|
if err != nil {
|
|
return "", fmt.Errorf("read checksum: %w", err)
|
|
}
|
|
|
|
tarFile := filepath.Join(tmp, datedName)
|
|
logger.Debugf("downloading geolocation database (%s)", datedName)
|
|
if err := downloadToFile(client, mmdbTarGZURL, tarFile); err != nil {
|
|
return "", fmt.Errorf("download database: %w", err)
|
|
}
|
|
|
|
if err := verifySHA256(tarFile, expectedHash); err != nil {
|
|
return "", fmt.Errorf("verify database checksum: %w", err)
|
|
}
|
|
|
|
if err := extractMMDBFromTarGZ(tarFile, mmdbPath); err != nil {
|
|
return "", fmt.Errorf("extract database: %w", err)
|
|
}
|
|
|
|
logger.Infof("geolocation database downloaded: %s", mmdbPath)
|
|
return mmdbPath, nil
|
|
}
|
|
|
|
// deriveMMDBFilename converts a tar.gz filename to an MMDB filename.
|
|
// Example: GeoLite2-City_20240101.tar.gz -> GeoLite2-City_20240101.mmdb
|
|
func deriveMMDBFilename(tarName string) string {
|
|
base, _, _ := strings.Cut(tarName, ".")
|
|
if !strings.Contains(base, "_") {
|
|
return "GeoLite2-City.mmdb"
|
|
}
|
|
return base + ".mmdb"
|
|
}
|
|
|
|
func fetchRemoteFilename(client *http.Client, url string) (string, error) {
|
|
resp, err := client.Head(url)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return "", fmt.Errorf("HEAD request: HTTP %d", resp.StatusCode)
|
|
}
|
|
|
|
cd := resp.Header.Get("Content-Disposition")
|
|
if cd == "" {
|
|
return "", errors.New("no Content-Disposition header")
|
|
}
|
|
|
|
_, params, err := mime.ParseMediaType(cd)
|
|
if err != nil {
|
|
return "", fmt.Errorf("parse Content-Disposition: %w", err)
|
|
}
|
|
|
|
name := filepath.Base(params["filename"])
|
|
if name == "" || name == "." {
|
|
return "", errors.New("no filename in Content-Disposition")
|
|
}
|
|
return name, nil
|
|
}
|
|
|
|
func downloadToFile(client *http.Client, url, dest string) error {
|
|
resp, err := client.Get(url) //nolint:gosec
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
|
|
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
|
|
}
|
|
|
|
f, err := os.Create(dest) //nolint:gosec
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
// Cap download at 256 MB to prevent unbounded reads from a compromised server.
|
|
if _, err := io.Copy(f, io.LimitReader(resp.Body, maxMMDBSize)); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func readChecksumFile(path string) (string, error) {
|
|
f, err := os.Open(path) //nolint:gosec
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer f.Close()
|
|
|
|
scanner := bufio.NewScanner(f)
|
|
if scanner.Scan() {
|
|
parts := strings.Fields(scanner.Text())
|
|
if len(parts) > 0 {
|
|
return parts[0], nil
|
|
}
|
|
}
|
|
if err := scanner.Err(); err != nil {
|
|
return "", err
|
|
}
|
|
return "", errors.New("empty checksum file")
|
|
}
|
|
|
|
func verifySHA256(path, expected string) error {
|
|
f, err := os.Open(path) //nolint:gosec
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
h := sha256.New()
|
|
if _, err := io.Copy(h, f); err != nil {
|
|
return err
|
|
}
|
|
|
|
actual := fmt.Sprintf("%x", h.Sum(nil))
|
|
if actual != expected {
|
|
return fmt.Errorf("SHA256 mismatch: expected %s, got %s", expected, actual)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func extractMMDBFromTarGZ(tarGZPath, destPath string) error {
|
|
f, err := os.Open(tarGZPath) //nolint:gosec
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
gz, err := gzip.NewReader(f)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer gz.Close()
|
|
|
|
tr := tar.NewReader(gz)
|
|
for {
|
|
hdr, err := tr.Next()
|
|
if err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
break
|
|
}
|
|
return err
|
|
}
|
|
|
|
if hdr.Typeflag == tar.TypeReg && filepath.Base(hdr.Name) == mmdbInnerName {
|
|
if hdr.Size < 0 || hdr.Size > maxMMDBSize {
|
|
return fmt.Errorf("mmdb entry size %d exceeds limit %d", hdr.Size, maxMMDBSize)
|
|
}
|
|
if err := extractToFileAtomic(io.LimitReader(tr, hdr.Size), destPath); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return fmt.Errorf("%s not found in archive", mmdbInnerName)
|
|
}
|
|
|
|
// extractToFileAtomic writes r to a temporary file in the same directory as
|
|
// destPath, then renames it into place so a crash never leaves a truncated file.
|
|
func extractToFileAtomic(r io.Reader, destPath string) error {
|
|
dir := filepath.Dir(destPath)
|
|
tmp, err := os.CreateTemp(dir, ".mmdb-*.tmp")
|
|
if err != nil {
|
|
return fmt.Errorf("create temp file: %w", err)
|
|
}
|
|
tmpPath := tmp.Name()
|
|
|
|
if _, err := io.Copy(tmp, r); err != nil { //nolint:gosec // G110: caller bounds with LimitReader
|
|
if closeErr := tmp.Close(); closeErr != nil {
|
|
log.Debugf("failed to close temp file %s: %v", tmpPath, closeErr)
|
|
}
|
|
if removeErr := os.Remove(tmpPath); removeErr != nil {
|
|
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
|
|
}
|
|
return fmt.Errorf("write mmdb: %w", err)
|
|
}
|
|
if err := tmp.Close(); err != nil {
|
|
if removeErr := os.Remove(tmpPath); removeErr != nil {
|
|
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
|
|
}
|
|
return fmt.Errorf("close temp file: %w", err)
|
|
}
|
|
if err := os.Rename(tmpPath, destPath); err != nil {
|
|
if removeErr := os.Remove(tmpPath); removeErr != nil {
|
|
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
|
|
}
|
|
return fmt.Errorf("rename to %s: %w", destPath, err)
|
|
}
|
|
return nil
|
|
}
|