mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
[management] use readlock on add peer (#4308)
This commit is contained in:
@@ -913,6 +913,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
|
|||||||
|
|
||||||
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||||
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
|
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
empty := &proto.Empty{}
|
empty := &proto.Empty{}
|
||||||
peerKey, err := s.parseRequest(ctx, req, empty)
|
peerKey, err := s.parseRequest(ctx, req, empty)
|
||||||
@@ -944,7 +945,7 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
|
|||||||
|
|
||||||
s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
|
s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
|
||||||
|
|
||||||
log.WithContext(ctx).Infof("peer %s logged out successfully", peerKey.String())
|
log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start))
|
||||||
|
|
||||||
return &proto.Empty{}, nil
|
return &proto.Empty{}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -609,7 +609,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
|||||||
newPeer.DNSLabel = freeLabel
|
newPeer.DNSLabel = freeLabel
|
||||||
newPeer.IP = freeIP
|
newPeer.IP = freeIP
|
||||||
|
|
||||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||||
defer func() {
|
defer func() {
|
||||||
if unlock != nil {
|
if unlock != nil {
|
||||||
unlock()
|
unlock()
|
||||||
|
|||||||
@@ -1476,8 +1476,9 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
|||||||
|
|
||||||
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||||
engine := os.Getenv("NETBIRD_STORE_ENGINE")
|
engine := os.Getenv("NETBIRD_STORE_ENGINE")
|
||||||
if engine == "sqlite" || engine == "" {
|
if engine == "sqlite" || engine == "mysql" || engine == "" {
|
||||||
t.Skip("Skipping test because sqlite test store is not respecting foreign keys")
|
// we intentionally disabled foreign keys in mysql
|
||||||
|
t.Skip("Skipping test because store is not respecting foreign keys")
|
||||||
}
|
}
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import (
|
|||||||
"gorm.io/gorm/logger"
|
"gorm.io/gorm/logger"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
@@ -76,7 +77,12 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
|||||||
conns = runtime.NumCPU()
|
conns = runtime.NumCPU()
|
||||||
}
|
}
|
||||||
|
|
||||||
if storeEngine == types.SqliteStoreEngine {
|
switch storeEngine {
|
||||||
|
case types.MysqlStoreEngine:
|
||||||
|
if err := db.Exec("SET GLOBAL FOREIGN_KEY_CHECKS = 0").Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
case types.SqliteStoreEngine:
|
||||||
if err == nil {
|
if err == nil {
|
||||||
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
|
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
|
||||||
}
|
}
|
||||||
@@ -142,14 +148,16 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
|
|||||||
func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||||
log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID)
|
log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID)
|
||||||
|
|
||||||
start := time.Now()
|
startWait := time.Now()
|
||||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||||
mtx := value.(*sync.RWMutex)
|
mtx := value.(*sync.RWMutex)
|
||||||
mtx.Lock()
|
mtx.Lock()
|
||||||
|
log.WithContext(ctx).Tracef("waiting to acquire write lock for ID %s in %v", uniqueID, time.Since(startWait))
|
||||||
|
startHold := time.Now()
|
||||||
|
|
||||||
unlock = func() {
|
unlock = func() {
|
||||||
mtx.Unlock()
|
mtx.Unlock()
|
||||||
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start))
|
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(startHold))
|
||||||
}
|
}
|
||||||
|
|
||||||
return unlock
|
return unlock
|
||||||
@@ -159,14 +167,16 @@ func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (
|
|||||||
func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||||
log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID)
|
log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID)
|
||||||
|
|
||||||
start := time.Now()
|
startWait := time.Now()
|
||||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||||
mtx := value.(*sync.RWMutex)
|
mtx := value.(*sync.RWMutex)
|
||||||
mtx.RLock()
|
mtx.RLock()
|
||||||
|
log.WithContext(ctx).Tracef("waiting to acquire read lock for ID %s in %v", uniqueID, time.Since(startWait))
|
||||||
|
startHold := time.Now()
|
||||||
|
|
||||||
unlock = func() {
|
unlock = func() {
|
||||||
mtx.RUnlock()
|
mtx.RUnlock()
|
||||||
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start))
|
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(startHold))
|
||||||
}
|
}
|
||||||
|
|
||||||
return unlock
|
return unlock
|
||||||
@@ -604,13 +614,16 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
|
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
|
||||||
|
ctx, cancel := getDebuggingCtx(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
tx := s.db
|
tx := s.db
|
||||||
if lockStrength != LockingStrengthNone {
|
if lockStrength != LockingStrengthNone {
|
||||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
}
|
}
|
||||||
|
|
||||||
var user types.User
|
var user types.User
|
||||||
result := tx.Take(&user, idQueryCondition, userID)
|
result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.NewUserNotFoundError(userID)
|
return nil, status.NewUserNotFoundError(userID)
|
||||||
@@ -1076,13 +1089,16 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
|
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
|
||||||
|
ctx, cancel := getDebuggingCtx(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
tx := s.db
|
tx := s.db
|
||||||
if lockStrength != LockingStrengthNone {
|
if lockStrength != LockingStrengthNone {
|
||||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
}
|
}
|
||||||
|
|
||||||
var accountNetwork types.AccountNetwork
|
var accountNetwork types.AccountNetwork
|
||||||
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
|
if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, status.NewAccountNotFoundError(accountID)
|
return nil, status.NewAccountNotFoundError(accountID)
|
||||||
}
|
}
|
||||||
@@ -1092,13 +1108,16 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||||
|
ctx, cancel := getDebuggingCtx(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
tx := s.db
|
tx := s.db
|
||||||
if lockStrength != LockingStrengthNone {
|
if lockStrength != LockingStrengthNone {
|
||||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
}
|
}
|
||||||
|
|
||||||
var peer nbpeer.Peer
|
var peer nbpeer.Peer
|
||||||
result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey)
|
result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey)
|
||||||
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
@@ -1147,8 +1166,11 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
|
|||||||
|
|
||||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||||
|
ctx, cancel := getDebuggingCtx(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
var user types.User
|
var user types.User
|
||||||
result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID)
|
result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID)
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||||
return status.NewUserNotFoundError(userID)
|
return status.NewUserNotFoundError(userID)
|
||||||
@@ -1329,13 +1351,16 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
|
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
|
||||||
|
ctx, cancel := getDebuggingCtx(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
tx := s.db
|
tx := s.db
|
||||||
if lockStrength != LockingStrengthNone {
|
if lockStrength != LockingStrengthNone {
|
||||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||||
}
|
}
|
||||||
|
|
||||||
var setupKey types.SetupKey
|
var setupKey types.SetupKey
|
||||||
result := tx.
|
result := tx.WithContext(ctx).
|
||||||
Take(&setupKey, GetKeyQueryCondition(s), key)
|
Take(&setupKey, GetKeyQueryCondition(s), key)
|
||||||
|
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
@@ -1349,7 +1374,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
||||||
result := s.db.Model(&types.SetupKey{}).
|
ctx, cancel := getDebuggingCtx(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Model(&types.SetupKey{}).
|
||||||
Where(idQueryCondition, setupKeyID).
|
Where(idQueryCondition, setupKeyID).
|
||||||
Updates(map[string]interface{}{
|
Updates(map[string]interface{}{
|
||||||
"used_times": gorm.Expr("used_times + 1"),
|
"used_times": gorm.Expr("used_times + 1"),
|
||||||
@@ -1369,8 +1397,11 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
|||||||
|
|
||||||
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
|
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
|
||||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||||
|
ctx, cancel := getDebuggingCtx(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
var groupID string
|
var groupID string
|
||||||
_ = s.db.Model(types.Group{}).
|
_ = s.db.WithContext(ctx).Model(types.Group{}).
|
||||||
Select("id").
|
Select("id").
|
||||||
Where("account_id = ? AND name = ?", accountID, "All").
|
Where("account_id = ? AND name = ?", accountID, "All").
|
||||||
Limit(1).
|
Limit(1).
|
||||||
@@ -1398,13 +1429,16 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
|
|||||||
|
|
||||||
// AddPeerToGroup adds a peer to a group
|
// AddPeerToGroup adds a peer to a group
|
||||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
|
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
|
||||||
|
ctx, cancel := getDebuggingCtx(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
peer := &types.GroupPeer{
|
peer := &types.GroupPeer{
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
GroupID: groupID,
|
GroupID: groupID,
|
||||||
PeerID: peerID,
|
PeerID: peerID,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := s.db.Clauses(clause.OnConflict{
|
err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||||
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
|
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
|
||||||
DoNothing: true,
|
DoNothing: true,
|
||||||
}).Create(peer).Error
|
}).Create(peer).Error
|
||||||
@@ -1594,7 +1628,10 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||||
if err := s.db.Create(peer).Error; err != nil {
|
ctx, cancel := getDebuggingCtx(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
|
||||||
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
|
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1720,7 +1757,10 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||||
result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
ctx, cancel := getDebuggingCtx(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||||
if result.Error != nil {
|
if result.Error != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
|
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
|
||||||
return status.Errorf(status.Internal, "failed to increment network serial count in store")
|
return status.Errorf(status.Internal, "failed to increment network serial count in store")
|
||||||
@@ -2762,3 +2802,33 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin
|
|||||||
|
|
||||||
return groupPeers, nil
|
return groupPeers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||||
|
userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string)
|
||||||
|
if ok {
|
||||||
|
//nolint
|
||||||
|
ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string)
|
||||||
|
if ok {
|
||||||
|
//nolint
|
||||||
|
ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string)
|
||||||
|
if ok {
|
||||||
|
//nolint
|
||||||
|
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
case <-grpcCtx.Done():
|
||||||
|
log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return ctx, cancel
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user