diff --git a/management/server/account_test.go b/management/server/account_test.go index 432b17d3f..3dab9b347 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2,6 +2,7 @@ package server import ( "net" + "sync" "testing" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -513,6 +514,189 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { } } +func TestAccountManager_NetworkUpdates(t *testing.T) { + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + return + } + + account, err := manager.AddAccount("test_account", "account_creator", "") + if err != nil { + t.Fatal(err) + } + + var setupKey *SetupKey + for _, key := range account.SetupKeys { + setupKey = key + if setupKey.Type == SetupKeyReusable { + break + } + } + + if setupKey == nil { + t.Errorf("expecting account to have a default setup key") + return + } + + if account.Network.Serial != 0 { + t.Errorf("expecting account network to have an initial Serial=0") + return + } + + getPeer := func() *Peer { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + return nil + } + expectedPeerKey := key.PublicKey().String() + + peer, err := manager.AddPeer(setupKey.Key, "", &Peer{ + Key: expectedPeerKey, + Meta: PeerSystemMeta{}, + Name: expectedPeerKey, + }) + if err != nil { + t.Fatalf("expecting peer1 to be added, got failure %v", err) + return nil + } + + return peer + } + + peer1 := getPeer() + peer2 := getPeer() + peer3 := getPeer() + + account, err = manager.GetAccountById(account.Id) + if err != nil { + t.Fatal(err) + return + } + + updMsg := manager.peersUpdateManager.CreateChannel(peer1.Key) + defer manager.peersUpdateManager.CloseChannel(peer1.Key) + + group := Group{ + ID: "group-id", + Name: "GroupA", + Peers: []string{peer1.Key, peer2.Key, peer3.Key}, + } + + rule := Rule{ + Source: []string{"group-id"}, + Destination: []string{"group-id"}, + Flow: TrafficFlowBidirect, + } + + wg := sync.WaitGroup{} + t.Run("save group update", func(t *testing.T) { + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 2 { + t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + if err := manager.SaveGroup(account.Id, &group); err != nil { + t.Errorf("save group: %v", err) + return + } + + wg.Wait() + }) + + t.Run("delete rule update", func(t *testing.T) { + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 0 { + t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + var defaultRule *Rule + for _, r := range account.Rules { + defaultRule = r + } + + if err := manager.DeleteRule(account.Id, defaultRule.ID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + wg.Wait() + }) + + t.Run("save rule update", func(t *testing.T) { + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 2 { + t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + if err := manager.SaveRule(account.Id, &rule); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + wg.Wait() + }) + + t.Run("delete peer update", func(t *testing.T) { + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 1 { + t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + if _, err := manager.DeletePeer(account.Id, peer3.Key); err != nil { + t.Errorf("delete peer: %v", err) + return + } + + wg.Wait() + }) + + t.Run("delete group update", func(t *testing.T) { + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 0 { + t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + if err := manager.DeleteGroup(account.Id, group.ID); err != nil { + t.Errorf("delete group rule: %v", err) + return + } + + wg.Wait() + }) +} + func TestAccountManager_DeletePeer(t *testing.T) { manager, err := createManager(t) if err != nil { @@ -664,7 +848,6 @@ func TestAccountManager_UpdatePeerMeta(t *testing.T) { } assert.Equal(t, newMeta, p.Meta) - } func createManager(t *testing.T) (*DefaultAccountManager, error) { diff --git a/management/server/file_store.go b/management/server/file_store.go index 500a8070f..a3abebb0c 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -180,6 +180,8 @@ func (s *FileStore) DeletePeer(accountId string, peerKey string) (*Peer, error) delete(account.Peers, peerKey) delete(s.PeerKeyId2AccountId, peerKey) + delete(s.PeerKeyId2DstRulesId, peerKey) + delete(s.PeerKeyId2SrcRulesId, peerKey) // cleanup groups var peers []string @@ -240,9 +242,34 @@ func (s *FileStore) SaveAccount(account *Account) error { s.PeerKeyId2AccountId[peer.Key] = account.Id } + // remove all peers related to account from rules indexes + cleanIDs := make([]string, 0) + for key := range s.PeerKeyId2SrcRulesId { + if accountID, ok := s.PeerKeyId2AccountId[key]; ok && accountID == account.Id { + cleanIDs = append(cleanIDs, key) + } + } + for _, key := range cleanIDs { + delete(s.PeerKeyId2SrcRulesId, key) + } + cleanIDs = cleanIDs[:0] + for key := range s.PeerKeyId2DstRulesId { + if accountID, ok := s.PeerKeyId2AccountId[key]; ok && accountID == account.Id { + cleanIDs = append(cleanIDs, key) + } + } + for _, key := range cleanIDs { + delete(s.PeerKeyId2DstRulesId, key) + } + + // rebuild rule indexes for _, rule := range account.Rules { for _, gid := range rule.Source { - for _, pid := range account.Groups[gid].Peers { + g, ok := account.Groups[gid] + if !ok { + break + } + for _, pid := range g.Peers { rules := s.PeerKeyId2SrcRulesId[pid] if rules == nil { rules = map[string]struct{}{} @@ -252,7 +279,11 @@ func (s *FileStore) SaveAccount(account *Account) error { } } for _, gid := range rule.Destination { - for _, pid := range account.Groups[gid].Peers { + g, ok := account.Groups[gid] + if !ok { + break + } + for _, pid := range g.Peers { rules := s.PeerKeyId2DstRulesId[pid] if rules == nil { rules = map[string]struct{}{} diff --git a/management/server/group.go b/management/server/group.go index e71f33eb3..de57ac2f3 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -54,7 +54,13 @@ func (am *DefaultAccountManager) SaveGroup(accountID string, group *Group) error } account.Groups[group.ID] = group - return am.Store.SaveAccount(account) + + account.Network.IncSerial() + if err = am.Store.SaveAccount(account); err != nil { + return err + } + + return am.updateAccountPeers(account) } // DeleteGroup object of the peers @@ -69,7 +75,12 @@ func (am *DefaultAccountManager) DeleteGroup(accountID, groupID string) error { delete(account.Groups, groupID) - return am.Store.SaveAccount(account) + account.Network.IncSerial() + if err = am.Store.SaveAccount(account); err != nil { + return err + } + + return am.updateAccountPeers(account) } // ListGroups objects of the peers @@ -116,7 +127,12 @@ func (am *DefaultAccountManager) GroupAddPeer(accountID, groupID, peerKey string group.Peers = append(group.Peers, peerKey) } - return am.Store.SaveAccount(account) + account.Network.IncSerial() + if err = am.Store.SaveAccount(account); err != nil { + return err + } + + return am.updateAccountPeers(account) } // GroupDeletePeer removes peer from the group @@ -134,14 +150,17 @@ func (am *DefaultAccountManager) GroupDeletePeer(accountID, groupID, peerKey str return status.Errorf(codes.NotFound, "group with ID %s not found", groupID) } + account.Network.IncSerial() for i, itemID := range group.Peers { if itemID == peerKey { group.Peers = append(group.Peers[:i], group.Peers[i+1:]...) - return am.Store.SaveAccount(account) + if err := am.Store.SaveAccount(account); err != nil { + return status.Errorf(codes.Internal, "can't save account") + } } } - return nil + return am.updateAccountPeers(account) } // GroupListPeers returns list of the peers from the group diff --git a/management/server/peer.go b/management/server/peer.go index a9b33f334..e76f2b067 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -134,6 +134,16 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (* return nil, status.Errorf(codes.NotFound, "account not found") } + // delete peer from groups + for _, g := range account.Groups { + for i, pk := range g.Peers { + if pk == peerKey { + g.Peers = append(g.Peers[:i], g.Peers[i+1:]...) + break + } + } + } + peer, err := am.Store.DeletePeer(accountId, peerKey) if err != nil { return nil, err @@ -163,39 +173,10 @@ func (am *DefaultAccountManager) DeletePeer(accountId string, peerKey string) (* return nil, err } - // notify other peers of the change - peers, err := am.Store.GetAccountPeers(accountId) - if err != nil { + if err := am.updateAccountPeers(account); err != nil { return nil, err } - for _, p := range peers { - peersToSend := []*Peer{} - for _, remote := range peers { - if p.Key != remote.Key { - peersToSend = append(peersToSend, remote) - } - } - update := toRemotePeerConfig(peersToSend) - err = am.peersUpdateManager.SendUpdate(p.Key, - &UpdateMessage{ - Update: &proto.SyncResponse{ - // fill those field for backward compatibility - RemotePeers: update, - RemotePeersIsEmpty: len(update) == 0, - // new field - NetworkMap: &proto.NetworkMap{ - Serial: account.Network.CurrentSerial(), - RemotePeers: update, - RemotePeersIsEmpty: len(update) == 0, - }, - }, - }) - if err != nil { - return nil, err - } - } - am.peersUpdateManager.CloseChannel(peerKey) return peer, nil } @@ -229,56 +210,8 @@ func (am *DefaultAccountManager) GetNetworkMap(peerKey string) (*NetworkMap, err return nil, status.Errorf(codes.Internal, "Invalid peer key %s", peerKey) } - var res []*Peer - srcRules, err := am.Store.GetPeerSrcRules(account.Id, peerKey) - if err != nil { - return &NetworkMap{ - Peers: res, - Network: account.Network.Copy(), - }, nil - } - - dstRules, err := am.Store.GetPeerDstRules(account.Id, peerKey) - if err != nil { - return &NetworkMap{ - Peers: res, - Network: account.Network.Copy(), - }, nil - } - - groups := map[string]*Group{} - for _, r := range srcRules { - if r.Flow == TrafficFlowBidirect { - for _, gid := range r.Destination { - groups[gid] = account.Groups[gid] - } - } - } - - for _, r := range dstRules { - if r.Flow == TrafficFlowBidirect { - for _, gid := range r.Source { - groups[gid] = account.Groups[gid] - } - } - } - - for _, g := range groups { - for _, pid := range g.Peers { - peer, ok := account.Peers[pid] - if !ok { - log.Warnf("peer %s found in group %s but doesn't belong to account %s", pid, g.ID, account.Id) - continue - } - // exclude original peer - if peer.Key != peerKey { - res = append(res, peer.Copy()) - } - } - } - return &NetworkMap{ - Peers: res, + Peers: am.getPeersByACL(account, peerKey), Network: account.Network.Copy(), }, err } @@ -411,3 +344,93 @@ func (am *DefaultAccountManager) UpdatePeerMeta(peerKey string, meta PeerSystemM } return nil } + +// getPeersByACL allowed for given peer by ACL +func (am *DefaultAccountManager) getPeersByACL(account *Account, peerKey string) []*Peer { + var peers []*Peer + srcRules, err := am.Store.GetPeerSrcRules(account.Id, peerKey) + if err != nil { + srcRules = []*Rule{} + } + + dstRules, err := am.Store.GetPeerDstRules(account.Id, peerKey) + if err != nil { + dstRules = []*Rule{} + } + + groups := map[string]*Group{} + for _, r := range srcRules { + if r.Flow == TrafficFlowBidirect { + for _, gid := range r.Destination { + if group, ok := account.Groups[gid]; ok { + groups[gid] = group + } + } + } + } + + for _, r := range dstRules { + if r.Flow == TrafficFlowBidirect { + for _, gid := range r.Source { + if group, ok := account.Groups[gid]; ok { + groups[gid] = group + } + } + } + } + + peersSet := make(map[string]struct{}) + for _, g := range groups { + for _, pid := range g.Peers { + peer, ok := account.Peers[pid] + if !ok { + log.Warnf( + "peer %s found in group %s but doesn't belong to account %s", + pid, + g.ID, + account.Id, + ) + continue + } + // exclude original peer + if _, ok := peersSet[peer.Key]; peer.Key != peerKey && !ok { + peersSet[peer.Key] = struct{}{} + peers = append(peers, peer.Copy()) + } + } + } + + return peers +} + +// updateAccountPeers network map constructed by ACL +func (am *DefaultAccountManager) updateAccountPeers(account *Account) error { + // notify other peers of the change + peers, err := am.Store.GetAccountPeers(account.Id) + if err != nil { + return err + } + + for _, p := range peers { + update := toRemotePeerConfig(am.getPeersByACL(account, p.Key)) + err = am.peersUpdateManager.SendUpdate(p.Key, + &UpdateMessage{ + Update: &proto.SyncResponse{ + // fill those field for backward compatibility + RemotePeers: update, + RemotePeersIsEmpty: len(update) == 0, + // new field + NetworkMap: &proto.NetworkMap{ + Serial: account.Network.CurrentSerial(), + RemotePeers: update, + RemotePeersIsEmpty: len(update) == 0, + }, + }, + }) + if err != nil { + return err + } + } + + return nil +} diff --git a/management/server/peer_test.go b/management/server/peer_test.go index ba87ca773..58210e61d 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -192,6 +192,7 @@ func TestAccountManager_GetNetworkMapWithRule(t *testing.T) { len(networkMap1.Peers), networkMap1.Peers, ) + return } if networkMap1.Peers[0].Key != peerKey2.PublicKey().String() { diff --git a/management/server/rule.go b/management/server/rule.go index 3cbf172d7..ebb8caab2 100644 --- a/management/server/rule.go +++ b/management/server/rule.go @@ -70,7 +70,13 @@ func (am *DefaultAccountManager) SaveRule(accountID string, rule *Rule) error { } account.Rules[rule.ID] = rule - return am.Store.SaveAccount(account) + + account.Network.IncSerial() + if err = am.Store.SaveAccount(account); err != nil { + return err + } + + return am.updateAccountPeers(account) } // DeleteRule of ACL from the store @@ -85,7 +91,12 @@ func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error { delete(account.Rules, ruleID) - return am.Store.SaveAccount(account) + account.Network.IncSerial() + if err = am.Store.SaveAccount(account); err != nil { + return err + } + + return am.updateAccountPeers(account) } // ListRules of ACL from the store