diff --git a/management/server/account.go b/management/server/account.go index 520858a16..d2aae927a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -49,6 +49,7 @@ type AccountManager interface { DeletePeer(accountId string, peerKey string) (*Peer, error) GetPeerByIP(accountId string, peerIP string) (*Peer, error) GetNetworkMap(peerKey string) (*NetworkMap, error) + GetPeerNetwork(peerKey string) (*Network, error) AddPeer(setupKey string, userId string, peer *Peer) (*Peer, error) UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error UpdatePeerSSHKey(peerKey string, sshKey string) error diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 0fa8a4026..f39f738db 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -230,7 +230,7 @@ func (s *Server) registerPeer(peerKey wgtypes.Key, req *proto.LoginRequest) (*Pe peersToSend = append(peersToSend, p) } } - update := toSyncResponse(s.config, remotePeer, peersToSend, nil, networkMap.Network.CurrentSerial()) + update := toSyncResponse(s.config, remotePeer, peersToSend, nil, networkMap.Network.CurrentSerial(), networkMap.Network) err = s.peersUpdateManager.SendUpdate(remotePeer.Key, &UpdateMessage{Update: update}) if err != nil { // todo rethink if we should keep this return @@ -309,10 +309,15 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto } } + network, err := s.accountManager.GetPeerNetwork(peer.Key) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed getting peer network on login") + } + // if peer has reached this point then it has logged in loginResp := &proto.LoginResponse{ WiretrusteeConfig: toWiretrusteeConfig(s.config, nil), - PeerConfig: toPeerConfig(peer), + PeerConfig: toPeerConfig(peer, network), } encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp) if err != nil { @@ -382,9 +387,10 @@ func toWiretrusteeConfig(config *Config, turnCredentials *TURNCredentials) *prot } } -func toPeerConfig(peer *Peer) *proto.PeerConfig { +func toPeerConfig(peer *Peer, network *Network) *proto.PeerConfig { + netmask, _ := network.Net.Mask.Size() return &proto.PeerConfig{ - Address: fmt.Sprintf("%s/%d", peer.IP.String(), SubnetSize), // take it from the network + Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, } } @@ -401,10 +407,10 @@ func toRemotePeerConfig(peers []*Peer) []*proto.RemotePeerConfig { return remotePeers } -func toSyncResponse(config *Config, peer *Peer, peers []*Peer, turnCredentials *TURNCredentials, serial uint64) *proto.SyncResponse { +func toSyncResponse(config *Config, peer *Peer, peers []*Peer, turnCredentials *TURNCredentials, serial uint64, network *Network) *proto.SyncResponse { wtConfig := toWiretrusteeConfig(config, turnCredentials) - pConfig := toPeerConfig(peer) + pConfig := toPeerConfig(peer, network) remotePeers := toRemotePeerConfig(peers) @@ -443,7 +449,7 @@ func (s *Server) sendInitialSync(peerKey wgtypes.Key, peer *Peer, srv proto.Mana } else { turnCredentials = nil } - plainResp := toSyncResponse(s.config, peer, networkMap.Peers, turnCredentials, networkMap.Network.CurrentSerial()) + plainResp := toSyncResponse(s.config, peer, networkMap.Peers, turnCredentials, networkMap.Network.CurrentSerial(), networkMap.Network) encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp) if err != nil { diff --git a/management/server/management_test.go b/management/server/management_test.go index 75a258009..b66b7a827 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -422,6 +422,22 @@ var _ = Describe("Management service", func() { close(ipChannel) }) }) + + Context("after login two peers", func() { + Specify("then they receive the same network", func() { + key, _ := wgtypes.GenerateKey() + firstLogin := loginPeerWithValidSetupKey(serverPubKey, key, client) + key, _ = wgtypes.GenerateKey() + secondLogin := loginPeerWithValidSetupKey(serverPubKey, key, client) + + _, firstLoginNetwork, err := net.ParseCIDR(firstLogin.GetPeerConfig().GetAddress()) + Expect(err).NotTo(HaveOccurred()) + _, secondLoginNetwork, err := net.ParseCIDR(secondLogin.GetPeerConfig().GetAddress()) + Expect(err).NotTo(HaveOccurred()) + + Expect(secondLoginNetwork.String()).To(BeEquivalentTo(firstLoginNetwork.String())) + }) + }) }) func loginPeerWithValidSetupKey(serverPubKey wgtypes.Key, key wgtypes.Key, client mgmtProto.ManagementServiceClient) *mgmtProto.LoginResponse { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 5e85bf2e6..587e7c54b 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -25,6 +25,7 @@ type MockAccountManager struct { DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error) GetPeerByIPFunc func(accountId string, peerIP string) (*server.Peer, error) GetNetworkMapFunc func(peerKey string) (*server.NetworkMap, error) + GetPeerNetworkFunc func(peerKey string) (*server.Network, error) AddPeerFunc func(setupKey string, userId string, peer *server.Peer) (*server.Peer, error) GetGroupFunc func(accountID, groupID string) (*server.Group, error) SaveGroupFunc func(accountID string, group *server.Group) error @@ -204,6 +205,14 @@ func (am *MockAccountManager) GetNetworkMap(peerKey string) (*server.NetworkMap, return nil, status.Errorf(codes.Unimplemented, "method GetNetworkMap is not implemented") } +// GetPeerNetwork mock implementation of GetPeerNetwork from server.AccountManager interface +func (am *MockAccountManager) GetPeerNetwork(peerKey string) (*server.Network, error) { + if am.GetPeerNetworkFunc != nil { + return am.GetPeerNetworkFunc(peerKey) + } + return nil, status.Errorf(codes.Unimplemented, "method GetPeerNetwork is not implemented") +} + // AddPeer mock implementation of AddPeer from server.AccountManager interface func (am *MockAccountManager) AddPeer( setupKey string, diff --git a/management/server/peer.go b/management/server/peer.go index 9fafb346c..1b078abfc 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -257,6 +257,19 @@ func (am *DefaultAccountManager) GetNetworkMap(peerKey string) (*NetworkMap, err }, err } +// GetPeerNetwork returns the Network for a given peer +func (am *DefaultAccountManager) GetPeerNetwork(peerKey string) (*Network, error) { + am.mux.Lock() + defer am.mux.Unlock() + + account, err := am.Store.GetPeerAccount(peerKey) + if err != nil { + return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerKey) + } + + return account.Network.Copy(), err +} + // AddPeer adds a new peer to the Store. // Each Account has a list of pre-authorised SetupKey and if no Account has a given key err wit ha code codes.Unauthenticated // will be returned, meaning the key is invalid @@ -493,6 +506,8 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) error { return err } + network := account.Network.Copy() + for _, p := range peers { update := toRemotePeerConfig(am.getPeersByACL(account, p.Key)) err = am.peersUpdateManager.SendUpdate(p.Key, @@ -506,7 +521,7 @@ func (am *DefaultAccountManager) updateAccountPeers(account *Account) error { Serial: account.Network.CurrentSerial(), RemotePeers: update, RemotePeersIsEmpty: len(update) == 0, - PeerConfig: toPeerConfig(p), + PeerConfig: toPeerConfig(p, network), }, }, }) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index b99b27e84..ea392b076 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -253,3 +253,69 @@ func TestAccountManager_GetNetworkMapWithRule(t *testing.T) { t.Errorf("expecting Account NetworkMap to have 0 peers, got %v", len(networkMap2.Peers)) } } + +func TestAccountManager_GetPeerNetwork(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + expectedId := "test_account" + userId := "account_creator" + account, err := createAccount(manager, expectedId, userId, "") + if err != nil { + t.Fatal(err) + } + + var setupKey *SetupKey + for _, key := range account.SetupKeys { + if key.Type == SetupKeyReusable { + setupKey = key + } + } + + peerKey1, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return + } + + _, err = manager.AddPeer(setupKey.Key, "", &Peer{ + Key: peerKey1.PublicKey().String(), + Meta: PeerSystemMeta{}, + Name: "test-peer-2", + }) + + if err != nil { + t.Errorf("expecting peer to be added, got failure %v", err) + return + } + + peerKey2, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return + } + _, err = manager.AddPeer(setupKey.Key, "", &Peer{ + Key: peerKey2.PublicKey().String(), + Meta: PeerSystemMeta{}, + Name: "test-peer-2", + }) + + if err != nil { + t.Errorf("expecting peer to be added, got failure %v", err) + return + } + + network, err := manager.GetPeerNetwork(peerKey1.PublicKey().String()) + if err != nil { + t.Fatal(err) + return + } + + if account.Network.Id != network.Id { + t.Errorf("expecting Account Networks ID to be equal, got %s expected %s", network.Id, account.Network.Id) + } + +}