[management] Refactor policy to use store methods (#2878)

This commit is contained in:
Bethuel Mmbaga
2024-11-26 12:46:05 +03:00
committed by GitHub
parent ca12bc6953
commit f118d81d32
18 changed files with 576 additions and 349 deletions

View File

@@ -6,10 +6,8 @@ import (
"strconv"
"github.com/gorilla/mux"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -122,21 +120,22 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
return
}
isUpdate := policyID != ""
if policyID == "" {
policyID = xid.New().String()
}
policy := server.Policy{
policy := &server.Policy{
ID: policyID,
AccountID: accountID,
Name: req.Name,
Enabled: req.Enabled,
Description: req.Description,
}
for _, rule := range req.Rules {
var ruleID string
if rule.Id != nil {
ruleID = *rule.Id
}
pr := server.PolicyRule{
ID: policyID, // TODO: when policy can contain multiple rules, need refactor
ID: ruleID,
PolicyID: policyID,
Name: rule.Name,
Destinations: rule.Destinations,
Sources: rule.Sources,
@@ -225,7 +224,8 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
policy.SourcePostureChecks = *req.SourcePostureChecks
}
if err := h.accountManager.SavePolicy(r.Context(), accountID, userID, &policy, isUpdate); err != nil {
policy, err := h.accountManager.SavePolicy(r.Context(), accountID, userID, policy)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -236,7 +236,7 @@ func (h *Policies) savePolicy(w http.ResponseWriter, r *http.Request, accountID
return
}
resp := toPolicyResponse(allGroups, &policy)
resp := toPolicyResponse(allGroups, policy)
if len(resp.Rules) == 0 {
util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w)
return

View File

@@ -38,12 +38,12 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
}
return policy, nil
},
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy, _ bool) error {
SavePolicyFunc: func(_ context.Context, _, _ string, policy *server.Policy) (*server.Policy, error) {
if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set"
policy.Rules[0].ID = "id-was-set"
}
return nil
return policy, nil
},
GetAllGroupsFunc: func(ctx context.Context, accountID, userID string) ([]*nbgroup.Group, error) {
return []*nbgroup.Group{{ID: "F"}, {ID: "G"}}, nil