mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 08:46:38 +00:00
Revert "Merge branch 'main' into feature/remote-debug"
This reverts commit6d6333058c, reversing changes made to446aded1f7.
This commit is contained in:
@@ -11,11 +11,11 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -198,7 +198,6 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
if req.Settings.Extra != nil {
|
||||
settings.Extra = &types.ExtraSettings{
|
||||
PeerApprovalEnabled: req.Settings.Extra.PeerApprovalEnabled,
|
||||
UserApprovalRequired: req.Settings.Extra.UserApprovalRequired,
|
||||
FlowEnabled: req.Settings.Extra.NetworkTrafficLogsEnabled,
|
||||
FlowGroups: req.Settings.Extra.NetworkTrafficLogsGroups,
|
||||
FlowPacketCounterEnabled: req.Settings.Extra.NetworkTrafficPacketCounterEnabled,
|
||||
@@ -328,7 +327,6 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
||||
if settings.Extra != nil {
|
||||
apiSettings.Extra = &api.AccountExtraSettings{
|
||||
PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled,
|
||||
UserApprovalRequired: settings.Extra.UserApprovalRequired,
|
||||
NetworkTrafficLogsEnabled: settings.Extra.FlowEnabled,
|
||||
NetworkTrafficLogsGroups: settings.Extra.FlowGroups,
|
||||
NetworkTrafficPacketCounterEnabled: settings.Extra.FlowPacketCounterEnabled,
|
||||
|
||||
@@ -15,11 +15,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func initAccountsTestData(t *testing.T, account *types.Account) *handler {
|
||||
|
||||
@@ -488,33 +488,33 @@ func toPeerListItemResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dn
|
||||
}
|
||||
|
||||
return &api.PeerBatch{
|
||||
CreatedAt: peer.CreatedAt,
|
||||
Id: peer.ID,
|
||||
Name: peer.Name,
|
||||
Ip: peer.IP.String(),
|
||||
ConnectionIp: peer.Location.ConnectionIP.String(),
|
||||
Connected: peer.Status.Connected,
|
||||
LastSeen: peer.Status.LastSeen,
|
||||
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
|
||||
KernelVersion: peer.Meta.KernelVersion,
|
||||
GeonameId: int(peer.Location.GeoNameID),
|
||||
Version: peer.Meta.WtVersion,
|
||||
Groups: groupsInfo,
|
||||
SshEnabled: peer.SSHEnabled,
|
||||
Hostname: peer.Meta.Hostname,
|
||||
UserId: peer.UserID,
|
||||
UiVersion: peer.Meta.UIVersion,
|
||||
DnsLabel: fqdn(peer, dnsDomain),
|
||||
ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain),
|
||||
LoginExpirationEnabled: peer.LoginExpirationEnabled,
|
||||
LastLogin: peer.GetLastLogin(),
|
||||
LoginExpired: peer.Status.LoginExpired,
|
||||
AccessiblePeersCount: accessiblePeersCount,
|
||||
CountryCode: peer.Location.CountryCode,
|
||||
CityName: peer.Location.CityName,
|
||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||
CreatedAt: peer.CreatedAt,
|
||||
Id: peer.ID,
|
||||
Name: peer.Name,
|
||||
Ip: peer.IP.String(),
|
||||
ConnectionIp: peer.Location.ConnectionIP.String(),
|
||||
Connected: peer.Status.Connected,
|
||||
LastSeen: peer.Status.LastSeen,
|
||||
Os: fmt.Sprintf("%s %s", peer.Meta.OS, osVersion),
|
||||
KernelVersion: peer.Meta.KernelVersion,
|
||||
GeonameId: int(peer.Location.GeoNameID),
|
||||
Version: peer.Meta.WtVersion,
|
||||
Groups: groupsInfo,
|
||||
SshEnabled: peer.SSHEnabled,
|
||||
Hostname: peer.Meta.Hostname,
|
||||
UserId: peer.UserID,
|
||||
UiVersion: peer.Meta.UIVersion,
|
||||
DnsLabel: fqdn(peer, dnsDomain),
|
||||
ExtraDnsLabels: fqdnList(peer.ExtraDNSLabels, dnsDomain),
|
||||
LoginExpirationEnabled: peer.LoginExpirationEnabled,
|
||||
LastLogin: peer.GetLastLogin(),
|
||||
LoginExpired: peer.Status.LoginExpired,
|
||||
AccessiblePeersCount: accessiblePeersCount,
|
||||
CountryCode: peer.Location.CountryCode,
|
||||
CityName: peer.Location.CityName,
|
||||
SerialNumber: peer.Meta.SystemSerialNumber,
|
||||
|
||||
InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
|
||||
Ephemeral: peer.Ephemeral,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -8,19 +8,17 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
const failedToConvertRoute = "failed to convert route to response: %v"
|
||||
|
||||
const exitNodeCIDR = "0.0.0.0/0"
|
||||
|
||||
// handler is the routes handler of the account
|
||||
type handler struct {
|
||||
accountManager account.Manager
|
||||
@@ -126,16 +124,8 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
|
||||
accessControlGroupIds = *req.AccessControlGroups
|
||||
}
|
||||
|
||||
// Set default skipAutoApply value for exit nodes (0.0.0.0/0 routes)
|
||||
skipAutoApply := false
|
||||
if req.SkipAutoApply != nil {
|
||||
skipAutoApply = *req.SkipAutoApply
|
||||
} else if newPrefix.String() == exitNodeCIDR {
|
||||
skipAutoApply = false
|
||||
}
|
||||
|
||||
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
|
||||
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute, skipAutoApply)
|
||||
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute)
|
||||
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -152,31 +142,23 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func (h *handler) validateRoute(req api.PostApiRoutesJSONRequestBody) error {
|
||||
return h.validateRouteCommon(req.Network, req.Domains, req.Peer, req.PeerGroups, req.NetworkId)
|
||||
}
|
||||
|
||||
func (h *handler) validateRouteUpdate(req api.PutApiRoutesRouteIdJSONRequestBody) error {
|
||||
return h.validateRouteCommon(req.Network, req.Domains, req.Peer, req.PeerGroups, req.NetworkId)
|
||||
}
|
||||
|
||||
func (h *handler) validateRouteCommon(network *string, domains *[]string, peer *string, peerGroups *[]string, networkId string) error {
|
||||
if network != nil && domains != nil {
|
||||
if req.Network != nil && req.Domains != nil {
|
||||
return status.Errorf(status.InvalidArgument, "only one of 'network' or 'domains' should be provided")
|
||||
}
|
||||
|
||||
if network == nil && domains == nil {
|
||||
if req.Network == nil && req.Domains == nil {
|
||||
return status.Errorf(status.InvalidArgument, "either 'network' or 'domains' should be provided")
|
||||
}
|
||||
|
||||
if peer == nil && peerGroups == nil {
|
||||
if req.Peer == nil && req.PeerGroups == nil {
|
||||
return status.Errorf(status.InvalidArgument, "either 'peer' or 'peer_groups' should be provided")
|
||||
}
|
||||
|
||||
if peer != nil && peerGroups != nil {
|
||||
if req.Peer != nil && req.PeerGroups != nil {
|
||||
return status.Errorf(status.InvalidArgument, "only one of 'peer' or 'peer_groups' should be provided")
|
||||
}
|
||||
|
||||
if utf8.RuneCountInString(networkId) > route.MaxNetIDChar || networkId == "" {
|
||||
if utf8.RuneCountInString(req.NetworkId) > route.MaxNetIDChar || req.NetworkId == "" {
|
||||
return status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d characters",
|
||||
route.MaxNetIDChar)
|
||||
}
|
||||
@@ -213,7 +195,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.validateRouteUpdate(req); err != nil {
|
||||
if err := h.validateRoute(req); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
@@ -223,24 +205,15 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
peerID = *req.Peer
|
||||
}
|
||||
|
||||
// Set default skipAutoApply value for exit nodes (0.0.0.0/0 routes)
|
||||
skipAutoApply := false
|
||||
if req.SkipAutoApply != nil {
|
||||
skipAutoApply = *req.SkipAutoApply
|
||||
} else if req.Network != nil && *req.Network == exitNodeCIDR {
|
||||
skipAutoApply = false
|
||||
}
|
||||
|
||||
newRoute := &route.Route{
|
||||
ID: route.ID(routeID),
|
||||
NetID: route.NetID(req.NetworkId),
|
||||
Masquerade: req.Masquerade,
|
||||
Metric: req.Metric,
|
||||
Description: req.Description,
|
||||
Enabled: req.Enabled,
|
||||
Groups: req.Groups,
|
||||
KeepRoute: req.KeepRoute,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
ID: route.ID(routeID),
|
||||
NetID: route.NetID(req.NetworkId),
|
||||
Masquerade: req.Masquerade,
|
||||
Metric: req.Metric,
|
||||
Description: req.Description,
|
||||
Enabled: req.Enabled,
|
||||
Groups: req.Groups,
|
||||
KeepRoute: req.KeepRoute,
|
||||
}
|
||||
|
||||
if req.Domains != nil {
|
||||
@@ -348,19 +321,18 @@ func toRouteResponse(serverRoute *route.Route) (*api.Route, error) {
|
||||
}
|
||||
network := serverRoute.Network.String()
|
||||
route := &api.Route{
|
||||
Id: string(serverRoute.ID),
|
||||
Description: serverRoute.Description,
|
||||
NetworkId: string(serverRoute.NetID),
|
||||
Enabled: serverRoute.Enabled,
|
||||
Peer: &serverRoute.Peer,
|
||||
Network: &network,
|
||||
Domains: &domains,
|
||||
NetworkType: serverRoute.NetworkType.String(),
|
||||
Masquerade: serverRoute.Masquerade,
|
||||
Metric: serverRoute.Metric,
|
||||
Groups: serverRoute.Groups,
|
||||
KeepRoute: serverRoute.KeepRoute,
|
||||
SkipAutoApply: &serverRoute.SkipAutoApply,
|
||||
Id: string(serverRoute.ID),
|
||||
Description: serverRoute.Description,
|
||||
NetworkId: string(serverRoute.NetID),
|
||||
Enabled: serverRoute.Enabled,
|
||||
Peer: &serverRoute.Peer,
|
||||
Network: &network,
|
||||
Domains: &domains,
|
||||
NetworkType: serverRoute.NetworkType.String(),
|
||||
Masquerade: serverRoute.Masquerade,
|
||||
Metric: serverRoute.Metric,
|
||||
Groups: serverRoute.Groups,
|
||||
KeepRoute: serverRoute.KeepRoute,
|
||||
}
|
||||
|
||||
if len(serverRoute.PeerGroups) > 0 {
|
||||
|
||||
@@ -15,13 +15,13 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -62,22 +62,21 @@ func initRoutesTestData() *handler {
|
||||
return &handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
GetRouteFunc: func(_ context.Context, _ string, routeID route.ID, _ string) (*route.Route, error) {
|
||||
switch routeID {
|
||||
case existingRouteID:
|
||||
if routeID == existingRouteID {
|
||||
return baseExistingRoute, nil
|
||||
case existingRouteID2:
|
||||
}
|
||||
if routeID == existingRouteID2 {
|
||||
route := baseExistingRoute.Copy()
|
||||
route.PeerGroups = []string{existingGroupID}
|
||||
return route, nil
|
||||
case existingRouteID3:
|
||||
} else if routeID == existingRouteID3 {
|
||||
route := baseExistingRoute.Copy()
|
||||
route.Domains = domain.List{existingDomain}
|
||||
return route, nil
|
||||
default:
|
||||
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
||||
}
|
||||
return nil, status.Errorf(status.NotFound, "route with ID %s not found", routeID)
|
||||
},
|
||||
CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool, skipAutoApply bool) (*route.Route, error) {
|
||||
CreateRouteFunc: func(_ context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peerID string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroups []string, enabled bool, _ string, keepRoute bool) (*route.Route, error) {
|
||||
if peerID == notFoundPeerID {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
|
||||
}
|
||||
@@ -104,7 +103,6 @@ func initRoutesTestData() *handler {
|
||||
Groups: groups,
|
||||
KeepRoute: keepRoute,
|
||||
AccessControlGroups: accessControlGroups,
|
||||
SkipAutoApply: skipAutoApply,
|
||||
}, nil
|
||||
},
|
||||
SaveRouteFunc: func(_ context.Context, _, _ string, r *route.Route) error {
|
||||
@@ -192,20 +190,19 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/routes",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","network_id":"awesomeNet","Peer":"%s","groups":["%s"],"skip_auto_apply":false}`, existingPeerID, existingGroupID))),
|
||||
[]byte(fmt.Sprintf(`{"Description":"Post","Network":"192.168.0.0/16","network_id":"awesomeNet","Peer":"%s","groups":["%s"]}`, existingPeerID, existingGroupID))),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: &api.Route{
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: util.ToPtr("192.168.0.0/16"),
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.IPv4NetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
SkipAutoApply: util.ToPtr(false),
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: util.ToPtr("192.168.0.0/16"),
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.IPv4NetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -213,22 +210,21 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/routes",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(fmt.Sprintf(`{"description":"Post","domains":["example.com"],"network_id":"domainNet","peer":"%s","groups":["%s"],"keep_route":true,"skip_auto_apply":false}`, existingPeerID, existingGroupID))),
|
||||
[]byte(fmt.Sprintf(`{"description":"Post","domains":["example.com"],"network_id":"domainNet","peer":"%s","groups":["%s"],"keep_route":true}`, existingPeerID, existingGroupID))),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: &api.Route{
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "domainNet",
|
||||
Network: util.ToPtr("invalid Prefix"),
|
||||
KeepRoute: true,
|
||||
Domains: &[]string{existingDomain},
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.DomainNetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
SkipAutoApply: util.ToPtr(false),
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "domainNet",
|
||||
Network: util.ToPtr("invalid Prefix"),
|
||||
KeepRoute: true,
|
||||
Domains: &[]string{existingDomain},
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.DomainNetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -236,7 +232,7 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
requestType: http.MethodPost,
|
||||
requestPath: "/api/routes",
|
||||
requestBody: bytes.NewBuffer(
|
||||
[]byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"],\"skip_auto_apply\":false}", existingPeerID, existingGroupID, existingGroupID))),
|
||||
[]byte(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"access_control_groups\":[\"%s\"]}", existingPeerID, existingGroupID, existingGroupID))),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: &api.Route{
|
||||
@@ -250,7 +246,6 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
AccessControlGroups: &[]string{existingGroupID},
|
||||
SkipAutoApply: util.ToPtr(false),
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -341,63 +336,60 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
name: "Network PUT OK",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/routes/" + existingRouteID,
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"],\"is_selected\":true}", existingPeerID, existingGroupID)),
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"Peer\":\"%s\",\"groups\":[\"%s\"]}", existingPeerID, existingGroupID)),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: &api.Route{
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: util.ToPtr("192.168.0.0/16"),
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.IPv4NetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
SkipAutoApply: util.ToPtr(false),
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: util.ToPtr("192.168.0.0/16"),
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.IPv4NetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Domains PUT OK",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/routes/" + existingRouteID,
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"],"keep_route":true,"skip_auto_apply":false}`, existingPeerID, existingGroupID)),
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf(`{"Description":"Post","domains":["example.com"],"network_id":"awesomeNet","Peer":"%s","groups":["%s"],"keep_route":true}`, existingPeerID, existingGroupID)),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: &api.Route{
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: util.ToPtr("invalid Prefix"),
|
||||
Domains: &[]string{existingDomain},
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.DomainNetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
KeepRoute: true,
|
||||
SkipAutoApply: util.ToPtr(false),
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: util.ToPtr("invalid Prefix"),
|
||||
Domains: &[]string{existingDomain},
|
||||
Peer: &existingPeerID,
|
||||
NetworkType: route.DomainNetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
KeepRoute: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "PUT OK when peer_groups provided",
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/routes/" + existingRouteID,
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"peer_groups\":[\"%s\"],\"groups\":[\"%s\"],\"skip_auto_apply\":false}", existingGroupID, existingGroupID)),
|
||||
requestBody: bytes.NewBufferString(fmt.Sprintf("{\"Description\":\"Post\",\"Network\":\"192.168.0.0/16\",\"network_id\":\"awesomeNet\",\"peer_groups\":[\"%s\"],\"groups\":[\"%s\"]}", existingGroupID, existingGroupID)),
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedBody: true,
|
||||
expectedRoute: &api.Route{
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: util.ToPtr("192.168.0.0/16"),
|
||||
Peer: &emptyString,
|
||||
PeerGroups: &[]string{existingGroupID},
|
||||
NetworkType: route.IPv4NetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
SkipAutoApply: util.ToPtr(false),
|
||||
Id: existingRouteID,
|
||||
Description: "Post",
|
||||
NetworkId: "awesomeNet",
|
||||
Network: util.ToPtr("192.168.0.0/16"),
|
||||
Peer: &emptyString,
|
||||
PeerGroups: &[]string{existingGroupID},
|
||||
NetworkType: route.IPv4NetworkString,
|
||||
Masquerade: false,
|
||||
Enabled: false,
|
||||
Groups: []string{existingGroupID},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
||||
@@ -9,11 +9,11 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
)
|
||||
@@ -31,8 +31,6 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS")
|
||||
addUsersTokensEndpoint(accountManager, router)
|
||||
}
|
||||
|
||||
@@ -325,76 +323,17 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
|
||||
}
|
||||
|
||||
isCurrent := user.ID == currenUserID
|
||||
|
||||
return &api.User{
|
||||
Id: user.ID,
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
AutoGroups: autoGroups,
|
||||
Status: userStatus,
|
||||
IsCurrent: &isCurrent,
|
||||
IsServiceUser: &user.IsServiceUser,
|
||||
IsBlocked: user.IsBlocked,
|
||||
LastLogin: &user.LastLogin,
|
||||
Issued: &user.Issued,
|
||||
PendingApproval: user.PendingApproval,
|
||||
Id: user.ID,
|
||||
Name: user.Name,
|
||||
Email: user.Email,
|
||||
Role: user.Role,
|
||||
AutoGroups: autoGroups,
|
||||
Status: userStatus,
|
||||
IsCurrent: &isCurrent,
|
||||
IsServiceUser: &user.IsServiceUser,
|
||||
IsBlocked: user.IsBlocked,
|
||||
LastLogin: &user.LastLogin,
|
||||
Issued: &user.Issued,
|
||||
}
|
||||
}
|
||||
|
||||
// approveUser is a POST request to approve a user that is pending approval
|
||||
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
userResponse := toUserResponse(user, userAuth.UserId)
|
||||
util.WriteJSONObject(r.Context(), w, userResponse)
|
||||
}
|
||||
|
||||
// rejectUser is a DELETE request to reject a user that is pending approval
|
||||
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodDelete {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteErrorResponse("invalid user ID", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
@@ -16,13 +16,13 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -725,133 +725,3 @@ func stringifyPermissionsKeys(permissions roles.Permissions) map[string]map[stri
|
||||
}
|
||||
return modules
|
||||
}
|
||||
|
||||
func TestApproveUserEndpoint(t *testing.T) {
|
||||
adminUser := &types.User{
|
||||
Id: "admin-user",
|
||||
Role: types.UserRoleAdmin,
|
||||
AccountID: existingAccountID,
|
||||
AutoGroups: []string{},
|
||||
}
|
||||
|
||||
pendingUser := &types.User{
|
||||
Id: "pending-user",
|
||||
Role: types.UserRoleUser,
|
||||
AccountID: existingAccountID,
|
||||
Blocked: true,
|
||||
PendingApproval: true,
|
||||
AutoGroups: []string{},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
expectedBody bool
|
||||
requestingUser *types.User
|
||||
}{
|
||||
{
|
||||
name: "approve user as admin should return 200",
|
||||
expectedStatus: 200,
|
||||
expectedBody: true,
|
||||
requestingUser: adminUser,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{}
|
||||
am.ApproveUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) (*types.UserInfo, error) {
|
||||
approvedUserInfo := &types.UserInfo{
|
||||
ID: pendingUser.Id,
|
||||
Email: "pending@example.com",
|
||||
Name: "Pending User",
|
||||
Role: string(pendingUser.Role),
|
||||
AutoGroups: []string{},
|
||||
IsServiceUser: false,
|
||||
IsBlocked: false,
|
||||
PendingApproval: false,
|
||||
LastLogin: time.Now(),
|
||||
Issued: types.UserIssuedAPI,
|
||||
}
|
||||
return approvedUserInfo, nil
|
||||
}
|
||||
|
||||
handler := newHandler(am)
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST")
|
||||
|
||||
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
userAuth := nbcontext.UserAuth{
|
||||
AccountId: existingAccountID,
|
||||
UserId: tc.requestingUser.Id,
|
||||
}
|
||||
ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
if tc.expectedBody {
|
||||
var response api.User
|
||||
err = json.Unmarshal(rr.Body.Bytes(), &response)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "pending-user", response.Id)
|
||||
assert.False(t, response.IsBlocked)
|
||||
assert.False(t, response.PendingApproval)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRejectUserEndpoint(t *testing.T) {
|
||||
adminUser := &types.User{
|
||||
Id: "admin-user",
|
||||
Role: types.UserRoleAdmin,
|
||||
AccountID: existingAccountID,
|
||||
AutoGroups: []string{},
|
||||
}
|
||||
|
||||
tt := []struct {
|
||||
name string
|
||||
expectedStatus int
|
||||
requestingUser *types.User
|
||||
}{
|
||||
{
|
||||
name: "reject user as admin should return 200",
|
||||
expectedStatus: 200,
|
||||
requestingUser: adminUser,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
am := &mock_server.MockAccountManager{}
|
||||
am.RejectUserFunc = func(ctx context.Context, accountID, initiatorUserID, targetUserID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
handler := newHandler(am)
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/users/{userId}/reject", handler.rejectUser).Methods("DELETE")
|
||||
|
||||
req, err := http.NewRequest("DELETE", "/users/pending-user/reject", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
userAuth := nbcontext.UserAuth{
|
||||
AccountId: existingAccountID,
|
||||
UserId: tc.requestingUser.Id,
|
||||
}
|
||||
ctx := nbcontext.SetUserAuthInContext(req.Context(), userAuth)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,9 +13,9 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error)
|
||||
|
||||
@@ -8,15 +8,16 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -17,9 +17,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
)
|
||||
|
||||
const modulePeers = "peers"
|
||||
@@ -48,7 +47,7 @@ func BenchmarkUpdatePeer(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesPeers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -66,7 +65,7 @@ func BenchmarkUpdatePeer(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationUpdate)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -83,7 +82,7 @@ func BenchmarkGetOnePeer(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesPeers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -93,7 +92,7 @@ func BenchmarkGetOnePeer(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetOne)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -110,7 +109,7 @@ func BenchmarkGetAllPeers(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesPeers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -120,7 +119,7 @@ func BenchmarkGetAllPeers(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationGetAll)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -137,7 +136,7 @@ func BenchmarkDeletePeer(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesPeers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/peers.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), 1000, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -147,7 +146,7 @@ func BenchmarkDeletePeer(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, modulePeers, testing_tools.OperationDelete)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,9 +17,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
)
|
||||
|
||||
// Map to store peers, groups, users, and setupKeys by name
|
||||
@@ -48,7 +47,7 @@ func BenchmarkCreateSetupKey(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesSetupKeys {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -70,7 +69,7 @@ func BenchmarkCreateSetupKey(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationCreate)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -87,7 +86,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesSetupKeys {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -110,7 +109,7 @@ func BenchmarkUpdateSetupKey(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationUpdate)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -127,7 +126,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesSetupKeys {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -137,7 +136,7 @@ func BenchmarkGetOneSetupKey(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetOne)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -154,7 +153,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesSetupKeys {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -164,7 +163,7 @@ func BenchmarkGetAllSetupKeys(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationGetAll)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -181,7 +180,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesSetupKeys {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/setup_keys.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, 1000)
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -191,7 +190,7 @@ func BenchmarkDeleteSetupKey(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleSetupKeys, testing_tools.OperationDelete)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,9 +18,8 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
)
|
||||
|
||||
const moduleUsers = "users"
|
||||
@@ -47,7 +46,7 @@ func BenchmarkUpdateUser(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesUsers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -72,7 +71,7 @@ func BenchmarkUpdateUser(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationUpdate)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -85,18 +84,18 @@ func BenchmarkGetOneUser(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesUsers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId)
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users/"+testing_tools.TestUserId, testing_tools.TestAdminId)
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetOne)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -111,18 +110,18 @@ func BenchmarkGetAllUsers(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesUsers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, bc.Users, bc.SetupKeys)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId)
|
||||
for i := 0; i < b.N; i++ {
|
||||
req := testing_tools.BuildRequest(b, nil, http.MethodGet, "/api/users", testing_tools.TestAdminId)
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationGetAll)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -137,7 +136,7 @@ func BenchmarkDeleteUsers(b *testing.B) {
|
||||
|
||||
for name, bc := range benchCasesUsers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
apiHandler, am, _ := channel.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
apiHandler, am, _ := testing_tools.BuildApiBlackBoxWithDBState(b, "../testdata/users.sql", nil, false)
|
||||
testing_tools.PopulateTestData(b, am.(*server.DefaultAccountManager), bc.Peers, bc.Groups, 1000, bc.SetupKeys)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
@@ -148,7 +147,7 @@ func BenchmarkDeleteUsers(b *testing.B) {
|
||||
apiHandler.ServeHTTP(recorder, req)
|
||||
}
|
||||
|
||||
testing_tools.EvaluateAPIBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete)
|
||||
testing_tools.EvaluateBenchmarkResults(b, name, time.Since(start), recorder, moduleUsers, testing_tools.OperationDelete)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,10 +15,9 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools/channel"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
func Test_SetupKeys_Create(t *testing.T) {
|
||||
@@ -288,7 +287,7 @@ func Test_SetupKeys_Create(t *testing.T) {
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(user.name+" - "+tc.name, func(t *testing.T) {
|
||||
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
@@ -573,7 +572,7 @@ func Test_SetupKeys_Update(t *testing.T) {
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
|
||||
body, err := json.Marshal(tc.requestBody)
|
||||
if err != nil {
|
||||
@@ -752,7 +751,7 @@ func Test_SetupKeys_Get(t *testing.T) {
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId)
|
||||
|
||||
@@ -904,7 +903,7 @@ func Test_SetupKeys_GetAll(t *testing.T) {
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, tc.requestPath, user.userId)
|
||||
|
||||
@@ -1088,7 +1087,7 @@ func Test_SetupKeys_Delete(t *testing.T) {
|
||||
for _, tc := range tt {
|
||||
for _, user := range users {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
apiHandler, am, done := channel.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
apiHandler, am, done := testing_tools.BuildApiBlackBoxWithDBState(t, "../testdata/setup_keys.sql", nil, true)
|
||||
|
||||
req := testing_tools.BuildRequest(t, []byte{}, tc.requestType, strings.Replace(tc.requestPath, "{id}", tc.requestId, 1), user.userId)
|
||||
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
package channel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
http2 "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
)
|
||||
|
||||
func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPeerUpdate *server.UpdateMessage, validateUpdate bool) (http.Handler, account.Manager, chan struct{}) {
|
||||
store, cleanup, err := store.NewTestStoreFromSQL(context.Background(), sqlFile, t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test store: %v", err)
|
||||
}
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create metrics: %v", err)
|
||||
}
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
updMsg := peersUpdateManager.CreateChannel(context.Background(), testing_tools.TestPeerId)
|
||||
done := make(chan struct{})
|
||||
if validateUpdate {
|
||||
go func() {
|
||||
if expectedPeerUpdate != nil {
|
||||
peerShouldReceiveUpdate(t, updMsg, expectedPeerUpdate)
|
||||
} else {
|
||||
peerShouldNotReceiveUpdate(t, updMsg)
|
||||
}
|
||||
close(done)
|
||||
}()
|
||||
}
|
||||
|
||||
geoMock := &geolocation.Mock{}
|
||||
validatorMock := server.MockIntegratedValidator{}
|
||||
proxyController := integrations.NewController(store)
|
||||
userManager := users.NewManager(store)
|
||||
permissionsManager := permissions.NewManager(store)
|
||||
settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager)
|
||||
am, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create manager: %v", err)
|
||||
}
|
||||
|
||||
// @note this is required so that PAT's validate from store, but JWT's are mocked
|
||||
authManager := auth.NewManager(store, "", "", "", "", []string{}, false)
|
||||
authManagerMock := &auth.MockManager{
|
||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||
EnsureUserAccessByJWTGroupsFunc: authManager.EnsureUserAccessByJWTGroups,
|
||||
MarkPATUsedFunc: authManager.MarkPATUsed,
|
||||
GetPATInfoFunc: authManager.GetPATInfo,
|
||||
}
|
||||
|
||||
networksManagerMock := networks.NewManagerMock()
|
||||
resourcesManagerMock := resources.NewManagerMock()
|
||||
routersManagerMock := routers.NewManagerMock()
|
||||
groupsManagerMock := groups.NewManagerMock()
|
||||
peersManager := peers.NewManager(store, permissionsManager)
|
||||
|
||||
apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManagerMock, resourcesManagerMock, routersManagerMock, groupsManagerMock, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create API handler: %v", err)
|
||||
}
|
||||
|
||||
return apiHandler, am, done
|
||||
}
|
||||
|
||||
func peerShouldNotReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage) {
|
||||
t.Helper()
|
||||
select {
|
||||
case msg := <-updateMessage:
|
||||
t.Errorf("Unexpected message received: %+v", msg)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func peerShouldReceiveUpdate(t testing_tools.TB, updateMessage <-chan *server.UpdateMessage, expected *server.UpdateMessage) {
|
||||
t.Helper()
|
||||
|
||||
select {
|
||||
case msg := <-updateMessage:
|
||||
if msg == nil {
|
||||
t.Errorf("Received nil update message, expected valid message")
|
||||
}
|
||||
assert.Equal(t, expected, msg)
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
t.Errorf("Timed out waiting for update message")
|
||||
}
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
|
||||
userAuth := nbcontext.UserAuth{}
|
||||
|
||||
switch token {
|
||||
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
|
||||
userAuth.UserId = token
|
||||
userAuth.AccountId = "testAccountId"
|
||||
userAuth.Domain = "test.com"
|
||||
userAuth.DomainCategory = "private"
|
||||
case "otherUserId":
|
||||
userAuth.UserId = "otherUserId"
|
||||
userAuth.AccountId = "otherAccountId"
|
||||
userAuth.Domain = "other.com"
|
||||
userAuth.DomainCategory = "private"
|
||||
case "invalidToken":
|
||||
return userAuth, nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
jwtToken := jwt.New(jwt.SigningMethodHS256)
|
||||
return userAuth, jwtToken, nil
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package testing_tools
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -13,12 +14,32 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
nbhttp "github.com/netbirdio/netbird/management/server/http"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
)
|
||||
@@ -202,11 +223,11 @@ func ReadResponse(t *testing.T, recorder *httptest.ResponseRecorder, expectedSta
|
||||
return content, expectedStatus == http.StatusOK
|
||||
}
|
||||
|
||||
func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, setupKeys int) {
|
||||
func PopulateTestData(b *testing.B, am *server.DefaultAccountManager, peers, groups, users, setupKeys int) {
|
||||
b.Helper()
|
||||
|
||||
ctx := context.Background()
|
||||
acc, err := am.GetAccount(ctx, TestAccountId)
|
||||
account, err := am.GetAccount(ctx, TestAccountId)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
@@ -222,23 +243,23 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se
|
||||
Status: &nbpeer.PeerStatus{LastSeen: time.Now().UTC(), Connected: true},
|
||||
UserID: TestUserId,
|
||||
}
|
||||
acc.Peers[peer.ID] = peer
|
||||
account.Peers[peer.ID] = peer
|
||||
}
|
||||
|
||||
// Create users
|
||||
for i := 0; i < users; i++ {
|
||||
user := &types.User{
|
||||
Id: fmt.Sprintf("olduser-%d", i),
|
||||
AccountID: acc.Id,
|
||||
AccountID: account.Id,
|
||||
Role: types.UserRoleUser,
|
||||
}
|
||||
acc.Users[user.Id] = user
|
||||
account.Users[user.Id] = user
|
||||
}
|
||||
|
||||
for i := 0; i < setupKeys; i++ {
|
||||
key := &types.SetupKey{
|
||||
Id: fmt.Sprintf("oldkey-%d", i),
|
||||
AccountID: acc.Id,
|
||||
AccountID: account.Id,
|
||||
AutoGroups: []string{"someGroupID"},
|
||||
UpdatedAt: time.Now().UTC(),
|
||||
ExpiresAt: util.ToPtr(time.Now().Add(ExpiresIn * time.Second)),
|
||||
@@ -246,11 +267,11 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se
|
||||
Type: "reusable",
|
||||
UsageLimit: 0,
|
||||
}
|
||||
acc.SetupKeys[key.Id] = key
|
||||
account.SetupKeys[key.Id] = key
|
||||
}
|
||||
|
||||
// Create groups and policies
|
||||
acc.Policies = make([]*types.Policy, 0, groups)
|
||||
account.Policies = make([]*types.Policy, 0, groups)
|
||||
for i := 0; i < groups; i++ {
|
||||
groupID := fmt.Sprintf("group-%d", i)
|
||||
group := &types.Group{
|
||||
@@ -261,7 +282,7 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se
|
||||
peerIndex := i*(peers/groups) + j
|
||||
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
|
||||
}
|
||||
acc.Groups[groupID] = group
|
||||
account.Groups[groupID] = group
|
||||
|
||||
// Create a policy for this group
|
||||
policy := &types.Policy{
|
||||
@@ -281,10 +302,10 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se
|
||||
},
|
||||
},
|
||||
}
|
||||
acc.Policies = append(acc.Policies, policy)
|
||||
account.Policies = append(account.Policies, policy)
|
||||
}
|
||||
|
||||
acc.PostureChecks = []*posture.Checks{
|
||||
account.PostureChecks = []*posture.Checks{
|
||||
{
|
||||
ID: "PostureChecksAll",
|
||||
Name: "All",
|
||||
@@ -296,38 +317,52 @@ func PopulateTestData(b *testing.B, am account.Manager, peers, groups, users, se
|
||||
},
|
||||
}
|
||||
|
||||
store := am.GetStore()
|
||||
|
||||
err = store.SaveAccount(context.Background(), acc)
|
||||
err = am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to save account: %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func EvaluateAPIBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) {
|
||||
b.Helper()
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
b.Fatalf("Benchmark %s failed: unexpected status code %d", testCase, recorder.Code)
|
||||
}
|
||||
|
||||
EvaluateBenchmarkResults(b, testCase, duration, module, operation)
|
||||
|
||||
}
|
||||
|
||||
func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, module string, operation string) {
|
||||
func EvaluateBenchmarkResults(b *testing.B, testCase string, duration time.Duration, recorder *httptest.ResponseRecorder, module string, operation string) {
|
||||
b.Helper()
|
||||
|
||||
branch := os.Getenv("GIT_BRANCH")
|
||||
if branch == "" && os.Getenv("CI") == "true" {
|
||||
if branch == "" {
|
||||
b.Fatalf("environment variable GIT_BRANCH is not set")
|
||||
}
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
b.Fatalf("Benchmark %s failed: unexpected status code %d", testCase, recorder.Code)
|
||||
}
|
||||
|
||||
msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6
|
||||
|
||||
gauge := BenchmarkDuration.WithLabelValues(module, operation, testCase, branch)
|
||||
gauge.Set(msPerOp)
|
||||
|
||||
b.ReportMetric(msPerOp, "ms/op")
|
||||
|
||||
}
|
||||
|
||||
func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserAuth, *jwt.Token, error) {
|
||||
userAuth := nbcontext.UserAuth{}
|
||||
|
||||
switch token {
|
||||
case "testUserId", "testAdminId", "testOwnerId", "testServiceUserId", "testServiceAdminId", "blockedUserId":
|
||||
userAuth.UserId = token
|
||||
userAuth.AccountId = "testAccountId"
|
||||
userAuth.Domain = "test.com"
|
||||
userAuth.DomainCategory = "private"
|
||||
case "otherUserId":
|
||||
userAuth.UserId = "otherUserId"
|
||||
userAuth.AccountId = "otherAccountId"
|
||||
userAuth.Domain = "other.com"
|
||||
userAuth.DomainCategory = "private"
|
||||
case "invalidToken":
|
||||
return userAuth, nil, errors.New("invalid token")
|
||||
}
|
||||
|
||||
jwtToken := jwt.New(jwt.SigningMethodHS256)
|
||||
return userAuth, jwtToken, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user