diff --git a/management/server/account.go b/management/server/account.go index 5c474a343..e092a4633 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -44,6 +44,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" ) const ( @@ -101,6 +102,10 @@ type DefaultAccountManager struct { accountUpdateLocks sync.Map updateAccountPeersBufferInterval atomic.Int64 + + addPeerMap sync.Map + addPeerSemaphore *semaphoregroup.SemaphoreGroup + addPeerSemaphoreMu sync.Mutex } // getJWTGroupsChanges calculates the changes needed to sync a user's JWT groups. diff --git a/management/server/peer.go b/management/server/peer.go index 9ff80442e..f4d8abd63 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -19,6 +19,7 @@ import ( "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" + semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/posture" @@ -762,6 +763,13 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) } +type addPeerResult struct { + Peer *nbpeer.Peer + NetworkMap *types.NetworkMap + Checks []*posture.Checks + Error error +} + func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, login types.PeerLogin, err error) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) { if errStatus, ok := status.FromError(err); ok && errStatus.Type() == status.NotFound { // we couldn't find this peer by its public key which can mean that peer hasn't been registered yet. @@ -773,7 +781,43 @@ func (am *DefaultAccountManager) handlePeerLoginNotFound(ctx context.Context, lo Location: nbpeer.Location{ConnectionIP: login.ConnectionIP}, } - return am.AddPeer(ctx, login.SetupKey, login.UserID, newPeer) + if v, ok := am.addPeerMap.Load(login.WireGuardPubKey); ok { + if r, ok := v.(addPeerResult); ok { + return r.Peer, r.NetworkMap, r.Checks, r.Error + } + } + + am.addPeerSemaphoreMu.Lock() + if am.addPeerSemaphore == nil { + am.addPeerSemaphore = semaphoregroup.NewSemaphoreGroup(10) + // ensures cleanup go routine waits for at least one add peer to be done + am.addPeerSemaphore.Add(context.Background()) + go func() { + // cleanup routine, keep recreating the semaphore + am.addPeerSemaphore.Wait() + am.addPeerSemaphore = nil + }() + } else { + am.addPeerSemaphore.Add(context.Background()) + } + am.addPeerSemaphoreMu.Unlock() + + go func() { + defer am.addPeerSemaphore.Done(context.Background()) + + peer, nMap, checks, err := am.AddPeer(context.Background(), login.SetupKey, login.UserID, newPeer) + + am.addPeerMap.Store(login.WireGuardPubKey, addPeerResult{peer, nMap, checks, err}) + }() + + // check again! + if v, ok := am.addPeerMap.Load(login.WireGuardPubKey); ok { + if r, ok := v.(addPeerResult); ok { + return r.Peer, r.NetworkMap, r.Checks, r.Error + } + } + + return nil, nil, nil, status.Errorf(status.Internal, "failed while logging in peer") } log.WithContext(ctx).Errorf("failed while logging in peer %s: %v", login.WireGuardPubKey, err)