Compare commits

...

9 Commits

Author SHA1 Message Date
Maycon Santos
5af9bbfec9 check for closed channel and adjust tests 2024-03-02 01:02:22 +01:00
Maycon Santos
19873db1a9 add comment and adjust log position 2024-02-27 18:55:09 +01:00
Maycon Santos
87d1fc3a2f Handle canceling schedule and avoid recursive call
Using time.Ticker allow us to avoid recursive call that may end up in schedule running and possible deadlock if no routine is listening for cancel calls

switch to closing channel
2024-02-27 18:38:04 +01:00
pascal-fischer
b085419ab8 FIx order when validating account settings (#1632)
* moved extraSettings validation to the end

* moved extraSettings validation directly after permission check
2024-02-27 14:17:22 +01:00
Bethuel Mmbaga
d78b652ff7 Rename PrivateNetworkCheck to PeerNetworkRangeCheck (#1629)
* Rename PrivateNetworkCheck to PeerNetworkRangeCheck

* update description and example

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2024-02-27 11:59:48 +01:00
Viktor Liu
7251150c1c Combine update-available and connected/disconnected tray icon states (#1615)
This PR updates the system tray icons to reflect both connection status and availability of updates. Now, the tray will show distinct icons for the following states: connected, disconnected, update available while connected, and update available while disconnected. This change improves user experience by providing a clear visual status indicator.

- Add new icons for connected and disconnected states with update available.
- Implement logic to switch icons based on connection status and update availability.
- Remove old icon references for default and update states.
2024-02-26 23:28:33 +01:00
Bethuel Mmbaga
b65c2f69b0 Add support for downloading Geo databases to the management service (#1626)
Adds support for downloading Geo databases to the management service. If the Geo databases are not found, the service will automatically attempt to download them during startup.
2024-02-26 22:49:28 +01:00
Yury Gargay
d8ce08d898 Extend bypass middleware with support of wildcard paths (#1628)
---------

Co-authored-by: Viktor Liu <viktor@netbird.io>
2024-02-26 17:54:58 +01:00
Maycon Santos
e1c50248d9 Add support for device flow on getting started with zitadel (#1616) 2024-02-26 12:33:16 +01:00
36 changed files with 718 additions and 210 deletions

View File

@@ -199,6 +199,6 @@ jobs:
- name: test script - name: test script
run: bash -x infrastructure_files/download-geolite2.sh run: bash -x infrastructure_files/download-geolite2.sh
- name: test mmdb file exists - name: test mmdb file exists
run: ls -l GeoLite2-City_*/GeoLite2-City.mmdb run: test -f GeoLite2-City.mmdb
- name: test geonames file exists - name: test geonames file exists
run: test -f geonames.db run: test -f geonames.db

View File

@@ -54,7 +54,7 @@ nfpms:
contents: contents:
- src: client/ui/netbird.desktop - src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird-systemtray-default.png - src: client/ui/netbird-systemtray-connected.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird
@@ -71,7 +71,7 @@ nfpms:
contents: contents:
- src: client/ui/netbird.desktop - src: client/ui/netbird.desktop
dst: /usr/share/applications/netbird.desktop dst: /usr/share/applications/netbird.desktop
- src: client/ui/netbird-systemtray-default.png - src: client/ui/netbird-systemtray-connected.png
dst: /usr/share/pixmaps/netbird.png dst: /usr/share/pixmaps/netbird.png
dependencies: dependencies:
- netbird - netbird

View File

@@ -82,17 +82,23 @@ var iconConnectedICO []byte
//go:embed netbird-systemtray-connected.png //go:embed netbird-systemtray-connected.png
var iconConnectedPNG []byte var iconConnectedPNG []byte
//go:embed netbird-systemtray-default.ico //go:embed netbird-systemtray-disconnected.ico
var iconDisconnectedICO []byte var iconDisconnectedICO []byte
//go:embed netbird-systemtray-default.png //go:embed netbird-systemtray-disconnected.png
var iconDisconnectedPNG []byte var iconDisconnectedPNG []byte
//go:embed netbird-systemtray-update.ico //go:embed netbird-systemtray-update-disconnected.ico
var iconUpdateICO []byte var iconUpdateDisconnectedICO []byte
//go:embed netbird-systemtray-update.png //go:embed netbird-systemtray-update-disconnected.png
var iconUpdatePNG []byte var iconUpdateDisconnectedPNG []byte
//go:embed netbird-systemtray-update-connected.ico
var iconUpdateConnectedICO []byte
//go:embed netbird-systemtray-update-connected.png
var iconUpdateConnectedPNG []byte
//go:embed netbird-systemtray-update-cloud.ico //go:embed netbird-systemtray-update-cloud.ico
var iconUpdateCloudICO []byte var iconUpdateCloudICO []byte
@@ -105,10 +111,11 @@ type serviceClient struct {
addr string addr string
conn proto.DaemonServiceClient conn proto.DaemonServiceClient
icConnected []byte icConnected []byte
icDisconnected []byte icDisconnected []byte
icUpdate []byte icUpdateConnected []byte
icUpdateCloud []byte icUpdateDisconnected []byte
icUpdateCloud []byte
// systray menu items // systray menu items
mStatus *systray.MenuItem mStatus *systray.MenuItem
@@ -139,6 +146,7 @@ type serviceClient struct {
preSharedKey string preSharedKey string
adminURL string adminURL string
connected bool
update *version.Update update *version.Update
daemonVersion string daemonVersion string
updateIndicationLock sync.Mutex updateIndicationLock sync.Mutex
@@ -161,13 +169,15 @@ func newServiceClient(addr string, a fyne.App, showSettings bool) *serviceClient
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
s.icConnected = iconConnectedICO s.icConnected = iconConnectedICO
s.icDisconnected = iconDisconnectedICO s.icDisconnected = iconDisconnectedICO
s.icUpdate = iconUpdateICO s.icUpdateConnected = iconUpdateConnectedICO
s.icUpdateDisconnected = iconUpdateDisconnectedICO
s.icUpdateCloud = iconUpdateCloudICO s.icUpdateCloud = iconUpdateCloudICO
} else { } else {
s.icConnected = iconConnectedPNG s.icConnected = iconConnectedPNG
s.icDisconnected = iconDisconnectedPNG s.icDisconnected = iconDisconnectedPNG
s.icUpdate = iconUpdatePNG s.icUpdateConnected = iconUpdateConnectedPNG
s.icUpdateDisconnected = iconUpdateDisconnectedPNG
s.icUpdateCloud = iconUpdateCloudPNG s.icUpdateCloud = iconUpdateCloudPNG
} }
@@ -369,7 +379,10 @@ func (s *serviceClient) updateStatus() error {
var systrayIconState bool var systrayIconState bool
if status.Status == string(internal.StatusConnected) && !s.mUp.Disabled() { if status.Status == string(internal.StatusConnected) && !s.mUp.Disabled() {
if !s.isUpdateIconActive { s.connected = true
if s.isUpdateIconActive {
systray.SetIcon(s.icUpdateConnected)
} else {
systray.SetIcon(s.icConnected) systray.SetIcon(s.icConnected)
} }
systray.SetTooltip("NetBird (Connected)") systray.SetTooltip("NetBird (Connected)")
@@ -378,7 +391,10 @@ func (s *serviceClient) updateStatus() error {
s.mDown.Enable() s.mDown.Enable()
systrayIconState = true systrayIconState = true
} else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() { } else if status.Status != string(internal.StatusConnected) && s.mUp.Disabled() {
if !s.isUpdateIconActive { s.connected = false
if s.isUpdateIconActive {
systray.SetIcon(s.icUpdateDisconnected)
} else {
systray.SetIcon(s.icDisconnected) systray.SetIcon(s.icDisconnected)
} }
systray.SetTooltip("NetBird (Disconnected)") systray.SetTooltip("NetBird (Disconnected)")
@@ -605,10 +621,13 @@ func (s *serviceClient) onUpdateAvailable() {
defer s.updateIndicationLock.Unlock() defer s.updateIndicationLock.Unlock()
s.mUpdate.Show() s.mUpdate.Show()
s.mAbout.SetIcon(s.icUpdateCloud)
s.isUpdateIconActive = true s.isUpdateIconActive = true
systray.SetIcon(s.icUpdate)
if s.connected {
systray.SetIcon(s.icUpdateConnected)
} else {
systray.SetIcon(s.icUpdateDisconnected)
}
} }
func openURL(url string) error { func openURL(url string) error {

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.3 KiB

After

Width:  |  Height:  |  Size: 5.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.1 KiB

After

Width:  |  Height:  |  Size: 8.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 9.0 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.5 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 7.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.6 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.3 KiB

View File

@@ -43,21 +43,18 @@ download_geolite_mmdb() {
mkdir -p "$EXTRACTION_DIR" mkdir -p "$EXTRACTION_DIR"
tar -xzvf "$DATABASE_FILE" > /dev/null 2>&1 tar -xzvf "$DATABASE_FILE" > /dev/null 2>&1
# Create a SHA256 signature file
MMDB_FILE="GeoLite2-City.mmdb" MMDB_FILE="GeoLite2-City.mmdb"
cd "$EXTRACTION_DIR" cp "$EXTRACTION_DIR"/"$MMDB_FILE" $MMDB_FILE
sha256sum "$MMDB_FILE" > "$MMDB_FILE.sha256"
echo "SHA256 signature created for $MMDB_FILE."
cd - > /dev/null 2>&1
# Remove downloaded files # Remove downloaded files
rm -r "$EXTRACTION_DIR"
rm "$DATABASE_FILE" "$SIGNATURE_FILE" rm "$DATABASE_FILE" "$SIGNATURE_FILE"
# Done. Print next steps # Done. Print next steps
echo "" echo ""
echo "Process completed successfully." echo "Process completed successfully."
echo "Now you can place $EXTRACTION_DIR/$MMDB_FILE to 'datadir' of management service." echo "Now you can place $MMDB_FILE to 'datadir' of management service."
echo -e "Example:\n\tdocker compose cp $EXTRACTION_DIR/$MMDB_FILE management:/var/lib/netbird/" echo -e "Example:\n\tdocker compose cp $MMDB_FILE management:/var/lib/netbird/"
} }

View File

@@ -137,6 +137,13 @@ create_new_application() {
BASE_REDIRECT_URL2=$5 BASE_REDIRECT_URL2=$5
LOGOUT_URL=$6 LOGOUT_URL=$6
ZITADEL_DEV_MODE=$7 ZITADEL_DEV_MODE=$7
DEVICE_CODE=$8
if [[ $DEVICE_CODE == "true" ]]; then
GRANT_TYPES='["OIDC_GRANT_TYPE_AUTHORIZATION_CODE","OIDC_GRANT_TYPE_DEVICE_CODE","OIDC_GRANT_TYPE_REFRESH_TOKEN"]'
else
GRANT_TYPES='["OIDC_GRANT_TYPE_AUTHORIZATION_CODE","OIDC_GRANT_TYPE_REFRESH_TOKEN"]'
fi
RESPONSE=$( RESPONSE=$(
curl -sS -X POST "$INSTANCE_URL/management/v1/projects/$PROJECT_ID/apps/oidc" \ curl -sS -X POST "$INSTANCE_URL/management/v1/projects/$PROJECT_ID/apps/oidc" \
@@ -154,10 +161,7 @@ create_new_application() {
"RESPONSETypes": [ "RESPONSETypes": [
"OIDC_RESPONSE_TYPE_CODE" "OIDC_RESPONSE_TYPE_CODE"
], ],
"grantTypes": [ "grantTypes": '"$GRANT_TYPES"',
"OIDC_GRANT_TYPE_AUTHORIZATION_CODE",
"OIDC_GRANT_TYPE_REFRESH_TOKEN"
],
"appType": "OIDC_APP_TYPE_USER_AGENT", "appType": "OIDC_APP_TYPE_USER_AGENT",
"authMethodType": "OIDC_AUTH_METHOD_TYPE_NONE", "authMethodType": "OIDC_AUTH_METHOD_TYPE_NONE",
"version": "OIDC_VERSION_1_0", "version": "OIDC_VERSION_1_0",
@@ -340,10 +344,10 @@ init_zitadel() {
# create zitadel spa applications # create zitadel spa applications
echo "Creating new Zitadel SPA Dashboard application" echo "Creating new Zitadel SPA Dashboard application"
DASHBOARD_APPLICATION_CLIENT_ID=$(create_new_application "$INSTANCE_URL" "$PAT" "Dashboard" "$BASE_REDIRECT_URL/nb-auth" "$BASE_REDIRECT_URL/nb-silent-auth" "$BASE_REDIRECT_URL/" "$ZITADEL_DEV_MODE") DASHBOARD_APPLICATION_CLIENT_ID=$(create_new_application "$INSTANCE_URL" "$PAT" "Dashboard" "$BASE_REDIRECT_URL/nb-auth" "$BASE_REDIRECT_URL/nb-silent-auth" "$BASE_REDIRECT_URL/" "$ZITADEL_DEV_MODE" "false")
echo "Creating new Zitadel SPA Cli application" echo "Creating new Zitadel SPA Cli application"
CLI_APPLICATION_CLIENT_ID=$(create_new_application "$INSTANCE_URL" "$PAT" "Cli" "http://localhost:53000/" "http://localhost:54000/" "http://localhost:53000/" "true") CLI_APPLICATION_CLIENT_ID=$(create_new_application "$INSTANCE_URL" "$PAT" "Cli" "http://localhost:53000/" "http://localhost:54000/" "http://localhost:53000/" "true" "true")
MACHINE_USER_ID=$(create_service_user "$INSTANCE_URL" "$PAT") MACHINE_USER_ID=$(create_service_user "$INSTANCE_URL" "$PAT")
@@ -561,6 +565,8 @@ renderCaddyfile() {
reverse_proxy /.well-known/openid-configuration h2c://zitadel:8080 reverse_proxy /.well-known/openid-configuration h2c://zitadel:8080
reverse_proxy /openapi/* h2c://zitadel:8080 reverse_proxy /openapi/* h2c://zitadel:8080
reverse_proxy /debug/* h2c://zitadel:8080 reverse_proxy /debug/* h2c://zitadel:8080
reverse_proxy /device/* h2c://zitadel:8080
reverse_proxy /device h2c://zitadel:8080
# Dashboard # Dashboard
reverse_proxy /* dashboard:80 reverse_proxy /* dashboard:80
} }
@@ -629,6 +635,14 @@ renderManagementJson() {
"ManagementEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/management/v1" "ManagementEndpoint": "$NETBIRD_HTTP_PROTOCOL://$NETBIRD_DOMAIN/management/v1"
} }
}, },
"DeviceAuthorizationFlow": {
"Provider": "hosted",
"ProviderConfig": {
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"ClientID": "$NETBIRD_AUTH_CLIENT_ID_CLI",
"Scope": "openid"
}
},
"PKCEAuthorizationFlow": { "PKCEAuthorizationFlow": {
"ProviderConfig": { "ProviderConfig": {
"Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI", "Audience": "$NETBIRD_AUTH_CLIENT_ID_CLI",

View File

@@ -166,7 +166,7 @@ var (
geo, err := geolocation.NewGeolocation(config.Datadir) geo, err := geolocation.NewGeolocation(config.Datadir)
if err != nil { if err != nil {
log.Warnf("could not initialize geo location service, we proceed without geo support") log.Warnf("could not initialize geo location service: %v, we proceed without geo support", err)
} else { } else {
log.Infof("geo location service has been initialized from %s", config.Datadir) log.Infof("geo location service has been initialized from %s", config.Datadir)
} }

View File

@@ -917,12 +917,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
unlock := am.Store.AcquireAccountLock(accountID) unlock := am.Store.AcquireAccountLock(accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccountByUser(userID) account, err := am.Store.GetAccount(accountID)
if err != nil {
return nil, err
}
err = additions.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID, am.eventStore)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -936,6 +931,11 @@ func (am *DefaultAccountManager) UpdateAccountSettings(accountID, userID string,
return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account") return nil, status.Errorf(status.PermissionDenied, "user is not allowed to update account")
} }
err = additions.ValidateExtraSettings(newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID, am.eventStore)
if err != nil {
return nil, err
}
oldSettings := account.Settings oldSettings := account.Settings
if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled { if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled {
event := activity.AccountPeerLoginExpirationEnabled event := activity.AccountPeerLoginExpirationEnabled

View File

@@ -0,0 +1,187 @@
package geolocation
import (
"encoding/csv"
"fmt"
"net/url"
"os"
"path"
"strconv"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
const (
geoLiteCityTarGZURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz"
geoLiteCityZipURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip"
geoLiteCitySha256TarURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City/download?suffix=tar.gz.sha256"
geoLiteCitySha256ZipURL = "https://pkgs.netbird.io/geolocation-dbs/GeoLite2-City-CSV/download?suffix=zip.sha256"
)
// loadGeolocationDatabases loads the MaxMind databases.
func loadGeolocationDatabases(dataDir string) error {
files := []string{MMDBFileName, GeoSqliteDBFile}
for _, file := range files {
exists, _ := fileExists(path.Join(dataDir, file))
if exists {
continue
}
switch file {
case MMDBFileName:
extractFunc := func(src string, dst string) error {
if err := decompressTarGzFile(src, dst); err != nil {
return err
}
return os.Rename(path.Join(dst, MMDBFileName), path.Join(dataDir, MMDBFileName))
}
if err := loadDatabase(
geoLiteCitySha256TarURL,
geoLiteCityTarGZURL,
extractFunc,
); err != nil {
return err
}
case GeoSqliteDBFile:
extractFunc := func(src string, dst string) error {
if err := decompressZipFile(src, dst); err != nil {
return err
}
extractedCsvFile := path.Join(dst, "GeoLite2-City-Locations-en.csv")
return importCsvToSqlite(dataDir, extractedCsvFile)
}
if err := loadDatabase(
geoLiteCitySha256ZipURL,
geoLiteCityZipURL,
extractFunc,
); err != nil {
return err
}
}
}
return nil
}
// loadDatabase downloads a file from the specified URL and verifies its checksum.
// It then calls the extract function to perform additional processing on the extracted files.
func loadDatabase(checksumURL string, fileURL string, extractFunc func(src string, dst string) error) error {
temp, err := os.MkdirTemp(os.TempDir(), "geolite")
if err != nil {
return err
}
defer os.RemoveAll(temp)
checksumFile := path.Join(temp, getDatabaseFileName(checksumURL))
err = downloadFile(checksumURL, checksumFile)
if err != nil {
return err
}
sha256sum, err := loadChecksumFromFile(checksumFile)
if err != nil {
return err
}
dbFile := path.Join(temp, getDatabaseFileName(fileURL))
err = downloadFile(fileURL, dbFile)
if err != nil {
return err
}
if err := verifyChecksum(dbFile, sha256sum); err != nil {
return err
}
return extractFunc(dbFile, temp)
}
// importCsvToSqlite imports a CSV file into a SQLite database.
func importCsvToSqlite(dataDir string, csvFile string) error {
geonames, err := loadGeonamesCsv(csvFile)
if err != nil {
return err
}
db, err := gorm.Open(sqlite.Open(path.Join(dataDir, GeoSqliteDBFile)), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
CreateBatchSize: 1000,
PrepareStmt: true,
})
if err != nil {
return err
}
defer func() {
sql, err := db.DB()
if err != nil {
return
}
sql.Close()
}()
if err := db.AutoMigrate(&GeoNames{}); err != nil {
return err
}
return db.Create(geonames).Error
}
func loadGeonamesCsv(filepath string) ([]GeoNames, error) {
f, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer f.Close()
reader := csv.NewReader(f)
records, err := reader.ReadAll()
if err != nil {
return nil, err
}
var geoNames []GeoNames
for index, record := range records {
if index == 0 {
continue
}
geoNameID, err := strconv.Atoi(record[0])
if err != nil {
return nil, err
}
geoName := GeoNames{
GeoNameID: geoNameID,
LocaleCode: record[1],
ContinentCode: record[2],
ContinentName: record[3],
CountryIsoCode: record[4],
CountryName: record[5],
Subdivision1IsoCode: record[6],
Subdivision1Name: record[7],
Subdivision2IsoCode: record[8],
Subdivision2Name: record[9],
CityName: record[10],
MetroCode: record[11],
TimeZone: record[12],
IsInEuropeanUnion: record[13],
}
geoNames = append(geoNames, geoName)
}
return geoNames, nil
}
// getDatabaseFileName extracts the file name from a given URL string.
func getDatabaseFileName(urlStr string) string {
u, err := url.Parse(urlStr)
if err != nil {
panic(err)
}
ext := u.Query().Get("suffix")
fileName := fmt.Sprintf("%s.%s", path.Base(u.Path), ext)
return fileName
}

View File

@@ -2,9 +2,7 @@ package geolocation
import ( import (
"bytes" "bytes"
"crypto/sha256"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"path" "path"
@@ -54,20 +52,23 @@ type Country struct {
CountryName string CountryName string
} }
func NewGeolocation(datadir string) (*Geolocation, error) { func NewGeolocation(dataDir string) (*Geolocation, error) {
mmdbPath := path.Join(datadir, MMDBFileName) if err := loadGeolocationDatabases(dataDir); err != nil {
return nil, fmt.Errorf("failed to load MaxMind databases: %v", err)
}
mmdbPath := path.Join(dataDir, MMDBFileName)
db, err := openDB(mmdbPath) db, err := openDB(mmdbPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
sha256sum, err := getSha256sum(mmdbPath) sha256sum, err := calculateFileSHA256(mmdbPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
locationDB, err := NewSqliteStore(datadir) locationDB, err := NewSqliteStore(dataDir)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -104,21 +105,6 @@ func openDB(mmdbPath string) (*maxminddb.Reader, error) {
return db, nil return db, nil
} }
func getSha256sum(mmdbPath string) ([]byte, error) {
f, err := os.Open(mmdbPath)
if err != nil {
return nil, err
}
defer f.Close()
h := sha256.New()
if _, err := io.Copy(h, f); err != nil {
return nil, err
}
return h.Sum(nil), nil
}
func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) { func (gl *Geolocation) Lookup(ip net.IP) (*Record, error) {
gl.mux.RLock() gl.mux.RLock()
defer gl.mux.RUnlock() defer gl.mux.RUnlock()
@@ -189,7 +175,7 @@ func (gl *Geolocation) reloader() {
log.Errorf("geonames db reload failed: %s", err) log.Errorf("geonames db reload failed: %s", err)
} }
newSha256sum1, err := getSha256sum(gl.mmdbPath) newSha256sum1, err := calculateFileSHA256(gl.mmdbPath)
if err != nil { if err != nil {
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
continue continue
@@ -198,7 +184,7 @@ func (gl *Geolocation) reloader() {
// we check sum twice just to avoid possible case when we reload during update of the file // we check sum twice just to avoid possible case when we reload during update of the file
// considering the frequency of file update (few times a week) checking sum twice should be enough // considering the frequency of file update (few times a week) checking sum twice should be enough
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
newSha256sum2, err := getSha256sum(gl.mmdbPath) newSha256sum2, err := calculateFileSHA256(gl.mmdbPath)
if err != nil { if err != nil {
log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err) log.Errorf("failed to calculate sha256 sum for '%s': %s", gl.mmdbPath, err)
continue continue

View File

@@ -20,6 +20,27 @@ const (
GeoSqliteDBFile = "geonames.db" GeoSqliteDBFile = "geonames.db"
) )
type GeoNames struct {
GeoNameID int `gorm:"column:geoname_id"`
LocaleCode string `gorm:"column:locale_code"`
ContinentCode string `gorm:"column:continent_code"`
ContinentName string `gorm:"column:continent_name"`
CountryIsoCode string `gorm:"column:country_iso_code"`
CountryName string `gorm:"column:country_name"`
Subdivision1IsoCode string `gorm:"column:subdivision_1_iso_code"`
Subdivision1Name string `gorm:"column:subdivision_1_name"`
Subdivision2IsoCode string `gorm:"column:subdivision_2_iso_code"`
Subdivision2Name string `gorm:"column:subdivision_2_name"`
CityName string `gorm:"column:city_name"`
MetroCode string `gorm:"column:metro_code"`
TimeZone string `gorm:"column:time_zone"`
IsInEuropeanUnion string `gorm:"column:is_in_european_union"`
}
func (*GeoNames) TableName() string {
return "geonames"
}
// SqliteStore represents a location storage backed by a Sqlite DB. // SqliteStore represents a location storage backed by a Sqlite DB.
type SqliteStore struct { type SqliteStore struct {
db *gorm.DB db *gorm.DB
@@ -37,7 +58,7 @@ func NewSqliteStore(dataDir string) (*SqliteStore, error) {
return nil, err return nil, err
} }
sha256sum, err := getSha256sum(file) sha256sum, err := calculateFileSHA256(file)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -60,7 +81,7 @@ func (s *SqliteStore) GetAllCountries() ([]Country, error) {
} }
var countries []Country var countries []Country
result := s.db.Table("geonames"). result := s.db.Model(&GeoNames{}).
Select("country_iso_code", "country_name"). Select("country_iso_code", "country_name").
Group("country_name"). Group("country_name").
Scan(&countries) Scan(&countries)
@@ -81,7 +102,7 @@ func (s *SqliteStore) GetCitiesByCountry(countryISOCode string) ([]City, error)
} }
var cities []City var cities []City
result := s.db.Table("geonames"). result := s.db.Model(&GeoNames{}).
Select("geoname_id", "city_name"). Select("geoname_id", "city_name").
Where("country_iso_code = ?", countryISOCode). Where("country_iso_code = ?", countryISOCode).
Group("city_name"). Group("city_name").
@@ -98,7 +119,7 @@ func (s *SqliteStore) reload() error {
s.mux.Lock() s.mux.Lock()
defer s.mux.Unlock() defer s.mux.Unlock()
newSha256sum1, err := getSha256sum(s.filePath) newSha256sum1, err := calculateFileSHA256(s.filePath)
if err != nil { if err != nil {
log.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err) log.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
} }
@@ -107,7 +128,7 @@ func (s *SqliteStore) reload() error {
// we check sum twice just to avoid possible case when we reload during update of the file // we check sum twice just to avoid possible case when we reload during update of the file
// considering the frequency of file update (few times a week) checking sum twice should be enough // considering the frequency of file update (few times a week) checking sum twice should be enough
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
newSha256sum2, err := getSha256sum(s.filePath) newSha256sum2, err := calculateFileSHA256(s.filePath)
if err != nil { if err != nil {
return fmt.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err) return fmt.Errorf("failed to calculate sha256 sum for '%s': %s", s.filePath, err)
} }

View File

@@ -0,0 +1,176 @@
package geolocation
import (
"archive/tar"
"archive/zip"
"bufio"
"bytes"
"compress/gzip"
"crypto/sha256"
"errors"
"fmt"
"io"
"net/http"
"os"
"path"
"strings"
)
// decompressTarGzFile decompresses a .tar.gz file.
func decompressTarGzFile(filepath, destDir string) error {
file, err := os.Open(filepath)
if err != nil {
return err
}
defer file.Close()
gzipReader, err := gzip.NewReader(file)
if err != nil {
return err
}
defer gzipReader.Close()
tarReader := tar.NewReader(gzipReader)
for {
header, err := tarReader.Next()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return err
}
if header.Typeflag == tar.TypeReg {
outFile, err := os.Create(path.Join(destDir, path.Base(header.Name)))
if err != nil {
return err
}
_, err = io.Copy(outFile, tarReader) // #nosec G110
outFile.Close()
if err != nil {
return err
}
}
}
return nil
}
// decompressZipFile decompresses a .zip file.
func decompressZipFile(filepath, destDir string) error {
r, err := zip.OpenReader(filepath)
if err != nil {
return err
}
defer r.Close()
for _, f := range r.File {
if f.FileInfo().IsDir() {
continue
}
outFile, err := os.Create(path.Join(destDir, path.Base(f.Name)))
if err != nil {
return err
}
rc, err := f.Open()
if err != nil {
outFile.Close()
return err
}
_, err = io.Copy(outFile, rc) // #nosec G110
outFile.Close()
rc.Close()
if err != nil {
return err
}
}
return nil
}
// calculateFileSHA256 calculates the SHA256 checksum of a file.
func calculateFileSHA256(filepath string) ([]byte, error) {
file, err := os.Open(filepath)
if err != nil {
return nil, err
}
defer file.Close()
h := sha256.New()
if _, err := io.Copy(h, file); err != nil {
return nil, err
}
return h.Sum(nil), nil
}
// loadChecksumFromFile loads the first checksum from a file.
func loadChecksumFromFile(filepath string) (string, error) {
file, err := os.Open(filepath)
if err != nil {
return "", err
}
defer file.Close()
scanner := bufio.NewScanner(file)
if scanner.Scan() {
parts := strings.Fields(scanner.Text())
if len(parts) > 0 {
return parts[0], nil
}
}
if err := scanner.Err(); err != nil {
return "", err
}
return "", nil
}
// verifyChecksum compares the calculated SHA256 checksum of a file against the expected checksum.
func verifyChecksum(filepath, expectedChecksum string) error {
calculatedChecksum, err := calculateFileSHA256(filepath)
fileCheckSum := fmt.Sprintf("%x", calculatedChecksum)
if err != nil {
return err
}
if fileCheckSum != expectedChecksum {
return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedChecksum, fileCheckSum)
}
return nil
}
// downloadFile downloads a file from a URL and saves it to a local file path.
func downloadFile(url, filepath string) error {
resp, err := http.Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("unexpected error occurred while downloading the file: %s", string(bodyBytes))
}
out, err := os.Create(filepath)
if err != nil {
return err
}
defer out.Close()
_, err = io.Copy(out, bytes.NewBuffer(bodyBytes))
return err
}

View File

@@ -862,8 +862,8 @@ components:
$ref: '#/components/schemas/OSVersionCheck' $ref: '#/components/schemas/OSVersionCheck'
geo_location_check: geo_location_check:
$ref: '#/components/schemas/GeoLocationCheck' $ref: '#/components/schemas/GeoLocationCheck'
private_network_check: peer_network_range_check:
$ref: '#/components/schemas/PrivateNetworkCheck' $ref: '#/components/schemas/PeerNetworkRangeCheck'
NBVersionCheck: NBVersionCheck:
description: Posture check for the version of NetBird description: Posture check for the version of NetBird
type: object type: object
@@ -934,16 +934,16 @@ components:
required: required:
- locations - locations
- action - action
PrivateNetworkCheck: PeerNetworkRangeCheck:
description: Posture check for allow or deny private network description: Posture check for allow or deny access based on peer local network addresses
type: object type: object
properties: properties:
ranges: ranges:
description: List of private network ranges in CIDR notation description: List of peer network ranges in CIDR notation
type: array type: array
items: items:
type: string type: string
example: ["192.168.1.0/24", "10.0.0.0/8"] example: ["192.168.1.0/24", "10.0.0.0/8", "2001:db8:1234:1a00::/56"]
action: action:
description: Action to take upon policy match description: Action to take upon policy match
type: string type: string

View File

@@ -74,6 +74,12 @@ const (
NameserverNsTypeUdp NameserverNsType = "udp" NameserverNsTypeUdp NameserverNsType = "udp"
) )
// Defines values for PeerNetworkRangeCheckAction.
const (
PeerNetworkRangeCheckActionAllow PeerNetworkRangeCheckAction = "allow"
PeerNetworkRangeCheckActionDeny PeerNetworkRangeCheckAction = "deny"
)
// Defines values for PolicyRuleAction. // Defines values for PolicyRuleAction.
const ( const (
PolicyRuleActionAccept PolicyRuleAction = "accept" PolicyRuleActionAccept PolicyRuleAction = "accept"
@@ -116,12 +122,6 @@ const (
PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp" PolicyRuleUpdateProtocolUdp PolicyRuleUpdateProtocol = "udp"
) )
// Defines values for PrivateNetworkCheckAction.
const (
PrivateNetworkCheckActionAllow PrivateNetworkCheckAction = "allow"
PrivateNetworkCheckActionDeny PrivateNetworkCheckAction = "deny"
)
// Defines values for UserStatus. // Defines values for UserStatus.
const ( const (
UserStatusActive UserStatus = "active" UserStatusActive UserStatus = "active"
@@ -199,8 +199,8 @@ type Checks struct {
// OsVersionCheck Posture check for the version of operating system // OsVersionCheck Posture check for the version of operating system
OsVersionCheck *OSVersionCheck `json:"os_version_check,omitempty"` OsVersionCheck *OSVersionCheck `json:"os_version_check,omitempty"`
// PrivateNetworkCheck Posture check for allow or deny private network // PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses
PrivateNetworkCheck *PrivateNetworkCheck `json:"private_network_check,omitempty"` PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:"peer_network_range_check,omitempty"`
} }
// City Describe city geographical location information // City Describe city geographical location information
@@ -656,6 +656,18 @@ type PeerMinimum struct {
Name string `json:"name"` Name string `json:"name"`
} }
// PeerNetworkRangeCheck Posture check for allow or deny access based on peer local network addresses
type PeerNetworkRangeCheck struct {
// Action Action to take upon policy match
Action PeerNetworkRangeCheckAction `json:"action"`
// Ranges List of peer network ranges in CIDR notation
Ranges []string `json:"ranges"`
}
// PeerNetworkRangeCheckAction Action to take upon policy match
type PeerNetworkRangeCheckAction string
// PeerRequest defines model for PeerRequest. // PeerRequest defines model for PeerRequest.
type PeerRequest struct { type PeerRequest struct {
// ApprovalRequired (Cloud only) Indicates whether peer needs approval // ApprovalRequired (Cloud only) Indicates whether peer needs approval
@@ -898,18 +910,6 @@ type PostureCheckUpdate struct {
Name string `json:"name"` Name string `json:"name"`
} }
// PrivateNetworkCheck Posture check for allow or deny private network
type PrivateNetworkCheck struct {
// Action Action to take upon policy match
Action PrivateNetworkCheckAction `json:"action"`
// Ranges List of private network ranges in CIDR notation
Ranges []string `json:"ranges"`
}
// PrivateNetworkCheckAction Action to take upon policy match
type PrivateNetworkCheckAction string
// Route defines model for Route. // Route defines model for Route.
type Route struct { type Route struct {
// Description Route description // Description Route description

View File

@@ -177,7 +177,10 @@ func TestAuthMiddleware_Handler(t *testing.T) {
for _, tc := range tt { for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
if tc.shouldBypassAuth { if tc.shouldBypassAuth {
bypass.AddBypassPath(tc.path) err := bypass.AddBypassPath(tc.path)
if err != nil {
t.Fatalf("failed to add bypass path: %v", err)
}
} }
req := httptest.NewRequest("GET", "http://testing"+tc.path, nil) req := httptest.NewRequest("GET", "http://testing"+tc.path, nil)

View File

@@ -1,8 +1,12 @@
package bypass package bypass
import ( import (
"fmt"
"net/http" "net/http"
"path"
"sync" "sync"
log "github.com/sirupsen/logrus"
) )
var byPassMutex sync.RWMutex var byPassMutex sync.RWMutex
@@ -11,10 +15,16 @@ var byPassMutex sync.RWMutex
var bypassPaths = make(map[string]struct{}) var bypassPaths = make(map[string]struct{})
// AddBypassPath adds an exact path to the list of paths that bypass middleware. // AddBypassPath adds an exact path to the list of paths that bypass middleware.
func AddBypassPath(path string) { // Paths can include wildcards, such as /api/*. Paths are matched using path.Match.
// Returns an error if the path has invalid pattern.
func AddBypassPath(path string) error {
byPassMutex.Lock() byPassMutex.Lock()
defer byPassMutex.Unlock() defer byPassMutex.Unlock()
if err := validatePath(path); err != nil {
return fmt.Errorf("validate: %w", err)
}
bypassPaths[path] = struct{}{} bypassPaths[path] = struct{}{}
return nil
} }
// RemovePath removes a path from the list of paths that bypass middleware. // RemovePath removes a path from the list of paths that bypass middleware.
@@ -24,16 +34,41 @@ func RemovePath(path string) {
delete(bypassPaths, path) delete(bypassPaths, path)
} }
// GetList returns a list of all bypass paths.
func GetList() []string {
byPassMutex.RLock()
defer byPassMutex.RUnlock()
list := make([]string, 0, len(bypassPaths))
for k := range bypassPaths {
list = append(list, k)
}
return list
}
// ShouldBypass checks if the request path is one of the auth bypass paths and returns true if the middleware should be bypassed. // ShouldBypass checks if the request path is one of the auth bypass paths and returns true if the middleware should be bypassed.
// This can be used to bypass authz/authn middlewares for certain paths, such as webhooks that implement their own authentication. // This can be used to bypass authz/authn middlewares for certain paths, such as webhooks that implement their own authentication.
func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *http.Request) bool { func ShouldBypass(requestPath string, h http.Handler, w http.ResponseWriter, r *http.Request) bool {
byPassMutex.RLock() byPassMutex.RLock()
defer byPassMutex.RUnlock() defer byPassMutex.RUnlock()
if _, ok := bypassPaths[requestPath]; ok { for bypassPath := range bypassPaths {
h.ServeHTTP(w, r) matched, err := path.Match(bypassPath, requestPath)
return true if err != nil {
log.Errorf("Error matching path %s with %s from %s: %v", bypassPath, requestPath, GetList(), err)
continue
}
if matched {
h.ServeHTTP(w, r)
return true
}
} }
return false return false
} }
func validatePath(p string) error {
_, err := path.Match(p, "")
return err
}

View File

@@ -11,6 +11,19 @@ import (
"github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/middleware/bypass"
) )
func TestGetList(t *testing.T) {
bypassPaths := []string{"/path1", "/path2", "/path3"}
for _, path := range bypassPaths {
err := bypass.AddBypassPath(path)
require.NoError(t, err, "Adding bypass path should not fail")
}
list := bypass.GetList()
assert.ElementsMatch(t, bypassPaths, list, "Bypass path list did not match expected paths")
}
func TestAuthBypass(t *testing.T) { func TestAuthBypass(t *testing.T) {
dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
@@ -31,6 +44,13 @@ func TestAuthBypass(t *testing.T) {
expectBypass: true, expectBypass: true,
expectHTTPCode: http.StatusOK, expectHTTPCode: http.StatusOK,
}, },
{
name: "Wildcard path added to bypass",
pathToAdd: "/bypass/*",
testPath: "/bypass/extra",
expectBypass: true,
expectHTTPCode: http.StatusOK,
},
{ {
name: "Path not added to bypass", name: "Path not added to bypass",
testPath: "/no-bypass", testPath: "/no-bypass",
@@ -59,6 +79,13 @@ func TestAuthBypass(t *testing.T) {
expectBypass: false, expectBypass: false,
expectHTTPCode: http.StatusOK, expectHTTPCode: http.StatusOK,
}, },
{
name: "Wildcard subpath does not match bypass",
pathToAdd: "/webhook/*",
testPath: "/webhook/extra/path",
expectBypass: false,
expectHTTPCode: http.StatusOK,
},
{ {
name: "Similar path does not match bypass", name: "Similar path does not match bypass",
pathToAdd: "/webhook", pathToAdd: "/webhook",
@@ -78,7 +105,8 @@ func TestAuthBypass(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
if tc.pathToAdd != "" { if tc.pathToAdd != "" {
bypass.AddBypassPath(tc.pathToAdd) err := bypass.AddBypassPath(tc.pathToAdd)
require.NoError(t, err, "Adding bypass path should not fail")
defer bypass.RemovePath(tc.pathToAdd) defer bypass.RemovePath(tc.pathToAdd)
} }

View File

@@ -213,8 +213,8 @@ func (p *PostureChecksHandler) savePostureChecks(
postureChecks.Checks.GeoLocationCheck = toPostureGeoLocationCheck(geoLocationCheck) postureChecks.Checks.GeoLocationCheck = toPostureGeoLocationCheck(geoLocationCheck)
} }
if privateNetworkCheck := req.Checks.PrivateNetworkCheck; privateNetworkCheck != nil { if peerNetworkRangeCheck := req.Checks.PeerNetworkRangeCheck; peerNetworkRangeCheck != nil {
postureChecks.Checks.PrivateNetworkCheck, err = toPrivateNetworkCheck(privateNetworkCheck) postureChecks.Checks.PeerNetworkRangeCheck, err = toPeerNetworkRangeCheck(peerNetworkRangeCheck)
if err != nil { if err != nil {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid network prefix"), w) util.WriteError(status.Errorf(status.InvalidArgument, "invalid network prefix"), w)
return return
@@ -235,7 +235,7 @@ func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {
} }
if req.Checks == nil || (req.Checks.NbVersionCheck == nil && req.Checks.OsVersionCheck == nil && if req.Checks == nil || (req.Checks.NbVersionCheck == nil && req.Checks.OsVersionCheck == nil &&
req.Checks.GeoLocationCheck == nil && req.Checks.PrivateNetworkCheck == nil) { req.Checks.GeoLocationCheck == nil && req.Checks.PeerNetworkRangeCheck == nil) {
return status.Errorf(status.InvalidArgument, "posture checks shouldn't be empty") return status.Errorf(status.InvalidArgument, "posture checks shouldn't be empty")
} }
@@ -278,17 +278,17 @@ func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {
} }
} }
if privateNetworkCheck := req.Checks.PrivateNetworkCheck; privateNetworkCheck != nil { if peerNetworkRangeCheck := req.Checks.PeerNetworkRangeCheck; peerNetworkRangeCheck != nil {
if privateNetworkCheck.Action == "" { if peerNetworkRangeCheck.Action == "" {
return status.Errorf(status.InvalidArgument, "action for private network check shouldn't be empty") return status.Errorf(status.InvalidArgument, "action for peer network range check shouldn't be empty")
} }
allowedActions := []api.PrivateNetworkCheckAction{api.PrivateNetworkCheckActionAllow, api.PrivateNetworkCheckActionDeny} allowedActions := []api.PeerNetworkRangeCheckAction{api.PeerNetworkRangeCheckActionAllow, api.PeerNetworkRangeCheckActionDeny}
if !slices.Contains(allowedActions, privateNetworkCheck.Action) { if !slices.Contains(allowedActions, peerNetworkRangeCheck.Action) {
return status.Errorf(status.InvalidArgument, "action for private network check is not valid value") return status.Errorf(status.InvalidArgument, "action for peer network range check is not valid value")
} }
if len(privateNetworkCheck.Ranges) == 0 { if len(peerNetworkRangeCheck.Ranges) == 0 {
return status.Errorf(status.InvalidArgument, "network ranges for private network check shouldn't be empty") return status.Errorf(status.InvalidArgument, "network ranges for peer network range check shouldn't be empty")
} }
} }
@@ -318,8 +318,8 @@ func toPostureChecksResponse(postureChecks *posture.Checks) *api.PostureCheck {
checks.GeoLocationCheck = toGeoLocationCheckResponse(postureChecks.Checks.GeoLocationCheck) checks.GeoLocationCheck = toGeoLocationCheckResponse(postureChecks.Checks.GeoLocationCheck)
} }
if postureChecks.Checks.PrivateNetworkCheck != nil { if postureChecks.Checks.PeerNetworkRangeCheck != nil {
checks.PrivateNetworkCheck = toPrivateNetworkCheckResponse(postureChecks.Checks.PrivateNetworkCheck) checks.PeerNetworkRangeCheck = toPeerNetworkRangeCheckResponse(postureChecks.Checks.PeerNetworkRangeCheck)
} }
return &api.PostureCheck{ return &api.PostureCheck{
@@ -369,19 +369,19 @@ func toPostureGeoLocationCheck(apiGeoLocationCheck *api.GeoLocationCheck) *postu
} }
} }
func toPrivateNetworkCheckResponse(check *posture.PrivateNetworkCheck) *api.PrivateNetworkCheck { func toPeerNetworkRangeCheckResponse(check *posture.PeerNetworkRangeCheck) *api.PeerNetworkRangeCheck {
netPrefixes := make([]string, 0, len(check.Ranges)) netPrefixes := make([]string, 0, len(check.Ranges))
for _, netPrefix := range check.Ranges { for _, netPrefix := range check.Ranges {
netPrefixes = append(netPrefixes, netPrefix.String()) netPrefixes = append(netPrefixes, netPrefix.String())
} }
return &api.PrivateNetworkCheck{ return &api.PeerNetworkRangeCheck{
Ranges: netPrefixes, Ranges: netPrefixes,
Action: api.PrivateNetworkCheckAction(check.Action), Action: api.PeerNetworkRangeCheckAction(check.Action),
} }
} }
func toPrivateNetworkCheck(check *api.PrivateNetworkCheck) (*posture.PrivateNetworkCheck, error) { func toPeerNetworkRangeCheck(check *api.PeerNetworkRangeCheck) (*posture.PeerNetworkRangeCheck, error) {
prefixes := make([]netip.Prefix, 0) prefixes := make([]netip.Prefix, 0)
for _, prefix := range check.Ranges { for _, prefix := range check.Ranges {
parsedPrefix, err := netip.ParsePrefix(prefix) parsedPrefix, err := netip.ParsePrefix(prefix)
@@ -391,7 +391,7 @@ func toPrivateNetworkCheck(check *api.PrivateNetworkCheck) (*posture.PrivateNetw
prefixes = append(prefixes, parsedPrefix) prefixes = append(prefixes, parsedPrefix)
} }
return &posture.PrivateNetworkCheck{ return &posture.PeerNetworkRangeCheck{
Ranges: prefixes, Ranges: prefixes,
Action: string(check.Action), Action: string(check.Action),
}, nil }, nil

View File

@@ -131,7 +131,7 @@ func TestGetPostureCheck(t *testing.T) {
ID: "privateNetworkPostureCheck", ID: "privateNetworkPostureCheck",
Name: "privateNetwork", Name: "privateNetwork",
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
PrivateNetworkCheck: &posture.PrivateNetworkCheck{ PeerNetworkRangeCheck: &posture.PeerNetworkRangeCheck{
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
}, },
@@ -375,7 +375,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}, },
}, },
{ {
name: "Create Posture Checks Private Network", name: "Create Posture Checks Peer Network Range",
requestType: http.MethodPost, requestType: http.MethodPost,
requestPath: "/api/posture-checks", requestPath: "/api/posture-checks",
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
@@ -383,7 +383,7 @@ func TestPostureCheckUpdate(t *testing.T) {
"name": "default", "name": "default",
"description": "default", "description": "default",
"checks": { "checks": {
"private_network_check": { "peer_network_range_check": {
"action": "allow", "action": "allow",
"ranges": [ "ranges": [
"10.0.0.0/8" "10.0.0.0/8"
@@ -398,11 +398,11 @@ func TestPostureCheckUpdate(t *testing.T) {
Name: "default", Name: "default",
Description: str("default"), Description: str("default"),
Checks: api.Checks{ Checks: api.Checks{
PrivateNetworkCheck: &api.PrivateNetworkCheck{ PeerNetworkRangeCheck: &api.PeerNetworkRangeCheck{
Ranges: []string{ Ranges: []string{
"10.0.0.0/8", "10.0.0.0/8",
}, },
Action: api.PrivateNetworkCheckActionAllow, Action: api.PeerNetworkRangeCheckActionAllow,
}, },
}, },
}, },
@@ -715,14 +715,14 @@ func TestPostureCheckUpdate(t *testing.T) {
expectedBody: false, expectedBody: false,
}, },
{ {
name: "Update Posture Checks Private Network", name: "Update Posture Checks Peer Network Range",
requestType: http.MethodPut, requestType: http.MethodPut,
requestPath: "/api/posture-checks/privateNetworkPostureCheck", requestPath: "/api/posture-checks/peerNetworkRangePostureCheck",
requestBody: bytes.NewBuffer( requestBody: bytes.NewBuffer(
[]byte(`{ []byte(`{
"name": "default", "name": "default",
"checks": { "checks": {
"private_network_check": { "peer_network_range_check": {
"action": "deny", "action": "deny",
"ranges": [ "ranges": [
"192.168.1.0/24" "192.168.1.0/24"
@@ -737,11 +737,11 @@ func TestPostureCheckUpdate(t *testing.T) {
Name: "default", Name: "default",
Description: str(""), Description: str(""),
Checks: api.Checks{ Checks: api.Checks{
PrivateNetworkCheck: &api.PrivateNetworkCheck{ PeerNetworkRangeCheck: &api.PeerNetworkRangeCheck{
Ranges: []string{ Ranges: []string{
"192.168.1.0/24", "192.168.1.0/24",
}, },
Action: api.PrivateNetworkCheckActionDeny, Action: api.PeerNetworkRangeCheckActionDeny,
}, },
}, },
}, },
@@ -784,10 +784,10 @@ func TestPostureCheckUpdate(t *testing.T) {
}, },
}, },
&posture.Checks{ &posture.Checks{
ID: "privateNetworkPostureCheck", ID: "peerNetworkRangePostureCheck",
Name: "privateNetwork", Name: "peerNetworkRange",
Checks: posture.ChecksDefinition{ Checks: posture.ChecksDefinition{
PrivateNetworkCheck: &posture.PrivateNetworkCheck{ PeerNetworkRangeCheck: &posture.PeerNetworkRangeCheck{
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
}, },
@@ -891,29 +891,50 @@ func TestPostureCheck_validatePostureChecksUpdate(t *testing.T) {
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}}) err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.NoError(t, err) assert.NoError(t, err)
// valid private network check // valid peer network range check
privateNetworkCheck := api.PrivateNetworkCheck{ peerNetworkRangeCheck := api.PeerNetworkRangeCheck{
Action: api.PrivateNetworkCheckActionAllow, Action: api.PeerNetworkRangeCheckActionAllow,
Ranges: []string{ Ranges: []string{
"192.168.1.0/24", "10.0.0.0/8", "192.168.1.0/24", "10.0.0.0/8",
}, },
} }
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{PrivateNetworkCheck: &privateNetworkCheck}}) err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.NoError(t, err) assert.NoError(t, err)
// invalid private network check // invalid peer network range check
privateNetworkCheck = api.PrivateNetworkCheck{ peerNetworkRangeCheck = api.PeerNetworkRangeCheck{
Action: api.PrivateNetworkCheckActionDeny, Action: api.PeerNetworkRangeCheckActionDeny,
Ranges: []string{}, Ranges: []string{},
} }
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{PrivateNetworkCheck: &privateNetworkCheck}}) err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.Error(t, err) assert.Error(t, err)
// invalid private network check // invalid peer network range check
privateNetworkCheck = api.PrivateNetworkCheck{ peerNetworkRangeCheck = api.PeerNetworkRangeCheck{
Action: "unknownAction", Action: "unknownAction",
Ranges: []string{}, Ranges: []string{},
} }
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{PrivateNetworkCheck: &privateNetworkCheck}}) err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.Error(t, err) assert.Error(t, err)
} }

View File

@@ -10,10 +10,10 @@ import (
) )
const ( const (
NBVersionCheckName = "NBVersionCheck" NBVersionCheckName = "NBVersionCheck"
OSVersionCheckName = "OSVersionCheck" OSVersionCheckName = "OSVersionCheck"
GeoLocationCheckName = "GeoLocationCheck" GeoLocationCheckName = "GeoLocationCheck"
PrivateNetworkCheckName = "PrivateNetworkCheck" PeerNetworkRangeCheckName = "PeerNetworkRangeCheck"
CheckActionAllow string = "allow" CheckActionAllow string = "allow"
CheckActionDeny string = "deny" CheckActionDeny string = "deny"
@@ -44,10 +44,10 @@ type Checks struct {
// ChecksDefinition contains definition of actual check // ChecksDefinition contains definition of actual check
type ChecksDefinition struct { type ChecksDefinition struct {
NBVersionCheck *NBVersionCheck `json:",omitempty"` NBVersionCheck *NBVersionCheck `json:",omitempty"`
OSVersionCheck *OSVersionCheck `json:",omitempty"` OSVersionCheck *OSVersionCheck `json:",omitempty"`
GeoLocationCheck *GeoLocationCheck `json:",omitempty"` GeoLocationCheck *GeoLocationCheck `json:",omitempty"`
PrivateNetworkCheck *PrivateNetworkCheck `json:",omitempty"` PeerNetworkRangeCheck *PeerNetworkRangeCheck `json:",omitempty"`
} }
// Copy returns a copy of a checks definition. // Copy returns a copy of a checks definition.
@@ -85,13 +85,13 @@ func (cd ChecksDefinition) Copy() ChecksDefinition {
} }
copy(cdCopy.GeoLocationCheck.Locations, geoCheck.Locations) copy(cdCopy.GeoLocationCheck.Locations, geoCheck.Locations)
} }
if cd.PrivateNetworkCheck != nil { if cd.PeerNetworkRangeCheck != nil {
privateNetCheck := cd.PrivateNetworkCheck peerNetRangeCheck := cd.PeerNetworkRangeCheck
cdCopy.PrivateNetworkCheck = &PrivateNetworkCheck{ cdCopy.PeerNetworkRangeCheck = &PeerNetworkRangeCheck{
Action: privateNetCheck.Action, Action: peerNetRangeCheck.Action,
Ranges: make([]netip.Prefix, len(privateNetCheck.Ranges)), Ranges: make([]netip.Prefix, len(peerNetRangeCheck.Ranges)),
} }
copy(cdCopy.PrivateNetworkCheck.Ranges, privateNetCheck.Ranges) copy(cdCopy.PeerNetworkRangeCheck.Ranges, peerNetRangeCheck.Ranges)
} }
return cdCopy return cdCopy
} }
@@ -130,8 +130,8 @@ func (pc *Checks) GetChecks() []Check {
if pc.Checks.GeoLocationCheck != nil { if pc.Checks.GeoLocationCheck != nil {
checks = append(checks, pc.Checks.GeoLocationCheck) checks = append(checks, pc.Checks.GeoLocationCheck)
} }
if pc.Checks.PrivateNetworkCheck != nil { if pc.Checks.PeerNetworkRangeCheck != nil {
checks = append(checks, pc.Checks.PrivateNetworkCheck) checks = append(checks, pc.Checks.PeerNetworkRangeCheck)
} }
return checks return checks
} }

View File

@@ -254,7 +254,7 @@ func TestChecks_Copy(t *testing.T) {
}, },
Action: CheckActionAllow, Action: CheckActionAllow,
}, },
PrivateNetworkCheck: &PrivateNetworkCheck{ PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("10.0.0.0/8"), netip.MustParsePrefix("10.0.0.0/8"),

View File

@@ -8,16 +8,16 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
) )
type PrivateNetworkCheck struct { type PeerNetworkRangeCheck struct {
Action string Action string
Ranges []netip.Prefix `gorm:"serializer:json"` Ranges []netip.Prefix `gorm:"serializer:json"`
} }
var _ Check = (*PrivateNetworkCheck)(nil) var _ Check = (*PeerNetworkRangeCheck)(nil)
func (p *PrivateNetworkCheck) Check(peer nbpeer.Peer) (bool, error) { func (p *PeerNetworkRangeCheck) Check(peer nbpeer.Peer) (bool, error) {
if len(peer.Meta.NetworkAddresses) == 0 { if len(peer.Meta.NetworkAddresses) == 0 {
return false, fmt.Errorf("peer's does not contain private network addresses") return false, fmt.Errorf("peer's does not contain peer network range addresses")
} }
maskedPrefixes := make([]netip.Prefix, 0, len(p.Ranges)) maskedPrefixes := make([]netip.Prefix, 0, len(p.Ranges))
@@ -34,7 +34,7 @@ func (p *PrivateNetworkCheck) Check(peer nbpeer.Peer) (bool, error) {
case CheckActionAllow: case CheckActionAllow:
return true, nil return true, nil
default: default:
return false, fmt.Errorf("invalid private network check action: %s", p.Action) return false, fmt.Errorf("invalid peer network range check action: %s", p.Action)
} }
} }
} }
@@ -46,9 +46,9 @@ func (p *PrivateNetworkCheck) Check(peer nbpeer.Peer) (bool, error) {
return false, nil return false, nil
} }
return false, fmt.Errorf("invalid private network check action: %s", p.Action) return false, fmt.Errorf("invalid peer network range check action: %s", p.Action)
} }
func (p *PrivateNetworkCheck) Name() string { func (p *PeerNetworkRangeCheck) Name() string {
return PrivateNetworkCheckName return PeerNetworkRangeCheckName
} }

View File

@@ -9,17 +9,17 @@ import (
nbpeer "github.com/netbirdio/netbird/management/server/peer" nbpeer "github.com/netbirdio/netbird/management/server/peer"
) )
func TestPrivateNetworkCheck_Check(t *testing.T) { func TestPeerNetworkRangeCheck_Check(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
check PrivateNetworkCheck check PeerNetworkRangeCheck
peer nbpeer.Peer peer nbpeer.Peer
wantErr bool wantErr bool
isValid bool isValid bool
}{ }{
{ {
name: "Peer private networks matches the allowed range", name: "Peer networks range matches the allowed range",
check: PrivateNetworkCheck{ check: PeerNetworkRangeCheck{
Action: CheckActionAllow, Action: CheckActionAllow,
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
@@ -42,8 +42,8 @@ func TestPrivateNetworkCheck_Check(t *testing.T) {
isValid: true, isValid: true,
}, },
{ {
name: "Peer private networks doesn't matches the allowed range", name: "Peer networks range doesn't matches the allowed range",
check: PrivateNetworkCheck{ check: PeerNetworkRangeCheck{
Action: CheckActionAllow, Action: CheckActionAllow,
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
@@ -63,8 +63,8 @@ func TestPrivateNetworkCheck_Check(t *testing.T) {
isValid: false, isValid: false,
}, },
{ {
name: "Peer with no privates network in the allow range", name: "Peer with no network range in the allow range",
check: PrivateNetworkCheck{ check: PeerNetworkRangeCheck{
Action: CheckActionAllow, Action: CheckActionAllow,
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"), netip.MustParsePrefix("192.168.0.0/16"),
@@ -76,8 +76,8 @@ func TestPrivateNetworkCheck_Check(t *testing.T) {
isValid: false, isValid: false,
}, },
{ {
name: "Peer private networks matches the denied range", name: "Peer networks range matches the denied range",
check: PrivateNetworkCheck{ check: PeerNetworkRangeCheck{
Action: CheckActionDeny, Action: CheckActionDeny,
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
@@ -100,8 +100,8 @@ func TestPrivateNetworkCheck_Check(t *testing.T) {
isValid: false, isValid: false,
}, },
{ {
name: "Peer private networks doesn't matches the denied range", name: "Peer networks range doesn't matches the denied range",
check: PrivateNetworkCheck{ check: PeerNetworkRangeCheck{
Action: CheckActionDeny, Action: CheckActionDeny,
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"), netip.MustParsePrefix("192.168.0.0/24"),
@@ -121,8 +121,8 @@ func TestPrivateNetworkCheck_Check(t *testing.T) {
isValid: true, isValid: true,
}, },
{ {
name: "Peer with no private networks in the denied range", name: "Peer with no networks range in the denied range",
check: PrivateNetworkCheck{ check: PeerNetworkRangeCheck{
Action: CheckActionDeny, Action: CheckActionDeny,
Ranges: []netip.Prefix{ Ranges: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/16"), netip.MustParsePrefix("192.168.0.0/16"),

View File

@@ -1,9 +1,10 @@
package server package server
import ( import (
log "github.com/sirupsen/logrus"
"sync" "sync"
"time" "time"
log "github.com/sirupsen/logrus"
) )
// Scheduler is an interface which implementations can schedule and cancel jobs // Scheduler is an interface which implementations can schedule and cancel jobs
@@ -55,14 +56,8 @@ func (wm *DefaultScheduler) cancel(ID string) bool {
cancel, ok := wm.jobs[ID] cancel, ok := wm.jobs[ID]
if ok { if ok {
delete(wm.jobs, ID) delete(wm.jobs, ID)
select { close(cancel)
case cancel <- struct{}{}: log.Debugf("cancelled scheduled job %s", ID)
log.Debugf("cancelled scheduled job %s", ID)
default:
log.Warnf("couldn't cancel job %s because there was no routine listening on the cancel event", ID)
return false
}
} }
return ok return ok
} }
@@ -90,25 +85,41 @@ func (wm *DefaultScheduler) Schedule(in time.Duration, ID string, job func() (ne
return return
} }
ticker := time.NewTicker(in)
wm.jobs[ID] = cancel wm.jobs[ID] = cancel
log.Debugf("scheduled a job %s to run in %s. There are %d total jobs scheduled.", ID, in.String(), len(wm.jobs)) log.Debugf("scheduled a job %s to run in %s. There are %d total jobs scheduled.", ID, in.String(), len(wm.jobs))
go func() { go func() {
select { for {
case <-time.After(in): select {
log.Debugf("time to do a scheduled job %s", ID) case <-ticker.C:
runIn, reschedule := job() select {
wm.mu.Lock() case <-cancel:
defer wm.mu.Unlock() log.Debugf("scheduled job %s was canceled, stop timer", ID)
delete(wm.jobs, ID) ticker.Stop()
if reschedule { return
go wm.Schedule(runIn, ID, job) default:
log.Debugf("time to do a scheduled job %s", ID)
}
runIn, reschedule := job()
if !reschedule {
wm.mu.Lock()
defer wm.mu.Unlock()
delete(wm.jobs, ID)
log.Debugf("job %s is not scheduled to run again", ID)
ticker.Stop()
return
}
// we need this comparison to avoid resetting the ticker with the same duration and missing the current elapsesed time
if runIn != in {
ticker.Reset(runIn)
}
case <-cancel:
log.Debugf("job %s was canceled, stopping timer", ID)
ticker.Stop()
return
} }
case <-cancel:
log.Debugf("stopped scheduled job %s ", ID)
wm.mu.Lock()
defer wm.mu.Unlock()
delete(wm.jobs, ID)
return
} }
}() }()
} }

View File

@@ -2,11 +2,12 @@ package server
import ( import (
"fmt" "fmt"
"github.com/stretchr/testify/assert"
"math/rand" "math/rand"
"sync" "sync"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
) )
func TestScheduler_Performance(t *testing.T) { func TestScheduler_Performance(t *testing.T) {
@@ -36,15 +37,24 @@ func TestScheduler_Cancel(t *testing.T) {
jobID1 := "test-scheduler-job-1" jobID1 := "test-scheduler-job-1"
jobID2 := "test-scheduler-job-2" jobID2 := "test-scheduler-job-2"
scheduler := NewDefaultScheduler() scheduler := NewDefaultScheduler()
scheduler.Schedule(2*time.Second, jobID1, func() (nextRunIn time.Duration, reschedule bool) { tChan := make(chan struct{})
return 0, false p := []string{jobID1, jobID2}
scheduler.Schedule(2*time.Millisecond, jobID1, func() (nextRunIn time.Duration, reschedule bool) {
tt := p[0]
<-tChan
t.Logf("job %s", tt)
return 2 * time.Millisecond, true
}) })
scheduler.Schedule(2*time.Second, jobID2, func() (nextRunIn time.Duration, reschedule bool) { scheduler.Schedule(2*time.Millisecond, jobID2, func() (nextRunIn time.Duration, reschedule bool) {
return 0, false return 2 * time.Millisecond, true
}) })
time.Sleep(4 * time.Millisecond)
assert.Len(t, scheduler.jobs, 2) assert.Len(t, scheduler.jobs, 2)
scheduler.Cancel([]string{jobID1}) scheduler.Cancel([]string{jobID1})
close(tChan)
p = []string{}
time.Sleep(4 * time.Millisecond)
assert.Len(t, scheduler.jobs, 1) assert.Len(t, scheduler.jobs, 1)
assert.NotNil(t, scheduler.jobs[jobID2]) assert.NotNil(t, scheduler.jobs[jobID2])
} }