diff --git a/signal/client/client.go b/signal/client/client.go index 2cfbd70e5..5378b1625 100644 --- a/signal/client/client.go +++ b/signal/client/client.go @@ -10,12 +10,17 @@ import ( // A set of tools to exchange connection details (Wireguard endpoints) with the remote peer. +// Client is an interface describing Signal client type Client interface { - Receive(msgHandler func(msg *proto.Message) error) + // Receive handles incoming messages from the Signal service + Receive(msgHandler func(msg *proto.Message) error) error Close() error + // Send sends a message to the Signal service (just one time rpc call, not stream) Send(msg *proto.Message) error + // SendToStream sends a message to the Signal service through a connected stream SendToStream(msg *proto.EncryptedMessage) error - WaitConnected() + // WaitStreamConnected blocks until client is connected to the Signal stream + WaitStreamConnected() } // decryptMessage decrypts the body of the msg using Wireguard private key and Remote peer's public key diff --git a/signal/client/client_test.go b/signal/client/client_test.go index 55aeaf2c6..19c49a6a4 100644 --- a/signal/client/client_test.go +++ b/signal/client/client_test.go @@ -5,6 +5,7 @@ import ( . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" log "github.com/sirupsen/logrus" + "github.com/wiretrustee/wiretrustee/signal/peer" sigProto "github.com/wiretrustee/wiretrustee/signal/proto" "github.com/wiretrustee/wiretrustee/signal/server" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -159,7 +160,7 @@ var _ = Describe("Client", func() { }) -func createSignalClient(addr string, key wgtypes.Key) *Client { +func createSignalClient(addr string, key wgtypes.Key) Client { var sigTLSEnabled = false client, err := NewClient(context.Background(), addr, key, sigTLSEnabled) if err != nil { @@ -189,7 +190,7 @@ func startSignal() (*grpc.Server, net.Listener) { panic(err) } s := grpc.NewServer() - sigProto.RegisterSignalExchangeServer(s, server.NewServer()) + sigProto.RegisterSignalExchangeServer(s, server.NewServer(peer.NewRegistry())) go func() { if err := s.Serve(lis); err != nil { log.Fatalf("failed to serve: %v", err) diff --git a/signal/client/grpc_client.go b/signal/client/grpc.go similarity index 57% rename from signal/client/grpc_client.go rename to signal/client/grpc.go index 25d818bff..e44ae8e76 100644 --- a/signal/client/grpc_client.go +++ b/signal/client/grpc.go @@ -10,6 +10,7 @@ import ( "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/metadata" @@ -19,6 +20,14 @@ import ( "time" ) +// A set of tools to exchange connection details (Wireguard endpoints) with the remote peer. + +// Status is the status of the client +type Status string + +const streamConnected Status = "streamConnected" +const streamDisconnected Status = "streamDisconnected" + // GrpcClient Wraps the Signal Exchange Service gRpc client type GrpcClient struct { key wgtypes.Key @@ -26,8 +35,11 @@ type GrpcClient struct { signalConn *grpc.ClientConn ctx context.Context stream proto.SignalExchange_ConnectStreamClient - //waiting group to notify once stream is connected - connWg *sync.WaitGroup //todo use a channel instead?? + // connectedCh used to notify goroutines waiting for the connection to the Signal stream + connectedCh chan struct{} + mux sync.Mutex + // streamConnected indicates whether this GrpcClient is streamConnected to the Signal stream + status Status } // Close Closes underlying connections to the Signal Exchange @@ -35,7 +47,7 @@ func (c *GrpcClient) Close() error { return c.signalConn.Close() } -// NewClient creates a new Signal client +// NewClient creates a new Signal GrpcClient func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { transportOption := grpc.WithInsecure() @@ -57,17 +69,17 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo })) if err != nil { - log.Errorf("failed to connect to the Signal gRPC server %v", err) + log.Errorf("failed to connect to the signalling server %v", err) return nil, err } - var wg sync.WaitGroup return &GrpcClient{ realClient: proto.NewSignalExchangeClient(conn), ctx: ctx, signalConn: conn, key: key, - connWg: &wg, + mux: sync.Mutex{}, + status: streamDisconnected, }, nil } @@ -77,8 +89,8 @@ func defaultBackoff(ctx context.Context) backoff.BackOff { InitialInterval: 800 * time.Millisecond, RandomizationFactor: backoff.DefaultRandomizationFactor, Multiplier: backoff.DefaultMultiplier, - MaxInterval: 15 * time.Minute, - MaxElapsedTime: time.Hour, //stop after an hour of trying, the error will be propagated to the general retry of the client + MaxInterval: 10 * time.Second, + MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client Stop: backoff.Stop, Clock: backoff.SystemClock, }, ctx) @@ -87,38 +99,76 @@ func defaultBackoff(ctx context.Context) backoff.BackOff { // Receive Connects to the Signal Exchange message stream and starts receiving messages. // The messages will be handled by msgHandler function provided. -// This function runs a goroutine underneath and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart) -// The key is the identifier of our Peer (could be Wireguard public key) -func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) { - c.connWg.Add(1) - go func() { +// This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart) +// The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller. +func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error { - var backOff = defaultBackoff(c.ctx) + var backOff = defaultBackoff(c.ctx) - operation := func() error { + operation := func() error { - stream, err := c.connect(c.key.PublicKey().String()) - if err != nil { - log.Warnf("disconnected from the Signal Exchange due to an error: %v", err) - c.connWg.Add(1) - return err - } + c.notifyStreamDisconnected() - err = c.receive(stream, msgHandler) - if err != nil { - backOff.Reset() - return err - } - - return nil + log.Debugf("signal connection state %v", c.signalConn.GetState()) + if !c.ready() { + return fmt.Errorf("no connection to signal") } - err := backoff.Retry(operation, backOff) + // connect to Signal stream identifying ourselves with a public Wireguard key + // todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management) + stream, err := c.connect(c.key.PublicKey().String()) if err != nil { - log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err) - return + log.Warnf("streamDisconnected from the Signal Exchange due to an error: %v", err) + return err } - }() + + c.notifyStreamConnected() + + log.Infof("streamConnected to the Signal Service stream") + + // start receiving messages from the Signal stream (from other peers through signal) + err = c.receive(stream, msgHandler) + if err != nil { + log.Warnf("streamDisconnected from the Signal Exchange due to an error: %v", err) + backOff.Reset() + return err + } + + return nil + } + + err := backoff.Retry(operation, backOff) + if err != nil { + log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err) + return err + } + + return nil +} +func (c *GrpcClient) notifyStreamDisconnected() { + c.mux.Lock() + defer c.mux.Unlock() + c.status = streamDisconnected +} + +func (c *GrpcClient) notifyStreamConnected() { + c.mux.Lock() + defer c.mux.Unlock() + c.status = streamConnected + if c.connectedCh != nil { + // there are goroutines waiting on this channel -> release them + close(c.connectedCh) + c.connectedCh = nil + } +} + +func (c *GrpcClient) getStreamStatusChan() <-chan struct{} { + c.mux.Lock() + defer c.mux.Unlock() + if c.connectedCh == nil { + c.connectedCh = make(chan struct{}) + } + return c.connectedCh } func (c *GrpcClient) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) { @@ -143,24 +193,37 @@ func (c *GrpcClient) connect(key string) (proto.SignalExchange_ConnectStreamClie if len(registered) == 0 { return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams") } - //connection established we are good to use the stream - c.connWg.Done() - - log.Infof("connected to the Signal Exchange Stream") return stream, nil } -// WaitConnected waits until the client is connected to the message stream -func (c *GrpcClient) WaitConnected() { - c.connWg.Wait() +// ready indicates whether the client is okay and ready to be used +// for now it just checks whether gRPC connection to the service is in state Ready +func (c *GrpcClient) ready() bool { + return c.signalConn.GetState() == connectivity.Ready +} + +// WaitStreamConnected waits until the client is connected to the Signal stream +func (c *GrpcClient) WaitStreamConnected() { + + if c.status == streamConnected { + return + } + + ch := c.getStreamStatusChan() + select { + case <-c.ctx.Done(): + case <-ch: + } } // SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server // The Client.Receive method must be called before sending messages to establish initial connection to the Signal Exchange // Client.connWg can be used to wait func (c *GrpcClient) SendToStream(msg *proto.EncryptedMessage) error { - + if !c.ready() { + return fmt.Errorf("no connection to signal") + } if c.stream == nil { return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call Client.Receive before sending messages") } @@ -177,13 +240,16 @@ func (c *GrpcClient) SendToStream(msg *proto.EncryptedMessage) error { // Send sends a message to the remote Peer through the Signal Exchange. func (c *GrpcClient) Send(msg *proto.Message) error { + if !c.ready() { + return fmt.Errorf("no connection to signal") + } + encryptedMessage, err := encryptMessage(msg, c.key) if err != nil { return err } - _, err = c.realClient.Send(context.TODO(), encryptedMessage) + _, err = c.realClient.Send(c.ctx, encryptedMessage) if err != nil { - log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err) return err } @@ -200,10 +266,10 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient, log.Warnf("stream canceled (usually indicates shutdown)") return err } else if s.Code() == codes.Unavailable { - log.Warnf("server has been stopped") + log.Warnf("Signal Service is unavailable") return err } else if err == io.EOF { - log.Warnf("stream closed by server") + log.Warnf("Signal Service stream closed by server") return err } else if err != nil { return err diff --git a/signal/client/ws_client.go b/signal/client/websocket.go similarity index 96% rename from signal/client/ws_client.go rename to signal/client/websocket.go index 1a8ae4dcc..2ec37a57c 100644 --- a/signal/client/ws_client.go +++ b/signal/client/websocket.go @@ -42,14 +42,14 @@ func (c *WebsocketClient) Close() error { return c.conn.Close(websocket.StatusNormalClosure, "close") } -func (c *WebsocketClient) Receive(msgHandler func(msg *proto.Message) error) { +func (c *WebsocketClient) Receive(msgHandler func(msg *proto.Message) error) error { for { _, byteMsg, err := c.conn.Read(c.ctx) if err != nil { log.Errorf("failed reading message from Signal Websocket %v", err) time.Sleep(2 * time.Second) //todo propagate to the upper layer and retry - return + return err } encryptedMsg := &proto.EncryptedMessage{} @@ -97,6 +97,6 @@ func (c *WebsocketClient) Send(msg *proto.Message) error { } -func (c *WebsocketClient) WaitConnected() { +func (c *WebsocketClient) WaitStreamConnected() { } diff --git a/signal/server/http/server.go b/signal/server/http/server.go index 56ace1dc1..b218db189 100644 --- a/signal/server/http/server.go +++ b/signal/server/http/server.go @@ -117,10 +117,10 @@ func (s *Server) serveWs(w http.ResponseWriter, r *http.Request) { conn.SetReadLimit(1024 * 1024 * 3) for { - t, byteMsg, err := conn.ReadMessage() + _, byteMsg, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { - log.Printf("error: %v", err, t) + log.Errorf("error: %v", err) } break }