diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 8893ad2e2..782e46948 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -913,6 +913,7 @@ func (s *GRPCServer) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (*proto.Empty, error) { log.WithContext(ctx).Debugf("Logout request from peer [%s]", req.WgPubKey) + start := time.Now() empty := &proto.Empty{} peerKey, err := s.parseRequest(ctx, req, empty) @@ -944,7 +945,7 @@ func (s *GRPCServer) Logout(ctx context.Context, req *proto.EncryptedMessage) (* s.accountManager.BufferUpdateAccountPeers(ctx, peer.AccountID) - log.WithContext(ctx).Infof("peer %s logged out successfully", peerKey.String()) + log.WithContext(ctx).Debugf("peer %s logged out successfully after %s", peerKey.String(), time.Since(start)) return &proto.Empty{}, nil } diff --git a/management/server/peer.go b/management/server/peer.go index d72eac91a..a1f669f4f 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -609,7 +609,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s newPeer.DNSLabel = freeLabel newPeer.IP = freeIP - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + unlock := am.Store.AcquireReadLockByUID(ctx, accountID) defer func() { if unlock != nil { unlock() diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 5c52692f3..d974e7c21 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -1476,8 +1476,9 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { func Test_RegisterPeerRollbackOnFailure(t *testing.T) { engine := os.Getenv("NETBIRD_STORE_ENGINE") - if engine == "sqlite" || engine == "" { - t.Skip("Skipping test because sqlite test store is not respecting foreign keys") + if engine == "sqlite" || engine == "mysql" || engine == "" { + // we intentionally disabled foreign keys in mysql + t.Skip("Skipping test because store is not respecting foreign keys") } if runtime.GOOS == "windows" { t.Skip("The SQLite store is not properly supported by Windows yet") diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 1bcae7048..8aa56f7b0 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -24,6 +24,7 @@ import ( "gorm.io/gorm/logger" nbdns "github.com/netbirdio/netbird/dns" + nbcontext "github.com/netbirdio/netbird/management/server/context" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" @@ -76,7 +77,12 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met conns = runtime.NumCPU() } - if storeEngine == types.SqliteStoreEngine { + switch storeEngine { + case types.MysqlStoreEngine: + if err := db.Exec("SET GLOBAL FOREIGN_KEY_CHECKS = 0").Error; err != nil { + return nil, err + } + case types.SqliteStoreEngine: if err == nil { log.WithContext(ctx).Warnf("setting NB_SQL_MAX_OPEN_CONNS is not supported for sqlite, using default value 1") } @@ -142,14 +148,16 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) { func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) (unlock func()) { log.WithContext(ctx).Tracef("acquiring write lock for ID %s", uniqueID) - start := time.Now() + startWait := time.Now() value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{}) mtx := value.(*sync.RWMutex) mtx.Lock() + log.WithContext(ctx).Tracef("waiting to acquire write lock for ID %s in %v", uniqueID, time.Since(startWait)) + startHold := time.Now() unlock = func() { mtx.Unlock() - log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(start)) + log.WithContext(ctx).Tracef("released write lock for ID %s in %v", uniqueID, time.Since(startHold)) } return unlock @@ -159,14 +167,16 @@ func (s *SqlStore) AcquireWriteLockByUID(ctx context.Context, uniqueID string) ( func (s *SqlStore) AcquireReadLockByUID(ctx context.Context, uniqueID string) (unlock func()) { log.WithContext(ctx).Tracef("acquiring read lock for ID %s", uniqueID) - start := time.Now() + startWait := time.Now() value, _ := s.resourceLocks.LoadOrStore(uniqueID, &sync.RWMutex{}) mtx := value.(*sync.RWMutex) mtx.RLock() + log.WithContext(ctx).Tracef("waiting to acquire read lock for ID %s in %v", uniqueID, time.Since(startWait)) + startHold := time.Now() unlock = func() { mtx.RUnlock() - log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(start)) + log.WithContext(ctx).Tracef("released read lock for ID %s in %v", uniqueID, time.Since(startHold)) } return unlock @@ -604,13 +614,16 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren } func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) { + ctx, cancel := getDebuggingCtx(ctx) + defer cancel() + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } var user types.User - result := tx.Take(&user, idQueryCondition, userID) + result := tx.WithContext(ctx).Take(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewUserNotFoundError(userID) @@ -1076,13 +1089,16 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock } func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { + ctx, cancel := getDebuggingCtx(ctx) + defer cancel() + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } var accountNetwork types.AccountNetwork - if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil { + if err := tx.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountID).Take(&accountNetwork).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.NewAccountNotFoundError(accountID) } @@ -1092,13 +1108,16 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt } func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { + ctx, cancel := getDebuggingCtx(ctx) + defer cancel() + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } var peer nbpeer.Peer - result := tx.Take(&peer, GetKeyQueryCondition(s), peerKey) + result := tx.WithContext(ctx).Take(&peer, GetKeyQueryCondition(s), peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { @@ -1147,8 +1166,11 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { + ctx, cancel := getDebuggingCtx(ctx) + defer cancel() + var user types.User - result := s.db.Take(&user, accountAndIDQueryCondition, accountID, userID) + result := s.db.WithContext(ctx).Take(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return status.NewUserNotFoundError(userID) @@ -1329,13 +1351,16 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s } func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) { + ctx, cancel := getDebuggingCtx(ctx) + defer cancel() + tx := s.db if lockStrength != LockingStrengthNone { tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)}) } var setupKey types.SetupKey - result := tx. + result := tx.WithContext(ctx). Take(&setupKey, GetKeyQueryCondition(s), key) if result.Error != nil { @@ -1349,7 +1374,10 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking } func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { - result := s.db.Model(&types.SetupKey{}). + ctx, cancel := getDebuggingCtx(ctx) + defer cancel() + + result := s.db.WithContext(ctx).Model(&types.SetupKey{}). Where(idQueryCondition, setupKeyID). Updates(map[string]interface{}{ "used_times": gorm.Expr("used_times + 1"), @@ -1369,8 +1397,11 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string // AddPeerToAllGroup adds a peer to the 'All' group. Method always needs to run in a transaction func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { + ctx, cancel := getDebuggingCtx(ctx) + defer cancel() + var groupID string - _ = s.db.Model(types.Group{}). + _ = s.db.WithContext(ctx).Model(types.Group{}). Select("id"). Where("account_id = ? AND name = ?", accountID, "All"). Limit(1). @@ -1398,13 +1429,16 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer // AddPeerToGroup adds a peer to a group func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountID, peerID, groupID string) error { + ctx, cancel := getDebuggingCtx(ctx) + defer cancel() + peer := &types.GroupPeer{ AccountID: accountID, GroupID: groupID, PeerID: peerID, } - err := s.db.Clauses(clause.OnConflict{ + err := s.db.WithContext(ctx).Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "group_id"}, {Name: "peer_id"}}, DoNothing: true, }).Create(peer).Error @@ -1594,7 +1628,10 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt } func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { - if err := s.db.Create(peer).Error; err != nil { + ctx, cancel := getDebuggingCtx(ctx) + defer cancel() + + if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { return status.Errorf(status.Internal, "issue adding peer to account: %s", err) } @@ -1720,7 +1757,10 @@ func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID stri } func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { - result := s.db.Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + ctx, cancel := getDebuggingCtx(ctx) + defer cancel() + + result := s.db.WithContext(ctx).Model(&types.Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) if result.Error != nil { log.WithContext(ctx).Errorf("failed to increment network serial count in store: %v", result.Error) return status.Errorf(status.Internal, "failed to increment network serial count in store") @@ -2762,3 +2802,33 @@ func (s *SqlStore) GetAccountGroupPeers(ctx context.Context, lockStrength Lockin return groupPeers, nil } + +func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + userID, ok := grpcCtx.Value(nbcontext.UserIDKey).(string) + if ok { + //nolint + ctx = context.WithValue(ctx, nbcontext.UserIDKey, userID) + } + + requestID, ok := grpcCtx.Value(nbcontext.RequestIDKey).(string) + if ok { + //nolint + ctx = context.WithValue(ctx, nbcontext.RequestIDKey, requestID) + } + + accountID, ok := grpcCtx.Value(nbcontext.AccountIDKey).(string) + if ok { + //nolint + ctx = context.WithValue(ctx, nbcontext.AccountIDKey, accountID) + } + + go func() { + select { + case <-ctx.Done(): + case <-grpcCtx.Done(): + log.WithContext(grpcCtx).Warnf("grpc context ended early, error: %v", grpcCtx.Err()) + } + }() + return ctx, cancel +}