mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[management] Refactor network map controller (#4789)
This commit is contained in:
@@ -19,6 +19,7 @@ import (
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
@@ -42,6 +43,7 @@ type Controller struct {
|
||||
accountManagerMetrics *telemetry.AccountManagerMetrics
|
||||
peersUpdateManager network_map.PeersUpdateManager
|
||||
settingsManager settings.Manager
|
||||
EphemeralPeersManager ephemeral.Manager
|
||||
|
||||
accountUpdateLocks sync.Map
|
||||
sendAccountUpdateLocks sync.Map
|
||||
@@ -70,7 +72,7 @@ type bufferUpdate struct {
|
||||
|
||||
var _ network_map.Controller = (*Controller)(nil)
|
||||
|
||||
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, config *config.Config) *Controller {
|
||||
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
|
||||
nMetrics, err := newMetrics(metrics.UpdateChannelMetrics())
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
||||
@@ -99,7 +101,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
dnsDomain: dnsDomain,
|
||||
config: config,
|
||||
|
||||
proxyController: proxyController,
|
||||
proxyController: proxyController,
|
||||
EphemeralPeersManager: ephemeralPeersManager,
|
||||
|
||||
holder: types.NewHolder(),
|
||||
expNewNetworkMap: newNetworkMapBuilder,
|
||||
@@ -107,6 +110,31 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *network_map.UpdateMessage, error) {
|
||||
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get peer %s: %v", peerID, err)
|
||||
}
|
||||
|
||||
c.EphemeralPeersManager.OnPeerConnected(ctx, peer)
|
||||
|
||||
return c.peersUpdateManager.CreateChannel(ctx, peerID), nil
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerDisconnected(ctx context.Context, accountID string, peerID string) {
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||
peer, err := c.repo.GetPeerByID(ctx, accountID, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get peer %s: %v", peerID, err)
|
||||
return
|
||||
}
|
||||
c.EphemeralPeersManager.OnPeerDisconnected(ctx, peer)
|
||||
}
|
||||
|
||||
func (c *Controller) CountStreams() int {
|
||||
return c.peersUpdateManager.CountStreams()
|
||||
}
|
||||
|
||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
var (
|
||||
@@ -366,38 +394,6 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) DeletePeer(ctx context.Context, accountId string, peerId string) error {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peers, err := c.repo.GetAccountPeers(ctx, accountId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
|
||||
c.peersUpdateManager.SendUpdate(ctx, peerId, &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: network.CurrentSerial(),
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
FirewallRules: []*proto.FirewallRule{},
|
||||
FirewallRulesIsEmpty: true,
|
||||
DNSConfig: &proto.DNSConfig{
|
||||
ForwarderPort: dnsFwdPort,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerId)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
|
||||
if isRequiresApproval {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
@@ -698,35 +694,83 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerUpdated(accountId string, peer *nbpeer.Peer) {
|
||||
c.UpdatePeerInNetworkMapCache(accountId, peer)
|
||||
_ = c.bufferSendUpdateAccountPeers(context.Background(), accountId)
|
||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get peers by ids: %w", err)
|
||||
}
|
||||
|
||||
for _, peer := range peers {
|
||||
c.UpdatePeerInNetworkMapCache(accountID, peer)
|
||||
}
|
||||
|
||||
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerAdded(ctx context.Context, accountID string, peerID string) error {
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
for _, peerID := range peerIDs {
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
err = c.onPeerAddedUpdNetworkMapCache(account, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
func (c *Controller) OnPeerDeleted(ctx context.Context, accountID string, peerID string) error {
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
network, err := c.repo.GetAccountNetwork(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peers, err := c.repo.GetAccountPeers(ctx, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dnsFwdPort := computeForwarderPort(peers, network_map.DnsForwarderPortMinVersion)
|
||||
for _, peerID := range peerIDs {
|
||||
c.peersUpdateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
NetworkMap: &proto.NetworkMap{
|
||||
Serial: network.CurrentSerial(),
|
||||
RemotePeers: []*proto.RemotePeerConfig{},
|
||||
RemotePeersIsEmpty: true,
|
||||
FirewallRules: []*proto.FirewallRule{},
|
||||
FirewallRulesIsEmpty: true,
|
||||
DNSConfig: &proto.DNSConfig{
|
||||
ForwarderPort: dnsFwdPort,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||
|
||||
if c.experimentalNetworkMap(accountID) {
|
||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
|
||||
continue
|
||||
}
|
||||
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -778,10 +822,6 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
||||
return networkMap, nil
|
||||
}
|
||||
|
||||
func (c *Controller) DisconnectPeers(ctx context.Context, peerIDs []string) {
|
||||
func (c *Controller) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) {
|
||||
c.peersUpdateManager.CloseChannels(ctx, peerIDs)
|
||||
}
|
||||
|
||||
func (c *Controller) IsConnected(peerID string) bool {
|
||||
return c.peersUpdateManager.HasChannel(peerID)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ type Repository interface {
|
||||
GetAccountNetwork(ctx context.Context, accountID string) (*types.Network, error)
|
||||
GetAccountPeers(ctx context.Context, accountID string) ([]*peer.Peer, error)
|
||||
GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error)
|
||||
GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error)
|
||||
GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error)
|
||||
}
|
||||
|
||||
type repository struct {
|
||||
@@ -37,3 +39,11 @@ func (r *repository) GetAccountPeers(ctx context.Context, accountID string) ([]*
|
||||
func (r *repository) GetAccountByPeerID(ctx context.Context, peerID string) (*types.Account, error) {
|
||||
return r.store.GetAccountByPeerID(ctx, peerID)
|
||||
}
|
||||
|
||||
func (r *repository) GetPeersByIDs(ctx context.Context, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
|
||||
return r.store.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, peerIDs)
|
||||
}
|
||||
|
||||
func (r *repository) GetPeerByID(ctx context.Context, accountID string, peerID string) (*peer.Peer, error) {
|
||||
return r.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
@@ -28,12 +28,12 @@ type Controller interface {
|
||||
GetDNSDomain(settings *types.Settings) string
|
||||
StartWarmup(context.Context)
|
||||
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
|
||||
CountStreams() int
|
||||
|
||||
DeletePeer(ctx context.Context, accountId string, peerId string) error
|
||||
|
||||
OnPeerUpdated(accountId string, peer *nbpeer.Peer)
|
||||
OnPeerAdded(ctx context.Context, accountID string, peerID string) error
|
||||
OnPeerDeleted(ctx context.Context, accountID string, peerID string) error
|
||||
DisconnectPeers(ctx context.Context, peerIDs []string)
|
||||
IsConnected(peerID string) bool
|
||||
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
|
||||
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
|
||||
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
|
||||
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
|
||||
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
|
||||
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)
|
||||
}
|
||||
|
||||
@@ -57,30 +57,30 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID an
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID)
|
||||
}
|
||||
|
||||
// DeletePeer mocks base method.
|
||||
func (m *MockController) DeletePeer(ctx context.Context, accountId, peerId string) error {
|
||||
// CountStreams mocks base method.
|
||||
func (m *MockController) CountStreams() int {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeletePeer", ctx, accountId, peerId)
|
||||
ret0, _ := ret[0].(error)
|
||||
ret := m.ctrl.Call(m, "CountStreams")
|
||||
ret0, _ := ret[0].(int)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeletePeer indicates an expected call of DeletePeer.
|
||||
func (mr *MockControllerMockRecorder) DeletePeer(ctx, accountId, peerId any) *gomock.Call {
|
||||
// CountStreams indicates an expected call of CountStreams.
|
||||
func (mr *MockControllerMockRecorder) CountStreams() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeer", reflect.TypeOf((*MockController)(nil).DeletePeer), ctx, accountId, peerId)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountStreams", reflect.TypeOf((*MockController)(nil).CountStreams))
|
||||
}
|
||||
|
||||
// DisconnectPeers mocks base method.
|
||||
func (m *MockController) DisconnectPeers(ctx context.Context, peerIDs []string) {
|
||||
func (m *MockController) DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "DisconnectPeers", ctx, peerIDs)
|
||||
m.ctrl.Call(m, "DisconnectPeers", ctx, accountId, peerIDs)
|
||||
}
|
||||
|
||||
// DisconnectPeers indicates an expected call of DisconnectPeers.
|
||||
func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, peerIDs any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) DisconnectPeers(ctx, accountId, peerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, peerIDs)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DisconnectPeers", reflect.TypeOf((*MockController)(nil).DisconnectPeers), ctx, accountId, peerIDs)
|
||||
}
|
||||
|
||||
// GetDNSDomain mocks base method.
|
||||
@@ -130,58 +130,73 @@ func (mr *MockControllerMockRecorder) GetValidatedPeerWithMap(ctx, isRequiresApp
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetValidatedPeerWithMap", reflect.TypeOf((*MockController)(nil).GetValidatedPeerWithMap), ctx, isRequiresApproval, accountID, p)
|
||||
}
|
||||
|
||||
// IsConnected mocks base method.
|
||||
func (m *MockController) IsConnected(peerID string) bool {
|
||||
// OnPeerConnected mocks base method.
|
||||
func (m *MockController) OnPeerConnected(ctx context.Context, accountID, peerID string) (chan *UpdateMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "IsConnected", peerID)
|
||||
ret0, _ := ret[0].(bool)
|
||||
return ret0
|
||||
ret := m.ctrl.Call(m, "OnPeerConnected", ctx, accountID, peerID)
|
||||
ret0, _ := ret[0].(chan *UpdateMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// IsConnected indicates an expected call of IsConnected.
|
||||
func (mr *MockControllerMockRecorder) IsConnected(peerID any) *gomock.Call {
|
||||
// OnPeerConnected indicates an expected call of OnPeerConnected.
|
||||
func (mr *MockControllerMockRecorder) OnPeerConnected(ctx, accountID, peerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IsConnected", reflect.TypeOf((*MockController)(nil).IsConnected), peerID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerConnected", reflect.TypeOf((*MockController)(nil).OnPeerConnected), ctx, accountID, peerID)
|
||||
}
|
||||
|
||||
// OnPeerAdded mocks base method.
|
||||
func (m *MockController) OnPeerAdded(ctx context.Context, accountID, peerID string) error {
|
||||
// OnPeerDisconnected mocks base method.
|
||||
func (m *MockController) OnPeerDisconnected(ctx context.Context, accountID, peerID string) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeerAdded", ctx, accountID, peerID)
|
||||
m.ctrl.Call(m, "OnPeerDisconnected", ctx, accountID, peerID)
|
||||
}
|
||||
|
||||
// OnPeerDisconnected indicates an expected call of OnPeerDisconnected.
|
||||
func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDisconnected", reflect.TypeOf((*MockController)(nil).OnPeerDisconnected), ctx, accountID, peerID)
|
||||
}
|
||||
|
||||
// OnPeersAdded mocks base method.
|
||||
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeerAdded indicates an expected call of OnPeerAdded.
|
||||
func (mr *MockControllerMockRecorder) OnPeerAdded(ctx, accountID, peerID any) *gomock.Call {
|
||||
// OnPeersAdded indicates an expected call of OnPeersAdded.
|
||||
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerAdded", reflect.TypeOf((*MockController)(nil).OnPeerAdded), ctx, accountID, peerID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
// OnPeerDeleted mocks base method.
|
||||
func (m *MockController) OnPeerDeleted(ctx context.Context, accountID, peerID string) error {
|
||||
// OnPeersDeleted mocks base method.
|
||||
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OnPeerDeleted", ctx, accountID, peerID)
|
||||
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeerDeleted indicates an expected call of OnPeerDeleted.
|
||||
func (mr *MockControllerMockRecorder) OnPeerDeleted(ctx, accountID, peerID any) *gomock.Call {
|
||||
// OnPeersDeleted indicates an expected call of OnPeersDeleted.
|
||||
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerDeleted", reflect.TypeOf((*MockController)(nil).OnPeerDeleted), ctx, accountID, peerID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs)
|
||||
}
|
||||
|
||||
// OnPeerUpdated mocks base method.
|
||||
func (m *MockController) OnPeerUpdated(accountId string, peer *peer.Peer) {
|
||||
// OnPeersUpdated mocks base method.
|
||||
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "OnPeerUpdated", accountId, peer)
|
||||
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// OnPeerUpdated indicates an expected call of OnPeerUpdated.
|
||||
func (mr *MockControllerMockRecorder) OnPeerUpdated(accountId, peer any) *gomock.Call {
|
||||
// OnPeersUpdated indicates an expected call of OnPeersUpdated.
|
||||
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeerUpdated", reflect.TypeOf((*MockController)(nil).OnPeerUpdated), accountId, peer)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs)
|
||||
}
|
||||
|
||||
// StartWarmup mocks base method.
|
||||
|
||||
19
management/internals/modules/peers/ephemeral/interface.go
Normal file
19
management/internals/modules/peers/ephemeral/interface.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package ephemeral
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
const (
|
||||
EphemeralLifeTime = 10 * time.Minute
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
LoadInitialPeers(ctx context.Context)
|
||||
Stop()
|
||||
OnPeerConnected(ctx context.Context, peer *nbpeer.Peer)
|
||||
OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer)
|
||||
}
|
||||
@@ -0,0 +1,250 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
)
|
||||
|
||||
const (
|
||||
// cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure.
|
||||
cleanupWindow = 1 * time.Minute
|
||||
)
|
||||
|
||||
var (
|
||||
timeNow = time.Now
|
||||
)
|
||||
|
||||
type ephemeralPeer struct {
|
||||
id string
|
||||
accountID string
|
||||
deadline time.Time
|
||||
next *ephemeralPeer
|
||||
}
|
||||
|
||||
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
|
||||
// in worst case we will get invalid error message in this manager.
|
||||
|
||||
// EphemeralManager keep a list of ephemeral peers. After EphemeralLifeTime inactivity the peer will be deleted
|
||||
// automatically. Inactivity means the peer disconnected from the Management server.
|
||||
type EphemeralManager struct {
|
||||
store store.Store
|
||||
peersManager peers.Manager
|
||||
|
||||
headPeer *ephemeralPeer
|
||||
tailPeer *ephemeralPeer
|
||||
peersLock sync.Mutex
|
||||
timer *time.Timer
|
||||
|
||||
lifeTime time.Duration
|
||||
cleanupWindow time.Duration
|
||||
}
|
||||
|
||||
// NewEphemeralManager instantiate new EphemeralManager
|
||||
func NewEphemeralManager(store store.Store, peersManager peers.Manager) *EphemeralManager {
|
||||
return &EphemeralManager{
|
||||
store: store,
|
||||
peersManager: peersManager,
|
||||
|
||||
lifeTime: ephemeral.EphemeralLifeTime,
|
||||
cleanupWindow: cleanupWindow,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
|
||||
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
|
||||
// head.
|
||||
func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) {
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
e.loadEphemeralPeers(ctx)
|
||||
if e.headPeer != nil {
|
||||
e.timer = time.AfterFunc(e.lifeTime, func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Stop timer
|
||||
func (e *EphemeralManager) Stop() {
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
if e.timer != nil {
|
||||
e.timer.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer
|
||||
// is active the manager will not delete it while it is active.
|
||||
func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) {
|
||||
if !peer.Ephemeral {
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("remove peer from ephemeral list: %s", peer.ID)
|
||||
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
e.removePeer(peer.ID)
|
||||
|
||||
// stop the unnecessary timer
|
||||
if e.headPeer == nil && e.timer != nil {
|
||||
e.timer.Stop()
|
||||
e.timer = nil
|
||||
}
|
||||
}
|
||||
|
||||
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
|
||||
// is inactive it will be deleted after the EphemeralLifeTime period.
|
||||
func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
|
||||
if !peer.Ephemeral {
|
||||
return
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
|
||||
|
||||
e.peersLock.Lock()
|
||||
defer e.peersLock.Unlock()
|
||||
|
||||
if e.isPeerOnList(peer.ID) {
|
||||
return
|
||||
}
|
||||
|
||||
e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
|
||||
if e.timer == nil {
|
||||
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
|
||||
if delay < 0 {
|
||||
delay = 0
|
||||
}
|
||||
e.timer = time.AfterFunc(delay, func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
|
||||
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
t := e.newDeadLine()
|
||||
for _, p := range peers {
|
||||
e.addPeer(p.AccountID, p.ID, t)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers))
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) cleanup(ctx context.Context) {
|
||||
log.Tracef("on ephemeral cleanup")
|
||||
deletePeers := make(map[string]*ephemeralPeer)
|
||||
|
||||
e.peersLock.Lock()
|
||||
now := timeNow()
|
||||
for p := e.headPeer; p != nil; p = p.next {
|
||||
if now.Before(p.deadline) {
|
||||
break
|
||||
}
|
||||
|
||||
deletePeers[p.id] = p
|
||||
e.headPeer = p.next
|
||||
if p.next == nil {
|
||||
e.tailPeer = nil
|
||||
}
|
||||
}
|
||||
|
||||
if e.headPeer != nil {
|
||||
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
|
||||
if delay < 0 {
|
||||
delay = 0
|
||||
}
|
||||
e.timer = time.AfterFunc(delay, func() {
|
||||
e.cleanup(ctx)
|
||||
})
|
||||
} else {
|
||||
e.timer = nil
|
||||
}
|
||||
|
||||
e.peersLock.Unlock()
|
||||
|
||||
peerIDsPerAccount := make(map[string][]string)
|
||||
for id, p := range deletePeers {
|
||||
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
|
||||
}
|
||||
|
||||
for accountID, peerIDs := range peerIDsPerAccount {
|
||||
log.WithContext(ctx).Debugf("delete ephemeral peers for account: %s", accountID)
|
||||
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
|
||||
ep := &ephemeralPeer{
|
||||
id: peerID,
|
||||
accountID: accountID,
|
||||
deadline: deadline,
|
||||
}
|
||||
|
||||
if e.headPeer == nil {
|
||||
e.headPeer = ep
|
||||
}
|
||||
if e.tailPeer != nil {
|
||||
e.tailPeer.next = ep
|
||||
}
|
||||
e.tailPeer = ep
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) removePeer(id string) {
|
||||
if e.headPeer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if e.headPeer.id == id {
|
||||
e.headPeer = e.headPeer.next
|
||||
if e.tailPeer.id == id {
|
||||
e.tailPeer = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for p := e.headPeer; p.next != nil; p = p.next {
|
||||
if p.next.id == id {
|
||||
// if we remove the last element from the chain then set the last-1 as tail
|
||||
if e.tailPeer.id == id {
|
||||
e.tailPeer = p
|
||||
}
|
||||
p.next = p.next.next
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) isPeerOnList(id string) bool {
|
||||
for p := e.headPeer; p != nil; p = p.next {
|
||||
if p.id == id {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (e *EphemeralManager) newDeadLine() time.Time {
|
||||
return timeNow().Add(e.lifeTime)
|
||||
}
|
||||
@@ -0,0 +1,353 @@
|
||||
package manager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
type MockStore struct {
|
||||
store.Store
|
||||
account *types.Account
|
||||
}
|
||||
|
||||
func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ store.LockingStrength) ([]*nbpeer.Peer, error) {
|
||||
var peers []*nbpeer.Peer
|
||||
for _, v := range s.account.Peers {
|
||||
if v.Ephemeral {
|
||||
peers = append(peers, v)
|
||||
}
|
||||
}
|
||||
return peers, nil
|
||||
}
|
||||
|
||||
type MockAccountManager struct {
|
||||
mu sync.Mutex
|
||||
nbAccount.Manager
|
||||
store *MockStore
|
||||
deletePeerCalls int
|
||||
bufferUpdateCalls map[string]int
|
||||
wg *sync.WaitGroup
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
a.deletePeerCalls++
|
||||
delete(a.store.account.Peers, peerID)
|
||||
if a.wg != nil {
|
||||
a.wg.Done()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) GetDeletePeerCalls() int {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
return a.deletePeerCalls
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if a.bufferUpdateCalls == nil {
|
||||
a.bufferUpdateCalls = make(map[string]int)
|
||||
}
|
||||
a.bufferUpdateCalls[accountID]++
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) GetBufferUpdateCalls(accountID string) int {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if a.bufferUpdateCalls == nil {
|
||||
return 0
|
||||
}
|
||||
return a.bufferUpdateCalls[accountID]
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) GetStore() store.Store {
|
||||
return a.store
|
||||
}
|
||||
|
||||
func TestNewManager(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
timeNow = time.Now
|
||||
})
|
||||
startTime := time.Now()
|
||||
timeNow = func() time.Time {
|
||||
return startTime
|
||||
}
|
||||
|
||||
store := &MockStore{}
|
||||
ctrl := gomock.NewController(t)
|
||||
peersManager := peers.NewMockManager(ctrl)
|
||||
|
||||
numberOfPeers := 5
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
// Expect DeletePeers to be called for ephemeral peers
|
||||
peersManager.EXPECT().
|
||||
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
|
||||
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
for _, peerID := range peerIDs {
|
||||
delete(store.account.Peers, peerID)
|
||||
}
|
||||
return nil
|
||||
}).
|
||||
AnyTimes()
|
||||
|
||||
mgr := NewEphemeralManager(store, peersManager)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
if len(store.account.Peers) != numberOfPeers {
|
||||
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewManagerPeerConnected(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
timeNow = time.Now
|
||||
})
|
||||
startTime := time.Now()
|
||||
timeNow = func() time.Time {
|
||||
return startTime
|
||||
}
|
||||
|
||||
store := &MockStore{}
|
||||
ctrl := gomock.NewController(t)
|
||||
peersManager := peers.NewMockManager(ctrl)
|
||||
|
||||
numberOfPeers := 5
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
// Expect DeletePeers to be called for ephemeral peers (except the connected one)
|
||||
peersManager.EXPECT().
|
||||
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
|
||||
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
for _, peerID := range peerIDs {
|
||||
delete(store.account.Peers, peerID)
|
||||
}
|
||||
return nil
|
||||
}).
|
||||
AnyTimes()
|
||||
|
||||
mgr := NewEphemeralManager(store, peersManager)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
||||
|
||||
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
expected := numberOfPeers + 1
|
||||
if len(store.account.Peers) != expected {
|
||||
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewManagerPeerDisconnected(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
timeNow = time.Now
|
||||
})
|
||||
startTime := time.Now()
|
||||
timeNow = func() time.Time {
|
||||
return startTime
|
||||
}
|
||||
|
||||
store := &MockStore{}
|
||||
ctrl := gomock.NewController(t)
|
||||
peersManager := peers.NewMockManager(ctrl)
|
||||
|
||||
numberOfPeers := 5
|
||||
numberOfEphemeralPeers := 3
|
||||
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
|
||||
|
||||
// Expect DeletePeers to be called for the one disconnected peer
|
||||
peersManager.EXPECT().
|
||||
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
|
||||
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
for _, peerID := range peerIDs {
|
||||
delete(store.account.Peers, peerID)
|
||||
}
|
||||
return nil
|
||||
}).
|
||||
AnyTimes()
|
||||
|
||||
mgr := NewEphemeralManager(store, peersManager)
|
||||
mgr.loadEphemeralPeers(context.Background())
|
||||
for _, v := range store.account.Peers {
|
||||
mgr.OnPeerConnected(context.Background(), v)
|
||||
|
||||
}
|
||||
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
|
||||
|
||||
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
|
||||
mgr.cleanup(context.Background())
|
||||
|
||||
expected := numberOfPeers + numberOfEphemeralPeers - 1
|
||||
if len(store.account.Peers) != expected {
|
||||
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
|
||||
const (
|
||||
ephemeralPeers = 10
|
||||
testLifeTime = 1 * time.Second
|
||||
testCleanupWindow = 100 * time.Millisecond
|
||||
)
|
||||
|
||||
t.Cleanup(func() {
|
||||
timeNow = time.Now
|
||||
})
|
||||
startTime := time.Now()
|
||||
timeNow = func() time.Time {
|
||||
return startTime
|
||||
}
|
||||
|
||||
mockStore := &MockStore{}
|
||||
account := newAccountWithId(context.Background(), "account", "", "", false)
|
||||
mockStore.account = account
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(ephemeralPeers)
|
||||
mockAM := &MockAccountManager{
|
||||
store: mockStore,
|
||||
wg: wg,
|
||||
}
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
peersManager := peers.NewMockManager(ctrl)
|
||||
|
||||
// Set up expectation that DeletePeers will be called once with all peer IDs
|
||||
peersManager.EXPECT().
|
||||
DeletePeers(gomock.Any(), account.Id, gomock.Any(), gomock.Any(), true).
|
||||
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
// Simulate the actual deletion behavior
|
||||
for _, peerID := range peerIDs {
|
||||
err := mockAM.DeletePeer(ctx, accountID, peerID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
mockAM.BufferUpdateAccountPeers(ctx, accountID)
|
||||
return nil
|
||||
}).
|
||||
Times(1)
|
||||
|
||||
mgr := NewEphemeralManager(mockStore, peersManager)
|
||||
mgr.lifeTime = testLifeTime
|
||||
mgr.cleanupWindow = testCleanupWindow
|
||||
|
||||
// Add peers and disconnect them at slightly different times (within cleanup window)
|
||||
for i := range ephemeralPeers {
|
||||
p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true}
|
||||
mockStore.account.Peers[p.ID] = p
|
||||
mgr.OnPeerDisconnected(context.Background(), p)
|
||||
startTime = startTime.Add(testCleanupWindow / (ephemeralPeers * 2))
|
||||
}
|
||||
|
||||
// Advance time past the lifetime to trigger cleanup
|
||||
startTime = startTime.Add(testLifeTime + testCleanupWindow)
|
||||
|
||||
// Wait for all deletions to complete
|
||||
wg.Wait()
|
||||
|
||||
assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime")
|
||||
assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once")
|
||||
assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers")
|
||||
}
|
||||
|
||||
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
|
||||
store.account = newAccountWithId(context.Background(), "my account", "", "", false)
|
||||
|
||||
for i := 0; i < numberOfPeers; i++ {
|
||||
peerId := fmt.Sprintf("peer_%d", i)
|
||||
p := &nbpeer.Peer{
|
||||
ID: peerId,
|
||||
Ephemeral: false,
|
||||
}
|
||||
store.account.Peers[p.ID] = p
|
||||
}
|
||||
|
||||
for i := 0; i < numberOfEphemeralPeers; i++ {
|
||||
peerId := fmt.Sprintf("ephemeral_peer_%d", i)
|
||||
p := &nbpeer.Peer{
|
||||
ID: peerId,
|
||||
Ephemeral: true,
|
||||
}
|
||||
store.account.Peers[p.ID] = p
|
||||
}
|
||||
}
|
||||
|
||||
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
|
||||
func newAccountWithId(ctx context.Context, accountID, userID, domain string, disableDefaultPolicy bool) *types.Account {
|
||||
log.WithContext(ctx).Debugf("creating new account")
|
||||
|
||||
network := types.NewNetwork()
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
users := make(map[string]*types.User)
|
||||
routes := make(map[route.ID]*route.Route)
|
||||
setupKeys := map[string]*types.SetupKey{}
|
||||
nameServersGroups := make(map[string]*nbdns.NameServerGroup)
|
||||
|
||||
owner := types.NewOwnerUser(userID)
|
||||
owner.AccountID = accountID
|
||||
users[userID] = owner
|
||||
|
||||
dnsSettings := types.DNSSettings{
|
||||
DisabledManagementGroups: make([]string, 0),
|
||||
}
|
||||
log.WithContext(ctx).Debugf("created new account %s", accountID)
|
||||
|
||||
acc := &types.Account{
|
||||
Id: accountID,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
SetupKeys: setupKeys,
|
||||
Network: network,
|
||||
Peers: peers,
|
||||
Users: users,
|
||||
CreatedBy: userID,
|
||||
Domain: domain,
|
||||
Routes: routes,
|
||||
NameServerGroups: nameServersGroups,
|
||||
DNSSettings: dnsSettings,
|
||||
Settings: &types.Settings{
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
|
||||
GroupsPropagationEnabled: true,
|
||||
RegularUsersViewBlocked: true,
|
||||
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
PeerInactivityExpiration: types.DefaultPeerInactivityExpiration,
|
||||
RoutingPeerDNSResolutionEnabled: true,
|
||||
},
|
||||
Onboarding: types.AccountOnboarding{
|
||||
OnboardingFlowPending: true,
|
||||
SignupFormPending: true,
|
||||
},
|
||||
}
|
||||
|
||||
if err := acc.AddAllGroup(disableDefaultPolicy); err != nil {
|
||||
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
|
||||
}
|
||||
return acc
|
||||
}
|
||||
162
management/internals/modules/peers/manager.go
Normal file
162
management/internals/modules/peers/manager.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package peers
|
||||
|
||||
//go:generate go run github.com/golang/mock/mockgen -package peers -destination=manager_mock.go -source=./manager.go -build_flags=-mod=mod
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error)
|
||||
GetPeerAccountID(ctx context.Context, peerID string) (string, error)
|
||||
GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error)
|
||||
GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error)
|
||||
DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error
|
||||
SetNetworkMapController(networkMapController network_map.Controller)
|
||||
SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator)
|
||||
SetAccountManager(accountManager account.Manager)
|
||||
}
|
||||
|
||||
type managerImpl struct {
|
||||
store store.Store
|
||||
permissionsManager permissions.Manager
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||
accountManager account.Manager
|
||||
|
||||
networkMapController network_map.Controller
|
||||
}
|
||||
|
||||
func NewManager(store store.Store, permissionsManager permissions.Manager) Manager {
|
||||
return &managerImpl{
|
||||
store: store,
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) SetNetworkMapController(networkMapController network_map.Controller) {
|
||||
m.networkMapController = networkMapController
|
||||
}
|
||||
|
||||
func (m *managerImpl) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) {
|
||||
m.integratedPeerValidator = integratedPeerValidator
|
||||
}
|
||||
|
||||
func (m *managerImpl) SetAccountManager(accountManager account.Manager) {
|
||||
m.accountManager = accountManager
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
|
||||
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
|
||||
allowed, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Peers, operations.Read)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
|
||||
}
|
||||
|
||||
return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
|
||||
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
|
||||
}
|
||||
|
||||
func (m *managerImpl) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
|
||||
return m.store.GetPeersByGroupIDs(ctx, accountID, groupsIDs)
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dnsDomain := m.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
for _, peerID := range peerIDs {
|
||||
var eventsToStore []func()
|
||||
err := m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
peer, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if checkConnected && (peer.Status.Connected || peer.Status.LastSeen.After(time.Now().Add(-(ephemeral.EphemeralLifeTime - 10*time.Second)))) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := transaction.RemovePeerFromAllGroups(ctx, peerID); err != nil {
|
||||
return fmt.Errorf("failed to remove peer %s from groups", peerID)
|
||||
}
|
||||
|
||||
if err := m.integratedPeerValidator.PeerDeleted(ctx, accountID, peerID, settings.Extra); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
peerPolicyRules, err := transaction.GetPolicyRulesByResourceID(ctx, store.LockingStrengthNone, accountID, peerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, rule := range peerPolicyRules {
|
||||
policy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, rule.PolicyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = transaction.DeletePolicy(ctx, accountID, rule.PolicyID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
|
||||
})
|
||||
}
|
||||
|
||||
if err = transaction.DeletePeer(ctx, accountID, peerID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eventsToStore = append(eventsToStore, func() {
|
||||
m.accountManager.StoreEvent(ctx, userID, peer.ID, accountID, activity.PeerRemovedByUser, peer.EventMeta(dnsDomain))
|
||||
})
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, event := range eventsToStore {
|
||||
event()
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
149
management/internals/modules/peers/manager_mock.go
Normal file
149
management/internals/modules/peers/manager_mock.go
Normal file
@@ -0,0 +1,149 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: ./manager.go
|
||||
|
||||
// Package peers is a generated GoMock package.
|
||||
package peers
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
network_map "github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
account "github.com/netbirdio/netbird/management/server/account"
|
||||
integrated_validator "github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
peer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
// MockManager is a mock of Manager interface.
|
||||
type MockManager struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockManagerMockRecorder
|
||||
}
|
||||
|
||||
// MockManagerMockRecorder is the mock recorder for MockManager.
|
||||
type MockManagerMockRecorder struct {
|
||||
mock *MockManager
|
||||
}
|
||||
|
||||
// NewMockManager creates a new mock instance.
|
||||
func NewMockManager(ctrl *gomock.Controller) *MockManager {
|
||||
mock := &MockManager{ctrl: ctrl}
|
||||
mock.recorder = &MockManagerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockManager) EXPECT() *MockManagerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// DeletePeers mocks base method.
|
||||
func (m *MockManager) DeletePeers(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeletePeers", ctx, accountID, peerIDs, userID, checkConnected)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeletePeers indicates an expected call of DeletePeers.
|
||||
func (mr *MockManagerMockRecorder) DeletePeers(ctx, accountID, peerIDs, userID, checkConnected interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeletePeers", reflect.TypeOf((*MockManager)(nil).DeletePeers), ctx, accountID, peerIDs, userID, checkConnected)
|
||||
}
|
||||
|
||||
// GetAllPeers mocks base method.
|
||||
func (m *MockManager) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAllPeers", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].([]*peer.Peer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAllPeers indicates an expected call of GetAllPeers.
|
||||
func (mr *MockManagerMockRecorder) GetAllPeers(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllPeers", reflect.TypeOf((*MockManager)(nil).GetAllPeers), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// GetPeer mocks base method.
|
||||
func (m *MockManager) GetPeer(ctx context.Context, accountID, userID, peerID string) (*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeer", ctx, accountID, userID, peerID)
|
||||
ret0, _ := ret[0].(*peer.Peer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPeer indicates an expected call of GetPeer.
|
||||
func (mr *MockManagerMockRecorder) GetPeer(ctx, accountID, userID, peerID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeer", reflect.TypeOf((*MockManager)(nil).GetPeer), ctx, accountID, userID, peerID)
|
||||
}
|
||||
|
||||
// GetPeerAccountID mocks base method.
|
||||
func (m *MockManager) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeerAccountID", ctx, peerID)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPeerAccountID indicates an expected call of GetPeerAccountID.
|
||||
func (mr *MockManagerMockRecorder) GetPeerAccountID(ctx, peerID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerAccountID", reflect.TypeOf((*MockManager)(nil).GetPeerAccountID), ctx, peerID)
|
||||
}
|
||||
|
||||
// GetPeersByGroupIDs mocks base method.
|
||||
func (m *MockManager) GetPeersByGroupIDs(ctx context.Context, accountID string, groupsIDs []string) ([]*peer.Peer, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetPeersByGroupIDs", ctx, accountID, groupsIDs)
|
||||
ret0, _ := ret[0].([]*peer.Peer)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetPeersByGroupIDs indicates an expected call of GetPeersByGroupIDs.
|
||||
func (mr *MockManagerMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupsIDs interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockManager)(nil).GetPeersByGroupIDs), ctx, accountID, groupsIDs)
|
||||
}
|
||||
|
||||
// SetAccountManager mocks base method.
|
||||
func (m *MockManager) SetAccountManager(accountManager account.Manager) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetAccountManager", accountManager)
|
||||
}
|
||||
|
||||
// SetAccountManager indicates an expected call of SetAccountManager.
|
||||
func (mr *MockManagerMockRecorder) SetAccountManager(accountManager interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetAccountManager", reflect.TypeOf((*MockManager)(nil).SetAccountManager), accountManager)
|
||||
}
|
||||
|
||||
// SetIntegratedPeerValidator mocks base method.
|
||||
func (m *MockManager) SetIntegratedPeerValidator(integratedPeerValidator integrated_validator.IntegratedValidator) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetIntegratedPeerValidator", integratedPeerValidator)
|
||||
}
|
||||
|
||||
// SetIntegratedPeerValidator indicates an expected call of SetIntegratedPeerValidator.
|
||||
func (mr *MockManagerMockRecorder) SetIntegratedPeerValidator(integratedPeerValidator interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetIntegratedPeerValidator", reflect.TypeOf((*MockManager)(nil).SetIntegratedPeerValidator), integratedPeerValidator)
|
||||
}
|
||||
|
||||
// SetNetworkMapController mocks base method.
|
||||
func (m *MockManager) SetNetworkMapController(networkMapController network_map.Controller) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "SetNetworkMapController", networkMapController)
|
||||
}
|
||||
|
||||
// SetNetworkMapController indicates an expected call of SetNetworkMapController.
|
||||
func (mr *MockManagerMockRecorder) SetNetworkMapController(networkMapController interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetworkMapController", reflect.TypeOf((*MockManager)(nil).SetNetworkMapController), networkMapController)
|
||||
}
|
||||
@@ -57,7 +57,7 @@ func (s *BaseServer) Metrics() telemetry.AppMetrics {
|
||||
|
||||
func (s *BaseServer) Store() store.Store {
|
||||
return Create(s, func() store.Store {
|
||||
store, err := store.NewStore(context.Background(), s.config.StoreConfig.Engine, s.config.Datadir, s.Metrics(), false)
|
||||
store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create store: %v", err)
|
||||
}
|
||||
@@ -73,17 +73,17 @@ func (s *BaseServer) EventStore() activity.Store {
|
||||
log.Fatalf("failed to initialize integration metrics: %v", err)
|
||||
}
|
||||
|
||||
eventStore, key, err := integrations.InitEventStore(context.Background(), s.config.Datadir, s.config.DataStoreEncryptionKey, integrationMetrics)
|
||||
eventStore, key, err := integrations.InitEventStore(context.Background(), s.Config.Datadir, s.Config.DataStoreEncryptionKey, integrationMetrics)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to initialize event store: %v", err)
|
||||
}
|
||||
|
||||
if s.config.DataStoreEncryptionKey != key {
|
||||
log.WithContext(context.Background()).Infof("update config with activity store key")
|
||||
s.config.DataStoreEncryptionKey = key
|
||||
err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.config)
|
||||
if s.Config.DataStoreEncryptionKey != key {
|
||||
log.WithContext(context.Background()).Infof("update Config with activity store key")
|
||||
s.Config.DataStoreEncryptionKey = key
|
||||
err := updateMgmtConfig(context.Background(), nbconfig.MgmtConfigPath, s.Config)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to update config with activity store: %v", err)
|
||||
log.Fatalf("failed to update Config with activity store: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -103,14 +103,14 @@ func (s *BaseServer) APIHandler() http.Handler {
|
||||
|
||||
func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
return Create(s, func() *grpc.Server {
|
||||
trustedPeers := s.config.ReverseProxy.TrustedPeers
|
||||
trustedPeers := s.Config.ReverseProxy.TrustedPeers
|
||||
defaultTrustedPeers := []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0"), netip.MustParsePrefix("::/0")}
|
||||
if len(trustedPeers) == 0 || slices.Equal[[]netip.Prefix](trustedPeers, defaultTrustedPeers) {
|
||||
log.WithContext(context.Background()).Warn("TrustedPeers are configured to default value '0.0.0.0/0', '::/0'. This allows connection IP spoofing.")
|
||||
trustedPeers = defaultTrustedPeers
|
||||
}
|
||||
trustedHTTPProxies := s.config.ReverseProxy.TrustedHTTPProxies
|
||||
trustedProxiesCount := s.config.ReverseProxy.TrustedHTTPProxiesCount
|
||||
trustedHTTPProxies := s.Config.ReverseProxy.TrustedHTTPProxies
|
||||
trustedProxiesCount := s.Config.ReverseProxy.TrustedHTTPProxiesCount
|
||||
if len(trustedHTTPProxies) > 0 && trustedProxiesCount > 0 {
|
||||
log.WithContext(context.Background()).Warn("TrustedHTTPProxies and TrustedHTTPProxiesCount both are configured. " +
|
||||
"This is not recommended way to extract X-Forwarded-For. Consider using one of these options.")
|
||||
@@ -128,15 +128,15 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
grpc.ChainStreamInterceptor(realip.StreamServerInterceptorOpts(realipOpts...), streamInterceptor),
|
||||
}
|
||||
|
||||
if s.config.HttpConfig.LetsEncryptDomain != "" {
|
||||
certManager, err := encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
|
||||
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
||||
certManager, err := encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create certificate manager: %v", err)
|
||||
}
|
||||
transportCredentials := credentials.NewTLS(certManager.TLSConfig())
|
||||
gRPCOpts = append(gRPCOpts, grpc.Creds(transportCredentials))
|
||||
} else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
|
||||
tlsConfig, err := loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
|
||||
} else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" {
|
||||
tlsConfig, err := loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey)
|
||||
if err != nil {
|
||||
log.Fatalf("cannot load TLS credentials: %v", err)
|
||||
}
|
||||
@@ -145,7 +145,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
}
|
||||
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
srv, err := nbgrpc.NewServer(s.config, s.AccountManager(), s.SettingsManager(), s.PeersUpdateManager(), s.SecretsManager(), s.Metrics(), s.EphemeralManager(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController())
|
||||
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create management server: %v", err)
|
||||
}
|
||||
|
||||
@@ -9,17 +9,17 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nmapcontroller "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/auth"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral/manager"
|
||||
)
|
||||
|
||||
func (s *BaseServer) PeersUpdateManager() network_map.PeersUpdateManager {
|
||||
return Create(s, func() *update_channel.PeersUpdateManager {
|
||||
return Create(s, func() network_map.PeersUpdateManager {
|
||||
return update_channel.NewPeersUpdateManager(s.Metrics())
|
||||
})
|
||||
}
|
||||
@@ -44,33 +44,37 @@ func (s *BaseServer) ProxyController() port_forwarding.Controller {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) SecretsManager() *grpc.TimeBasedAuthSecretsManager {
|
||||
return Create(s, func() *grpc.TimeBasedAuthSecretsManager {
|
||||
return grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.config.TURNConfig, s.config.Relay, s.SettingsManager(), s.GroupsManager())
|
||||
func (s *BaseServer) SecretsManager() grpc.SecretsManager {
|
||||
return Create(s, func() grpc.SecretsManager {
|
||||
secretsManager, err := grpc.NewTimeBasedAuthSecretsManager(s.PeersUpdateManager(), s.Config.TURNConfig, s.Config.Relay, s.SettingsManager(), s.GroupsManager())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create secrets manager: %v", err)
|
||||
}
|
||||
return secretsManager
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AuthManager() auth.Manager {
|
||||
return Create(s, func() auth.Manager {
|
||||
return auth.NewManager(s.Store(),
|
||||
s.config.HttpConfig.AuthIssuer,
|
||||
s.config.HttpConfig.AuthAudience,
|
||||
s.config.HttpConfig.AuthKeysLocation,
|
||||
s.config.HttpConfig.AuthUserIDClaim,
|
||||
s.config.GetAuthAudiences(),
|
||||
s.config.HttpConfig.IdpSignKeyRefreshEnabled)
|
||||
s.Config.HttpConfig.AuthIssuer,
|
||||
s.Config.HttpConfig.AuthAudience,
|
||||
s.Config.HttpConfig.AuthKeysLocation,
|
||||
s.Config.HttpConfig.AuthUserIDClaim,
|
||||
s.Config.GetAuthAudiences(),
|
||||
s.Config.HttpConfig.IdpSignKeyRefreshEnabled)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
|
||||
return Create(s, func() ephemeral.Manager {
|
||||
return manager.NewEphemeralManager(s.Store(), s.AccountManager())
|
||||
return manager.NewEphemeralManager(s.Store(), s.PeersManager())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) NetworkMapController() network_map.Controller {
|
||||
return Create(s, func() *nmapcontroller.Controller {
|
||||
return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.dnsDomain, s.ProxyController(), s.config)
|
||||
return Create(s, func() network_map.Controller {
|
||||
return nmapcontroller.NewController(context.Background(), s.Store(), s.Metrics(), s.PeersUpdateManager(), s.AccountRequestBuffer(), s.IntegratedValidator(), s.SettingsManager(), s.DNSDomain(), s.ProxyController(), s.EphemeralManager(), s.Config)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -79,3 +83,7 @@ func (s *BaseServer) AccountRequestBuffer() *server.AccountRequestBuffer {
|
||||
return server.NewAccountRequestBuffer(context.Background(), s.Store())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) DNSDomain() string {
|
||||
return s.dnsDomain
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
@@ -14,7 +15,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/peers"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
@@ -22,12 +23,12 @@ import (
|
||||
|
||||
func (s *BaseServer) GeoLocationManager() geolocation.Geolocation {
|
||||
return Create(s, func() geolocation.Geolocation {
|
||||
geo, err := geolocation.NewGeolocation(context.Background(), s.config.Datadir, !s.disableGeoliteUpdate)
|
||||
geo, err := geolocation.NewGeolocation(context.Background(), s.Config.Datadir, !s.disableGeoliteUpdate)
|
||||
if err != nil {
|
||||
log.Fatalf("could not initialize geolocation service: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("geolocation service has been initialized from %s", s.config.Datadir)
|
||||
log.Infof("geolocation service has been initialized from %s", s.Config.Datadir)
|
||||
|
||||
return geo
|
||||
})
|
||||
@@ -60,20 +61,22 @@ func (s *BaseServer) SettingsManager() settings.Manager {
|
||||
|
||||
func (s *BaseServer) PeersManager() peers.Manager {
|
||||
return Create(s, func() peers.Manager {
|
||||
return peers.NewManager(s.Store(), s.PermissionsManager())
|
||||
manager := peers.NewManager(s.Store(), s.PermissionsManager())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
manager.SetNetworkMapController(s.NetworkMapController())
|
||||
manager.SetIntegratedPeerValidator(s.IntegratedValidator())
|
||||
manager.SetAccountManager(s.AccountManager())
|
||||
})
|
||||
return manager
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AccountManager() account.Manager {
|
||||
return Create(s, func() account.Manager {
|
||||
accountManager, err := server.BuildManager(context.Background(), s.config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.config.DisableDefaultPolicy)
|
||||
accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create account manager: %v", err)
|
||||
}
|
||||
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
accountManager.SetEphemeralManager(s.EphemeralManager())
|
||||
})
|
||||
return accountManager
|
||||
})
|
||||
}
|
||||
@@ -82,8 +85,8 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
||||
return Create(s, func() idp.Manager {
|
||||
var idpManager idp.Manager
|
||||
var err error
|
||||
if s.config.IdpManagerConfig != nil {
|
||||
idpManager, err = idp.NewManager(context.Background(), *s.config.IdpManagerConfig, s.Metrics())
|
||||
if s.Config.IdpManagerConfig != nil {
|
||||
idpManager, err = idp.NewManager(context.Background(), *s.Config.IdpManagerConfig, s.Metrics())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create IDP manager: %v", err)
|
||||
}
|
||||
|
||||
@@ -41,10 +41,10 @@ type Server interface {
|
||||
}
|
||||
|
||||
// Server holds the HTTP BaseServer instance.
|
||||
// Add any additional fields you need, such as database connections, config, etc.
|
||||
// Add any additional fields you need, such as database connections, Config, etc.
|
||||
type BaseServer struct {
|
||||
// config holds the server configuration
|
||||
config *nbconfig.Config
|
||||
// Config holds the server configuration
|
||||
Config *nbconfig.Config
|
||||
// container of dependencies, each dependency is identified by a unique string.
|
||||
container map[string]any
|
||||
// AfterInit is a function that will be called after the server is initialized
|
||||
@@ -70,7 +70,7 @@ type BaseServer struct {
|
||||
// NewServer initializes and configures a new Server instance
|
||||
func NewServer(config *nbconfig.Config, dnsDomain, mgmtSingleAccModeDomain string, mgmtPort, mgmtMetricsPort int, disableMetrics, disableGeoliteUpdate, userDeleteFromIDPEnabled bool) *BaseServer {
|
||||
return &BaseServer{
|
||||
config: config,
|
||||
Config: config,
|
||||
container: make(map[string]any),
|
||||
dnsDomain: dnsDomain,
|
||||
mgmtSingleAccModeDomain: mgmtSingleAccModeDomain,
|
||||
@@ -103,14 +103,14 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
|
||||
var tlsConfig *tls.Config
|
||||
tlsEnabled := false
|
||||
if s.config.HttpConfig.LetsEncryptDomain != "" {
|
||||
s.certManager, err = encryption.CreateCertManager(s.config.Datadir, s.config.HttpConfig.LetsEncryptDomain)
|
||||
if s.Config.HttpConfig.LetsEncryptDomain != "" {
|
||||
s.certManager, err = encryption.CreateCertManager(s.Config.Datadir, s.Config.HttpConfig.LetsEncryptDomain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed creating LetsEncrypt cert manager: %v", err)
|
||||
}
|
||||
tlsEnabled = true
|
||||
} else if s.config.HttpConfig.CertFile != "" && s.config.HttpConfig.CertKey != "" {
|
||||
tlsConfig, err = loadTLSConfig(s.config.HttpConfig.CertFile, s.config.HttpConfig.CertKey)
|
||||
} else if s.Config.HttpConfig.CertFile != "" && s.Config.HttpConfig.CertKey != "" {
|
||||
tlsConfig, err = loadTLSConfig(s.Config.HttpConfig.CertFile, s.Config.HttpConfig.CertKey)
|
||||
if err != nil {
|
||||
log.WithContext(srvCtx).Errorf("cannot load TLS credentials: %v", err)
|
||||
return err
|
||||
@@ -126,8 +126,8 @@ func (s *BaseServer) Start(ctx context.Context) error {
|
||||
|
||||
if !s.disableMetrics {
|
||||
idpManager := "disabled"
|
||||
if s.config.IdpManagerConfig != nil && s.config.IdpManagerConfig.ManagerType != "" {
|
||||
idpManager = s.config.IdpManagerConfig.ManagerType
|
||||
if s.Config.IdpManagerConfig != nil && s.Config.IdpManagerConfig.ManagerType != "" {
|
||||
idpManager = s.Config.IdpManagerConfig.ManagerType
|
||||
}
|
||||
metricsWorker := metrics.NewWorker(srvCtx, installationID, s.Store(), s.PeersUpdateManager(), idpManager)
|
||||
go metricsWorker.Run(srvCtx)
|
||||
|
||||
@@ -24,7 +24,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/integrations/integrated_validator"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -55,15 +54,12 @@ const (
|
||||
type Server struct {
|
||||
accountManager account.Manager
|
||||
settingsManager settings.Manager
|
||||
wgKey wgtypes.Key
|
||||
proto.UnimplementedManagementServiceServer
|
||||
peersUpdateManager network_map.PeersUpdateManager
|
||||
config *nbconfig.Config
|
||||
secretsManager SecretsManager
|
||||
appMetrics telemetry.AppMetrics
|
||||
ephemeralManager ephemeral.Manager
|
||||
peerLocks sync.Map
|
||||
authManager auth.Manager
|
||||
config *nbconfig.Config
|
||||
secretsManager SecretsManager
|
||||
appMetrics telemetry.AppMetrics
|
||||
peerLocks sync.Map
|
||||
authManager auth.Manager
|
||||
|
||||
logBlockedPeers bool
|
||||
blockPeersWithSameConfig bool
|
||||
@@ -82,23 +78,16 @@ func NewServer(
|
||||
config *nbconfig.Config,
|
||||
accountManager account.Manager,
|
||||
settingsManager settings.Manager,
|
||||
peersUpdateManager network_map.PeersUpdateManager,
|
||||
secretsManager SecretsManager,
|
||||
appMetrics telemetry.AppMetrics,
|
||||
ephemeralManager ephemeral.Manager,
|
||||
authManager auth.Manager,
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||
networkMapController network_map.Controller,
|
||||
) (*Server, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if appMetrics != nil {
|
||||
// update gauge based on number of connected peers which is equal to open gRPC streams
|
||||
err = appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
|
||||
return int64(peersUpdateManager.CountStreams())
|
||||
err := appMetrics.GRPCMetrics().RegisterConnectedStreams(func() int64 {
|
||||
return int64(networkMapController.CountStreams())
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -120,16 +109,12 @@ func NewServer(
|
||||
}
|
||||
|
||||
return &Server{
|
||||
wgKey: key,
|
||||
// peerKey -> event channel
|
||||
peersUpdateManager: peersUpdateManager,
|
||||
accountManager: accountManager,
|
||||
settingsManager: settingsManager,
|
||||
config: config,
|
||||
secretsManager: secretsManager,
|
||||
authManager: authManager,
|
||||
appMetrics: appMetrics,
|
||||
ephemeralManager: ephemeralManager,
|
||||
logBlockedPeers: logBlockedPeers,
|
||||
blockPeersWithSameConfig: blockPeersWithSameConfig,
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
@@ -163,8 +148,14 @@ func (s *Server) GetServerKey(ctx context.Context, req *proto.Empty) (*proto.Ser
|
||||
nanos := int32(now.Nanosecond())
|
||||
expiresAt := ×tamp.Timestamp{Seconds: secs, Nanos: nanos}
|
||||
|
||||
key, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get wireguard key: %v", err)
|
||||
return nil, errors.New("failed to get wireguard key")
|
||||
}
|
||||
|
||||
return &proto.ServerKeyResponse{
|
||||
Key: s.wgKey.PublicKey().String(),
|
||||
Key: key.PublicKey().String(),
|
||||
ExpiresAt: expiresAt,
|
||||
}, nil
|
||||
}
|
||||
@@ -269,9 +260,13 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
return err
|
||||
}
|
||||
|
||||
updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
|
||||
|
||||
s.ephemeralManager.OnPeerConnected(ctx, peer)
|
||||
updates, err := s.networkMapController.OnPeerConnected(ctx, accountID, peer.ID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err)
|
||||
s.syncSem.Add(-1)
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return err
|
||||
}
|
||||
|
||||
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
|
||||
|
||||
@@ -323,13 +318,19 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg
|
||||
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
|
||||
// then sends the encrypted message to the connected peer via the sync server.
|
||||
func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error {
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
|
||||
key, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return status.Errorf(codes.Internal, "failed processing update message")
|
||||
}
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update)
|
||||
if err != nil {
|
||||
s.cancelPeerRoutines(ctx, accountID, peer)
|
||||
return status.Errorf(codes.Internal, "failed processing update message")
|
||||
}
|
||||
err = srv.SendMsg(&proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
WgPubKey: key.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -348,9 +349,8 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err)
|
||||
}
|
||||
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
s.networkMapController.OnPeerDisconnected(ctx, accountID, peer.ID)
|
||||
s.secretsManager.CancelRefresh(peer.ID)
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
|
||||
log.WithContext(ctx).Tracef("peer %s has been disconnected", peer.Key)
|
||||
}
|
||||
@@ -504,7 +504,12 @@ func (s *Server) parseRequest(ctx context.Context, req *proto.EncryptedMessage,
|
||||
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "provided wgPubKey %s is invalid", req.WgPubKey)
|
||||
}
|
||||
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, parsed)
|
||||
key, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
return wgtypes.Key{}, status.Errorf(codes.Internal, "failed processing request")
|
||||
}
|
||||
|
||||
err = encryption.DecryptMessage(peerKey, key, req.Body, parsed)
|
||||
if err != nil {
|
||||
return wgtypes.Key{}, status.Errorf(codes.InvalidArgument, "invalid request message")
|
||||
}
|
||||
@@ -601,12 +606,6 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
|
||||
log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart))
|
||||
|
||||
// if the login request contains setup key then it is a registration request
|
||||
if loginReq.GetSetupKey() != "" {
|
||||
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
|
||||
log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart))
|
||||
}
|
||||
|
||||
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed preparing login response for peer %s: %s", peerKey, err)
|
||||
@@ -615,14 +614,20 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
|
||||
log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart))
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
|
||||
key, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed getting server's WireGuard private key: %s", err)
|
||||
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||
}
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, key, loginResp)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
|
||||
return nil, status.Errorf(codes.Internal, "failed logging in peer")
|
||||
}
|
||||
|
||||
return &proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
WgPubKey: key.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
}, nil
|
||||
}
|
||||
@@ -715,14 +720,19 @@ func (s *Server) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, peer
|
||||
|
||||
plainResp := ToSyncResponse(ctx, s.config, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, peer, turnToken, relayToken, networkMap, s.networkMapController.GetDNSDomain(settings), postureChecks, nil, settings, settings.Extra, peerGroups, dnsFwdPort)
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, plainResp)
|
||||
key, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "failed getting server key")
|
||||
}
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, key, plainResp)
|
||||
if err != nil {
|
||||
return status.Errorf(codes.Internal, "error handling request")
|
||||
}
|
||||
|
||||
sendStart := time.Now()
|
||||
err = srv.Send(&proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
WgPubKey: key.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
})
|
||||
log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart))
|
||||
@@ -752,7 +762,12 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.DeviceAuthorizationFlowRequest{})
|
||||
key, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get server key")
|
||||
}
|
||||
|
||||
err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.DeviceAuthorizationFlowRequest{})
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
|
||||
log.WithContext(ctx).Warn(errMSG)
|
||||
@@ -782,13 +797,13 @@ func (s *Server) GetDeviceAuthorizationFlow(ctx context.Context, req *proto.Encr
|
||||
},
|
||||
}
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to encrypt no device authorization flow information")
|
||||
}
|
||||
|
||||
return &proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
WgPubKey: key.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
}, nil
|
||||
}
|
||||
@@ -810,7 +825,12 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp
|
||||
return nil, status.Error(codes.InvalidArgument, errMSG)
|
||||
}
|
||||
|
||||
err = encryption.DecryptMessage(peerKey, s.wgKey, req.Body, &proto.PKCEAuthorizationFlowRequest{})
|
||||
key, err := s.secretsManager.GetWGKey()
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "failed to get server key")
|
||||
}
|
||||
|
||||
err = encryption.DecryptMessage(peerKey, key, req.Body, &proto.PKCEAuthorizationFlowRequest{})
|
||||
if err != nil {
|
||||
errMSG := fmt.Sprintf("error while decrypting peer's message with Wireguard public key %s.", req.WgPubKey)
|
||||
log.WithContext(ctx).Warn(errMSG)
|
||||
@@ -838,13 +858,13 @@ func (s *Server) GetPKCEAuthorizationFlow(ctx context.Context, req *proto.Encryp
|
||||
|
||||
flowInfoResp := s.integratedPeerValidator.ValidateFlowResponse(ctx, peerKey.String(), initInfoFlow)
|
||||
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, flowInfoResp)
|
||||
encryptedResp, err := encryption.EncryptMessage(peerKey, key, flowInfoResp)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to encrypt no pkce authorization flow information")
|
||||
}
|
||||
|
||||
return &proto.EncryptedMessage{
|
||||
WgPubKey: s.wgKey.PublicKey().String(),
|
||||
WgPubKey: key.PublicKey().String(),
|
||||
Body: encryptedResp,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -73,15 +73,17 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
for _, testCase := range testCases {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
mgmtServer := &Server{
|
||||
wgKey: testingServerKey,
|
||||
secretsManager: &TimeBasedAuthSecretsManager{wgKey: testingServerKey},
|
||||
config: &config.Config{
|
||||
DeviceAuthorizationFlow: testCase.inputFlow,
|
||||
},
|
||||
}
|
||||
|
||||
message := &mgmtProto.DeviceAuthorizationFlowRequest{}
|
||||
key, err := mgmtServer.secretsManager.GetWGKey()
|
||||
require.NoError(t, err, "should be able to get server key")
|
||||
|
||||
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), mgmtServer.wgKey, message)
|
||||
encryptedMSG, err := encryption.EncryptMessage(testingClientKey.PublicKey(), key, message)
|
||||
require.NoError(t, err, "should be able to encrypt message")
|
||||
|
||||
resp, err := mgmtServer.GetDeviceAuthorizationFlow(
|
||||
@@ -95,7 +97,7 @@ func TestServer_GetDeviceAuthorizationFlow(t *testing.T) {
|
||||
if testCase.expectedComparisonFunc != nil {
|
||||
flowInfoResp := &mgmtProto.DeviceAuthorizationFlow{}
|
||||
|
||||
err = encryption.DecryptMessage(mgmtServer.wgKey.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
|
||||
err = encryption.DecryptMessage(key.PublicKey(), testingClientKey, resp.Body, flowInfoResp)
|
||||
require.NoError(t, err, "should be able to decrypt")
|
||||
|
||||
testCase.expectedComparisonFunc(t, testCase.expectedFlow.Provider, flowInfoResp.Provider, testCase.expectedComparisonMSG)
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
@@ -29,6 +30,7 @@ type SecretsManager interface {
|
||||
GenerateRelayToken() (*Token, error)
|
||||
SetupRefresh(ctx context.Context, accountID, peerKey string)
|
||||
CancelRefresh(peerKey string)
|
||||
GetWGKey() (wgtypes.Key, error)
|
||||
}
|
||||
|
||||
// TimeBasedAuthSecretsManager generates credentials with TTL and using pre-shared secret known to TURN server
|
||||
@@ -43,11 +45,17 @@ type TimeBasedAuthSecretsManager struct {
|
||||
groupsManager groups.Manager
|
||||
turnCancelMap map[string]chan struct{}
|
||||
relayCancelMap map[string]chan struct{}
|
||||
wgKey wgtypes.Key
|
||||
}
|
||||
|
||||
type Token auth.Token
|
||||
|
||||
func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) *TimeBasedAuthSecretsManager {
|
||||
func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager, turnCfg *nbconfig.TURNConfig, relayCfg *nbconfig.Relay, settingsManager settings.Manager, groupsManager groups.Manager) (*TimeBasedAuthSecretsManager, error) {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mgr := &TimeBasedAuthSecretsManager{
|
||||
updateManager: updateManager,
|
||||
turnCfg: turnCfg,
|
||||
@@ -56,6 +64,7 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager
|
||||
relayCancelMap: make(map[string]chan struct{}),
|
||||
settingsManager: settingsManager,
|
||||
groupsManager: groupsManager,
|
||||
wgKey: key,
|
||||
}
|
||||
|
||||
if turnCfg != nil {
|
||||
@@ -81,7 +90,12 @@ func NewTimeBasedAuthSecretsManager(updateManager network_map.PeersUpdateManager
|
||||
}
|
||||
}
|
||||
|
||||
return mgr
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
// GetWGKey returns WireGuard private key used to generate peer keys
|
||||
func (m *TimeBasedAuthSecretsManager) GetWGKey() (wgtypes.Key, error) {
|
||||
return m.wgKey, nil
|
||||
}
|
||||
|
||||
// GenerateTurnToken generates new time-based secret credentials for TURN
|
||||
|
||||
@@ -46,12 +46,13 @@ func TestTimeBasedAuthSecretsManager_GenerateCredentials(t *testing.T) {
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
||||
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*config.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager, groupsManager)
|
||||
require.NoError(t, err)
|
||||
|
||||
turnCredentials, err := tested.GenerateTurnToken()
|
||||
require.NoError(t, err)
|
||||
@@ -98,12 +99,13 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
|
||||
settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), "someAccountID").Return(&types.ExtraSettings{}, nil).AnyTimes()
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
||||
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*config.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager, groupsManager)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
@@ -201,12 +203,13 @@ func TestTimeBasedAuthSecretsManager_CancelRefresh(t *testing.T) {
|
||||
settingsMockManager := settings.NewMockManager(ctrl)
|
||||
groupsManager := groups.NewManagerMock()
|
||||
|
||||
tested := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
||||
tested, err := NewTimeBasedAuthSecretsManager(peersManager, &config.TURNConfig{
|
||||
CredentialsTTL: ttl,
|
||||
Secret: secret,
|
||||
Turns: []*config.Host{TurnTestHost},
|
||||
TimeBasedCredentials: true,
|
||||
}, rc, settingsMockManager, groupsManager)
|
||||
require.NoError(t, err)
|
||||
|
||||
tested.SetupRefresh(context.Background(), "someAccountID", peer)
|
||||
if _, ok := tested.turnCancelMap[peer]; !ok {
|
||||
|
||||
Reference in New Issue
Block a user