Feat rego default policy (#700)

Converts rules to Rego policies and allow users to write raw policies to set up connectivity and firewall on the clients.
This commit is contained in:
Givi Khojanashvili
2023-03-13 15:14:18 +01:00
committed by GitHub
parent 221934447e
commit 3bfa26b13b
25 changed files with 1740 additions and 890 deletions

View File

@@ -1,7 +1,7 @@
openapi: 3.0.1
info:
title: NetBird REST API
description: API to manipulate groups, rules and retrieve information about peers and users
description: API to manipulate groups, rules, policies and retrieve information about peers and users
version: 0.0.1
tags:
- name: Users
@@ -14,6 +14,8 @@ tags:
description: Interact with and view information about groups.
- name: Rules
description: Interact with and view information about rules.
- name: Policies
description: Interact with and view information about policies.
- name: Routes
description: Interact with and view information about routes.
- name: DNS
@@ -393,6 +395,77 @@ components:
enum: [ "name","description","disabled","flow","sources","destinations" ]
required:
- path
PolicyRule:
type: object
properties:
id:
description: Rule ID
type: string
name:
description: Rule name identifier
type: string
description:
description: Rule friendly description
type: string
enabled:
description: Rules status
type: boolean
sources:
description: policy source groups
type: array
items:
$ref: '#/components/schemas/GroupMinimum'
destinations:
description: policy destination groups
type: array
items:
$ref: '#/components/schemas/GroupMinimum'
action:
description: policy accept or drops packets
type: string
enum: ["accept","drop"]
required:
- name
- sources
- destinations
- action
- enabled
PolicyMinimum:
type: object
properties:
name:
description: Policy name identifier
type: string
description:
description: Policy friendly description
type: string
enabled:
description: Policy status
type: boolean
query:
description: Policy Rego query
type: string
rules:
description: Policy rule object for policy UI editor
type: array
items:
$ref: '#/components/schemas/PolicyRule'
required:
- name
- description
- enabled
- query
- rules
Policy:
allOf:
- $ref: '#/components/schemas/PolicyMinimum'
- type: object
properties:
id:
description: Policy ID
type: string
required:
- id
RouteRequest:
type: object
properties:
@@ -574,12 +647,12 @@ components:
"setupkey.peer.add", "setupkey.add", "setupkey.update", "setupkey.revoke", "setupkey.overuse",
"setupkey.group.delete", "setupkey.group.add",
"rule.add", "rule.delete", "rule.update",
"policy.add", "policy.delete", "policy.update",
"group.add", "group.update", "dns.setting.disabled.management.group.add", "dns.setting.disabled.management.group.delete",
"account.create", "account.setting.peer.login.expiration.update", "account.setting.peer.login.expiration.disable", "account.setting.peer.login.expiration.enable",
"route.add", "route.delete", "route.update",
"nameserver.group.add", "nameserver.group.delete", "nameserver.group.update",
"peer.ssh.disable", "peer.ssh.enable", "peer.rename", "peer.login.expiration.disable", "peer.login.expiration.enable"
]
"peer.ssh.disable", "peer.ssh.enable", "peer.rename", "peer.login.expiration.disable", "peer.login.expiration.enable" ]
initiator_id:
description: The ID of the initiator of the event. E.g., an ID of a user that triggered the event.
type: string
@@ -1337,41 +1410,6 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
patch:
summary: Update information about a Rule
tags: [ Rules ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The Rule ID
requestBody:
description: Update Rule request using a list of json patch objects
content:
'application/json':
schema:
type: array
items:
$ref: '#/components/schemas/RulePatchOperation'
responses:
'200':
description: A Rule object
content:
application/json:
schema:
$ref: '#/components/schemas/Rule'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a Rule
tags: [ Rules ]
@@ -1396,7 +1434,134 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/policies:
get:
summary: Returns a list of all Policies
tags: [ Policies ]
security:
- BearerAuth: [ ]
responses:
'200':
description: A JSON Array of Policies
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/Policy'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Creates a Policy
tags: [ Policies ]
security:
- BearerAuth: [ ]
requestBody:
description: New Policy request
content:
'application/json':
schema:
allOf:
- $ref: '#/components/schemas/PolicyMinimum'
responses:
'200':
description: A Policy Object
content:
application/json:
schema:
$ref: '#/components/schemas/Policy'
/api/policies/{id}:
get:
summary: Get information about a Policies
tags: [ Policies ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The Policy ID
responses:
'200':
description: A Policy object
content:
application/json:
schema:
$ref: '#/components/schemas/Policy'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
put:
summary: Update/Replace a Policy
tags: [ Policies ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The Policy ID
requestBody:
description: Update Policy request
content:
'application/json':
schema:
allOf:
- $ref: '#/components/schemas/PolicyMinimum'
responses:
'200':
description: A Policy object
content:
application/json:
schema:
$ref: '#/components/schemas/Policy'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a Policy
tags: [ Policies ]
security:
- BearerAuth: [ ]
parameters:
- in: path
name: id
required: true
schema:
type: string
description: The Policy ID
responses:
'200':
description: Delete status code
content: { }
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/routes:
get:
summary: Returns a list of all routes

View File

@@ -29,6 +29,9 @@ const (
EventActivityCodePeerRename EventActivityCode = "peer.rename"
EventActivityCodePeerSshDisable EventActivityCode = "peer.ssh.disable"
EventActivityCodePeerSshEnable EventActivityCode = "peer.ssh.enable"
EventActivityCodePolicyAdd EventActivityCode = "policy.add"
EventActivityCodePolicyDelete EventActivityCode = "policy.delete"
EventActivityCodePolicyUpdate EventActivityCode = "policy.update"
EventActivityCodeRouteAdd EventActivityCode = "route.add"
EventActivityCodeRouteDelete EventActivityCode = "route.delete"
EventActivityCodeRouteUpdate EventActivityCode = "route.update"
@@ -94,6 +97,12 @@ const (
PatchMinimumOpReplace PatchMinimumOp = "replace"
)
// Defines values for PolicyRuleAction.
const (
PolicyRuleActionAccept PolicyRuleAction = "accept"
PolicyRuleActionDrop PolicyRuleAction = "drop"
)
// Defines values for RoutePatchOperationOp.
const (
RoutePatchOperationOpAdd RoutePatchOperationOp = "add"
@@ -113,23 +122,6 @@ const (
RoutePatchOperationPathPeer RoutePatchOperationPath = "peer"
)
// Defines values for RulePatchOperationOp.
const (
RulePatchOperationOpAdd RulePatchOperationOp = "add"
RulePatchOperationOpRemove RulePatchOperationOp = "remove"
RulePatchOperationOpReplace RulePatchOperationOp = "replace"
)
// Defines values for RulePatchOperationPath.
const (
RulePatchOperationPathDescription RulePatchOperationPath = "description"
RulePatchOperationPathDestinations RulePatchOperationPath = "destinations"
RulePatchOperationPathDisabled RulePatchOperationPath = "disabled"
RulePatchOperationPathFlow RulePatchOperationPath = "flow"
RulePatchOperationPathName RulePatchOperationPath = "name"
RulePatchOperationPathSources RulePatchOperationPath = "sources"
)
// Defines values for UserStatus.
const (
UserStatusActive UserStatus = "active"
@@ -387,6 +379,72 @@ type PeerMinimum struct {
Name string `json:"name"`
}
// Policy defines model for Policy.
type Policy struct {
// Description Policy friendly description
Description string `json:"description"`
// Enabled Policy status
Enabled bool `json:"enabled"`
// Id Policy ID
Id string `json:"id"`
// Name Policy name identifier
Name string `json:"name"`
// Query Policy Rego query
Query string `json:"query"`
// Rules Policy rule object for policy UI editor
Rules []PolicyRule `json:"rules"`
}
// PolicyMinimum defines model for PolicyMinimum.
type PolicyMinimum struct {
// Description Policy friendly description
Description string `json:"description"`
// Enabled Policy status
Enabled bool `json:"enabled"`
// Name Policy name identifier
Name string `json:"name"`
// Query Policy Rego query
Query string `json:"query"`
// Rules Policy rule object for policy UI editor
Rules []PolicyRule `json:"rules"`
}
// PolicyRule defines model for PolicyRule.
type PolicyRule struct {
// Action policy accept or drops packets
Action PolicyRuleAction `json:"action"`
// Description Rule friendly description
Description *string `json:"description,omitempty"`
// Destinations policy destination groups
Destinations []GroupMinimum `json:"destinations"`
// Enabled Rules status
Enabled bool `json:"enabled"`
// Id Rule ID
Id *string `json:"id,omitempty"`
// Name Rule name identifier
Name string `json:"name"`
// Sources policy source groups
Sources []GroupMinimum `json:"sources"`
}
// PolicyRuleAction policy accept or drops packets
type PolicyRuleAction string
// Route defines model for Route.
type Route struct {
// Description Route description
@@ -504,24 +562,6 @@ type RuleMinimum struct {
Name string `json:"name"`
}
// RulePatchOperation defines model for RulePatchOperation.
type RulePatchOperation struct {
// Op Patch operation type
Op RulePatchOperationOp `json:"op"`
// Path Rule field to update in form /<field>
Path RulePatchOperationPath `json:"path"`
// Value Values to be applied
Value []string `json:"value"`
}
// RulePatchOperationOp Patch operation type
type RulePatchOperationOp string
// RulePatchOperationPath Rule field to update in form /<field>
type RulePatchOperationPath string
// SetupKey defines model for SetupKey.
type SetupKey struct {
// AutoGroups Setup key groups to auto-assign to peers registered with this key
@@ -666,6 +706,12 @@ type PutApiPeersIdJSONBody struct {
SshEnabled bool `json:"ssh_enabled"`
}
// PostApiPoliciesJSONBody defines parameters for PostApiPolicies.
type PostApiPoliciesJSONBody = PolicyMinimum
// PutApiPoliciesIdJSONBody defines parameters for PutApiPoliciesId.
type PutApiPoliciesIdJSONBody = PolicyMinimum
// PatchApiRoutesIdJSONBody defines parameters for PatchApiRoutesId.
type PatchApiRoutesIdJSONBody = []RoutePatchOperation
@@ -686,9 +732,6 @@ type PostApiRulesJSONBody struct {
Sources *[]string `json:"sources,omitempty"`
}
// PatchApiRulesIdJSONBody defines parameters for PatchApiRulesId.
type PatchApiRulesIdJSONBody = []RulePatchOperation
// PutApiRulesIdJSONBody defines parameters for PutApiRulesId.
type PutApiRulesIdJSONBody struct {
// Description Rule friendly description
@@ -733,6 +776,12 @@ type PutApiGroupsIdJSONRequestBody PutApiGroupsIdJSONBody
// PutApiPeersIdJSONRequestBody defines body for PutApiPeersId for application/json ContentType.
type PutApiPeersIdJSONRequestBody PutApiPeersIdJSONBody
// PostApiPoliciesJSONRequestBody defines body for PostApiPolicies for application/json ContentType.
type PostApiPoliciesJSONRequestBody = PostApiPoliciesJSONBody
// PutApiPoliciesIdJSONRequestBody defines body for PutApiPoliciesId for application/json ContentType.
type PutApiPoliciesIdJSONRequestBody = PutApiPoliciesIdJSONBody
// PostApiRoutesJSONRequestBody defines body for PostApiRoutes for application/json ContentType.
type PostApiRoutesJSONRequestBody = RouteRequest
@@ -745,9 +794,6 @@ type PutApiRoutesIdJSONRequestBody = RouteRequest
// PostApiRulesJSONRequestBody defines body for PostApiRules for application/json ContentType.
type PostApiRulesJSONRequestBody PostApiRulesJSONBody
// PatchApiRulesIdJSONRequestBody defines body for PatchApiRulesId for application/json ContentType.
type PatchApiRulesIdJSONRequestBody = PatchApiRulesIdJSONBody
// PutApiRulesIdJSONRequestBody defines body for PutApiRulesId for application/json ContentType.
type PutApiRulesIdJSONRequestBody PutApiRulesIdJSONBody

View File

@@ -59,6 +59,7 @@ func APIHandler(accountManager s.AccountManager, appMetrics telemetry.AppMetrics
api.addUsersEndpoint()
api.addSetupKeysEndpoint()
api.addRulesEndpoint()
api.addPoliciesEndpoint()
api.addGroupsEndpoint()
api.addRoutesEndpoint()
api.addDNSNameserversEndpoint()
@@ -122,11 +123,19 @@ func (apiHandler *apiHandler) addRulesEndpoint() {
apiHandler.Router.HandleFunc("/rules", rulesHandler.GetAllRules).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/rules", rulesHandler.CreateRule).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/rules/{id}", rulesHandler.UpdateRule).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/rules/{id}", rulesHandler.PatchRule).Methods("PATCH", "OPTIONS")
apiHandler.Router.HandleFunc("/rules/{id}", rulesHandler.GetRule).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/rules/{id}", rulesHandler.DeleteRule).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addPoliciesEndpoint() {
policiesHandler := NewPoliciesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/policies", policiesHandler.GetAllPolicies).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/policies", policiesHandler.CreatePolicy).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/policies/{id}", policiesHandler.UpdatePolicy).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/policies/{id}", policiesHandler.GetPolicy).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/policies/{id}", policiesHandler.DeletePolicy).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addGroupsEndpoint() {
groupsHandler := NewGroupsHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/groups", groupsHandler.GetAllGroups).Methods("GET", "OPTIONS")

View File

@@ -0,0 +1,326 @@
package http
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
)
// Policies is a handler that returns policy of the account
type Policies struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
}
// NewPoliciesHandler creates a new Policies handler
func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Policies {
return &Policies{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// GetAllPolicies list for the account
func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
accountPolicies, err := h.accountManager.ListPolicies(account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, accountPolicies)
}
// UpdatePolicy handles update to a policy identified by a given ID
func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
policyID := vars["id"]
if len(policyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
policyIdx := -1
for i, policy := range account.Policies {
if policy.ID == policyID {
policyIdx = i
break
}
}
if policyIdx < 0 {
util.WriteError(status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w)
return
}
var req api.PutApiPoliciesIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if req.Name == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w)
return
}
policy := server.Policy{
ID: policyID,
Name: req.Name,
Enabled: req.Enabled,
Description: req.Description,
Query: req.Query,
}
if req.Rules != nil {
for _, r := range req.Rules {
pr := server.PolicyRule{
Destinations: groupMinimumsToStrings(account, r.Destinations),
Sources: groupMinimumsToStrings(account, r.Sources),
Name: r.Name,
}
pr.Enabled = r.Enabled
if r.Description != nil {
pr.Description = *r.Description
}
if r.Id != nil {
pr.ID = *r.Id
}
switch r.Action {
case api.PolicyRuleActionAccept:
pr.Action = server.PolicyTrafficActionAccept
case api.PolicyRuleActionDrop:
pr.Action = server.PolicyTrafficActionDrop
default:
util.WriteError(status.Errorf(status.InvalidArgument, "unknown action type"), w)
return
}
policy.Rules = append(policy.Rules, &pr)
}
}
if err := policy.UpdateQueryFromRules(); err != nil {
log.Errorf("failed to update policy query: %v", err)
util.WriteError(err, w)
return
}
if err = h.accountManager.SavePolicy(account.Id, user.Id, &policy); err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, toPolicyResponse(account, &policy))
}
// CreatePolicy handles policy creation request
func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
var req api.PostApiPoliciesJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if req.Name == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "policy name shouldn't be empty"), w)
return
}
policy := &server.Policy{
ID: xid.New().String(),
Name: req.Name,
Enabled: req.Enabled,
Description: req.Description,
Query: req.Query,
}
if req.Rules != nil {
for _, r := range req.Rules {
pr := server.PolicyRule{
ID: xid.New().String(),
Destinations: groupMinimumsToStrings(account, r.Destinations),
Sources: groupMinimumsToStrings(account, r.Sources),
Name: r.Name,
}
pr.Enabled = r.Enabled
if r.Description != nil {
pr.Description = *r.Description
}
switch r.Action {
case api.PolicyRuleActionAccept:
pr.Action = server.PolicyTrafficActionAccept
case api.PolicyRuleActionDrop:
pr.Action = server.PolicyTrafficActionDrop
default:
util.WriteError(status.Errorf(status.InvalidArgument, "unknown action type"), w)
return
}
policy.Rules = append(policy.Rules, &pr)
}
}
if err := policy.UpdateQueryFromRules(); err != nil {
util.WriteError(err, w)
return
}
if err = h.accountManager.SavePolicy(account.Id, user.Id, policy); err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, toPolicyResponse(account, policy))
}
// DeletePolicy handles policy deletion request
func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
aID := account.Id
vars := mux.Vars(r)
policyID := vars["id"]
if len(policyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
if err = h.accountManager.DeletePolicy(aID, policyID, user.Id); err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, "")
}
// GetPolicy handles a group Get request identified by ID
func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
switch r.Method {
case http.MethodGet:
vars := mux.Vars(r)
policyID := vars["id"]
if len(policyID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid policy ID"), w)
return
}
policy, err := h.accountManager.GetPolicy(account.Id, policyID, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, toPolicyResponse(account, policy))
default:
util.WriteError(status.Errorf(status.NotFound, "method not found"), w)
}
}
func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy {
cache := make(map[string]api.GroupMinimum)
ap := &api.Policy{
Id: policy.ID,
Name: policy.Name,
Description: policy.Description,
Enabled: policy.Enabled,
Query: policy.Query,
}
if len(policy.Rules) == 0 {
return ap
}
for _, r := range policy.Rules {
rule := api.PolicyRule{
Id: &r.ID,
Name: r.Name,
Enabled: r.Enabled,
Description: &r.Description,
}
for _, gid := range r.Sources {
_, ok := cache[gid]
if ok {
continue
}
if group, ok := account.Groups[gid]; ok {
minimum := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
}
rule.Sources = append(rule.Sources, minimum)
cache[gid] = minimum
}
}
for _, gid := range r.Destinations {
cachedMinimum, ok := cache[gid]
if ok {
rule.Destinations = append(rule.Destinations, cachedMinimum)
continue
}
if group, ok := account.Groups[gid]; ok {
minimum := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
}
rule.Destinations = append(rule.Destinations, minimum)
cache[gid] = minimum
}
}
ap.Rules = append(ap.Rules, rule)
}
return ap
}
func groupMinimumsToStrings(account *server.Account, gm []api.GroupMinimum) []string {
result := make([]string, 0, len(gm))
for _, gm := range gm {
if _, ok := account.Groups[gm.Id]; ok {
continue
}
result = append(result, gm.Id)
}
return result
}

