From 241b8191560b3f4689252f08b41f614a3b7e286e Mon Sep 17 00:00:00 2001 From: braginini Date: Wed, 1 Mar 2023 18:54:27 +0100 Subject: [PATCH] Refactor Sync --- management/server/account.go | 1 + management/server/grpcserver.go | 68 ++++-------- management/server/mock_server/account_mock.go | 9 ++ management/server/peer.go | 101 ++++++++++++------ 4 files changed, 99 insertions(+), 80 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index c19ea90cc..d36d21bf5 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -97,6 +97,7 @@ type AccountManager interface { GetPeer(accountID, peerID, userID string) (*Peer, error) UpdateAccountSettings(accountID, userID string, newSettings *Settings) (*Account, error) LoginPeer(login PeerLogin) (*Peer, error) + SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) } type DefaultAccountManager struct { diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 1645dbd6d..045cc9766 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -3,6 +3,7 @@ package server import ( "context" "fmt" + pb "github.com/golang/protobuf/proto" //nolint "strings" "time" @@ -118,44 +119,18 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi log.Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, p.Addr.String()) } - peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) - if err != nil { - log.Warnf("error while parsing peer's Wireguard public key %s on Sync request.", peerKey.String()) - return status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", peerKey.String()) - } - - peer, err := s.accountManager.GetPeerByKey(peerKey.String()) - if err != nil { - p, _ := gRPCPeer.FromContext(srv.Context()) - msg := status.Errorf(codes.PermissionDenied, "provided peer with the key wgPubKey %s is not registered, remote addr is %s", peerKey.String(), p.Addr.String()) - log.Debug(msg) - return msg - } - - account, err := s.accountManager.GetAccountByPeerID(peer.ID) - if err != nil { - return status.Error(codes.Internal, "internal server error") - } - expired, left := peer.LoginExpired(account.Settings.PeerLoginExpiration) - expired = account.Settings.PeerLoginExpirationEnabled && expired - if peer.UserID != "" && (expired || peer.Status.LoginExpired) { - err = s.accountManager.MarkPeerLoginExpired(peerKey.String(), true) - if err != nil { - log.Warnf("failed marking peer login expired %s %v", peerKey, err) - } - return status.Errorf(codes.PermissionDenied, "peer login has expired %v ago. Please log in once more", left) - } - syncReq := &proto.SyncRequest{} - err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, syncReq) + peerKey, err := s.parseRequest(req, syncReq) if err != nil { - p, _ := gRPCPeer.FromContext(srv.Context()) - msg := status.Errorf(codes.InvalidArgument, "invalid request message from %s,remote addr is %s", peerKey.String(), p.Addr.String()) - log.Debug(msg) - return msg + return err } - err = s.sendInitialSync(peerKey, peer, srv) + peer, netMap, err := s.accountManager.SyncPeer(PeerSync{WireGuardPubKey: peerKey.String()}) + if err != nil { + return mapError(err) + } + + err = s.sendInitialSync(peerKey, peer, netMap, srv) if err != nil { log.Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) return err @@ -263,21 +238,19 @@ func extractPeerMeta(loginReq *proto.LoginRequest) PeerSystemMeta { } } -func (s *GRPCServer) parseLoginRequest(req *proto.EncryptedMessage) (*proto.LoginRequest, wgtypes.Key, error) { +func (s *GRPCServer) parseRequest(req *proto.EncryptedMessage, parsed pb.Message) (wgtypes.Key, error) { peerKey, err := wgtypes.ParseKey(req.GetWgPubKey()) if err != nil { log.Warnf("error while parsing peer's WireGuard public key %s.", req.WgPubKey) - return nil, wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey) + return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey) } - loginReq := &proto.LoginRequest{} - err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, loginReq) + err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed) if err != nil { - return nil, wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message") + return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message") } - return loginReq, peerKey, nil - + return peerKey, nil } // Login endpoint first checks whether peer is registered under any account @@ -293,7 +266,8 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p log.Debugf("Login request from peer [%s] [%s]", req.WgPubKey, p.Addr.String()) } - loginReq, peerKey, err := s.parseLoginRequest(req) + loginReq := &proto.LoginRequest{} + peerKey, err := s.parseRequest(req, loginReq) if err != nil { return nil, err } @@ -310,7 +284,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p // JWT token is not always provided, it is fine for userID to be empty cuz it might be that peer is already registered, // or it uses a setup key to register. if loginReq.GetJwtToken() != "" { - // todo what about the case when JWT provided expired? + // todo what about the case when JWT provided is expired? userID, err = s.validateToken(loginReq.GetJwtToken()) if err != nil { log.Warnf("failed validating JWT token sent from peer %s", peerKey) @@ -473,13 +447,7 @@ func (s *GRPCServer) IsHealthy(ctx context.Context, req *proto.Empty) (*proto.Em } // sendInitialSync sends initial proto.SyncResponse to the peer requesting synchronization -func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.ManagementService_SyncServer) error { - networkMap, err := s.accountManager.GetNetworkMap(peer.ID) - if err != nil { - log.Warnf("error getting a list of peers for a peer %s", peer.ID) - return err - } - +func (s *GRPCServer) sendInitialSync(peerKey wgtypes.Key, peer *Peer, networkMap *NetworkMap, srv proto.ManagementService_SyncServer) error { // make secret time based TURN credentials optional var turnCredentials *TURNCredentials if s.config.TURNConfig.TimeBasedCredentials { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 0af18f988..ab8b1a267 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -73,6 +73,7 @@ type MockAccountManager struct { GetAccountByPeerIDFunc func(peerID string) (*server.Account, error) UpdateAccountSettingsFunc func(accountID, userID string, newSettings *server.Settings) (*server.Account, error) LoginPeerFunc func(login server.PeerLogin) (*server.Peer, error) + SyncPeerFunc func(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error) } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface @@ -564,3 +565,11 @@ func (am *MockAccountManager) LoginPeer(login server.PeerLogin) (*server.Peer, e } return nil, status.Errorf(codes.Unimplemented, "method LoginPeer is not implemented") } + +// SyncPeer mocks SyncPeer of the AccountManager interface +func (am *MockAccountManager) SyncPeer(sync server.PeerSync) (*server.Peer, *server.NetworkMap, error) { + if am.SyncPeerFunc != nil { + return am.SyncPeerFunc(sync) + } + return nil, nil, status.Errorf(codes.Unimplemented, "method SyncPeer is not implemented") +} diff --git a/management/server/peer.go b/management/server/peer.go index 09e77e201..e103bfcd0 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -36,6 +36,12 @@ type PeerStatus struct { LoginExpired bool } +// PeerSync used as a data object between the gRPC API and AccountManager on Sync request. +type PeerSync struct { + // WireGuardPubKey is a peers WireGuard public key + WireGuardPubKey string +} + // PeerLogin used as a data object between the gRPC API and AccountManager on Login request. type PeerLogin struct { // WireGuardPubKey is a peers WireGuard public key @@ -454,6 +460,34 @@ func (am *DefaultAccountManager) GetPeerByIP(accountID string, peerIP string) (* return nil, status.Errorf(status.NotFound, "peer with IP %s not found", peerIP) } +func (am *DefaultAccountManager) getNetworkMap(peer *Peer, account *Account) *NetworkMap { + aclPeers := account.getPeersByACL(peer.ID) + // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. + routesUpdate := account.getRoutesToSync(peer.ID, aclPeers) + + dnsManagementStatus := account.getPeerDNSManagementStatus(peer.ID) + dnsUpdate := nbdns.Config{ + ServiceEnable: dnsManagementStatus, + } + + if dnsManagementStatus { + var zones []nbdns.CustomZone + peersCustomZone := getPeersCustomZone(account, am.dnsDomain) + if peersCustomZone.Domain != "" { + zones = append(zones, peersCustomZone) + } + dnsUpdate.CustomZones = zones + dnsUpdate.NameServerGroups = getPeerNSGroups(account, peer.ID) + } + + return &NetworkMap{ + Peers: aclPeers, + Network: account.Network.Copy(), + Routes: routesUpdate, + DNSConfig: dnsUpdate, + } +} + // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, error) { @@ -467,31 +501,7 @@ func (am *DefaultAccountManager) GetNetworkMap(peerID string) (*NetworkMap, erro return nil, status.Errorf(status.NotFound, "peer with ID %s not found", peerID) } - aclPeers := account.getPeersByACL(peerID) - // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. - routesUpdate := account.getRoutesToSync(peerID, aclPeers) - - dnsManagementStatus := account.getPeerDNSManagementStatus(peerID) - dnsUpdate := nbdns.Config{ - ServiceEnable: dnsManagementStatus, - } - - if dnsManagementStatus { - var zones []nbdns.CustomZone - peersCustomZone := getPeersCustomZone(account, am.dnsDomain) - if peersCustomZone.Domain != "" { - zones = append(zones, peersCustomZone) - } - dnsUpdate.CustomZones = zones - dnsUpdate.NameServerGroups = getPeerNSGroups(account, peerID) - } - - return &NetworkMap{ - Peers: aclPeers, - Network: account.Network.Copy(), - Routes: routesUpdate, - DNSConfig: dnsUpdate, - }, err + return am.getNetworkMap(peer, account), nil } // GetPeerNetwork returns the Network for a given peer @@ -643,13 +653,13 @@ func (am *DefaultAccountManager) AddPeer(setupKey, userID string, peer *Peer) (* return newPeer, nil } -func (am *DefaultAccountManager) checkPeerLoginExpiration(login PeerLogin, peer *Peer, account *Account) error { +func (am *DefaultAccountManager) checkPeerLoginExpiration(loginUserID string, peer *Peer, account *Account) error { if peer.AddedWithSSOLogin() { expired, expiresIn := peer.LoginExpired(account.Settings.PeerLoginExpiration) expired = account.Settings.PeerLoginExpirationEnabled && expired if expired || peer.Status.LoginExpired { log.Debugf("peer %s login expired", peer.ID) - if login.UserID == "" { + if loginUserID == "" { // absence of a user ID indicates that JWT wasn't provided. _, err := am.markPeerLoginExpired(peer, account, true) if err != nil { @@ -659,8 +669,8 @@ func (am *DefaultAccountManager) checkPeerLoginExpiration(login PeerLogin, peer "peer login has expired %v ago. Please log in once more", expiresIn) } else { // user ID is there meaning that JWT validation passed successfully in the API layer. - if peer.UserID != login.UserID { - log.Warnf("user mismatch when loggin in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, login.UserID) + if peer.UserID != loginUserID { + log.Warnf("user mismatch when loggin in peer %s: peer user %s, login user %s ", peer.ID, peer.UserID, loginUserID) return status.Errorf(status.Unauthenticated, "can't login") } _ = am.updatePeerLastLogin(peer, account) @@ -671,6 +681,37 @@ func (am *DefaultAccountManager) checkPeerLoginExpiration(login PeerLogin, peer return nil } +// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible +func (am *DefaultAccountManager) SyncPeer(sync PeerSync) (*Peer, *NetworkMap, error) { + account, err := am.Store.GetAccountByPeerPubKey(sync.WireGuardPubKey) + if err != nil { + return nil, nil, err + } + + // we found the peer, and we follow a normal login flow + unlock := am.Store.AcquireAccountLock(account.Id) + defer unlock() + + // fetch the account from the store once more after acquiring lock to avoid concurrent updates inconsistencies + account, err = am.Store.GetAccount(account.Id) + if err != nil { + return nil, nil, err + } + + peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) + if err != nil { + return nil, nil, err + } + + err = am.checkPeerLoginExpiration("", peer, account) + if err != nil { + return nil, nil, err + } + + return peer, am.getNetworkMap(peer, account), nil + +} + // LoginPeer logs in or registers a peer. // If peer doesn't exist the function checks whether a setup key or a user is present and registers a new peer if so. func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, error) { @@ -705,7 +746,7 @@ func (am *DefaultAccountManager) LoginPeer(login PeerLogin) (*Peer, error) { return nil, err } - err = am.checkPeerLoginExpiration(login, peer, account) + err = am.checkPeerLoginExpiration(login.UserID, peer, account) if err != nil { return nil, err }