Merge branch 'main' into add-process-posture-check

This commit is contained in:
bcmmbaga
2024-03-18 18:41:38 +03:00
52 changed files with 1311 additions and 1283 deletions

View File

@@ -40,9 +40,6 @@ const (
PublicCategory = "public"
PrivateCategory = "private"
UnknownCategory = "unknown"
GroupIssuedAPI = "api"
GroupIssuedJWT = "jwt"
GroupIssuedIntegration = "integration"
CacheExpirationMax = 7 * 24 * 3600 * time.Second // 7 days
CacheExpirationMin = 3 * 24 * 3600 * time.Second // 3 days
DefaultPeerLoginExpiration = 24 * time.Hour
@@ -227,9 +224,6 @@ type Account struct {
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
// Settings is a dictionary of Account settings
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
// deprecated on store and api level
Rules map[string]*Rule `json:"-" gorm:"-"`
RulesG []Rule `json:"-" gorm:"-"`
}
type UserInfo struct {
@@ -559,6 +553,16 @@ func (a *Account) FindUser(userID string) (*User, error) {
return user, nil
}
// FindGroupByName looks for a given group in the Account by name or returns error if the group wasn't found.
func (a *Account) FindGroupByName(groupName string) (*Group, error) {
for _, group := range a.Groups {
if group.Name == groupName {
return group, nil
}
}
return nil, status.Errorf(status.NotFound, "group %s not found", groupName)
}
// FindSetupKey looks for a given SetupKey in the Account or returns error if it wasn't found.
func (a *Account) FindSetupKey(setupKey string) (*SetupKey, error) {
key := a.SetupKeys[setupKey]
@@ -1361,16 +1365,21 @@ func (am *DefaultAccountManager) removeUserFromCache(accountID, userID string) e
func (am *DefaultAccountManager) updateAccountDomainAttributes(account *Account, claims jwtclaims.AuthorizationClaims,
primaryDomain bool,
) error {
account.IsDomainPrimaryAccount = primaryDomain
lowerDomain := strings.ToLower(claims.Domain)
userObj := account.Users[claims.UserId]
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
account.Domain = lowerDomain
}
// prevent updating category for different domain until admin logs in
if account.Domain == lowerDomain {
account.DomainCategory = claims.DomainCategory
if claims.Domain != "" {
account.IsDomainPrimaryAccount = primaryDomain
lowerDomain := strings.ToLower(claims.Domain)
userObj := account.Users[claims.UserId]
if account.Domain != lowerDomain && userObj.Role == UserRoleAdmin {
account.Domain = lowerDomain
}
// prevent updating category for different domain until admin logs in
if account.Domain == lowerDomain {
account.DomainCategory = claims.DomainCategory
}
} else {
log.Errorf("claims don't contain a valid domain, skipping domain attributes update. Received claims: %v", claims)
}
err := am.Store.SaveAccount(account)

View File

@@ -258,18 +258,6 @@ func TestStore(t *testing.T) {
t.Errorf("failed to restore a FileStore file - missing Group all")
}
if restoredAccount.Rules["all"] == nil {
t.Errorf("failed to restore a FileStore file - missing Rule all")
return
}
if restoredAccount.Rules["dmz"] == nil {
t.Errorf("failed to restore a FileStore file - missing Rule dmz")
return
}
assert.Equal(t, account.Rules["all"], restoredAccount.Rules["all"], "failed to restore a FileStore file - missing Rule all")
assert.Equal(t, account.Rules["dmz"], restoredAccount.Rules["dmz"], "failed to restore a FileStore file - missing Rule dmz")
if len(restoredAccount.Policies) != 2 {
t.Errorf("failed to restore a FileStore file - missing Policies")
return
@@ -411,7 +399,6 @@ func TestFileStore_GetAccount(t *testing.T) {
assert.Len(t, account.Peers, len(expected.Peers))
assert.Len(t, account.Users, len(expected.Users))
assert.Len(t, account.SetupKeys, len(expected.SetupKeys))
assert.Len(t, account.Rules, len(expected.Rules))
assert.Len(t, account.Routes, len(expected.Routes))
assert.Len(t, account.NameServerGroups, len(expected.NameServerGroups))
}

View File

@@ -3,6 +3,7 @@ package server
import (
"fmt"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
@@ -18,6 +19,12 @@ func (e *GroupLinkError) Error() string {
return fmt.Sprintf("group has been linked to %s: %s", e.Resource, e.Name)
}
const (
GroupIssuedAPI = "api"
GroupIssuedJWT = "jwt"
GroupIssuedIntegration = "integration"
)
// Group of the peers for ACL
type Group struct {
// ID of the group
@@ -29,7 +36,7 @@ type Group struct {
// Name visible in the UI
Name string
// Issued of the group
// Issued defines how this group was created (enum of "api", "integration" or "jwt")
Issued string
// Peers list of the group
@@ -116,6 +123,29 @@ func (am *DefaultAccountManager) SaveGroup(accountID, userID string, newGroup *G
return err
}
if newGroup.ID == "" && newGroup.Issued != GroupIssuedAPI {
return status.Errorf(status.InvalidArgument, "%s group without ID set", newGroup.Issued)
}
if newGroup.ID == "" && newGroup.Issued == GroupIssuedAPI {
existingGroup, err := account.FindGroupByName(newGroup.Name)
if err != nil {
s, ok := status.FromError(err)
if !ok || s.ErrorType != status.NotFound {
return err
}
}
// avoid duplicate groups only for the API issued groups. Integration or JWT groups can be duplicated as they are
// coming from the IdP that we don't have control of.
if existingGroup != nil {
return status.Errorf(status.AlreadyExists, "group with name %s already exists", newGroup.Name)
}
newGroup.ID = xid.New().String()
}
for _, peerID := range newGroup.Peers {
if account.Peers[peerID] == nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)

View File

@@ -13,6 +13,41 @@ const (
groupAdminUserID = "testingAdminUser"
)
func TestDefaultAccountManager_CreateGroup(t *testing.T) {
am, err := createManager(t)
if err != nil {
t.Error("failed to create account manager")
}
account, err := initTestGroupAccount(am)
if err != nil {
t.Error("failed to init testing account")
}
for _, group := range account.Groups {
group.Issued = GroupIssuedIntegration
err = am.SaveGroup(account.Id, groupAdminUserID, group)
if err != nil {
t.Errorf("should allow to create %s groups", GroupIssuedIntegration)
}
}
for _, group := range account.Groups {
group.Issued = GroupIssuedJWT
err = am.SaveGroup(account.Id, groupAdminUserID, group)
if err != nil {
t.Errorf("should allow to create %s groups", GroupIssuedJWT)
}
}
for _, group := range account.Groups {
group.Issued = GroupIssuedAPI
group.ID = ""
err = am.SaveGroup(account.Id, groupAdminUserID, group)
if err == nil {
t.Errorf("should not create api group with the same name, %s", group.Name)
}
}
}
func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
am, err := createManager(t)
if err != nil {
@@ -137,7 +172,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
groupForIntegration := &Group{
ID: "grp-for-integration",
AccountID: "account-id",
Name: "Group for users",
Name: "Group for users integration",
Issued: GroupIssuedIntegration,
Peers: make([]string, 0),
}

View File

@@ -17,8 +17,6 @@ tags:
description: Interact with and view information about setup keys.
- name: Groups
description: Interact with and view information about groups.
- name: Rules
description: Interact with and view information about rules.
- name: Policies
description: Interact with and view information about policies.
- name: Posture Checks
@@ -587,7 +585,10 @@ components:
type: integer
example: 2
issued:
description: How group was issued by API or from JWT token
description: How the group was issued (api, integration, jwt)
type: string
enum: ["api", "integration", "jwt"]
example: api
type: string
example: api
required:
@@ -621,73 +622,6 @@ components:
$ref: '#/components/schemas/PeerMinimum'
required:
- peers
RuleMinimum:
type: object
properties:
name:
description: Rule name identifier
type: string
example: Default
description:
description: Rule friendly description
type: string
example: This is a default rule that allows connections between all the resources
disabled:
description: Rules status
type: boolean
example: false
flow:
description: Rule flow, currently, only "bidirect" for bi-directional traffic is accepted
type: string
example: bidirect
required:
- name
- description
- disabled
- flow
RuleRequest:
allOf:
- $ref: '#/components/schemas/RuleMinimum'
- type: object
properties:
sources:
type: array
description: List of source group IDs
items:
type: string
example: "ch8i4ug6lnn4g9hqv7m1"
destinations:
type: array
description: List of destination group IDs
items:
type: string
example: "ch8i4ug6lnn4g9hqv7m0"
Rule:
allOf:
- type: object
properties:
id:
description: Rule ID
type: string
example: ch8i4ug6lnn4g9hqv7mg
required:
- id
- $ref: '#/components/schemas/RuleMinimum'
- type: object
properties:
sources:
description: Rule source group IDs
type: array
items:
$ref: '#/components/schemas/GroupMinimum'
destinations:
description: Rule destination group IDs
type: array
items:
$ref: '#/components/schemas/GroupMinimum'
required:
- sources
- destinations
PolicyRuleMinimum:
type: object
properties:
@@ -1339,7 +1273,7 @@ paths:
/api/accounts/{accountId}:
delete:
summary: Delete an Account
description: Deletes an account and all its resources. Only administrators and account owners can delete accounts.
description: Deletes an account and all its resources. Only account owners can delete accounts.
tags: [ Accounts ]
security:
- BearerAuth: [ ]
@@ -2059,147 +1993,6 @@ paths:
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/rules:
get:
summary: List all Rules
description: Returns a list of all rules. This will be deprecated in favour of `/api/policies`.
tags: [ Rules ]
deprecated: true
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
responses:
'200':
description: A JSON Array of Rules
content:
application/json:
schema:
type: array
items:
$ref: '#/components/schemas/Rule'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
post:
summary: Create a Rule
description: Creates a rule. This will be deprecated in favour of `/api/policies`.
deprecated: true
tags: [ Rules ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
requestBody:
description: New Rule request
content:
'application/json':
schema:
$ref: '#/components/schemas/RuleRequest'
responses:
'200':
description: A Rule Object
content:
application/json:
schema:
$ref: '#/components/schemas/Rule'
/api/rules/{ruleId}:
get:
summary: Retrieve a Rule
description: Get information about a rules. This will be deprecated in favour of `/api/policies/{policyID}`.
deprecated: true
tags: [ Rules ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: ruleId
required: true
schema:
type: string
description: The unique identifier of a rule
responses:
'200':
description: A Rule object
content:
application/json:
schema:
$ref: '#/components/schemas/Rule'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
put:
summary: Update a Rule
description: Update/Replace a rule. This will be deprecated in favour of `/api/policies/{policyID}`.
deprecated: true
tags: [ Rules ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: ruleId
required: true
schema:
type: string
description: The unique identifier of a rule
requestBody:
description: Update Rule request
content:
'application/json':
schema:
$ref: '#/components/schemas/RuleRequest'
responses:
'200':
description: A Rule object
content:
application/json:
schema:
$ref: '#/components/schemas/Rule'
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
delete:
summary: Delete a Rule
description: Delete a rule. This will be deprecated in favour of `/api/policies/{policyID}`.
deprecated: true
tags: [ Rules ]
security:
- BearerAuth: [ ]
- TokenAuth: [ ]
parameters:
- in: path
name: ruleId
required: true
schema:
type: string
description: The unique identifier of a rule
responses:
'200':
description: Delete status code
content: { }
'400':
"$ref": "#/components/responses/bad_request"
'401':
"$ref": "#/components/responses/requires_authentication"
'403':
"$ref": "#/components/responses/forbidden"
'500':
"$ref": "#/components/responses/internal_error"
/api/policies:
get:
summary: List all Policies

View File

@@ -993,66 +993,6 @@ type RouteRequest struct {
PeerGroups *[]string `json:"peer_groups,omitempty"`
}
// Rule defines model for Rule.
type Rule struct {
// Description Rule friendly description
Description string `json:"description"`
// Destinations Rule destination group IDs
Destinations []GroupMinimum `json:"destinations"`
// Disabled Rules status
Disabled bool `json:"disabled"`
// Flow Rule flow, currently, only "bidirect" for bi-directional traffic is accepted
Flow string `json:"flow"`
// Id Rule ID
Id string `json:"id"`
// Name Rule name identifier
Name string `json:"name"`
// Sources Rule source group IDs
Sources []GroupMinimum `json:"sources"`
}
// RuleMinimum defines model for RuleMinimum.
type RuleMinimum struct {
// Description Rule friendly description
Description string `json:"description"`
// Disabled Rules status
Disabled bool `json:"disabled"`
// Flow Rule flow, currently, only "bidirect" for bi-directional traffic is accepted
Flow string `json:"flow"`
// Name Rule name identifier
Name string `json:"name"`
}
// RuleRequest defines model for RuleRequest.
type RuleRequest struct {
// Description Rule friendly description
Description string `json:"description"`
// Destinations List of destination group IDs
Destinations *[]string `json:"destinations,omitempty"`
// Disabled Rules status
Disabled bool `json:"disabled"`
// Flow Rule flow, currently, only "bidirect" for bi-directional traffic is accepted
Flow string `json:"flow"`
// Name Rule name identifier
Name string `json:"name"`
// Sources List of source group IDs
Sources *[]string `json:"sources,omitempty"`
}
// SetupKey defines model for SetupKey.
type SetupKey struct {
// AutoGroups List of group IDs to auto-assign to peers registered with this key
@@ -1236,12 +1176,6 @@ type PostApiRoutesJSONRequestBody = RouteRequest
// PutApiRoutesRouteIdJSONRequestBody defines body for PutApiRoutesRouteId for application/json ContentType.
type PutApiRoutesRouteIdJSONRequestBody = RouteRequest
// PostApiRulesJSONRequestBody defines body for PostApiRules for application/json ContentType.
type PostApiRulesJSONRequestBody = RuleRequest
// PutApiRulesRuleIdJSONRequestBody defines body for PutApiRulesRuleId for application/json ContentType.
type PutApiRulesRuleIdJSONRequestBody = RuleRequest
// PostApiSetupKeysJSONRequestBody defines body for PostApiSetupKeys for application/json ContentType.
type PostApiSetupKeysJSONRequestBody = SetupKeyRequest

View File

@@ -8,8 +8,6 @@ import (
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/status"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/jwtclaims"
@@ -151,7 +149,6 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) {
peers = *req.Peers
}
group := server.Group{
ID: xid.New().String(),
Name: req.Name,
Peers: peers,
Issued: server.GroupIssuedAPI,

View File

@@ -85,7 +85,6 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
api.addUsersEndpoint()
api.addUsersTokensEndpoint()
api.addSetupKeysEndpoint()
api.addRulesEndpoint()
api.addPoliciesEndpoint()
api.addGroupsEndpoint()
api.addRoutesEndpoint()
@@ -158,15 +157,6 @@ func (apiHandler *apiHandler) addSetupKeysEndpoint() {
apiHandler.Router.HandleFunc("/setup-keys/{keyId}", keysHandler.UpdateSetupKey).Methods("PUT", "OPTIONS")
}
func (apiHandler *apiHandler) addRulesEndpoint() {
rulesHandler := NewRulesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/rules", rulesHandler.GetAllRules).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/rules", rulesHandler.CreateRule).Methods("POST", "OPTIONS")
apiHandler.Router.HandleFunc("/rules/{ruleId}", rulesHandler.UpdateRule).Methods("PUT", "OPTIONS")
apiHandler.Router.HandleFunc("/rules/{ruleId}", rulesHandler.GetRule).Methods("GET", "OPTIONS")
apiHandler.Router.HandleFunc("/rules/{ruleId}", rulesHandler.DeleteRule).Methods("DELETE", "OPTIONS")
}
func (apiHandler *apiHandler) addPoliciesEndpoint() {
policiesHandler := NewPoliciesHandler(apiHandler.AccountManager, apiHandler.AuthCfg)
apiHandler.Router.HandleFunc("/policies", policiesHandler.GetAllPolicies).Methods("GET", "OPTIONS")

View File

@@ -55,6 +55,9 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler {
GetPeersFunc: func(accountID, userID string) ([]*nbpeer.Peer, error) {
return peers, nil
},
GetDNSDomainFunc: func() string {
return "netbird.selfhosted"
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{

View File

@@ -3,7 +3,6 @@ package http
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
@@ -44,24 +43,6 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies {
}
return nil
},
SaveRuleFunc: func(_, _ string, rule *server.Rule) error {
if !strings.HasPrefix(rule.ID, "id-") {
rule.ID = "id-was-set"
}
return nil
},
GetRuleFunc: func(_, ruleID, _ string) (*server.Rule, error) {
if ruleID != "idoftherule" {
return nil, fmt.Errorf("not found")
}
return &server.Rule{
ID: "idoftherule",
Name: "Rule",
Source: []string{"idofsrcrule"},
Destination: []string{"idofdestrule"},
Flow: server.TrafficFlowBidirect,
}, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{

View File

@@ -1,305 +0,0 @@
package http
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/rs/xid"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/util"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/netbirdio/netbird/management/server/status"
)
// RulesHandler is a handler that returns rules of the account
type RulesHandler struct {
accountManager server.AccountManager
claimsExtractor *jwtclaims.ClaimsExtractor
}
// NewRulesHandler creates a new RulesHandler HTTP handler
func NewRulesHandler(accountManager server.AccountManager, authCfg AuthCfg) *RulesHandler {
return &RulesHandler{
accountManager: accountManager,
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithAudience(authCfg.Audience),
jwtclaims.WithUserIDClaim(authCfg.UserIDClaim),
),
}
}
// GetAllRules list for the account
func (h *RulesHandler) GetAllRules(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
accountPolicies, err := h.accountManager.ListPolicies(account.Id, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
rules := []*api.Rule{}
for _, policy := range accountPolicies {
for _, r := range policy.Rules {
rules = append(rules, toRuleResponse(account, r.ToRule()))
}
}
util.WriteJSONObject(w, rules)
}
// UpdateRule handles update to a rule identified by a given ID
func (h *RulesHandler) UpdateRule(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
vars := mux.Vars(r)
ruleID := vars["ruleId"]
if len(ruleID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return
}
policy, err := h.accountManager.GetPolicy(account.Id, ruleID, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
var req api.PutApiRulesRuleIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
}
if req.Name == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
return
}
var reqSources []string
if req.Sources != nil {
reqSources = *req.Sources
}
var reqDestinations []string
if req.Destinations != nil {
reqDestinations = *req.Destinations
}
if len(policy.Rules) != 1 {
util.WriteError(status.Errorf(status.Internal, "policy should contain exactly one rule"), w)
return
}
policy.Name = req.Name
policy.Description = req.Description
policy.Enabled = !req.Disabled
policy.Rules[0].ID = ruleID
policy.Rules[0].Name = req.Name
policy.Rules[0].Sources = reqSources
policy.Rules[0].Destinations = reqDestinations
policy.Rules[0].Enabled = !req.Disabled
policy.Rules[0].Description = req.Description
switch req.Flow {
case server.TrafficFlowBidirectString:
policy.Rules[0].Action = server.PolicyTrafficActionAccept
default:
util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w)
return
}
err = h.accountManager.SavePolicy(account.Id, user.Id, policy)
if err != nil {
util.WriteError(err, w)
return
}
resp := toRuleResponse(account, policy.Rules[0].ToRule())
util.WriteJSONObject(w, &resp)
}
// CreateRule handles rule creation request
func (h *RulesHandler) CreateRule(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
var req api.PostApiRulesJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
if req.Name == "" {
util.WriteError(status.Errorf(status.InvalidArgument, "rule name shouldn't be empty"), w)
return
}
var reqSources []string
if req.Sources != nil {
reqSources = *req.Sources
}
var reqDestinations []string
if req.Destinations != nil {
reqDestinations = *req.Destinations
}
rule := server.Rule{
ID: xid.New().String(),
Name: req.Name,
Source: reqSources,
Destination: reqDestinations,
Disabled: req.Disabled,
Description: req.Description,
}
switch req.Flow {
case server.TrafficFlowBidirectString:
rule.Flow = server.TrafficFlowBidirect
default:
util.WriteError(status.Errorf(status.InvalidArgument, "unknown flow type"), w)
return
}
policy, err := server.RuleToPolicy(&rule)
if err != nil {
util.WriteError(err, w)
return
}
err = h.accountManager.SavePolicy(account.Id, user.Id, policy)
if err != nil {
util.WriteError(err, w)
return
}
resp := toRuleResponse(account, &rule)
util.WriteJSONObject(w, &resp)
}
// DeleteRule handles rule deletion request
func (h *RulesHandler) DeleteRule(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
aID := account.Id
rID := mux.Vars(r)["ruleId"]
if len(rID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return
}
err = h.accountManager.DeletePolicy(aID, rID, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, emptyObject{})
}
// GetRule handles a group Get request identified by ID
func (h *RulesHandler) GetRule(w http.ResponseWriter, r *http.Request) {
claims := h.claimsExtractor.FromRequestContext(r)
account, user, err := h.accountManager.GetAccountFromToken(claims)
if err != nil {
util.WriteError(err, w)
return
}
switch r.Method {
case http.MethodGet:
ruleID := mux.Vars(r)["ruleId"]
if len(ruleID) == 0 {
util.WriteError(status.Errorf(status.InvalidArgument, "invalid rule ID"), w)
return
}
policy, err := h.accountManager.GetPolicy(account.Id, ruleID, user.Id)
if err != nil {
util.WriteError(err, w)
return
}
util.WriteJSONObject(w, toRuleResponse(account, policy.Rules[0].ToRule()))
default:
util.WriteError(status.Errorf(status.NotFound, "method not found"), w)
}
}
func toRuleResponse(account *server.Account, rule *server.Rule) *api.Rule {
cache := make(map[string]api.GroupMinimum)
gr := api.Rule{
Id: rule.ID,
Name: rule.Name,
Description: rule.Description,
Disabled: rule.Disabled,
}
switch rule.Flow {
case server.TrafficFlowBidirect:
gr.Flow = server.TrafficFlowBidirectString
default:
gr.Flow = "unknown"
}
for _, gid := range rule.Source {
_, ok := cache[gid]
if ok {
continue
}
if group, ok := account.Groups[gid]; ok {
minimum := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
}
gr.Sources = append(gr.Sources, minimum)
cache[gid] = minimum
}
}
for _, gid := range rule.Destination {
cachedMinimum, ok := cache[gid]
if ok {
gr.Destinations = append(gr.Destinations, cachedMinimum)
continue
}
if group, ok := account.Groups[gid]; ok {
minimum := api.GroupMinimum{
Id: group.ID,
Name: group.Name,
PeersCount: len(group.Peers),
}
gr.Destinations = append(gr.Destinations, minimum)
cache[gid] = minimum
}
}
return &gr
}

View File

@@ -1,265 +0,0 @@
package http
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/status"
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/jwtclaims"
"github.com/magiconair/properties/assert"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/mock_server"
)
func initRulesTestData(rules ...*server.Rule) *RulesHandler {
testPolicies := make(map[string]*server.Policy, len(rules))
for _, rule := range rules {
policy, err := server.RuleToPolicy(rule)
if err != nil {
panic(err)
}
testPolicies[policy.ID] = policy
}
return &RulesHandler{
accountManager: &mock_server.MockAccountManager{
GetPolicyFunc: func(_, policyID, _ string) (*server.Policy, error) {
policy, ok := testPolicies[policyID]
if !ok {
return nil, status.Errorf(status.NotFound, "policy not found")
}
return policy, nil
},
SavePolicyFunc: func(_, _ string, policy *server.Policy) error {
if !strings.HasPrefix(policy.ID, "id-") {
policy.ID = "id-was-set"
}
return nil
},
SaveRuleFunc: func(_, _ string, rule *server.Rule) error {
if !strings.HasPrefix(rule.ID, "id-") {
rule.ID = "id-was-set"
}
return nil
},
GetRuleFunc: func(_, ruleID, _ string) (*server.Rule, error) {
if ruleID != "idoftherule" {
return nil, fmt.Errorf("not found")
}
return &server.Rule{
ID: "idoftherule",
Name: "Rule",
Source: []string{"idofsrcrule"},
Destination: []string{"idofdestrule"},
Flow: server.TrafficFlowBidirect,
}, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
user := server.NewAdminUser("test_user")
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Rules: map[string]*server.Rule{"id-existed": {ID: "id-existed"}},
Groups: map[string]*server.Group{
"F": {ID: "F"},
"G": {ID: "G"},
},
Users: map[string]*server.User{
"test_user": user,
},
}, user, nil
},
},
claimsExtractor: jwtclaims.NewClaimsExtractor(
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
Domain: "hotmail.com",
AccountId: "test_id",
}
}),
),
}
}
func TestRulesGetRule(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
requestType string
requestPath string
requestBody io.Reader
}{
{
name: "GetRule OK",
expectedBody: true,
requestType: http.MethodGet,
requestPath: "/api/rules/idoftherule",
expectedStatus: http.StatusOK,
},
{
name: "GetRule not found",
requestType: http.MethodGet,
requestPath: "/api/rules/notexists",
expectedStatus: http.StatusNotFound,
},
}
rule := &server.Rule{
ID: "idoftherule",
Name: "Rule",
}
p := initRulesTestData(rule)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/rules/{ruleId}", p.GetRule).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v",
status, tc.expectedStatus)
return
}
if !tc.expectedBody {
return
}
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
var got api.Rule
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
assert.Equal(t, got.Id, rule.ID)
assert.Equal(t, got.Name, rule.Name)
})
}
}
func TestRulesWriteRule(t *testing.T) {
tt := []struct {
name string
expectedStatus int
expectedBody bool
expectedRule *api.Rule
requestType string
requestPath string
requestBody io.Reader
}{
{
name: "WriteRule POST OK",
requestType: http.MethodPost,
requestPath: "/api/rules",
requestBody: bytes.NewBuffer(
[]byte(`{"Name":"Default POSTed Rule","Flow":"bidirect"}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRule: &api.Rule{
Id: "id-was-set",
Name: "Default POSTed Rule",
Flow: server.TrafficFlowBidirectString,
},
},
{
name: "WriteRule POST Invalid Name",
requestType: http.MethodPost,
requestPath: "/api/rules",
requestBody: bytes.NewBuffer(
[]byte(`{"Name":"","Flow":"bidirect"}`)),
expectedStatus: http.StatusUnprocessableEntity,
expectedBody: false,
},
{
name: "WriteRule PUT OK",
requestType: http.MethodPut,
requestPath: "/api/rules/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`{"Name":"Default POSTed Rule","Flow":"bidirect"}`)),
expectedStatus: http.StatusOK,
expectedBody: true,
expectedRule: &api.Rule{
Id: "id-existed",
Name: "Default POSTed Rule",
Flow: server.TrafficFlowBidirectString,
},
},
{
name: "WriteRule PUT Invalid Name",
requestType: http.MethodPut,
requestPath: "/api/rules/id-existed",
requestBody: bytes.NewBuffer(
[]byte(`{"Name":"","Flow":"bidirect"}`)),
expectedStatus: http.StatusUnprocessableEntity,
},
}
p := initRulesTestData(&server.Rule{
ID: "id-existed",
Name: "Default POSTed Rule",
Flow: server.TrafficFlowBidirect,
})
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
router := mux.NewRouter()
router.HandleFunc("/api/rules", p.CreateRule).Methods("POST")
router.HandleFunc("/api/rules/{ruleId}", p.UpdateRule).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
defer res.Body.Close()
content, err := io.ReadAll(res.Body)
if err != nil {
t.Fatalf("I don't know what I expected; %v", err)
}
if status := recorder.Code; status != tc.expectedStatus {
t.Errorf("handler returned wrong status code: got %v want %v, content: %s",
status, tc.expectedStatus, string(content))
return
}
if !tc.expectedBody {
return
}
got := &api.Rule{}
if err = json.Unmarshal(content, &got); err != nil {
t.Fatalf("Sent content is not in correct json format; %v", err)
}
tc.expectedRule.Id = got.Id
assert.Equal(t, got, tc.expectedRule)
})
}
}

View File

@@ -76,6 +76,10 @@ func NewAuthentikManager(config AuthentikClientConfig,
return nil, fmt.Errorf("authentik IdP configuration is incomplete, TokenEndpoint is missing")
}
if config.Issuer == "" {
return nil, fmt.Errorf("authentik IdP configuration is incomplete, Issuer is missing")
}
if config.GrantType == "" {
return nil, fmt.Errorf("authentik IdP configuration is incomplete, GrantType is missing")
}

View File

@@ -7,9 +7,10 @@ import (
"testing"
"time"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func TestNewAuthentikManager(t *testing.T) {
@@ -25,6 +26,7 @@ func TestNewAuthentikManager(t *testing.T) {
Username: "username",
Password: "password",
TokenEndpoint: "https://localhost:8080/application/o/token/",
Issuer: "https://localhost:8080/application/o/netbird/",
GrantType: "client_credentials",
}
@@ -75,7 +77,17 @@ func TestNewAuthentikManager(t *testing.T) {
assertErrFuncMessage: "should return error when field empty",
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5} {
testCase6Config := defaultTestConfig
testCase6Config.Issuer = ""
testCase6 := test{
name: "Missing Issuer Configuration",
inputConfig: testCase6Config,
assertErrFunc: require.Error,
assertErrFuncMessage: "should return error when field empty",
}
for _, testCase := range []test{testCase1, testCase2, testCase3, testCase4, testCase5, testCase6} {
t.Run(testCase.name, func(t *testing.T) {
_, err := NewAuthentikManager(testCase.inputConfig, &telemetry.MockAppMetrics{})
testCase.assertErrFunc(t, err, testCase.assertErrFuncMessage)

View File

@@ -75,6 +75,27 @@ type zitadelProfile struct {
Human *zitadelUser `json:"human"`
}
// zitadelUserDetails represents the metadata for the new user that was created
type zitadelUserDetails struct {
Sequence string `json:"sequence"` // uint64 as a string
CreationDate string `json:"creationDate"` // ISO format
ChangeDate string `json:"changeDate"` // ISO format
ResourceOwner string
}
// zitadelPasswordlessRegistration represents the information for the user to complete signup
type zitadelPasswordlessRegistration struct {
Link string `json:"link"`
Expiration string `json:"expiration"` // ex: 3600s
}
// zitadelUser represents an zitadel create user response
type zitadelUserResponse struct {
UserId string `json:"userId"`
Details zitadelUserDetails `json:"details"`
PasswordlessRegistration zitadelPasswordlessRegistration `json:"passwordlessRegistration"`
}
// NewZitadelManager creates a new instance of the ZitadelManager.
func NewZitadelManager(config ZitadelClientConfig, appMetrics telemetry.AppMetrics) (*ZitadelManager, error) {
httpTransport := http.DefaultTransport.(*http.Transport).Clone()
@@ -224,9 +245,57 @@ func (zc *ZitadelCredentials) Authenticate() (JWTToken, error) {
return zc.jwtToken, nil
}
// CreateUser creates a new user in zitadel Idp and sends an invite.
func (zm *ZitadelManager) CreateUser(_, _, _, _ string) (*UserData, error) {
return nil, fmt.Errorf("method CreateUser not implemented")
// CreateUser creates a new user in zitadel Idp and sends an invite via Zitadel.
func (zm *ZitadelManager) CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error) {
firstLast := strings.SplitN(name, " ", 2)
var addUser = map[string]any{
"userName": email,
"profile": map[string]string{
"firstName": firstLast[0],
"lastName": firstLast[0],
"displayName": name,
},
"email": map[string]any{
"email": email,
"isEmailVerified": false,
},
"passwordChangeRequired": true,
"requestPasswordlessRegistration": false, // let Zitadel send the invite for us
}
payload, err := zm.helper.Marshal(addUser)
if err != nil {
return nil, err
}
body, err := zm.post("users/human/_import", string(payload))
if err != nil {
return nil, err
}
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountCreateUser()
}
var newUser zitadelUserResponse
err = zm.helper.Unmarshal(body, &newUser)
if err != nil {
return nil, err
}
var pending bool = true
ret := &UserData{
Email: email,
Name: name,
ID: newUser.UserId,
AppMetadata: AppMetadata{
WTAccountID: accountID,
WTPendingInvite: &pending,
WTInvitedBy: invitedByEmail,
},
}
return ret, nil
}
// GetUserByEmail searches users with a given email.
@@ -354,10 +423,25 @@ func (zm *ZitadelManager) UpdateUserAppMetadata(_ string, _ AppMetadata) error {
return nil
}
type inviteUserRequest struct {
Email string `json:"email"`
}
// InviteUserByID resend invitations to users who haven't activated,
// their accounts prior to the expiration period.
func (zm *ZitadelManager) InviteUserByID(_ string) error {
return fmt.Errorf("method InviteUserByID not implemented")
func (zm *ZitadelManager) InviteUserByID(userID string) error {
inviteUser := inviteUserRequest{
Email: userID,
}
payload, err := zm.helper.Marshal(inviteUser)
if err != nil {
return err
}
// don't care about the body in the response
_, err = zm.post(fmt.Sprintf("users/%s/_resend_initialization", userID), string(payload))
return err
}
// DeleteUser from Zitadel
@@ -411,7 +495,38 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
}
// delete perform Delete requests.
func (zm *ZitadelManager) delete(_ string) error {
func (zm *ZitadelManager) delete(resource string) error {
jwtToken, err := zm.credentials.Authenticate()
if err != nil {
return err
}
reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource)
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
if err != nil {
return err
}
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
req.Header.Add("content-type", "application/json")
resp, err := zm.httpClient.Do(req)
if err != nil {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestError()
}
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
if zm.appMetrics != nil {
zm.appMetrics.IDPMetrics().CountRequestStatusError()
}
return fmt.Errorf("unable to get %s, statusCode %d", reqURL, resp.StatusCode)
}
return nil
}

View File

@@ -412,7 +412,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
}
peersUpdateManager := NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "",
accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false)
if err != nil {
return nil, "", err

View File

@@ -503,7 +503,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
}
peersUpdateManager := server.NewPeersUpdateManager(nil)
eventStore := &activity.InMemoryEventStore{}
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "netbird.selfhosted",
eventStore, nil, false)
if err != nil {
log.Fatalf("failed creating a manager: %v", err)

View File

@@ -38,10 +38,7 @@ type MockAccountManager struct {
ListGroupsFunc func(accountID string) ([]*server.Group, error)
GroupAddPeerFunc func(accountID, groupID, peerID string) error
GroupDeletePeerFunc func(accountID, groupID, peerID string) error
GetRuleFunc func(accountID, ruleID, userID string) (*server.Rule, error)
SaveRuleFunc func(accountID, userID string, rule *server.Rule) error
DeleteRuleFunc func(accountID, ruleID, userID string) error
ListRulesFunc func(accountID, userID string) ([]*server.Rule, error)
GetPolicyFunc func(accountID, policyID, userID string) (*server.Policy, error)
SavePolicyFunc func(accountID, userID string, policy *server.Policy) error
DeletePolicyFunc func(accountID, policyID, userID string) error
@@ -302,22 +299,6 @@ func (am *MockAccountManager) GroupDeletePeer(accountID, groupID, peerID string)
return status.Errorf(codes.Unimplemented, "method GroupDeletePeer is not implemented")
}
// GetRule mock implementation of GetRule from server.AccountManager interface
func (am *MockAccountManager) GetRule(accountID, ruleID, userID string) (*server.Rule, error) {
if am.GetRuleFunc != nil {
return am.GetRuleFunc(accountID, ruleID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetRule is not implemented")
}
// SaveRule mock implementation of SaveRule from server.AccountManager interface
func (am *MockAccountManager) SaveRule(accountID, userID string, rule *server.Rule) error {
if am.SaveRuleFunc != nil {
return am.SaveRuleFunc(accountID, userID, rule)
}
return status.Errorf(codes.Unimplemented, "method SaveRule is not implemented")
}
// DeleteRule mock implementation of DeleteRule from server.AccountManager interface
func (am *MockAccountManager) DeleteRule(accountID, ruleID, userID string) error {
if am.DeleteRuleFunc != nil {
@@ -326,14 +307,6 @@ func (am *MockAccountManager) DeleteRule(accountID, ruleID, userID string) error
return status.Errorf(codes.Unimplemented, "method DeleteRule is not implemented")
}
// ListRules mock implementation of ListRules from server.AccountManager interface
func (am *MockAccountManager) ListRules(accountID, userID string) ([]*server.Rule, error) {
if am.ListRulesFunc != nil {
return am.ListRulesFunc(accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method ListRules is not implemented")
}
// GetPolicy mock implementation of GetPolicy from server.AccountManager interface
func (am *MockAccountManager) GetPolicy(accountID, policyID, userID string) (*server.Policy, error) {
if am.GetPolicyFunc != nil {

View File

@@ -759,7 +759,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "", eventStore, nil, false)
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false)
}
func createNSStore(t *testing.T) (Store, error) {

View File

@@ -52,6 +52,17 @@ const (
PolicyRuleFlowBidirect = PolicyRuleDirection("bidirect")
)
const (
// DefaultRuleName is a name for the Default rule that is created for every account
DefaultRuleName = "Default"
// DefaultRuleDescription is a description for the Default rule that is created for every account
DefaultRuleDescription = "This is a default rule that allows connections between all the resources"
// DefaultPolicyName is a name for the Default policy that is created for every account
DefaultPolicyName = "Default"
// DefaultPolicyDescription is a description for the Default policy that is created for every account
DefaultPolicyDescription = "This is a default policy that allows connections between all the resources"
)
const (
firewallRuleDirectionIN = 0
firewallRuleDirectionOUT = 1
@@ -119,19 +130,6 @@ func (pm *PolicyRule) Copy() *PolicyRule {
return rule
}
// ToRule converts the PolicyRule to a legacy representation of the Rule (for backwards compatibility)
func (pm *PolicyRule) ToRule() *Rule {
return &Rule{
ID: pm.ID,
Name: pm.Name,
Description: pm.Description,
Disabled: !pm.Enabled,
Flow: TrafficFlowBidirect,
Destination: pm.Destinations,
Source: pm.Sources,
}
}
// Policy of the Rego query
type Policy struct {
// ID of the policy'

View File

@@ -1014,7 +1014,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
return nil, err
}
eventStore := &activity.InMemoryEventStore{}
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "", eventStore, nil, false)
return BuildManager(store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false)
}
func createRouterStore(t *testing.T) (Store, error) {

View File

@@ -1,100 +0,0 @@
package server
import "fmt"
// TrafficFlowType defines allowed direction of the traffic in the rule
type TrafficFlowType int
const (
// TrafficFlowBidirect allows traffic to both direction
TrafficFlowBidirect TrafficFlowType = iota
// TrafficFlowBidirectString allows traffic to both direction
TrafficFlowBidirectString = "bidirect"
// DefaultRuleName is a name for the Default rule that is created for every account
DefaultRuleName = "Default"
// DefaultRuleDescription is a description for the Default rule that is created for every account
DefaultRuleDescription = "This is a default rule that allows connections between all the resources"
// DefaultPolicyName is a name for the Default policy that is created for every account
DefaultPolicyName = "Default"
// DefaultPolicyDescription is a description for the Default policy that is created for every account
DefaultPolicyDescription = "This is a default policy that allows connections between all the resources"
)
// Rule of ACL for groups
type Rule struct {
// ID of the rule
ID string
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// Name of the rule visible in the UI
Name string
// Description of the rule visible in the UI
Description string
// Disabled status of rule in the system
Disabled bool
// Source list of groups IDs of peers
Source []string `gorm:"serializer:json"`
// Destination list of groups IDs of peers
Destination []string `gorm:"serializer:json"`
// Flow of the traffic allowed by the rule
Flow TrafficFlowType
}
func (r *Rule) Copy() *Rule {
rule := &Rule{
ID: r.ID,
Name: r.Name,
Description: r.Description,
Disabled: r.Disabled,
Source: make([]string, len(r.Source)),
Destination: make([]string, len(r.Destination)),
Flow: r.Flow,
}
copy(rule.Source, r.Source)
copy(rule.Destination, r.Destination)
return rule
}
// EventMeta returns activity event meta related to this rule
func (r *Rule) EventMeta() map[string]any {
return map[string]any{"name": r.Name}
}
// ToPolicyRule converts a Rule to a PolicyRule object
func (r *Rule) ToPolicyRule() *PolicyRule {
if r == nil {
return nil
}
return &PolicyRule{
ID: r.ID,
Name: r.Name,
Enabled: !r.Disabled,
Description: r.Description,
Destinations: r.Destination,
Sources: r.Source,
Bidirectional: true,
Protocol: PolicyRuleProtocolALL,
Action: PolicyTrafficActionAccept,
}
}
// RuleToPolicy converts a Rule to a Policy query object
func RuleToPolicy(rule *Rule) (*Policy, error) {
if rule == nil {
return nil, fmt.Errorf("rule is empty")
}
return &Policy{
ID: rule.ID,
Name: rule.Name,
Description: rule.Description,
Enabled: !rule.Disabled,
Rules: []*PolicyRule{rule.ToPolicyRule()},
}, nil
}

View File

@@ -3,6 +3,7 @@ package server
import (
"fmt"
"math/rand"
"runtime"
"sync"
"testing"
"time"
@@ -25,7 +26,13 @@ func TestScheduler_Performance(t *testing.T) {
return 0, false
})
}
failed := waitTimeout(wg, 3*time.Second)
timeout := 3 * time.Second
if runtime.GOOS == "windows" {
// sleep and ticker are slower on windows see https://github.com/golang/go/issues/44343
timeout = 5 * time.Second
}
failed := waitTimeout(wg, timeout)
if failed {
t.Fatal("timed out while waiting for test to finish")
return
@@ -39,22 +46,29 @@ func TestScheduler_Cancel(t *testing.T) {
scheduler := NewDefaultScheduler()
tChan := make(chan struct{})
p := []string{jobID1, jobID2}
scheduler.Schedule(2*time.Millisecond, jobID1, func() (nextRunIn time.Duration, reschedule bool) {
scheduletime := 2 * time.Millisecond
sleepTime := 4 * time.Millisecond
if runtime.GOOS == "windows" {
// sleep and ticker are slower on windows see https://github.com/golang/go/issues/44343
sleepTime = 20 * time.Millisecond
}
scheduler.Schedule(scheduletime, jobID1, func() (nextRunIn time.Duration, reschedule bool) {
tt := p[0]
<-tChan
t.Logf("job %s", tt)
return 2 * time.Millisecond, true
return scheduletime, true
})
scheduler.Schedule(2*time.Millisecond, jobID2, func() (nextRunIn time.Duration, reschedule bool) {
return 2 * time.Millisecond, true
scheduler.Schedule(scheduletime, jobID2, func() (nextRunIn time.Duration, reschedule bool) {
return scheduletime, true
})
time.Sleep(4 * time.Millisecond)
time.Sleep(sleepTime)
assert.Len(t, scheduler.jobs, 2)
scheduler.Cancel([]string{jobID1})
close(tChan)
p = []string{}
time.Sleep(4 * time.Millisecond)
time.Sleep(sleepTime)
assert.Len(t, scheduler.jobs, 1)
assert.NotNil(t, scheduler.jobs[jobID2])
}

View File

@@ -64,7 +64,7 @@ func NewSqliteStore(dataDir string, metrics telemetry.AppMetrics) (*SqliteStore,
sql.SetMaxOpenConns(conns) // TODO: make it configurable
err = db.AutoMigrate(
&SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &Group{}, &Rule{},
&SetupKey{}, &nbpeer.Peer{}, &User{}, &PersonalAccessToken{}, &Group{},
&Account{}, &Policy{}, &PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{},
&installation{}, &account.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{},
)

View File

@@ -103,7 +103,14 @@ func NewStoreFromJson(dataDir string, metrics telemetry.AppMetrics) (Store, erro
return nil, err
}
switch kind := getStoreEngineFromEnv(); kind {
// if store engine is not set in the config we first try to evaluate NETBIRD_STORE_ENGINE
kind := getStoreEngineFromEnv()
if kind == "" {
// NETBIRD_STORE_ENGINE is not set we evaluate default based on dataDir
kind = getStoreEngineFromDatadir(dataDir)
}
switch kind {
case FileStoreEngine:
return fstore, nil
case SqliteStoreEngine: