mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-26 04:06:38 +00:00
Merge branch 'main' into feature/relay-integration
This commit is contained in:
@@ -18,6 +18,8 @@ import (
|
||||
|
||||
"github.com/eko/gocache/v3/cache"
|
||||
cacheStore "github.com/eko/gocache/v3/store"
|
||||
"github.com/hashicorp/go-multierror"
|
||||
"github.com/miekg/dns"
|
||||
gocache "github.com/patrickmn/go-cache"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -37,6 +39,7 @@ import (
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -65,6 +68,7 @@ type AccountManager interface {
|
||||
SaveSetupKey(ctx context.Context, accountID string, key *SetupKey, userID string) (*SetupKey, error)
|
||||
CreateUser(ctx context.Context, accountID, initiatorUserID string, key *UserInfo) (*UserInfo, error)
|
||||
DeleteUser(ctx context.Context, accountID, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
|
||||
InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
ListSetupKeys(ctx context.Context, accountID, userID string) ([]*SetupKey, error)
|
||||
SaveUser(ctx context.Context, accountID, initiatorUserID string, update *User) (*UserInfo, error)
|
||||
@@ -98,6 +102,7 @@ type AccountManager interface {
|
||||
SaveGroup(ctx context.Context, accountID, userID string, group *nbgroup.Group) error
|
||||
SaveGroups(ctx context.Context, accountID, userID string, newGroups []*nbgroup.Group) error
|
||||
DeleteGroup(ctx context.Context, accountId, userId, groupID string) error
|
||||
DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
ListGroups(ctx context.Context, accountId string) ([]*nbgroup.Group, error)
|
||||
GroupAddPeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
GroupDeletePeer(ctx context.Context, accountId, groupID, peerID string) error
|
||||
@@ -170,6 +175,8 @@ type DefaultAccountManager struct {
|
||||
userDeleteFromIDPEnabled bool
|
||||
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
|
||||
metrics telemetry.AppMetrics
|
||||
}
|
||||
|
||||
// Settings represents Account settings structure that can be modified via API and Dashboard
|
||||
@@ -401,8 +408,16 @@ func (a *Account) GetGroup(groupID string) *nbgroup.Group {
|
||||
return a.Groups[groupID]
|
||||
}
|
||||
|
||||
// GetPeerNetworkMap returns a group by ID if exists, nil otherwise
|
||||
func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain string, validatedPeersMap map[string]struct{}) *NetworkMap {
|
||||
// GetPeerNetworkMap returns the networkmap for the given peer ID.
|
||||
func (a *Account) GetPeerNetworkMap(
|
||||
ctx context.Context,
|
||||
peerID string,
|
||||
peersCustomZone nbdns.CustomZone,
|
||||
validatedPeersMap map[string]struct{},
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
|
||||
peer := a.Peers[peerID]
|
||||
if peer == nil {
|
||||
return &NetworkMap{
|
||||
@@ -438,7 +453,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
|
||||
|
||||
if dnsManagementStatus {
|
||||
var zones []nbdns.CustomZone
|
||||
peersCustomZone := getPeersCustomZone(ctx, a, dnsDomain)
|
||||
|
||||
if peersCustomZone.Domain != "" {
|
||||
zones = append(zones, peersCustomZone)
|
||||
}
|
||||
@@ -446,7 +461,7 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
|
||||
dnsUpdate.NameServerGroups = getPeerNSGroups(a, peerID)
|
||||
}
|
||||
|
||||
return &NetworkMap{
|
||||
nm := &NetworkMap{
|
||||
Peers: peersToConnect,
|
||||
Network: a.Network.Copy(),
|
||||
Routes: routesUpdate,
|
||||
@@ -454,6 +469,60 @@ func (a *Account) GetPeerNetworkMap(ctx context.Context, peerID, dnsDomain strin
|
||||
OfflinePeers: expiredPeers,
|
||||
FirewallRules: firewallRules,
|
||||
}
|
||||
|
||||
if metrics != nil {
|
||||
objectCount := int64(len(peersToConnect) + len(expiredPeers) + len(routesUpdate) + len(firewallRules))
|
||||
metrics.CountNetworkMapObjects(objectCount)
|
||||
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
|
||||
}
|
||||
|
||||
return nm
|
||||
}
|
||||
|
||||
func (a *Account) GetPeersCustomZone(ctx context.Context, dnsDomain string) nbdns.CustomZone {
|
||||
var merr *multierror.Error
|
||||
|
||||
if dnsDomain == "" {
|
||||
log.WithContext(ctx).Error("no dns domain is set, returning empty zone")
|
||||
return nbdns.CustomZone{}
|
||||
}
|
||||
|
||||
customZone := nbdns.CustomZone{
|
||||
Domain: dns.Fqdn(dnsDomain),
|
||||
Records: make([]nbdns.SimpleRecord, 0, len(a.Peers)),
|
||||
}
|
||||
|
||||
domainSuffix := "." + dnsDomain
|
||||
|
||||
var sb strings.Builder
|
||||
for _, peer := range a.Peers {
|
||||
if peer.DNSLabel == "" {
|
||||
merr = multierror.Append(merr, fmt.Errorf("peer %s has an empty DNS label", peer.Name))
|
||||
continue
|
||||
}
|
||||
|
||||
sb.Grow(len(peer.DNSLabel) + len(domainSuffix))
|
||||
sb.WriteString(peer.DNSLabel)
|
||||
sb.WriteString(domainSuffix)
|
||||
|
||||
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
|
||||
Name: sb.String(),
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: defaultTTL,
|
||||
RData: peer.IP.String(),
|
||||
})
|
||||
|
||||
sb.Reset()
|
||||
}
|
||||
|
||||
go func() {
|
||||
if merr != nil {
|
||||
log.WithContext(ctx).Errorf("error generating custom zone for account %s: %v", a.Id, merr)
|
||||
}
|
||||
}()
|
||||
|
||||
return customZone
|
||||
}
|
||||
|
||||
// GetExpiredPeers returns peers that have been expired
|
||||
@@ -853,7 +922,7 @@ func (a *Account) UserGroupsAddToPeers(userID string, groups ...string) {
|
||||
func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
||||
for _, gid := range groups {
|
||||
group, ok := a.Groups[gid]
|
||||
if !ok {
|
||||
if !ok || group.Name == "All" {
|
||||
continue
|
||||
}
|
||||
update := make([]string, 0, len(group.Peers))
|
||||
@@ -871,10 +940,18 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
||||
}
|
||||
|
||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||
func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
||||
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, geo *geolocation.Geolocation,
|
||||
func BuildManager(
|
||||
ctx context.Context,
|
||||
store Store,
|
||||
peersUpdateManager *PeersUpdateManager,
|
||||
idpManager idp.Manager,
|
||||
singleAccountModeDomain string,
|
||||
dnsDomain string,
|
||||
eventStore activity.Store,
|
||||
geo *geolocation.Geolocation,
|
||||
userDeleteFromIDPEnabled bool,
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||
metrics telemetry.AppMetrics,
|
||||
) (*DefaultAccountManager, error) {
|
||||
am := &DefaultAccountManager{
|
||||
Store: store,
|
||||
@@ -889,6 +966,7 @@ func BuildManager(ctx context.Context, store Store, peersUpdateManager *PeersUpd
|
||||
peerLoginExpiry: NewDefaultScheduler(),
|
||||
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
metrics: metrics,
|
||||
}
|
||||
allAccounts := store.GetAllAccounts(ctx)
|
||||
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
||||
@@ -1994,6 +2072,28 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
|
||||
return am.Store.GetAccountIDByPeerPubKey(ctx, peerKey)
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) {
|
||||
user, err := am.Store.GetUserByUserID(ctx, peer.UserID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
err = checkIfPeerOwnerIsBlocked(peer, user)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if peerLoginExpired(ctx, peer, settings) {
|
||||
err = am.handleExpiredPeer(ctx, user, peer)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// addAllGroup to account object if it doesn't exist
|
||||
func addAllGroup(account *Account) error {
|
||||
if len(account.Groups) == 0 {
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/jwtclaims"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -410,7 +411,8 @@ func TestAccount_GetPeerNetworkMap(t *testing.T) {
|
||||
validatedPeers[p] = struct{}{}
|
||||
}
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, "netbird.io", validatedPeers)
|
||||
customZone := account.GetPeersCustomZone(context.Background(), "netbird.io")
|
||||
networkMap := account.GetPeerNetworkMap(context.Background(), testCase.peerID, customZone, validatedPeers, nil)
|
||||
assert.Len(t, networkMap.Peers, len(testCase.expectedPeers))
|
||||
assert.Len(t, networkMap.OfflinePeers, len(testCase.expectedOfflinePeers))
|
||||
}
|
||||
@@ -2293,7 +2295,13 @@ func TestAccount_UserGroupsRemoveFromPeers(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
type TB interface {
|
||||
Cleanup(func())
|
||||
Helper()
|
||||
TempDir() string
|
||||
}
|
||||
|
||||
func createManager(t TB) (*DefaultAccountManager, error) {
|
||||
t.Helper()
|
||||
|
||||
store, err := createStore(t)
|
||||
@@ -2302,7 +2310,12 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{})
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
manager, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -2310,7 +2323,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return manager, nil
|
||||
}
|
||||
|
||||
func createStore(t *testing.T) (Store, error) {
|
||||
func createStore(t TB) (Store, error) {
|
||||
t.Helper()
|
||||
dataDir := t.TempDir()
|
||||
store, cleanUp, err := NewTestStoreFromJson(context.Background(), dataDir)
|
||||
|
||||
@@ -4,8 +4,8 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
@@ -17,6 +17,50 @@ import (
|
||||
|
||||
const defaultTTL = 300
|
||||
|
||||
// DNSConfigCache is a thread-safe cache for DNS configuration components
|
||||
type DNSConfigCache struct {
|
||||
CustomZones sync.Map
|
||||
NameServerGroups sync.Map
|
||||
}
|
||||
|
||||
// GetCustomZone retrieves a cached custom zone
|
||||
func (c *DNSConfigCache) GetCustomZone(key string) (*proto.CustomZone, bool) {
|
||||
if c == nil {
|
||||
return nil, false
|
||||
}
|
||||
if value, ok := c.CustomZones.Load(key); ok {
|
||||
return value.(*proto.CustomZone), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// SetCustomZone stores a custom zone in the cache
|
||||
func (c *DNSConfigCache) SetCustomZone(key string, value *proto.CustomZone) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.CustomZones.Store(key, value)
|
||||
}
|
||||
|
||||
// GetNameServerGroup retrieves a cached name server group
|
||||
func (c *DNSConfigCache) GetNameServerGroup(key string) (*proto.NameServerGroup, bool) {
|
||||
if c == nil {
|
||||
return nil, false
|
||||
}
|
||||
if value, ok := c.NameServerGroups.Load(key); ok {
|
||||
return value.(*proto.NameServerGroup), true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// SetNameServerGroup stores a name server group in the cache
|
||||
func (c *DNSConfigCache) SetNameServerGroup(key string, value *proto.NameServerGroup) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.NameServerGroups.Store(key, value)
|
||||
}
|
||||
|
||||
type lookupMap map[string]struct{}
|
||||
|
||||
// DNSSettings defines dns settings at the account level
|
||||
@@ -113,69 +157,73 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
|
||||
return nil
|
||||
}
|
||||
|
||||
func toProtocolDNSConfig(update nbdns.Config) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{ServiceEnable: update.ServiceEnable}
|
||||
// toProtocolDNSConfig converts nbdns.Config to proto.DNSConfig using the cache
|
||||
func toProtocolDNSConfig(update nbdns.Config, cache *DNSConfigCache) *proto.DNSConfig {
|
||||
protoUpdate := &proto.DNSConfig{
|
||||
ServiceEnable: update.ServiceEnable,
|
||||
CustomZones: make([]*proto.CustomZone, 0, len(update.CustomZones)),
|
||||
NameServerGroups: make([]*proto.NameServerGroup, 0, len(update.NameServerGroups)),
|
||||
}
|
||||
|
||||
for _, zone := range update.CustomZones {
|
||||
protoZone := &proto.CustomZone{Domain: zone.Domain}
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
cacheKey := zone.Domain
|
||||
if cachedZone, exists := cache.GetCustomZone(cacheKey); exists {
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, cachedZone)
|
||||
} else {
|
||||
protoZone := convertToProtoCustomZone(zone)
|
||||
cache.SetCustomZone(cacheKey, protoZone)
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
protoUpdate.CustomZones = append(protoUpdate.CustomZones, protoZone)
|
||||
}
|
||||
|
||||
for _, nsGroup := range update.NameServerGroups {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
cacheKey := nsGroup.ID
|
||||
if cachedGroup, exists := cache.GetNameServerGroup(cacheKey); exists {
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, cachedGroup)
|
||||
} else {
|
||||
protoGroup := convertToProtoNameServerGroup(nsGroup)
|
||||
cache.SetNameServerGroup(cacheKey, protoGroup)
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoNS := &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
}
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, protoNS)
|
||||
}
|
||||
protoUpdate.NameServerGroups = append(protoUpdate.NameServerGroups, protoGroup)
|
||||
}
|
||||
|
||||
return protoUpdate
|
||||
}
|
||||
|
||||
func getPeersCustomZone(ctx context.Context, account *Account, dnsDomain string) nbdns.CustomZone {
|
||||
if dnsDomain == "" {
|
||||
log.WithContext(ctx).Errorf("no dns domain is set, returning empty zone")
|
||||
return nbdns.CustomZone{}
|
||||
// Helper function to convert nbdns.CustomZone to proto.CustomZone
|
||||
func convertToProtoCustomZone(zone nbdns.CustomZone) *proto.CustomZone {
|
||||
protoZone := &proto.CustomZone{
|
||||
Domain: zone.Domain,
|
||||
Records: make([]*proto.SimpleRecord, 0, len(zone.Records)),
|
||||
}
|
||||
|
||||
customZone := nbdns.CustomZone{
|
||||
Domain: dns.Fqdn(dnsDomain),
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if peer.DNSLabel == "" {
|
||||
log.WithContext(ctx).Errorf("found a peer with empty dns label. It was probably caused by a invalid character in its name. Peer Name: %s", peer.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
customZone.Records = append(customZone.Records, nbdns.SimpleRecord{
|
||||
Name: dns.Fqdn(peer.DNSLabel + "." + dnsDomain),
|
||||
Type: int(dns.TypeA),
|
||||
Class: nbdns.DefaultClass,
|
||||
TTL: defaultTTL,
|
||||
RData: peer.IP.String(),
|
||||
for _, record := range zone.Records {
|
||||
protoZone.Records = append(protoZone.Records, &proto.SimpleRecord{
|
||||
Name: record.Name,
|
||||
Type: int64(record.Type),
|
||||
Class: record.Class,
|
||||
TTL: int64(record.TTL),
|
||||
RData: record.RData,
|
||||
})
|
||||
}
|
||||
return protoZone
|
||||
}
|
||||
|
||||
return customZone
|
||||
// Helper function to convert nbdns.NameServerGroup to proto.NameServerGroup
|
||||
func convertToProtoNameServerGroup(nsGroup *nbdns.NameServerGroup) *proto.NameServerGroup {
|
||||
protoGroup := &proto.NameServerGroup{
|
||||
Primary: nsGroup.Primary,
|
||||
Domains: nsGroup.Domains,
|
||||
SearchDomainsEnabled: nsGroup.SearchDomainsEnabled,
|
||||
NameServers: make([]*proto.NameServer, 0, len(nsGroup.NameServers)),
|
||||
}
|
||||
for _, ns := range nsGroup.NameServers {
|
||||
protoGroup.NameServers = append(protoGroup.NameServers, &proto.NameServer{
|
||||
IP: ns.IP.String(),
|
||||
Port: int64(ns.Port),
|
||||
NSType: int64(ns.NSType),
|
||||
})
|
||||
}
|
||||
return protoGroup
|
||||
}
|
||||
|
||||
func getPeerNSGroups(account *Account, peerID string) []*nbdns.NameServerGroup {
|
||||
|
||||
@@ -2,9 +2,14 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/dns"
|
||||
@@ -195,7 +200,11 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{})
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
}
|
||||
|
||||
func createDNSStore(t *testing.T) (Store, error) {
|
||||
@@ -320,3 +329,150 @@ func initTestDNSAccount(t *testing.T, am *DefaultAccountManager) (*Account, erro
|
||||
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
}
|
||||
|
||||
func generateTestData(size int) nbdns.Config {
|
||||
config := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: make([]nbdns.CustomZone, size),
|
||||
NameServerGroups: make([]*nbdns.NameServerGroup, size),
|
||||
}
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
config.CustomZones[i] = nbdns.CustomZone{
|
||||
Domain: fmt.Sprintf("domain%d.com", i),
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{
|
||||
Name: fmt.Sprintf("record%d", i),
|
||||
Type: 1,
|
||||
Class: "IN",
|
||||
TTL: 3600,
|
||||
RData: "192.168.1.1",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config.NameServerGroups[i] = &nbdns.NameServerGroup{
|
||||
ID: fmt.Sprintf("group%d", i),
|
||||
Primary: i == 0,
|
||||
Domains: []string{fmt.Sprintf("domain%d.com", i)},
|
||||
SearchDomainsEnabled: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
Port: 53,
|
||||
NSType: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
func BenchmarkToProtocolDNSConfig(b *testing.B) {
|
||||
sizes := []int{10, 100, 1000}
|
||||
|
||||
for _, size := range sizes {
|
||||
testData := generateTestData(size)
|
||||
|
||||
b.Run(fmt.Sprintf("WithCache-Size%d", size), func(b *testing.B) {
|
||||
cache := &DNSConfigCache{}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
toProtocolDNSConfig(testData, cache)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run(fmt.Sprintf("WithoutCache-Size%d", size), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
cache := &DNSConfigCache{}
|
||||
toProtocolDNSConfig(testData, cache)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
var cache DNSConfigCache
|
||||
|
||||
// Create two different configs
|
||||
config1 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.com",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "www", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.1"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group1",
|
||||
Name: "Group 1",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.8.8"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
config2 := nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
CustomZones: []nbdns.CustomZone{
|
||||
{
|
||||
Domain: "example.org",
|
||||
Records: []nbdns.SimpleRecord{
|
||||
{Name: "mail", Type: 1, Class: "IN", TTL: 300, RData: "192.168.1.2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
ID: "group2",
|
||||
Name: "Group 2",
|
||||
NameServers: []nbdns.NameServer{
|
||||
{IP: netip.MustParseAddr("8.8.4.4"), Port: 53},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// First run with config1
|
||||
result1 := toProtocolDNSConfig(config1, &cache)
|
||||
|
||||
// Second run with config2
|
||||
result2 := toProtocolDNSConfig(config2, &cache)
|
||||
|
||||
// Third run with config1 again
|
||||
result3 := toProtocolDNSConfig(config1, &cache)
|
||||
|
||||
// Verify that result1 and result3 are identical
|
||||
if !reflect.DeepEqual(result1, result3) {
|
||||
t.Errorf("Results are not identical when run with the same input. Expected %v, got %v", result1, result3)
|
||||
}
|
||||
|
||||
// Verify that result2 is different from result1 and result3
|
||||
if reflect.DeepEqual(result1, result2) || reflect.DeepEqual(result2, result3) {
|
||||
t.Errorf("Results should be different for different inputs")
|
||||
}
|
||||
|
||||
// Verify that the cache contains elements from both configs
|
||||
if _, exists := cache.GetCustomZone("example.com"); !exists {
|
||||
t.Errorf("Cache should contain custom zone for example.com")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetCustomZone("example.org"); !exists {
|
||||
t.Errorf("Cache should contain custom zone for example.org")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group1"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group1'")
|
||||
}
|
||||
|
||||
if _, exists := cache.GetNameServerGroup("group2"); !exists {
|
||||
t.Errorf("Cache should contain name server group 'group2'")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -469,6 +469,35 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User,
|
||||
return account.Users[userID].Copy(), nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) {
|
||||
accountID, ok := s.UserID2AccountID[userID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists")
|
||||
}
|
||||
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return account.Users[userID].Copy(), nil
|
||||
}
|
||||
|
||||
func (s *FileStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
account, err := s.getAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupsSlice := make([]*nbgroup.Group, 0, len(account.Groups))
|
||||
|
||||
for _, group := range account.Groups {
|
||||
groupsSlice = append(groupsSlice, group)
|
||||
}
|
||||
|
||||
return groupsSlice, nil
|
||||
}
|
||||
|
||||
// GetAllAccounts returns all accounts
|
||||
func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) {
|
||||
s.mux.Lock()
|
||||
|
||||
@@ -2,8 +2,12 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
@@ -243,7 +247,7 @@ func difference(a, b []string) []string {
|
||||
return diff
|
||||
}
|
||||
|
||||
// DeleteGroup object of the peers
|
||||
// DeleteGroup object of the peers.
|
||||
func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountId)
|
||||
defer unlock()
|
||||
@@ -253,96 +257,14 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
||||
return err
|
||||
}
|
||||
|
||||
g, ok := account.Groups[groupID]
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
// disable a deleting integration group if the initiator is not an admin service user
|
||||
if g.Issued == nbgroup.GroupIssuedIntegration {
|
||||
executingUser := account.Users[userId]
|
||||
if executingUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
||||
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
|
||||
}
|
||||
if err = validateDeleteGroup(account, group, userId); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// check route links
|
||||
for _, r := range account.Routes {
|
||||
for _, g := range r.Groups {
|
||||
if g == groupID {
|
||||
return &GroupLinkError{"route", string(r.NetID)}
|
||||
}
|
||||
}
|
||||
for _, g := range r.PeerGroups {
|
||||
if g == groupID {
|
||||
return &GroupLinkError{"route", string(r.NetID)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check DNS links
|
||||
for _, dns := range account.NameServerGroups {
|
||||
for _, g := range dns.Groups {
|
||||
if g == groupID {
|
||||
return &GroupLinkError{"name server groups", dns.Name}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check ACL links
|
||||
for _, policy := range account.Policies {
|
||||
for _, rule := range policy.Rules {
|
||||
for _, src := range rule.Sources {
|
||||
if src == groupID {
|
||||
return &GroupLinkError{"policy", policy.Name}
|
||||
}
|
||||
}
|
||||
|
||||
for _, dst := range rule.Destinations {
|
||||
if dst == groupID {
|
||||
return &GroupLinkError{"policy", policy.Name}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check setup key links
|
||||
for _, setupKey := range account.SetupKeys {
|
||||
for _, grp := range setupKey.AutoGroups {
|
||||
if grp == groupID {
|
||||
return &GroupLinkError{"setup key", setupKey.Name}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check user links
|
||||
for _, user := range account.Users {
|
||||
for _, grp := range user.AutoGroups {
|
||||
if grp == groupID {
|
||||
return &GroupLinkError{"user", user.Id}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// check DisabledManagementGroups
|
||||
for _, disabledMgmGrp := range account.DNSSettings.DisabledManagementGroups {
|
||||
if disabledMgmGrp == groupID {
|
||||
return &GroupLinkError{"disabled DNS management groups", g.Name}
|
||||
}
|
||||
}
|
||||
|
||||
// check integrated peer validator groups
|
||||
if account.Settings.Extra != nil {
|
||||
for _, integratedPeerValidatorGroups := range account.Settings.Extra.IntegratedValidatorGroups {
|
||||
if groupID == integratedPeerValidatorGroups {
|
||||
return &GroupLinkError{"integrated validator", g.Name}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
delete(account.Groups, groupID)
|
||||
|
||||
account.Network.IncSerial()
|
||||
@@ -350,13 +272,57 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta())
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteGroups deletes groups from an account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
//
|
||||
// If an error occurs while deleting a group, the function skips it and continues deleting other groups.
|
||||
// Errors are collected and returned at the end.
|
||||
func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
|
||||
account, err := am.Store.GetAccount(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var allErrors error
|
||||
|
||||
deletedGroups := make([]*nbgroup.Group, 0, len(groupIDs))
|
||||
for _, groupID := range groupIDs {
|
||||
group, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := validateDeleteGroup(account, group, userId); err != nil {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete group %s: %w", groupID, err))
|
||||
continue
|
||||
}
|
||||
|
||||
delete(account.Groups, groupID)
|
||||
deletedGroups = append(deletedGroups, group)
|
||||
}
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(ctx, account); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, g := range deletedGroups {
|
||||
am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta())
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return allErrors
|
||||
}
|
||||
|
||||
// ListGroups objects of the peers
|
||||
func (am *DefaultAccountManager) ListGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
@@ -440,3 +406,102 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateDeleteGroup(account *Account, group *nbgroup.Group, userID string) error {
|
||||
// disable a deleting integration group if the initiator is not an admin service user
|
||||
if group.Issued == nbgroup.GroupIssuedIntegration {
|
||||
executingUser := account.Users[userID]
|
||||
if executingUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
if executingUser.Role != UserRoleAdmin || !executingUser.IsServiceUser {
|
||||
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
|
||||
}
|
||||
}
|
||||
|
||||
if isLinked, linkedRoute := isGroupLinkedToRoute(account.Routes, group.ID); isLinked {
|
||||
return &GroupLinkError{"route", string(linkedRoute.NetID)}
|
||||
}
|
||||
|
||||
if isLinked, linkedDns := isGroupLinkedToDns(account.NameServerGroups, group.ID); isLinked {
|
||||
return &GroupLinkError{"name server groups", linkedDns.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedPolicy := isGroupLinkedToPolicy(account.Policies, group.ID); isLinked {
|
||||
return &GroupLinkError{"policy", linkedPolicy.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(account.SetupKeys, group.ID); isLinked {
|
||||
return &GroupLinkError{"setup key", linkedSetupKey.Name}
|
||||
}
|
||||
|
||||
if isLinked, linkedUser := isGroupLinkedToUser(account.Users, group.ID); isLinked {
|
||||
return &GroupLinkError{"user", linkedUser.Id}
|
||||
}
|
||||
|
||||
if slices.Contains(account.DNSSettings.DisabledManagementGroups, group.ID) {
|
||||
return &GroupLinkError{"disabled DNS management groups", group.Name}
|
||||
}
|
||||
|
||||
if account.Settings.Extra != nil {
|
||||
if slices.Contains(account.Settings.Extra.IntegratedValidatorGroups, group.ID) {
|
||||
return &GroupLinkError{"integrated validator", group.Name}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
|
||||
func isGroupLinkedToRoute(routes map[route.ID]*route.Route, groupID string) (bool, *route.Route) {
|
||||
for _, r := range routes {
|
||||
if slices.Contains(r.Groups, groupID) || slices.Contains(r.PeerGroups, groupID) {
|
||||
return true, r
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
|
||||
func isGroupLinkedToPolicy(policies []*Policy, groupID string) (bool, *Policy) {
|
||||
for _, policy := range policies {
|
||||
for _, rule := range policy.Rules {
|
||||
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
|
||||
return true, policy
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
|
||||
func isGroupLinkedToDns(nameServerGroups map[string]*nbdns.NameServerGroup, groupID string) (bool, *nbdns.NameServerGroup) {
|
||||
for _, dns := range nameServerGroups {
|
||||
for _, g := range dns.Groups {
|
||||
if g == groupID {
|
||||
return true, dns
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
|
||||
func isGroupLinkedToSetupKey(setupKeys map[string]*SetupKey, groupID string) (bool, *SetupKey) {
|
||||
for _, setupKey := range setupKeys {
|
||||
if slices.Contains(setupKey.AutoGroups, groupID) {
|
||||
return true, setupKey
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// isGroupLinkedToUser checks if a group is linked to any user in the account.
|
||||
func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) {
|
||||
for _, user := range users {
|
||||
if slices.Contains(user.AutoGroups, groupID) {
|
||||
return true, user
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -3,12 +3,14 @@ package server
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -21,7 +23,7 @@ func TestDefaultAccountManager_CreateGroup(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestGroupAccount(am)
|
||||
_, account, err := initTestGroupAccount(am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
@@ -56,7 +58,7 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestGroupAccount(am)
|
||||
_, account, err := initTestGroupAccount(am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
@@ -132,7 +134,136 @@ func TestDefaultAccountManager_DeleteGroup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
func TestDefaultAccountManager_DeleteGroups(t *testing.T) {
|
||||
am, err := createManager(t)
|
||||
assert.NoError(t, err, "Failed to create account manager")
|
||||
|
||||
manager, account, err := initTestGroupAccount(am)
|
||||
assert.NoError(t, err, "Failed to init testing account")
|
||||
|
||||
groups := make([]*nbgroup.Group, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
groups[i] = &nbgroup.Group{
|
||||
ID: fmt.Sprintf("group-%d", i+1),
|
||||
AccountID: account.Id,
|
||||
Name: fmt.Sprintf("group-%d", i+1),
|
||||
Issued: nbgroup.GroupIssuedAPI,
|
||||
}
|
||||
}
|
||||
|
||||
err = manager.SaveGroups(context.Background(), account.Id, groupAdminUserID, groups)
|
||||
assert.NoError(t, err, "Failed to save test groups")
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
groupIDs []string
|
||||
expectedReasons []string
|
||||
expectedDeleted []string
|
||||
expectedNotDeleted []string
|
||||
}{
|
||||
{
|
||||
name: "route",
|
||||
groupIDs: []string{"grp-for-route"},
|
||||
expectedReasons: []string{"route"},
|
||||
},
|
||||
{
|
||||
name: "route with peer groups",
|
||||
groupIDs: []string{"grp-for-route2"},
|
||||
expectedReasons: []string{"route"},
|
||||
},
|
||||
{
|
||||
name: "name server groups",
|
||||
groupIDs: []string{"grp-for-name-server-grp"},
|
||||
expectedReasons: []string{"name server groups"},
|
||||
},
|
||||
{
|
||||
name: "policy",
|
||||
groupIDs: []string{"grp-for-policies"},
|
||||
expectedReasons: []string{"policy"},
|
||||
},
|
||||
{
|
||||
name: "setup keys",
|
||||
groupIDs: []string{"grp-for-keys"},
|
||||
expectedReasons: []string{"setup key"},
|
||||
},
|
||||
{
|
||||
name: "users",
|
||||
groupIDs: []string{"grp-for-users"},
|
||||
expectedReasons: []string{"user"},
|
||||
},
|
||||
{
|
||||
name: "integration",
|
||||
groupIDs: []string{"grp-for-integration"},
|
||||
expectedReasons: []string{"only service users with admin power can delete integration group"},
|
||||
},
|
||||
{
|
||||
name: "successfully delete multiple groups",
|
||||
groupIDs: []string{"group-1", "group-2"},
|
||||
expectedDeleted: []string{"group-1", "group-2"},
|
||||
},
|
||||
{
|
||||
name: "delete non-existent group",
|
||||
groupIDs: []string{"non-existent-group"},
|
||||
expectedDeleted: []string{"non-existent-group"},
|
||||
},
|
||||
{
|
||||
name: "delete multiple groups with mixed results",
|
||||
groupIDs: []string{"group-3", "grp-for-policies", "group-4", "grp-for-users"},
|
||||
expectedReasons: []string{"policy", "user"},
|
||||
expectedDeleted: []string{"group-3", "group-4"},
|
||||
expectedNotDeleted: []string{"grp-for-policies", "grp-for-users"},
|
||||
},
|
||||
{
|
||||
name: "delete groups with multiple errors",
|
||||
groupIDs: []string{"grp-for-policies", "grp-for-users"},
|
||||
expectedReasons: []string{"policy", "user"},
|
||||
expectedNotDeleted: []string{"grp-for-policies", "grp-for-users"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = am.DeleteGroups(context.Background(), account.Id, groupAdminUserID, tc.groupIDs)
|
||||
if len(tc.expectedReasons) > 0 {
|
||||
assert.Error(t, err)
|
||||
var foundExpectedErrors int
|
||||
|
||||
wrappedErr, ok := err.(interface{ Unwrap() []error })
|
||||
assert.Equal(t, ok, true)
|
||||
|
||||
for _, e := range wrappedErr.Unwrap() {
|
||||
var sErr *status.Error
|
||||
if errors.As(e, &sErr) {
|
||||
assert.Contains(t, tc.expectedReasons, sErr.Message, "unexpected error message")
|
||||
foundExpectedErrors++
|
||||
}
|
||||
|
||||
var gErr *GroupLinkError
|
||||
if errors.As(e, &gErr) {
|
||||
assert.Contains(t, tc.expectedReasons, gErr.Resource, "unexpected error resource")
|
||||
foundExpectedErrors++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, len(tc.expectedReasons), foundExpectedErrors, "not all expected errors were found")
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
for _, groupID := range tc.expectedDeleted {
|
||||
_, err := am.GetGroup(context.Background(), account.Id, groupID, groupAdminUserID)
|
||||
assert.Error(t, err, "group should have been deleted: %s", groupID)
|
||||
}
|
||||
|
||||
for _, groupID := range tc.expectedNotDeleted {
|
||||
group, err := am.GetGroup(context.Background(), account.Id, groupID, groupAdminUserID)
|
||||
assert.NoError(t, err, "group should not have been deleted: %s", groupID)
|
||||
assert.NotNil(t, group, "group should exist: %s", groupID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func initTestGroupAccount(am *DefaultAccountManager) (*DefaultAccountManager, *Account, error) {
|
||||
accountID := "testingAcc"
|
||||
domain := "example.com"
|
||||
|
||||
@@ -236,7 +367,7 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
|
||||
err := am.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForRoute)
|
||||
@@ -247,5 +378,9 @@ func initTestGroupAccount(am *DefaultAccountManager) (*Account, error) {
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForUsers)
|
||||
_ = am.SaveGroup(context.Background(), accountID, groupAdminUserID, groupForIntegration)
|
||||
|
||||
return am.Store.GetAccount(context.Background(), account.Id)
|
||||
acc, err := am.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return am, acc, nil
|
||||
}
|
||||
|
||||
@@ -256,7 +256,7 @@ func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string
|
||||
}
|
||||
|
||||
if err := s.accountManager.CheckUserAccessByJWTGroups(ctx, claims); err != nil {
|
||||
return "", status.Errorf(codes.PermissionDenied, err.Error())
|
||||
return "", status.Error(codes.PermissionDenied, err.Error())
|
||||
}
|
||||
|
||||
return claims.UserId, nil
|
||||
@@ -267,15 +267,15 @@ func mapError(ctx context.Context, err error) error {
|
||||
if e, ok := internalStatus.FromError(err); ok {
|
||||
switch e.Type() {
|
||||
case internalStatus.PermissionDenied:
|
||||
return status.Errorf(codes.PermissionDenied, e.Message)
|
||||
return status.Error(codes.PermissionDenied, e.Message)
|
||||
case internalStatus.Unauthorized:
|
||||
return status.Errorf(codes.PermissionDenied, e.Message)
|
||||
return status.Error(codes.PermissionDenied, e.Message)
|
||||
case internalStatus.Unauthenticated:
|
||||
return status.Errorf(codes.PermissionDenied, e.Message)
|
||||
return status.Error(codes.PermissionDenied, e.Message)
|
||||
case internalStatus.PreconditionFailed:
|
||||
return status.Errorf(codes.FailedPrecondition, e.Message)
|
||||
return status.Error(codes.FailedPrecondition, e.Message)
|
||||
case internalStatus.NotFound:
|
||||
return status.Errorf(codes.NotFound, e.Message)
|
||||
return status.Error(codes.NotFound, e.Message)
|
||||
default:
|
||||
}
|
||||
}
|
||||
@@ -550,53 +550,46 @@ func toPeerConfig(peer *nbpeer.Peer, network *Network, dnsName string) *proto.Pe
|
||||
}
|
||||
}
|
||||
|
||||
func toRemotePeerConfig(peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||
remotePeers := []*proto.RemotePeerConfig{}
|
||||
for _, rPeer := range peers {
|
||||
fqdn := rPeer.FQDN(dnsName)
|
||||
remotePeers = append(remotePeers, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: []string{fmt.Sprintf(AllowedIPsFormat, rPeer.IP)},
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: fqdn,
|
||||
})
|
||||
}
|
||||
return remotePeers
|
||||
}
|
||||
|
||||
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNRelayToken, relayCredentials *TURNRelayToken, networkMap *NetworkMap, dnsName string, checks []*posture.Checks) *proto.SyncResponse {
|
||||
wtConfig := toWiretrusteeConfig(config, turnCredentials, relayCredentials)
|
||||
|
||||
pConfig := toPeerConfig(peer, networkMap.Network, dnsName)
|
||||
|
||||
remotePeers := toRemotePeerConfig(networkMap.Peers, dnsName)
|
||||
|
||||
routesUpdate := toProtocolRoutes(networkMap.Routes)
|
||||
|
||||
dnsUpdate := toProtocolDNSConfig(networkMap.DNSConfig)
|
||||
|
||||
offlinePeers := toRemotePeerConfig(networkMap.OfflinePeers, dnsName)
|
||||
|
||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
||||
|
||||
return &proto.SyncResponse{
|
||||
WiretrusteeConfig: wtConfig,
|
||||
PeerConfig: pConfig,
|
||||
RemotePeers: remotePeers,
|
||||
RemotePeersIsEmpty: len(remotePeers) == 0,
|
||||
func toSyncResponse(ctx context.Context, config *Config, peer *nbpeer.Peer, turnCredentials *TURNRelayToken, relayCredentials *TURNRelayToken, networkMap *NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache) *proto.SyncResponse {
|
||||
response := &proto.SyncResponse{
|
||||
WiretrusteeConfig: toWiretrusteeConfig(config, turnCredentials, relayCredentials),
|
||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName),
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
PeerConfig: pConfig,
|
||||
RemotePeers: remotePeers,
|
||||
OfflinePeers: offlinePeers,
|
||||
RemotePeersIsEmpty: len(remotePeers) == 0,
|
||||
Routes: routesUpdate,
|
||||
DNSConfig: dnsUpdate,
|
||||
FirewallRules: firewallRules,
|
||||
FirewallRulesIsEmpty: len(firewallRules) == 0,
|
||||
Serial: networkMap.Network.CurrentSerial(),
|
||||
Routes: toProtocolRoutes(networkMap.Routes),
|
||||
DNSConfig: toProtocolDNSConfig(networkMap.DNSConfig, dnsCache),
|
||||
},
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
response.NetworkMap.PeerConfig = response.PeerConfig
|
||||
|
||||
allPeers := make([]*proto.RemotePeerConfig, 0, len(networkMap.Peers)+len(networkMap.OfflinePeers))
|
||||
allPeers = appendRemotePeerConfig(allPeers, networkMap.Peers, dnsName)
|
||||
response.RemotePeers = allPeers
|
||||
response.NetworkMap.RemotePeers = allPeers
|
||||
response.RemotePeersIsEmpty = len(allPeers) == 0
|
||||
response.NetworkMap.RemotePeersIsEmpty = response.RemotePeersIsEmpty
|
||||
|
||||
response.NetworkMap.OfflinePeers = appendRemotePeerConfig(nil, networkMap.OfflinePeers, dnsName)
|
||||
|
||||
firewallRules := toProtocolFirewallRules(networkMap.FirewallRules)
|
||||
response.NetworkMap.FirewallRules = firewallRules
|
||||
response.NetworkMap.FirewallRulesIsEmpty = len(firewallRules) == 0
|
||||
|
||||
return response
|
||||
}
|
||||
|
||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||
for _, rPeer := range peers {
|
||||
dst = append(dst, &proto.RemotePeerConfig{
|
||||
WgPubKey: rPeer.Key,
|
||||
AllowedIps: []string{rPeer.IP.String() + "/32"},
|
||||
SshConfig: &proto.SSHConfig{SshPubKey: []byte(rPeer.SSHKey)},
|
||||
Fqdn: rPeer.FQDN(dnsName),
|
||||
})
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// IsHealthy indicates whether the service is healthy
|
||||
@@ -615,7 +608,7 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
|
||||
if s.config.TURNConfig.TimeBasedCredentials {
|
||||
turnCredentials = trt
|
||||
}
|
||||
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, trt, networkMap, s.accountManager.GetDNSDomain(), postureChecks)
|
||||
plainResp := toSyncResponse(ctx, s.config, peer, turnCredentials, trt, networkMap, s.accountManager.GetDNSDomain(), postureChecks, nil)
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||
if err != nil {
|
||||
|
||||
@@ -71,7 +71,8 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee
|
||||
return
|
||||
}
|
||||
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
|
||||
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
|
||||
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
@@ -115,7 +116,9 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account,
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validPeers)
|
||||
|
||||
customZone := account.GetPeersCustomZone(ctx, h.accountManager.GetDNSDomain())
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, customZone, validPeers, nil)
|
||||
accessiblePeers := toAccessiblePeers(netMap, dnsDomain)
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
@@ -194,9 +197,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
groupMinimumInfo := toGroupsInfo(account.Groups, peer.ID)
|
||||
|
||||
accessiblePeerNumbers, _ := h.accessiblePeersNumber(r.Context(), account, peer.ID)
|
||||
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, accessiblePeerNumbers))
|
||||
respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0))
|
||||
}
|
||||
|
||||
validPeersMap, err := h.accountManager.GetValidatedPeers(account)
|
||||
@@ -210,16 +211,6 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, respBody)
|
||||
}
|
||||
|
||||
func (h *PeersHandler) accessiblePeersNumber(ctx context.Context, account *server.Account, peerID string) (int, error) {
|
||||
validatedPeersMap, err := h.accountManager.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
netMap := account.GetPeerNetworkMap(ctx, peerID, h.accountManager.GetDNSDomain(), validatedPeersMap)
|
||||
return len(netMap.Peers) + len(netMap.OfflinePeers), nil
|
||||
}
|
||||
|
||||
func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approvedPeersMap map[string]struct{}) {
|
||||
for _, peer := range respBody {
|
||||
_, ok := approvedPeersMap[peer.Id]
|
||||
|
||||
@@ -46,7 +46,7 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH
|
||||
testPostureChecks[postureChecks.ID] = postureChecks
|
||||
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, err.Error())
|
||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -3,6 +3,7 @@ package idp
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -44,14 +45,14 @@ type mockJsonParser struct {
|
||||
|
||||
func (m *mockJsonParser) Marshal(v interface{}) ([]byte, error) {
|
||||
if m.marshalErrorString != "" {
|
||||
return nil, fmt.Errorf(m.marshalErrorString)
|
||||
return nil, errors.New(m.marshalErrorString)
|
||||
}
|
||||
return m.jsonParser.Marshal(v)
|
||||
}
|
||||
|
||||
func (m *mockJsonParser) Unmarshal(data []byte, v interface{}) error {
|
||||
if m.unmarshalErrorString != "" {
|
||||
return fmt.Errorf(m.unmarshalErrorString)
|
||||
return errors.New(m.unmarshalErrorString)
|
||||
}
|
||||
return m.jsonParser.Unmarshal(data, v)
|
||||
}
|
||||
|
||||
@@ -150,7 +150,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt
|
||||
// If we get here, the required token is missing
|
||||
errorMsg := "required authorization token not found"
|
||||
log.WithContext(ctx).Debugf(" Error: No credentials found (CredentialsOptional=false)")
|
||||
return nil, fmt.Errorf(errorMsg)
|
||||
return nil, errors.New(errorMsg)
|
||||
}
|
||||
|
||||
// Now parse the token
|
||||
@@ -173,7 +173,7 @@ func (m *JWTValidator) ValidateAndParse(ctx context.Context, token string) (*jwt
|
||||
// Check if the parsed token is valid...
|
||||
if !parsedToken.Valid {
|
||||
errorMsg := "token is invalid"
|
||||
log.WithContext(ctx).Debugf(errorMsg)
|
||||
log.WithContext(ctx).Debug(errorMsg)
|
||||
return nil, errors.New(errorMsg)
|
||||
}
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
mgmtProto "github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -419,8 +420,12 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, *DefaultAccoun
|
||||
|
||||
ctx := context.WithValue(context.Background(), formatter.ExecutionContextKey, formatter.SystemSource) //nolint:staticcheck
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
accountManager, err := BuildManager(ctx, store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{})
|
||||
eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, "", err
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
@@ -541,8 +542,13 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
|
||||
|
||||
peersUpdateManager := server.NewPeersUpdateManager(nil)
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted",
|
||||
eventStore, nil, false, MocIntegratedValidator{})
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating metrics: %v", err)
|
||||
}
|
||||
|
||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating a manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -42,6 +42,7 @@ type MockAccountManager struct {
|
||||
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *group.Group) error
|
||||
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*group.Group) error
|
||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||
DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error
|
||||
ListGroupsFunc func(ctx context.Context, accountID string) ([]*group.Group, error)
|
||||
GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error
|
||||
@@ -67,6 +68,7 @@ type MockAccountManager struct {
|
||||
SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *server.User, addIfNotExists bool) (*server.UserInfo, error)
|
||||
SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*server.User, addIfNotExists bool) ([]*server.UserInfo, error)
|
||||
DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error
|
||||
DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error
|
||||
CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*server.PersonalAccessTokenGenerated, error)
|
||||
DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error
|
||||
GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*server.PersonalAccessToken, error)
|
||||
@@ -326,6 +328,14 @@ func (am *MockAccountManager) DeleteGroup(ctx context.Context, accountId, userId
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteGroup is not implemented")
|
||||
}
|
||||
|
||||
// DeleteGroups mock implementation of DeleteGroups from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeleteGroups(ctx context.Context, accountId, userId string, groupIDs []string) error {
|
||||
if am.DeleteGroupsFunc != nil {
|
||||
return am.DeleteGroupsFunc(ctx, accountId, userId, groupIDs)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteGroups is not implemented")
|
||||
}
|
||||
|
||||
// ListGroups mock implementation of ListGroups from server.AccountManager interface
|
||||
func (am *MockAccountManager) ListGroups(ctx context.Context, accountID string) ([]*group.Group, error) {
|
||||
if am.ListGroupsFunc != nil {
|
||||
@@ -528,6 +538,14 @@ func (am *MockAccountManager) DeleteUser(ctx context.Context, accountID string,
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteUser is not implemented")
|
||||
}
|
||||
|
||||
// DeleteRegularUsers mocks DeleteRegularUsers of the AccountManager interface
|
||||
func (am *MockAccountManager) DeleteRegularUsers(ctx context.Context, accountID string, initiatorUserID string, targetUserIDs []string) error {
|
||||
if am.DeleteRegularUsersFunc != nil {
|
||||
return am.DeleteRegularUsersFunc(ctx, accountID, initiatorUserID, targetUserIDs)
|
||||
}
|
||||
return status.Errorf(codes.Unimplemented, "method DeleteRegularUsers is not implemented")
|
||||
}
|
||||
|
||||
func (am *MockAccountManager) InviteUser(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error {
|
||||
if am.InviteUserFunc != nil {
|
||||
return am.InviteUserFunc(ctx, accountID, initiatorUserID, targetUserID)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -762,7 +763,11 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
}
|
||||
|
||||
func createNSStore(t *testing.T) (Store, error) {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
@@ -65,12 +66,14 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
peers := make([]*nbpeer.Peer, 0)
|
||||
peersMap := make(map[string]*nbpeer.Peer)
|
||||
|
||||
if !user.HasAdminPower() && !user.IsServiceUser && account.Settings.RegularUsersViewBlocked {
|
||||
regularUser := !user.HasAdminPower() && !user.IsServiceUser
|
||||
|
||||
if regularUser && account.Settings.RegularUsersViewBlocked {
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if !(user.HasAdminPower() || user.IsServiceUser) && user.Id != peer.UserID {
|
||||
if regularUser && user.Id != peer.UserID {
|
||||
// only display peers that belong to the current user if the current user is not an admin
|
||||
continue
|
||||
}
|
||||
@@ -79,6 +82,10 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
|
||||
peersMap[peer.ID] = p
|
||||
}
|
||||
|
||||
if !regularUser {
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
// fetch all the peers that have access to the user's peers
|
||||
for _, peer := range peers {
|
||||
aclPeers, _ := account.getPeerConnectionResources(ctx, peer.ID, approvedPeersMap)
|
||||
@@ -316,7 +323,8 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validatedPeers), nil
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
return account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, nil), nil
|
||||
}
|
||||
|
||||
// GetPeerNetwork returns the Network for a given peer
|
||||
@@ -529,7 +537,8 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
}
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, peer)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, am.dnsDomain, approvedPeersMap)
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||
return newPeer, networkMap, postureChecks, nil
|
||||
}
|
||||
|
||||
@@ -540,16 +549,25 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
return nil, nil, nil, status.NewPeerNotRegisteredError()
|
||||
}
|
||||
|
||||
err = checkIfPeerOwnerIsBlocked(peer, account)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
if peer.UserID != "" {
|
||||
log.Infof("Peer has no userID")
|
||||
|
||||
user, err := account.FindUser(peer.UserID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
err = checkIfPeerOwnerIsBlocked(peer, user)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if peerLoginExpired(ctx, peer, account.Settings) {
|
||||
return nil, nil, nil, status.NewPeerLoginExpiredError()
|
||||
}
|
||||
|
||||
peer, updated := updatePeerMeta(peer, sync.Meta, account)
|
||||
updated := peer.UpdateMetaIfNew(sync.Meta)
|
||||
if updated {
|
||||
err = am.Store.SavePeer(ctx, account.Id, peer)
|
||||
if err != nil {
|
||||
@@ -585,7 +603,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
|
||||
}
|
||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
||||
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, validPeersMap), postureChecks, nil
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||
}
|
||||
|
||||
// LoginPeer logs in or registers a peer.
|
||||
@@ -614,31 +633,28 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
// it means that the client has already checked if it needs login and had been through the SSO flow
|
||||
// so, we can skip this check and directly proceed with the login
|
||||
if login.UserID == "" {
|
||||
log.Info("Peer needs login")
|
||||
err = am.checkIFPeerNeedsLoginWithoutLock(ctx, accountID, login)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
unlockAccount := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer unlockAccount()
|
||||
unlockPeer := am.Store.AcquireWriteLockByUID(ctx, login.WireGuardPubKey)
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
if unlockPeer != nil {
|
||||
unlockPeer()
|
||||
}
|
||||
}()
|
||||
|
||||
// fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
peer, err := account.FindPeerByPubKey(login.WireGuardPubKey)
|
||||
if err != nil {
|
||||
return nil, nil, nil, status.NewPeerNotRegisteredError()
|
||||
}
|
||||
|
||||
err = checkIfPeerOwnerIsBlocked(peer, account)
|
||||
settings, err := am.Store.GetAccountSettings(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
@@ -646,21 +662,39 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
// this flag prevents unnecessary calls to the persistent store.
|
||||
shouldStorePeer := false
|
||||
updateRemotePeers := false
|
||||
if peerLoginExpired(ctx, peer, account.Settings) {
|
||||
err = am.handleExpiredPeer(ctx, login, account, peer)
|
||||
|
||||
if login.UserID != "" {
|
||||
changed, err := am.handleUserPeer(ctx, peer, settings)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
updateRemotePeers = true
|
||||
shouldStorePeer = true
|
||||
if changed {
|
||||
shouldStorePeer = true
|
||||
updateRemotePeers = true
|
||||
}
|
||||
}
|
||||
|
||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, account.Id, peer, account.GetPeerGroupsList(peer.ID), account.Settings.Extra)
|
||||
groups, err := am.Store.GetAccountGroups(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
peer, updated := updatePeerMeta(peer, login.Meta, account)
|
||||
var grps []string
|
||||
for _, group := range groups {
|
||||
for _, id := range group.Peers {
|
||||
if id == peer.ID {
|
||||
grps = append(grps, group.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, grps, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
updated := peer.UpdateMetaIfNew(login.Meta)
|
||||
if updated {
|
||||
shouldStorePeer = true
|
||||
}
|
||||
@@ -677,8 +711,13 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
|
||||
}
|
||||
}
|
||||
|
||||
unlock()
|
||||
unlock = nil
|
||||
unlockPeer()
|
||||
unlockPeer = nil
|
||||
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
if updateRemotePeers || isStatusChanged {
|
||||
am.updateAccountPeers(ctx, account)
|
||||
@@ -732,39 +771,34 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
|
||||
}
|
||||
postureChecks = am.getPeerPostureChecks(account, peer)
|
||||
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap), postureChecks, nil
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, login PeerLogin, account *Account, peer *nbpeer.Peer) error {
|
||||
err := checkAuth(ctx, login.UserID, peer)
|
||||
func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *User, peer *nbpeer.Peer) error {
|
||||
err := checkAuth(ctx, user.Id, peer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If peer was expired before and if it reached this point, it is re-authenticated.
|
||||
// UserID is present, meaning that JWT validation passed successfully in the API layer.
|
||||
updatePeerLastLogin(peer, account)
|
||||
|
||||
// sync user last login with peer last login
|
||||
user, err := account.FindUser(login.UserID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "couldn't find user")
|
||||
}
|
||||
|
||||
err = am.Store.SaveUserLastLogin(account.Id, user.Id, peer.LastLogin)
|
||||
peer = peer.UpdateLastLogin()
|
||||
err = am.Store.SavePeer(ctx, peer.AccountID, peer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, login.UserID, peer.ID, account.Id, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
|
||||
err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
am.StoreEvent(ctx, user.Id, peer.ID, user.AccountID, activity.UserLoggedInPeer, peer.EventMeta(am.GetDNSDomain()))
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, account *Account) error {
|
||||
func checkIfPeerOwnerIsBlocked(peer *nbpeer.Peer, user *User) error {
|
||||
if peer.AddedWithSSOLogin() {
|
||||
user, err := account.FindUser(peer.UserID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.PermissionDenied, "user doesn't exist")
|
||||
}
|
||||
if user.IsBlocked() {
|
||||
return status.Errorf(status.PermissionDenied, "user is blocked")
|
||||
}
|
||||
@@ -794,11 +828,6 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *Settings
|
||||
return false
|
||||
}
|
||||
|
||||
func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) {
|
||||
peer.UpdateLastLogin()
|
||||
account.UpdatePeer(peer)
|
||||
}
|
||||
|
||||
// UpdatePeerSSHKey updates peer's public SSH key
|
||||
func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error {
|
||||
if sshKey == "" {
|
||||
@@ -897,33 +926,48 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
||||
return nil, status.Errorf(status.Internal, "user %s has no access to peer %s under account %s", userID, peerID, accountID)
|
||||
}
|
||||
|
||||
func updatePeerMeta(peer *nbpeer.Peer, meta nbpeer.PeerSystemMeta, account *Account) (*nbpeer.Peer, bool) {
|
||||
if peer.UpdateMetaIfNew(meta) {
|
||||
account.UpdatePeer(peer)
|
||||
return peer, true
|
||||
}
|
||||
return peer, false
|
||||
}
|
||||
|
||||
// updateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
if am.metrics != nil {
|
||||
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(start))
|
||||
}
|
||||
}()
|
||||
|
||||
peers := account.GetPeers()
|
||||
|
||||
approvedPeersMap, err := am.GetValidatedPeers(account)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed send out updates to peers, failed to validate peer: %v", err)
|
||||
log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to validate peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
semaphore := make(chan struct{}, 10)
|
||||
|
||||
dnsCache := &DNSConfigCache{}
|
||||
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
|
||||
|
||||
for _, peer := range peers {
|
||||
if !am.peersUpdateManager.HasChannel(peer.ID) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
continue
|
||||
}
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, peer)
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap)
|
||||
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
|
||||
wg.Add(1)
|
||||
semaphore <- struct{}{}
|
||||
go func(p *nbpeer.Peer) {
|
||||
defer wg.Done()
|
||||
defer func() { <-semaphore }()
|
||||
|
||||
postureChecks := am.getPeerPostureChecks(account, p)
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
|
||||
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks, dnsCache)
|
||||
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
|
||||
}(peer)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"slices"
|
||||
@@ -241,7 +240,7 @@ func (p *Peer) FQDN(dnsDomain string) string {
|
||||
if dnsDomain == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
|
||||
return p.DNSLabel + "." + dnsDomain
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to the peer
|
||||
|
||||
31
management/server/peer/peer_test.go
Normal file
31
management/server/peer/peer_test.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// FQDNOld is the original implementation for benchmarking purposes
|
||||
func (p *Peer) FQDNOld(dnsDomain string) string {
|
||||
if dnsDomain == "" {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("%s.%s", p.DNSLabel, dnsDomain)
|
||||
}
|
||||
|
||||
func BenchmarkFQDN(b *testing.B) {
|
||||
p := &Peer{DNSLabel: "test-peer"}
|
||||
dnsDomain := "example.com"
|
||||
|
||||
b.Run("Old", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
p.FQDNOld(dnsDomain)
|
||||
}
|
||||
})
|
||||
|
||||
b.Run("New", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
p.FQDN(dnsDomain)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -2,15 +2,26 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/domain"
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
nbroute "github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
func TestPeer_LoginExpired(t *testing.T) {
|
||||
@@ -633,3 +644,354 @@ func TestDefaultAccountManager_GetPeers(t *testing.T) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func setupTestAccountManager(b *testing.B, peers int, groups int) (*DefaultAccountManager, string, string, error) {
|
||||
b.Helper()
|
||||
|
||||
manager, err := createManager(b)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
accountID := "test_account"
|
||||
adminUser := "account_creator"
|
||||
regularUser := "regular_user"
|
||||
|
||||
account := newAccountWithId(context.Background(), accountID, adminUser, "")
|
||||
account.Users[regularUser] = &User{
|
||||
Id: regularUser,
|
||||
Role: UserRoleUser,
|
||||
}
|
||||
|
||||
// Create peers
|
||||
for i := 0; i < peers; i++ {
|
||||
peerKey, _ := wgtypes.GeneratePrivateKey()
|
||||
peer := &nbpeer.Peer{
|
||||
ID: fmt.Sprintf("peer-%d", i),
|
||||
DNSLabel: fmt.Sprintf("peer-%d", i),
|
||||
Key: peerKey.PublicKey().String(),
|
||||
IP: net.ParseIP(fmt.Sprintf("100.64.%d.%d", i/256, i%256)),
|
||||
Status: &nbpeer.PeerStatus{},
|
||||
UserID: regularUser,
|
||||
}
|
||||
account.Peers[peer.ID] = peer
|
||||
}
|
||||
|
||||
// Create groups and policies
|
||||
account.Policies = make([]*Policy, 0, groups)
|
||||
for i := 0; i < groups; i++ {
|
||||
groupID := fmt.Sprintf("group-%d", i)
|
||||
group := &nbgroup.Group{
|
||||
ID: groupID,
|
||||
Name: fmt.Sprintf("Group %d", i),
|
||||
}
|
||||
for j := 0; j < peers/groups; j++ {
|
||||
peerIndex := i*(peers/groups) + j
|
||||
group.Peers = append(group.Peers, fmt.Sprintf("peer-%d", peerIndex))
|
||||
}
|
||||
account.Groups[groupID] = group
|
||||
|
||||
// Create a policy for this group
|
||||
policy := &Policy{
|
||||
ID: fmt.Sprintf("policy-%d", i),
|
||||
Name: fmt.Sprintf("Policy for Group %d", i),
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
ID: fmt.Sprintf("rule-%d", i),
|
||||
Name: fmt.Sprintf("Rule for Group %d", i),
|
||||
Enabled: true,
|
||||
Sources: []string{groupID},
|
||||
Destinations: []string{groupID},
|
||||
Bidirectional: true,
|
||||
Protocol: PolicyRuleProtocolALL,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
account.Policies = append(account.Policies, policy)
|
||||
}
|
||||
|
||||
account.PostureChecks = []*posture.Checks{
|
||||
{
|
||||
ID: "PostureChecksAll",
|
||||
Name: "All",
|
||||
Checks: posture.ChecksDefinition{
|
||||
NBVersionCheck: &posture.NBVersionCheck{
|
||||
MinVersion: "0.0.1",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
return nil, "", "", err
|
||||
}
|
||||
|
||||
return manager, accountID, regularUser, nil
|
||||
}
|
||||
|
||||
func BenchmarkGetPeers(b *testing.B) {
|
||||
benchCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
groups int
|
||||
}{
|
||||
{"Small", 50, 5},
|
||||
{"Medium", 500, 10},
|
||||
{"Large", 5000, 20},
|
||||
{"Small single", 50, 1},
|
||||
{"Medium single", 500, 1},
|
||||
{"Large 5", 5000, 5},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, userID, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, err := manager.GetPeers(context.Background(), accountID, userID)
|
||||
if err != nil {
|
||||
b.Fatalf("GetPeers failed: %v", err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkUpdateAccountPeers(b *testing.B) {
|
||||
benchCases := []struct {
|
||||
name string
|
||||
peers int
|
||||
groups int
|
||||
}{
|
||||
{"Small", 50, 5},
|
||||
{"Medium", 500, 10},
|
||||
{"Large", 5000, 20},
|
||||
{"Small single", 50, 1},
|
||||
{"Medium single", 500, 1},
|
||||
{"Large 5", 5000, 5},
|
||||
}
|
||||
|
||||
log.SetOutput(io.Discard)
|
||||
defer log.SetOutput(os.Stderr)
|
||||
|
||||
for _, bc := range benchCases {
|
||||
b.Run(bc.name, func(b *testing.B) {
|
||||
manager, accountID, _, err := setupTestAccountManager(b, bc.peers, bc.groups)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to setup test account manager: %v", err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
account, err := manager.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
b.Fatalf("Failed to get account: %v", err)
|
||||
}
|
||||
|
||||
peerChannels := make(map[string]chan *UpdateMessage)
|
||||
|
||||
for peerID := range account.Peers {
|
||||
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
|
||||
}
|
||||
|
||||
manager.peersUpdateManager.peerChannels = peerChannels
|
||||
|
||||
b.ResetTimer()
|
||||
start := time.Now()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
manager.updateAccountPeers(ctx, account)
|
||||
}
|
||||
|
||||
duration := time.Since(start)
|
||||
b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op")
|
||||
b.ReportMetric(0, "ns/op")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToSyncResponse(t *testing.T) {
|
||||
_, ipnet, err := net.ParseCIDR("192.168.1.0/24")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
domainList, err := domain.FromStringList([]string{"example.com"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config := &Config{
|
||||
Signal: &Host{
|
||||
Proto: "https",
|
||||
URI: "signal.uri",
|
||||
Username: "",
|
||||
Password: "",
|
||||
},
|
||||
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
|
||||
TURNConfig: &TURNConfig{
|
||||
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
|
||||
},
|
||||
}
|
||||
peer := &nbpeer.Peer{
|
||||
IP: net.ParseIP("192.168.1.1"),
|
||||
SSHEnabled: true,
|
||||
Key: "peer-key",
|
||||
DNSLabel: "peer1",
|
||||
SSHKey: "peer1-ssh-key",
|
||||
}
|
||||
turnCredentials := &TURNCredentials{
|
||||
Username: "turn-user",
|
||||
Password: "turn-pass",
|
||||
}
|
||||
networkMap := &NetworkMap{
|
||||
Network: &Network{Net: *ipnet, Serial: 1000},
|
||||
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
|
||||
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
|
||||
Routes: []*nbroute.Route{
|
||||
{
|
||||
ID: "route1",
|
||||
Network: netip.MustParsePrefix("10.0.0.0/24"),
|
||||
Domains: domainList,
|
||||
KeepRoute: true,
|
||||
NetID: "route1",
|
||||
Peer: "peer1",
|
||||
NetworkType: 1,
|
||||
Masquerade: true,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
DNSConfig: nbdns.Config{
|
||||
ServiceEnable: true,
|
||||
NameServerGroups: []*nbdns.NameServerGroup{
|
||||
{
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Primary: true,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: true,
|
||||
},
|
||||
{
|
||||
ID: "ns1",
|
||||
NameServers: []nbdns.NameServer{{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
}},
|
||||
Groups: []string{"group1"},
|
||||
Primary: true,
|
||||
Domains: []string{"example.com"},
|
||||
Enabled: true,
|
||||
SearchDomainsEnabled: true,
|
||||
},
|
||||
},
|
||||
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
|
||||
},
|
||||
FirewallRules: []*FirewallRule{
|
||||
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
|
||||
},
|
||||
}
|
||||
dnsName := "example.com"
|
||||
checks := []*posture.Checks{
|
||||
{
|
||||
Checks: posture.ChecksDefinition{
|
||||
ProcessCheck: &posture.ProcessCheck{
|
||||
Processes: []posture.Process{{LinuxPath: "/usr/bin/netbird"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dnsCache := &DNSConfigCache{}
|
||||
|
||||
response := toSyncResponse(context.Background(), config, peer, turnCredentials, networkMap, dnsName, checks, dnsCache)
|
||||
|
||||
assert.NotNil(t, response)
|
||||
// assert peer config
|
||||
assert.Equal(t, "192.168.1.1/24", response.PeerConfig.Address)
|
||||
assert.Equal(t, "peer1.example.com", response.PeerConfig.Fqdn)
|
||||
assert.Equal(t, true, response.PeerConfig.SshConfig.SshEnabled)
|
||||
// assert wiretrustee config
|
||||
assert.Equal(t, "signal.uri", response.WiretrusteeConfig.Signal.Uri)
|
||||
assert.Equal(t, proto.HostConfig_HTTPS, response.WiretrusteeConfig.Signal.GetProtocol())
|
||||
assert.Equal(t, "stun.uri", response.WiretrusteeConfig.Stuns[0].Uri)
|
||||
assert.Equal(t, "turn.uri", response.WiretrusteeConfig.Turns[0].HostConfig.GetUri())
|
||||
assert.Equal(t, "turn-user", response.WiretrusteeConfig.Turns[0].User)
|
||||
assert.Equal(t, "turn-pass", response.WiretrusteeConfig.Turns[0].Password)
|
||||
// assert RemotePeers
|
||||
assert.Equal(t, 1, len(response.RemotePeers))
|
||||
assert.Equal(t, "192.168.1.2/32", response.RemotePeers[0].AllowedIps[0])
|
||||
assert.Equal(t, "peer2-key", response.RemotePeers[0].WgPubKey)
|
||||
assert.Equal(t, "peer2.example.com", response.RemotePeers[0].GetFqdn())
|
||||
assert.Equal(t, false, response.RemotePeers[0].GetSshConfig().GetSshEnabled())
|
||||
assert.Equal(t, []byte("peer2-ssh-key"), response.RemotePeers[0].GetSshConfig().GetSshPubKey())
|
||||
// assert network map
|
||||
assert.Equal(t, uint64(1000), response.NetworkMap.Serial)
|
||||
assert.Equal(t, "192.168.1.1/24", response.NetworkMap.PeerConfig.Address)
|
||||
assert.Equal(t, "peer1.example.com", response.NetworkMap.PeerConfig.Fqdn)
|
||||
assert.Equal(t, true, response.NetworkMap.PeerConfig.SshConfig.SshEnabled)
|
||||
// assert network map RemotePeers
|
||||
assert.Equal(t, 1, len(response.NetworkMap.RemotePeers))
|
||||
assert.Equal(t, "192.168.1.2/32", response.NetworkMap.RemotePeers[0].AllowedIps[0])
|
||||
assert.Equal(t, "peer2-key", response.NetworkMap.RemotePeers[0].WgPubKey)
|
||||
assert.Equal(t, "peer2.example.com", response.NetworkMap.RemotePeers[0].GetFqdn())
|
||||
assert.Equal(t, []byte("peer2-ssh-key"), response.NetworkMap.RemotePeers[0].GetSshConfig().GetSshPubKey())
|
||||
// assert network map OfflinePeers
|
||||
assert.Equal(t, 1, len(response.NetworkMap.OfflinePeers))
|
||||
assert.Equal(t, "192.168.1.3/32", response.NetworkMap.OfflinePeers[0].AllowedIps[0])
|
||||
assert.Equal(t, "peer3-key", response.NetworkMap.OfflinePeers[0].WgPubKey)
|
||||
assert.Equal(t, "peer3.example.com", response.NetworkMap.OfflinePeers[0].GetFqdn())
|
||||
assert.Equal(t, []byte("peer3-ssh-key"), response.NetworkMap.OfflinePeers[0].GetSshConfig().GetSshPubKey())
|
||||
// assert network map Routes
|
||||
assert.Equal(t, 1, len(response.NetworkMap.Routes))
|
||||
assert.Equal(t, "10.0.0.0/24", response.NetworkMap.Routes[0].Network)
|
||||
assert.Equal(t, "route1", response.NetworkMap.Routes[0].ID)
|
||||
assert.Equal(t, "peer1", response.NetworkMap.Routes[0].Peer)
|
||||
assert.Equal(t, "example.com", response.NetworkMap.Routes[0].Domains[0])
|
||||
assert.Equal(t, true, response.NetworkMap.Routes[0].KeepRoute)
|
||||
assert.Equal(t, true, response.NetworkMap.Routes[0].Masquerade)
|
||||
assert.Equal(t, int64(9999), response.NetworkMap.Routes[0].Metric)
|
||||
assert.Equal(t, int64(1), response.NetworkMap.Routes[0].NetworkType)
|
||||
assert.Equal(t, "route1", response.NetworkMap.Routes[0].NetID)
|
||||
// assert network map DNSConfig
|
||||
assert.Equal(t, true, response.NetworkMap.DNSConfig.ServiceEnable)
|
||||
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones))
|
||||
assert.Equal(t, 2, len(response.NetworkMap.DNSConfig.NameServerGroups))
|
||||
// assert network map DNSConfig.CustomZones
|
||||
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Domain)
|
||||
assert.Equal(t, 1, len(response.NetworkMap.DNSConfig.CustomZones[0].Records))
|
||||
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Name)
|
||||
assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Type)
|
||||
assert.Equal(t, "IN", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].Class)
|
||||
assert.Equal(t, int64(60), response.NetworkMap.DNSConfig.CustomZones[0].Records[0].TTL)
|
||||
assert.Equal(t, "100.64.0.1", response.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData)
|
||||
// assert network map DNSConfig.NameServerGroups
|
||||
assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].Primary)
|
||||
assert.Equal(t, true, response.NetworkMap.DNSConfig.NameServerGroups[0].SearchDomainsEnabled)
|
||||
assert.Equal(t, "example.com", response.NetworkMap.DNSConfig.NameServerGroups[0].Domains[0])
|
||||
assert.Equal(t, "8.8.8.8", response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetIP())
|
||||
assert.Equal(t, int64(1), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetNSType())
|
||||
assert.Equal(t, int64(53), response.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].GetPort())
|
||||
// assert network map Firewall
|
||||
assert.Equal(t, 1, len(response.NetworkMap.FirewallRules))
|
||||
assert.Equal(t, "192.168.1.2", response.NetworkMap.FirewallRules[0].PeerIP)
|
||||
assert.Equal(t, proto.FirewallRule_IN, response.NetworkMap.FirewallRules[0].Direction)
|
||||
assert.Equal(t, proto.FirewallRule_ACCEPT, response.NetworkMap.FirewallRules[0].Action)
|
||||
assert.Equal(t, proto.FirewallRule_TCP, response.NetworkMap.FirewallRules[0].Protocol)
|
||||
assert.Equal(t, "80", response.NetworkMap.FirewallRules[0].Port)
|
||||
// assert posture checks
|
||||
assert.Equal(t, 1, len(response.Checks))
|
||||
assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0])
|
||||
}
|
||||
|
||||
@@ -213,7 +213,6 @@ type FirewallRule struct {
|
||||
//
|
||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||
func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
|
||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx)
|
||||
for _, policy := range a.Policies {
|
||||
if !policy.Enabled {
|
||||
@@ -225,8 +224,8 @@ func (a *Account) getPeerConnectionResources(ctx context.Context, peerID string,
|
||||
continue
|
||||
}
|
||||
|
||||
sourcePeers, peerInSources := getAllPeersFromGroups(ctx, a, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
|
||||
destinationPeers, peerInDestinations := getAllPeersFromGroups(ctx, a, rule.Destinations, peerID, nil, validatedPeersMap)
|
||||
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap)
|
||||
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap)
|
||||
|
||||
if rule.Bidirectional {
|
||||
if peerInSources {
|
||||
@@ -290,8 +289,8 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
|
||||
fr.PeerIP = "0.0.0.0"
|
||||
}
|
||||
|
||||
ruleID := (rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||
fr.Protocol + fr.Action + strings.Join(rule.Ports, ","))
|
||||
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
|
||||
if _, ok := rulesExists[ruleID]; ok {
|
||||
continue
|
||||
}
|
||||
@@ -491,23 +490,23 @@ func toProtocolFirewallRules(update []*FirewallRule) []*proto.FirewallRule {
|
||||
//
|
||||
// Important: Posture checks are applicable only to source group peers,
|
||||
// for destination group peers, call this method with an empty list of sourcePostureChecksIDs
|
||||
func getAllPeersFromGroups(ctx context.Context, account *Account, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
||||
func (a *Account) getAllPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, bool) {
|
||||
peerInGroups := false
|
||||
filteredPeers := make([]*nbpeer.Peer, 0, len(groups))
|
||||
for _, g := range groups {
|
||||
group, ok := account.Groups[g]
|
||||
group, ok := a.Groups[g]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, p := range group.Peers {
|
||||
peer, ok := account.Peers[p]
|
||||
peer, ok := a.Peers[p]
|
||||
if !ok || peer == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// validate the peer based on policy posture checks applied
|
||||
isValid := account.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
|
||||
isValid := a.validatePostureChecksOnPeer(ctx, sourcePostureChecksIDs, peer.ID)
|
||||
if !isValid {
|
||||
continue
|
||||
}
|
||||
@@ -535,7 +534,7 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture
|
||||
}
|
||||
|
||||
for _, postureChecksID := range sourcePostureChecksID {
|
||||
postureChecks := getPostureChecks(a, postureChecksID)
|
||||
postureChecks := a.getPostureChecks(postureChecksID)
|
||||
if postureChecks == nil {
|
||||
continue
|
||||
}
|
||||
@@ -553,8 +552,8 @@ func (a *Account) validatePostureChecksOnPeer(ctx context.Context, sourcePosture
|
||||
return true
|
||||
}
|
||||
|
||||
func getPostureChecks(account *Account, postureChecksID string) *posture.Checks {
|
||||
for _, postureChecks := range account.PostureChecks {
|
||||
func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks {
|
||||
for _, postureChecks := range a.PostureChecks {
|
||||
if postureChecks.ID == postureChecksID {
|
||||
return postureChecks
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
|
||||
}
|
||||
|
||||
if err := postureChecks.Validate(); err != nil {
|
||||
return status.Errorf(status.InvalidArgument, err.Error())
|
||||
return status.Errorf(status.InvalidArgument, err.Error()) //nolint
|
||||
}
|
||||
|
||||
exists, uniqName := am.savePostureChecks(account, postureChecks)
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbgroup "github.com/netbirdio/netbird/management/server/group"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
@@ -1233,7 +1234,11 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{})
|
||||
|
||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.selfhosted", eventStore, nil, false, MocIntegratedValidator{}, metrics)
|
||||
}
|
||||
|
||||
func createRouterStore(t *testing.T) (Store, error) {
|
||||
|
||||
@@ -223,10 +223,8 @@ func (am *DefaultAccountManager) CreateSetupKey(ctx context.Context, accountID s
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, group := range autoGroups {
|
||||
if _, ok := account.Groups[group]; !ok {
|
||||
return nil, status.Errorf(status.NotFound, "group %s doesn't exist", group)
|
||||
}
|
||||
if err := validateSetupKeyAutoGroups(account, autoGroups); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
setupKey := GenerateSetupKey(keyName, keyType, keyDuration, autoGroups, usageLimit, ephemeral)
|
||||
@@ -279,6 +277,10 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
|
||||
return nil, status.Errorf(status.NotFound, "setup key not found")
|
||||
}
|
||||
|
||||
if err := validateSetupKeyAutoGroups(account, keyToSave.AutoGroups); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// only auto groups, revoked status, and name can be updated for now
|
||||
newKey := oldKey.Copy()
|
||||
newKey.Name = keyToSave.Name
|
||||
@@ -399,3 +401,16 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
|
||||
|
||||
return foundKey, nil
|
||||
}
|
||||
|
||||
func validateSetupKeyAutoGroups(account *Account, autoGroups []string) error {
|
||||
for _, group := range autoGroups {
|
||||
g, ok := account.Groups[group]
|
||||
if !ok {
|
||||
return status.Errorf(status.NotFound, "group %s doesn't exist", group)
|
||||
}
|
||||
if g.Name == "All" {
|
||||
return status.Errorf(status.InvalidArgument, "can't add All group to the setup key")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -26,10 +26,17 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
err = manager.SaveGroups(context.Background(), account.Id, userID, []*nbgroup.Group{
|
||||
{
|
||||
ID: "group_1",
|
||||
Name: "group_name_1",
|
||||
Peers: []string{},
|
||||
},
|
||||
{
|
||||
ID: "group_2",
|
||||
Name: "group_name_2",
|
||||
Peers: []string{},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -70,6 +77,19 @@ func TestDefaultAccountManager_SaveSetupKey(t *testing.T) {
|
||||
assert.NotEmpty(t, ev.Meta["key"])
|
||||
assert.Equal(t, userID, ev.InitiatorID)
|
||||
assert.Equal(t, key.Id, ev.TargetID)
|
||||
|
||||
groupAll, err := account.GetGroupAll()
|
||||
assert.NoError(t, err)
|
||||
|
||||
// saving setup key with All group assigned to auto groups should return error
|
||||
autoGroups = append(autoGroups, groupAll.ID)
|
||||
_, err = manager.SaveSetupKey(context.Background(), account.Id, &SetupKey{
|
||||
Id: key.Id,
|
||||
Name: newKeyName,
|
||||
Revoked: revoked,
|
||||
AutoGroups: autoGroups,
|
||||
}, userID)
|
||||
assert.Error(t, err, "should not save setup key with All group assigned in auto groups")
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
@@ -102,6 +122,9 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
groupAll, err := account.GetGroupAll()
|
||||
assert.NoError(t, err)
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
|
||||
@@ -134,8 +157,14 @@ func TestDefaultAccountManager_CreateSetupKey(t *testing.T) {
|
||||
expectedGroups: []string{"FAKE"},
|
||||
expectedFailure: true,
|
||||
}
|
||||
testCase3 := testCase{
|
||||
name: "Create Setup Key should fail because of All group",
|
||||
expectedKeyName: "my-test-key",
|
||||
expectedGroups: []string{groupAll.ID},
|
||||
expectedFailure: true,
|
||||
}
|
||||
|
||||
for _, tCase := range []testCase{testCase1, testCase2} {
|
||||
for _, tCase := range []testCase{testCase1, testCase2, testCase3} {
|
||||
t.Run(tCase.name, func(t *testing.T) {
|
||||
key, err := manager.CreateSetupKey(context.Background(), account.Id, tCase.expectedKeyName, SetupKeyReusable, expiresIn,
|
||||
tCase.expectedGroups, SetupKeyUnlimitedUsage, userID, false)
|
||||
|
||||
@@ -468,6 +468,34 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) {
|
||||
var user User
|
||||
result := s.db.First(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "user not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting user from store")
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
|
||||
var groups []*nbgroup.Group
|
||||
result := s.db.Find(&groups, idQueryCondition, accountID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "issue getting groups from store")
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) {
|
||||
var accounts []Account
|
||||
result := s.db.Find(&accounts)
|
||||
|
||||
@@ -41,6 +41,8 @@ type Store interface {
|
||||
GetAccountByPrivateDomain(ctx context.Context, domain string) (*Account, error)
|
||||
GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error)
|
||||
GetUserByTokenID(ctx context.Context, tokenID string) (*User, error)
|
||||
GetUserByUserID(ctx context.Context, userID string) (*User, error)
|
||||
GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error)
|
||||
GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
|
||||
SaveAccount(ctx context.Context, account *Account) error
|
||||
SaveUsers(accountID string, users map[string]*User) error
|
||||
|
||||
69
management/server/telemetry/accountmanager_metrics.go
Normal file
69
management/server/telemetry/accountmanager_metrics.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
)
|
||||
|
||||
// AccountManagerMetrics represents all metrics related to the AccountManager
|
||||
type AccountManagerMetrics struct {
|
||||
ctx context.Context
|
||||
updateAccountPeersDurationMs metric.Float64Histogram
|
||||
getPeerNetworkMapDurationMs metric.Float64Histogram
|
||||
networkMapObjectCount metric.Int64Histogram
|
||||
}
|
||||
|
||||
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
|
||||
func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*AccountManagerMetrics, error) {
|
||||
updateAccountPeersDurationMs, err := meter.Float64Histogram("management.account.update.account.peers.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithExplicitBucketBoundaries(
|
||||
0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000, 2500, 5000, 10000, 30000,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
getPeerNetworkMapDurationMs, err := meter.Float64Histogram("management.account.get.peer.network.map.duration.ms",
|
||||
metric.WithUnit("milliseconds"),
|
||||
metric.WithExplicitBucketBoundaries(
|
||||
0.1, 0.5, 1, 2.5, 5, 10, 25, 50, 100, 250, 500, 1000,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
networkMapObjectCount, err := meter.Int64Histogram("management.account.network.map.object.count",
|
||||
metric.WithUnit("objects"),
|
||||
metric.WithExplicitBucketBoundaries(
|
||||
50, 100, 200, 500, 1000, 2500, 5000, 10000,
|
||||
))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &AccountManagerMetrics{
|
||||
ctx: ctx,
|
||||
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
|
||||
updateAccountPeersDurationMs: updateAccountPeersDurationMs,
|
||||
networkMapObjectCount: networkMapObjectCount,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
// CountUpdateAccountPeersDuration counts the duration of updating account peers
|
||||
func (metrics *AccountManagerMetrics) CountUpdateAccountPeersDuration(duration time.Duration) {
|
||||
metrics.updateAccountPeersDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
|
||||
}
|
||||
|
||||
// CountGetPeerNetworkMapDuration counts the duration of getting the peer network map
|
||||
func (metrics *AccountManagerMetrics) CountGetPeerNetworkMapDuration(duration time.Duration) {
|
||||
metrics.getPeerNetworkMapDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
|
||||
}
|
||||
|
||||
// CountNetworkMapObjects counts the number of network map objects
|
||||
func (metrics *AccountManagerMetrics) CountNetworkMapObjects(count int64) {
|
||||
metrics.networkMapObjectCount.Record(metrics.ctx, count)
|
||||
}
|
||||
@@ -20,14 +20,15 @@ const defaultEndpoint = "/metrics"
|
||||
|
||||
// MockAppMetrics mocks the AppMetrics interface
|
||||
type MockAppMetrics struct {
|
||||
GetMeterFunc func() metric2.Meter
|
||||
CloseFunc func() error
|
||||
ExposeFunc func(ctx context.Context, port int, endpoint string) error
|
||||
IDPMetricsFunc func() *IDPMetrics
|
||||
HTTPMiddlewareFunc func() *HTTPMiddleware
|
||||
GRPCMetricsFunc func() *GRPCMetrics
|
||||
StoreMetricsFunc func() *StoreMetrics
|
||||
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
|
||||
GetMeterFunc func() metric2.Meter
|
||||
CloseFunc func() error
|
||||
ExposeFunc func(ctx context.Context, port int, endpoint string) error
|
||||
IDPMetricsFunc func() *IDPMetrics
|
||||
HTTPMiddlewareFunc func() *HTTPMiddleware
|
||||
GRPCMetricsFunc func() *GRPCMetrics
|
||||
StoreMetricsFunc func() *StoreMetrics
|
||||
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
|
||||
AddAccountManagerMetricsFunc func() *AccountManagerMetrics
|
||||
}
|
||||
|
||||
// GetMeter mocks the GetMeter function of the AppMetrics interface
|
||||
@@ -94,6 +95,14 @@ func (mock *MockAppMetrics) UpdateChannelMetrics() *UpdateChannelMetrics {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AccountManagerMetrics mocks the MockAppMetrics function of the AccountManagerMetrics interface
|
||||
func (mock *MockAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
|
||||
if mock.AddAccountManagerMetricsFunc != nil {
|
||||
return mock.AddAccountManagerMetricsFunc()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AppMetrics is metrics interface
|
||||
type AppMetrics interface {
|
||||
GetMeter() metric2.Meter
|
||||
@@ -104,19 +113,21 @@ type AppMetrics interface {
|
||||
GRPCMetrics() *GRPCMetrics
|
||||
StoreMetrics() *StoreMetrics
|
||||
UpdateChannelMetrics() *UpdateChannelMetrics
|
||||
AccountManagerMetrics() *AccountManagerMetrics
|
||||
}
|
||||
|
||||
// defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/
|
||||
type defaultAppMetrics struct {
|
||||
// Meter can be used by different application parts to create counters and measure things
|
||||
Meter metric2.Meter
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
idpMetrics *IDPMetrics
|
||||
httpMiddleware *HTTPMiddleware
|
||||
grpcMetrics *GRPCMetrics
|
||||
storeMetrics *StoreMetrics
|
||||
updateChannelMetrics *UpdateChannelMetrics
|
||||
Meter metric2.Meter
|
||||
listener net.Listener
|
||||
ctx context.Context
|
||||
idpMetrics *IDPMetrics
|
||||
httpMiddleware *HTTPMiddleware
|
||||
grpcMetrics *GRPCMetrics
|
||||
storeMetrics *StoreMetrics
|
||||
updateChannelMetrics *UpdateChannelMetrics
|
||||
accountManagerMetrics *AccountManagerMetrics
|
||||
}
|
||||
|
||||
// IDPMetrics returns metrics for the idp package
|
||||
@@ -144,6 +155,11 @@ func (appMetrics *defaultAppMetrics) UpdateChannelMetrics() *UpdateChannelMetric
|
||||
return appMetrics.updateChannelMetrics
|
||||
}
|
||||
|
||||
// AccountManagerMetrics returns metrics for the account manager
|
||||
func (appMetrics *defaultAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
|
||||
return appMetrics.accountManagerMetrics
|
||||
}
|
||||
|
||||
// Close stop application metrics HTTP handler and closes listener.
|
||||
func (appMetrics *defaultAppMetrics) Close() error {
|
||||
if appMetrics.listener == nil {
|
||||
@@ -220,13 +236,19 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountManagerMetrics, err := NewAccountManagerMetrics(ctx, meter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &defaultAppMetrics{
|
||||
Meter: meter,
|
||||
ctx: ctx,
|
||||
idpMetrics: idpMetrics,
|
||||
httpMiddleware: middleware,
|
||||
grpcMetrics: grpcMetrics,
|
||||
storeMetrics: storeMetrics,
|
||||
updateChannelMetrics: updateChannelMetrics,
|
||||
Meter: meter,
|
||||
ctx: ctx,
|
||||
idpMetrics: idpMetrics,
|
||||
httpMiddleware: middleware,
|
||||
grpcMetrics: grpcMetrics,
|
||||
storeMetrics: storeMetrics,
|
||||
updateChannelMetrics: updateChannelMetrics,
|
||||
accountManagerMetrics: accountManagerMetrics,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -472,51 +473,18 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error {
|
||||
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !isNil(am.idpManager) {
|
||||
// Delete if the user already exists in the IdP.Necessary in cases where a user account
|
||||
// was created where a user account was provisioned but the user did not sign in
|
||||
_, err = am.idpManager.GetUserDataByID(ctx, targetUserID, idp.AppMetadata{WTAccountID: account.Id})
|
||||
if err == nil {
|
||||
err = am.deleteUserFromIDP(ctx, targetUserID, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID)
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
|
||||
meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u, err := account.FindUser(targetUserID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err)
|
||||
}
|
||||
|
||||
var tuCreatedAt time.Time
|
||||
if u != nil {
|
||||
tuCreatedAt = u.CreatedAt
|
||||
}
|
||||
|
||||
delete(account.Users, targetUserID)
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}
|
||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
return nil
|
||||
@@ -976,10 +944,14 @@ func validateUserUpdate(account *Account, initiatorUser, oldUser, update *User)
|
||||
}
|
||||
|
||||
for _, newGroupID := range update.AutoGroups {
|
||||
if _, ok := account.Groups[newGroupID]; !ok {
|
||||
group, ok := account.Groups[newGroupID]
|
||||
if !ok {
|
||||
return status.Errorf(status.InvalidArgument, "provided group ID %s in the user %s update doesn't exist",
|
||||
newGroupID, update.Id)
|
||||
}
|
||||
if group.Name == "All" {
|
||||
return status.Errorf(status.InvalidArgument, "can't add All group to the user")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -1190,6 +1162,116 @@ func (am *DefaultAccountManager) getEmailAndNameOfTargetUser(ctx context.Context
|
||||
return "", "", fmt.Errorf("user info not found for user: %s", targetId)
|
||||
}
|
||||
|
||||
// DeleteRegularUsers deletes regular users from an account.
|
||||
// Note: This function does not acquire the global lock.
|
||||
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
|
||||
//
|
||||
// If an error occurs while deleting the user, the function skips it and continues deleting other users.
|
||||
// Errors are collected and returned at the end.
|
||||
func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string) error {
|
||||
account, err := am.Store.GetAccount(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
executingUser := account.Users[initiatorUserID]
|
||||
if executingUser == nil {
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
if !executingUser.HasAdminPower() {
|
||||
return status.Errorf(status.PermissionDenied, "only users with admin power can delete users")
|
||||
}
|
||||
|
||||
var allErrors error
|
||||
|
||||
deletedUsersMeta := make(map[string]map[string]any)
|
||||
for _, targetUserID := range targetUserIDs {
|
||||
if initiatorUserID == targetUserID {
|
||||
allErrors = errors.Join(allErrors, errors.New("self deletion is not allowed"))
|
||||
continue
|
||||
}
|
||||
|
||||
targetUser := account.Users[targetUserID]
|
||||
if targetUser == nil {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("target user: %s not found", targetUserID))
|
||||
continue
|
||||
}
|
||||
|
||||
if targetUser.Role == UserRoleOwner {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("unable to delete a user: %s with owner role", targetUserID))
|
||||
continue
|
||||
}
|
||||
|
||||
// disable deleting integration user if the initiator is not admin service user
|
||||
if targetUser.Issued == UserIssuedIntegration && !executingUser.IsServiceUser {
|
||||
allErrors = errors.Join(allErrors, errors.New("only integration service user can delete this user"))
|
||||
continue
|
||||
}
|
||||
|
||||
meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete user %s: %s", targetUserID, err))
|
||||
continue
|
||||
}
|
||||
|
||||
delete(account.Users, targetUserID)
|
||||
deletedUsersMeta[targetUserID] = meta
|
||||
}
|
||||
|
||||
err = am.Store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete users: %w", err)
|
||||
}
|
||||
|
||||
am.updateAccountPeers(ctx, account)
|
||||
|
||||
for targetUserID, meta := range deletedUsersMeta {
|
||||
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
|
||||
}
|
||||
|
||||
return allErrors
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, error) {
|
||||
tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to resolve email address: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !isNil(am.idpManager) {
|
||||
// Delete if the user already exists in the IdP. Necessary in cases where a user account
|
||||
// was created where a user account was provisioned but the user did not sign in
|
||||
_, err = am.idpManager.GetUserDataByID(ctx, targetUserID, idp.AppMetadata{WTAccountID: account.Id})
|
||||
if err == nil {
|
||||
err = am.deleteUserFromIDP(ctx, targetUserID, account.Id)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID)
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err)
|
||||
}
|
||||
}
|
||||
|
||||
err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
u, err := account.FindUser(targetUserID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to find user %s for deletion, this should never happen: %s", targetUserID, err)
|
||||
}
|
||||
|
||||
var tuCreatedAt time.Time
|
||||
if u != nil {
|
||||
tuCreatedAt = u.CreatedAt
|
||||
}
|
||||
|
||||
return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil
|
||||
}
|
||||
|
||||
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
|
||||
for _, user := range userData {
|
||||
if user.ID == userID {
|
||||
|
||||
@@ -662,6 +662,157 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
func TestUser_DeleteUser_RegularUsers(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
account := newAccountWithId(context.Background(), mockAccountID, mockUserID, "")
|
||||
|
||||
targetId := "user2"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: true,
|
||||
ServiceUserName: "user2username",
|
||||
}
|
||||
targetId = "user3"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
}
|
||||
targetId = "user4"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedIntegration,
|
||||
}
|
||||
|
||||
targetId = "user5"
|
||||
account.Users[targetId] = &User{
|
||||
Id: targetId,
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleOwner,
|
||||
}
|
||||
account.Users["user6"] = &User{
|
||||
Id: "user6",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
}
|
||||
account.Users["user7"] = &User{
|
||||
Id: "user7",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
}
|
||||
account.Users["user8"] = &User{
|
||||
Id: "user8",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleAdmin,
|
||||
}
|
||||
account.Users["user9"] = &User{
|
||||
Id: "user9",
|
||||
IsServiceUser: false,
|
||||
Issued: UserIssuedAPI,
|
||||
Role: UserRoleAdmin,
|
||||
}
|
||||
|
||||
err := store.SaveAccount(context.Background(), account)
|
||||
if err != nil {
|
||||
t.Fatalf("Error when saving account: %s", err)
|
||||
}
|
||||
|
||||
am := DefaultAccountManager{
|
||||
Store: store,
|
||||
eventStore: &activity.InMemoryEventStore{},
|
||||
integratedPeerValidator: MocIntegratedValidator{},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
userIDs []string
|
||||
expectedReasons []string
|
||||
expectedDeleted []string
|
||||
expectedNotDeleted []string
|
||||
}{
|
||||
{
|
||||
name: "Delete service user successfully ",
|
||||
userIDs: []string{"user2"},
|
||||
expectedDeleted: []string{"user2"},
|
||||
},
|
||||
{
|
||||
name: "Delete regular user successfully",
|
||||
userIDs: []string{"user3"},
|
||||
expectedDeleted: []string{"user3"},
|
||||
},
|
||||
{
|
||||
name: "Delete integration regular user permission denied",
|
||||
userIDs: []string{"user4"},
|
||||
expectedReasons: []string{"only integration service user can delete this user"},
|
||||
expectedNotDeleted: []string{"user4"},
|
||||
},
|
||||
{
|
||||
name: "Delete user with owner role should return permission denied",
|
||||
userIDs: []string{"user5"},
|
||||
expectedReasons: []string{"unable to delete a user: user5 with owner role"},
|
||||
expectedNotDeleted: []string{"user5"},
|
||||
},
|
||||
{
|
||||
name: "Delete multiple users with mixed results",
|
||||
userIDs: []string{"user5", "user5", "user6", "user7"},
|
||||
expectedReasons: []string{"only integration service user can delete this user", "unable to delete a user: user5 with owner role"},
|
||||
expectedDeleted: []string{"user6", "user7"},
|
||||
expectedNotDeleted: []string{"user4", "user5"},
|
||||
},
|
||||
{
|
||||
name: "Delete non-existent user",
|
||||
userIDs: []string{"non-existent-user"},
|
||||
expectedReasons: []string{"target user: non-existent-user not found"},
|
||||
expectedNotDeleted: []string{},
|
||||
},
|
||||
{
|
||||
name: "Delete multiple regular users successfully",
|
||||
userIDs: []string{"user8", "user9"},
|
||||
expectedDeleted: []string{"user8", "user9"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
err = am.DeleteRegularUsers(context.Background(), mockAccountID, mockUserID, tc.userIDs)
|
||||
if len(tc.expectedReasons) > 0 {
|
||||
assert.Error(t, err)
|
||||
var foundExpectedErrors int
|
||||
|
||||
wrappedErr, ok := err.(interface{ Unwrap() []error })
|
||||
assert.Equal(t, ok, true)
|
||||
|
||||
for _, e := range wrappedErr.Unwrap() {
|
||||
assert.Contains(t, tc.expectedReasons, e.Error(), "unexpected error message")
|
||||
foundExpectedErrors++
|
||||
}
|
||||
|
||||
assert.Equal(t, len(tc.expectedReasons), foundExpectedErrors, "not all expected errors were found")
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
acc, err := am.GetAccountByUserOrAccountID(context.Background(), "", account.Id, "")
|
||||
assert.NoError(t, err)
|
||||
|
||||
for _, id := range tc.expectedDeleted {
|
||||
_, exists := acc.Users[id]
|
||||
assert.False(t, exists, "user should have been deleted: %s", id)
|
||||
}
|
||||
|
||||
for _, id := range tc.expectedNotDeleted {
|
||||
user, exists := acc.Users[id]
|
||||
assert.True(t, exists, "user should not have been deleted: %s", id)
|
||||
assert.NotNil(t, user, "user should exist: %s", id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
store := newStore(t)
|
||||
defer store.Close(context.Background())
|
||||
|
||||
Reference in New Issue
Block a user