Files
netbird/proxy/internal/geolocation/download.go

265 lines
6.9 KiB
Go

package geolocation
import (
"archive/tar"
"bufio"
"compress/gzip"
"crypto/sha256"
"errors"
"fmt"
"io"
"mime"
"net/http"
"os"
"path/filepath"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
const (
mmdbTarGZURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz"
mmdbSha256URL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256"
mmdbInnerName = "GeoLite2-City.mmdb"
downloadTimeout = 2 * time.Minute
maxMMDBSize = 256 << 20 // 256 MB
)
// ensureMMDB checks for an existing MMDB file in dataDir. If none is found,
// it downloads from pkgs.netbird.io with SHA256 verification.
func ensureMMDB(logger *log.Logger, dataDir string) (string, error) {
if err := os.MkdirAll(dataDir, 0o755); err != nil {
return "", fmt.Errorf("create geo data directory %s: %w", dataDir, err)
}
pattern := filepath.Join(dataDir, mmdbGlob)
if files, _ := filepath.Glob(pattern); len(files) > 0 {
mmdbPath := files[len(files)-1]
logger.Debugf("using existing geolocation database: %s", mmdbPath)
return mmdbPath, nil
}
logger.Info("geolocation database not found, downloading from pkgs.netbird.io")
return downloadMMDB(logger, dataDir)
}
func downloadMMDB(logger *log.Logger, dataDir string) (string, error) {
client := &http.Client{Timeout: downloadTimeout}
datedName, err := fetchRemoteFilename(client, mmdbTarGZURL)
if err != nil {
return "", fmt.Errorf("get remote filename: %w", err)
}
mmdbFilename := deriveMMDBFilename(datedName)
mmdbPath := filepath.Join(dataDir, mmdbFilename)
tmp, err := os.MkdirTemp("", "geolite-proxy-*")
if err != nil {
return "", fmt.Errorf("create temp directory: %w", err)
}
defer os.RemoveAll(tmp)
checksumFile := filepath.Join(tmp, "checksum.sha256")
if err := downloadToFile(client, mmdbSha256URL, checksumFile); err != nil {
return "", fmt.Errorf("download checksum: %w", err)
}
expectedHash, err := readChecksumFile(checksumFile)
if err != nil {
return "", fmt.Errorf("read checksum: %w", err)
}
tarFile := filepath.Join(tmp, datedName)
logger.Debugf("downloading geolocation database (%s)", datedName)
if err := downloadToFile(client, mmdbTarGZURL, tarFile); err != nil {
return "", fmt.Errorf("download database: %w", err)
}
if err := verifySHA256(tarFile, expectedHash); err != nil {
return "", fmt.Errorf("verify database checksum: %w", err)
}
if err := extractMMDBFromTarGZ(tarFile, mmdbPath); err != nil {
return "", fmt.Errorf("extract database: %w", err)
}
logger.Infof("geolocation database downloaded: %s", mmdbPath)
return mmdbPath, nil
}
// deriveMMDBFilename converts a tar.gz filename to an MMDB filename.
// Example: GeoLite2-City_20240101.tar.gz -> GeoLite2-City_20240101.mmdb
func deriveMMDBFilename(tarName string) string {
base, _, _ := strings.Cut(tarName, ".")
if !strings.Contains(base, "_") {
return "GeoLite2-City.mmdb"
}
return base + ".mmdb"
}
func fetchRemoteFilename(client *http.Client, url string) (string, error) {
resp, err := client.Head(url)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("HEAD request: HTTP %d", resp.StatusCode)
}
cd := resp.Header.Get("Content-Disposition")
if cd == "" {
return "", errors.New("no Content-Disposition header")
}
_, params, err := mime.ParseMediaType(cd)
if err != nil {
return "", fmt.Errorf("parse Content-Disposition: %w", err)
}
name := filepath.Base(params["filename"])
if name == "" || name == "." {
return "", errors.New("no filename in Content-Disposition")
}
return name, nil
}
func downloadToFile(client *http.Client, url, dest string) error {
resp, err := client.Get(url) //nolint:gosec
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
f, err := os.Create(dest) //nolint:gosec
if err != nil {
return err
}
defer f.Close()
// Cap download at 256 MB to prevent unbounded reads from a compromised server.
if _, err := io.Copy(f, io.LimitReader(resp.Body, maxMMDBSize)); err != nil {
return err
}
return nil
}
func readChecksumFile(path string) (string, error) {
f, err := os.Open(path) //nolint:gosec
if err != nil {
return "", err
}
defer f.Close()
scanner := bufio.NewScanner(f)
if scanner.Scan() {
parts := strings.Fields(scanner.Text())
if len(parts) > 0 {
return parts[0], nil
}
}
if err := scanner.Err(); err != nil {
return "", err
}
return "", errors.New("empty checksum file")
}
func verifySHA256(path, expected string) error {
f, err := os.Open(path) //nolint:gosec
if err != nil {
return err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return err
}
actual := fmt.Sprintf("%x", h.Sum(nil))
if actual != expected {
return fmt.Errorf("SHA256 mismatch: expected %s, got %s", expected, actual)
}
return nil
}
func extractMMDBFromTarGZ(tarGZPath, destPath string) error {
f, err := os.Open(tarGZPath) //nolint:gosec
if err != nil {
return err
}
defer f.Close()
gz, err := gzip.NewReader(f)
if err != nil {
return err
}
defer gz.Close()
tr := tar.NewReader(gz)
for {
hdr, err := tr.Next()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return err
}
if hdr.Typeflag == tar.TypeReg && filepath.Base(hdr.Name) == mmdbInnerName {
if hdr.Size < 0 || hdr.Size > maxMMDBSize {
return fmt.Errorf("mmdb entry size %d exceeds limit %d", hdr.Size, maxMMDBSize)
}
if err := extractToFileAtomic(io.LimitReader(tr, hdr.Size), destPath); err != nil {
return err
}
return nil
}
}
return fmt.Errorf("%s not found in archive", mmdbInnerName)
}
// extractToFileAtomic writes r to a temporary file in the same directory as
// destPath, then renames it into place so a crash never leaves a truncated file.
func extractToFileAtomic(r io.Reader, destPath string) error {
dir := filepath.Dir(destPath)
tmp, err := os.CreateTemp(dir, ".mmdb-*.tmp")
if err != nil {
return fmt.Errorf("create temp file: %w", err)
}
tmpPath := tmp.Name()
if _, err := io.Copy(tmp, r); err != nil { //nolint:gosec // G110: caller bounds with LimitReader
if closeErr := tmp.Close(); closeErr != nil {
log.Debugf("failed to close temp file %s: %v", tmpPath, closeErr)
}
if removeErr := os.Remove(tmpPath); removeErr != nil {
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
}
return fmt.Errorf("write mmdb: %w", err)
}
if err := tmp.Close(); err != nil {
if removeErr := os.Remove(tmpPath); removeErr != nil {
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
}
return fmt.Errorf("close temp file: %w", err)
}
if err := os.Rename(tmpPath, destPath); err != nil {
if removeErr := os.Remove(tmpPath); removeErr != nil {
log.Debugf("failed to remove temp file %s: %v", tmpPath, removeErr)
}
return fmt.Errorf("rename to %s: %w", destPath, err)
}
return nil
}