diff --git a/client/internal/engine.go b/client/internal/engine.go index 036151c76..b4c44cba0 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -42,6 +42,8 @@ type Engine struct { mgmClient *mgm.Client // conns is a collection of remote peer connections indexed by local public key of the remote peers conns map[string]*Connection + // peerMap is a map that holds all the peers that are known to this peer + peerMap map[string]struct{} // peerMux is used to sync peer operations (e.g. open connection, peer removal) peerMux *sync.Mutex @@ -75,6 +77,7 @@ func NewEngine(signalClient *signal.Client, mgmClient *mgm.Client, config *Engin signal: signalClient, mgmClient: mgmClient, conns: map[string]*Connection{}, + peerMap: map[string]struct{}{}, peerMux: &sync.Mutex{}, syncMsgMux: &sync.Mutex{}, config: config, @@ -139,6 +142,9 @@ func (e *Engine) Start() error { // initializePeer peer agent attempt to open connection func (e *Engine) initializePeer(peer Peer) { + + e.peerMap[peer.WgPubKey] = struct{}{} + var backOff = backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: backoff.DefaultInitialInterval, RandomizationFactor: backoff.DefaultRandomizationFactor, @@ -158,8 +164,8 @@ func (e *Engine) initializePeer(peer Peer) { _, err := e.openPeerConnection(e.wgPort, e.config.WgPrivateKey, peer) e.peerMux.Lock() defer e.peerMux.Unlock() - if _, ok := e.conns[peer.WgPubKey]; !ok { - log.Debugf("removed connection attempt to peer: %v, not retrying", peer.WgPubKey) + if _, ok := e.peerMap[peer.WgPubKey]; !ok { + log.Debugf("peer was removed: %v, stop connecting", peer.WgPubKey) return nil } @@ -170,18 +176,18 @@ func (e *Engine) initializePeer(peer Peer) { return nil } - err := backoff.Retry(operation, backOff) - if err != nil { - // should actually never happen - panic(err) - } + go func() { + err := backoff.Retry(operation, backOff) + if err != nil { + // should actually never happen + panic(err) + } + }() } -func (e *Engine) removePeerConnections(peers []string) error { - e.peerMux.Lock() - defer e.peerMux.Unlock() +func (e *Engine) removePeers(peers []string) error { for _, peer := range peers { - err := e.removePeerConnection(peer) + err := e.removePeer(peer) if err != nil { return err } @@ -194,7 +200,7 @@ func (e *Engine) removeAllPeerConnections() error { e.peerMux.Lock() defer e.peerMux.Unlock() for peer := range e.conns { - err := e.removePeerConnection(peer) + err := e.removePeer(peer) if err != nil { return err } @@ -202,14 +208,17 @@ func (e *Engine) removeAllPeerConnections() error { return nil } -// removePeerConnection closes existing peer connection and removes peer -func (e *Engine) removePeerConnection(peerKey string) error { +// removePeer closes an existing peer connection and removes a peer +func (e *Engine) removePeer(peerKey string) error { + + delete(e.peerMap, peerKey) + conn, exists := e.conns[peerKey] if exists && conn != nil { delete(e.conns, peerKey) return conn.Close() } - log.Infof("removed connection to peer %s", peerKey) + log.Infof("removed peer %s", peerKey) return nil } @@ -228,7 +237,6 @@ func (e *Engine) GetPeerConnectionStatus(peerKey string) *Status { // openPeerConnection opens a new remote peer connection func (e *Engine) openPeerConnection(wgPort int, myKey wgtypes.Key, peer Peer) (*Connection, error) { - e.peerMux.Lock() remoteKey, _ := wgtypes.ParseKey(peer.WgPubKey) connConfig := &ConnConfig{ @@ -254,6 +262,7 @@ func (e *Engine) openPeerConnection(wgPort int, myKey wgtypes.Key, peer Peer) (* return signalCandidate(candidate, myKey, remoteKey, e.signal) } conn := NewConnection(*connConfig, signalCandidate, signalOffer, signalAnswer) + e.peerMux.Lock() e.conns[remoteKey.String()] = conn e.peerMux.Unlock() @@ -388,7 +397,9 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error { } func (e *Engine) updatePeers(remotePeers []*mgmProto.RemotePeerConfig) error { - log.Debugf("got peers update from Management Service, updating") + e.peerMux.Lock() + defer e.peerMux.Unlock() + log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(remotePeers)) remotePeerMap := make(map[string]struct{}) for _, peer := range remotePeers { remotePeerMap[peer.GetWgPubKey()] = struct{}{} @@ -401,7 +412,7 @@ func (e *Engine) updatePeers(remotePeers []*mgmProto.RemotePeerConfig) error { toRemove = append(toRemove, p) } } - err := e.removePeerConnections(toRemove) + err := e.removePeers(toRemove) if err != nil { return err } @@ -410,8 +421,8 @@ func (e *Engine) updatePeers(remotePeers []*mgmProto.RemotePeerConfig) error { for _, peer := range remotePeers { peerKey := peer.GetWgPubKey() peerIPs := peer.GetAllowedIps() - if _, ok := e.conns[peerKey]; !ok { - go e.initializePeer(Peer{ + if _, ok := e.peerMap[peerKey]; !ok { + e.initializePeer(Peer{ WgPubKey: peerKey, WgAllowedIps: strings.Join(peerIPs, ","), }) diff --git a/management/client/client.go b/management/client/client.go index f78e3eb57..6ccd62ec0 100644 --- a/management/client/client.go +++ b/management/client/client.go @@ -83,7 +83,7 @@ func defaultBackoff(ctx context.Context) backoff.BackOff { // ready indicates whether the client is okay and ready to be used // for now it just checks whether gRPC connection to the service is ready func (c *Client) ready() bool { - return c.conn.GetState() == connectivity.Ready + return c.conn.GetState() == connectivity.Ready || c.conn.GetState() == connectivity.Idle } // Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages diff --git a/signal/client/client.go b/signal/client/client.go index 0c9222a34..c086de54b 100644 --- a/signal/client/client.go +++ b/signal/client/client.go @@ -207,7 +207,7 @@ func (c *Client) connect(key string) (proto.SignalExchange_ConnectStreamClient, // ready indicates whether the client is okay and ready to be used // for now it just checks whether gRPC connection to the service is in state Ready func (c *Client) ready() bool { - return c.signalConn.GetState() == connectivity.Ready + return c.signalConn.GetState() == connectivity.Ready || c.signalConn.GetState() == connectivity.Idle } // WaitStreamConnected waits until the client is connected to the Signal stream