View File

@@ -40,14 +40,16 @@ func (h *RulesHandler) GetAllRules(w http.ResponseWriter, r *http.Request) {
return
}
accountRules, err := h.accountManager.ListRules(account.Id, user.Id)
accountPolicies, err := h.accountManager.ListPolicies(account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
rules := []*api.Rule{}
for _, r := range accountRules {
rules = append(rules, toRuleResponse(account, r))
for _, policy := range accountPolicies {
for _, r := range policy.Rules {
rules = append(rules, toRuleResponse(account, r.ToRule()))
}
}
util.WriteJSONObject(w, rules)
@@ -69,9 +71,9 @@ func (h *RulesHandler) UpdateRule(w http.ResponseWriter, r *http.Request) {
return
}
_, ok := account.Rules[ruleID]
if !ok {
util.WriteError(status.Errorf(status.NotFound, "couldn't find rule id %s", ruleID), w)
policy, err := h.accountManager.GetPolicy(account.Id, ruleID, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
@@ -96,173 +98,40 @@ func (h *RulesHandler) UpdateRule(w http.ResponseWriter, r *http.Request) {
reqDestinations = *req.Destinations
}
rule := server.Rule{
ID: ruleID,
Name: req.Name,
Source: reqSources,
Destination: reqDestinations,
Disabled: req.Disabled,
Description: req.Description,
if len(policy.Rules) != 1 {
util.WriteError(status.Errorf(status.Internal, "policy should contain exactly one rule"), w)
return
}
policy.Name = req.Name
policy.Description = req.Description
policy.Enabled = !req.Disabled
policy.Rules[0].ID = ruleID
policy.Rules[0].Name = req.Name
policy.Rules[0].Sources = reqSources
policy.Rules[0].Destinations = reqDestinations
policy.Rules[0].Enabled = !req.Disabled
policy.Rules[0].Description = req.Description
if err := policy.UpdateQueryFromRules(); err != nil {
util.WriteError(err, w)
return
}
switch req.Flow {
case server.TrafficFlowBidirectString:
rule.Flow = server.TrafficFlowBidirect
policy.Rules[0].Action = server.PolicyTrafficActionAccept
default:
util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w)
return
}
err = h.accountManager.SaveRule(account.Id, user.Id, &rule)
err = h.accountManager.SavePolicy(account.Id, user.Id, policy)
if err != nil {
util.WriteError(err, w)
return
}
resp := toRuleResponse(account, &rule)
util.WriteJSONObject(w, &resp)
}
// PatchRule handles patch updates to a rule identified by a given ID
func (h *RulesHandler) PatchRule(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, _, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
ruleID := vars["id"]
if len(ruleID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return
}
_, ok := account.Rules[ruleID]
if !ok {
util.WriteError(status.Errorf(status.NotFound, "couldn't find rule ID %s", ruleID), w)
return
}
var req api.PatchApiRulesIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if len(req) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "no patch instruction received"), w)
return
}
var operations []server.RuleUpdateOperation
for _, patch := range req {
switch patch.Path {
case api.RulePatchOperationPathName:
if patch.Op != api.RulePatchOperationOpReplace {
util.WriteError(status.Errorf(status.InvalidArgument,
"name field only accepts replace operation, got %s", patch.Op), w)
return
}
if len(patch.Value) == 0 || patch.Value[0] == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
return
}
operations = append(operations, server.RuleUpdateOperation{
Type: server.UpdateRuleName,
Values: patch.Value,
})
case api.RulePatchOperationPathDescription:
if patch.Op != api.RulePatchOperationOpReplace {
util.WriteError(status.Errorf(status.InvalidArgument,
"description field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RuleUpdateOperation{
Type: server.UpdateRuleDescription,
Values: patch.Value,
})
case api.RulePatchOperationPathFlow:
if patch.Op != api.RulePatchOperationOpReplace {
util.WriteError(status.Errorf(status.InvalidArgument,
"flow field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RuleUpdateOperation{
Type: server.UpdateRuleFlow,
Values: patch.Value,
})
case api.RulePatchOperationPathDisabled:
if patch.Op != api.RulePatchOperationOpReplace {
util.WriteError(status.Errorf(status.InvalidArgument,
"disabled field only accepts replace operation, got %s", patch.Op), w)
return
}
operations = append(operations, server.RuleUpdateOperation{
Type: server.UpdateRuleStatus,
Values: patch.Value,
})
case api.RulePatchOperationPathSources:
switch patch.Op {
case api.RulePatchOperationOpReplace:
operations = append(operations, server.RuleUpdateOperation{
Type: server.UpdateSourceGroups,
Values: patch.Value,
})
case api.RulePatchOperationOpRemove:
operations = append(operations, server.RuleUpdateOperation{
Type: server.RemoveGroupsFromSource,
Values: patch.Value,
})
case api.RulePatchOperationOpAdd:
operations = append(operations, server.RuleUpdateOperation{
Type: server.InsertGroupsToSource,
Values: patch.Value,
})
default:
util.WriteError(status.Errorf(status.InvalidArgument,
"invalid operation \"%s\" on Source field", patch.Op), w)
return
}
case api.RulePatchOperationPathDestinations:
switch patch.Op {
case api.RulePatchOperationOpReplace:
operations = append(operations, server.RuleUpdateOperation{
Type: server.UpdateDestinationGroups,
Values: patch.Value,
})
case api.RulePatchOperationOpRemove:
operations = append(operations, server.RuleUpdateOperation{
Type: server.RemoveGroupsFromDestination,
Values: patch.Value,
})
case api.RulePatchOperationOpAdd:
operations = append(operations, server.RuleUpdateOperation{
Type: server.InsertGroupsToDestination,
Values: patch.Value,
})
default:
util.WriteError(status.Errorf(status.InvalidArgument,
"invalid operation \"%s\" on Destination field", patch.Op), w)
return
}
default:
util.WriteError(status.Errorf(status.InvalidArgument, "invalid patch path"), w)
return
}
}
rule, err := h.accountManager.UpdateRule(account.Id, ruleID, operations)
if err != nil {
util.WriteError(err, w)
return
}
resp := toRuleResponse(account, rule)
resp := toRuleResponse(account, policy.Rules[0].ToRule())
util.WriteJSONObject(w, &resp)
}
@@ -315,7 +184,12 @@ func (h *RulesHandler) CreateRule(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.SaveRule(account.Id, user.Id, &rule)
policy, err := server.RuleToPolicy(&rule)
if err != nil {
util.WriteError(err, w)
return
}
err = h.accountManager.SavePolicy(account.Id, user.Id, policy)
if err != nil {
util.WriteError(err, w)
return
@@ -342,7 +216,7 @@ func (h *RulesHandler) DeleteRule(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeleteRule(aID, rID, user.Id)
err = h.accountManager.DeletePolicy(aID, rID, user.Id)
if err != nil {
util.WriteError(err, w)
return
@@ -368,13 +242,13 @@ func (h *RulesHandler) GetRule(w http.ResponseWriter, r *http.Request) {
return
}
rule, err := h.accountManager.GetRule(account.Id, ruleID, user.Id)
policy, err := h.accountManager.GetPolicy(account.Id, ruleID, user.Id)
if err != nil {
util.WriteError(status.Errorf(status.NotFound, "rule not found"), w)
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, toRuleResponse(account, rule))
util.WriteJSONObject(w, toRuleResponse(account, policy.Rules[0].ToRule()))
default:
util.WriteError(status.Errorf(status.NotFound, "method not found"), w)
}

View File

@@ -11,6 +11,7 @@ import (
"testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/gorilla/mux"
@@ -23,8 +24,32 @@ import (
)
func initRulesTestData(rules ...*server.Rule) *RulesHandler {
testPolicies := make(map[string]*server.Policy, len(rules))
for _, rule := range rules {
policy, err := server.RuleToPolicy(rule)
if err != nil {
panic(err)
}
if err := policy.UpdateQueryFromRules(); err != nil {
panic(err)
}
testPolicies[policy.ID] = policy
}
return &RulesHandler{
accountManager: &mock_server.MockAccountManager{
GetPolicyFunc: func(_, policyID, _ string) (*server.Policy, error) {
policy, ok := testPolicies[policyID]
if !ok {
return nil, status.Errorf(status.NotFound, "policy not found")
}
return policy, nil
},
SavePolicyFunc: func(_, _ string, policy *server.Policy) error {
if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set"
}
return nil
},
SaveRuleFunc: func(_, _ string, rule *server.Rule) error {
if !strings.HasPrefix(rule.ID, "id-") {
rule.ID = "id-was-set"
@@ -43,32 +68,6 @@ func initRulesTestData(rules ...*server.Rule) *RulesHandler {
Flow: server.TrafficFlowBidirect,
}, nil
},
UpdateRuleFunc: func(_ string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error) {
var rule server.Rule
rule.ID = ruleID
for _, operation := range operations {
switch operation.Type {
case server.UpdateRuleName:
rule.Name = operation.Values[0]
case server.UpdateRuleDescription:
rule.Description = operation.Values[0]
case server.UpdateRuleFlow:
if server.TrafficFlowBidirectString == operation.Values[0] {
rule.Flow = server.TrafficFlowBidirect
} else {
rule.Flow = 100
}
case server.UpdateSourceGroups, server.InsertGroupsToSource:
rule.Source = operation.Values
case server.UpdateDestinationGroups, server.InsertGroupsToDestination:
rule.Destination = operation.Values
case server.RemoveGroupsFromSource, server.RemoveGroupsFromDestination:
default:
return nil, fmt.Errorf("no operation")
}
}
return &rule, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
@@ -221,58 +220,13 @@ func TestRulesWriteRule(t *testing.T) {
[]byte(`{"Name":"","Flow":"bidirect"}`)),
expectedStatus: http.StatusUnprocessableEntity,
},
{
name: "Write Rule PATCH Name OK",
requestType: http.MethodPatch,
requestPath: "/api/rules/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`[{"op":"replace","path":"name","value":["Default POSTed Rule"]}]`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRule: &api.Rule{
Id: "id-existed",
Name: "Default POSTed Rule",
Flow: server.TrafficFlowBidirectString,
},
},
{
name: "Write Rule PATCH Invalid Name OP",
requestType: http.MethodPatch,
requestPath: "/api/rules/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`[{"op":"insert","path":"name","value":[""]}]`)),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "Write Rule PATCH Invalid Name",
requestType: http.MethodPatch,
requestPath: "/api/rules/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`[{"op":"replace","path":"name","value":[]}]`)),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "Write Rule PATCH Sources OK",
requestType: http.MethodPatch,
requestPath: "/api/rules/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`[{"op":"replace","path":"sources","value":["G","F"]}]`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRule: &api.Rule{
Id: "id-existed",
Flow: server.TrafficFlowBidirectString,
Sources: []api.GroupMinimum{
{Id: "G"},
{Id: "F"},
},
},
},
}
p := initRulesTestData()
p := initRulesTestData(&server.Rule{
ID: "id-existed",
Name: "Default POSTed Rule",
Flow: server.TrafficFlowBidirect,
})
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -282,7 +236,6 @@ func TestRulesWriteRule(t *testing.T) {
router := mux.NewRouter()
router.HandleFunc("/api/rules", p.CreateRule).Methods("POST")
router.HandleFunc("/api/rules/{id}", p.UpdateRule).Methods("PUT")
router.HandleFunc("/api/rules/{id}", p.PatchRule).Methods("PATCH")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -307,6 +260,7 @@ func TestRulesWriteRule(t *testing.T) {
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.expectedRule.Id = got.Id
assert.Equal(t, got, tc.expectedRule)
})