From b65c2f69b0a204600f71688ec47961c16d00d6b5 Mon Sep 17 00:00:00 2001 From: Bethuel Mmbaga Date: Tue, 27 Feb 2024 00:49:28 +0300 Subject: [PATCH] Add support for downloading Geo databases to the management service (#1626) Adds support for downloading Geo databases to the management service. If the Geo databases are not found, the service will automatically attempt to download them during startup. --- .../workflows/test-infrastructure-files.yml | 2 +- infrastructure_files/download-geolite2.sh | 11 +- management/cmd/management.go | 2 +- management/server/geolocation/database.go | 187 ++++++++++++++++++ management/server/geolocation/geolocation.go | 32 +-- management/server/geolocation/store.go | 31 ++- management/server/geolocation/utils.go | 176 +++++++++++++++++ 7 files changed, 404 insertions(+), 37 deletions(-) create mode 100644 management/server/geolocation/database.go create mode 100644 management/server/geolocation/utils.go diff --git a/.github/workflows/test-infrastructure-files.yml b/.github/workflows/test-infrastructure-files.yml index e1261dabc..8badb5b9b 100644 --- a/.github/workflows/test-infrastructure-files.yml +++ b/.github/workflows/test-infrastructure-files.yml @@ -199,6 +199,6 @@ jobs: - name: test script run: bash -x infrastructure_files/download-geolite2.sh - name: test mmdb file exists - run: ls -l GeoLite2-City_*/GeoLite2-City.mmdb + run: test -f GeoLite2-City.mmdb - name: test geonames file exists run: test -f geonames.db diff --git a/infrastructure_files/download-geolite2.sh b/infrastructure_files/download-geolite2.sh index e09873627..4a9db5e01 100755 --- a/infrastructure_files/download-geolite2.sh +++ b/infrastructure_files/download-geolite2.sh @@ -43,21 +43,18 @@ download_geolite_mmdb() { mkdir -p "$EXTRACTION_DIR" tar -xzvf "$DATABASE_FILE" > /dev/null 2>&1 - # Create a SHA256 signature file MMDB_FILE="GeoLite2-City.mmdb" - cd "$EXTRACTION_DIR" - sha256sum "$MMDB_FILE" > "$MMDB_FILE.sha256" - echo "SHA256 signature created for $MMDB_FILE." - cd - > /dev/null 2>&1 + cp "$EXTRACTION_DIR"/"$MMDB_FILE" $MMDB_FILE # Remove downloaded files + rm -r "$EXTRACTION_DIR" rm "$DATABASE_FILE" "$SIGNATURE_FILE" # Done. Print next steps echo "" echo "Process completed successfully." - echo "Now you can place $EXTRACTION_DIR/$MMDB_FILE to 'datadir' of management service." - echo -e "Example:\n\tdocker compose cp $EXTRACTION_DIR/$MMDB_FILE management:/var/lib/netbird/" + echo "Now you can place $MMDB_FILE to 'datadir' of management service." + echo -e "Example:\n\tdocker compose cp $MMDB_FILE management:/var/lib/netbird/" } diff --git a/management/cmd/management.go b/management/cmd/management.go index 002cf36d8..6b135ba01 100644 --- a/management/cmd/management.go +++ b/management/cmd/management.go @@ -166,7 +166,7 @@ var ( geo, err := geolocation.NewGeolocation(config.Datadir) if err != nil { - log.Warnf("could not initialize geo location service, we proceed without geo support") + log.Warnf("could not initialize geo location service: %v, we proceed without geo support", err) } else { log.Infof("geo location service has been initialized from %s", config.Datadir) } diff --git a/management/server/geolocation/database.go b/management/server/geolocation/database.go new file mode 100644 index 000000000..adf89b282 --- /dev/null +++ b/management/server/geolocation/database.go @@ -0,0 +1,187 @@ +package geolocation + +import ( + "encoding/csv" + "fmt" + "net/url" + "os" + "path" + "strconv" + + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/logger" +) + +const ( + geoLiteCityTarGZURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz" + geoLiteCityZipURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip" + geoLiteCitySha256TarURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256" + geoLiteCitySha256ZipURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip.sha256" +) + +// loadGeolocationDatabases loads the MaxMind databases. +func loadGeolocationDatabases(dataDir string) error { + files := []string{MMDBFileName, GeoSqliteDBFile} + for _, file := range files { + exists, _ := fileExists(path.Join(dataDir, file)) + if exists { + continue + } + + switch file { + case MMDBFileName: + extractFunc := func(src string, dst string) error { + if err := decompressTarGzFile(src, dst); err != nil { + return err + } + return os.Rename(path.Join(dst, MMDBFileName), path.Join(dataDir, MMDBFileName)) + } + if err := loadDatabase( + geoLiteCitySha256TarURL, + geoLiteCityTarGZURL, + extractFunc, + ); err != nil { + return err + } + + case GeoSqliteDBFile: + extractFunc := func(src string, dst string) error { + if err := decompressZipFile(src, dst); err != nil { + return err + } + extractedCsvFile := path.Join(dst, "GeoLite2-City-Locations-en.csv") + return importCsvToSqlite(dataDir, extractedCsvFile) + } + + if err := loadDatabase( + geoLiteCitySha256ZipURL, + geoLiteCityZipURL, + extractFunc, + ); err != nil { + return err + } + } + } + return nil +} + +// loadDatabase downloads a file from the specified URL and verifies its checksum. +// It then calls the extract function to perform additional processing on the extracted files. +func loadDatabase(checksumURL string, fileURL string, extractFunc func(src string, dst string) error) error { + temp, err := os.MkdirTemp(os.TempDir(), "geolite") + if err != nil { + return err + } + defer os.RemoveAll(temp) + + checksumFile := path.Join(temp, getDatabaseFileName(checksumURL)) + err = downloadFile(checksumURL, checksumFile) + if err != nil { + return err + } + + sha256sum, err := loadChecksumFromFile(checksumFile) + if err != nil { + return err + } + + dbFile := path.Join(temp, getDatabaseFileName(fileURL)) + err = downloadFile(fileURL, dbFile) + if err != nil { + return err + } + + if err := verifyChecksum(dbFile, sha256sum); err != nil { + return err + } + + return extractFunc(dbFile, temp) +} + +// importCsvToSqlite imports a CSV file into a SQLite database. +func importCsvToSqlite(dataDir string, csvFile string) error { + geonames, err := loadGeonamesCsv(csvFile) + if err != nil { + return err + } + + db, err := gorm.Open(sqlite.Open(path.Join(dataDir, GeoSqliteDBFile)), &gorm.Config{ + Logger: logger.Default.LogMode(logger.Silent), + CreateBatchSize: 1000, + PrepareStmt: true, + }) + if err != nil { + return err + } + defer func() { + sql, err := db.DB() + if err != nil { + return + } + sql.Close() + }() + + if err := db.AutoMigrate(&GeoNames{}); err != nil { + return err + } + + return db.Create(geonames).Error +} + +func loadGeonamesCsv(filepath string) ([]GeoNames, error) { + f, err := os.Open(filepath) + if err != nil { + return nil, err + } + defer f.Close() + + reader := csv.NewReader(f) + records, err := reader.ReadAll() + if err != nil { + return nil, err + } + + var geoNames []GeoNames + for index, record := range records { + if index == 0 { + continue + } + geoNameID, err := strconv.Atoi(record[0]) + if err != nil { + return nil, err + } + + geoName := GeoNames{ + GeoNameID: geoNameID, + LocaleCode: record[1], + ContinentCode: record[2], + ContinentName: record[3], + CountryIsoCode: record[4], + CountryName: record[5], + Subdivision1IsoCode: record[6], + Subdivision1Name: record[7], + Subdivision2IsoCode: record[8], + Subdivision2Name: record[9], + CityName: record[10], + MetroCode: record[11], + TimeZone: record[12], + IsInEuropeanUnion: record[13], + } + geoNames = append(geoNames, geoName) + } + + return geoNames, nil +} + +// getDatabaseFileName extracts the file name from a given URL string. +func getDatabaseFileName(urlStr string) string { + u, err := url.Parse(urlStr) + if err != nil { + panic(err) + } + + ext := u.Query().Get("suffix") + fileName := fmt.Sprintf("%s.%s", path.Base(u.Path), ext) + return fileName +} diff --git a/management/server/geolocation/geolocation.go b/management/server/geolocation/geolocation.go index de7a8af82..88cdfcb9f 100644 --- a/management/server/geolocation/geolocation.go +++ b/management/server/geolocation/geolocation.go @@ -2,9 +2,7 @@ package geolocation import ( "bytes" - "crypto/sha256" "fmt" - "io" "net" "os" "path" @@ -54,20 +52,23 @@ type Country struct { CountryName string } -func NewGeolocation(datadir string) (*Geolocation, error) { - mmdbPath := path.Join(datadir, MMDBFileName) +func NewGeolocation(dataDir string) (*Geolocation, error) { + if err := loadGeolocationDatabases(dataDir); err != nil { + return nil, fmt.Errorf("failed to load MaxMind databases: %v", err) + } + mmdbPath := path.Join(dataDir, MMDBFileName) db, err := openDB(mmdbPath) if err != nil { return nil, err } - sha256sum, err := getSha256sum(mmdbPath) + sha256sum, err := calculateFileSHA256(mmdbPath) if err != nil { return nil, err } - locationDB, err := NewSqliteStore(datadir) + locationDB, err := NewSqliteStore(dataDir) if err != nil { return nil, err } @@ -104,21 +105,6 @@ func openDB(mmdbPath string) (*maxminddb.Reader, error) { return db, nil } -func getSha256sum(mmdbPath string) ([]byte, error) { - f, err := os.Open(mmdbPath) - if err != nil { - return nil, err - } - defer f.Close() - - h := sha256.New() - if _, err := io.Copy(h, f); err != nil { - return nil, err - } - - return h.Sum(nil), nil -} - func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) { gl.mux.RLock() defer gl.mux.RUnlock() @@ -189,7 +175,7 @@ func (gl *Geolocation) reloader() { log.Errorf("geonames db reload failed: %s", err) } - newSha256sum1, err := getSha256sum(gl.mmdbPath) + newSha256sum1, err := calculateFileSHA256(gl.mmdbPath) if err != nil { log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) continue @@ -198,7 +184,7 @@ func (gl *Geolocation) reloader() { // we check sum twice just to avoid possible case when we reload during update of the file // considering the frequency of file update (few times a week) checking sum twice should be enough time.Sleep(50 * time.Millisecond) - newSha256sum2, err := getSha256sum(gl.mmdbPath) + newSha256sum2, err := calculateFileSHA256(gl.mmdbPath) if err != nil { log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) continue diff --git a/management/server/geolocation/store.go b/management/server/geolocation/store.go index 9f3638a7c..74401d6ca 100644 --- a/management/server/geolocation/store.go +++ b/management/server/geolocation/store.go @@ -20,6 +20,27 @@ const ( GeoSqliteDBFile = "geonames.db" ) +type GeoNames struct { + GeoNameID int `gorm:"column:geoname_id"` + LocaleCode string `gorm:"column:locale_code"` + ContinentCode string `gorm:"column:continent_code"` + ContinentName string `gorm:"column:continent_name"` + CountryIsoCode string `gorm:"column:country_iso_code"` + CountryName string `gorm:"column:country_name"` + Subdivision1IsoCode string `gorm:"column:subdivision_1_iso_code"` + Subdivision1Name string `gorm:"column:subdivision_1_name"` + Subdivision2IsoCode string `gorm:"column:subdivision_2_iso_code"` + Subdivision2Name string `gorm:"column:subdivision_2_name"` + CityName string `gorm:"column:city_name"` + MetroCode string `gorm:"column:metro_code"` + TimeZone string `gorm:"column:time_zone"` + IsInEuropeanUnion string `gorm:"column:is_in_european_union"` +} + +func (*GeoNames) TableName() string { + return "geonames" +} + // SqliteStore represents a location storage backed by a Sqlite DB. type SqliteStore struct { db *gorm.DB @@ -37,7 +58,7 @@ func NewSqliteStore(dataDir string) (*SqliteStore, error) { return nil, err } - sha256sum, err := getSha256sum(file) + sha256sum, err := calculateFileSHA256(file) if err != nil { return nil, err } @@ -60,7 +81,7 @@ func (s *SqliteStore) GetAllCountries() ([]Country, error) { } var countries []Country - result := s.db.Table("geonames"). + result := s.db.Model(&GeoNames{}). Select("country_iso_code", "country_name"). Group("country_name"). Scan(&countries) @@ -81,7 +102,7 @@ func (s *SqliteStore) GetCitiesByCountry(countryISOCode string) ([]City, error) } var cities []City - result := s.db.Table("geonames"). + result := s.db.Model(&GeoNames{}). Select("geoname_id", "city_name"). Where("country_iso_code = ?", countryISOCode). Group("city_name"). @@ -98,7 +119,7 @@ func (s *SqliteStore) reload() error { s.mux.Lock() defer s.mux.Unlock() - newSha256sum1, err := getSha256sum(s.filePath) + newSha256sum1, err := calculateFileSHA256(s.filePath) if err != nil { log.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err) } @@ -107,7 +128,7 @@ func (s *SqliteStore) reload() error { // we check sum twice just to avoid possible case when we reload during update of the file // considering the frequency of file update (few times a week) checking sum twice should be enough time.Sleep(50 * time.Millisecond) - newSha256sum2, err := getSha256sum(s.filePath) + newSha256sum2, err := calculateFileSHA256(s.filePath) if err != nil { return fmt.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err) } diff --git a/management/server/geolocation/utils.go b/management/server/geolocation/utils.go new file mode 100644 index 000000000..bdbd4732d --- /dev/null +++ b/management/server/geolocation/utils.go @@ -0,0 +1,176 @@ +package geolocation + +import ( + "archive/tar" + "archive/zip" + "bufio" + "bytes" + "compress/gzip" + "crypto/sha256" + "errors" + "fmt" + "io" + "net/http" + "os" + "path" + "strings" +) + +// decompressTarGzFile decompresses a .tar.gz file. +func decompressTarGzFile(filepath, destDir string) error { + file, err := os.Open(filepath) + if err != nil { + return err + } + defer file.Close() + + gzipReader, err := gzip.NewReader(file) + if err != nil { + return err + } + defer gzipReader.Close() + + tarReader := tar.NewReader(gzipReader) + + for { + header, err := tarReader.Next() + if err != nil { + if errors.Is(err, io.EOF) { + break + } + return err + } + + if header.Typeflag == tar.TypeReg { + outFile, err := os.Create(path.Join(destDir, path.Base(header.Name))) + if err != nil { + return err + } + + _, err = io.Copy(outFile, tarReader) // #nosec G110 + outFile.Close() + if err != nil { + return err + } + } + + } + + return nil +} + +// decompressZipFile decompresses a .zip file. +func decompressZipFile(filepath, destDir string) error { + r, err := zip.OpenReader(filepath) + if err != nil { + return err + } + defer r.Close() + + for _, f := range r.File { + if f.FileInfo().IsDir() { + continue + } + + outFile, err := os.Create(path.Join(destDir, path.Base(f.Name))) + if err != nil { + return err + } + + rc, err := f.Open() + if err != nil { + outFile.Close() + return err + } + + _, err = io.Copy(outFile, rc) // #nosec G110 + outFile.Close() + rc.Close() + if err != nil { + return err + } + } + + return nil +} + +// calculateFileSHA256 calculates the SHA256 checksum of a file. +func calculateFileSHA256(filepath string) ([]byte, error) { + file, err := os.Open(filepath) + if err != nil { + return nil, err + } + defer file.Close() + + h := sha256.New() + if _, err := io.Copy(h, file); err != nil { + return nil, err + } + + return h.Sum(nil), nil +} + +// loadChecksumFromFile loads the first checksum from a file. +func loadChecksumFromFile(filepath string) (string, error) { + file, err := os.Open(filepath) + if err != nil { + return "", err + } + defer file.Close() + + scanner := bufio.NewScanner(file) + if scanner.Scan() { + parts := strings.Fields(scanner.Text()) + if len(parts) > 0 { + return parts[0], nil + } + } + if err := scanner.Err(); err != nil { + return "", err + } + + return "", nil +} + +// verifyChecksum compares the calculated SHA256 checksum of a file against the expected checksum. +func verifyChecksum(filepath, expectedChecksum string) error { + calculatedChecksum, err := calculateFileSHA256(filepath) + + fileCheckSum := fmt.Sprintf("%x", calculatedChecksum) + if err != nil { + return err + } + + if fileCheckSum != expectedChecksum { + return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedChecksum, fileCheckSum) + } + + return nil +} + +// downloadFile downloads a file from a URL and saves it to a local file path. +func downloadFile(url, filepath string) error { + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected error occurred while downloading the file: %s", string(bodyBytes)) + } + + out, err := os.Create(filepath) + if err != nil { + return err + } + defer out.Close() + + _, err = io.Copy(out, bytes.NewBuffer(bodyBytes)) + return err +}