[management] Refactor network map controller (#4789)

This commit is contained in:
Pascal Fischer
2025-12-02 12:34:28 +01:00
committed by GitHub
parent 52948ccd61
commit 7193bd2da7
45 changed files with 819 additions and 492 deletions

View File

@@ -19,6 +19,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
@@ -42,6 +43,7 @@ type Controller struct {
accountManagerMetrics *telemetry.AccountManagerMetrics
peersUpdateManager network_map.PeersUpdateManager
settingsManager settings.Manager
EphemeralPeersManager ephemeral.Manager
accountUpdateLocks sync.Map
sendAccountUpdateLocks sync.Map
@@ -70,7 +72,7 @@ type bufferUpdate struct {
var _ network_map.Controller = (*Controller)(nil)
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, config *config.Config) *Controller {
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
nMetrics, err := newMetrics(metrics.UpdateChannelMetrics())
if err != nil {
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
@@ -99,7 +101,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
dnsDomain: dnsDomain,
config: config,
proxyController: proxyController,
proxyController: proxyController,
EphemeralPeersManager: ephemeralPeersManager,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
@@ -107,6 +110,31 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
}
}
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
if err != nil {
return nil, fmt.Errorf("failed to get peer %s: %v", peerID, err)
}
c.EphemeralPeersManager.OnPeerConnected(ctx, peer)
return c.peersUpdateManager.CreateChannel(ctx, peerID), nil
}
func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, peerID string) {
c.peersUpdateManager.CloseChannel(ctx, peerID)
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get peer %s: %v", peerID, err)
return
}
c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer)
}
func (c *Controller) CountStreams() int {
return c.peersUpdateManager.CountStreams()
}
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
var (
@@ -366,38 +394,6 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
return nil
}
func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error {
network, err := c.repo.GetAccountNetwork(ctx, accountId)
if err != nil {
return err
}
peers, err := c.repo.GetAccountPeers(ctx, accountId)
if err != nil {
return err
}
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{
Update: &proto.SyncResponse{
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
NetworkMap: &proto.NetworkMap{
Serial: network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
DNSConfig: &proto.DNSConfig{
ForwarderPort: dnsFwdPort,
},
},
},
})
c.peersUpdateManager.CloseChannel(ctx, peerId)
return nil
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
@@ -698,35 +694,83 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
return false, nil
}
func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) {
c.UpdatePeerInNetworkMapCache(accountId, peer)
_ = c.bufferSendUpdateAccountPeers(context.Background(), accountId)
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
if err != nil {
return fmt.Errorf("failed to get peers by ids: %w", err)
}
for _, peer := range peers {
c.UpdatePeerInNetworkMapCache(accountID, peer)
}
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
}
return nil
}
func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error {
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
for _, peerID := range peerIDs {
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
if err != nil {
return err
err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
if err != nil {
return err
}
}
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
}
func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error {
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
if err != nil {
return err
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return err
}
peers, err := c.repo.GetAccountPeers(ctx, accountID)
if err != nil {
return err
}
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
for _, peerID := range peerIDs {
c.peersUpdateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
Update: &proto.SyncResponse{
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
NetworkMap: &proto.NetworkMap{
Serial: network.CurrentSerial(),
RemotePeers: []*proto.RemotePeerConfig{},
RemotePeersIsEmpty: true,
FirewallRules: []*proto.FirewallRule{},
FirewallRulesIsEmpty: true,
DNSConfig: &proto.DNSConfig{
ForwarderPort: dnsFwdPort,
},
},
},
})
c.peersUpdateManager.CloseChannel(ctx, peerID)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
continue
}
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
continue
}
}
}
@@ -778,10 +822,6 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
return networkMap, nil
}
func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) {
func (c *Controller) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) {
c.peersUpdateManager.CloseChannels(ctx, peerIDs)
}
func (c *Controller) IsConnected(peerID string) bool {
return c.peersUpdateManager.HasChannel(peerID)
}

View File

@@ -12,6 +12,8 @@ type Repository interface {
GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error)
GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error)
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
}
type repository struct {
@@ -37,3 +39,11 @@ func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]*
func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) {
return r.store.GetAccountByPeerID(ctx, peerID)
}
func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
return r.store.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, peerIDs)
}
func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
}

View File

@@ -28,12 +28,12 @@ type Controller interface {
GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context)
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
CountStreams() int
DeletePeer(ctx context.Context, accountId string, peerId string) error
OnPeerUpdated(accountId string, peer *nbpeer.Peer)
OnPeerAdded(ctx context.Context, accountID string, peerID string) error
OnPeerDeleted(ctx context.Context, accountID string, peerID string) error
DisconnectPeers(ctx context.Context, peerIDs []string)
IsConnected(peerID string) bool
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)
}

