merge main

This commit is contained in:
braginini
2021-11-07 13:05:21 +01:00
parent b9aa2aa329
commit 04de743dff
5 changed files with 125 additions and 53 deletions

View File

@@ -10,12 +10,17 @@ import (
// A set of tools to exchange connection details (Wireguard endpoints) with the remote peer. // A set of tools to exchange connection details (Wireguard endpoints) with the remote peer.
// Client is an interface describing Signal client
type Client interface { type Client interface {
Receive(msgHandler func(msg *proto.Message) error) // Receive handles incoming messages from the Signal service
Receive(msgHandler func(msg *proto.Message) error) error
Close() error Close() error
// Send sends a message to the Signal service (just one time rpc call, not stream)
Send(msg *proto.Message) error Send(msg *proto.Message) error
// SendToStream sends a message to the Signal service through a connected stream
SendToStream(msg *proto.EncryptedMessage) error SendToStream(msg *proto.EncryptedMessage) error
WaitConnected() // WaitStreamConnected blocks until client is connected to the Signal stream
WaitStreamConnected()
} }
// decryptMessage decrypts the body of the msg using Wireguard private key and Remote peer's public key // decryptMessage decrypts the body of the msg using Wireguard private key and Remote peer's public key

View File

@@ -5,6 +5,7 @@ import (
. "github.com/onsi/ginkgo" . "github.com/onsi/ginkgo"
. "github.com/onsi/gomega" . "github.com/onsi/gomega"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/wiretrustee/wiretrustee/signal/peer"
sigProto "github.com/wiretrustee/wiretrustee/signal/proto" sigProto "github.com/wiretrustee/wiretrustee/signal/proto"
"github.com/wiretrustee/wiretrustee/signal/server" "github.com/wiretrustee/wiretrustee/signal/server"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
@@ -159,7 +160,7 @@ var _ = Describe("Client", func() {
}) })
func createSignalClient(addr string, key wgtypes.Key) *Client { func createSignalClient(addr string, key wgtypes.Key) Client {
var sigTLSEnabled = false var sigTLSEnabled = false
client, err := NewClient(context.Background(), addr, key, sigTLSEnabled) client, err := NewClient(context.Background(), addr, key, sigTLSEnabled)
if err != nil { if err != nil {
@@ -189,7 +190,7 @@ func startSignal() (*grpc.Server, net.Listener) {
panic(err) panic(err)
} }
s := grpc.NewServer() s := grpc.NewServer()
sigProto.RegisterSignalExchangeServer(s, server.NewServer()) sigProto.RegisterSignalExchangeServer(s, server.NewServer(peer.NewRegistry()))
go func() { go func() {
if err := s.Serve(lis); err != nil { if err := s.Serve(lis); err != nil {
log.Fatalf("failed to serve: %v", err) log.Fatalf("failed to serve: %v", err)

View File

@@ -10,6 +10,7 @@ import (
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/metadata" "google.golang.org/grpc/metadata"
@@ -19,6 +20,14 @@ import (
"time" "time"
) )
// A set of tools to exchange connection details (Wireguard endpoints) with the remote peer.
// Status is the status of the client
type Status string
const streamConnected Status = "streamConnected"
const streamDisconnected Status = "streamDisconnected"
// GrpcClient Wraps the Signal Exchange Service gRpc client // GrpcClient Wraps the Signal Exchange Service gRpc client
type GrpcClient struct { type GrpcClient struct {
key wgtypes.Key key wgtypes.Key
@@ -26,8 +35,11 @@ type GrpcClient struct {
signalConn *grpc.ClientConn signalConn *grpc.ClientConn
ctx context.Context ctx context.Context
stream proto.SignalExchange_ConnectStreamClient stream proto.SignalExchange_ConnectStreamClient
//waiting group to notify once stream is connected // connectedCh used to notify goroutines waiting for the connection to the Signal stream
connWg *sync.WaitGroup //todo use a channel instead?? connectedCh chan struct{}
mux sync.Mutex
// streamConnected indicates whether this GrpcClient is streamConnected to the Signal stream
status Status
} }
// Close Closes underlying connections to the Signal Exchange // Close Closes underlying connections to the Signal Exchange
@@ -35,7 +47,7 @@ func (c *GrpcClient) Close() error {
return c.signalConn.Close() return c.signalConn.Close()
} }
// NewClient creates a new Signal client // NewClient creates a new Signal GrpcClient
func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) {
transportOption := grpc.WithInsecure() transportOption := grpc.WithInsecure()
@@ -57,17 +69,17 @@ func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled boo
})) }))
if err != nil { if err != nil {
log.Errorf("failed to connect to the Signal gRPC server %v", err) log.Errorf("failed to connect to the signalling server %v", err)
return nil, err return nil, err
} }
var wg sync.WaitGroup
return &GrpcClient{ return &GrpcClient{
realClient: proto.NewSignalExchangeClient(conn), realClient: proto.NewSignalExchangeClient(conn),
ctx: ctx, ctx: ctx,
signalConn: conn, signalConn: conn,
key: key, key: key,
connWg: &wg, mux: sync.Mutex{},
status: streamDisconnected,
}, nil }, nil
} }
@@ -77,8 +89,8 @@ func defaultBackoff(ctx context.Context) backoff.BackOff {
InitialInterval: 800 * time.Millisecond, InitialInterval: 800 * time.Millisecond,
RandomizationFactor: backoff.DefaultRandomizationFactor, RandomizationFactor: backoff.DefaultRandomizationFactor,
Multiplier: backoff.DefaultMultiplier, Multiplier: backoff.DefaultMultiplier,
MaxInterval: 15 * time.Minute, MaxInterval: 10 * time.Second,
MaxElapsedTime: time.Hour, //stop after an hour of trying, the error will be propagated to the general retry of the client MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client
Stop: backoff.Stop, Stop: backoff.Stop,
Clock: backoff.SystemClock, Clock: backoff.SystemClock,
}, ctx) }, ctx)
@@ -87,38 +99,76 @@ func defaultBackoff(ctx context.Context) backoff.BackOff {
// Receive Connects to the Signal Exchange message stream and starts receiving messages. // Receive Connects to the Signal Exchange message stream and starts receiving messages.
// The messages will be handled by msgHandler function provided. // The messages will be handled by msgHandler function provided.
// This function runs a goroutine underneath and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart) // This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart)
// The key is the identifier of our Peer (could be Wireguard public key) // The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller.
func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) { func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error {
c.connWg.Add(1)
go func() {
var backOff = defaultBackoff(c.ctx) var backOff = defaultBackoff(c.ctx)
operation := func() error { operation := func() error {
stream, err := c.connect(c.key.PublicKey().String()) c.notifyStreamDisconnected()
if err != nil {
log.Warnf("disconnected from the Signal Exchange due to an error: %v", err)
c.connWg.Add(1)
return err
}
err = c.receive(stream, msgHandler) log.Debugf("signal connection state %v", c.signalConn.GetState())
if err != nil { if !c.ready() {
backOff.Reset() return fmt.Errorf("no connection to signal")
return err
}
return nil
} }
err := backoff.Retry(operation, backOff) // connect to Signal stream identifying ourselves with a public Wireguard key
// todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management)
stream, err := c.connect(c.key.PublicKey().String())
if err != nil { if err != nil {
log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err) log.Warnf("streamDisconnected from the Signal Exchange due to an error: %v", err)
return return err
} }
}()
c.notifyStreamConnected()
log.Infof("streamConnected to the Signal Service stream")
// start receiving messages from the Signal stream (from other peers through signal)
err = c.receive(stream, msgHandler)
if err != nil {
log.Warnf("streamDisconnected from the Signal Exchange due to an error: %v", err)
backOff.Reset()
return err
}
return nil
}
err := backoff.Retry(operation, backOff)
if err != nil {
log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err)
return err
}
return nil
}
func (c *GrpcClient) notifyStreamDisconnected() {
c.mux.Lock()
defer c.mux.Unlock()
c.status = streamDisconnected
}
func (c *GrpcClient) notifyStreamConnected() {
c.mux.Lock()
defer c.mux.Unlock()
c.status = streamConnected
if c.connectedCh != nil {
// there are goroutines waiting on this channel -> release them
close(c.connectedCh)
c.connectedCh = nil
}
}
func (c *GrpcClient) getStreamStatusChan() <-chan struct{} {
c.mux.Lock()
defer c.mux.Unlock()
if c.connectedCh == nil {
c.connectedCh = make(chan struct{})
}
return c.connectedCh
} }
func (c *GrpcClient) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) { func (c *GrpcClient) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) {
@@ -143,24 +193,37 @@ func (c *GrpcClient) connect(key string) (proto.SignalExchange_ConnectStreamClie
if len(registered) == 0 { if len(registered) == 0 {
return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams") return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams")
} }
//connection established we are good to use the stream
c.connWg.Done()
log.Infof("connected to the Signal Exchange Stream")
return stream, nil return stream, nil
} }
// WaitConnected waits until the client is connected to the message stream // ready indicates whether the client is okay and ready to be used
func (c *GrpcClient) WaitConnected() { // for now it just checks whether gRPC connection to the service is in state Ready
c.connWg.Wait() func (c *GrpcClient) ready() bool {
return c.signalConn.GetState() == connectivity.Ready
}
// WaitStreamConnected waits until the client is connected to the Signal stream
func (c *GrpcClient) WaitStreamConnected() {
if c.status == streamConnected {
return
}
ch := c.getStreamStatusChan()
select {
case <-c.ctx.Done():
case <-ch:
}
} }
// SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server // SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server
// The Client.Receive method must be called before sending messages to establish initial connection to the Signal Exchange // The Client.Receive method must be called before sending messages to establish initial connection to the Signal Exchange
// Client.connWg can be used to wait // Client.connWg can be used to wait
func (c *GrpcClient) SendToStream(msg *proto.EncryptedMessage) error { func (c *GrpcClient) SendToStream(msg *proto.EncryptedMessage) error {
if !c.ready() {
return fmt.Errorf("no connection to signal")
}
if c.stream == nil { if c.stream == nil {
return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call Client.Receive before sending messages") return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call Client.Receive before sending messages")
} }
@@ -177,13 +240,16 @@ func (c *GrpcClient) SendToStream(msg *proto.EncryptedMessage) error {
// Send sends a message to the remote Peer through the Signal Exchange. // Send sends a message to the remote Peer through the Signal Exchange.
func (c *GrpcClient) Send(msg *proto.Message) error { func (c *GrpcClient) Send(msg *proto.Message) error {
if !c.ready() {
return fmt.Errorf("no connection to signal")
}
encryptedMessage, err := encryptMessage(msg, c.key) encryptedMessage, err := encryptMessage(msg, c.key)
if err != nil { if err != nil {
return err return err
} }
_, err = c.realClient.Send(context.TODO(), encryptedMessage) _, err = c.realClient.Send(c.ctx, encryptedMessage)
if err != nil { if err != nil {
log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err)
return err return err
} }
@@ -200,10 +266,10 @@ func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient,
log.Warnf("stream canceled (usually indicates shutdown)") log.Warnf("stream canceled (usually indicates shutdown)")
return err return err
} else if s.Code() == codes.Unavailable { } else if s.Code() == codes.Unavailable {
log.Warnf("server has been stopped") log.Warnf("Signal Service is unavailable")
return err return err
} else if err == io.EOF { } else if err == io.EOF {
log.Warnf("stream closed by server") log.Warnf("Signal Service stream closed by server")
return err return err
} else if err != nil { } else if err != nil {
return err return err

View File

@@ -42,14 +42,14 @@ func (c *WebsocketClient) Close() error {
return c.conn.Close(websocket.StatusNormalClosure, "close") return c.conn.Close(websocket.StatusNormalClosure, "close")
} }
func (c *WebsocketClient) Receive(msgHandler func(msg *proto.Message) error) { func (c *WebsocketClient) Receive(msgHandler func(msg *proto.Message) error) error {
for { for {
_, byteMsg, err := c.conn.Read(c.ctx) _, byteMsg, err := c.conn.Read(c.ctx)
if err != nil { if err != nil {
log.Errorf("failed reading message from Signal Websocket %v", err) log.Errorf("failed reading message from Signal Websocket %v", err)
time.Sleep(2 * time.Second) time.Sleep(2 * time.Second)
//todo propagate to the upper layer and retry //todo propagate to the upper layer and retry
return return err
} }
encryptedMsg := &proto.EncryptedMessage{} encryptedMsg := &proto.EncryptedMessage{}
@@ -97,6 +97,6 @@ func (c *WebsocketClient) Send(msg *proto.Message) error {
} }
func (c *WebsocketClient) WaitConnected() { func (c *WebsocketClient) WaitStreamConnected() {
} }

View File

@@ -117,10 +117,10 @@ func (s *Server) serveWs(w http.ResponseWriter, r *http.Request) {
conn.SetReadLimit(1024 * 1024 * 3) conn.SetReadLimit(1024 * 1024 * 3)
for { for {
t, byteMsg, err := conn.ReadMessage() _, byteMsg, err := conn.ReadMessage()
if err != nil { if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
log.Printf("error: %v", err, t) log.Errorf("error: %v", err)
} }
break break
} }