diff --git a/client/internal/engine.go b/client/internal/engine.go index 1fd011635..4145665e9 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -4,7 +4,7 @@ import ( "context" "fmt" "github.com/cenkalti/backoff/v4" - ice "github.com/pion/ice/v2" + "github.com/pion/ice/v2" log "github.com/sirupsen/logrus" "github.com/wiretrustee/wiretrustee/iface" mgm "github.com/wiretrustee/wiretrustee/management/client" @@ -142,7 +142,7 @@ func (e *Engine) initializePeer(peer Peer) { RandomizationFactor: backoff.DefaultRandomizationFactor, Multiplier: backoff.DefaultMultiplier, MaxInterval: 5 * time.Second, - MaxElapsedTime: time.Duration(0), //never stop + MaxElapsedTime: 0, //never stop Stop: backoff.Stop, Clock: backoff.SystemClock, }, e.ctx) @@ -157,8 +157,7 @@ func (e *Engine) initializePeer(peer Peer) { } if err != nil { - log.Warnln(err) - log.Debugf("retrying connection because of error: %s", err.Error()) + log.Infof("retrying connection because of error: %s", err.Error()) return err } return nil @@ -332,6 +331,8 @@ func (e *Engine) receiveManagementEvents() { return nil }) if err != nil { + // happens if management is unavailable for a long time. + // We want to cancel the operation of the whole client e.cancel() return } @@ -414,68 +415,77 @@ func (e *Engine) updatePeers(remotePeers []*mgmProto.RemotePeerConfig) error { // receiveSignalEvents connects to the Signal Service event stream to negotiate connection with remote peers func (e *Engine) receiveSignalEvents() { - // connect to a stream of messages coming from the signal server - e.signal.Receive(func(msg *sProto.Message) error { - e.syncMsgMux.Lock() - defer e.syncMsgMux.Unlock() + go func() { + // connect to a stream of messages coming from the signal server + err := e.signal.Receive(func(msg *sProto.Message) error { - conn := e.conns[msg.Key] - if conn == nil { - return fmt.Errorf("wrongly addressed message %s", msg.Key) - } + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() - if conn.Config.RemoteWgKey.String() != msg.Key { - return fmt.Errorf("unknown peer %s", msg.Key) - } - - switch msg.GetBody().Type { - case sProto.Body_OFFER: - remoteCred, err := signal.UnMarshalCredential(msg) - if err != nil { - return err + conn := e.conns[msg.Key] + if conn == nil { + return fmt.Errorf("wrongly addressed message %s", msg.Key) } - err = conn.OnOffer(IceCredentials{ - uFrag: remoteCred.UFrag, - pwd: remoteCred.Pwd, - }) - if err != nil { - return err + if conn.Config.RemoteWgKey.String() != msg.Key { + return fmt.Errorf("unknown peer %s", msg.Key) + } + + switch msg.GetBody().Type { + case sProto.Body_OFFER: + remoteCred, err := signal.UnMarshalCredential(msg) + if err != nil { + return err + } + err = conn.OnOffer(IceCredentials{ + uFrag: remoteCred.UFrag, + pwd: remoteCred.Pwd, + }) + + if err != nil { + return err + } + + return nil + case sProto.Body_ANSWER: + remoteCred, err := signal.UnMarshalCredential(msg) + if err != nil { + return err + } + err = conn.OnAnswer(IceCredentials{ + uFrag: remoteCred.UFrag, + pwd: remoteCred.Pwd, + }) + + if err != nil { + return err + } + + case sProto.Body_CANDIDATE: + + candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload) + if err != nil { + log.Errorf("failed on parsing remote candidate %s -> %s", candidate, err) + return err + } + + err = conn.OnRemoteCandidate(candidate) + if err != nil { + log.Errorf("error handling CANDIATE from %s", msg.Key) + return err + } } return nil - case sProto.Body_ANSWER: - remoteCred, err := signal.UnMarshalCredential(msg) - if err != nil { - return err - } - err = conn.OnAnswer(IceCredentials{ - uFrag: remoteCred.UFrag, - pwd: remoteCred.Pwd, - }) - - if err != nil { - return err - } - - case sProto.Body_CANDIDATE: - - candidate, err := ice.UnmarshalCandidate(msg.GetBody().Payload) - if err != nil { - log.Errorf("failed on parsing remote candidate %s -> %s", candidate, err) - return err - } - - err = conn.OnRemoteCandidate(candidate) - if err != nil { - log.Errorf("error handling CANDIATE from %s", msg.Key) - return err - } + }) + if err != nil { + // happens if signal is unavailable for a long time. + // We want to cancel the operation of the whole client + e.cancel() + return } + }() - return nil - }) - - e.signal.WaitConnected() + e.signal.WaitStreamConnected() } diff --git a/management/client/client.go b/management/client/client.go index 860afb13a..891a9f980 100644 --- a/management/client/client.go +++ b/management/client/client.go @@ -3,6 +3,7 @@ package client import ( "context" "crypto/tls" + "fmt" "github.com/cenkalti/backoff/v4" log "github.com/sirupsen/logrus" "github.com/wiretrustee/wiretrustee/client/system" @@ -10,6 +11,7 @@ import ( "github.com/wiretrustee/wiretrustee/management/proto" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" "io" @@ -71,12 +73,18 @@ func defaultBackoff(ctx context.Context) backoff.BackOff { RandomizationFactor: backoff.DefaultRandomizationFactor, Multiplier: backoff.DefaultMultiplier, MaxInterval: 10 * time.Second, - MaxElapsedTime: 30 * time.Minute, //stop after an 30 min of trying, the error will be propagated to the general retry of the client + MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client Stop: backoff.Stop, Clock: backoff.SystemClock, }, ctx) } +// ready indicates whether the client is okay and ready to be used +// for now it just checks whether gRPC connection to the service is ready +func (c *Client) ready() bool { + return c.conn.GetState() == connectivity.Ready +} + // Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages // Blocking request. The result will be sent via msgHandler callback function func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error { @@ -85,6 +93,12 @@ func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error { operation := func() error { + log.Debugf("management connection state %v", c.conn.GetState()) + + if !c.ready() { + return fmt.Errorf("no connection to management") + } + // todo we already have it since we did the Login, maybe cache it locally? serverPubKey, err := c.GetServerPublicKey() if err != nil { @@ -98,7 +112,7 @@ func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error { return err } - log.Infof("connected to the Management Service Stream") + log.Infof("connected to the Management Service stream") // blocking until error err = c.receiveEvents(stream, *serverPubKey, msgHandler) @@ -139,7 +153,7 @@ func (c *Client) receiveEvents(stream proto.ManagementService_SyncClient, server for { update, err := stream.Recv() if err == io.EOF { - log.Errorf("managment stream was closed: %s", err) + log.Errorf("Management stream has been closed by server: %s", err) return err } if err != nil { @@ -165,6 +179,10 @@ func (c *Client) receiveEvents(stream proto.ManagementService_SyncClient, server // GetServerPublicKey returns server Wireguard public key (used later for encrypting messages sent to the server) func (c *Client) GetServerPublicKey() (*wgtypes.Key, error) { + if !c.ready() { + return nil, fmt.Errorf("no connection to management") + } + mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting defer cancel() resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{}) @@ -181,6 +199,9 @@ func (c *Client) GetServerPublicKey() (*wgtypes.Key, error) { } func (c *Client) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) { + if !c.ready() { + return nil, fmt.Errorf("no connection to management") + } loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) if err != nil { log.Errorf("failed to encrypt message: %s", err) diff --git a/signal/client/client.go b/signal/client/client.go index 52c06aa84..f49375cc4 100644 --- a/signal/client/client.go +++ b/signal/client/client.go @@ -11,6 +11,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" @@ -23,6 +24,12 @@ import ( // A set of tools to exchange connection details (Wireguard endpoints) with the remote peer. +// Status is the status of the client +type Status string + +const streamConnected Status = "streamConnected" +const streamDisconnected Status = "streamDisconnected" + // Client Wraps the Signal Exchange Service gRpc client type Client struct { key wgtypes.Key @@ -30,8 +37,11 @@ type Client struct { signalConn *grpc.ClientConn ctx context.Context stream proto.SignalExchange_ConnectStreamClient - //waiting group to notify once stream is connected - connWg *sync.WaitGroup //todo use a channel instead?? + // connectedCh used to notify goroutines waiting for the connection to the Signal stream + connectedCh chan struct{} + mux sync.Mutex + // streamConnected indicates whether this client is streamConnected to the Signal stream + status Status } // Close Closes underlying connections to the Signal Exchange @@ -65,13 +75,13 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo return nil, err } - var wg sync.WaitGroup return &Client{ realClient: proto.NewSignalExchangeClient(conn), ctx: ctx, signalConn: conn, key: key, - connWg: &wg, + mux: sync.Mutex{}, + status: streamDisconnected, }, nil } @@ -82,7 +92,7 @@ func defaultBackoff(ctx context.Context) backoff.BackOff { RandomizationFactor: backoff.DefaultRandomizationFactor, Multiplier: backoff.DefaultMultiplier, MaxInterval: 10 * time.Second, - MaxElapsedTime: 30 * time.Minute, //stop after an 30 min of trying, the error will be propagated to the general retry of the client + MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client Stop: backoff.Stop, Clock: backoff.SystemClock, }, ctx) @@ -91,38 +101,76 @@ func defaultBackoff(ctx context.Context) backoff.BackOff { // Receive Connects to the Signal Exchange message stream and starts receiving messages. // The messages will be handled by msgHandler function provided. -// This function runs a goroutine underneath and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart) -// The key is the identifier of our Peer (could be Wireguard public key) -func (c *Client) Receive(msgHandler func(msg *proto.Message) error) { - c.connWg.Add(1) - go func() { +// This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart) +// The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller. +func (c *Client) Receive(msgHandler func(msg *proto.Message) error) error { - var backOff = defaultBackoff(c.ctx) + var backOff = defaultBackoff(c.ctx) - operation := func() error { + operation := func() error { - stream, err := c.connect(c.key.PublicKey().String()) - if err != nil { - log.Warnf("disconnected from the Signal Exchange due to an error: %v", err) - c.connWg.Add(1) - return err - } + c.notifyStreamDisconnected() - err = c.receive(stream, msgHandler) - if err != nil { - backOff.Reset() - return err - } - - return nil + log.Debugf("signal connection state %v", c.signalConn.GetState()) + if !c.ready() { + return fmt.Errorf("no connection to signal") } - err := backoff.Retry(operation, backOff) + // connect to Signal stream identifying ourselves with a public Wireguard key + // todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management) + stream, err := c.connect(c.key.PublicKey().String()) if err != nil { - log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err) - return + log.Warnf("streamDisconnected from the Signal Exchange due to an error: %v", err) + return err } - }() + + c.notifyStreamConnected() + + log.Infof("streamConnected to the Signal Service stream") + + // start receiving messages from the Signal stream (from other peers through signal) + err = c.receive(stream, msgHandler) + if err != nil { + log.Warnf("streamDisconnected from the Signal Exchange due to an error: %v", err) + backOff.Reset() + return err + } + + return nil + } + + err := backoff.Retry(operation, backOff) + if err != nil { + log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err) + return err + } + + return nil +} +func (c *Client) notifyStreamDisconnected() { + c.mux.Lock() + defer c.mux.Unlock() + c.status = streamDisconnected +} + +func (c *Client) notifyStreamConnected() { + c.mux.Lock() + defer c.mux.Unlock() + c.status = streamConnected + if c.connectedCh != nil { + // there are goroutines waiting on this channel -> release them + close(c.connectedCh) + c.connectedCh = nil + } +} + +func (c *Client) getStreamStatusChan() <-chan struct{} { + c.mux.Lock() + defer c.mux.Unlock() + if c.connectedCh == nil { + c.connectedCh = make(chan struct{}) + } + return c.connectedCh } func (c *Client) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) { @@ -147,24 +195,37 @@ func (c *Client) connect(key string) (proto.SignalExchange_ConnectStreamClient, if len(registered) == 0 { return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams") } - //connection established we are good to use the stream - c.connWg.Done() - - log.Infof("connected to the Signal Exchange Stream") return stream, nil } -// WaitConnected waits until the client is connected to the message stream -func (c *Client) WaitConnected() { - c.connWg.Wait() +// ready indicates whether the client is okay and ready to be used +// for now it just checks whether gRPC connection to the service is in state Ready +func (c *Client) ready() bool { + return c.signalConn.GetState() == connectivity.Ready +} + +// WaitStreamConnected waits until the client is connected to the Signal stream +func (c *Client) WaitStreamConnected() { + + if c.status == streamConnected { + return + } + + ch := c.getStreamStatusChan() + select { + case <-c.ctx.Done(): + case <-ch: + } } // SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server // The Client.Receive method must be called before sending messages to establish initial connection to the Signal Exchange // Client.connWg can be used to wait func (c *Client) SendToStream(msg *proto.EncryptedMessage) error { - + if !c.ready() { + return fmt.Errorf("no connection to signal") + } if c.stream == nil { return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call Client.Receive before sending messages") } @@ -221,13 +282,17 @@ func (c *Client) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, er // Send sends a message to the remote Peer through the Signal Exchange. func (c *Client) Send(msg *proto.Message) error { + if !c.ready() { + return fmt.Errorf("no connection to signal") + } + encryptedMessage, err := c.encryptMessage(msg) if err != nil { return err } _, err = c.realClient.Send(context.TODO(), encryptedMessage) if err != nil { - log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err) + //log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err) return err } @@ -244,10 +309,10 @@ func (c *Client) receive(stream proto.SignalExchange_ConnectStreamClient, log.Warnf("stream canceled (usually indicates shutdown)") return err } else if s.Code() == codes.Unavailable { - log.Warnf("server has been stopped") + log.Warnf("Signal Service is unavailable") return err } else if err == io.EOF { - log.Warnf("stream closed by server") + log.Warnf("Signal Service stream closed by server") return err } else if err != nil { return err diff --git a/signal/client/client_test.go b/signal/client/client_test.go index 2ac5b03ee..55aeaf2c6 100644 --- a/signal/client/client_test.go +++ b/signal/client/client_test.go @@ -36,7 +36,7 @@ var _ = Describe("Client", func() { }) Describe("Exchanging messages", func() { - Context("between connected peers", func() { + Context("between streamConnected peers", func() { It("should be successful", func() { var msgReceived sync.WaitGroup @@ -48,30 +48,42 @@ var _ = Describe("Client", func() { // connect PeerA to Signal keyA, _ := wgtypes.GenerateKey() clientA := createSignalClient(addr, keyA) - clientA.Receive(func(msg *sigProto.Message) error { - receivedOnA = msg.GetBody().GetPayload() - msgReceived.Done() - return nil - }) - clientA.WaitConnected() + go func() { + err := clientA.Receive(func(msg *sigProto.Message) error { + receivedOnA = msg.GetBody().GetPayload() + msgReceived.Done() + return nil + }) + if err != nil { + return + } + }() + clientA.WaitStreamConnected() // connect PeerB to Signal keyB, _ := wgtypes.GenerateKey() clientB := createSignalClient(addr, keyB) - clientB.Receive(func(msg *sigProto.Message) error { - receivedOnB = msg.GetBody().GetPayload() - err := clientB.Send(&sigProto.Message{ - Key: keyB.PublicKey().String(), - RemoteKey: keyA.PublicKey().String(), - Body: &sigProto.Body{Payload: "pong"}, + + go func() { + err := clientB.Receive(func(msg *sigProto.Message) error { + receivedOnB = msg.GetBody().GetPayload() + err := clientB.Send(&sigProto.Message{ + Key: keyB.PublicKey().String(), + RemoteKey: keyA.PublicKey().String(), + Body: &sigProto.Body{Payload: "pong"}, + }) + if err != nil { + Fail("failed sending a message to PeerA") + } + msgReceived.Done() + return nil }) if err != nil { - Fail("failed sending a message to PeerA") + return } - msgReceived.Done() - return nil - }) - clientB.WaitConnected() + }() + + clientB.WaitStreamConnected() // PeerA initiates ping-pong err := clientA.Send(&sigProto.Message{ @@ -100,11 +112,15 @@ var _ = Describe("Client", func() { key, _ := wgtypes.GenerateKey() client := createSignalClient(addr, key) - client.Receive(func(msg *sigProto.Message) error { - return nil - }) - client.WaitConnected() - + go func() { + err := client.Receive(func(msg *sigProto.Message) error { + return nil + }) + if err != nil { + return + } + }() + client.WaitStreamConnected() Expect(client).NotTo(BeNil()) }) })