Refactor posture check validations (#1705)

* Add posture checks validation

* Refactor code to incorporate posture checks validation directly into management.

* Add posture checks validation for geolocation, OS version, network, process, and NB-version

* Fix tests
This commit is contained in:
Bethuel Mmbaga
2024-03-14 23:16:50 +03:00
committed by GitHub
parent 90ab2f7c89
commit 180f5a122e
16 changed files with 518 additions and 324 deletions

View File

@@ -2,6 +2,7 @@ package http
import (
"net/http"
"regexp"
"github.com/gorilla/mux"
@@ -13,6 +14,10 @@ import (
"github.com/netbirdio/netbird/management/server/status"
)
var (
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
)
// GeolocationsHandler is a handler that returns locations.
type GeolocationsHandler struct {
accountManager server.AccountManager
@@ -73,8 +78,8 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http.
}
if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
return
}

View File

@@ -4,8 +4,6 @@ import (
"encoding/json"
"net/http"
"net/netip"
"regexp"
"slices"
"github.com/gorilla/mux"
"github.com/rs/xid"
@@ -19,10 +17,6 @@ import (
"github.com/netbirdio/netbird/management/server/status"
)
var (
countryCodeRegex = regexp.MustCompile("^[a-zA-Z]{2}$")
)
// PostureChecksHandler is a handler that returns posture checks of the account.
type PostureChecksHandler struct {
accountManager server.AccountManager
@@ -165,19 +159,16 @@ func (p *PostureChecksHandler) savePostureChecks(
user *server.User,
postureChecksID string,
) {
var (
err error
req api.PostureCheckUpdate
)
var req api.PostureCheckUpdate
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
err := validatePostureChecksUpdate(req)
if err != nil {
util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w)
return
}
if postureChecksID == "" {
postureChecksID = xid.New().String()
}
@@ -206,8 +197,8 @@ func (p *PostureChecksHandler) savePostureChecks(
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if p.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
util.WriteError(status.Errorf(status.PreconditionFailed, "Geo location database is not initialized. "+
"Check the self-hosted Geo database documentation at https://docs.netbird.io/selfhosted/geo-support"), w)
return
}
postureChecks.Checks.GeoLocationCheck = toPostureGeoLocationCheck(geoLocationCheck)
@@ -233,84 +224,6 @@ func (p *PostureChecksHandler) savePostureChecks(
util.WriteJSONObject(w, toPostureChecksResponse(&postureChecks))
}
func validatePostureChecksUpdate(req api.PostureCheckUpdate) error {
if req.Name == "" {
return status.Errorf(status.InvalidArgument, "posture checks name shouldn't be empty")
}
if req.Checks == nil || (req.Checks.NbVersionCheck == nil && req.Checks.OsVersionCheck == nil &&
req.Checks.GeoLocationCheck == nil && req.Checks.PeerNetworkRangeCheck == nil && req.Checks.ProcessCheck == nil) {
return status.Errorf(status.InvalidArgument, "posture checks shouldn't be empty")
}
if req.Checks.NbVersionCheck != nil && req.Checks.NbVersionCheck.MinVersion == "" {
return status.Errorf(status.InvalidArgument, "minimum version for NetBird's version check shouldn't be empty")
}
if osVersionCheck := req.Checks.OsVersionCheck; osVersionCheck != nil {
emptyOS := osVersionCheck.Android == nil && osVersionCheck.Darwin == nil && osVersionCheck.Ios == nil &&
osVersionCheck.Linux == nil && osVersionCheck.Windows == nil
emptyMinVersion := osVersionCheck.Android != nil && osVersionCheck.Android.MinVersion == "" ||
osVersionCheck.Darwin != nil && osVersionCheck.Darwin.MinVersion == "" ||
osVersionCheck.Ios != nil && osVersionCheck.Ios.MinVersion == "" ||
osVersionCheck.Linux != nil && osVersionCheck.Linux.MinKernelVersion == "" ||
osVersionCheck.Windows != nil && osVersionCheck.Windows.MinKernelVersion == ""
if emptyOS || emptyMinVersion {
return status.Errorf(status.InvalidArgument,
"minimum version for at least one OS in the OS version check shouldn't be empty")
}
}
if geoLocationCheck := req.Checks.GeoLocationCheck; geoLocationCheck != nil {
if geoLocationCheck.Action == "" {
return status.Errorf(status.InvalidArgument, "action for geolocation check shouldn't be empty")
}
allowedActions := []api.GeoLocationCheckAction{api.GeoLocationCheckActionAllow, api.GeoLocationCheckActionDeny}
if !slices.Contains(allowedActions, geoLocationCheck.Action) {
return status.Errorf(status.InvalidArgument, "action for geolocation check is not valid value")
}
if len(geoLocationCheck.Locations) == 0 {
return status.Errorf(status.InvalidArgument, "locations for geolocation check shouldn't be empty")
}
for _, loc := range geoLocationCheck.Locations {
if loc.CountryCode == "" {
return status.Errorf(status.InvalidArgument, "country code for geolocation check shouldn't be empty")
}
if !countryCodeRegex.MatchString(loc.CountryCode) {
return status.Errorf(status.InvalidArgument, "country code must be 2 letters (ISO 3166-1 alpha-2 format)")
}
}
}
if peerNetworkRangeCheck := req.Checks.PeerNetworkRangeCheck; peerNetworkRangeCheck != nil {
if peerNetworkRangeCheck.Action == "" {
return status.Errorf(status.InvalidArgument, "action for peer network range check shouldn't be empty")
}
allowedActions := []api.PeerNetworkRangeCheckAction{api.PeerNetworkRangeCheckActionAllow, api.PeerNetworkRangeCheckActionDeny}
if !slices.Contains(allowedActions, peerNetworkRangeCheck.Action) {
return status.Errorf(status.InvalidArgument, "action for peer network range check is not valid value")
}
if len(peerNetworkRangeCheck.Ranges) == 0 {
return status.Errorf(status.InvalidArgument, "network ranges for peer network range check shouldn't be empty")
}
}
if processCheck := req.Checks.ProcessCheck; processCheck != nil {
if len(processCheck.Processes) == 0 {
return status.Errorf(status.InvalidArgument, "processes for process check shouldn't be empty")
}
for _, process := range processCheck.Processes {
if process.Path == nil && process.WindowsPath == nil {
return status.Errorf(status.InvalidArgument, "path for process check shouldn't be empty")
}
}
}
return nil
}
func toPostureChecksResponse(postureChecks *posture.Checks) *api.PostureCheck {
var checks api.Checks

View File

@@ -43,6 +43,11 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
SavePostureChecksFunc: func(accountID, userID string, postureChecks *posture.Checks) error {
postureChecks.ID = "postureCheck"
testPostureChecks[postureChecks.ID] = postureChecks
if err := postureChecks.Validate(); err != nil {
return status.Errorf(status.InvalidArgument, err.Error())
}
return nil
},
DeletePostureChecksFunc: func(accountID, postureChecksID, userID string) error {
@@ -483,7 +488,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -498,7 +503,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -512,7 +517,7 @@ func TestPostureCheckUpdate(t *testing.T) {
"nb_version_check": {}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -526,7 +531,7 @@ func TestPostureCheckUpdate(t *testing.T) {
"geo_location_check": {}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -700,11 +705,8 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
setupHandlerFunc: func(handler *PostureChecksHandler) {
handler.geolocationManager = nil
},
},
{
name: "Update Posture Checks Invalid Check",
@@ -719,7 +721,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -734,7 +736,7 @@ func TestPostureCheckUpdate(t *testing.T) {
}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -748,7 +750,7 @@ func TestPostureCheckUpdate(t *testing.T) {
"nb_version_check": {}
}
}`)),
expectedStatus: http.StatusBadRequest,
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
@@ -878,143 +880,3 @@ func TestPostureCheckUpdate(t *testing.T) {
})
}
}
func TestPostureCheck_validatePostureChecksUpdate(t *testing.T) {
str := func(s string) *string { return &s }
// empty name
err := validatePostureChecksUpdate(api.PostureCheckUpdate{})
assert.Error(t, err)
// empty checks
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default"})
assert.Error(t, err)
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{}})
assert.Error(t, err)
// not valid NbVersionCheck
nbVersionCheck := api.NBVersionCheck{}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{NbVersionCheck: &nbVersionCheck}})
assert.Error(t, err)
// valid NbVersionCheck
nbVersionCheck = api.NBVersionCheck{MinVersion: "1.0"}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{NbVersionCheck: &nbVersionCheck}})
assert.NoError(t, err)
// not valid OsVersionCheck
osVersionCheck := api.OSVersionCheck{}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.Error(t, err)
// not valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{}}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.Error(t, err)
// not valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{}, Darwin: &api.MinVersionCheck{MinVersion: "14.2"}}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.Error(t, err)
// valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{Linux: &api.MinKernelVersionCheck{MinKernelVersion: "6.0"}}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.NoError(t, err)
// valid OsVersionCheck
osVersionCheck = api.OSVersionCheck{
Linux: &api.MinKernelVersionCheck{MinKernelVersion: "6.0"},
Darwin: &api.MinVersionCheck{MinVersion: "14.2"},
}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{OsVersionCheck: &osVersionCheck}})
assert.NoError(t, err)
// valid peer network range check
peerNetworkRangeCheck := api.PeerNetworkRangeCheck{
Action: api.PeerNetworkRangeCheckActionAllow,
Ranges: []string{
"192.168.1.0/24", "10.0.0.0/8",
},
}
err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.NoError(t, err)
// invalid peer network range check
peerNetworkRangeCheck = api.PeerNetworkRangeCheck{
Action: api.PeerNetworkRangeCheckActionDeny,
Ranges: []string{},
}
err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.Error(t, err)
// invalid peer network range check
peerNetworkRangeCheck = api.PeerNetworkRangeCheck{
Action: "unknownAction",
Ranges: []string{},
}
err = validatePostureChecksUpdate(
api.PostureCheckUpdate{
Name: "Default",
Checks: &api.Checks{
PeerNetworkRangeCheck: &peerNetworkRangeCheck,
},
},
)
assert.Error(t, err)
// valid process check
processCheck := api.ProcessCheck{
Processes: []api.Process{
{
Path: str("/usr/local/bin/netbird"),
WindowsPath: str("C:\\ProgramData\\NetBird\\netbird.exe"),
},
},
}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{ProcessCheck: &processCheck}})
assert.NoError(t, err)
// valid unix process check
processCheck = api.ProcessCheck{
Processes: []api.Process{
{
Path: str("/usr/local/bin/netbird"),
},
},
}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{ProcessCheck: &processCheck}})
assert.NoError(t, err)
// valid window process check
processCheck = api.ProcessCheck{
Processes: []api.Process{
{
WindowsPath: str("C:\\ProgramData\\NetBird\\netbird.exe"),
},
},
}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{ProcessCheck: &processCheck}})
assert.NoError(t, err)
// invalid process check
processCheck = api.ProcessCheck{
Processes: make([]api.Process, 0),
}
err = validatePostureChecksUpdate(api.PostureCheckUpdate{Name: "Default", Checks: &api.Checks{ProcessCheck: &processCheck}})
assert.Error(t, err)
}