[management] use readlock on add peer (#4308)

This commit is contained in:
Pascal Fischer
2025-08-11 15:21:26 +02:00
committed by GitHub
parent 1022a5015c
commit a942e4add5
4 changed files with 91 additions and 19 deletions

View File

@@ -913,6 +913,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
start := time.Now()
empty := &proto.Empty{}
peerKey, err := s.parseRequest(ctx, req, empty)
@@ -944,7 +945,7 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
log.WithContext(ctx).Infof("peer %s logged out successfully", peerKey.String())
log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start))
return &proto.Empty{}, nil
}

View File

@@ -609,7 +609,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
newPeer.DNSLabel = freeLabel
newPeer.IP = freeIP
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
defer func() {
if unlock != nil {
unlock()

View File

@@ -1476,8 +1476,9 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
engine := os.Getenv("NETBIRD_STORE_ENGINE")
if engine == "sqlite" || engine == "" {
t.Skip("Skipping test because sqlite test store is not respecting foreign keys")
if engine == "sqlite" || engine == "mysql" || engine == "" {
// we intentionally disabled foreign keys in mysql
t.Skip("Skipping test because store is not respecting foreign keys")
}
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")

View File

@@ -24,6 +24,7 @@ import (
"gorm.io/gorm/logger"
nbdns "github.com/netbirdio/netbird/dns"
nbcontext "github.com/netbirdio/netbird/management/server/context"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
@@ -76,7 +77,12 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
conns = runtime.NumCPU()
}
if storeEngine == types.SqliteStoreEngine {
switch storeEngine {
case types.MysqlStoreEngine:
if err := db.Exec("SET GLOBAL FOREIGN_KEY_CHECKS = 0").Error; err != nil {
return nil, err
}
case types.SqliteStoreEngine:
if err == nil {
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
}
@@ -142,14 +148,16 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID)
start := time.Now()
startWait := time.Now()
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.Lock()
log.WithContext(ctx).Tracef("waiting to acquire write lock for ID %s in %v", uniqueID, time.Since(startWait))
startHold := time.Now()
unlock = func() {
mtx.Unlock()
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start))
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(startHold))
}
return unlock
@@ -159,14 +167,16 @@ func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (
func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID)
start := time.Now()
startWait := time.Now()
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
mtx := value.(*sync.RWMutex)
mtx.RLock()
log.WithContext(ctx).Tracef("waiting to acquire read lock for ID %s in %v", uniqueID, time.Since(startWait))
startHold := time.Now()
unlock = func() {
mtx.RUnlock()
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start))
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(startHold))
}
return unlock
@@ -604,13 +614,16 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var user types.User
result := tx.Take(&user, idQueryCondition, userID)
result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
@@ -1076,13 +1089,16 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountNetwork types.AccountNetwork
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
@@ -1092,13 +1108,16 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peer nbpeer.Peer
result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey)
result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
@@ -1147,8 +1166,11 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
var user types.User
result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID)
result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID)
@@ -1329,13 +1351,16 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
}
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var setupKey types.SetupKey
result := tx.
result := tx.WithContext(ctx).
Take(&setupKey, GetKeyQueryCondition(s), key)
if result.Error != nil {
@@ -1349,7 +1374,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
}
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
result := s.db.Model(&types.SetupKey{}).
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
result := s.db.WithContext(ctx).Model(&types.SetupKey{}).
Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{
"used_times": gorm.Expr("used_times + 1"),
@@ -1369,8 +1397,11 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
var groupID string
_ = s.db.Model(types.Group{}).
_ = s.db.WithContext(ctx).Model(types.Group{}).
Select("id").
Where("account_id = ? AND name = ?", accountID, "All").
Limit(1).
@@ -1398,13 +1429,16 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
// AddPeerToGroup adds a peer to a group
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
peer := &types.GroupPeer{
AccountID: accountID,
GroupID: groupID,
PeerID: peerID,
}
err := s.db.Clauses(clause.OnConflict{
err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
DoNothing: true,
}).Create(peer).Error
@@ -1594,7 +1628,10 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
if err := s.db.Create(peer).Error; err != nil {
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
}
@@ -1720,7 +1757,10 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
ctx, cancel := getDebuggingCtx(ctx)
defer cancel()
result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
return status.Errorf(status.Internal, "failed to increment network serial count in store")
@@ -2762,3 +2802,33 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin
return groupPeers, nil
}
func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID)
}
requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID)
}
accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string)
if ok {
//nolint
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
}
go func() {
select {
case <-ctx.Done():
case <-grpcCtx.Done():
log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err())
}
}()
return ctx, cancel
}