mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[management] use readlock on add peer (#4308)
This commit is contained in:
@@ -913,6 +913,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage)
|
||||
|
||||
func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) {
|
||||
log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey)
|
||||
start := time.Now()
|
||||
|
||||
empty := &proto.Empty{}
|
||||
peerKey, err := s.parseRequest(ctx, req, empty)
|
||||
@@ -944,7 +945,7 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*
|
||||
|
||||
s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID)
|
||||
|
||||
log.WithContext(ctx).Infof("peer %s logged out successfully", peerKey.String())
|
||||
log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start))
|
||||
|
||||
return &proto.Empty{}, nil
|
||||
}
|
||||
|
||||
@@ -609,7 +609,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
|
||||
newPeer.DNSLabel = freeLabel
|
||||
newPeer.IP = freeIP
|
||||
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
unlock := am.Store.AcquireReadLockByUID(ctx, accountID)
|
||||
defer func() {
|
||||
if unlock != nil {
|
||||
unlock()
|
||||
|
||||
@@ -1476,8 +1476,9 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
|
||||
|
||||
func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
|
||||
engine := os.Getenv("NETBIRD_STORE_ENGINE")
|
||||
if engine == "sqlite" || engine == "" {
|
||||
t.Skip("Skipping test because sqlite test store is not respecting foreign keys")
|
||||
if engine == "sqlite" || engine == "mysql" || engine == "" {
|
||||
// we intentionally disabled foreign keys in mysql
|
||||
t.Skip("Skipping test because store is not respecting foreign keys")
|
||||
}
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("The SQLite store is not properly supported by Windows yet")
|
||||
|
||||
@@ -24,6 +24,7 @@ import (
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||
@@ -76,7 +77,12 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met
|
||||
conns = runtime.NumCPU()
|
||||
}
|
||||
|
||||
if storeEngine == types.SqliteStoreEngine {
|
||||
switch storeEngine {
|
||||
case types.MysqlStoreEngine:
|
||||
if err := db.Exec("SET GLOBAL FOREIGN_KEY_CHECKS = 0").Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case types.SqliteStoreEngine:
|
||||
if err == nil {
|
||||
log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1")
|
||||
}
|
||||
@@ -142,14 +148,16 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
|
||||
func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID)
|
||||
|
||||
start := time.Now()
|
||||
startWait := time.Now()
|
||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.Lock()
|
||||
log.WithContext(ctx).Tracef("waiting to acquire write lock for ID %s in %v", uniqueID, time.Since(startWait))
|
||||
startHold := time.Now()
|
||||
|
||||
unlock = func() {
|
||||
mtx.Unlock()
|
||||
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(startHold))
|
||||
}
|
||||
|
||||
return unlock
|
||||
@@ -159,14 +167,16 @@ func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (
|
||||
func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) {
|
||||
log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID)
|
||||
|
||||
start := time.Now()
|
||||
startWait := time.Now()
|
||||
value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{})
|
||||
mtx := value.(*sync.RWMutex)
|
||||
mtx.RLock()
|
||||
log.WithContext(ctx).Tracef("waiting to acquire read lock for ID %s in %v", uniqueID, time.Since(startWait))
|
||||
startHold := time.Now()
|
||||
|
||||
unlock = func() {
|
||||
mtx.RUnlock()
|
||||
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start))
|
||||
log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(startHold))
|
||||
}
|
||||
|
||||
return unlock
|
||||
@@ -604,13 +614,16 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var user types.User
|
||||
result := tx.Take(&user, idQueryCondition, userID)
|
||||
result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewUserNotFoundError(userID)
|
||||
@@ -1076,13 +1089,16 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var accountNetwork types.AccountNetwork
|
||||
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
|
||||
if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
}
|
||||
@@ -1092,13 +1108,16 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var peer nbpeer.Peer
|
||||
result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey)
|
||||
result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey)
|
||||
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
@@ -1147,8 +1166,11 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking
|
||||
|
||||
// SaveUserLastLogin stores the last login time for a user in DB.
|
||||
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
var user types.User
|
||||
result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID)
|
||||
result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID)
|
||||
if result.Error != nil {
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
return status.NewUserNotFoundError(userID)
|
||||
@@ -1329,13 +1351,16 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var setupKey types.SetupKey
|
||||
result := tx.
|
||||
result := tx.WithContext(ctx).
|
||||
Take(&setupKey, GetKeyQueryCondition(s), key)
|
||||
|
||||
if result.Error != nil {
|
||||
@@ -1349,7 +1374,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
|
||||
result := s.db.Model(&types.SetupKey{}).
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
result := s.db.WithContext(ctx).Model(&types.SetupKey{}).
|
||||
Where(idQueryCondition, setupKeyID).
|
||||
Updates(map[string]interface{}{
|
||||
"used_times": gorm.Expr("used_times + 1"),
|
||||
@@ -1369,8 +1397,11 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
|
||||
|
||||
// AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction
|
||||
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
var groupID string
|
||||
_ = s.db.Model(types.Group{}).
|
||||
_ = s.db.WithContext(ctx).Model(types.Group{}).
|
||||
Select("id").
|
||||
Where("account_id = ? AND name = ?", accountID, "All").
|
||||
Limit(1).
|
||||
@@ -1398,13 +1429,16 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
|
||||
|
||||
// AddPeerToGroup adds a peer to a group
|
||||
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error {
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
peer := &types.GroupPeer{
|
||||
AccountID: accountID,
|
||||
GroupID: groupID,
|
||||
PeerID: peerID,
|
||||
}
|
||||
|
||||
err := s.db.Clauses(clause.OnConflict{
|
||||
err := s.db.WithContext(ctx).Clauses(clause.OnConflict{
|
||||
Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}},
|
||||
DoNothing: true,
|
||||
}).Create(peer).Error
|
||||
@@ -1594,7 +1628,10 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
|
||||
}
|
||||
|
||||
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
|
||||
if err := s.db.Create(peer).Error; err != nil {
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
|
||||
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
|
||||
}
|
||||
|
||||
@@ -1720,7 +1757,10 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri
|
||||
}
|
||||
|
||||
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
|
||||
result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
ctx, cancel := getDebuggingCtx(ctx)
|
||||
defer cancel()
|
||||
|
||||
result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to increment network serial count in store")
|
||||
@@ -2762,3 +2802,33 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin
|
||||
|
||||
return groupPeers, nil
|
||||
}
|
||||
|
||||
func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string)
|
||||
if ok {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID)
|
||||
}
|
||||
|
||||
requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string)
|
||||
if ok {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID)
|
||||
}
|
||||
|
||||
accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string)
|
||||
if ok {
|
||||
//nolint
|
||||
ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID)
|
||||
}
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-grpcCtx.Done():
|
||||
log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err())
|
||||
}
|
||||
}()
|
||||
return ctx, cancel
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user