mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-20 07:39:56 +00:00
Merge branch 'main' of github.com:netbirdio/netbird into feat/local-user-totp
This commit is contained in:
@@ -257,7 +257,10 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
|
||||
|
||||
// UpdatePeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
||||
}
|
||||
|
||||
@@ -331,9 +334,13 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||
|
||||
if c.accountManagerMetrics != nil {
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
||||
b := bufUpd.(*bufferUpdate)
|
||||
|
||||
@@ -348,14 +355,14 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
||||
|
||||
go func() {
|
||||
defer b.mu.Unlock()
|
||||
_ = c.UpdateAccountPeers(ctx, accountID)
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
||||
if !b.update.Load() {
|
||||
return
|
||||
}
|
||||
b.update.Store(false)
|
||||
if b.next == nil {
|
||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||
_ = c.UpdateAccountPeers(ctx, accountID)
|
||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -18,9 +18,9 @@ const (
|
||||
)
|
||||
|
||||
type Controller interface {
|
||||
UpdateAccountPeers(ctx context.Context, accountID string) error
|
||||
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string) error
|
||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||
GetDNSDomain(settings *types.Settings) string
|
||||
StartWarmup(context.Context)
|
||||
|
||||
@@ -44,17 +44,17 @@ func (m *MockController) EXPECT() *MockControllerMockRecorder {
|
||||
}
|
||||
|
||||
// BufferUpdateAccountPeers mocks base method.
|
||||
func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID)
|
||||
ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID, reason)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers.
|
||||
func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, reason any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
// CountStreams mocks base method.
|
||||
@@ -238,15 +238,15 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeer(ctx, accountId, peerId a
|
||||
}
|
||||
|
||||
// UpdateAccountPeers mocks base method.
|
||||
func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||
func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID)
|
||||
ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID, reason)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// UpdateAccountPeers indicates an expected call of UpdateAccountPeers.
|
||||
func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *gomock.Call {
|
||||
func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason)
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ func (a *MockAccountManager) GetDeletePeerCalls() int {
|
||||
return a.deletePeerCalls
|
||||
}
|
||||
|
||||
func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||
func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
if a.bufferUpdateCalls == nil {
|
||||
@@ -248,7 +248,7 @@ func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
|
||||
return err
|
||||
}
|
||||
}
|
||||
mockAM.BufferUpdateAccountPeers(ctx, accountID)
|
||||
mockAM.BufferUpdateAccountPeers(ctx, accountID, types.UpdateReason{})
|
||||
return nil
|
||||
}).
|
||||
Times(1)
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -178,7 +179,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
||||
}
|
||||
}
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
|
||||
// Manager defines the interface for proxy operations
|
||||
type Manager interface {
|
||||
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error
|
||||
Disconnect(ctx context.Context, proxyID string) error
|
||||
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error)
|
||||
Disconnect(ctx context.Context, proxyID, sessionID string) error
|
||||
Heartbeat(ctx context.Context, p *Proxy) error
|
||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveClusters(ctx context.Context) ([]Cluster, error)
|
||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
|
||||
@@ -13,7 +13,8 @@ import (
|
||||
// store defines the interface for proxy persistence operations
|
||||
type store interface {
|
||||
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
@@ -43,7 +44,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
||||
|
||||
// Connect registers a new proxy connection in the database.
|
||||
// capabilities may be nil for old proxies that do not report them.
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error {
|
||||
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
|
||||
now := time.Now()
|
||||
var caps proxy.Capabilities
|
||||
if capabilities != nil {
|
||||
@@ -51,6 +52,7 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
|
||||
}
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
SessionID: sessionID,
|
||||
ClusterAddress: clusterAddress,
|
||||
IPAddress: ipAddress,
|
||||
LastSeen: now,
|
||||
@@ -61,48 +63,42 @@ func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"proxyID": proxyID,
|
||||
"sessionID": sessionID,
|
||||
"clusterAddress": clusterAddress,
|
||||
"ipAddress": ipAddress,
|
||||
}).Info("proxy connected")
|
||||
|
||||
return nil
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// Disconnect marks a proxy as disconnected in the database
|
||||
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
now := time.Now()
|
||||
p := &proxy.Proxy{
|
||||
ID: proxyID,
|
||||
Status: "disconnected",
|
||||
DisconnectedAt: &now,
|
||||
LastSeen: now,
|
||||
}
|
||||
|
||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
|
||||
// Disconnect marks a proxy as disconnected in the database.
|
||||
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).WithFields(log.Fields{
|
||||
"proxyID": proxyID,
|
||||
"proxyID": proxyID,
|
||||
"sessionID": sessionID,
|
||||
}).Info("proxy disconnected")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Heartbeat updates the proxy's last seen timestamp
|
||||
func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||
// Heartbeat updates the proxy's last seen timestamp.
|
||||
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
||||
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID)
|
||||
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s session %s", p.ID, p.SessionID)
|
||||
m.metrics.IncrementProxyHeartbeatCount()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -93,31 +93,32 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
|
||||
}
|
||||
|
||||
// Connect mocks base method.
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error {
|
||||
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
||||
ret0, _ := ret[0].(*Proxy)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Connect indicates an expected call of Connect.
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
||||
}
|
||||
|
||||
// Disconnect mocks base method.
|
||||
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
|
||||
func (m *MockManager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
|
||||
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID, sessionID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Disconnect indicates an expected call of Disconnect.
|
||||
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID, sessionID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID, sessionID)
|
||||
}
|
||||
|
||||
// GetActiveClusterAddresses mocks base method.
|
||||
@@ -151,17 +152,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Ca
|
||||
}
|
||||
|
||||
// Heartbeat mocks base method.
|
||||
func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
func (m *MockManager) Heartbeat(ctx context.Context, p *Proxy) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress)
|
||||
ret := m.ctrl.Call(m, "Heartbeat", ctx, p)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Heartbeat indicates an expected call of Heartbeat.
|
||||
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p)
|
||||
}
|
||||
|
||||
// MockController is a mock of Controller interface.
|
||||
|
||||
@@ -18,12 +18,13 @@ type Capabilities struct {
|
||||
// Proxy represents a reverse proxy instance
|
||||
type Proxy struct {
|
||||
ID string `gorm:"primaryKey;type:varchar(255)"`
|
||||
SessionID string `gorm:"type:varchar(36)"`
|
||||
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||
IPAddress string `gorm:"type:varchar(45)"`
|
||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||
ConnectedAt *time.Time
|
||||
DisconnectedAt *time.Time
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||
Capabilities Capabilities `gorm:"embedded"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
@@ -85,7 +85,7 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
|
||||
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {},
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
},
|
||||
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -231,7 +232,7 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationCreate})
|
||||
|
||||
return s, nil
|
||||
}
|
||||
@@ -515,7 +516,7 @@ func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, s
|
||||
}
|
||||
|
||||
m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationUpdate})
|
||||
|
||||
return service, nil
|
||||
}
|
||||
@@ -819,7 +820,7 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -860,7 +861,7 @@ func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID strin
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
|
||||
}
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -916,7 +917,7 @@ func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationUpdate})
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1098,7 +1099,7 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
|
||||
}
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationCreate})
|
||||
|
||||
serviceURL := "https://" + svc.Domain
|
||||
if service.IsL4Protocol(svc.Mode) {
|
||||
@@ -1210,7 +1211,7 @@ func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serv
|
||||
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1261,7 +1262,7 @@ func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerI
|
||||
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
|
||||
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta)
|
||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -447,7 +447,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) {
|
||||
storedActivity = activityID.(activity.Activity)
|
||||
},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {},
|
||||
}
|
||||
|
||||
mockStore.EXPECT().
|
||||
@@ -549,7 +549,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) {
|
||||
storedActivity = activityID.(activity.Activity)
|
||||
},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {},
|
||||
}
|
||||
|
||||
mockStore.EXPECT().
|
||||
@@ -593,7 +593,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, meta map[string]any) {
|
||||
storedMeta = meta
|
||||
},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {},
|
||||
}
|
||||
|
||||
mockStore.EXPECT().
|
||||
@@ -704,7 +704,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
|
||||
accountMgr := &mock_server.MockAccountManager{
|
||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {},
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
},
|
||||
@@ -1173,7 +1173,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
mockAcct.EXPECT().
|
||||
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
|
||||
mockAcct.EXPECT().
|
||||
UpdateAccountPeers(ctx, accountID)
|
||||
UpdateAccountPeers(ctx, accountID, gomock.Any())
|
||||
|
||||
err = mgr.DeleteService(ctx, accountID, userID, service.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -144,7 +145,7 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string,
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta())
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZone, Operation: types.UpdateOperationUpdate})
|
||||
|
||||
return zone, nil
|
||||
}
|
||||
@@ -206,7 +207,7 @@ func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZone, Operation: types.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -95,7 +96,7 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta)
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationCreate})
|
||||
|
||||
return record, nil
|
||||
}
|
||||
@@ -154,7 +155,7 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta)
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationUpdate})
|
||||
|
||||
return record, nil
|
||||
}
|
||||
@@ -201,7 +202,7 @@ func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneI
|
||||
meta := record.EventMeta(zone.ID, zone.Name)
|
||||
m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta)
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationDelete})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -173,7 +173,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
}
|
||||
|
||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider())
|
||||
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider(), s.SessionStore())
|
||||
if err != nil {
|
||||
log.Fatalf("failed to create management server: %v", err)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/management-integrations/integrations"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||
|
||||
@@ -66,6 +67,12 @@ func (s *BaseServer) SecretsManager() grpc.SecretsManager {
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) SessionStore() *auth.SessionStore {
|
||||
return Create(s, func() *auth.SessionStore {
|
||||
return auth.NewSessionStore(s.CacheStore())
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BaseServer) AuthManager() auth.Manager {
|
||||
audiences := s.Config.GetAuthAudiences()
|
||||
audience := s.Config.HttpConfig.AuthAudience
|
||||
|
||||
@@ -11,11 +11,14 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/google/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/grpc/codes"
|
||||
@@ -81,14 +84,44 @@ type ProxyServiceServer struct {
|
||||
// Store for PKCE verifiers
|
||||
pkceVerifierStore *PKCEVerifierStore
|
||||
|
||||
// tokenTTL is the lifetime of one-time tokens generated for proxy
|
||||
// authentication. Defaults to defaultProxyTokenTTL when zero.
|
||||
tokenTTL time.Duration
|
||||
|
||||
// snapshotBatchSize is the number of mappings per gRPC message during
|
||||
// initial snapshot delivery. Configurable via NB_PROXY_SNAPSHOT_BATCH_SIZE.
|
||||
snapshotBatchSize int
|
||||
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
const pkceVerifierTTL = 10 * time.Minute
|
||||
|
||||
const defaultProxyTokenTTL = 5 * time.Minute
|
||||
|
||||
const defaultSnapshotBatchSize = 500
|
||||
|
||||
func snapshotBatchSizeFromEnv() int {
|
||||
if v := os.Getenv("NB_PROXY_SNAPSHOT_BATCH_SIZE"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
return n
|
||||
}
|
||||
}
|
||||
return defaultSnapshotBatchSize
|
||||
}
|
||||
|
||||
// proxyTokenTTL returns the configured token TTL or the default when unset.
|
||||
func (s *ProxyServiceServer) proxyTokenTTL() time.Duration {
|
||||
if s.tokenTTL > 0 {
|
||||
return s.tokenTTL
|
||||
}
|
||||
return defaultProxyTokenTTL
|
||||
}
|
||||
|
||||
// proxyConnection represents a connected proxy
|
||||
type proxyConnection struct {
|
||||
proxyID string
|
||||
sessionID string
|
||||
address string
|
||||
capabilities *proto.ProxyCapabilities
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
@@ -108,6 +141,7 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
proxyManager: proxyMgr,
|
||||
snapshotBatchSize: snapshotBatchSizeFromEnv(),
|
||||
cancel: cancel,
|
||||
}
|
||||
go s.cleanupStaleProxies(ctx)
|
||||
@@ -166,9 +200,22 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
return status.Errorf(codes.InvalidArgument, "proxy address is invalid")
|
||||
}
|
||||
|
||||
sessionID := uuid.NewString()
|
||||
|
||||
if old, loaded := s.connectedProxies.Load(proxyID); loaded {
|
||||
oldConn := old.(*proxyConnection)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"old_session_id": oldConn.sessionID,
|
||||
"new_session_id": sessionID,
|
||||
}).Info("Superseding existing proxy connection")
|
||||
oldConn.cancel()
|
||||
}
|
||||
|
||||
connCtx, cancel := context.WithCancel(ctx)
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
sessionID: sessionID,
|
||||
address: proxyAddress,
|
||||
capabilities: req.GetCapabilities(),
|
||||
stream: stream,
|
||||
@@ -177,79 +224,93 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
// Register proxy in database with capabilities
|
||||
var caps *proxy.Capabilities
|
||||
if c := req.GetCapabilities(); c != nil {
|
||||
caps = &proxy.Capabilities{
|
||||
SupportsCustomPorts: c.SupportsCustomPorts,
|
||||
RequireSubdomain: c.RequireSubdomain,
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
SupportsCrowdsec: c.SupportsCrowdsec,
|
||||
}
|
||||
}
|
||||
if err := s.proxyManager.Connect(ctx, proxyID, proxyAddress, peerInfo, caps); err != nil {
|
||||
proxyRecord, err := s.proxyManager.Connect(ctx, proxyID, sessionID, proxyAddress, peerInfo, caps)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to register proxy %s in database: %v", proxyID, err)
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(ctx, conn.address, proxyID); unregErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after Connect failure for proxy %s: %v", proxyID, unregErr)
|
||||
}
|
||||
cancel()
|
||||
return status.Errorf(codes.Internal, "register proxy in database: %v", err)
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"address": proxyAddress,
|
||||
"cluster_addr": proxyAddress,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||
}
|
||||
|
||||
s.connectedProxies.Delete(proxyID)
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
log.Infof("Proxy %s disconnected", proxyID)
|
||||
}()
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
if err := s.proxyController.RegisterProxyToCluster(ctx, conn.address, proxyID); err != nil {
|
||||
log.WithContext(ctx).Warnf("Failed to register proxy %s in cluster: %v", proxyID, err)
|
||||
}
|
||||
|
||||
if err := s.sendSnapshot(ctx, conn); err != nil {
|
||||
if s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
if unregErr := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); unregErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, unregErr)
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
if disconnErr := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); disconnErr != nil {
|
||||
log.WithContext(ctx).Debugf("cleanup after snapshot failure for proxy %s: %v", proxyID, disconnErr)
|
||||
}
|
||||
return fmt.Errorf("send snapshot to proxy %s: %w", proxyID, err)
|
||||
}
|
||||
|
||||
errChan := make(chan error, 2)
|
||||
go s.sender(conn, errChan)
|
||||
|
||||
// Start heartbeat goroutine
|
||||
go s.heartbeat(connCtx, proxyID, proxyAddress, peerInfo)
|
||||
log.WithFields(log.Fields{
|
||||
"proxy_id": proxyID,
|
||||
"session_id": sessionID,
|
||||
"address": proxyAddress,
|
||||
"cluster_addr": proxyAddress,
|
||||
"total_proxies": len(s.GetConnectedProxies()),
|
||||
}).Info("Proxy registered in cluster")
|
||||
defer func() {
|
||||
if !s.connectedProxies.CompareAndDelete(proxyID, conn) {
|
||||
log.Infof("Proxy %s session %s: skipping cleanup, superseded by new connection", proxyID, sessionID)
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.proxyController.UnregisterProxyFromCluster(context.Background(), conn.address, proxyID); err != nil {
|
||||
log.Warnf("Failed to unregister proxy %s from cluster: %v", proxyID, err)
|
||||
}
|
||||
if err := s.proxyManager.Disconnect(context.Background(), proxyID, sessionID); err != nil {
|
||||
log.Warnf("Failed to mark proxy %s as disconnected: %v", proxyID, err)
|
||||
}
|
||||
|
||||
cancel()
|
||||
log.Infof("Proxy %s session %s disconnected", proxyID, sessionID)
|
||||
}()
|
||||
|
||||
go s.heartbeat(connCtx, proxyRecord)
|
||||
|
||||
select {
|
||||
case err := <-errChan:
|
||||
log.WithContext(ctx).Warnf("Failed to send update: %v", err)
|
||||
return fmt.Errorf("send update to proxy %s: %w", proxyID, err)
|
||||
case <-connCtx.Done():
|
||||
log.WithContext(ctx).Infof("Proxy %s context canceled", proxyID)
|
||||
return connCtx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
// heartbeat updates the proxy's last_seen timestamp every minute
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) {
|
||||
func (s *ProxyServiceServer) heartbeat(ctx context.Context, p *proxy.Proxy) {
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
if err := s.proxyManager.Heartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||
if err := s.proxyManager.Heartbeat(ctx, p); err != nil {
|
||||
log.WithContext(ctx).Debugf("Failed to update proxy %s heartbeat: %v", p.ID, err)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
log.WithContext(ctx).Infof("proxy %s heartbeat stopped: context canceled", p.ID)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -267,22 +328,27 @@ func (s *ProxyServiceServer) sendSnapshot(ctx context.Context, conn *proxyConnec
|
||||
return err
|
||||
}
|
||||
|
||||
// Send mappings in batches to reduce per-message gRPC overhead while
|
||||
// staying well within the default 4 MB message size limit.
|
||||
for i := 0; i < len(mappings); i += s.snapshotBatchSize {
|
||||
end := i + s.snapshotBatchSize
|
||||
if end > len(mappings) {
|
||||
end = len(mappings)
|
||||
}
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: mappings[i:end],
|
||||
InitialSyncComplete: end == len(mappings),
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot batch: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(mappings) == 0 {
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
InitialSyncComplete: true,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send snapshot completion: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for i, m := range mappings {
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{m},
|
||||
InitialSyncComplete: i == len(mappings)-1,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("send proxy mapping: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -300,13 +366,9 @@ func (s *ProxyServiceServer) snapshotServiceMappings(ctx context.Context, conn *
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, 5*time.Minute)
|
||||
token, err := s.tokenStore.GenerateToken(service.AccountID, service.ID, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"service": service.Name,
|
||||
"account": service.AccountID,
|
||||
}).WithError(err).Error("failed to generate auth token for snapshot")
|
||||
continue
|
||||
return nil, fmt.Errorf("generate auth token for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m := service.ToProtoMapping(rpservice.Create, token, s.GetOIDCValidationConfig())
|
||||
@@ -386,13 +448,16 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateRes
|
||||
conn := value.(*proxyConnection)
|
||||
resp := s.perProxyMessage(update, conn.proxyID)
|
||||
if resp == nil {
|
||||
log.Warnf("Token generation failed for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||
conn.cancel()
|
||||
return true
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- resp:
|
||||
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
|
||||
default:
|
||||
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
|
||||
log.Warnf("Send channel full for proxy %s, disconnecting to force resync", conn.proxyID)
|
||||
conn.cancel()
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -472,13 +537,16 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
||||
}
|
||||
msg := s.perProxyMessage(updateResponse, proxyID)
|
||||
if msg == nil {
|
||||
log.WithContext(ctx).Warnf("Token generation failed for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
|
||||
conn.cancel()
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- msg:
|
||||
log.WithContext(ctx).Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||
default:
|
||||
log.WithContext(ctx).Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
log.WithContext(ctx).Warnf("Send channel full for proxy %s in cluster %s, disconnecting to force resync", proxyID, clusterAddr)
|
||||
conn.cancel()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -504,7 +572,8 @@ func proxyAcceptsMapping(conn *proxyConnection, mapping *proto.ProxyMapping) boo
|
||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||
// create/update operations. For delete operations the original mapping is
|
||||
// used unchanged because proxies do not need to authenticate for removal.
|
||||
// Returns nil if token generation fails (the proxy should be skipped).
|
||||
// Returns nil if token generation fails; the caller must disconnect the
|
||||
// proxy so it can resync via a fresh snapshot on reconnect.
|
||||
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
|
||||
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
||||
for _, mapping := range update.Mapping {
|
||||
@@ -513,7 +582,7 @@ func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateRespo
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
|
||||
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, s.proxyTokenTTL())
|
||||
if err != nil {
|
||||
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
|
||||
return nil
|
||||
|
||||
174
management/internals/shared/grpc/proxy_snapshot_test.go
Normal file
174
management/internals/shared/grpc/proxy_snapshot_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
// recordingStream captures all messages sent via Send so tests can inspect
|
||||
// batching behaviour without a real gRPC transport.
|
||||
type recordingStream struct {
|
||||
grpc.ServerStream
|
||||
messages []*proto.GetMappingUpdateResponse
|
||||
}
|
||||
|
||||
func (s *recordingStream) Send(m *proto.GetMappingUpdateResponse) error {
|
||||
s.messages = append(s.messages, m)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *recordingStream) Context() context.Context { return context.Background() }
|
||||
func (s *recordingStream) SetHeader(metadata.MD) error { return nil }
|
||||
func (s *recordingStream) SendHeader(metadata.MD) error { return nil }
|
||||
func (s *recordingStream) SetTrailer(metadata.MD) {}
|
||||
func (s *recordingStream) SendMsg(any) error { return nil }
|
||||
func (s *recordingStream) RecvMsg(any) error { return nil }
|
||||
|
||||
// makeServices creates n enabled services assigned to the given cluster.
|
||||
func makeServices(n int, cluster string) []*rpservice.Service {
|
||||
services := make([]*rpservice.Service, n)
|
||||
for i := range n {
|
||||
services[i] = &rpservice.Service{
|
||||
ID: fmt.Sprintf("svc-%d", i),
|
||||
AccountID: "acct-1",
|
||||
Name: fmt.Sprintf("svc-%d", i),
|
||||
Domain: fmt.Sprintf("svc-%d.example.com", i),
|
||||
ProxyCluster: cluster,
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{TargetType: rpservice.TargetTypeHost, TargetId: "host-1"},
|
||||
},
|
||||
}
|
||||
}
|
||||
return services
|
||||
}
|
||||
|
||||
func newSnapshotTestServer(t *testing.T, batchSize int) *ProxyServiceServer {
|
||||
t.Helper()
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: NewOneTimeTokenStore(context.Background(), testCacheStore(t)),
|
||||
snapshotBatchSize: batchSize,
|
||||
}
|
||||
s.SetProxyController(newTestProxyController())
|
||||
return s
|
||||
}
|
||||
|
||||
func TestSendSnapshot_BatchesMappings(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 3
|
||||
const totalServices = 7 // 3 + 3 + 1
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{
|
||||
proxyID: "proxy-a",
|
||||
address: cluster,
|
||||
stream: stream,
|
||||
}
|
||||
|
||||
err := s.sendSnapshot(context.Background(), conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expect ceil(7/3) = 3 messages
|
||||
require.Len(t, stream.messages, 3, "should send ceil(totalServices/batchSize) messages")
|
||||
|
||||
assert.Len(t, stream.messages[0].Mapping, 3)
|
||||
assert.False(t, stream.messages[0].InitialSyncComplete, "first batch should not be sync-complete")
|
||||
|
||||
assert.Len(t, stream.messages[1].Mapping, 3)
|
||||
assert.False(t, stream.messages[1].InitialSyncComplete, "middle batch should not be sync-complete")
|
||||
|
||||
assert.Len(t, stream.messages[2].Mapping, 1)
|
||||
assert.True(t, stream.messages[2].InitialSyncComplete, "last batch must be sync-complete")
|
||||
|
||||
// Verify all service IDs are present exactly once
|
||||
seen := make(map[string]bool)
|
||||
for _, msg := range stream.messages {
|
||||
for _, m := range msg.Mapping {
|
||||
assert.False(t, seen[m.Id], "duplicate service ID %s", m.Id)
|
||||
seen[m.Id] = true
|
||||
}
|
||||
}
|
||||
assert.Len(t, seen, totalServices)
|
||||
}
|
||||
|
||||
func TestSendSnapshot_ExactBatchMultiple(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 3
|
||||
const totalServices = 6 // exactly 2 batches
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||
|
||||
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||
require.Len(t, stream.messages, 2)
|
||||
|
||||
assert.Len(t, stream.messages[0].Mapping, 3)
|
||||
assert.False(t, stream.messages[0].InitialSyncComplete)
|
||||
|
||||
assert.Len(t, stream.messages[1].Mapping, 3)
|
||||
assert.True(t, stream.messages[1].InitialSyncComplete)
|
||||
}
|
||||
|
||||
func TestSendSnapshot_SingleBatch(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
const batchSize = 100
|
||||
const totalServices = 5
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(makeServices(totalServices, cluster), nil)
|
||||
|
||||
s := newSnapshotTestServer(t, batchSize)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||
|
||||
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||
require.Len(t, stream.messages, 1, "all mappings should fit in one batch")
|
||||
assert.Len(t, stream.messages[0].Mapping, totalServices)
|
||||
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||
}
|
||||
|
||||
func TestSendSnapshot_EmptySnapshot(t *testing.T) {
|
||||
const cluster = "cluster.example.com"
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
mgr := rpservice.NewMockManager(ctrl)
|
||||
mgr.EXPECT().GetGlobalServices(gomock.Any()).Return(nil, nil)
|
||||
|
||||
s := newSnapshotTestServer(t, 500)
|
||||
s.serviceManager = mgr
|
||||
|
||||
stream := &recordingStream{}
|
||||
conn := &proxyConnection{proxyID: "proxy-a", address: cluster, stream: stream}
|
||||
|
||||
require.NoError(t, s.sendSnapshot(context.Background(), conn))
|
||||
require.Len(t, stream.messages, 1, "empty snapshot must still send sync-complete")
|
||||
assert.Empty(t, stream.messages[0].Mapping)
|
||||
assert.True(t, stream.messages[0].InitialSyncComplete)
|
||||
}
|
||||
@@ -85,11 +85,14 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
|
||||
// registerFakeProxyWithCaps adds a fake proxy connection with explicit capabilities.
|
||||
func registerFakeProxyWithCaps(s *ProxyServiceServer, proxyID, clusterAddr string, caps *proto.ProxyCapabilities) chan *proto.GetMappingUpdateResponse {
|
||||
ch := make(chan *proto.GetMappingUpdateResponse, 10)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
address: clusterAddr,
|
||||
capabilities: caps,
|
||||
sendChan: ch,
|
||||
ctx: ctx,
|
||||
cancel: cancel,
|
||||
}
|
||||
s.connectedProxies.Store(proxyID, conn)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
jwtv5 "github.com/golang-jwt/jwt/v5"
|
||||
pb "github.com/golang/protobuf/proto" // nolint
|
||||
"github.com/golang/protobuf/ptypes/timestamp"
|
||||
"github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip"
|
||||
@@ -67,6 +68,7 @@ type Server struct {
|
||||
appMetrics telemetry.AppMetrics
|
||||
peerLocks sync.Map
|
||||
authManager auth.Manager
|
||||
sessionStore *auth.SessionStore
|
||||
|
||||
logBlockedPeers bool
|
||||
blockPeersWithSameConfig bool
|
||||
@@ -98,6 +100,7 @@ func NewServer(
|
||||
integratedPeerValidator integrated_validator.IntegratedValidator,
|
||||
networkMapController network_map.Controller,
|
||||
oAuthConfigProvider idp.OAuthConfigProvider,
|
||||
sessionStore *auth.SessionStore,
|
||||
) (*Server, error) {
|
||||
if appMetrics != nil {
|
||||
// update gauge based on number of connected peers which is equal to open gRPC streams
|
||||
@@ -140,6 +143,7 @@ func NewServer(
|
||||
integratedPeerValidator: integratedPeerValidator,
|
||||
networkMapController: networkMapController,
|
||||
oAuthConfigProvider: oAuthConfigProvider,
|
||||
sessionStore: sessionStore,
|
||||
|
||||
loginFilter: newLoginFilter(),
|
||||
|
||||
@@ -535,7 +539,7 @@ func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID st
|
||||
log.WithContext(ctx).Debugf("peer %s has been disconnected", peer.Key)
|
||||
}
|
||||
|
||||
func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, error) {
|
||||
func (s *Server) validateToken(ctx context.Context, peerKey, jwtToken string) (string, error) {
|
||||
if s.authManager == nil {
|
||||
return "", status.Errorf(codes.Internal, "missing auth manager")
|
||||
}
|
||||
@@ -545,6 +549,10 @@ func (s *Server) validateToken(ctx context.Context, jwtToken string) (string, er
|
||||
return "", status.Errorf(codes.InvalidArgument, "invalid jwt token, err: %v", err)
|
||||
}
|
||||
|
||||
if err := s.claimLoginToken(ctx, peerKey, jwtToken, token); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||
accountId, _, err := s.accountManager.GetAccountIDFromUserAuth(ctx, userAuth)
|
||||
if err != nil {
|
||||
@@ -828,6 +836,31 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
|
||||
return loginResp, nil
|
||||
}
|
||||
|
||||
func (s *Server) claimLoginToken(ctx context.Context, peerKey, jwtToken string, token *jwtv5.Token) error {
|
||||
if s.sessionStore == nil || token == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
exp, err := token.Claims.GetExpirationTime()
|
||||
if err != nil || exp == nil {
|
||||
log.WithContext(ctx).Warnf("JWT has no usable exp claim for peer %s", peerKey)
|
||||
return status.Error(codes.Unauthenticated, "jwt token has no expiration")
|
||||
}
|
||||
|
||||
err = s.sessionStore.RegisterToken(ctx, jwtToken, exp.Time)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if errors.Is(err, auth.ErrTokenAlreadyUsed) || errors.Is(err, auth.ErrTokenExpired) {
|
||||
log.WithContext(ctx).Warnf("%v for peer %s", err, peerKey)
|
||||
return status.Error(codes.Unauthenticated, err.Error())
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Warnf("failed to claim JWT for peer %s: %v", peerKey, err)
|
||||
return status.Error(codes.Unavailable, "failed to claim jwt token")
|
||||
}
|
||||
|
||||
// processJwtToken validates the existence of a JWT token in the login request, and returns the corresponding user ID if
|
||||
// the token is valid.
|
||||
//
|
||||
@@ -838,7 +871,7 @@ func (s *Server) processJwtToken(ctx context.Context, loginReq *proto.LoginReque
|
||||
if loginReq.GetJwtToken() != "" {
|
||||
var err error
|
||||
for i := 0; i < 3; i++ {
|
||||
userID, err = s.validateToken(ctx, loginReq.GetJwtToken())
|
||||
userID, err = s.validateToken(ctx, peerKey.String(), loginReq.GetJwtToken())
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user