From 5db130a12e4e367fb3c158edd73794151a2f6e07 Mon Sep 17 00:00:00 2001 From: Mikhail Bragin Date: Tue, 18 Jan 2022 16:44:58 +0100 Subject: [PATCH] Support new Management service protocol (NetworkMap) (#193) * feature: support new management service protocol * chore: add more logging to track networkmap serial * refactor: organize peer update code in engine * chore: fix lint issues * refactor: extract Signal client interface * test: add signal client mock * refactor: introduce Management Service client interface * chore: place management and signal clients mocks to respective packages * test: add Serial test to the engine * fix: lint issues * test: unit tests for a networkMapUpdate * test: unit tests Sync update --- client/cmd/login.go | 4 +- client/cmd/up.go | 4 +- client/internal/engine.go | 168 +++++++++------ client/internal/engine_test.go | 254 +++++++++++++++++++++-- client/internal/peer/conn.go | 4 + management/client/client.go | 250 +---------------------- management/client/client_test.go | 2 +- management/client/grpc.go | 253 +++++++++++++++++++++++ management/client/mock.go | 49 +++++ signal/client/client.go | 337 +------------------------------ signal/client/client_test.go | 4 +- signal/client/grpc.go | 336 ++++++++++++++++++++++++++++++ signal/client/mock.go | 72 +++++++ util/common.go | 16 ++ 14 files changed, 1102 insertions(+), 651 deletions(-) create mode 100644 management/client/grpc.go create mode 100644 management/client/mock.go create mode 100644 signal/client/grpc.go create mode 100644 signal/client/mock.go create mode 100644 util/common.go diff --git a/client/cmd/login.go b/client/cmd/login.go index 551dda300..e439b8e8f 100644 --- a/client/cmd/login.go +++ b/client/cmd/login.go @@ -82,7 +82,7 @@ var ( ) // loginPeer attempts to login to Management Service. If peer wasn't registered, tries the registration flow. -func loginPeer(serverPublicKey wgtypes.Key, client *mgm.Client, setupKey string) (*mgmProto.LoginResponse, error) { +func loginPeer(serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string) (*mgmProto.LoginResponse, error) { loginResp, err := client.Login(serverPublicKey) if err != nil { @@ -101,7 +101,7 @@ func loginPeer(serverPublicKey wgtypes.Key, client *mgm.Client, setupKey string) // registerPeer checks whether setupKey was provided via cmd line and if not then it prompts user to enter a key. // Otherwise tries to register with the provided setupKey via command line. -func registerPeer(serverPublicKey wgtypes.Key, client *mgm.Client, setupKey string) (*mgmProto.LoginResponse, error) { +func registerPeer(serverPublicKey wgtypes.Key, client *mgm.GrpcClient, setupKey string) (*mgmProto.LoginResponse, error) { var err error if setupKey == "" { diff --git a/client/cmd/up.go b/client/cmd/up.go index b1e6d27de..c9e155aac 100644 --- a/client/cmd/up.go +++ b/client/cmd/up.go @@ -83,7 +83,7 @@ func createEngineConfig(key wgtypes.Key, config *internal.Config, peerConfig *mg } // connectToSignal creates Signal Service client and established a connection -func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig, ourPrivateKey wgtypes.Key) (*signal.Client, error) { +func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig, ourPrivateKey wgtypes.Key) (*signal.GrpcClient, error) { var sigTLSEnabled bool if wtConfig.Signal.Protocol == mgmProto.HostConfig_HTTPS { sigTLSEnabled = true @@ -101,7 +101,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.WiretrusteeConfig, } // connectToManagement creates Management Services client, establishes a connection, logs-in and gets a global Wiretrustee config (signal, turn, stun hosts, etc) -func connectToManagement(ctx context.Context, managementAddr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*mgm.Client, *mgmProto.LoginResponse, error) { +func connectToManagement(ctx context.Context, managementAddr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*mgm.GrpcClient, *mgmProto.LoginResponse, error) { log.Debugf("connecting to management server %s", managementAddr) client, err := mgm.NewClient(ctx, managementAddr, ourPrivateKey, tlsEnabled) if err != nil { diff --git a/client/internal/engine.go b/client/internal/engine.go index 4b10f51b5..18c340eef 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -12,6 +12,7 @@ import ( mgmProto "github.com/wiretrustee/wiretrustee/management/proto" signal "github.com/wiretrustee/wiretrustee/signal/client" sProto "github.com/wiretrustee/wiretrustee/signal/proto" + "github.com/wiretrustee/wiretrustee/util" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "math/rand" "strings" @@ -44,9 +45,9 @@ type EngineConfig struct { // Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers. type Engine struct { // signal is a Signal Service client - signal *signal.Client + signal signal.Client // mgmClient is a Management Service client - mgmClient *mgm.Client + mgmClient mgm.Client // peerConns is a map that holds all the peers that are known to this peer peerConns map[string]*peer.Conn @@ -64,6 +65,9 @@ type Engine struct { ctx context.Context wgInterface iface.WGIface + + // networkSerial is the latest Serial (state ID) of the network sent by the Management service + networkSerial uint64 } // Peer is an instance of the Connection Peer @@ -73,17 +77,18 @@ type Peer struct { } // NewEngine creates a new Connection Engine -func NewEngine(signalClient *signal.Client, mgmClient *mgm.Client, config *EngineConfig, cancel context.CancelFunc, ctx context.Context) *Engine { +func NewEngine(signalClient signal.Client, mgmClient mgm.Client, config *EngineConfig, cancel context.CancelFunc, ctx context.Context) *Engine { return &Engine{ - signal: signalClient, - mgmClient: mgmClient, - peerConns: map[string]*peer.Conn{}, - syncMsgMux: &sync.Mutex{}, - config: config, - STUNs: []*ice.URL{}, - TURNs: []*ice.URL{}, - cancel: cancel, - ctx: ctx, + signal: signalClient, + mgmClient: mgmClient, + peerConns: map[string]*peer.Conn{}, + syncMsgMux: &sync.Mutex{}, + config: config, + STUNs: []*ice.URL{}, + TURNs: []*ice.URL{}, + cancel: cancel, + ctx: ctx, + networkSerial: 0, } } @@ -91,7 +96,7 @@ func (e *Engine) Stop() error { e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock() - err := e.removeAllPeerConnections() + err := e.removeAllPeers() if err != nil { return err } @@ -146,8 +151,22 @@ func (e *Engine) Start() error { return nil } -func (e *Engine) removePeers(peers []string) error { - for _, p := range peers { +// removePeers finds and removes peers that do not exist anymore in the network map received from the Management Service +func (e *Engine) removePeers(peersUpdate []*mgmProto.RemotePeerConfig) error { + + currentPeers := make([]string, 0, len(e.peerConns)) + for p := range e.peerConns { + currentPeers = append(currentPeers, p) + } + + newPeers := make([]string, 0, len(peersUpdate)) + for _, p := range peersUpdate { + newPeers = append(newPeers, p.GetWgPubKey()) + } + + toRemove := util.SliceDiff(currentPeers, newPeers) + + for _, p := range toRemove { err := e.removePeer(p) if err != nil { return err @@ -157,7 +176,7 @@ func (e *Engine) removePeers(peers []string) error { return nil } -func (e *Engine) removeAllPeerConnections() error { +func (e *Engine) removeAllPeers() error { log.Debugf("removing all peer connections") for p := range e.peerConns { err := e.removePeer(p) @@ -189,6 +208,16 @@ func (e *Engine) GetPeerConnectionStatus(peerKey string) peer.ConnStatus { return -1 } +func (e *Engine) GetPeers() []string { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + + peers := []string{} + for s := range e.peerConns { + peers = append(peers, s) + } + return peers +} // GetConnectedPeers returns a connection Status or nil if peer connection wasn't found func (e *Engine) GetConnectedPeers() []string { @@ -205,7 +234,7 @@ func (e *Engine) GetConnectedPeers() []string { return peers } -func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s *signal.Client) error { +func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client) error { err := s.Send(&sProto.Message{ Key: myKey.PublicKey().String(), RemoteKey: remoteKey.String(), @@ -223,7 +252,7 @@ func signalCandidate(candidate ice.Candidate, myKey wgtypes.Key, remoteKey wgtyp return nil } -func signalAuth(uFrag string, pwd string, myKey wgtypes.Key, remoteKey wgtypes.Key, s *signal.Client, isAnswer bool) error { +func signalAuth(uFrag string, pwd string, myKey wgtypes.Key, remoteKey wgtypes.Key, s signal.Client, isAnswer bool) error { var t sProto.Body_Type if isAnswer { @@ -246,37 +275,42 @@ func signalAuth(uFrag string, pwd string, myKey wgtypes.Key, remoteKey wgtypes.K return nil } +func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + + if update.GetWiretrusteeConfig() != nil { + err := e.updateTURNs(update.GetWiretrusteeConfig().GetTurns()) + if err != nil { + return err + } + + err = e.updateSTUNs(update.GetWiretrusteeConfig().GetStuns()) + if err != nil { + return err + } + + //todo update signal + } + + if update.GetNetworkMap() != nil { + // only apply new changes and ignore old ones + err := e.updateNetworkMap(update.GetNetworkMap()) + if err != nil { + return err + } + } + + return nil + +} + // receiveManagementEvents connects to the Management Service event stream to receive updates from the management service // E.g. when a new peer has been registered and we are allowed to connect to it. func (e *Engine) receiveManagementEvents() { go func() { err := e.mgmClient.Sync(func(update *mgmProto.SyncResponse) error { - e.syncMsgMux.Lock() - defer e.syncMsgMux.Unlock() - - if update.GetWiretrusteeConfig() != nil { - err := e.updateTURNs(update.GetWiretrusteeConfig().GetTurns()) - if err != nil { - return err - } - - err = e.updateSTUNs(update.GetWiretrusteeConfig().GetStuns()) - if err != nil { - return err - } - - //todo update signal - } - - if update.GetRemotePeers() != nil || update.GetRemotePeersIsEmpty() { - // empty arrays are serialized by protobuf to null, but for our case empty array is a valid state. - err := e.updatePeers(update.GetRemotePeers()) - if err != nil { - return err - } - } - - return nil + return e.handleSync(update) }) if err != nil { // happens if management is unavailable for a long time. @@ -327,27 +361,41 @@ func (e *Engine) updateTURNs(turns []*mgmProto.ProtectedHostConfig) error { return nil } -func (e *Engine) updatePeers(remotePeers []*mgmProto.RemotePeerConfig) error { - log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(remotePeers)) - remotePeerMap := make(map[string]struct{}) - for _, p := range remotePeers { - remotePeerMap[p.GetWgPubKey()] = struct{}{} +func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { + + serial := networkMap.GetSerial() + if e.networkSerial > serial { + log.Debugf("received outdated NetworkMap with serial %d, ignoring", serial) + return nil } - //remove peers that are no longer available for us - toRemove := []string{} - for p := range e.peerConns { - if _, ok := remotePeerMap[p]; !ok { - toRemove = append(toRemove, p) + log.Debugf("got peers update from Management Service, total peers to connect to = %d", len(networkMap.GetRemotePeers())) + + // cleanup request, most likely our peer has been deleted + if networkMap.GetRemotePeersIsEmpty() { + err := e.removeAllPeers() + if err != nil { + return err + } + } else { + err := e.removePeers(networkMap.GetRemotePeers()) + if err != nil { + return err + } + + err = e.addNewPeers(networkMap.GetRemotePeers()) + if err != nil { + return err } } - err := e.removePeers(toRemove) - if err != nil { - return err - } - // add new peers - for _, p := range remotePeers { + e.networkSerial = serial + return nil +} + +// addNewPeers finds and adds peers that were not know before but arrived from the Management service with the update +func (e *Engine) addNewPeers(peersUpdate []*mgmProto.RemotePeerConfig) error { + for _, p := range peersUpdate { peerKey := p.GetWgPubKey() peerIPs := p.GetAllowedIps() if _, ok := e.peerConns[peerKey]; !ok { diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 130440a96..4812148de 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -37,6 +37,232 @@ var ( } ) +func TestEngine_UpdateNetworkMap(t *testing.T) { + + // test setup + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + engine := NewEngine(&signal.MockClient{}, &mgmt.MockClient{}, &EngineConfig{ + WgIfaceName: "utun100", + WgAddr: "100.64.0.1/24", + WgPrivateKey: key, + WgPort: 33100, + }, cancel, ctx) + + type testCase struct { + idx int + networkMap *mgmtProto.NetworkMap + expectedLen int + expectedPeers []string + expectedSerial uint64 + } + + peer1 := &mgmtProto.RemotePeerConfig{ + WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + AllowedIps: []string{"100.64.0.10/24"}, + } + + peer2 := &mgmtProto.RemotePeerConfig{ + WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + AllowedIps: []string{"100.64.0.11/24"}, + } + + peer3 := &mgmtProto.RemotePeerConfig{ + WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + AllowedIps: []string{"100.64.0.12/24"}, + } + + // 1st case - new peer and network map has Serial grater than local => apply the update + case1 := testCase{ + idx: 1, + networkMap: &mgmtProto.NetworkMap{ + Serial: 1, + PeerConfig: nil, + RemotePeers: []*mgmtProto.RemotePeerConfig{ + peer1, + }, + RemotePeersIsEmpty: false, + }, + expectedLen: 1, + expectedPeers: []string{peer1.GetWgPubKey()}, + expectedSerial: 1, + } + + // 2nd case - one extra peer added and network map has Serial grater than local => apply the update + case2 := testCase{ + idx: 2, + networkMap: &mgmtProto.NetworkMap{ + Serial: 2, + PeerConfig: nil, + RemotePeers: []*mgmtProto.RemotePeerConfig{ + peer1, peer2, + }, + RemotePeersIsEmpty: false, + }, + expectedLen: 2, + expectedPeers: []string{peer1.GetWgPubKey(), peer2.GetWgPubKey()}, + expectedSerial: 2, + } + + // 3rd case - an update with 3 peers and Serial lower than the current serial of the engine => ignore the update + case3 := testCase{ + idx: 3, + networkMap: &mgmtProto.NetworkMap{ + Serial: 0, + PeerConfig: nil, + RemotePeers: []*mgmtProto.RemotePeerConfig{ + peer1, peer2, peer3, + }, + RemotePeersIsEmpty: false, + }, + expectedLen: 2, + expectedPeers: []string{peer1.GetWgPubKey(), peer2.GetWgPubKey()}, + expectedSerial: 2, + } + + // 4th case - an update with 2 peers (1 new and 1 old) => apply the update removing old peer and adding a new one + case4 := testCase{ + idx: 3, + networkMap: &mgmtProto.NetworkMap{ + Serial: 4, + PeerConfig: nil, + RemotePeers: []*mgmtProto.RemotePeerConfig{ + peer2, peer3, + }, + RemotePeersIsEmpty: false, + }, + expectedLen: 2, + expectedPeers: []string{peer2.GetWgPubKey(), peer3.GetWgPubKey()}, + expectedSerial: 4, + } + + // 5th case - an update with all peers to be removed + case5 := testCase{ + idx: 3, + networkMap: &mgmtProto.NetworkMap{ + Serial: 5, + PeerConfig: nil, + RemotePeers: []*mgmtProto.RemotePeerConfig{}, + RemotePeersIsEmpty: true, + }, + expectedLen: 0, + expectedPeers: nil, + expectedSerial: 5, + } + + for _, c := range []testCase{case1, case2, case3, case4, case5} { + err = engine.updateNetworkMap(c.networkMap) + if err != nil { + t.Fatal(err) + return + } + + if len(engine.peerConns) != c.expectedLen { + t.Errorf("case %d expecting Engine.peerConns to be of size %d, got %d", c.idx, c.expectedLen, len(engine.peerConns)) + } + + if engine.networkSerial != c.expectedSerial { + t.Errorf("case %d expecting Engine.networkSerial to be equal to %d, actual %d", c.idx, c.expectedSerial, engine.networkSerial) + } + + for _, p := range c.expectedPeers { + if _, ok := engine.peerConns[p]; !ok { + t.Errorf("case %d expecting Engine.peerConns to contain peer %s", c.idx, p) + } + } + } + +} + +func TestEngine_Sync(t *testing.T) { + + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // feed updates to Engine via mocked Management client + updates := make(chan *mgmtProto.SyncResponse) + defer close(updates) + syncFunc := func(msgHandler func(msg *mgmtProto.SyncResponse) error) error { + + for msg := range updates { + err := msgHandler(msg) + if err != nil { + t.Fatal(err) + } + } + return nil + } + + engine := NewEngine(&signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, &EngineConfig{ + WgIfaceName: "utun100", + WgAddr: "100.64.0.1/24", + WgPrivateKey: key, + WgPort: 33100, + }, cancel, ctx) + + defer func() { + err := engine.Stop() + if err != nil { + return + } + }() + + err = engine.Start() + if err != nil { + t.Fatal(err) + return + } + + peer1 := &mgmtProto.RemotePeerConfig{ + WgPubKey: "RRHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=", + AllowedIps: []string{"100.64.0.10/24"}, + } + peer2 := &mgmtProto.RemotePeerConfig{ + WgPubKey: "LLHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=", + AllowedIps: []string{"100.64.0.11/24"}, + } + peer3 := &mgmtProto.RemotePeerConfig{ + WgPubKey: "GGHf3Ma6z6mdLbriAJbqhX9+nM/B71lgw2+91q3LlhU=", + AllowedIps: []string{"100.64.0.12/24"}, + } + // 1st update with just 1 peer and serial larger than the current serial of the engine => apply update + updates <- &mgmtProto.SyncResponse{ + NetworkMap: &mgmtProto.NetworkMap{ + Serial: 10, + PeerConfig: nil, + RemotePeers: []*mgmtProto.RemotePeerConfig{peer1, peer2, peer3}, + RemotePeersIsEmpty: false, + }, + } + + timeout := time.After(time.Second * 2) + for { + select { + case <-timeout: + t.Fatalf("timeout while waiting for test to finish") + default: + } + + if len(engine.GetPeers()) == 3 && engine.networkSerial == 10 { + break + } + } + +} + func TestEngine_MultiplePeers(t *testing.T) { //log.SetLevel(log.DebugLevel) @@ -58,23 +284,14 @@ func TestEngine_MultiplePeers(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() sport := 10010 - signalServer, err := startSignal(sport) + sigServer, err := startSignal(sport) if err != nil { t.Fatal(err) return } - defer signalServer.Stop() + defer sigServer.Stop() mport := 33081 - mgmtServer, err := startManagement(mport, &server.Config{ - Stuns: []*server.Host{}, - TURNConfig: &server.TURNConfig{}, - Signal: &server.Host{ - Proto: "http", - URI: "localhost:10000", - }, - Datadir: dir, - HttpConfig: nil, - }) + mgmtServer, err := startManagement(mport, dir) if err != nil { t.Fatal(err) return @@ -201,7 +418,18 @@ func startSignal(port int) (*grpc.Server, error) { return s, nil } -func startManagement(port int, config *server.Config) (*grpc.Server, error) { +func startManagement(port int, dataDir string) (*grpc.Server, error) { + + config := &server.Config{ + Stuns: []*server.Host{}, + TURNConfig: &server.TURNConfig{}, + Signal: &server.Host{ + Proto: "http", + URI: "localhost:10000", + }, + Datadir: dataDir, + HttpConfig: nil, + } lis, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", port)) if err != nil { diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 9b881559b..2458284ca 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -419,3 +419,7 @@ func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate) { } }() } + +func (conn *Conn) GetKey() string { + return conn.config.Key +} diff --git a/management/client/client.go b/management/client/client.go index 6ccd62ec0..e881fb40d 100644 --- a/management/client/client.go +++ b/management/client/client.go @@ -1,253 +1,15 @@ package client import ( - "context" - "crypto/tls" - "fmt" - "github.com/cenkalti/backoff/v4" - log "github.com/sirupsen/logrus" - "github.com/wiretrustee/wiretrustee/client/system" - "github.com/wiretrustee/wiretrustee/encryption" "github.com/wiretrustee/wiretrustee/management/proto" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "google.golang.org/grpc" - "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/keepalive" "io" - "time" ) -type Client struct { - key wgtypes.Key - realClient proto.ManagementServiceClient - ctx context.Context - conn *grpc.ClientConn -} - -// NewClient creates a new client to Management service -func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*Client, error) { - - transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) - - if tlsEnabled { - transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) - } - - mgmCtx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - conn, err := grpc.DialContext( - mgmCtx, - addr, - transportOption, - grpc.WithBlock(), - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: 15 * time.Second, - Timeout: 10 * time.Second, - })) - - if err != nil { - log.Errorf("failed creating connection to Management Service %v", err) - return nil, err - } - - realClient := proto.NewManagementServiceClient(conn) - - return &Client{ - key: ourPrivateKey, - realClient: realClient, - ctx: ctx, - conn: conn, - }, nil -} - -// Close closes connection to the Management Service -func (c *Client) Close() error { - return c.conn.Close() -} - -//defaultBackoff is a basic backoff mechanism for general issues -func defaultBackoff(ctx context.Context) backoff.BackOff { - return backoff.WithContext(&backoff.ExponentialBackOff{ - InitialInterval: 800 * time.Millisecond, - RandomizationFactor: backoff.DefaultRandomizationFactor, - Multiplier: backoff.DefaultMultiplier, - MaxInterval: 10 * time.Second, - MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client - Stop: backoff.Stop, - Clock: backoff.SystemClock, - }, ctx) -} - -// ready indicates whether the client is okay and ready to be used -// for now it just checks whether gRPC connection to the service is ready -func (c *Client) ready() bool { - return c.conn.GetState() == connectivity.Ready || c.conn.GetState() == connectivity.Idle -} - -// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages -// Blocking request. The result will be sent via msgHandler callback function -func (c *Client) Sync(msgHandler func(msg *proto.SyncResponse) error) error { - - var backOff = defaultBackoff(c.ctx) - - operation := func() error { - - log.Debugf("management connection state %v", c.conn.GetState()) - - if !c.ready() { - return fmt.Errorf("no connection to management") - } - - // todo we already have it since we did the Login, maybe cache it locally? - serverPubKey, err := c.GetServerPublicKey() - if err != nil { - log.Errorf("failed getting Management Service public key: %s", err) - return err - } - - stream, err := c.connectToStream(*serverPubKey) - if err != nil { - log.Errorf("failed to open Management Service stream: %s", err) - return err - } - - log.Infof("connected to the Management Service stream") - - // blocking until error - err = c.receiveEvents(stream, *serverPubKey, msgHandler) - if err != nil { - backOff.Reset() - return err - } - - return nil - } - - err := backoff.Retry(operation, backOff) - if err != nil { - log.Warnf("exiting Management Service connection retry loop due to unrecoverable error: %s", err) - return err - } - - return nil -} - -func (c *Client) connectToStream(serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) { - req := &proto.SyncRequest{} - - myPrivateKey := c.key - myPublicKey := myPrivateKey.PublicKey() - - encryptedReq, err := encryption.EncryptMessage(serverPubKey, myPrivateKey, req) - if err != nil { - log.Errorf("failed encrypting message: %s", err) - return nil, err - } - - syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq} - return c.realClient.Sync(c.ctx, syncReq) -} - -func (c *Client) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error { - for { - update, err := stream.Recv() - if err == io.EOF { - log.Errorf("Management stream has been closed by server: %s", err) - return err - } - if err != nil { - log.Warnf("disconnected from Management Service sync stream: %v", err) - return err - } - - log.Debugf("got an update message from Management Service") - decryptedResp := &proto.SyncResponse{} - err = encryption.DecryptMessage(serverPubKey, c.key, update.Body, decryptedResp) - if err != nil { - log.Errorf("failed decrypting update message from Management Service: %s", err) - return err - } - - err = msgHandler(decryptedResp) - if err != nil { - log.Errorf("failed handling an update message received from Management Service: %v", err.Error()) - return err - } - } -} - -// GetServerPublicKey returns server Wireguard public key (used later for encrypting messages sent to the server) -func (c *Client) GetServerPublicKey() (*wgtypes.Key, error) { - if !c.ready() { - return nil, fmt.Errorf("no connection to management") - } - - mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting - defer cancel() - resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{}) - if err != nil { - return nil, err - } - - serverKey, err := wgtypes.ParseKey(resp.Key) - if err != nil { - return nil, err - } - - return &serverKey, nil -} - -func (c *Client) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) { - if !c.ready() { - return nil, fmt.Errorf("no connection to management") - } - loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) - if err != nil { - log.Errorf("failed to encrypt message: %s", err) - return nil, err - } - mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting - defer cancel() - resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{ - WgPubKey: c.key.PublicKey().String(), - Body: loginReq, - }) - - if err != nil { - return nil, err - } - - loginResp := &proto.LoginResponse{} - err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp) - if err != nil { - log.Errorf("failed to decrypt registration message: %s", err) - return nil, err - } - - return loginResp, nil -} - -// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key -// Takes care of encrypting and decrypting messages. -// This method will also collect system info and send it with the request (e.g. hostname, os, etc) -func (c *Client) Register(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error) { - gi := system.GetInfo() - meta := &proto.PeerSystemMeta{ - Hostname: gi.Hostname, - GoOS: gi.GoOS, - OS: gi.OS, - Core: gi.OSVersion, - Platform: gi.Platform, - Kernel: gi.Kernel, - WiretrusteeVersion: "", - } - log.Debugf("detected system %v", meta) - return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: meta}) -} - -// Login attempts login to Management Server. Takes care of encrypting and decrypting messages. -func (c *Client) Login(serverKey wgtypes.Key) (*proto.LoginResponse, error) { - return c.login(serverKey, &proto.LoginRequest{}) +type Client interface { + io.Closer + Sync(msgHandler func(msg *proto.SyncResponse) error) error + GetServerPublicKey() (*wgtypes.Key, error) + Register(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error) + Login(serverKey wgtypes.Key) (*proto.LoginResponse, error) } diff --git a/management/client/client_test.go b/management/client/client_test.go index f6efbac62..4313d3ae6 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -16,7 +16,7 @@ import ( "time" ) -var tested *Client +var tested *GrpcClient var serverAddr string const ValidKey = "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" diff --git a/management/client/grpc.go b/management/client/grpc.go new file mode 100644 index 000000000..9d2e603b2 --- /dev/null +++ b/management/client/grpc.go @@ -0,0 +1,253 @@ +package client + +import ( + "context" + "crypto/tls" + "fmt" + "github.com/cenkalti/backoff/v4" + log "github.com/sirupsen/logrus" + "github.com/wiretrustee/wiretrustee/client/system" + "github.com/wiretrustee/wiretrustee/encryption" + "github.com/wiretrustee/wiretrustee/management/proto" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" + "io" + "time" +) + +type GrpcClient struct { + key wgtypes.Key + realClient proto.ManagementServiceClient + ctx context.Context + conn *grpc.ClientConn +} + +// NewClient creates a new client to Management service +func NewClient(ctx context.Context, addr string, ourPrivateKey wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { + + transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) + + if tlsEnabled { + transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) + } + + mgmCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + conn, err := grpc.DialContext( + mgmCtx, + addr, + transportOption, + grpc.WithBlock(), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: 15 * time.Second, + Timeout: 10 * time.Second, + })) + + if err != nil { + log.Errorf("failed creating connection to Management Service %v", err) + return nil, err + } + + realClient := proto.NewManagementServiceClient(conn) + + return &GrpcClient{ + key: ourPrivateKey, + realClient: realClient, + ctx: ctx, + conn: conn, + }, nil +} + +// Close closes connection to the Management Service +func (c *GrpcClient) Close() error { + return c.conn.Close() +} + +//defaultBackoff is a basic backoff mechanism for general issues +func defaultBackoff(ctx context.Context) backoff.BackOff { + return backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 800 * time.Millisecond, + RandomizationFactor: backoff.DefaultRandomizationFactor, + Multiplier: backoff.DefaultMultiplier, + MaxInterval: 10 * time.Second, + MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) +} + +// ready indicates whether the client is okay and ready to be used +// for now it just checks whether gRPC connection to the service is ready +func (c *GrpcClient) ready() bool { + return c.conn.GetState() == connectivity.Ready || c.conn.GetState() == connectivity.Idle +} + +// Sync wraps the real client's Sync endpoint call and takes care of retries and encryption/decryption of messages +// Blocking request. The result will be sent via msgHandler callback function +func (c *GrpcClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error { + + var backOff = defaultBackoff(c.ctx) + + operation := func() error { + + log.Debugf("management connection state %v", c.conn.GetState()) + + if !c.ready() { + return fmt.Errorf("no connection to management") + } + + // todo we already have it since we did the Login, maybe cache it locally? + serverPubKey, err := c.GetServerPublicKey() + if err != nil { + log.Errorf("failed getting Management Service public key: %s", err) + return err + } + + stream, err := c.connectToStream(*serverPubKey) + if err != nil { + log.Errorf("failed to open Management Service stream: %s", err) + return err + } + + log.Infof("connected to the Management Service stream") + + // blocking until error + err = c.receiveEvents(stream, *serverPubKey, msgHandler) + if err != nil { + backOff.Reset() + return err + } + + return nil + } + + err := backoff.Retry(operation, backOff) + if err != nil { + log.Warnf("exiting Management Service connection retry loop due to unrecoverable error: %s", err) + return err + } + + return nil +} + +func (c *GrpcClient) connectToStream(serverPubKey wgtypes.Key) (proto.ManagementService_SyncClient, error) { + req := &proto.SyncRequest{} + + myPrivateKey := c.key + myPublicKey := myPrivateKey.PublicKey() + + encryptedReq, err := encryption.EncryptMessage(serverPubKey, myPrivateKey, req) + if err != nil { + log.Errorf("failed encrypting message: %s", err) + return nil, err + } + + syncReq := &proto.EncryptedMessage{WgPubKey: myPublicKey.String(), Body: encryptedReq} + return c.realClient.Sync(c.ctx, syncReq) +} + +func (c *GrpcClient) receiveEvents(stream proto.ManagementService_SyncClient, serverPubKey wgtypes.Key, msgHandler func(msg *proto.SyncResponse) error) error { + for { + update, err := stream.Recv() + if err == io.EOF { + log.Errorf("Management stream has been closed by server: %s", err) + return err + } + if err != nil { + log.Warnf("disconnected from Management Service sync stream: %v", err) + return err + } + + log.Debugf("got an update message from Management Service") + decryptedResp := &proto.SyncResponse{} + err = encryption.DecryptMessage(serverPubKey, c.key, update.Body, decryptedResp) + if err != nil { + log.Errorf("failed decrypting update message from Management Service: %s", err) + return err + } + + err = msgHandler(decryptedResp) + if err != nil { + log.Errorf("failed handling an update message received from Management Service: %v", err.Error()) + return err + } + } +} + +// GetServerPublicKey returns server Wireguard public key (used later for encrypting messages sent to the server) +func (c *GrpcClient) GetServerPublicKey() (*wgtypes.Key, error) { + if !c.ready() { + return nil, fmt.Errorf("no connection to management") + } + + mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting + defer cancel() + resp, err := c.realClient.GetServerKey(mgmCtx, &proto.Empty{}) + if err != nil { + return nil, err + } + + serverKey, err := wgtypes.ParseKey(resp.Key) + if err != nil { + return nil, err + } + + return &serverKey, nil +} + +func (c *GrpcClient) login(serverKey wgtypes.Key, req *proto.LoginRequest) (*proto.LoginResponse, error) { + if !c.ready() { + return nil, fmt.Errorf("no connection to management") + } + loginReq, err := encryption.EncryptMessage(serverKey, c.key, req) + if err != nil { + log.Errorf("failed to encrypt message: %s", err) + return nil, err + } + mgmCtx, cancel := context.WithTimeout(c.ctx, 5*time.Second) //todo make a general setting + defer cancel() + resp, err := c.realClient.Login(mgmCtx, &proto.EncryptedMessage{ + WgPubKey: c.key.PublicKey().String(), + Body: loginReq, + }) + + if err != nil { + return nil, err + } + + loginResp := &proto.LoginResponse{} + err = encryption.DecryptMessage(serverKey, c.key, resp.Body, loginResp) + if err != nil { + log.Errorf("failed to decrypt registration message: %s", err) + return nil, err + } + + return loginResp, nil +} + +// Register registers peer on Management Server. It actually calls a Login endpoint with a provided setup key +// Takes care of encrypting and decrypting messages. +// This method will also collect system info and send it with the request (e.g. hostname, os, etc) +func (c *GrpcClient) Register(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error) { + gi := system.GetInfo() + meta := &proto.PeerSystemMeta{ + Hostname: gi.Hostname, + GoOS: gi.GoOS, + OS: gi.OS, + Core: gi.OSVersion, + Platform: gi.Platform, + Kernel: gi.Kernel, + WiretrusteeVersion: "", + } + log.Debugf("detected system %v", meta) + return c.login(serverKey, &proto.LoginRequest{SetupKey: setupKey, Meta: meta}) +} + +// Login attempts login to Management Server. Takes care of encrypting and decrypting messages. +func (c *GrpcClient) Login(serverKey wgtypes.Key) (*proto.LoginResponse, error) { + return c.login(serverKey, &proto.LoginRequest{}) +} diff --git a/management/client/mock.go b/management/client/mock.go new file mode 100644 index 000000000..af99ba863 --- /dev/null +++ b/management/client/mock.go @@ -0,0 +1,49 @@ +package client + +import ( + "github.com/wiretrustee/wiretrustee/management/proto" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type MockClient struct { + CloseFunc func() error + SyncFunc func(msgHandler func(msg *proto.SyncResponse) error) error + GetServerPublicKeyFunc func() (*wgtypes.Key, error) + RegisterFunc func(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error) + LoginFunc func(serverKey wgtypes.Key) (*proto.LoginResponse, error) +} + +func (m *MockClient) Close() error { + if m.CloseFunc == nil { + return nil + } + return m.CloseFunc() +} + +func (m *MockClient) Sync(msgHandler func(msg *proto.SyncResponse) error) error { + if m.SyncFunc == nil { + return nil + } + return m.SyncFunc(msgHandler) +} + +func (m *MockClient) GetServerPublicKey() (*wgtypes.Key, error) { + if m.GetServerPublicKeyFunc == nil { + return nil, nil + } + return m.GetServerPublicKeyFunc() +} + +func (m *MockClient) Register(serverKey wgtypes.Key, setupKey string) (*proto.LoginResponse, error) { + if m.RegisterFunc == nil { + return nil, nil + } + return m.RegisterFunc(serverKey, setupKey) +} + +func (m *MockClient) Login(serverKey wgtypes.Key) (*proto.LoginResponse, error) { + if m.LoginFunc == nil { + return nil, nil + } + return m.LoginFunc(serverKey) +} diff --git a/signal/client/client.go b/signal/client/client.go index bfa3f4f8d..e1a8bb767 100644 --- a/signal/client/client.go +++ b/signal/client/client.go @@ -1,26 +1,11 @@ package client import ( - "context" - "crypto/tls" "fmt" - "github.com/cenkalti/backoff/v4" - log "github.com/sirupsen/logrus" - "github.com/wiretrustee/wiretrustee/encryption" "github.com/wiretrustee/wiretrustee/signal/proto" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/connectivity" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/keepalive" - "google.golang.org/grpc/metadata" - "google.golang.org/grpc/status" "io" "strings" - "sync" - "time" ) // A set of tools to exchange connection details (Wireguard endpoints) with the remote peer. @@ -31,317 +16,15 @@ type Status string const StreamConnected Status = "Connected" const StreamDisconnected Status = "Disconnected" -// Client Wraps the Signal Exchange Service gRpc client -type Client struct { - key wgtypes.Key - realClient proto.SignalExchangeClient - signalConn *grpc.ClientConn - ctx context.Context - stream proto.SignalExchange_ConnectStreamClient - // connectedCh used to notify goroutines waiting for the connection to the Signal stream - connectedCh chan struct{} - mux sync.Mutex - // StreamConnected indicates whether this client is StreamConnected to the Signal stream - status Status -} - -func (c *Client) StreamConnected() bool { - return c.status == StreamConnected -} - -func (c *Client) GetStatus() Status { - return c.status -} - -// Close Closes underlying connections to the Signal Exchange -func (c *Client) Close() error { - return c.signalConn.Close() -} - -// NewClient creates a new Signal client -func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*Client, error) { - - transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) - - if tlsEnabled { - transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) - } - - sigCtx, cancel := context.WithTimeout(ctx, 10*time.Second) - defer cancel() - conn, err := grpc.DialContext( - sigCtx, - addr, - transportOption, - grpc.WithBlock(), - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: 15 * time.Second, - Timeout: 10 * time.Second, - })) - - if err != nil { - log.Errorf("failed to connect to the signalling server %v", err) - return nil, err - } - - return &Client{ - realClient: proto.NewSignalExchangeClient(conn), - ctx: ctx, - signalConn: conn, - key: key, - mux: sync.Mutex{}, - status: StreamDisconnected, - }, nil -} - -//defaultBackoff is a basic backoff mechanism for general issues -func defaultBackoff(ctx context.Context) backoff.BackOff { - return backoff.WithContext(&backoff.ExponentialBackOff{ - InitialInterval: 800 * time.Millisecond, - RandomizationFactor: backoff.DefaultRandomizationFactor, - Multiplier: backoff.DefaultMultiplier, - MaxInterval: 10 * time.Second, - MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client - Stop: backoff.Stop, - Clock: backoff.SystemClock, - }, ctx) - -} - -// Receive Connects to the Signal Exchange message stream and starts receiving messages. -// The messages will be handled by msgHandler function provided. -// This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart) -// The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller. -func (c *Client) Receive(msgHandler func(msg *proto.Message) error) error { - - var backOff = defaultBackoff(c.ctx) - - operation := func() error { - - c.notifyStreamDisconnected() - - log.Debugf("signal connection state %v", c.signalConn.GetState()) - if !c.Ready() { - return fmt.Errorf("no connection to signal") - } - - // connect to Signal stream identifying ourselves with a public Wireguard key - // todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management) - stream, err := c.connect(c.key.PublicKey().String()) - if err != nil { - log.Warnf("disconnected from the Signal Exchange due to an error: %v", err) - return err - } - - c.notifyStreamConnected() - - log.Infof("connected to the Signal Service stream") - - // start receiving messages from the Signal stream (from other peers through signal) - err = c.receive(stream, msgHandler) - if err != nil { - log.Warnf("disconnected from the Signal Exchange due to an error: %v", err) - backOff.Reset() - return err - } - - return nil - } - - err := backoff.Retry(operation, backOff) - if err != nil { - log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err) - return err - } - - return nil -} -func (c *Client) notifyStreamDisconnected() { - c.mux.Lock() - defer c.mux.Unlock() - c.status = StreamDisconnected -} - -func (c *Client) notifyStreamConnected() { - c.mux.Lock() - defer c.mux.Unlock() - c.status = StreamConnected - if c.connectedCh != nil { - // there are goroutines waiting on this channel -> release them - close(c.connectedCh) - c.connectedCh = nil - } -} - -func (c *Client) getStreamStatusChan() <-chan struct{} { - c.mux.Lock() - defer c.mux.Unlock() - if c.connectedCh == nil { - c.connectedCh = make(chan struct{}) - } - return c.connectedCh -} - -func (c *Client) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) { - c.stream = nil - - // add key fingerprint to the request header to be identified on the server side - md := metadata.New(map[string]string{proto.HeaderId: key}) - ctx := metadata.NewOutgoingContext(c.ctx, md) - - stream, err := c.realClient.ConnectStream(ctx, grpc.WaitForReady(true)) - - c.stream = stream - if err != nil { - return nil, err - } - // blocks - header, err := c.stream.Header() - if err != nil { - return nil, err - } - registered := header.Get(proto.HeaderRegistered) - if len(registered) == 0 { - return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams") - } - - return stream, nil -} - -// Ready indicates whether the client is okay and Ready to be used -// for now it just checks whether gRPC connection to the service is in state Ready -func (c *Client) Ready() bool { - return c.signalConn.GetState() == connectivity.Ready || c.signalConn.GetState() == connectivity.Idle -} - -// WaitStreamConnected waits until the client is connected to the Signal stream -func (c *Client) WaitStreamConnected() { - - if c.status == StreamConnected { - return - } - - ch := c.getStreamStatusChan() - select { - case <-c.ctx.Done(): - case <-ch: - } -} - -// SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server -// The Client.Receive method must be called before sending messages to establish initial connection to the Signal Exchange -// Client.connWg can be used to wait -func (c *Client) SendToStream(msg *proto.EncryptedMessage) error { - if !c.Ready() { - return fmt.Errorf("no connection to signal") - } - if c.stream == nil { - return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call Client.Receive before sending messages") - } - - err := c.stream.Send(msg) - if err != nil { - log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err) - return err - } - - return nil -} - -// decryptMessage decrypts the body of the msg using Wireguard private key and Remote peer's public key -func (c *Client) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, error) { - remoteKey, err := wgtypes.ParseKey(msg.GetKey()) - if err != nil { - return nil, err - } - - body := &proto.Body{} - err = encryption.DecryptMessage(remoteKey, c.key, msg.GetBody(), body) - if err != nil { - return nil, err - } - - return &proto.Message{ - Key: msg.Key, - RemoteKey: msg.RemoteKey, - Body: body, - }, nil -} - -// encryptMessage encrypts the body of the msg using Wireguard private key and Remote peer's public key -func (c *Client) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, error) { - - remoteKey, err := wgtypes.ParseKey(msg.RemoteKey) - if err != nil { - return nil, err - } - - encryptedBody, err := encryption.EncryptMessage(remoteKey, c.key, msg.Body) - if err != nil { - return nil, err - } - - return &proto.EncryptedMessage{ - Key: msg.GetKey(), - RemoteKey: msg.GetRemoteKey(), - Body: encryptedBody, - }, nil -} - -// Send sends a message to the remote Peer through the Signal Exchange. -func (c *Client) Send(msg *proto.Message) error { - - if !c.Ready() { - return fmt.Errorf("no connection to signal") - } - - encryptedMessage, err := c.encryptMessage(msg) - if err != nil { - return err - } - - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - _, err = c.realClient.Send(ctx, encryptedMessage) - if err != nil { - return err - } - - return nil -} - -// receive receives messages from other peers coming through the Signal Exchange -func (c *Client) receive(stream proto.SignalExchange_ConnectStreamClient, - msgHandler func(msg *proto.Message) error) error { - - for { - msg, err := stream.Recv() - if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled { - log.Warnf("stream canceled (usually indicates shutdown)") - return err - } else if s.Code() == codes.Unavailable { - log.Warnf("Signal Service is unavailable") - return err - } else if err == io.EOF { - log.Warnf("Signal Service stream closed by server") - return err - } else if err != nil { - return err - } - log.Debugf("received a new message from Peer [fingerprint: %s]", msg.Key) - - decryptedMessage, err := c.decryptMessage(msg) - if err != nil { - log.Errorf("failed decrypting message of Peer [key: %s] error: [%s]", msg.Key, err.Error()) - } - - err = msgHandler(decryptedMessage) - - if err != nil { - log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error()) - //todo send something?? - } - } +type Client interface { + io.Closer + StreamConnected() bool + GetStatus() Status + Receive(msgHandler func(msg *proto.Message) error) error + Ready() bool + WaitStreamConnected() + SendToStream(msg *proto.EncryptedMessage) error + Send(msg *proto.Message) error } // UnMarshalCredential parses the credentials from the message and returns a Credential instance @@ -369,7 +52,7 @@ func MarshalCredential(myKey wgtypes.Key, remoteKey wgtypes.Key, credential *Cre }, nil } -// Credential is an instance of a Client's Credential +// Credential is an instance of a GrpcClient's Credential type Credential struct { UFrag string Pwd string diff --git a/signal/client/client_test.go b/signal/client/client_test.go index f6c814325..5f59e5b80 100644 --- a/signal/client/client_test.go +++ b/signal/client/client_test.go @@ -17,7 +17,7 @@ import ( "time" ) -var _ = Describe("Client", func() { +var _ = Describe("GrpcClient", func() { var ( addr string @@ -160,7 +160,7 @@ var _ = Describe("Client", func() { }) -func createSignalClient(addr string, key wgtypes.Key) *Client { +func createSignalClient(addr string, key wgtypes.Key) *GrpcClient { var sigTLSEnabled = false client, err := NewClient(context.Background(), addr, key, sigTLSEnabled) if err != nil { diff --git a/signal/client/grpc.go b/signal/client/grpc.go new file mode 100644 index 000000000..dc7b8cb52 --- /dev/null +++ b/signal/client/grpc.go @@ -0,0 +1,336 @@ +package client + +import ( + "context" + "crypto/tls" + "fmt" + "github.com/cenkalti/backoff/v4" + log "github.com/sirupsen/logrus" + "github.com/wiretrustee/wiretrustee/encryption" + "github.com/wiretrustee/wiretrustee/signal/proto" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/keepalive" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "io" + "sync" + "time" +) + +// GrpcClient Wraps the Signal Exchange Service gRpc client +type GrpcClient struct { + key wgtypes.Key + realClient proto.SignalExchangeClient + signalConn *grpc.ClientConn + ctx context.Context + stream proto.SignalExchange_ConnectStreamClient + // connectedCh used to notify goroutines waiting for the connection to the Signal stream + connectedCh chan struct{} + mux sync.Mutex + // StreamConnected indicates whether this client is StreamConnected to the Signal stream + status Status +} + +func (c *GrpcClient) StreamConnected() bool { + return c.status == StreamConnected +} + +func (c *GrpcClient) GetStatus() Status { + return c.status +} + +// Close Closes underlying connections to the Signal Exchange +func (c *GrpcClient) Close() error { + return c.signalConn.Close() +} + +// NewClient creates a new Signal client +func NewClient(ctx context.Context, addr string, key wgtypes.Key, tlsEnabled bool) (*GrpcClient, error) { + + transportOption := grpc.WithTransportCredentials(insecure.NewCredentials()) + + if tlsEnabled { + transportOption = grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{})) + } + + sigCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + conn, err := grpc.DialContext( + sigCtx, + addr, + transportOption, + grpc.WithBlock(), + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: 15 * time.Second, + Timeout: 10 * time.Second, + })) + + if err != nil { + log.Errorf("failed to connect to the signalling server %v", err) + return nil, err + } + + return &GrpcClient{ + realClient: proto.NewSignalExchangeClient(conn), + ctx: ctx, + signalConn: conn, + key: key, + mux: sync.Mutex{}, + status: StreamDisconnected, + }, nil +} + +//defaultBackoff is a basic backoff mechanism for general issues +func defaultBackoff(ctx context.Context) backoff.BackOff { + return backoff.WithContext(&backoff.ExponentialBackOff{ + InitialInterval: 800 * time.Millisecond, + RandomizationFactor: backoff.DefaultRandomizationFactor, + Multiplier: backoff.DefaultMultiplier, + MaxInterval: 10 * time.Second, + MaxElapsedTime: 12 * time.Hour, //stop after 12 hours of trying, the error will be propagated to the general retry of the client + Stop: backoff.Stop, + Clock: backoff.SystemClock, + }, ctx) + +} + +// Receive Connects to the Signal Exchange message stream and starts receiving messages. +// The messages will be handled by msgHandler function provided. +// This function is blocking and reconnects to the Signal Exchange if errors occur (e.g. Exchange restart) +// The connection retry logic will try to reconnect for 30 min and if wasn't successful will propagate the error to the function caller. +func (c *GrpcClient) Receive(msgHandler func(msg *proto.Message) error) error { + + var backOff = defaultBackoff(c.ctx) + + operation := func() error { + + c.notifyStreamDisconnected() + + log.Debugf("signal connection state %v", c.signalConn.GetState()) + if !c.Ready() { + return fmt.Errorf("no connection to signal") + } + + // connect to Signal stream identifying ourselves with a public Wireguard key + // todo once the key rotation logic has been implemented, consider changing to some other identifier (received from management) + stream, err := c.connect(c.key.PublicKey().String()) + if err != nil { + log.Warnf("disconnected from the Signal Exchange due to an error: %v", err) + return err + } + + c.notifyStreamConnected() + + log.Infof("connected to the Signal Service stream") + + // start receiving messages from the Signal stream (from other peers through signal) + err = c.receive(stream, msgHandler) + if err != nil { + log.Warnf("disconnected from the Signal Exchange due to an error: %v", err) + backOff.Reset() + return err + } + + return nil + } + + err := backoff.Retry(operation, backOff) + if err != nil { + log.Errorf("exiting Signal Service connection retry loop due to unrecoverable error: %s", err) + return err + } + + return nil +} +func (c *GrpcClient) notifyStreamDisconnected() { + c.mux.Lock() + defer c.mux.Unlock() + c.status = StreamDisconnected +} + +func (c *GrpcClient) notifyStreamConnected() { + c.mux.Lock() + defer c.mux.Unlock() + c.status = StreamConnected + if c.connectedCh != nil { + // there are goroutines waiting on this channel -> release them + close(c.connectedCh) + c.connectedCh = nil + } +} + +func (c *GrpcClient) getStreamStatusChan() <-chan struct{} { + c.mux.Lock() + defer c.mux.Unlock() + if c.connectedCh == nil { + c.connectedCh = make(chan struct{}) + } + return c.connectedCh +} + +func (c *GrpcClient) connect(key string) (proto.SignalExchange_ConnectStreamClient, error) { + c.stream = nil + + // add key fingerprint to the request header to be identified on the server side + md := metadata.New(map[string]string{proto.HeaderId: key}) + ctx := metadata.NewOutgoingContext(c.ctx, md) + + stream, err := c.realClient.ConnectStream(ctx, grpc.WaitForReady(true)) + + c.stream = stream + if err != nil { + return nil, err + } + // blocks + header, err := c.stream.Header() + if err != nil { + return nil, err + } + registered := header.Get(proto.HeaderRegistered) + if len(registered) == 0 { + return nil, fmt.Errorf("didn't receive a registration header from the Signal server whille connecting to the streams") + } + + return stream, nil +} + +// Ready indicates whether the client is okay and Ready to be used +// for now it just checks whether gRPC connection to the service is in state Ready +func (c *GrpcClient) Ready() bool { + return c.signalConn.GetState() == connectivity.Ready || c.signalConn.GetState() == connectivity.Idle +} + +// WaitStreamConnected waits until the client is connected to the Signal stream +func (c *GrpcClient) WaitStreamConnected() { + + if c.status == StreamConnected { + return + } + + ch := c.getStreamStatusChan() + select { + case <-c.ctx.Done(): + case <-ch: + } +} + +// SendToStream sends a message to the remote Peer through the Signal Exchange using established stream connection to the Signal Server +// The GrpcClient.Receive method must be called before sending messages to establish initial connection to the Signal Exchange +// GrpcClient.connWg can be used to wait +func (c *GrpcClient) SendToStream(msg *proto.EncryptedMessage) error { + if !c.Ready() { + return fmt.Errorf("no connection to signal") + } + if c.stream == nil { + return fmt.Errorf("connection to the Signal Exchnage has not been established yet. Please call GrpcClient.Receive before sending messages") + } + + err := c.stream.Send(msg) + if err != nil { + log.Errorf("error while sending message to peer [%s] [error: %v]", msg.RemoteKey, err) + return err + } + + return nil +} + +// decryptMessage decrypts the body of the msg using Wireguard private key and Remote peer's public key +func (c *GrpcClient) decryptMessage(msg *proto.EncryptedMessage) (*proto.Message, error) { + remoteKey, err := wgtypes.ParseKey(msg.GetKey()) + if err != nil { + return nil, err + } + + body := &proto.Body{} + err = encryption.DecryptMessage(remoteKey, c.key, msg.GetBody(), body) + if err != nil { + return nil, err + } + + return &proto.Message{ + Key: msg.Key, + RemoteKey: msg.RemoteKey, + Body: body, + }, nil +} + +// encryptMessage encrypts the body of the msg using Wireguard private key and Remote peer's public key +func (c *GrpcClient) encryptMessage(msg *proto.Message) (*proto.EncryptedMessage, error) { + + remoteKey, err := wgtypes.ParseKey(msg.RemoteKey) + if err != nil { + return nil, err + } + + encryptedBody, err := encryption.EncryptMessage(remoteKey, c.key, msg.Body) + if err != nil { + return nil, err + } + + return &proto.EncryptedMessage{ + Key: msg.GetKey(), + RemoteKey: msg.GetRemoteKey(), + Body: encryptedBody, + }, nil +} + +// Send sends a message to the remote Peer through the Signal Exchange. +func (c *GrpcClient) Send(msg *proto.Message) error { + + if !c.Ready() { + return fmt.Errorf("no connection to signal") + } + + encryptedMessage, err := c.encryptMessage(msg) + if err != nil { + return err + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + _, err = c.realClient.Send(ctx, encryptedMessage) + if err != nil { + return err + } + + return nil +} + +// receive receives messages from other peers coming through the Signal Exchange +func (c *GrpcClient) receive(stream proto.SignalExchange_ConnectStreamClient, + msgHandler func(msg *proto.Message) error) error { + + for { + msg, err := stream.Recv() + if s, ok := status.FromError(err); ok && s.Code() == codes.Canceled { + log.Warnf("stream canceled (usually indicates shutdown)") + return err + } else if s.Code() == codes.Unavailable { + log.Warnf("Signal Service is unavailable") + return err + } else if err == io.EOF { + log.Warnf("Signal Service stream closed by server") + return err + } else if err != nil { + return err + } + log.Debugf("received a new message from Peer [fingerprint: %s]", msg.Key) + + decryptedMessage, err := c.decryptMessage(msg) + if err != nil { + log.Errorf("failed decrypting message of Peer [key: %s] error: [%s]", msg.Key, err.Error()) + } + + err = msgHandler(decryptedMessage) + + if err != nil { + log.Errorf("error while handling message of Peer [key: %s] error: [%s]", msg.Key, err.Error()) + //todo send something?? + } + } +} diff --git a/signal/client/mock.go b/signal/client/mock.go new file mode 100644 index 000000000..0b86defc7 --- /dev/null +++ b/signal/client/mock.go @@ -0,0 +1,72 @@ +package client + +import ( + "github.com/wiretrustee/wiretrustee/signal/proto" +) + +type MockClient struct { + CloseFunc func() error + GetStatusFunc func() Status + StreamConnectedFunc func() bool + ReadyFunc func() bool + WaitStreamConnectedFunc func() + ReceiveFunc func(msgHandler func(msg *proto.Message) error) error + SendToStreamFunc func(msg *proto.EncryptedMessage) error + SendFunc func(msg *proto.Message) error +} + +func (sm *MockClient) Close() error { + if sm.CloseFunc == nil { + return nil + } + return sm.CloseFunc() +} + +func (sm *MockClient) GetStatus() Status { + if sm.GetStatusFunc == nil { + return "" + } + return sm.GetStatusFunc() +} + +func (sm *MockClient) StreamConnected() bool { + if sm.StreamConnectedFunc == nil { + return false + } + return sm.StreamConnectedFunc() +} + +func (sm *MockClient) Ready() bool { + if sm.ReadyFunc == nil { + return false + } + return sm.ReadyFunc() +} + +func (sm *MockClient) WaitStreamConnected() { + if sm.WaitStreamConnectedFunc == nil { + return + } + sm.WaitStreamConnectedFunc() +} + +func (sm *MockClient) Receive(msgHandler func(msg *proto.Message) error) error { + if sm.ReceiveFunc == nil { + return nil + } + return sm.ReceiveFunc(msgHandler) +} + +func (sm *MockClient) SendToStream(msg *proto.EncryptedMessage) error { + if sm.SendToStreamFunc == nil { + return nil + } + return sm.SendToStreamFunc(msg) +} + +func (sm *MockClient) Send(msg *proto.Message) error { + if sm.SendFunc == nil { + return nil + } + return sm.SendFunc(msg) +} diff --git a/util/common.go b/util/common.go new file mode 100644 index 000000000..4d0059ed9 --- /dev/null +++ b/util/common.go @@ -0,0 +1,16 @@ +package util + +// SliceDiff returns the elements in slice `x` that are not in slice `y` +func SliceDiff(x, y []string) []string { + mapY := make(map[string]struct{}, len(y)) + for _, val := range y { + mapY[val] = struct{}{} + } + var diff []string + for _, val := range x { + if _, found := mapY[val]; !found { + diff = append(diff, val) + } + } + return diff +}