Compare commits

...

7 Commits

Author SHA1 Message Date
Bethuel Mmbaga
f64e73ca70 Fix invalid cross-device link when moving geolocation databases (#1638)
* Fix invalid cross-device link when move geonames db

* Add test for geolocation databases in workflow

This step checks the existence and proper functioning of geolocation databases, including GeoLite2-City.mmdb and Geonames.db. It will help us ensure that geolocation databases are loaded correctly in the management.

* Enable debug mode

* Increase sleep duration in geolocation tests
2024-02-28 16:42:33 +03:00
pascal-fischer
b085419ab8 FIx order when validating account settings (#1632)
* moved extraSettings validation to the end

* moved extraSettings validation directly after permission check
2024-02-27 14:17:22 +01:00
Bethuel Mmbaga
d78b652ff7 Rename PrivateNetworkCheck to PeerNetworkRangeCheck (#1629)
* Rename PrivateNetworkCheck to PeerNetworkRangeCheck

* update description and example

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-02-27 11:59:48 +01:00
Viktor Liu
7251150c1c Combine update-available and connected/disconnected tray icon states (#1615)
This PR updates the system tray icons to reflect both connection status and availability of updates. Now, the tray will show distinct icons for the following states: connected, disconnected, update available while connected, and update available while disconnected. This change improves user experience by providing a clear visual status indicator.

- Add new icons for connected and disconnected states with update available.
- Implement logic to switch icons based on connection status and update availability.
- Remove old icon references for default and update states.
2024-02-26 23:28:33 +01:00
Bethuel Mmbaga
b65c2f69b0 Add support for downloading Geo databases to the management service (#1626)
Adds support for downloading Geo databases to the management service. If the Geo databases are not found, the service will automatically attempt to download them during startup.
2024-02-26 22:49:28 +01:00
Yury Gargay
d8ce08d898 Extend bypass middleware with support of wildcard paths (#1628)
---------

Co-authored-by: Viktor Liu <viktor@netbird.io>
2024-02-26 17:54:58 +01:00
Maycon Santos
e1c50248d9 Add support for device flow on getting started with zitadel (#1616) 2024-02-26 12:33:16 +01:00
34 changed files with 698 additions and 181 deletions

View File

@@ -162,6 +162,13 @@ jobs:
test $count -eq 4 test $count -eq 4
working-directory: infrastructure_files/artifacts working-directory: infrastructure_files/artifacts
- name: test geolocation databases
working-directory: infrastructure_files/artifacts
run: |
sleep 30
docker compose exec management ls -l /var/lib/netbird/ | grep -i GeoLite2-City.mmdb
docker compose exec management ls -l /var/lib/netbird/ | grep -i geonames.db
test-getting-started-script: test-getting-started-script:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
@@ -199,6 +206,6 @@ jobs:
- name: test script - name: test script
run: bash -x infrastructure_files/download-geolite2.sh run: bash -x infrastructure_files/download-geolite2.sh
- name: test mmdb file exists - name: test mmdb file exists
run: ls -l GeoLite2-City_*/GeoLite2-City.mmdb run: test -f GeoLite2-City.mmdb
- name: test geonames file exists - name: test geonames file exists
run: test -f geonames.db run: test -f geonames.db

View File

@@ -54,7 +54,7 @@ nfpms:
contents: contents:
- src: client/ui/netbird.desktop - src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird-systemtray-default.png - src: client/ui/netbird-systemtray-connected.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird
@@ -71,7 +71,7 @@ nfpms:
contents: contents:
- src: client/ui/netbird.desktop - src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird-systemtray-default.png - src: client/ui/netbird-systemtray-connected.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird

View File

@@ -82,17 +82,23 @@ var iconConnectedICO []byte
//go:embed netbird-systemtray-connected.png //go:embed netbird-systemtray-connected.png
var iconConnectedPNG []byte var iconConnectedPNG []byte
//go:embed netbird-systemtray-default.ico //go:embed netbird-systemtray-disconnected.ico
var iconDisconnectedICO []byte var iconDisconnectedICO []byte
//go:embed netbird-systemtray-default.png //go:embed netbird-systemtray-disconnected.png
var iconDisconnectedPNG []byte var iconDisconnectedPNG []byte
//go:embed netbird-systemtray-update.ico //go:embed netbird-systemtray-update-disconnected.ico
var iconUpdateICO []byte var iconUpdateDisconnectedICO []byte
//go:embed netbird-systemtray-update.png //go:embed netbird-systemtray-update-disconnected.png
var iconUpdatePNG []byte var iconUpdateDisconnectedPNG []byte
//go:embed netbird-systemtray-update-connected.ico
var iconUpdateConnectedICO []byte
//go:embed netbird-systemtray-update-connected.png
var iconUpdateConnectedPNG []byte
//go:embed netbird-systemtray-update-cloud.ico //go:embed netbird-systemtray-update-cloud.ico
var iconUpdateCloudICO []byte var iconUpdateCloudICO []byte
@@ -105,10 +111,11 @@ type serviceClient struct {
addr string addr string
conn proto.DaemonServiceClient conn proto.DaemonServiceClient
icConnected []byte icConnected []byte
icDisconnected []byte icDisconnected []byte
icUpdate []byte icUpdateConnected []byte
icUpdateCloud []byte icUpdateDisconnected []byte
icUpdateCloud []byte
// systray menu items // systray menu items
mStatus *systray.MenuItem mStatus *systray.MenuItem
@@ -139,6 +146,7 @@ type serviceClient struct {
preSharedKey string preSharedKey string
adminURL string adminURL string
connected bool
update *version.Update update *version.Update
daemonVersion string daemonVersion string
updateIndicationLock sync.Mutex updateIndicationLock sync.Mutex
@@ -161,13 +169,15 @@ func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
s.icConnected = iconConnectedICO s.icConnected = iconConnectedICO
s.icDisconnected = iconDisconnectedICO s.icDisconnected = iconDisconnectedICO
s.icUpdate = iconUpdateICO s.icUpdateConnected = iconUpdateConnectedICO
s.icUpdateDisconnected = iconUpdateDisconnectedICO
s.icUpdateCloud = iconUpdateCloudICO s.icUpdateCloud = iconUpdateCloudICO
} else { } else {
s.icConnected = iconConnectedPNG s.icConnected = iconConnectedPNG
s.icDisconnected = iconDisconnectedPNG s.icDisconnected = iconDisconnectedPNG
s.icUpdate = iconUpdatePNG s.icUpdateConnected = iconUpdateConnectedPNG
s.icUpdateDisconnected = iconUpdateDisconnectedPNG
s.icUpdateCloud = iconUpdateCloudPNG s.icUpdateCloud = iconUpdateCloudPNG
} }
@@ -369,7 +379,10 @@ func (s *serviceClient) updateStatus() error {
var systrayIconState bool var systrayIconState bool
if status.Status == string(internal.StatusConnected) && !s.mUp.Disabled() { if status.Status == string(internal.StatusConnected) && !s.mUp.Disabled() {
if !s.isUpdateIconActive { s.connected = true
if s.isUpdateIconActive {
systray.SetIcon(s.icUpdateConnected)
} else {
systray.SetIcon(s.icConnected) systray.SetIcon(s.icConnected)
} }
systray.SetTooltip("NetBird (Connected)") systray.SetTooltip("NetBird (Connected)")
@@ -378,7 +391,10 @@ func (s *serviceClient) updateStatus() error {
s.mDown.Enable() s.mDown.Enable()
systrayIconState = true systrayIconState = true
} else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() { } else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() {
if !s.isUpdateIconActive { s.connected = false
if s.isUpdateIconActive {
systray.SetIcon(s.icUpdateDisconnected)
} else {
systray.SetIcon(s.icDisconnected) systray.SetIcon(s.icDisconnected)
} }
systray.SetTooltip("NetBird (Disconnected)") systray.SetTooltip("NetBird (Disconnected)")
@@ -605,10 +621,13 @@ func (s *serviceClient) onUpdateAvailable() {
defer s.updateIndicationLock.Unlock() defer s.updateIndicationLock.Unlock()
s.mUpdate.Show() s.mUpdate.Show()
s.mAbout.SetIcon(s.icUpdateCloud)
s.isUpdateIconActive = true s.isUpdateIconActive = true
systray.SetIcon(s.icUpdate)
if s.connected {
systray.SetIcon(s.icUpdateConnected)
} else {
systray.SetIcon(s.icUpdateDisconnected)
}
} }
func openURL(url string) error { func openURL(url string) error {

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.3 KiB

After

Width:  |  Height:  |  Size: 5.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.1 KiB

After

Width:  |  Height:  |  Size: 8.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.3 KiB

View File

@@ -43,21 +43,18 @@ download_geolite_mmdb() {
mkdir -p "$EXTRACTION_DIR" mkdir -p "$EXTRACTION_DIR"
tar -xzvf "$DATABASE_FILE" > /dev/null 2>&1 tar -xzvf "$DATABASE_FILE" > /dev/null 2>&1
# Create a SHA256 signature file
MMDB_FILE="GeoLite2-City.mmdb" MMDB_FILE="GeoLite2-City.mmdb"
cd "$EXTRACTION_DIR" cp "$EXTRACTION_DIR"/"$MMDB_FILE" $MMDB_FILE
sha256sum "$MMDB_FILE" > "$MMDB_FILE.sha256"
echo "SHA256 signature created for $MMDB_FILE."
cd - > /dev/null 2>&1
# Remove downloaded files # Remove downloaded files
rm -r "$EXTRACTION_DIR"
rm "$DATABASE_FILE" "$SIGNATURE_FILE" rm "$DATABASE_FILE" "$SIGNATURE_FILE"
# Done. Print next steps # Done. Print next steps
echo "" echo ""
echo "Process completed successfully." echo "Process completed successfully."
echo "Now you can place $EXTRACTION_DIR/$MMDB_FILE to 'datadir' of management service." echo "Now you can place $MMDB_FILE to 'datadir' of management service."
echo -e "Example:\n\tdocker compose cp $EXTRACTION_DIR/$MMDB_FILE management:/var/lib/netbird/" echo -e "Example:\n\tdocker compose cp $MMDB_FILE management:/var/lib/netbird/"
} }

View File

@@ -137,6 +137,13 @@ create_new_application() {
BASE_REDIRECT_URL2=$5 BASE_REDIRECT_URL2=$5
LOGOUT_URL=$6 LOGOUT_URL=$6
ZITADEL_DEV_MODE=$7 ZITADEL_DEV_MODE=$7
DEVICE_CODE=$8
if [[ $DEVICE_CODE == "true" ]]; then
GRANT_TYPES='["OIDC_GRANT_TYPE_AUTHORIZATION_CODE","OIDC_GRANT_TYPE_DEVICE_CODE","OIDC_GRANT_TYPE_REFRESH_TOKEN"]'
else
GRANT_TYPES='["OIDC_GRANT_TYPE_AUTHORIZATION_CODE","OIDC_GRANT_TYPE_REFRESH_TOKEN"]'
fi
RESPONSE=$( RESPONSE=$(
curl -sS -X POST "$INSTANCE_URL/management/v1/projects/$PROJECT_ID/apps/oidc" \ curl -sS -X POST "$INSTANCE_URL/management/v1/projects/$PROJECT_ID/apps/oidc" \
@@ -154,10 +161,7 @@ create_new_application() {
"RESPONSETypes": [ "RESPONSETypes": [
"OIDC_RESPONSE_TYPE_CODE" "OIDC_RESPONSE_TYPE_CODE"
], ],
"grantTypes": [ "grantTypes": '"$GRANT_TYPES"',
"OIDC_GRANT_TYPE_AUTHORIZATION_CODE",
"OIDC_GRANT_TYPE_REFRESH_TOKEN"
],
"appType": "OIDC_APP_TYPE_USER_AGENT", "appType": "OIDC_APP_TYPE_USER_AGENT",
"authMethodType": "OIDC_AUTH_METHOD_TYPE_NONE", "authMethodType": "OIDC_AUTH_METHOD_TYPE_NONE",
"version": "OIDC_VERSION_1_0", "version": "OIDC_VERSION_1_0",
@@ -340,10 +344,10 @@ init_zitadel() {
# create zitadel spa applications # create zitadel spa applications
echo "Creating new Zitadel SPA Dashboard application" echo "Creating new Zitadel SPA Dashboard application"
DASHBOARD_APPLICATION_CLIENT_ID=$(create_new_application "$INSTANCE_URL" "$PAT" "Dashboard" "$BASE_REDIRECT_URL/nb-auth" "$BASE_REDIRECT_URL/nb-silent-auth" "$BASE_REDIRECT_URL/" "$ZITADEL_DEV_MODE") DASHBOARD_APPLICATION_CLIENT_ID=$(create_new_application "$INSTANCE_URL" "$PAT" "Dashboard" "$BASE_REDIRECT_URL/nb-auth" "$BASE_REDIRECT_URL/nb-silent-auth" "$BASE_REDIRECT_URL/" "$ZITADEL_DEV_MODE" "false")
echo "Creating new Zitadel SPA Cli application" echo "Creating new Zitadel SPA Cli application"
CLI_APPLICATION_CLIENT_ID=$(create_new_application "$INSTANCE_URL" "$PAT" "Cli" "http://localhost:53000/" "http://localhost:54000/" "http://localhost:53000/" "true") CLI_APPLICATION_CLIENT_ID=$(create_new_application "$INSTANCE_URL" "$PAT" "Cli" "http://localhost:53000/" "http://localhost:54000/" "http://localhost:53000/" "true" "true")
MACHINE_USER_ID=$(create_service_user "$INSTANCE_URL" "$PAT") MACHINE_USER_ID=$(create_service_user "$INSTANCE_URL" "$PAT")
@@ -561,6 +565,8 @@ renderCaddyfile() {
reverse_proxy /.well-known/openid-configuration h2c://zitadel:8080 reverse_proxy /.well-known/openid-configuration h2c://zitadel:8080
reverse_proxy /openapi/* h2c://zitadel:8080 reverse_proxy /openapi/* h2c://zitadel:8080
reverse_proxy /debug/* h2c://zitadel:8080 reverse_proxy /debug/* h2c://zitadel:8080
reverse_proxy /device/* h2c://zitadel:8080
reverse_proxy /device h2c://zitadel:8080
# Dashboard # Dashboard
reverse_proxy /* dashboard:80 reverse_proxy /* dashboard:80
} }
@@ -629,6 +635,14 @@ renderManagementJson() {
"ManagementEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/management/v1" "ManagementEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/management/v1"
} }
}, },
"DeviceAuthorizationFlow": {
"Provider": "hosted",
"ProviderConfig": {
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"Scope": "openid"
}
},
"PKCEAuthorizationFlow": { "PKCEAuthorizationFlow": {
"ProviderConfig": { "ProviderConfig": {
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI", "Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",

View File

@@ -166,7 +166,7 @@ var (
geo, err := geolocation.NewGeolocation(config.Datadir) geo, err := geolocation.NewGeolocation(config.Datadir)
if err != nil { if err != nil {
log.Warnf("could not initialize geo location service, we proceed without geo support") log.Warnf("could not initialize geo location service: %v, we proceed without geo support", err)
} else { } else {
log.Infof("geo location service has been initialized from %s", config.Datadir) log.Infof("geo location service has been initialized from %s", config.Datadir)
} }

View File

@@ -917,12 +917,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccountByUser(userID) account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
err = additions.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID, am.eventStore)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -936,6 +931,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
} }
err = additions.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID, am.eventStore)
if err != nil {
return nil, err
}
oldSettings := account.Settings oldSettings := account.Settings
if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled { if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled {
event := activity.AccountPeerLoginExpirationEnabled event := activity.AccountPeerLoginExpirationEnabled

View File

@@ -0,0 +1,210 @@
package geolocation
import (
"encoding/csv"
"fmt"
"io"
"net/url"
"os"
"path"
"strconv"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
const (
geoLiteCityTarGZURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz"
geoLiteCityZipURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip"
geoLiteCitySha256TarURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256"
geoLiteCitySha256ZipURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip.sha256"
)
// loadGeolocationDatabases loads the MaxMind databases.
func loadGeolocationDatabases(dataDir string) error {
files := []string{MMDBFileName, GeoSqliteDBFile}
for _, file := range files {
exists, _ := fileExists(path.Join(dataDir, file))
if exists {
continue
}
switch file {
case MMDBFileName:
extractFunc := func(src string, dst string) error {
if err := decompressTarGzFile(src, dst); err != nil {
return err
}
return copyFile(path.Join(dst, MMDBFileName), path.Join(dataDir, MMDBFileName))
}
if err := loadDatabase(
geoLiteCitySha256TarURL,
geoLiteCityTarGZURL,
extractFunc,
); err != nil {
return err
}
case GeoSqliteDBFile:
extractFunc := func(src string, dst string) error {
if err := decompressZipFile(src, dst); err != nil {
return err
}
extractedCsvFile := path.Join(dst, "GeoLite2-City-Locations-en.csv")
return importCsvToSqlite(dataDir, extractedCsvFile)
}
if err := loadDatabase(
geoLiteCitySha256ZipURL,
geoLiteCityZipURL,
extractFunc,
); err != nil {
return err
}
}
}
return nil
}
// loadDatabase downloads a file from the specified URL and verifies its checksum.
// It then calls the extract function to perform additional processing on the extracted files.
func loadDatabase(checksumURL string, fileURL string, extractFunc func(src string, dst string) error) error {
temp, err := os.MkdirTemp(os.TempDir(), "geolite")
if err != nil {
return err
}
defer os.RemoveAll(temp)
checksumFile := path.Join(temp, getDatabaseFileName(checksumURL))
err = downloadFile(checksumURL, checksumFile)
if err != nil {
return err
}
sha256sum, err := loadChecksumFromFile(checksumFile)
if err != nil {
return err
}
dbFile := path.Join(temp, getDatabaseFileName(fileURL))
err = downloadFile(fileURL, dbFile)
if err != nil {
return err
}
if err := verifyChecksum(dbFile, sha256sum); err != nil {
return err
}
return extractFunc(dbFile, temp)
}
// importCsvToSqlite imports a CSV file into a SQLite database.
func importCsvToSqlite(dataDir string, csvFile string) error {
geonames, err := loadGeonamesCsv(csvFile)
if err != nil {
return err
}
db, err := gorm.Open(sqlite.Open(path.Join(dataDir, GeoSqliteDBFile)), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
CreateBatchSize: 1000,
PrepareStmt: true,
})
if err != nil {
return err
}
defer func() {
sql, err := db.DB()
if err != nil {
return
}
sql.Close()
}()
if err := db.AutoMigrate(&GeoNames{}); err != nil {
return err
}
return db.Create(geonames).Error
}
func loadGeonamesCsv(filepath string) ([]GeoNames, error) {
f, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer f.Close()
reader := csv.NewReader(f)
records, err := reader.ReadAll()
if err != nil {
return nil, err
}
var geoNames []GeoNames
for index, record := range records {
if index == 0 {
continue
}
geoNameID, err := strconv.Atoi(record[0])
if err != nil {
return nil, err
}
geoName := GeoNames{
GeoNameID: geoNameID,
LocaleCode: record[1],
ContinentCode: record[2],
ContinentName: record[3],
CountryIsoCode: record[4],
CountryName: record[5],
Subdivision1IsoCode: record[6],
Subdivision1Name: record[7],
Subdivision2IsoCode: record[8],
Subdivision2Name: record[9],
CityName: record[10],
MetroCode: record[11],
TimeZone: record[12],
IsInEuropeanUnion: record[13],
}
geoNames = append(geoNames, geoName)
}
return geoNames, nil
}
// getDatabaseFileName extracts the file name from a given URL string.
func getDatabaseFileName(urlStr string) string {
u, err := url.Parse(urlStr)
if err != nil {
panic(err)
}
ext := u.Query().Get("suffix")
fileName := fmt.Sprintf("%s.%s", path.Base(u.Path), ext)
return fileName
}
// copyFile performs a file copy operation from the source file to the destination.
func copyFile(src string, dst string) error {
srcFile, err := os.Open(src)
if err != nil {
return err
}
defer srcFile.Close()
dstFile, err := os.Create(dst)
if err != nil {
return err
}
defer dstFile.Close()
_, err = io.Copy(dstFile, srcFile)
if err != nil {
return err
}
return nil
}

View File

@@ -2,9 +2,7 @@ package geolocation
import ( import (
"bytes" "bytes"
"crypto/sha256"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"path" "path"
@@ -54,20 +52,23 @@ type Country struct {
CountryName string CountryName string
} }
func NewGeolocation(datadir string) (*Geolocation, error) { func NewGeolocation(dataDir string) (*Geolocation, error) {
mmdbPath := path.Join(datadir, MMDBFileName) if err := loadGeolocationDatabases(dataDir); err != nil {
return nil, fmt.Errorf("failed to load MaxMind databases: %v", err)
}
mmdbPath := path.Join(dataDir, MMDBFileName)
db, err := openDB(mmdbPath) db, err := openDB(mmdbPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sha256sum, err := getSha256sum(mmdbPath) sha256sum, err := calculateFileSHA256(mmdbPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
locationDB, err := NewSqliteStore(datadir) locationDB, err := NewSqliteStore(dataDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -104,21 +105,6 @@ func openDB(mmdbPath string) (*maxminddb.Reader, error) {
return db, nil return db, nil
} }
func getSha256sum(mmdbPath string) ([]byte, error) {
f, err := os.Open(mmdbPath)
if err != nil {
return nil, err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return nil, err
}
return h.Sum(nil), nil
}
func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) { func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) {
gl.mux.RLock() gl.mux.RLock()
defer gl.mux.RUnlock() defer gl.mux.RUnlock()
@@ -189,7 +175,7 @@ func (gl *Geolocation) reloader() {
log.Errorf("geonames db reload failed: %s", err) log.Errorf("geonames db reload failed: %s", err)
} }
newSha256sum1, err := getSha256sum(gl.mmdbPath) newSha256sum1, err := calculateFileSHA256(gl.mmdbPath)
if err != nil { if err != nil {
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
continue continue
@@ -198,7 +184,7 @@ func (gl *Geolocation) reloader() {
// we check sum twice just to avoid possible case when we reload during update of the file // we check sum twice just to avoid possible case when we reload during update of the file
// considering the frequency of file update (few times a week) checking sum twice should be enough // considering the frequency of file update (few times a week) checking sum twice should be enough
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
newSha256sum2, err := getSha256sum(gl.mmdbPath) newSha256sum2, err := calculateFileSHA256(gl.mmdbPath)
if err != nil { if err != nil {
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
continue continue

View File

@@ -20,6 +20,27 @@ const (
GeoSqliteDBFile = "geonames.db" GeoSqliteDBFile = "geonames.db"
) )
type GeoNames struct {
GeoNameID int `gorm:"column:geoname_id"`
LocaleCode string `gorm:"column:locale_code"`
ContinentCode string `gorm:"column:continent_code"`
ContinentName string `gorm:"column:continent_name"`
CountryIsoCode string `gorm:"column:country_iso_code"`
CountryName string `gorm:"column:country_name"`
Subdivision1IsoCode string `gorm:"column:subdivision_1_iso_code"`
Subdivision1Name string `gorm:"column:subdivision_1_name"`
Subdivision2IsoCode string `gorm:"column:subdivision_2_iso_code"`
Subdivision2Name string `gorm:"column:subdivision_2_name"`
CityName string `gorm:"column:city_name"`
MetroCode string `gorm:"column:metro_code"`
TimeZone string `gorm:"column:time_zone"`
IsInEuropeanUnion string `gorm:"column:is_in_european_union"`
}
func (*GeoNames) TableName() string {
return "geonames"
}
// SqliteStore represents a location storage backed by a Sqlite DB. // SqliteStore represents a location storage backed by a Sqlite DB.
type SqliteStore struct { type SqliteStore struct {
db *gorm.DB db *gorm.DB
@@ -37,7 +58,7 @@ func NewSqliteStore(dataDir string) (*SqliteStore, error) {
return nil, err return nil, err
} }
sha256sum, err := getSha256sum(file) sha256sum, err := calculateFileSHA256(file)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -60,7 +81,7 @@ func (s *SqliteStore) GetAllCountries() ([]Country, error) {
} }
var countries []Country var countries []Country
result := s.db.Table("geonames"). result := s.db.Model(&GeoNames{}).
Select("country_iso_code", "country_name"). Select("country_iso_code", "country_name").
Group("country_name"). Group("country_name").
Scan(&countries) Scan(&countries)
@@ -81,7 +102,7 @@ func (s *SqliteStore) GetCitiesByCountry(countryISOCode string) ([]City, error)
} }
var cities []City var cities []City
result := s.db.Table("geonames"). result := s.db.Model(&GeoNames{}).
Select("geoname_id", "city_name"). Select("geoname_id", "city_name").
Where("country_iso_code = ?", countryISOCode). Where("country_iso_code = ?", countryISOCode).
Group("city_name"). Group("city_name").
@@ -98,7 +119,7 @@ func (s *SqliteStore) reload() error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
newSha256sum1, err := getSha256sum(s.filePath) newSha256sum1, err := calculateFileSHA256(s.filePath)
if err != nil { if err != nil {
log.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err) log.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
} }
@@ -107,7 +128,7 @@ func (s *SqliteStore) reload() error {
// we check sum twice just to avoid possible case when we reload during update of the file // we check sum twice just to avoid possible case when we reload during update of the file
// considering the frequency of file update (few times a week) checking sum twice should be enough // considering the frequency of file update (few times a week) checking sum twice should be enough
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
newSha256sum2, err := getSha256sum(s.filePath) newSha256sum2, err := calculateFileSHA256(s.filePath)
if err != nil { if err != nil {
return fmt.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err) return fmt.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
} }

View File

@@ -0,0 +1,176 @@
package geolocation
import (
"archive/tar"
"archive/zip"
"bufio"
"bytes"
"compress/gzip"
"crypto/sha256"
"errors"
"fmt"
"io"
"net/http"
"os"
"path"
"strings"
)
// decompressTarGzFile decompresses a .tar.gz file.
func decompressTarGzFile(filepath, destDir string) error {
file, err := os.Open(filepath)
if err != nil {
return err
}
defer file.Close()
gzipReader, err := gzip.NewReader(file)
if err != nil {
return err
}
defer gzipReader.Close()
tarReader := tar.NewReader(gzipReader)
for {
header, err := tarReader.Next()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return err
}
if header.Typeflag == tar.TypeReg {
outFile, err := os.Create(path.Join(destDir, path.Base(header.Name)))
if err != nil {
return err
}
_, err = io.Copy(outFile, tarReader) // #nosec G110
outFile.Close()
if err != nil {
return err
}
}
}
return nil
}
// decompressZipFile decompresses a .zip file.
func decompressZipFile(filepath, destDir string) error {
r, err := zip.OpenReader(filepath)
if err != nil {
return err
}
defer r.Close()
for _, f := range r.File {
if f.FileInfo().IsDir() {
continue
}
outFile, err := os.Create(path.Join(destDir, path.Base(f.Name)))
if err != nil {
return err
}
rc, err := f.Open()
if err != nil {
outFile.Close()
return err
}
_, err = io.Copy(outFile, rc) // #nosec G110
outFile.Close()
rc.Close()
if err != nil {
return err
}
}
return nil
}
// calculateFileSHA256 calculates the SHA256 checksum of a file.
func calculateFileSHA256(filepath string) ([]byte, error) {
file, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer file.Close()
h := sha256.New()
if _, err := io.Copy(h, file); err != nil {
return nil, err
}
return h.Sum(nil), nil
}
// loadChecksumFromFile loads the first checksum from a file.
func loadChecksumFromFile(filepath string) (string, error) {
file, err := os.Open(filepath)
if err != nil {
return "", err
}
defer file.Close()
scanner := bufio.NewScanner(file)
if scanner.Scan() {
parts := strings.Fields(scanner.Text())
if len(parts) > 0 {
return parts[0], nil
}
}
if err := scanner.Err(); err != nil {
return "", err
}
return "", nil
}
// verifyChecksum compares the calculated SHA256 checksum of a file against the expected checksum.
func verifyChecksum(filepath, expectedChecksum string) error {
calculatedChecksum, err := calculateFileSHA256(filepath)
fileCheckSum := fmt.Sprintf("%x", calculatedChecksum)
if err != nil {
return err
}
if fileCheckSum != expectedChecksum {
return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedChecksum, fileCheckSum)
}
return nil
}
// downloadFile downloads a file from a URL and saves it to a local file path.
func downloadFile(url, filepath string) error {
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected error occurred while downloading the file: %s", string(bodyBytes))
}
out, err := os.Create(filepath)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, bytes.NewBuffer(bodyBytes))
return err
}

View File

@@ -862,8 +862,8 @@ components:
$ref: '#/components/schemas/OSVersionCheck' $ref: '#/components/schemas/OSVersionCheck'
geo_location_check: geo_location_check:
$ref: '#/components/schemas/GeoLocationCheck' $ref: '#/components/schemas/GeoLocationCheck'
private_network_check: peer_network_range_check:
$ref: '#/components/schemas/PrivateNetworkCheck' $ref: '#/components/schemas/PeerNetworkRangeCheck'
NBVersionCheck: NBVersionCheck:
description: Posture check for the version of NetBird description: Posture check for the version of NetBird
type: object type: object
@@ -934,16 +934,16 @@ components:
required: required:
- locations - locations
- action - action
PrivateNetworkCheck: PeerNetworkRangeCheck:
description: Posture check for allow or deny private network description: Posture check for allow or deny access based on peer local network addresses
type: object type: object
properties: properties:
ranges: ranges:
description: List of private network ranges in CIDR notation description: List of peer network ranges in CIDR notation
type: array type: array
items: items:
type: string type: string
example: ["192.168.1.0/24", "10.0.0.0/8"] example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
action: action:
description: Action to take upon policy match description: Action to take upon policy match
type: string type: string

View File

@@ -74,6 +74,12 @@ const (
NameserverNsTypeUdp NameserverNsType = "udp" NameserverNsTypeUdp NameserverNsType = "udp"
) )
// Defines values for PeerNetworkRangeCheckAction.
const (
PeerNetworkRangeCheckActionAllow PeerNetworkRangeCheckAction = "allow"
PeerNetworkRangeCheckActionDeny PeerNetworkRangeCheckAction = "deny"
)
// Defines values for PolicyRuleAction. // Defines values for PolicyRuleAction.
const ( const (
PolicyRuleActionAccept PolicyRuleAction = "accept" PolicyRuleActionAccept PolicyRuleAction = "accept"
@@ -116,12 +122,6 @@ const (
PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp" PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
) )
// Defines values for PrivateNetworkCheckAction.
const (
PrivateNetworkCheckActionAllow PrivateNetworkCheckAction = "allow"
PrivateNetworkCheckActionDeny PrivateNetworkCheckAction = "deny"
)
// Defines values for UserStatus. // Defines values for UserStatus.
const ( const (
UserStatusActive UserStatus = "active" UserStatusActive UserStatus = "active"
@@ -199,8 +199,8 @@ type Checks struct {
// OsVersionCheck Posture check for the version of operating system // OsVersionCheck Posture check for the version of operating system
OsVersionCheck *OSVersionCheck `json:"os_version_check,omitempty"` OsVersionCheck *OSVersionCheck `json:"os_version_check,omitempty"`
// PrivateNetworkCheck Posture check for allow or deny private network // PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses
PrivateNetworkCheck *PrivateNetworkCheck `json:"private_network_check,omitempty"` PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:"peer_network_range_check,omitempty"`
} }
// City Describe city geographical location information // City Describe city geographical location information
@@ -656,6 +656,18 @@ type PeerMinimum struct {
Name string `json:"name"` Name string `json:"name"`
} }
// PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses
type PeerNetworkRangeCheck struct {
// Action Action to take upon policy match
Action PeerNetworkRangeCheckAction `json:"action"`
// Ranges List of peer network ranges in CIDR notation
Ranges []string `json:"ranges"`
}
// PeerNetworkRangeCheckAction Action to take upon policy match
type PeerNetworkRangeCheckAction string
// PeerRequest defines model for PeerRequest. // PeerRequest defines model for PeerRequest.
type PeerRequest struct { type PeerRequest struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval // ApprovalRequired (Cloud only) Indicates whether peer needs approval
@@ -898,18 +910,6 @@ type PostureCheckUpdate struct {
Name string `json:"name"` Name string `json:"name"`
} }
// PrivateNetworkCheck Posture check for allow or deny private network
type PrivateNetworkCheck struct {
// Action Action to take upon policy match
Action PrivateNetworkCheckAction `json:"action"`
// Ranges List of private network ranges in CIDR notation
Ranges []string `json:"ranges"`
}
// PrivateNetworkCheckAction Action to take upon policy match
type PrivateNetworkCheckAction string
// Route defines model for Route. // Route defines model for Route.
type Route struct { type Route struct {
// Description Route description // Description Route description

View File

@@ -177,7 +177,10 @@ func TestAuthMiddleware_Handler(t *testing.T) {
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
if tc.shouldBypassAuth { if tc.shouldBypassAuth {
bypass.AddBypassPath(tc.path) err := bypass.AddBypassPath(tc.path)
if err != nil {
t.Fatalf("failed to add bypass path: %v", err)
}
} }
req := httptest.NewRequest("GET", "http://testing"+tc.path, nil) req := httptest.NewRequest("GET", "http://testing"+tc.path, nil)

View File

@@ -1,8 +1,12 @@
package bypass package bypass
import ( import (
"fmt"
"net/http" "net/http"
"path"
"sync" "sync"
log "github.com/sirupsen/logrus"
) )
var byPassMutex sync.RWMutex var byPassMutex sync.RWMutex
@@ -11,10 +15,16 @@ var byPassMutex sync.RWMutex
var bypassPaths = make(map[string]struct{}) var bypassPaths = make(map[string]struct{})
// AddBypassPath adds an exact path to the list of paths that bypass middleware. // AddBypassPath adds an exact path to the list of paths that bypass middleware.
func AddBypassPath(path string) { // Paths can include wildcards, such as /api/*. Paths are matched using path.Match.
// Returns an error if the path has invalid pattern.
func AddBypassPath(path string) error {
byPassMutex.Lock() byPassMutex.Lock()
defer byPassMutex.Unlock() defer byPassMutex.Unlock()
if err := validatePath(path); err != nil {
return fmt.Errorf("validate: %w", err)
}
bypassPaths[path] = struct{}{} bypassPaths[path] = struct{}{}
return nil
} }
// RemovePath removes a path from the list of paths that bypass middleware. // RemovePath removes a path from the list of paths that bypass middleware.
@@ -24,16 +34,41 @@ func RemovePath(path string) {
delete(bypassPaths, path) delete(bypassPaths, path)
} }
// GetList returns a list of all bypass paths.
func GetList() []string {
byPassMutex.RLock()
defer byPassMutex.RUnlock()
list := make([]string, 0, len(bypassPaths))
for k := range bypassPaths {
list = append(list, k)
}
return list
}
// ShouldBypass checks if the request path is one of the auth bypass paths and returns true if the middleware should be bypassed. // ShouldBypass checks if the request path is one of the auth bypass paths and returns true if the middleware should be bypassed.
// This can be used to bypass authz/authn middlewares for certain paths, such as webhooks that implement their own authentication. // This can be used to bypass authz/authn middlewares for certain paths, such as webhooks that implement their own authentication.
func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *http.Request) bool { func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *http.Request) bool {
byPassMutex.RLock() byPassMutex.RLock()
defer byPassMutex.RUnlock() defer byPassMutex.RUnlock()
if _, ok := bypassPaths[requestPath]; ok { for bypassPath := range bypassPaths {
h.ServeHTTP(w, r) matched, err := path.Match(bypassPath, requestPath)
return true if err != nil {
log.Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
continue
}
if matched {
h.ServeHTTP(w, r)
return true
}
} }
return false return false
} }
func validatePath(p string) error {
_, err := path.Match(p, "")
return err
}

View File

@@ -11,6 +11,19 @@ import (
"github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/middleware/bypass"
) )
func TestGetList(t *testing.T) {
bypassPaths := []string{"/path1", "/path2", "/path3"}
for _, path := range bypassPaths {
err := bypass.AddBypassPath(path)
require.NoError(t, err, "Adding bypass path should not fail")
}
list := bypass.GetList()
assert.ElementsMatch(t, bypassPaths, list, "Bypass path list did not match expected paths")
}
func TestAuthBypass(t *testing.T) { func TestAuthBypass(t *testing.T) {
dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -31,6 +44,13 @@ func TestAuthBypass(t *testing.T) {
expectBypass: true, expectBypass: true,
expectHTTPCode: http.StatusOK, expectHTTPCode: http.StatusOK,
}, },
{
name: "Wildcard path added to bypass",
pathToAdd: "/bypass/*",
testPath: "/bypass/extra",
expectBypass: true,
expectHTTPCode: http.StatusOK,
},
{ {
name: "Path not added to bypass", name: "Path not added to bypass",
testPath: "/no-bypass", testPath: "/no-bypass",
@@ -59,6 +79,13 @@ func TestAuthBypass(t *testing.T) {
expectBypass: false, expectBypass: false,
expectHTTPCode: http.StatusOK, expectHTTPCode: http.StatusOK,
}, },
{
name: "Wildcard subpath does not match bypass",
pathToAdd: "/webhook/*",
testPath: "/webhook/extra/path",
expectBypass: false,
expectHTTPCode: http.StatusOK,
},
{ {
name: "Similar path does not match bypass", name: "Similar path does not match bypass",
pathToAdd: "/webhook", pathToAdd: "/webhook",
@@ -78,7 +105,8 @@ func TestAuthBypass(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
if tc.pathToAdd != "" { if tc.pathToAdd != "" {
bypass.AddBypassPath(tc.pathToAdd) err := bypass.AddBypassPath(tc.pathToAdd)
require.NoError(t, err, "Adding bypass path should not fail")
defer bypass.RemovePath(tc.pathToAdd) defer bypass.RemovePath(tc.pathToAdd)
} }

View File

@@ -213,8 +213,8 @@ func (p *PostureChecksHandler) savePostureChecks(
postureChecks.Checks.GeoLocationCheck = toPostureGeoLocationCheck(geoLocationCheck) postureChecks.Checks.GeoLocationCheck = toPostureGeoLocationCheck(geoLocationCheck)
} }
if privateNetworkCheck := req.Checks.PrivateNetworkCheck; privateNetworkCheck != nil { if peerNetworkRangeCheck := req.Checks.PeerNetworkRangeCheck; peerNetworkRangeCheck != nil {
postureChecks.Checks.PrivateNetworkCheck, err = toPrivateNetworkCheck(privateNetworkCheck) postureChecks.Checks.PeerNetworkRangeCheck, err = toPeerNetworkRangeCheck(peerNetworkRangeCheck)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid network prefix"), w) util.WriteError(status.Errorf(status.InvalidArgument, "invalid network prefix"), w)
return return
@@ -235,7 +235,7 @@ func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {
} }
if req.Checks == nil || (req.Checks.NbVersionCheck == nil && req.Checks.OsVersionCheck == nil && if req.Checks == nil || (req.Checks.NbVersionCheck == nil && req.Checks.OsVersionCheck == nil &&
req.Checks.GeoLocationCheck == nil && req.Checks.PrivateNetworkCheck == nil) { req.Checks.GeoLocationCheck == nil && req.Checks.PeerNetworkRangeCheck == nil) {
return status.Errorf(status.InvalidArgument, "posture checks shouldn't be empty") return status.Errorf(status.InvalidArgument, "posture checks shouldn't be empty")
} }
@@ -278,17 +278,17 @@ func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {
} }
} }
if privateNetworkCheck := req.Checks.PrivateNetworkCheck; privateNetworkCheck != nil { if peerNetworkRangeCheck := req.Checks.PeerNetworkRangeCheck; peerNetworkRangeCheck != nil {
if privateNetworkCheck.Action == "" { if peerNetworkRangeCheck.Action == "" {
return status.Errorf(status.InvalidArgument, "action for private network check shouldn't be empty") return status.Errorf(status.InvalidArgument, "action for peer network range check shouldn't be empty")
} }
allowedActions := []api.PrivateNetworkCheckAction{api.PrivateNetworkCheckActionAllow, api.PrivateNetworkCheckActionDeny} allowedActions := []api.PeerNetworkRangeCheckAction{api.PeerNetworkRangeCheckActionAllow, api.PeerNetworkRangeCheckActionDeny}
if !slices.Contains(allowedActions, privateNetworkCheck.Action) { if !slices.Contains(allowedActions, peerNetworkRangeCheck.Action) {
return status.Errorf(status.InvalidArgument, "action for private network check is not valid value") return status.Errorf(status.InvalidArgument, "action for peer network range check is not valid value")
} }
if len(privateNetworkCheck.Ranges) == 0 { if len(peerNetworkRangeCheck.Ranges) == 0 {
return status.Errorf(status.InvalidArgument, "network ranges for private network check shouldn't be empty") return status.Errorf(status.InvalidArgument, "network ranges for peer network range check shouldn't be empty")
} }
} }
@@ -318,8 +318,8 @@ func toPostureChecksResponse(postureChecks *posture.Checks) *api.PostureCheck {
checks.GeoLocationCheck = toGeoLocationCheckResponse(postureChecks.Checks.GeoLocationCheck) checks.GeoLocationCheck = toGeoLocationCheckResponse(postureChecks.Checks.GeoLocationCheck)
} }
if postureChecks.Checks.PrivateNetworkCheck != nil { if postureChecks.Checks.PeerNetworkRangeCheck != nil {
checks.PrivateNetworkCheck = toPrivateNetworkCheckResponse(postureChecks.Checks.PrivateNetworkCheck) checks.PeerNetworkRangeCheck = toPeerNetworkRangeCheckResponse(postureChecks.Checks.PeerNetworkRangeCheck)
} }
return &api.PostureCheck{ return &api.PostureCheck{
@@ -369,19 +369,19 @@ func toPostureGeoLocationCheck(apiGeoLocationCheck *api.GeoLocationCheck) *postu
} }
} }
func toPrivateNetworkCheckResponse(check *posture.PrivateNetworkCheck) *api.PrivateNetworkCheck { func toPeerNetworkRangeCheckResponse(check *posture.PeerNetworkRangeCheck) *api.PeerNetworkRangeCheck {
netPrefixes := make([]string, 0, len(check.Ranges)) netPrefixes := make([]string, 0, len(check.Ranges))
for _, netPrefix := range check.Ranges { for _, netPrefix := range check.Ranges {
netPrefixes = append(netPrefixes, netPrefix.String()) netPrefixes = append(netPrefixes, netPrefix.String())
} }
return &api.PrivateNetworkCheck{ return &api.PeerNetworkRangeCheck{
Ranges: netPrefixes, Ranges: netPrefixes,
Action: api.PrivateNetworkCheckAction(check.Action), Action: api.PeerNetworkRangeCheckAction(check.Action),
} }
} }
func toPrivateNetworkCheck(check *api.PrivateNetworkCheck) (*posture.PrivateNetworkCheck, error) { func toPeerNetworkRangeCheck(check *api.PeerNetworkRangeCheck) (*posture.PeerNetworkRangeCheck, error) {
prefixes := make([]netip.Prefix, 0) prefixes := make([]netip.Prefix, 0)
for _, prefix := range check.Ranges { for _, prefix := range check.Ranges {
parsedPrefix, err := netip.ParsePrefix(prefix) parsedPrefix, err := netip.ParsePrefix(prefix)
@@ -391,7 +391,7 @@ func toPrivateNetworkCheck(check *api.PrivateNetworkCheck) (*posture.PrivateNetw
prefixes = append(prefixes, parsedPrefix) prefixes = append(prefixes, parsedPrefix)
} }
return &posture.PrivateNetworkCheck{ return &posture.PeerNetworkRangeCheck{
Ranges: prefixes, Ranges: prefixes,
Action: string(check.Action), Action: string(check.Action),
}, nil }, nil

View File

@@ -131,7 +131,7 @@ func TestGetPostureCheck(t *testing.T) {
ID: "privateNetworkPostureCheck", ID: "privateNetworkPostureCheck",
Name: "privateNetwork", Name: "privateNetwork",
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
PrivateNetworkCheck: &posture.PrivateNetworkCheck{ PeerNetworkRangeCheck: &posture.PeerNetworkRangeCheck{
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
}, },
@@ -375,7 +375,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}, },
}, },
{ {
name: "Create Posture Checks Private Network", name: "Create Posture Checks Peer Network Range",
requestType: http.MethodPost, requestType: http.MethodPost,
requestPath: "/api/posture-checks", requestPath: "/api/posture-checks",
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
@@ -383,7 +383,7 @@ func TestPostureCheckUpdate(t *testing.T) {
"name": "default", "name": "default",
"description": "default", "description": "default",
"checks": { "checks": {
"private_network_check": { "peer_network_range_check": {
"action": "allow", "action": "allow",
"ranges": [ "ranges": [
"10.0.0.0/8" "10.0.0.0/8"
@@ -398,11 +398,11 @@ func TestPostureCheckUpdate(t *testing.T) {
Name: "default", Name: "default",
Description: str("default"), Description: str("default"),
Checks: api.Checks{ Checks: api.Checks{
PrivateNetworkCheck: &api.PrivateNetworkCheck{ PeerNetworkRangeCheck: &api.PeerNetworkRangeCheck{
Ranges: []string{ Ranges: []string{
"10.0.0.0/8", "10.0.0.0/8",
}, },
Action: api.PrivateNetworkCheckActionAllow, Action: api.PeerNetworkRangeCheckActionAllow,
}, },
}, },
}, },
@@ -715,14 +715,14 @@ func TestPostureCheckUpdate(t *testing.T) {
expectedBody: false, expectedBody: false,
}, },
{ {
name: "Update Posture Checks Private Network", name: "Update Posture Checks Peer Network Range",
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/posture-checks/privateNetworkPostureCheck", requestPath: "/api/posture-checks/peerNetworkRangePostureCheck",
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte(`{ []byte(`{
"name": "default", "name": "default",
"checks": { "checks": {
"private_network_check": { "peer_network_range_check": {
"action": "deny", "action": "deny",
"ranges": [ "ranges": [
"192.168.1.0/24" "192.168.1.0/24"
@@ -737,11 +737,11 @@ func TestPostureCheckUpdate(t *testing.T) {
Name: "default", Name: "default",
Description: str(""), Description: str(""),
Checks: api.Checks{ Checks: api.Checks{
PrivateNetworkCheck: &api.PrivateNetworkCheck{ PeerNetworkRangeCheck: &api.PeerNetworkRangeCheck{
Ranges: []string{ Ranges: []string{
"192.168.1.0/24", "192.168.1.0/24",
}, },
Action: api.PrivateNetworkCheckActionDeny, Action: api.PeerNetworkRangeCheckActionDeny,
}, },
}, },
}, },
@@ -784,10 +784,10 @@ func TestPostureCheckUpdate(t *testing.T) {
}, },
}, },
&posture.Checks{ &posture.Checks{
ID: "privateNetworkPostureCheck", ID: "peerNetworkRangePostureCheck",
Name: "privateNetwork", Name: "peerNetworkRange",
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
PrivateNetworkCheck: &posture.PrivateNetworkCheck{ PeerNetworkRangeCheck: &posture.PeerNetworkRangeCheck{
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
}, },
@@ -891,29 +891,50 @@ func TestPostureCheck_validatePostureChecksUpdate(t *testing.T) {
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}}) err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.NoError(t, err) assert.NoError(t, err)
// valid private network check // valid peer network range check
privateNetworkCheck := api.PrivateNetworkCheck{ peerNetworkRangeCheck := api.PeerNetworkRangeCheck{
Action: api.PrivateNetworkCheckActionAllow, Action: api.PeerNetworkRangeCheckActionAllow,
Ranges: []string{ Ranges: []string{
"192.168.1.0/24", "10.0.0.0/8", "192.168.1.0/24", "10.0.0.0/8",
}, },
} }
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{PrivateNetworkCheck: &privateNetworkCheck}}) err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.NoError(t, err) assert.NoError(t, err)
// invalid private network check // invalid peer network range check
privateNetworkCheck = api.PrivateNetworkCheck{ peerNetworkRangeCheck = api.PeerNetworkRangeCheck{
Action: api.PrivateNetworkCheckActionDeny, Action: api.PeerNetworkRangeCheckActionDeny,
Ranges: []string{}, Ranges: []string{},
} }
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{PrivateNetworkCheck: &privateNetworkCheck}}) err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.Error(t, err) assert.Error(t, err)
// invalid private network check // invalid peer network range check
privateNetworkCheck = api.PrivateNetworkCheck{ peerNetworkRangeCheck = api.PeerNetworkRangeCheck{
Action: "unknownAction", Action: "unknownAction",
Ranges: []string{}, Ranges: []string{},
} }
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{PrivateNetworkCheck: &privateNetworkCheck}}) err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.Error(t, err) assert.Error(t, err)
} }

View File

@@ -10,10 +10,10 @@ import (
) )
const ( const (
NBVersionCheckName = "NBVersionCheck" NBVersionCheckName = "NBVersionCheck"
OSVersionCheckName = "OSVersionCheck" OSVersionCheckName = "OSVersionCheck"
GeoLocationCheckName = "GeoLocationCheck" GeoLocationCheckName = "GeoLocationCheck"
PrivateNetworkCheckName = "PrivateNetworkCheck" PeerNetworkRangeCheckName = "PeerNetworkRangeCheck"
CheckActionAllow string = "allow" CheckActionAllow string = "allow"
CheckActionDeny string = "deny" CheckActionDeny string = "deny"
@@ -44,10 +44,10 @@ type Checks struct {
// ChecksDefinition contains definition of actual check // ChecksDefinition contains definition of actual check
type ChecksDefinition struct { type ChecksDefinition struct {
NBVersionCheck *NBVersionCheck `json:",omitempty"` NBVersionCheck *NBVersionCheck `json:",omitempty"`
OSVersionCheck *OSVersionCheck `json:",omitempty"` OSVersionCheck *OSVersionCheck `json:",omitempty"`
GeoLocationCheck *GeoLocationCheck `json:",omitempty"` GeoLocationCheck *GeoLocationCheck `json:",omitempty"`
PrivateNetworkCheck *PrivateNetworkCheck `json:",omitempty"` PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:",omitempty"`
} }
// Copy returns a copy of a checks definition. // Copy returns a copy of a checks definition.
@@ -85,13 +85,13 @@ func (cd ChecksDefinition) Copy() ChecksDefinition {
} }
copy(cdCopy.GeoLocationCheck.Locations, geoCheck.Locations) copy(cdCopy.GeoLocationCheck.Locations, geoCheck.Locations)
} }
if cd.PrivateNetworkCheck != nil { if cd.PeerNetworkRangeCheck != nil {
privateNetCheck := cd.PrivateNetworkCheck peerNetRangeCheck := cd.PeerNetworkRangeCheck
cdCopy.PrivateNetworkCheck = &PrivateNetworkCheck{ cdCopy.PeerNetworkRangeCheck = &PeerNetworkRangeCheck{
Action: privateNetCheck.Action, Action: peerNetRangeCheck.Action,
Ranges: make([]netip.Prefix, len(privateNetCheck.Ranges)), Ranges: make([]netip.Prefix, len(peerNetRangeCheck.Ranges)),
} }
copy(cdCopy.PrivateNetworkCheck.Ranges, privateNetCheck.Ranges) copy(cdCopy.PeerNetworkRangeCheck.Ranges, peerNetRangeCheck.Ranges)
} }
return cdCopy return cdCopy
} }
@@ -130,8 +130,8 @@ func (pc *Checks) GetChecks() []Check {
if pc.Checks.GeoLocationCheck != nil { if pc.Checks.GeoLocationCheck != nil {
checks = append(checks, pc.Checks.GeoLocationCheck) checks = append(checks, pc.Checks.GeoLocationCheck)
} }
if pc.Checks.PrivateNetworkCheck != nil { if pc.Checks.PeerNetworkRangeCheck != nil {
checks = append(checks, pc.Checks.PrivateNetworkCheck) checks = append(checks, pc.Checks.PeerNetworkRangeCheck)
} }
return checks return checks
} }

View File

@@ -254,7 +254,7 @@ func TestChecks_Copy(t *testing.T) {
}, },
Action: CheckActionAllow, Action: CheckActionAllow,
}, },
PrivateNetworkCheck: &PrivateNetworkCheck{ PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("10.0.0.0/8"), netip.MustParsePrefix("10.0.0.0/8"),

View File

@@ -8,16 +8,16 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
) )
type PrivateNetworkCheck struct { type PeerNetworkRangeCheck struct {
Action string Action string
Ranges []netip.Prefix `gorm:"serializer:json"` Ranges []netip.Prefix `gorm:"serializer:json"`
} }
var _ Check = (*PrivateNetworkCheck)(nil) var _ Check = (*PeerNetworkRangeCheck)(nil)
func (p *PrivateNetworkCheck) Check(peer nbpeer.Peer) (bool, error) { func (p *PeerNetworkRangeCheck) Check(peer nbpeer.Peer) (bool, error) {
if len(peer.Meta.NetworkAddresses) == 0 { if len(peer.Meta.NetworkAddresses) == 0 {
return false, fmt.Errorf("peer's does not contain private network addresses") return false, fmt.Errorf("peer's does not contain peer network range addresses")
} }
maskedPrefixes := make([]netip.Prefix, 0, len(p.Ranges)) maskedPrefixes := make([]netip.Prefix, 0, len(p.Ranges))
@@ -34,7 +34,7 @@ func (p *PrivateNetworkCheck) Check(peer nbpeer.Peer) (bool, error) {
case CheckActionAllow: case CheckActionAllow:
return true, nil return true, nil
default: default:
return false, fmt.Errorf("invalid private network check action: %s", p.Action) return false, fmt.Errorf("invalid peer network range check action: %s", p.Action)
} }
} }
} }
@@ -46,9 +46,9 @@ func (p *PrivateNetworkCheck) Check(peer nbpeer.Peer) (bool, error) {
return false, nil return false, nil
} }
return false, fmt.Errorf("invalid private network check action: %s", p.Action) return false, fmt.Errorf("invalid peer network range check action: %s", p.Action)
} }
func (p *PrivateNetworkCheck) Name() string { func (p *PeerNetworkRangeCheck) Name() string {
return PrivateNetworkCheckName return PeerNetworkRangeCheckName
} }

View File

@@ -9,17 +9,17 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
) )
func TestPrivateNetworkCheck_Check(t *testing.T) { func TestPeerNetworkRangeCheck_Check(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
check PrivateNetworkCheck check PeerNetworkRangeCheck
peer nbpeer.Peer peer nbpeer.Peer
wantErr bool wantErr bool
isValid bool isValid bool
}{ }{
{ {
name: "Peer private networks matches the allowed range", name: "Peer networks range matches the allowed range",
check: PrivateNetworkCheck{ check: PeerNetworkRangeCheck{
Action: CheckActionAllow, Action: CheckActionAllow,
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
@@ -42,8 +42,8 @@ func TestPrivateNetworkCheck_Check(t *testing.T) {
isValid: true, isValid: true,
}, },
{ {
name: "Peer private networks doesn't matches the allowed range", name: "Peer networks range doesn't matches the allowed range",
check: PrivateNetworkCheck{ check: PeerNetworkRangeCheck{
Action: CheckActionAllow, Action: CheckActionAllow,
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
@@ -63,8 +63,8 @@ func TestPrivateNetworkCheck_Check(t *testing.T) {
isValid: false, isValid: false,
}, },
{ {
name: "Peer with no privates network in the allow range", name: "Peer with no network range in the allow range",
check: PrivateNetworkCheck{ check: PeerNetworkRangeCheck{
Action: CheckActionAllow, Action: CheckActionAllow,
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"), netip.MustParsePrefix("192.168.0.0/16"),
@@ -76,8 +76,8 @@ func TestPrivateNetworkCheck_Check(t *testing.T) {
isValid: false, isValid: false,
}, },
{ {
name: "Peer private networks matches the denied range", name: "Peer networks range matches the denied range",
check: PrivateNetworkCheck{ check: PeerNetworkRangeCheck{
Action: CheckActionDeny, Action: CheckActionDeny,
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
@@ -100,8 +100,8 @@ func TestPrivateNetworkCheck_Check(t *testing.T) {
isValid: false, isValid: false,
}, },
{ {
name: "Peer private networks doesn't matches the denied range", name: "Peer networks range doesn't matches the denied range",
check: PrivateNetworkCheck{ check: PeerNetworkRangeCheck{
Action: CheckActionDeny, Action: CheckActionDeny,
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
@@ -121,8 +121,8 @@ func TestPrivateNetworkCheck_Check(t *testing.T) {
isValid: true, isValid: true,
}, },
{ {
name: "Peer with no private networks in the denied range", name: "Peer with no networks range in the denied range",
check: PrivateNetworkCheck{ check: PeerNetworkRangeCheck{
Action: CheckActionDeny, Action: CheckActionDeny,
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"), netip.MustParsePrefix("192.168.0.0/16"),