mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Merge branch 'main' into idp-user-cache
This commit is contained in:
@@ -81,7 +81,6 @@ type AccountManager interface {
|
||||
GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error)
|
||||
GetGroup(accountId, groupID string) (*Group, error)
|
||||
SaveGroup(accountID, userID string, group *Group) error
|
||||
UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error)
|
||||
DeleteGroup(accountId, userId, groupID string) error
|
||||
ListGroups(accountId string) ([]*Group, error)
|
||||
GroupAddPeer(accountId, groupID, peerID string) error
|
||||
@@ -94,13 +93,11 @@ type AccountManager interface {
|
||||
GetRoute(accountID, routeID, userID string) (*route.Route, error)
|
||||
CreateRoute(accountID string, prefix, peerID, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
||||
SaveRoute(accountID, userID string, route *route.Route) error
|
||||
UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error)
|
||||
DeleteRoute(accountID, routeID, userID string) error
|
||||
ListRoutes(accountID, userID string) ([]*route.Route, error)
|
||||
GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroup(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
|
||||
DeleteNameServerGroup(accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroups(accountID string) ([]*nbdns.NameServerGroup, error)
|
||||
GetDNSDomain() string
|
||||
@@ -134,6 +131,9 @@ type DefaultAccountManager struct {
|
||||
// dnsDomain is used for peer resolution. This is appended to the peer's name
|
||||
dnsDomain string
|
||||
peerLoginExpiry Scheduler
|
||||
|
||||
// userDeleteFromIDPEnabled allows to delete user from IDP when user is deleted from account
|
||||
userDeleteFromIDPEnabled bool
|
||||
}
|
||||
|
||||
// Settings represents Account settings structure that can be modified via API and Dashboard
|
||||
@@ -739,18 +739,19 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) {
|
||||
|
||||
// BuildManager creates a new DefaultAccountManager with a provided Store
|
||||
func BuildManager(store Store, peersUpdateManager *PeersUpdateManager, idpManager idp.Manager,
|
||||
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store,
|
||||
singleAccountModeDomain string, dnsDomain string, eventStore activity.Store, userDeleteFromIDPEnabled bool,
|
||||
) (*DefaultAccountManager, error) {
|
||||
am := &DefaultAccountManager{
|
||||
Store: store,
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
idpManager: idpManager,
|
||||
ctx: context.Background(),
|
||||
cacheMux: sync.Mutex{},
|
||||
cacheLoading: map[string]chan struct{}{},
|
||||
dnsDomain: dnsDomain,
|
||||
eventStore: eventStore,
|
||||
peerLoginExpiry: NewDefaultScheduler(),
|
||||
Store: store,
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
idpManager: idpManager,
|
||||
ctx: context.Background(),
|
||||
cacheMux: sync.Mutex{},
|
||||
cacheLoading: map[string]chan struct{}{},
|
||||
dnsDomain: dnsDomain,
|
||||
eventStore: eventStore,
|
||||
peerLoginExpiry: NewDefaultScheduler(),
|
||||
userDeleteFromIDPEnabled: userDeleteFromIDPEnabled,
|
||||
}
|
||||
allAccounts := store.GetAllAccounts()
|
||||
// enable single account mode only if configured by user and number of existing accounts is not grater than 1
|
||||
@@ -875,33 +876,19 @@ func (am *DefaultAccountManager) peerLoginExpirationJob(accountID string) func()
|
||||
return account.GetNextPeerExpiration()
|
||||
}
|
||||
|
||||
expiredPeers := account.GetExpiredPeers()
|
||||
var peerIDs []string
|
||||
for _, peer := range account.GetExpiredPeers() {
|
||||
if peer.Status.LoginExpired {
|
||||
continue
|
||||
}
|
||||
for _, peer := range expiredPeers {
|
||||
peerIDs = append(peerIDs, peer.ID)
|
||||
peer.MarkLoginExpired(true)
|
||||
account.UpdatePeer(peer)
|
||||
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
|
||||
if err != nil {
|
||||
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
|
||||
return account.GetNextPeerExpiration()
|
||||
}
|
||||
am.storeEvent(peer.UserID, peer.ID, account.Id, activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()))
|
||||
}
|
||||
|
||||
log.Debugf("discovered %d peers to expire for account %s", len(peerIDs), account.Id)
|
||||
|
||||
if len(peerIDs) != 0 {
|
||||
// this will trigger peer disconnect from the management service
|
||||
am.peersUpdateManager.CloseChannels(peerIDs)
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
log.Errorf("failed updating account peers while expiring peers for account %s", accountID)
|
||||
return account.GetNextPeerExpiration()
|
||||
}
|
||||
if err := am.expireAndUpdatePeers(account, expiredPeers); err != nil {
|
||||
log.Errorf("failed updating account peers while expiring peers for account %s", account.Id)
|
||||
return account.GetNextPeerExpiration()
|
||||
}
|
||||
|
||||
return account.GetNextPeerExpiration()
|
||||
}
|
||||
}
|
||||
@@ -1672,19 +1659,3 @@ func newAccountWithId(accountID, userID, domain string) *Account {
|
||||
}
|
||||
return acc
|
||||
}
|
||||
|
||||
func removeFromList(inputList []string, toRemove []string) []string {
|
||||
toRemoveMap := make(map[string]struct{})
|
||||
for _, item := range toRemove {
|
||||
toRemoveMap[item] = struct{}{}
|
||||
}
|
||||
|
||||
var resultList []string
|
||||
for _, item := range inputList {
|
||||
_, ok := toRemoveMap[item]
|
||||
if !ok {
|
||||
resultList = append(resultList, item)
|
||||
}
|
||||
}
|
||||
return resultList
|
||||
}
|
||||
|
||||
@@ -2063,7 +2063,7 @@ func createManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore)
|
||||
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.cloud", eventStore, false)
|
||||
}
|
||||
|
||||
func createStore(t *testing.T) (Store, error) {
|
||||
|
||||
@@ -104,6 +104,8 @@ const (
|
||||
UserBlocked
|
||||
// UserUnblocked indicates that a user unblocked another user
|
||||
UserUnblocked
|
||||
// UserDeleted indicates that a user deleted another user
|
||||
UserDeleted
|
||||
// GroupDeleted indicates that a user deleted group
|
||||
GroupDeleted
|
||||
// UserLoggedInPeer indicates that user logged in their peer with an interactive SSO login
|
||||
@@ -162,6 +164,7 @@ var activityMap = map[Activity]Code{
|
||||
ServiceUserDeleted: {"Service user deleted", "service.user.delete"},
|
||||
UserBlocked: {"User blocked", "user.block"},
|
||||
UserUnblocked: {"User unblocked", "user.unblock"},
|
||||
UserDeleted: {"User deleted", "user.delete"},
|
||||
GroupDeleted: {"Group deleted", "group.delete"},
|
||||
UserLoggedInPeer: {"User logged in peer", "user.peer.login"},
|
||||
PeerLoginExpired: {"Peer login expired", "peer.login.expire"},
|
||||
|
||||
@@ -18,10 +18,13 @@ type Event struct {
|
||||
ID uint64
|
||||
// InitiatorID is the ID of an object that initiated the event (e.g., a user)
|
||||
InitiatorID string
|
||||
// InitiatorEmail is the email address of an object that initiated the event. This will be set on deleted users only
|
||||
InitiatorEmail string
|
||||
// TargetID is the ID of an object that was effected by the event (e.g., a peer)
|
||||
TargetID string
|
||||
// AccountID is the ID of an account where the event happened
|
||||
AccountID string
|
||||
|
||||
// Meta of the event, e.g. deleted peer information like name, IP, etc
|
||||
Meta map[string]any
|
||||
}
|
||||
@@ -35,12 +38,13 @@ func (e *Event) Copy() *Event {
|
||||
}
|
||||
|
||||
return &Event{
|
||||
Timestamp: e.Timestamp,
|
||||
Activity: e.Activity,
|
||||
ID: e.ID,
|
||||
InitiatorID: e.InitiatorID,
|
||||
TargetID: e.TargetID,
|
||||
AccountID: e.AccountID,
|
||||
Meta: meta,
|
||||
Timestamp: e.Timestamp,
|
||||
Activity: e.Activity,
|
||||
ID: e.ID,
|
||||
InitiatorID: e.InitiatorID,
|
||||
InitiatorEmail: e.InitiatorEmail,
|
||||
TargetID: e.TargetID,
|
||||
AccountID: e.AccountID,
|
||||
Meta: meta,
|
||||
}
|
||||
}
|
||||
|
||||
81
management/server/activity/sqlite/crypt.go
Normal file
81
management/server/activity/sqlite/crypt.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05}
|
||||
|
||||
type EmailEncrypt struct {
|
||||
block cipher.Block
|
||||
}
|
||||
|
||||
func GenerateKey() (string, error) {
|
||||
key := make([]byte, 32)
|
||||
_, err := rand.Read(key)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
readableKey := base64.StdEncoding.EncodeToString(key)
|
||||
return readableKey, nil
|
||||
}
|
||||
|
||||
func NewEmailEncrypt(key string) (*EmailEncrypt, error) {
|
||||
binKey, err := base64.StdEncoding.DecodeString(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(binKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ec := &EmailEncrypt{
|
||||
block: block,
|
||||
}
|
||||
|
||||
return ec, nil
|
||||
}
|
||||
|
||||
func (ec *EmailEncrypt) Encrypt(payload string) string {
|
||||
plainText := pkcs5Padding([]byte(payload))
|
||||
cipherText := make([]byte, len(plainText))
|
||||
cbc := cipher.NewCBCEncrypter(ec.block, iv)
|
||||
cbc.CryptBlocks(cipherText, plainText)
|
||||
return base64.StdEncoding.EncodeToString(cipherText)
|
||||
}
|
||||
|
||||
func (ec *EmailEncrypt) Decrypt(data string) (string, error) {
|
||||
cipherText, err := base64.StdEncoding.DecodeString(data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
cbc := cipher.NewCBCDecrypter(ec.block, iv)
|
||||
cbc.CryptBlocks(cipherText, cipherText)
|
||||
payload, err := pkcs5UnPadding(cipherText)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(payload), nil
|
||||
}
|
||||
|
||||
func pkcs5Padding(ciphertext []byte) []byte {
|
||||
padding := aes.BlockSize - len(ciphertext)%aes.BlockSize
|
||||
padText := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(ciphertext, padText...)
|
||||
}
|
||||
|
||||
func pkcs5UnPadding(src []byte) ([]byte, error) {
|
||||
srcLen := len(src)
|
||||
paddingLen := int(src[srcLen-1])
|
||||
if paddingLen >= srcLen || paddingLen > aes.BlockSize {
|
||||
return nil, fmt.Errorf("padding size error")
|
||||
}
|
||||
return src[:srcLen-paddingLen], nil
|
||||
}
|
||||
63
management/server/activity/sqlite/crypt_test.go
Normal file
63
management/server/activity/sqlite/crypt_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGenerateKey(t *testing.T) {
|
||||
testData := "exampl@netbird.io"
|
||||
key, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %s", err)
|
||||
}
|
||||
ee, err := NewEmailEncrypt(key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to init email encryption: %s", err)
|
||||
}
|
||||
|
||||
encrypted := ee.Encrypt(testData)
|
||||
if encrypted == "" {
|
||||
t.Fatalf("invalid encrypted text")
|
||||
}
|
||||
|
||||
decrypted, err := ee.Decrypt(encrypted)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decrypt data: %s", err)
|
||||
}
|
||||
|
||||
if decrypted != testData {
|
||||
t.Fatalf("decrypted data is not match with test data: %s, %s", testData, decrypted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCorruptKey(t *testing.T) {
|
||||
testData := "exampl@netbird.io"
|
||||
key, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %s", err)
|
||||
}
|
||||
ee, err := NewEmailEncrypt(key)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to init email encryption: %s", err)
|
||||
}
|
||||
|
||||
encrypted := ee.Encrypt(testData)
|
||||
if encrypted == "" {
|
||||
t.Fatalf("invalid encrypted text")
|
||||
}
|
||||
|
||||
newKey, err := GenerateKey()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate key: %s", err)
|
||||
}
|
||||
|
||||
ee, err = NewEmailEncrypt(newKey)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to init email encryption: %s", err)
|
||||
}
|
||||
|
||||
res, err := ee.Decrypt(encrypted)
|
||||
if err == nil || res == testData {
|
||||
t.Fatalf("incorrect decryption, the result is: %s", res)
|
||||
}
|
||||
}
|
||||
@@ -3,14 +3,14 @@ package sqlite
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
|
||||
// sqlite driver
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
_ "github.com/mattn/go-sqlite3" // sqlite driver
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -25,35 +25,62 @@ const (
|
||||
"meta TEXT," +
|
||||
" target_id TEXT);"
|
||||
|
||||
selectDescQuery = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" +
|
||||
" FROM events WHERE account_id = ? ORDER BY timestamp DESC LIMIT ? OFFSET ?;"
|
||||
selectAscQuery = "SELECT id, activity, timestamp, initiator_id, target_id, account_id, meta" +
|
||||
" FROM events WHERE account_id = ? ORDER BY timestamp ASC LIMIT ? OFFSET ?;"
|
||||
creatTableAccountEmailQuery = `CREATE TABLE IF NOT EXISTS deleted_users (id TEXT NOT NULL, email TEXT NOT NULL);`
|
||||
|
||||
selectDescQuery = `SELECT events.id, activity, timestamp, initiator_id, i.email as "initiator_email", target_id, t.email as "target_email", account_id, meta
|
||||
FROM events
|
||||
LEFT JOIN deleted_users i ON events.initiator_id = i.id
|
||||
LEFT JOIN deleted_users t ON events.target_id = t.id
|
||||
WHERE account_id = ?
|
||||
ORDER BY timestamp DESC LIMIT ? OFFSET ?;`
|
||||
|
||||
selectAscQuery = `SELECT events.id, activity, timestamp, initiator_id, i.email as "initiator_email", target_id, t.email as "target_email", account_id, meta
|
||||
FROM events
|
||||
LEFT JOIN deleted_users i ON events.initiator_id = i.id
|
||||
LEFT JOIN deleted_users t ON events.target_id = t.id
|
||||
WHERE account_id = ?
|
||||
ORDER BY timestamp ASC LIMIT ? OFFSET ?;`
|
||||
|
||||
insertQuery = "INSERT INTO events(activity, timestamp, initiator_id, target_id, account_id, meta) " +
|
||||
"VALUES(?, ?, ?, ?, ?, ?)"
|
||||
|
||||
insertDeleteUserQuery = `INSERT INTO deleted_users(id, email) VALUES(?, ?)`
|
||||
)
|
||||
|
||||
// Store is the implementation of the activity.Store interface backed by SQLite
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
emailEncrypt *EmailEncrypt
|
||||
|
||||
insertStatement *sql.Stmt
|
||||
selectAscStatement *sql.Stmt
|
||||
selectDescStatement *sql.Stmt
|
||||
deleteUserStmt *sql.Stmt
|
||||
}
|
||||
|
||||
// NewSQLiteStore creates a new Store with an event table if not exists.
|
||||
func NewSQLiteStore(dataDir string) (*Store, error) {
|
||||
func NewSQLiteStore(dataDir string, encryptionKey string) (*Store, error) {
|
||||
dbFile := filepath.Join(dataDir, eventSinkDB)
|
||||
db, err := sql.Open("sqlite3", dbFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
crypt, err := NewEmailEncrypt(encryptionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = db.Exec(createTableQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, err = db.Exec(creatTableAccountEmailQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
insertStmt, err := db.Prepare(insertQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -69,25 +96,35 @@ func NewSQLiteStore(dataDir string) (*Store, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Store{
|
||||
deleteUserStmt, err := db.Prepare(insertDeleteUserQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &Store{
|
||||
db: db,
|
||||
emailEncrypt: crypt,
|
||||
insertStatement: insertStmt,
|
||||
selectDescStatement: selectDescStmt,
|
||||
selectAscStatement: selectAscStmt,
|
||||
}, nil
|
||||
deleteUserStmt: deleteUserStmt,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func processResult(result *sql.Rows) ([]*activity.Event, error) {
|
||||
func (store *Store) processResult(result *sql.Rows) ([]*activity.Event, error) {
|
||||
events := make([]*activity.Event, 0)
|
||||
for result.Next() {
|
||||
var id int64
|
||||
var operation activity.Activity
|
||||
var timestamp time.Time
|
||||
var initiator string
|
||||
var initiatorEmail *string
|
||||
var target string
|
||||
var targetEmail *string
|
||||
var account string
|
||||
var jsonMeta string
|
||||
err := result.Scan(&id, &operation, ×tamp, &initiator, &target, &account, &jsonMeta)
|
||||
err := result.Scan(&id, &operation, ×tamp, &initiator, &initiatorEmail, &target, &targetEmail, &account, &jsonMeta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -100,7 +137,17 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) {
|
||||
}
|
||||
}
|
||||
|
||||
events = append(events, &activity.Event{
|
||||
if targetEmail != nil {
|
||||
email, err := store.emailEncrypt.Decrypt(*targetEmail)
|
||||
if err != nil {
|
||||
log.Errorf("failed to decrypt email address for target id: %s", target)
|
||||
meta["email"] = ""
|
||||
} else {
|
||||
meta["email"] = email
|
||||
}
|
||||
}
|
||||
|
||||
event := &activity.Event{
|
||||
Timestamp: timestamp,
|
||||
Activity: operation,
|
||||
ID: uint64(id),
|
||||
@@ -108,7 +155,18 @@ func processResult(result *sql.Rows) ([]*activity.Event, error) {
|
||||
TargetID: target,
|
||||
AccountID: account,
|
||||
Meta: meta,
|
||||
})
|
||||
}
|
||||
|
||||
if initiatorEmail != nil {
|
||||
email, err := store.emailEncrypt.Decrypt(*initiatorEmail)
|
||||
if err != nil {
|
||||
log.Errorf("failed to decrypt email address of initiator: %s", initiator)
|
||||
} else {
|
||||
event.InitiatorEmail = email
|
||||
}
|
||||
}
|
||||
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
return events, nil
|
||||
@@ -127,13 +185,18 @@ func (store *Store) Get(accountID string, offset, limit int, descending bool) ([
|
||||
}
|
||||
|
||||
defer result.Close() //nolint
|
||||
return processResult(result)
|
||||
return store.processResult(result)
|
||||
}
|
||||
|
||||
// Save an event in the SQLite events table
|
||||
// Save an event in the SQLite events table end encrypt the "email" element in meta map
|
||||
func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
|
||||
var jsonMeta string
|
||||
if event.Meta != nil {
|
||||
meta, err := store.saveDeletedUserEmailInEncrypted(event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if meta != nil {
|
||||
metaBytes, err := json.Marshal(event.Meta)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -156,6 +219,29 @@ func (store *Store) Save(event *activity.Event) (*activity.Event, error) {
|
||||
return eventCopy, nil
|
||||
}
|
||||
|
||||
// saveDeletedUserEmailInEncrypted if the meta contains email then store it in encrypted way and delete this item from
|
||||
// meta map
|
||||
func (store *Store) saveDeletedUserEmailInEncrypted(event *activity.Event) (map[string]any, error) {
|
||||
email, ok := event.Meta["email"]
|
||||
if !ok {
|
||||
return event.Meta, nil
|
||||
}
|
||||
|
||||
delete(event.Meta, "email")
|
||||
|
||||
encrypted := store.emailEncrypt.Encrypt(fmt.Sprintf("%s", email))
|
||||
_, err := store.deleteUserStmt.Exec(event.TargetID, encrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(event.Meta) == 1 {
|
||||
return nil, nil // nolint
|
||||
}
|
||||
delete(event.Meta, "email")
|
||||
return event.Meta, nil
|
||||
}
|
||||
|
||||
// Close the Store
|
||||
func (store *Store) Close() error {
|
||||
if store.db != nil {
|
||||
|
||||
@@ -12,7 +12,8 @@ import (
|
||||
|
||||
func TestNewSQLiteStore(t *testing.T) {
|
||||
dataDir := t.TempDir()
|
||||
store, err := NewSQLiteStore(dataDir)
|
||||
key, _ := GenerateKey()
|
||||
store, err := NewSQLiteStore(dataDir, key)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
|
||||
@@ -35,7 +35,8 @@ type Config struct {
|
||||
TURNConfig *TURNConfig
|
||||
Signal *Host
|
||||
|
||||
Datadir string
|
||||
Datadir string
|
||||
DataStoreEncryptionKey string
|
||||
|
||||
HttpConfig *HttpServerConfig
|
||||
|
||||
|
||||
@@ -191,7 +191,7 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.test", eventStore)
|
||||
return BuildManager(store, NewPeersUpdateManager(), nil, "", "netbird.test", eventStore, false)
|
||||
}
|
||||
|
||||
func createDNSStore(t *testing.T) (Store, error) {
|
||||
|
||||
@@ -33,26 +33,6 @@ type Group struct {
|
||||
Peers []string
|
||||
}
|
||||
|
||||
const (
|
||||
// UpdateGroupName indicates a name update operation
|
||||
UpdateGroupName GroupUpdateOperationType = iota
|
||||
// InsertPeersToGroup indicates insert peers to group operation
|
||||
InsertPeersToGroup
|
||||
// RemovePeersFromGroup indicates a remove peers from group operation
|
||||
RemovePeersFromGroup
|
||||
// UpdateGroupPeers indicates a replacement of group peers list
|
||||
UpdateGroupPeers
|
||||
)
|
||||
|
||||
// GroupUpdateOperationType operation type
|
||||
type GroupUpdateOperationType int
|
||||
|
||||
// GroupUpdateOperation operation object with type and values to be applied
|
||||
type GroupUpdateOperation struct {
|
||||
Type GroupUpdateOperationType
|
||||
Values []string
|
||||
}
|
||||
|
||||
// EventMeta returns activity event meta related to the group
|
||||
func (g *Group) EventMeta() map[string]any {
|
||||
return map[string]any{"name": g.Name}
|
||||
@@ -165,57 +145,6 @@ func difference(a, b []string) []string {
|
||||
return diff
|
||||
}
|
||||
|
||||
// UpdateGroup updates a group using a list of operations
|
||||
func (am *DefaultAccountManager) UpdateGroup(accountID string,
|
||||
groupID string, operations []GroupUpdateOperation,
|
||||
) (*Group, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groupToUpdate, ok := account.Groups[groupID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "group with ID %s no longer exists", groupID)
|
||||
}
|
||||
|
||||
group := groupToUpdate.Copy()
|
||||
|
||||
for _, operation := range operations {
|
||||
switch operation.Type {
|
||||
case UpdateGroupName:
|
||||
group.Name = operation.Values[0]
|
||||
case UpdateGroupPeers:
|
||||
group.Peers = operation.Values
|
||||
case InsertPeersToGroup:
|
||||
sourceList := group.Peers
|
||||
resultList := removeFromList(sourceList, operation.Values)
|
||||
group.Peers = append(resultList, operation.Values...)
|
||||
case RemovePeersFromGroup:
|
||||
sourceList := group.Peers
|
||||
resultList := removeFromList(sourceList, operation.Values)
|
||||
group.Peers = resultList
|
||||
}
|
||||
}
|
||||
|
||||
account.Groups[groupID] = group
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return group, nil
|
||||
}
|
||||
|
||||
// DeleteGroup object of the peers
|
||||
func (am *DefaultAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountId)
|
||||
|
||||
0
management/server/http/api/generate.sh
Normal file → Executable file
0
management/server/http/api/generate.sh
Normal file → Executable file
@@ -922,6 +922,10 @@ components:
|
||||
description: The ID of the initiator of the event. E.g., an ID of a user that triggered the event.
|
||||
type: string
|
||||
example: google-oauth2|123456789012345678901
|
||||
initiator_email:
|
||||
description: The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event.
|
||||
type: string
|
||||
example: demo@netbird.io
|
||||
target_id:
|
||||
description: The ID of the target of the event. E.g., an ID of the peer that a user removed.
|
||||
type: string
|
||||
@@ -938,6 +942,7 @@ components:
|
||||
- activity
|
||||
- activity_code
|
||||
- initiator_id
|
||||
- initiator_email
|
||||
- target_id
|
||||
- meta
|
||||
responses:
|
||||
|
||||
@@ -164,6 +164,9 @@ type Event struct {
|
||||
// Id Event unique identifier
|
||||
Id string `json:"id"`
|
||||
|
||||
// InitiatorEmail The e-mail address of the initiator of the event. E.g., an e-mail of a user that triggered the event.
|
||||
InitiatorEmail string `json:"initiator_email"`
|
||||
|
||||
// InitiatorId The ID of the initiator of the event. E.g., an ID of a user that triggered the event.
|
||||
InitiatorId string `json:"initiator_id"`
|
||||
|
||||
|
||||
@@ -45,14 +45,46 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
events := make([]*api.Event, 0)
|
||||
for _, e := range accountEvents {
|
||||
events = append(events, toEventResponse(e))
|
||||
events := make([]*api.Event, len(accountEvents))
|
||||
for i, e := range accountEvents {
|
||||
events[i] = toEventResponse(e)
|
||||
}
|
||||
|
||||
err = h.fillEventsWithInitiatorEmail(events, account.Id, user.Id)
|
||||
if err != nil {
|
||||
util.WriteError(err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(w, events)
|
||||
}
|
||||
|
||||
func (h *EventsHandler) fillEventsWithInitiatorEmail(events []*api.Event, accountId, userId string) error {
|
||||
// build email map based on users
|
||||
userInfos, err := h.accountManager.GetUsersFromAccount(accountId, userId)
|
||||
if err != nil {
|
||||
log.Errorf("failed to get users from account: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
emails := make(map[string]string)
|
||||
for _, ui := range userInfos {
|
||||
emails[ui.ID] = ui.Email
|
||||
}
|
||||
|
||||
// fill event with email of initiator
|
||||
var ok bool
|
||||
for _, event := range events {
|
||||
if event.InitiatorEmail == "" {
|
||||
event.InitiatorEmail, ok = emails[event.InitiatorId]
|
||||
if !ok {
|
||||
log.Warnf("failed to resolve email for initiator: %s", event.InitiatorId)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func toEventResponse(event *activity.Event) *api.Event {
|
||||
meta := make(map[string]string)
|
||||
if event.Meta != nil {
|
||||
@@ -60,13 +92,15 @@ func toEventResponse(event *activity.Event) *api.Event {
|
||||
meta[s] = fmt.Sprintf("%v", a)
|
||||
}
|
||||
}
|
||||
return &api.Event{
|
||||
Id: fmt.Sprint(event.ID),
|
||||
InitiatorId: event.InitiatorID,
|
||||
Activity: event.Activity.Message(),
|
||||
ActivityCode: api.EventActivityCode(event.Activity.StringCode()),
|
||||
TargetId: event.TargetID,
|
||||
Timestamp: event.Timestamp,
|
||||
Meta: meta,
|
||||
e := &api.Event{
|
||||
Id: fmt.Sprint(event.ID),
|
||||
InitiatorId: event.InitiatorID,
|
||||
InitiatorEmail: event.InitiatorEmail,
|
||||
Activity: event.Activity.Message(),
|
||||
ActivityCode: api.EventActivityCode(event.Activity.StringCode()),
|
||||
TargetId: event.TargetID,
|
||||
Timestamp: event.Timestamp,
|
||||
Meta: meta,
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
@@ -37,6 +37,9 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E
|
||||
},
|
||||
}, user, nil
|
||||
},
|
||||
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
|
||||
return make([]*server.UserInfo, 0), nil
|
||||
},
|
||||
},
|
||||
claimsExtractor: jwtclaims.NewClaimsExtractor(
|
||||
jwtclaims.WithFromRequestContext(func(r *http.Request) jwtclaims.AuthorizationClaims {
|
||||
|
||||
@@ -53,22 +53,6 @@ func initGroupTestData(user *server.User, groups ...*server.Group) *GroupsHandle
|
||||
Issued: server.GroupIssuedAPI,
|
||||
}, nil
|
||||
},
|
||||
UpdateGroupFunc: func(_ string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) {
|
||||
var group server.Group
|
||||
group.ID = groupID
|
||||
for _, operation := range operations {
|
||||
switch operation.Type {
|
||||
case server.UpdateGroupName:
|
||||
group.Name = operation.Values[0]
|
||||
case server.UpdateGroupPeers, server.InsertPeersToGroup:
|
||||
group.Peers = operation.Values
|
||||
case server.RemovePeersFromGroup:
|
||||
default:
|
||||
return nil, fmt.Errorf("no operation")
|
||||
}
|
||||
}
|
||||
return &group, nil
|
||||
},
|
||||
GetPeerByIPFunc: func(_ string, peerIP string) (*server.Peer, error) {
|
||||
for _, peer := range TestPeers {
|
||||
if peer.IP.String() == peerIP {
|
||||
|
||||
@@ -88,31 +88,6 @@ func initNameserversTestData() *NameserversHandler {
|
||||
}
|
||||
return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID)
|
||||
},
|
||||
UpdateNameServerGroupFunc: func(accountID, nsGroupID, _ string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
|
||||
nsGroupToUpdate := baseExistingNSGroup.Copy()
|
||||
if nsGroupID != nsGroupToUpdate.ID {
|
||||
return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
|
||||
}
|
||||
for _, operation := range operations {
|
||||
switch operation.Type {
|
||||
case server.UpdateNameServerGroupName:
|
||||
nsGroupToUpdate.Name = operation.Values[0]
|
||||
case server.UpdateNameServerGroupDescription:
|
||||
nsGroupToUpdate.Description = operation.Values[0]
|
||||
case server.UpdateNameServerGroupNameServers:
|
||||
var parsedNSList []nbdns.NameServer
|
||||
for _, nsURL := range operation.Values {
|
||||
parsed, err := nbdns.ParseNameServerURL(nsURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsedNSList = append(parsedNSList, parsed)
|
||||
}
|
||||
nsGroupToUpdate.NameServers = parsedNSList
|
||||
}
|
||||
}
|
||||
return nsGroupToUpdate, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingNSAccount, testingAccount.Users["test_user"], nil
|
||||
},
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/http/api"
|
||||
@@ -108,38 +107,6 @@ func initRoutesTestData() *RoutesHandler {
|
||||
IP: netip.MustParseAddr(existingPeerID).AsSlice(),
|
||||
}, nil
|
||||
},
|
||||
UpdateRouteFunc: func(_ string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) {
|
||||
routeToUpdate := baseExistingRoute
|
||||
if routeID != routeToUpdate.ID {
|
||||
return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
|
||||
}
|
||||
for _, operation := range operations {
|
||||
switch operation.Type {
|
||||
case server.UpdateRouteNetwork:
|
||||
routeToUpdate.NetworkType, routeToUpdate.Network, _ = route.ParseNetwork(operation.Values[0])
|
||||
case server.UpdateRouteDescription:
|
||||
routeToUpdate.Description = operation.Values[0]
|
||||
case server.UpdateRouteNetworkIdentifier:
|
||||
routeToUpdate.NetID = operation.Values[0]
|
||||
case server.UpdateRoutePeer:
|
||||
routeToUpdate.Peer = operation.Values[0]
|
||||
if routeToUpdate.Peer == notFoundPeerID {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", routeToUpdate.Peer)
|
||||
}
|
||||
case server.UpdateRouteMetric:
|
||||
routeToUpdate.Metric, _ = strconv.Atoi(operation.Values[0])
|
||||
case server.UpdateRouteMasquerade:
|
||||
routeToUpdate.Masquerade, _ = strconv.ParseBool(operation.Values[0])
|
||||
case server.UpdateRouteEnabled:
|
||||
routeToUpdate.Enabled, _ = strconv.ParseBool(operation.Values[0])
|
||||
case server.UpdateRouteGroups:
|
||||
routeToUpdate.Groups = operation.Values
|
||||
default:
|
||||
return nil, fmt.Errorf("no operation")
|
||||
}
|
||||
}
|
||||
return routeToUpdate, nil
|
||||
},
|
||||
GetAccountFromTokenFunc: func(_ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) {
|
||||
return testingAccount, testingAccount.Users["test_user"], nil
|
||||
},
|
||||
|
||||
@@ -513,7 +513,9 @@ func buildUserExportRequest() (string, error) {
|
||||
return string(str), nil
|
||||
}
|
||||
|
||||
func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) {
|
||||
func (am *Auth0Manager) createRequest(
|
||||
method string, endpoint string, body io.Reader,
|
||||
) (*http.Request, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -521,17 +523,23 @@ func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*
|
||||
|
||||
reqURL := am.authIssuer + endpoint
|
||||
|
||||
payload := strings.NewReader(payloadStr)
|
||||
|
||||
req, err := http.NewRequest("POST", reqURL, payload)
|
||||
req, err := http.NewRequest(method, reqURL, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func (am *Auth0Manager) createPostRequest(endpoint string, payloadStr string) (*http.Request, error) {
|
||||
req, err := am.createRequest("POST", endpoint, strings.NewReader(payloadStr))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
return req, nil
|
||||
|
||||
}
|
||||
|
||||
// GetAllAccounts gets all registered accounts with corresponding user data.
|
||||
@@ -737,6 +745,38 @@ func (am *Auth0Manager) InviteUserByID(userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteUser from Auth0
|
||||
func (am *Auth0Manager) DeleteUser(userID string) error {
|
||||
req, err := am.createRequest(http.MethodDelete, "/api/v2/users/"+url.QueryEscape(userID), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Debugf("execute delete request: %v", err)
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
log.Errorf("close delete request body: %v", err)
|
||||
}
|
||||
}()
|
||||
if resp.StatusCode != 204 {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkExportJobStatus checks the status of the job created at CreateExportUsersJob.
|
||||
// If the status is "completed", then return the downloadLink
|
||||
func (am *Auth0Manager) checkExportJobStatus(jobID string) (bool, string, error) {
|
||||
|
||||
@@ -458,6 +458,38 @@ func (am *AuthentikManager) InviteUserByID(_ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Authentik
|
||||
func (am *AuthentikManager) DeleteUser(userID string) error {
|
||||
ctx, err := am.authenticationContext()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userPk, err := strconv.ParseInt(userID, 10, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp, err := am.apiClient.CoreApi.CoreUsersDestroy(ctx, int32(userPk)).Execute()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close() // nolint
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user %s, statusCode %d", userID, resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *AuthentikManager) authenticationContext() (context.Context, error) {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
if err != nil {
|
||||
|
||||
@@ -454,6 +454,43 @@ func (am *AzureManager) InviteUserByID(_ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Azure
|
||||
func (am *AzureManager) DeleteUser(userID string) error {
|
||||
jwtToken, err := am.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/users/%s", am.GraphAPIEndpoint, url.QueryEscape(userID))
|
||||
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
log.Debugf("delete idp user %s", userID)
|
||||
|
||||
resp, err := am.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if am.appMetrics != nil {
|
||||
am.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *AzureManager) getUserExtensions() ([]azureExtension, error) {
|
||||
q := url.Values{}
|
||||
q.Add("$select", extensionFields)
|
||||
|
||||
@@ -254,6 +254,19 @@ func (gm *GoogleWorkspaceManager) InviteUserByID(_ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from GoogleWorkspace.
|
||||
func (gm *GoogleWorkspaceManager) DeleteUser(userID string) error {
|
||||
if err := gm.usersService.Delete(userID).Do(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if gm.appMetrics != nil {
|
||||
gm.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey.
|
||||
// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it.
|
||||
// If that fails, it falls back to using the default Google credentials path.
|
||||
|
||||
@@ -26,6 +26,7 @@ type Manager interface {
|
||||
CreateUser(email, name, accountID, invitedByEmail string) (*UserData, error)
|
||||
GetUserByEmail(email string) ([]*UserData, error)
|
||||
InviteUserByID(userID string) error
|
||||
DeleteUser(userID string) error
|
||||
}
|
||||
|
||||
// ClientConfig defines common client configuration for all IdP manager
|
||||
|
||||
@@ -467,6 +467,47 @@ func (km *KeycloakManager) InviteUserByID(_ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Keycloack
|
||||
func (km *KeycloakManager) DeleteUser(userID string) error {
|
||||
jwtToken, err := km.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/users/%s", km.adminEndpoint, url.QueryEscape(userID))
|
||||
|
||||
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
resp, err := km.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close() // nolint
|
||||
|
||||
// In the docs, they specified 200, but in the endpoints, they return 204
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent {
|
||||
if km.appMetrics != nil {
|
||||
km.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildKeycloakCreateUserRequestPayload(email string, name string, appMetadata AppMetadata) (string, error) {
|
||||
attrs := keycloakUserAttributes{}
|
||||
attrs.Set(wtAccountID, appMetadata.WTAccountID)
|
||||
|
||||
@@ -319,6 +319,28 @@ func (om *OktaManager) InviteUserByID(_ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Okta
|
||||
func (om *OktaManager) DeleteUser(userID string) error {
|
||||
resp, err := om.client.User.DeactivateOrDeleteUser(context.Background(), userID, nil)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
return err
|
||||
}
|
||||
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if om.appMetrics != nil {
|
||||
om.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
return fmt.Errorf("unable to delete user, statusCode %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateUserProfileSchema updates the Okta user schema to include custom fields,
|
||||
// wt_account_id and wt_pending_invite.
|
||||
func updateUserProfileSchema(client *okta.Client) error {
|
||||
|
||||
@@ -428,7 +428,7 @@ func (zm *ZitadelManager) UpdateUserAppMetadata(userID string, appMetadata AppMe
|
||||
return err
|
||||
}
|
||||
|
||||
resource := fmt.Sprintf("users/%s/metadata/_bulk", userID)
|
||||
resource := fmt.Sprintf("users/%s", userID)
|
||||
_, err = zm.post(resource, string(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -447,6 +447,21 @@ func (zm *ZitadelManager) InviteUserByID(_ string) error {
|
||||
return fmt.Errorf("method InviteUserByID not implemented")
|
||||
}
|
||||
|
||||
// DeleteUser from Zitadel
|
||||
func (zm *ZitadelManager) DeleteUser(userID string) error {
|
||||
resource := fmt.Sprintf("users/%s", userID)
|
||||
if err := zm.delete(resource); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountDeleteUser()
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
// getUserMetadata requests user metadata from zitadel via ID.
|
||||
func (zm *ZitadelManager) getUserMetadata(userID string) ([]zitadelMetadata, error) {
|
||||
resource := fmt.Sprintf("users/%s/metadata/_search", userID)
|
||||
@@ -500,6 +515,42 @@ func (zm *ZitadelManager) post(resource string, body string) ([]byte, error) {
|
||||
return io.ReadAll(resp.Body)
|
||||
}
|
||||
|
||||
// delete perform Delete requests.
|
||||
func (zm *ZitadelManager) delete(resource string) error {
|
||||
jwtToken, err := zm.credentials.Authenticate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
reqURL := fmt.Sprintf("%s/%s", zm.managementEndpoint, resource)
|
||||
req, err := http.NewRequest(http.MethodDelete, reqURL, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Add("authorization", "Bearer "+jwtToken.AccessToken)
|
||||
req.Header.Add("content-type", "application/json")
|
||||
|
||||
resp, err := zm.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountRequestError()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != 200 {
|
||||
if zm.appMetrics != nil {
|
||||
zm.appMetrics.IDPMetrics().CountRequestStatusError()
|
||||
}
|
||||
|
||||
return fmt.Errorf("unable to delete %s, statusCode %d", reqURL, resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// get perform Get requests.
|
||||
func (zm *ZitadelManager) get(resource string, q url.Values) ([]byte, error) {
|
||||
jwtToken, err := zm.credentials.Authenticate()
|
||||
|
||||
@@ -412,7 +412,7 @@ func startManagement(t *testing.T, config *Config) (*grpc.Server, string, error)
|
||||
peersUpdateManager := NewPeersUpdateManager()
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
accountManager, err := BuildManager(store, peersUpdateManager, nil, "", "",
|
||||
eventStore)
|
||||
eventStore, false)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
@@ -503,7 +503,7 @@ func startServer(config *server.Config) (*grpc.Server, net.Listener) {
|
||||
peersUpdateManager := server.NewPeersUpdateManager()
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
accountManager, err := server.BuildManager(store, peersUpdateManager, nil, "", "",
|
||||
eventStore)
|
||||
eventStore, false)
|
||||
if err != nil {
|
||||
log.Fatalf("failed creating a manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -31,7 +31,6 @@ type MockAccountManager struct {
|
||||
AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, *server.NetworkMap, error)
|
||||
GetGroupFunc func(accountID, groupID string) (*server.Group, error)
|
||||
SaveGroupFunc func(accountID, userID string, group *server.Group) error
|
||||
UpdateGroupFunc func(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error)
|
||||
DeleteGroupFunc func(accountID, userId, groupID string) error
|
||||
ListGroupsFunc func(accountID string) ([]*server.Group, error)
|
||||
GroupAddPeerFunc func(accountID, groupID, peerKey string) error
|
||||
@@ -54,7 +53,6 @@ type MockAccountManager struct {
|
||||
CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, groups []string, enabled bool, userID string) (*route.Route, error)
|
||||
GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error)
|
||||
SaveRouteFunc func(accountID, userID string, route *route.Route) error
|
||||
UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error)
|
||||
DeleteRouteFunc func(accountID, routeID, userID string) error
|
||||
ListRoutesFunc func(accountID, userID string) ([]*route.Route, error)
|
||||
SaveSetupKeyFunc func(accountID string, key *server.SetupKey, userID string) (*server.SetupKey, error)
|
||||
@@ -68,7 +66,6 @@ type MockAccountManager struct {
|
||||
GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error)
|
||||
CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string) (*nbdns.NameServerGroup, error)
|
||||
SaveNameServerGroupFunc func(accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error
|
||||
UpdateNameServerGroupFunc func(accountID, nsGroupID, userID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error)
|
||||
DeleteNameServerGroupFunc func(accountID, nsGroupID, userID string) error
|
||||
ListNameServerGroupsFunc func(accountID string) ([]*nbdns.NameServerGroup, error)
|
||||
CreateUserFunc func(accountID, userID string, key *server.UserInfo) (*server.UserInfo, error)
|
||||
@@ -267,14 +264,6 @@ func (am *MockAccountManager) SaveGroup(accountID, userID string, group *server.
|
||||
return status.Errorf(codes.Unimplemented, "method SaveGroup is not implemented")
|
||||
}
|
||||
|
||||
// UpdateGroup mock implementation of UpdateGroup from server.AccountManager interface
|
||||
func (am *MockAccountManager) UpdateGroup(accountID string, groupID string, operations []server.GroupUpdateOperation) (*server.Group, error) {
|
||||
if am.UpdateGroupFunc != nil {
|
||||
return am.UpdateGroupFunc(accountID, groupID, operations)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method UpdateGroup not implemented")
|
||||
}
|
||||
|
||||
// DeleteGroup mock implementation of DeleteGroup from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeleteGroup(accountId, userId, groupID string) error {
|
||||
if am.DeleteGroupFunc != nil {
|
||||
@@ -435,14 +424,6 @@ func (am *MockAccountManager) SaveRoute(accountID, userID string, route *route.R
|
||||
return status.Errorf(codes.Unimplemented, "method SaveRoute is not implemented")
|
||||
}
|
||||
|
||||
// UpdateRoute mock implementation of UpdateRoute from server.AccountManager interface
|
||||
func (am *MockAccountManager) UpdateRoute(accountID, ruleID string, operations []server.RouteUpdateOperation) (*route.Route, error) {
|
||||
if am.UpdateRouteFunc != nil {
|
||||
return am.UpdateRouteFunc(accountID, ruleID, operations)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method UpdateRoute not implemented")
|
||||
}
|
||||
|
||||
// DeleteRoute mock implementation of DeleteRoute from server.AccountManager interface
|
||||
func (am *MockAccountManager) DeleteRoute(accountID, routeID, userID string) error {
|
||||
if am.DeleteRouteFunc != nil {
|
||||
@@ -533,14 +514,6 @@ func (am *MockAccountManager) SaveNameServerGroup(accountID, userID string, nsGr
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateNameServerGroup mocks UpdateNameServerGroup of the AccountManager interface
|
||||
func (am *MockAccountManager) UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []server.NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
|
||||
if am.UpdateNameServerGroupFunc != nil {
|
||||
return am.UpdateNameServerGroupFunc(accountID, nsGroupID, userID, operations)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// DeleteNameServerGroup mocks DeleteNameServerGroup of the AccountManager interface
|
||||
func (am *MockAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
|
||||
if am.DeleteNameServerGroupFunc != nil {
|
||||
|
||||
@@ -3,7 +3,6 @@ package server
|
||||
import (
|
||||
"errors"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -15,54 +14,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/status"
|
||||
)
|
||||
|
||||
const (
|
||||
// UpdateNameServerGroupName indicates a nameserver group name update operation
|
||||
UpdateNameServerGroupName NameServerGroupUpdateOperationType = iota
|
||||
// UpdateNameServerGroupDescription indicates a nameserver group description update operation
|
||||
UpdateNameServerGroupDescription
|
||||
// UpdateNameServerGroupNameServers indicates a nameserver group nameservers list update operation
|
||||
UpdateNameServerGroupNameServers
|
||||
// UpdateNameServerGroupGroups indicates a nameserver group' groups update operation
|
||||
UpdateNameServerGroupGroups
|
||||
// UpdateNameServerGroupEnabled indicates a nameserver group status update operation
|
||||
UpdateNameServerGroupEnabled
|
||||
// UpdateNameServerGroupPrimary indicates a nameserver group primary status update operation
|
||||
UpdateNameServerGroupPrimary
|
||||
// UpdateNameServerGroupDomains indicates a nameserver group' domains update operation
|
||||
UpdateNameServerGroupDomains
|
||||
|
||||
domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
|
||||
)
|
||||
|
||||
// NameServerGroupUpdateOperationType operation type
|
||||
type NameServerGroupUpdateOperationType int
|
||||
|
||||
func (t NameServerGroupUpdateOperationType) String() string {
|
||||
switch t {
|
||||
case UpdateNameServerGroupDescription:
|
||||
return "UpdateNameServerGroupDescription"
|
||||
case UpdateNameServerGroupName:
|
||||
return "UpdateNameServerGroupName"
|
||||
case UpdateNameServerGroupNameServers:
|
||||
return "UpdateNameServerGroupNameServers"
|
||||
case UpdateNameServerGroupGroups:
|
||||
return "UpdateNameServerGroupGroups"
|
||||
case UpdateNameServerGroupEnabled:
|
||||
return "UpdateNameServerGroupEnabled"
|
||||
case UpdateNameServerGroupPrimary:
|
||||
return "UpdateNameServerGroupPrimary"
|
||||
case UpdateNameServerGroupDomains:
|
||||
return "UpdateNameServerGroupDomains"
|
||||
default:
|
||||
return "InvalidOperation"
|
||||
}
|
||||
}
|
||||
|
||||
// NameServerGroupUpdateOperation operation object with type and values to be applied
|
||||
type NameServerGroupUpdateOperation struct {
|
||||
Type NameServerGroupUpdateOperationType
|
||||
Values []string
|
||||
}
|
||||
const domainPattern = `^(?i)[a-z0-9]+([\-\.]{1}[a-z0-9]+)*\.[a-z]{2,}$`
|
||||
|
||||
// GetNameServerGroup gets a nameserver group object from account and nameserver group IDs
|
||||
func (am *DefaultAccountManager) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
|
||||
@@ -172,109 +124,6 @@ func (am *DefaultAccountManager) SaveNameServerGroup(accountID, userID string, n
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateNameServerGroup updates existing nameserver group with set of operations
|
||||
func (am *DefaultAccountManager) UpdateNameServerGroup(accountID, nsGroupID, userID string, operations []NameServerGroupUpdateOperation) (*nbdns.NameServerGroup, error) {
|
||||
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(operations) == 0 {
|
||||
return nil, status.Errorf(status.InvalidArgument, "operations shouldn't be empty")
|
||||
}
|
||||
|
||||
nsGroupToUpdate, ok := account.NameServerGroups[nsGroupID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "nameserver group ID %s no longer exists", nsGroupID)
|
||||
}
|
||||
|
||||
newNSGroup := nsGroupToUpdate.Copy()
|
||||
|
||||
for _, operation := range operations {
|
||||
valuesCount := len(operation.Values)
|
||||
if valuesCount < 1 {
|
||||
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be at least 1", operation.Type.String())
|
||||
}
|
||||
|
||||
for _, value := range operation.Values {
|
||||
if value == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid empty string value", operation.Type.String())
|
||||
}
|
||||
}
|
||||
switch operation.Type {
|
||||
case UpdateNameServerGroupDescription:
|
||||
newNSGroup.Description = operation.Values[0]
|
||||
case UpdateNameServerGroupName:
|
||||
if valuesCount > 1 {
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse name values, expected 1 value got %d", valuesCount)
|
||||
}
|
||||
err = validateNSGroupName(operation.Values[0], nsGroupID, account.NameServerGroups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newNSGroup.Name = operation.Values[0]
|
||||
case UpdateNameServerGroupNameServers:
|
||||
var nsList []nbdns.NameServer
|
||||
for _, url := range operation.Values {
|
||||
ns, err := nbdns.ParseNameServerURL(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nsList = append(nsList, ns)
|
||||
}
|
||||
err = validateNSList(nsList)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newNSGroup.NameServers = nsList
|
||||
case UpdateNameServerGroupGroups:
|
||||
err = validateGroups(operation.Values, account.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newNSGroup.Groups = operation.Values
|
||||
case UpdateNameServerGroupEnabled:
|
||||
enabled, err := strconv.ParseBool(operation.Values[0])
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
|
||||
}
|
||||
newNSGroup.Enabled = enabled
|
||||
case UpdateNameServerGroupPrimary:
|
||||
primary, err := strconv.ParseBool(operation.Values[0])
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse primary status %s, not boolean", operation.Values[0])
|
||||
}
|
||||
newNSGroup.Primary = primary
|
||||
case UpdateNameServerGroupDomains:
|
||||
err = validateDomainInput(false, operation.Values)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newNSGroup.Domains = operation.Values
|
||||
}
|
||||
}
|
||||
|
||||
account.NameServerGroups[nsGroupID] = newNSGroup
|
||||
|
||||
account.Network.IncSerial()
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return newNSGroup.Copy(), status.Errorf(status.Internal, "failed to update peers after update nameserver %s", newNSGroup.Name)
|
||||
}
|
||||
|
||||
return newNSGroup.Copy(), nil
|
||||
}
|
||||
|
||||
// DeleteNameServerGroup deletes nameserver group with nsGroupID
|
||||
func (am *DefaultAccountManager) DeleteNameServerGroup(accountID, nsGroupID, userID string) error {
|
||||
|
||||
|
||||
@@ -655,323 +655,6 @@ func TestSaveNameServerGroup(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateNameServerGroup(t *testing.T) {
|
||||
nsGroupID := "testingNSGroup"
|
||||
|
||||
existingNSGroup := &nbdns.NameServerGroup{
|
||||
ID: nsGroupID,
|
||||
Name: "super",
|
||||
Description: "super",
|
||||
Primary: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.2.2"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
},
|
||||
Groups: []string{group1ID},
|
||||
Enabled: true,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
existingNSGroup *nbdns.NameServerGroup
|
||||
nsGroupID string
|
||||
operations []NameServerGroupUpdateOperation
|
||||
shouldCreate bool
|
||||
errFunc require.ErrorAssertionFunc
|
||||
expectedNSGroup *nbdns.NameServerGroup
|
||||
}{
|
||||
{
|
||||
name: "Should Config Single Property",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupName,
|
||||
Values: []string{"superNew"},
|
||||
},
|
||||
},
|
||||
errFunc: require.NoError,
|
||||
shouldCreate: true,
|
||||
expectedNSGroup: &nbdns.NameServerGroup{
|
||||
ID: nsGroupID,
|
||||
Name: "superNew",
|
||||
Description: "super",
|
||||
Primary: true,
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.1.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("1.1.2.2"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
},
|
||||
Groups: []string{group1ID},
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Should Config Multiple Properties",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupName,
|
||||
Values: []string{"superNew"},
|
||||
},
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupDescription,
|
||||
Values: []string{"superDescription"},
|
||||
},
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupNameServers,
|
||||
Values: []string{"udp://127.0.0.1:53", "udp://8.8.8.8:53"},
|
||||
},
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupGroups,
|
||||
Values: []string{group1ID, group2ID},
|
||||
},
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupEnabled,
|
||||
Values: []string{"false"},
|
||||
},
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupPrimary,
|
||||
Values: []string{"false"},
|
||||
},
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupDomains,
|
||||
Values: []string{validDomain},
|
||||
},
|
||||
},
|
||||
errFunc: require.NoError,
|
||||
shouldCreate: true,
|
||||
expectedNSGroup: &nbdns.NameServerGroup{
|
||||
ID: nsGroupID,
|
||||
Name: "superNew",
|
||||
Description: "superDescription",
|
||||
Primary: false,
|
||||
Domains: []string{validDomain},
|
||||
NameServers: []nbdns.NameServer{
|
||||
{
|
||||
IP: netip.MustParseAddr("127.0.0.1"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
{
|
||||
IP: netip.MustParseAddr("8.8.8.8"),
|
||||
NSType: nbdns.UDPNameServerType,
|
||||
Port: nbdns.DefaultDNSPort,
|
||||
},
|
||||
},
|
||||
Groups: []string{group1ID, group2ID},
|
||||
Enabled: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Should Not Config On Invalid ID",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: "nonExistingNSGroup",
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Should Not Config On Empty Operations",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Should Not Config On Empty Values",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupName,
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Should Not Config On Empty String",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupName,
|
||||
Values: []string{""},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Should Not Config On Invalid Name Large String",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupName,
|
||||
Values: []string{"12345678901234567890qwertyuiopqwertyuiop1"},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Should Not Config On Invalid On Existing Name",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupName,
|
||||
Values: []string{existingNSGroupName},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Should Not Config On Invalid On Multiple Name Values",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupName,
|
||||
Values: []string{"nameOne", "nameTwo"},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Should Not Config On Invalid Boolean",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupEnabled,
|
||||
Values: []string{"yes"},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Should Not Config On Invalid Nameservers Wrong Schema",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupNameServers,
|
||||
Values: []string{"https://127.0.0.1:53"},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Should Not Config On Invalid Nameservers Wrong IP",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupNameServers,
|
||||
Values: []string{"udp://8.8.8.300:53"},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Should Not Config On Large Number Of Nameservers",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupNameServers,
|
||||
Values: []string{"udp://127.0.0.1:53", "udp://8.8.8.8:53", "udp://8.8.4.4:53"},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Should Not Config On Invalid GroupID",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupGroups,
|
||||
Values: []string{"nonExistingGroupID"},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Should Not Config On Invalid Domains",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupDomains,
|
||||
Values: []string{invalidDomain},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Should Not Config On Invalid Primary Status",
|
||||
existingNSGroup: existingNSGroup,
|
||||
nsGroupID: existingNSGroup.ID,
|
||||
operations: []NameServerGroupUpdateOperation{
|
||||
NameServerGroupUpdateOperation{
|
||||
Type: UpdateNameServerGroupPrimary,
|
||||
Values: []string{"yes"},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
am, err := createNSManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestNSAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
account.NameServerGroups[testCase.existingNSGroup.ID] = testCase.existingNSGroup
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Error("account should be saved")
|
||||
}
|
||||
|
||||
updatedRoute, err := am.UpdateNameServerGroup(account.Id, testCase.nsGroupID, userID, testCase.operations)
|
||||
testCase.errFunc(t, err)
|
||||
|
||||
if !testCase.shouldCreate {
|
||||
return
|
||||
}
|
||||
|
||||
testCase.expectedNSGroup.ID = updatedRoute.ID
|
||||
|
||||
if !testCase.expectedNSGroup.IsEqual(updatedRoute) {
|
||||
t.Errorf("new nameserver group didn't match expected group:\nGot %#v\nExpected:%#v\n", updatedRoute, testCase.expectedNSGroup)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteNameServerGroup(t *testing.T) {
|
||||
nsGroupID := "testingNSGroup"
|
||||
|
||||
@@ -1061,7 +744,7 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore)
|
||||
return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore, false)
|
||||
}
|
||||
|
||||
func createNSStore(t *testing.T) (Store, error) {
|
||||
|
||||
@@ -2,7 +2,6 @@ package server
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
@@ -13,57 +12,6 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// UpdateRouteDescription indicates a route description update operation
|
||||
UpdateRouteDescription RouteUpdateOperationType = iota
|
||||
// UpdateRouteNetwork indicates a route IP update operation
|
||||
UpdateRouteNetwork
|
||||
// UpdateRoutePeer indicates a route peer update operation
|
||||
UpdateRoutePeer
|
||||
// UpdateRouteMetric indicates a route metric update operation
|
||||
UpdateRouteMetric
|
||||
// UpdateRouteMasquerade indicates a route masquerade update operation
|
||||
UpdateRouteMasquerade
|
||||
// UpdateRouteEnabled indicates a route enabled update operation
|
||||
UpdateRouteEnabled
|
||||
// UpdateRouteNetworkIdentifier indicates a route net ID update operation
|
||||
UpdateRouteNetworkIdentifier
|
||||
// UpdateRouteGroups indicates a group list update operation
|
||||
UpdateRouteGroups
|
||||
)
|
||||
|
||||
// RouteUpdateOperationType operation type
|
||||
type RouteUpdateOperationType int
|
||||
|
||||
func (t RouteUpdateOperationType) String() string {
|
||||
switch t {
|
||||
case UpdateRouteDescription:
|
||||
return "UpdateRouteDescription"
|
||||
case UpdateRouteNetwork:
|
||||
return "UpdateRouteNetwork"
|
||||
case UpdateRoutePeer:
|
||||
return "UpdateRoutePeer"
|
||||
case UpdateRouteMetric:
|
||||
return "UpdateRouteMetric"
|
||||
case UpdateRouteMasquerade:
|
||||
return "UpdateRouteMasquerade"
|
||||
case UpdateRouteEnabled:
|
||||
return "UpdateRouteEnabled"
|
||||
case UpdateRouteNetworkIdentifier:
|
||||
return "UpdateRouteNetworkIdentifier"
|
||||
case UpdateRouteGroups:
|
||||
return "UpdateRouteGroups"
|
||||
default:
|
||||
return "InvalidOperation"
|
||||
}
|
||||
}
|
||||
|
||||
// RouteUpdateOperation operation object with type and values to be applied
|
||||
type RouteUpdateOperation struct {
|
||||
Type RouteUpdateOperationType
|
||||
Values []string
|
||||
}
|
||||
|
||||
// GetRoute gets a route object from account and route IDs
|
||||
func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
@@ -241,109 +189,6 @@ func (am *DefaultAccountManager) SaveRoute(accountID, userID string, routeToSave
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateRoute updates existing route with set of operations
|
||||
func (am *DefaultAccountManager) UpdateRoute(accountID, routeID string, operations []RouteUpdateOperation) (*route.Route, error) {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
defer unlock()
|
||||
|
||||
account, err := am.Store.GetAccount(accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routeToUpdate, ok := account.Routes[routeID]
|
||||
if !ok {
|
||||
return nil, status.Errorf(status.NotFound, "route %s no longer exists", routeID)
|
||||
}
|
||||
|
||||
newRoute := routeToUpdate.Copy()
|
||||
|
||||
for _, operation := range operations {
|
||||
|
||||
if len(operation.Values) != 1 {
|
||||
return nil, status.Errorf(status.InvalidArgument, "operation %s contains invalid number of values, it should be 1", operation.Type.String())
|
||||
}
|
||||
|
||||
switch operation.Type {
|
||||
case UpdateRouteDescription:
|
||||
newRoute.Description = operation.Values[0]
|
||||
case UpdateRouteNetworkIdentifier:
|
||||
if utf8.RuneCountInString(operation.Values[0]) > route.MaxNetIDChar || operation.Values[0] == "" {
|
||||
return nil, status.Errorf(status.InvalidArgument, "identifier should be between 1 and %d", route.MaxNetIDChar)
|
||||
}
|
||||
newRoute.NetID = operation.Values[0]
|
||||
case UpdateRouteNetwork:
|
||||
prefixType, prefix, err := route.ParseNetwork(operation.Values[0])
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse IP %s", operation.Values[0])
|
||||
}
|
||||
err = am.checkPrefixPeerExists(accountID, routeToUpdate.Peer, prefix)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newRoute.Network = prefix
|
||||
newRoute.NetworkType = prefixType
|
||||
case UpdateRoutePeer:
|
||||
if operation.Values[0] != "" {
|
||||
peer := account.GetPeer(operation.Values[0])
|
||||
if peer == nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "peer with ID %s not found", operation.Values[0])
|
||||
}
|
||||
}
|
||||
|
||||
err = am.checkPrefixPeerExists(accountID, operation.Values[0], routeToUpdate.Network)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newRoute.Peer = operation.Values[0]
|
||||
case UpdateRouteMetric:
|
||||
metric, err := strconv.Atoi(operation.Values[0])
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, not int", operation.Values[0])
|
||||
}
|
||||
if metric < route.MinMetric || metric > route.MaxMetric {
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse metric %s, value should be %d > N < %d",
|
||||
operation.Values[0],
|
||||
route.MinMetric,
|
||||
route.MaxMetric,
|
||||
)
|
||||
}
|
||||
newRoute.Metric = metric
|
||||
case UpdateRouteMasquerade:
|
||||
masquerade, err := strconv.ParseBool(operation.Values[0])
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse masquerade %s, not boolean", operation.Values[0])
|
||||
}
|
||||
newRoute.Masquerade = masquerade
|
||||
case UpdateRouteEnabled:
|
||||
enabled, err := strconv.ParseBool(operation.Values[0])
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "failed to parse enabled %s, not boolean", operation.Values[0])
|
||||
}
|
||||
newRoute.Enabled = enabled
|
||||
case UpdateRouteGroups:
|
||||
err = validateGroups(operation.Values, account.Groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newRoute.Groups = operation.Values
|
||||
}
|
||||
}
|
||||
|
||||
account.Routes[routeID] = newRoute
|
||||
|
||||
account.Network.IncSerial()
|
||||
if err = am.Store.SaveAccount(account); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(status.Internal, "failed to update account peers")
|
||||
}
|
||||
return newRoute, nil
|
||||
}
|
||||
|
||||
// DeleteRoute deletes route with routeID
|
||||
func (am *DefaultAccountManager) DeleteRoute(accountID, routeID, userID string) error {
|
||||
unlock := am.Store.AcquireAccountLock(accountID)
|
||||
|
||||
@@ -524,265 +524,6 @@ func TestSaveRoute(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRoute(t *testing.T) {
|
||||
routeID := "testingRouteID"
|
||||
|
||||
existingRoute := &route.Route{
|
||||
ID: routeID,
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
NetID: "superRoute",
|
||||
NetworkType: route.IPv4Network,
|
||||
Peer: peer1ID,
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
existingRoute *route.Route
|
||||
operations []RouteUpdateOperation
|
||||
shouldCreate bool
|
||||
errFunc require.ErrorAssertionFunc
|
||||
expectedRoute *route.Route
|
||||
}{
|
||||
{
|
||||
name: "Happy Path Single OPS",
|
||||
existingRoute: existingRoute,
|
||||
operations: []RouteUpdateOperation{
|
||||
{
|
||||
Type: UpdateRoutePeer,
|
||||
Values: []string{peer2ID},
|
||||
},
|
||||
},
|
||||
errFunc: require.NoError,
|
||||
shouldCreate: true,
|
||||
expectedRoute: &route.Route{
|
||||
ID: routeID,
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
NetID: "superRoute",
|
||||
NetworkType: route.IPv4Network,
|
||||
Peer: peer2ID,
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Happy Path Multiple OPS",
|
||||
existingRoute: existingRoute,
|
||||
operations: []RouteUpdateOperation{
|
||||
{
|
||||
Type: UpdateRouteDescription,
|
||||
Values: []string{"great"},
|
||||
},
|
||||
{
|
||||
Type: UpdateRouteNetwork,
|
||||
Values: []string{"192.168.0.0/24"},
|
||||
},
|
||||
{
|
||||
Type: UpdateRoutePeer,
|
||||
Values: []string{peer2ID},
|
||||
},
|
||||
{
|
||||
Type: UpdateRouteMetric,
|
||||
Values: []string{"3030"},
|
||||
},
|
||||
{
|
||||
Type: UpdateRouteMasquerade,
|
||||
Values: []string{"true"},
|
||||
},
|
||||
{
|
||||
Type: UpdateRouteEnabled,
|
||||
Values: []string{"false"},
|
||||
},
|
||||
{
|
||||
Type: UpdateRouteNetworkIdentifier,
|
||||
Values: []string{"megaRoute"},
|
||||
},
|
||||
{
|
||||
Type: UpdateRouteGroups,
|
||||
Values: []string{routeGroup2},
|
||||
},
|
||||
},
|
||||
errFunc: require.NoError,
|
||||
shouldCreate: true,
|
||||
expectedRoute: &route.Route{
|
||||
ID: routeID,
|
||||
Network: netip.MustParsePrefix("192.168.0.0/24"),
|
||||
NetID: "megaRoute",
|
||||
NetworkType: route.IPv4Network,
|
||||
Peer: peer2ID,
|
||||
Description: "great",
|
||||
Masquerade: true,
|
||||
Metric: 3030,
|
||||
Enabled: false,
|
||||
Groups: []string{routeGroup2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Empty Values Should Fail",
|
||||
existingRoute: existingRoute,
|
||||
operations: []RouteUpdateOperation{
|
||||
{
|
||||
Type: UpdateRoutePeer,
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Multiple Values Should Fail",
|
||||
existingRoute: existingRoute,
|
||||
operations: []RouteUpdateOperation{
|
||||
{
|
||||
Type: UpdateRoutePeer,
|
||||
Values: []string{peer2ID, peer1ID},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Bad Prefix Should Fail",
|
||||
existingRoute: existingRoute,
|
||||
operations: []RouteUpdateOperation{
|
||||
{
|
||||
Type: UpdateRouteNetwork,
|
||||
Values: []string{"192.168.0.0/34"},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Bad Peer Should Fail",
|
||||
existingRoute: existingRoute,
|
||||
operations: []RouteUpdateOperation{
|
||||
{
|
||||
Type: UpdateRoutePeer,
|
||||
Values: []string{"non existing Peer"},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Empty Peer",
|
||||
existingRoute: existingRoute,
|
||||
operations: []RouteUpdateOperation{
|
||||
{
|
||||
Type: UpdateRoutePeer,
|
||||
Values: []string{""},
|
||||
},
|
||||
},
|
||||
errFunc: require.NoError,
|
||||
shouldCreate: true,
|
||||
expectedRoute: &route.Route{
|
||||
ID: routeID,
|
||||
Network: netip.MustParsePrefix("192.168.0.0/16"),
|
||||
NetID: "superRoute",
|
||||
NetworkType: route.IPv4Network,
|
||||
Peer: "",
|
||||
Description: "super",
|
||||
Masquerade: false,
|
||||
Metric: 9999,
|
||||
Enabled: true,
|
||||
Groups: []string{routeGroup1},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Large Network ID Should Fail",
|
||||
existingRoute: existingRoute,
|
||||
operations: []RouteUpdateOperation{
|
||||
{
|
||||
Type: UpdateRouteNetworkIdentifier,
|
||||
Values: []string{"12345678901234567890qwertyuiopqwertyuiop1"},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Empty Network ID Should Fail",
|
||||
existingRoute: existingRoute,
|
||||
operations: []RouteUpdateOperation{
|
||||
{
|
||||
Type: UpdateRouteNetworkIdentifier,
|
||||
Values: []string{""},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Invalid Metric Should Fail",
|
||||
existingRoute: existingRoute,
|
||||
operations: []RouteUpdateOperation{
|
||||
{
|
||||
Type: UpdateRouteMetric,
|
||||
Values: []string{"999999"},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Invalid Boolean Should Fail",
|
||||
existingRoute: existingRoute,
|
||||
operations: []RouteUpdateOperation{
|
||||
{
|
||||
Type: UpdateRouteMasquerade,
|
||||
Values: []string{"yes"},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
{
|
||||
name: "Invalid Group Should Fail",
|
||||
existingRoute: existingRoute,
|
||||
operations: []RouteUpdateOperation{
|
||||
{
|
||||
Type: UpdateRouteGroups,
|
||||
Values: []string{routeInvalidGroup1},
|
||||
},
|
||||
},
|
||||
errFunc: require.Error,
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
am, err := createRouterManager(t)
|
||||
if err != nil {
|
||||
t.Error("failed to create account manager")
|
||||
}
|
||||
|
||||
account, err := initTestRouteAccount(t, am)
|
||||
if err != nil {
|
||||
t.Error("failed to init testing account")
|
||||
}
|
||||
|
||||
account.Routes[testCase.existingRoute.ID] = testCase.existingRoute
|
||||
|
||||
err = am.Store.SaveAccount(account)
|
||||
if err != nil {
|
||||
t.Error("account should be saved")
|
||||
}
|
||||
|
||||
updatedRoute, err := am.UpdateRoute(account.Id, testCase.existingRoute.ID, testCase.operations)
|
||||
|
||||
testCase.errFunc(t, err)
|
||||
|
||||
if !testCase.shouldCreate {
|
||||
return
|
||||
}
|
||||
|
||||
testCase.expectedRoute.ID = updatedRoute.ID
|
||||
|
||||
if !testCase.expectedRoute.IsEqual(updatedRoute) {
|
||||
t.Errorf("new route didn't match expected route:\nGot %#v\nExpected:%#v\n", updatedRoute, testCase.expectedRoute)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteRoute(t *testing.T) {
|
||||
testingRoute := &route.Route{
|
||||
ID: "testingRoute",
|
||||
@@ -940,7 +681,7 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, error) {
|
||||
return nil, err
|
||||
}
|
||||
eventStore := &activity.InMemoryEventStore{}
|
||||
return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore)
|
||||
return BuildManager(store, NewPeersUpdateManager(), nil, "", "", eventStore, false)
|
||||
}
|
||||
|
||||
func createRouterStore(t *testing.T) (Store, error) {
|
||||
|
||||
@@ -2,6 +2,7 @@ package telemetry
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.opentelemetry.io/otel/metric"
|
||||
"go.opentelemetry.io/otel/metric/instrument"
|
||||
"go.opentelemetry.io/otel/metric/instrument/syncint64"
|
||||
@@ -13,6 +14,7 @@ type IDPMetrics struct {
|
||||
getUserByEmailCounter syncint64.Counter
|
||||
getAllAccountsCounter syncint64.Counter
|
||||
createUserCounter syncint64.Counter
|
||||
deleteUserCounter syncint64.Counter
|
||||
getAccountCounter syncint64.Counter
|
||||
getUserByIDCounter syncint64.Counter
|
||||
authenticateRequestCounter syncint64.Counter
|
||||
@@ -39,6 +41,10 @@ func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
deleteUserCounter, err := meter.SyncInt64().Counter("management.idp.delete.user.counter", instrument.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
getAccountCounter, err := meter.SyncInt64().Counter("management.idp.get.account.counter", instrument.WithUnit("1"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -65,6 +71,7 @@ func NewIDPMetrics(ctx context.Context, meter metric.Meter) (*IDPMetrics, error)
|
||||
getUserByEmailCounter: getUserByEmailCounter,
|
||||
getAllAccountsCounter: getAllAccountsCounter,
|
||||
createUserCounter: createUserCounter,
|
||||
deleteUserCounter: deleteUserCounter,
|
||||
getAccountCounter: getAccountCounter,
|
||||
getUserByIDCounter: getUserByIDCounter,
|
||||
authenticateRequestCounter: authenticateRequestCounter,
|
||||
@@ -88,6 +95,11 @@ func (idpMetrics *IDPMetrics) CountCreateUser() {
|
||||
idpMetrics.createUserCounter.Add(idpMetrics.ctx, 1)
|
||||
}
|
||||
|
||||
// CountDeleteUser ...
|
||||
func (idpMetrics *IDPMetrics) CountDeleteUser() {
|
||||
idpMetrics.deleteUserCounter.Add(idpMetrics.ctx, 1)
|
||||
}
|
||||
|
||||
// CountGetAllAccounts ...
|
||||
func (idpMetrics *IDPMetrics) CountGetAllAccounts() {
|
||||
idpMetrics.getAllAccountsCounter.Add(idpMetrics.ctx, 1)
|
||||
|
||||
@@ -327,15 +327,43 @@ func (am *DefaultAccountManager) DeleteUser(accountID, initiatorUserID string, t
|
||||
return status.Errorf(status.NotFound, "user not found")
|
||||
}
|
||||
if executingUser.Role != UserRoleAdmin {
|
||||
return status.Errorf(status.PermissionDenied, "only admins can delete service users")
|
||||
return status.Errorf(status.PermissionDenied, "only admins can delete users")
|
||||
}
|
||||
|
||||
if !targetUser.IsServiceUser {
|
||||
return status.Errorf(status.PermissionDenied, "regular users can not be deleted")
|
||||
peers, err := account.FindUserPeers(targetUserID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to find user peers")
|
||||
}
|
||||
|
||||
meta := map[string]any{"name": targetUser.ServiceUserName}
|
||||
am.storeEvent(initiatorUserID, targetUserID, accountID, activity.ServiceUserDeleted, meta)
|
||||
if err := am.expireAndUpdatePeers(account, peers); err != nil {
|
||||
log.Errorf("failed update deleted peers expiration: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
targetUserEmail, err := am.getEmailOfTargetUser(account.Id, initiatorUserID, targetUserID)
|
||||
if err != nil {
|
||||
log.Errorf("failed to resolve email address: %s", err)
|
||||
return err
|
||||
}
|
||||
|
||||
var meta map[string]any
|
||||
var eventAction activity.Activity
|
||||
if targetUser.IsServiceUser {
|
||||
meta = map[string]any{"name": targetUser.ServiceUserName}
|
||||
eventAction = activity.ServiceUserDeleted
|
||||
} else {
|
||||
meta = map[string]any{"email": targetUserEmail}
|
||||
eventAction = activity.UserDeleted
|
||||
|
||||
}
|
||||
am.storeEvent(initiatorUserID, targetUserID, accountID, eventAction, meta)
|
||||
|
||||
if !isNil(am.idpManager) {
|
||||
err := am.deleteUserFromIDP(targetUserID, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
delete(account.Users, targetUserID)
|
||||
|
||||
@@ -609,23 +637,10 @@ func (am *DefaultAccountManager) SaveUser(accountID, initiatorUserID string, upd
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var peerIDs []string
|
||||
for _, peer := range blockedPeers {
|
||||
peerIDs = append(peerIDs, peer.ID)
|
||||
peer.MarkLoginExpired(true)
|
||||
account.UpdatePeer(peer)
|
||||
err = am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status)
|
||||
if err != nil {
|
||||
log.Errorf("failed saving peer status while expiring peer %s", peer.ID)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
am.peersUpdateManager.CloseChannels(peerIDs)
|
||||
err = am.updateAccountPeers(account)
|
||||
if err != nil {
|
||||
log.Errorf("failed updating account peers while expiring peers of a blocked user %s", accountID)
|
||||
return nil, err
|
||||
|
||||
if err := am.expireAndUpdatePeers(account, blockedPeers); err != nil {
|
||||
log.Errorf("failed update expired peers: %s", err)
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -814,6 +829,67 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) (
|
||||
return userInfos, nil
|
||||
}
|
||||
|
||||
// expireAndUpdatePeers expires all peers of the given user and updates them in the account
|
||||
func (am *DefaultAccountManager) expireAndUpdatePeers(account *Account, peers []*Peer) error {
|
||||
var peerIDs []string
|
||||
for _, peer := range peers {
|
||||
peerIDs = append(peerIDs, peer.ID)
|
||||
peer.MarkLoginExpired(true)
|
||||
account.UpdatePeer(peer)
|
||||
if err := am.Store.SavePeerStatus(account.Id, peer.ID, *peer.Status); err != nil {
|
||||
return err
|
||||
}
|
||||
am.storeEvent(
|
||||
peer.UserID, peer.ID, account.Id,
|
||||
activity.PeerLoginExpired, peer.EventMeta(am.GetDNSDomain()),
|
||||
)
|
||||
}
|
||||
|
||||
if len(peerIDs) != 0 {
|
||||
// this will trigger peer disconnect from the management service
|
||||
am.peersUpdateManager.CloseChannels(peerIDs)
|
||||
if err := am.updateAccountPeers(account); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) deleteUserFromIDP(targetUserID, accountID string) error {
|
||||
if am.userDeleteFromIDPEnabled {
|
||||
log.Debugf("user %s deleted from IdP", targetUserID)
|
||||
err := am.idpManager.DeleteUser(targetUserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete user %s from IdP: %s", targetUserID, err)
|
||||
}
|
||||
} else {
|
||||
err := am.idpManager.UpdateUserAppMetadata(targetUserID, idp.AppMetadata{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove user %s app metadata in IdP: %s", targetUserID, err)
|
||||
}
|
||||
|
||||
_, err = am.refreshCache(accountID)
|
||||
if err != nil {
|
||||
log.Errorf("refresh account (%q) cache: %v", accountID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) getEmailOfTargetUser(accountId string, initiatorId, targetId string) (string, error) {
|
||||
userInfos, err := am.GetUsersFromAccount(accountId, initiatorId)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, ui := range userInfos {
|
||||
if ui.ID == targetId {
|
||||
return ui.Email, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("email not found for user: %s", targetId)
|
||||
}
|
||||
|
||||
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {
|
||||
for _, user := range userData {
|
||||
if user.ID == userID {
|
||||
|
||||
@@ -439,8 +439,9 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
|
||||
}
|
||||
|
||||
err = am.DeleteUser(mockAccountID, mockUserID, mockUserID)
|
||||
|
||||
assert.Errorf(t, err, "Regular users can not be deleted (yet)")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_GetUser(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user