View File

@@ -57,30 +57,30 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID an
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID)
}
// DeletePeer mocks base method.
func (m *MockController) DeletePeer(ctx context.Context, accountId, peerId string) error {
// CountStreams mocks base method.
func (m *MockController) CountStreams() int {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeletePeer", ctx, accountId, peerId)
ret0, _ := ret[0].(error)
ret := m.ctrl.Call(m, "CountStreams")
ret0, _ := ret[0].(int)
return ret0
}
// DeletePeer indicates an expected call of DeletePeer.
func (mr *MockControllerMockRecorder) DeletePeer(ctx, accountId, peerId any) *gomock.Call {
// CountStreams indicates an expected call of CountStreams.
func (mr *MockControllerMockRecorder) CountStreams() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeer", reflect.TypeOf((*MockController)(nil).DeletePeer), ctx, accountId, peerId)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountStreams", reflect.TypeOf((*MockController)(nil).CountStreams))
}
// DisconnectPeers mocks base method.
func (m *MockController) DisconnectPeers(ctx context.Context, peerIDs []string) {
func (m *MockController) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "DisconnectPeers", ctx, peerIDs)
m.ctrl.Call(m, "DisconnectPeers", ctx, accountId, peerIDs)
}
// DisconnectPeers indicates an expected call of DisconnectPeers.
func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, peerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, accountId, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, peerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, accountId, peerIDs)
}
// GetDNSDomain mocks base method.
@@ -130,58 +130,73 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
}
// IsConnected mocks base method.
func (m *MockController) IsConnected(peerID string) bool {
// OnPeerConnected mocks base method.
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "IsConnected", peerID)
ret0, _ := ret[0].(bool)
return ret0
ret := m.ctrl.Call(m, "OnPeerConnected", ctx, accountID, peerID)
ret0, _ := ret[0].(chan *UpdateMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// IsConnected indicates an expected call of IsConnected.
func (mr *MockControllerMockRecorder) IsConnected(peerID any) *gomock.Call {
// OnPeerConnected indicates an expected call of OnPeerConnected.
func (mr *MockControllerMockRecorder) OnPeerConnected(ctx, accountID, peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockController)(nil).IsConnected), peerID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerConnected", reflect.TypeOf((*MockController)(nil).OnPeerConnected), ctx, accountID, peerID)
}
// OnPeerAdded mocks base method.
func (m *MockController) OnPeerAdded(ctx context.Context, accountID, peerID string) error {
// OnPeerDisconnected mocks base method.
func (m *MockController) OnPeerDisconnected(ctx context.Context, accountID, peerID string) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeerAdded", ctx, accountID, peerID)
m.ctrl.Call(m, "OnPeerDisconnected", ctx, accountID, peerID)
}
// OnPeerDisconnected indicates an expected call of OnPeerDisconnected.
func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDisconnected", reflect.TypeOf((*MockController)(nil).OnPeerDisconnected), ctx, accountID, peerID)
}
// OnPeersAdded mocks base method.
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeerAdded indicates an expected call of OnPeerAdded.
func (mr *MockControllerMockRecorder) OnPeerAdded(ctx, accountID, peerID any) *gomock.Call {
// OnPeersAdded indicates an expected call of OnPeersAdded.
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerAdded", reflect.TypeOf((*MockController)(nil).OnPeerAdded), ctx, accountID, peerID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs)
}
// OnPeerDeleted mocks base method.
func (m *MockController) OnPeerDeleted(ctx context.Context, accountID, peerID string) error {
// OnPeersDeleted mocks base method.
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeerDeleted", ctx, accountID, peerID)
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeerDeleted indicates an expected call of OnPeerDeleted.
func (mr *MockControllerMockRecorder) OnPeerDeleted(ctx, accountID, peerID any) *gomock.Call {
// OnPeersDeleted indicates an expected call of OnPeersDeleted.
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDeleted", reflect.TypeOf((*MockController)(nil).OnPeerDeleted), ctx, accountID, peerID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs)
}
// OnPeerUpdated mocks base method.
func (m *MockController) OnPeerUpdated(accountId string, peer *peer.Peer) {
// OnPeersUpdated mocks base method.
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error {
m.ctrl.T.Helper()
m.ctrl.Call(m, "OnPeerUpdated", accountId, peer)
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeerUpdated indicates an expected call of OnPeerUpdated.
func (mr *MockControllerMockRecorder) OnPeerUpdated(accountId, peer any) *gomock.Call {
// OnPeersUpdated indicates an expected call of OnPeersUpdated.
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerUpdated", reflect.TypeOf((*MockController)(nil).OnPeerUpdated), accountId, peer)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs)
}
// StartWarmup mocks base method.