mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[proxy] Send proxy updates on account delete (#5375)
This commit is contained in:
@@ -12,6 +12,7 @@ type Manager interface {
|
||||
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
|
||||
DeleteAllServices(ctx context.Context, accountID, userID string) error
|
||||
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
|
||||
SetStatus(ctx context.Context, accountID, serviceID string, status ProxyStatus) error
|
||||
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
|
||||
|
||||
@@ -49,6 +49,20 @@ func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
|
||||
}
|
||||
|
||||
// DeleteAllServices mocks base method.
|
||||
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAllServices indicates an expected call of DeleteAllServices.
|
||||
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||
}
|
||||
|
||||
// DeleteService mocks base method.
|
||||
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -150,7 +151,7 @@ func (m *managerImpl) CreateService(ctx context.Context, accountID, userID strin
|
||||
return nil, fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
@@ -330,21 +331,35 @@ func (m *managerImpl) preserveServiceMetadata(service, existingService *reversep
|
||||
}
|
||||
|
||||
func (m *managerImpl) sendServiceUpdateNotifications(service *reverseproxy.Service, updateInfo *serviceUpdateInfo) {
|
||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
|
||||
switch {
|
||||
case updateInfo.domainChanged && updateInfo.oldCluster != service.ProxyCluster:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), updateInfo.oldCluster)
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
||||
m.sendServiceUpdate(service, reverseproxy.Delete, updateInfo.oldCluster, "")
|
||||
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
|
||||
case !service.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg), service.ProxyCluster)
|
||||
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "")
|
||||
case service.Enabled && updateInfo.serviceEnabledChanged:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Create, "", oidcCfg), service.ProxyCluster)
|
||||
m.sendServiceUpdate(service, reverseproxy.Create, service.ProxyCluster, "")
|
||||
default:
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", oidcCfg), service.ProxyCluster)
|
||||
m.sendServiceUpdate(service, reverseproxy.Update, service.ProxyCluster, "")
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) sendServiceUpdate(service *reverseproxy.Service, operation reverseproxy.Operation, cluster, oldService string) {
|
||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
mapping := service.ToProtoMapping(operation, oldService, oidcCfg)
|
||||
m.sendMappingsToCluster([]*proto.ProxyMapping{mapping}, cluster)
|
||||
}
|
||||
|
||||
func (m *managerImpl) sendMappingsToCluster(mappings []*proto.ProxyMapping, cluster string) {
|
||||
if len(mappings) == 0 {
|
||||
return
|
||||
}
|
||||
update := &proto.GetMappingUpdateResponse{
|
||||
Mapping: mappings,
|
||||
}
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(update, cluster)
|
||||
}
|
||||
|
||||
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
||||
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*reverseproxy.Target) error {
|
||||
for _, target := range targets {
|
||||
@@ -397,7 +412,54 @@ func (m *managerImpl) DeleteService(ctx context.Context, accountID, userID, serv
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, serviceID, accountID, activity.ServiceDeleted, service.EventMeta())
|
||||
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Delete, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||
m.sendServiceUpdate(service, reverseproxy.Delete, service.ProxyCluster, "")
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Services, operations.Delete)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
if !ok {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
var services []*reverseproxy.Service
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
var err error
|
||||
services, err = transaction.GetServicesByAccountID(ctx, store.LockingStrengthUpdate, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, service := range services {
|
||||
if err = transaction.DeleteService(ctx, accountID, service.ID); err != nil {
|
||||
return fmt.Errorf("failed to delete service: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clusterMappings := make(map[string][]*proto.ProxyMapping)
|
||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
|
||||
for _, service := range services {
|
||||
m.accountManager.StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, service.EventMeta())
|
||||
mapping := service.ToProtoMapping(reverseproxy.Delete, "", oidcCfg)
|
||||
clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping)
|
||||
}
|
||||
|
||||
for cluster, mappings := range clusterMappings {
|
||||
m.sendMappingsToCluster(mappings, cluster)
|
||||
}
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
@@ -452,7 +514,7 @@ func (m *managerImpl) ReloadService(ctx context.Context, accountID, serviceID st
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||
m.sendServiceUpdate(service, reverseproxy.Update, service.ProxyCluster, "")
|
||||
|
||||
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
|
||||
@@ -465,12 +527,20 @@ func (m *managerImpl) ReloadAllServicesForAccount(ctx context.Context, accountID
|
||||
return fmt.Errorf("failed to get services: %w", err)
|
||||
}
|
||||
|
||||
clusterMappings := make(map[string][]*proto.ProxyMapping)
|
||||
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||
|
||||
for _, service := range services {
|
||||
err = m.replaceHostByLookup(ctx, accountID, service)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to replace host by lookup for service %s: %w", service.ID, err)
|
||||
}
|
||||
m.proxyGRPCServer.SendServiceUpdateToCluster(service.ToProtoMapping(reverseproxy.Update, "", m.proxyGRPCServer.GetOIDCValidationConfig()), service.ProxyCluster)
|
||||
mapping := service.ToProtoMapping(reverseproxy.Update, "", oidcCfg)
|
||||
clusterMappings[service.ProxyCluster] = append(clusterMappings[service.ProxyCluster], mapping)
|
||||
}
|
||||
|
||||
for cluster, mappings := range clusterMappings {
|
||||
m.sendMappingsToCluster(mappings, cluster)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
@@ -61,9 +61,6 @@ type ProxyServiceServer struct {
|
||||
// Map of cluster address -> set of proxy IDs
|
||||
clusterProxies sync.Map
|
||||
|
||||
// Channel for broadcasting reverse proxy updates to all proxies
|
||||
updatesChan chan *proto.ProxyMapping
|
||||
|
||||
// Manager for access logs
|
||||
accessLogManager accesslogs.Manager
|
||||
|
||||
@@ -101,7 +98,7 @@ type proxyConnection struct {
|
||||
proxyID string
|
||||
address string
|
||||
stream proto.ProxyService_GetMappingUpdateServer
|
||||
sendChan chan *proto.ProxyMapping
|
||||
sendChan chan *proto.GetMappingUpdateResponse
|
||||
ctx context.Context
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
@@ -110,7 +107,6 @@ type proxyConnection struct {
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager) *ProxyServiceServer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &ProxyServiceServer{
|
||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||
accessLogManager: accessLogMgr,
|
||||
oidcConfig: oidcConfig,
|
||||
tokenStore: tokenStore,
|
||||
@@ -177,7 +173,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
||||
proxyID: proxyID,
|
||||
address: proxyAddress,
|
||||
stream: stream,
|
||||
sendChan: make(chan *proto.ProxyMapping, 100),
|
||||
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||
ctx: connCtx,
|
||||
cancel: cancel,
|
||||
}
|
||||
@@ -288,7 +284,7 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
||||
for {
|
||||
select {
|
||||
case msg := <-conn.sendChan:
|
||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{Mapping: []*proto.ProxyMapping{msg}}); err != nil {
|
||||
if err := conn.stream.Send(msg); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
@@ -339,7 +335,7 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
|
||||
// Management should call this when services are created/updated/removed.
|
||||
// For create/update operations a unique one-time auth token is generated per
|
||||
// proxy so that every replica can independently authenticate with management.
|
||||
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) {
|
||||
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
|
||||
log.Debugf("Broadcasting service update to all connected proxy servers")
|
||||
s.connectedProxies.Range(func(key, value interface{}) bool {
|
||||
conn := value.(*proxyConnection)
|
||||
@@ -349,7 +345,7 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) {
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- msg:
|
||||
log.Debugf("Sent service update with id %s to proxy server %s", update.Id, conn.proxyID)
|
||||
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
|
||||
default:
|
||||
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
|
||||
}
|
||||
@@ -418,7 +414,7 @@ func (s *ProxyServiceServer) removeFromCluster(clusterAddr, proxyID string) {
|
||||
// If clusterAddr is empty, broadcasts to all connected proxy servers (backward compatibility).
|
||||
// For create/update operations a unique one-time auth token is generated per
|
||||
// proxy so that every replica can independently authenticate with management.
|
||||
func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMapping, clusterAddr string) {
|
||||
func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.GetMappingUpdateResponse, clusterAddr string) {
|
||||
if clusterAddr == "" {
|
||||
s.SendServiceUpdate(update)
|
||||
return
|
||||
@@ -441,7 +437,7 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMappi
|
||||
}
|
||||
select {
|
||||
case conn.sendChan <- msg:
|
||||
log.Debugf("Sent service update with id %s to proxy %s in cluster %s", update.Id, proxyID, clusterAddr)
|
||||
log.Debugf("Sent service update to proxy %s in cluster %s", proxyID, clusterAddr)
|
||||
default:
|
||||
log.Warnf("Failed to send service update to proxy %s in cluster %s (channel full)", proxyID, clusterAddr)
|
||||
}
|
||||
@@ -451,23 +447,31 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(update *proto.ProxyMappi
|
||||
}
|
||||
|
||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||
// create/update operations. For delete operations the original message is
|
||||
// returned unchanged because proxies do not need to authenticate for removal.
|
||||
// create/update operations. For delete operations the original mapping is
|
||||
// used unchanged because proxies do not need to authenticate for removal.
|
||||
// Returns nil if token generation fails (the proxy should be skipped).
|
||||
func (s *ProxyServiceServer) perProxyMessage(update *proto.ProxyMapping, proxyID string) *proto.ProxyMapping {
|
||||
if update.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED || update.AccountId == "" {
|
||||
return update
|
||||
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
|
||||
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
||||
for _, mapping := range update.Mapping {
|
||||
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
|
||||
resp = append(resp, mapping)
|
||||
continue
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
msg := shallowCloneMapping(mapping)
|
||||
msg.AuthToken = token
|
||||
resp = append(resp, msg)
|
||||
}
|
||||
|
||||
token, err := s.tokenStore.GenerateToken(update.AccountId, update.Id, 5*time.Minute)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
|
||||
return nil
|
||||
return &proto.GetMappingUpdateResponse{
|
||||
Mapping: resp,
|
||||
}
|
||||
|
||||
msg := shallowCloneMapping(update)
|
||||
msg.AuthToken = token
|
||||
return msg
|
||||
}
|
||||
|
||||
// shallowCloneMapping creates a shallow copy of a ProxyMapping, reusing the
|
||||
|
||||
@@ -17,6 +17,10 @@ type mockReverseProxyManager struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
|
||||
// registerFakeProxy adds a fake proxy connection to the server's internal maps
|
||||
// and returns the channel where messages will be received.
|
||||
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
|
||||
ch := make(chan *proto.ProxyMapping, 10)
|
||||
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse {
|
||||
ch := make(chan *proto.GetMappingUpdateResponse, 10)
|
||||
conn := &proxyConnection{
|
||||
proxyID: proxyID,
|
||||
address: clusterAddr,
|
||||
@@ -31,7 +31,7 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
|
||||
return ch
|
||||
}
|
||||
|
||||
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
|
||||
func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
|
||||
select {
|
||||
case msg := <-ch:
|
||||
return msg
|
||||
@@ -45,20 +45,19 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
defer tokenStore.Close()
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||
tokenStore: tokenStore,
|
||||
}
|
||||
|
||||
const cluster = "proxy.example.com"
|
||||
const numProxies = 3
|
||||
|
||||
channels := make([]chan *proto.ProxyMapping, numProxies)
|
||||
channels := make([]chan *proto.GetMappingUpdateResponse, numProxies)
|
||||
for i := range numProxies {
|
||||
id := "proxy-" + string(rune('a'+i))
|
||||
channels[i] = registerFakeProxy(s, id, cluster)
|
||||
}
|
||||
|
||||
update := &proto.ProxyMapping{
|
||||
mapping := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "service-1",
|
||||
AccountId: "account-1",
|
||||
@@ -68,14 +67,20 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
update := &proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{mapping},
|
||||
}
|
||||
|
||||
s.SendServiceUpdateToCluster(update, cluster)
|
||||
|
||||
tokens := make([]string, numProxies)
|
||||
for i, ch := range channels {
|
||||
msg := drainChannel(ch)
|
||||
require.NotNil(t, msg, "proxy %d should receive a message", i)
|
||||
assert.Equal(t, update.Domain, msg.Domain)
|
||||
assert.Equal(t, update.Id, msg.Id)
|
||||
resp := drainChannel(ch)
|
||||
require.NotNil(t, resp, "proxy %d should receive a message", i)
|
||||
require.Len(t, resp.Mapping, 1, "proxy %d should receive exactly one mapping", i)
|
||||
msg := resp.Mapping[0]
|
||||
assert.Equal(t, mapping.Domain, msg.Domain)
|
||||
assert.Equal(t, mapping.Id, msg.Id)
|
||||
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
|
||||
tokens[i] = msg.AuthToken
|
||||
}
|
||||
@@ -100,31 +105,36 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
||||
defer tokenStore.Close()
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||
tokenStore: tokenStore,
|
||||
}
|
||||
|
||||
const cluster = "proxy.example.com"
|
||||
ch1 := registerFakeProxy(s, "proxy-a", cluster)
|
||||
ch2 := registerFakeProxy(s, "proxy-b", cluster)
|
||||
|
||||
update := &proto.ProxyMapping{
|
||||
mapping := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
||||
Id: "service-1",
|
||||
AccountId: "account-1",
|
||||
Domain: "test.example.com",
|
||||
}
|
||||
|
||||
update := &proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{mapping},
|
||||
}
|
||||
|
||||
s.SendServiceUpdateToCluster(update, cluster)
|
||||
|
||||
msg1 := drainChannel(ch1)
|
||||
msg2 := drainChannel(ch2)
|
||||
require.NotNil(t, msg1)
|
||||
require.NotNil(t, msg2)
|
||||
resp1 := drainChannel(ch1)
|
||||
resp2 := drainChannel(ch2)
|
||||
require.NotNil(t, resp1)
|
||||
require.NotNil(t, resp2)
|
||||
require.Len(t, resp1.Mapping, 1)
|
||||
require.Len(t, resp2.Mapping, 1)
|
||||
|
||||
// Delete operations should not generate tokens
|
||||
assert.Empty(t, msg1.AuthToken)
|
||||
assert.Empty(t, msg2.AuthToken)
|
||||
assert.Empty(t, resp1.Mapping[0].AuthToken)
|
||||
assert.Empty(t, resp2.Mapping[0].AuthToken)
|
||||
|
||||
// No tokens should have been created
|
||||
assert.Equal(t, 0, tokenStore.GetTokenCount())
|
||||
@@ -135,27 +145,35 @@ func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
||||
defer tokenStore.Close()
|
||||
|
||||
s := &ProxyServiceServer{
|
||||
tokenStore: tokenStore,
|
||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
||||
tokenStore: tokenStore,
|
||||
}
|
||||
|
||||
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
|
||||
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
|
||||
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
|
||||
|
||||
update := &proto.ProxyMapping{
|
||||
mapping := &proto.ProxyMapping{
|
||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||
Id: "service-1",
|
||||
AccountId: "account-1",
|
||||
Domain: "test.example.com",
|
||||
}
|
||||
|
||||
update := &proto.GetMappingUpdateResponse{
|
||||
Mapping: []*proto.ProxyMapping{mapping},
|
||||
}
|
||||
|
||||
s.SendServiceUpdate(update)
|
||||
|
||||
msg1 := drainChannel(ch1)
|
||||
msg2 := drainChannel(ch2)
|
||||
require.NotNil(t, msg1)
|
||||
require.NotNil(t, msg2)
|
||||
resp1 := drainChannel(ch1)
|
||||
resp2 := drainChannel(ch2)
|
||||
require.NotNil(t, resp1)
|
||||
require.NotNil(t, resp2)
|
||||
require.Len(t, resp1.Mapping, 1)
|
||||
require.Len(t, resp2.Mapping, 1)
|
||||
|
||||
msg1 := resp1.Mapping[0]
|
||||
msg2 := resp2.Mapping[0]
|
||||
|
||||
assert.NotEmpty(t, msg1.AuthToken)
|
||||
assert.NotEmpty(t, msg2.AuthToken)
|
||||
|
||||
@@ -714,6 +714,11 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
err = am.reverseProxyManager.DeleteAllServices(ctx, accountID, userID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
for _, otherUser := range account.Users {
|
||||
if otherUser.Id == userID {
|
||||
continue
|
||||
|
||||
@@ -31,6 +31,7 @@ import (
|
||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/manager"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/cache"
|
||||
@@ -3122,7 +3123,8 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil))
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil)
|
||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyGrpcServer, nil))
|
||||
|
||||
return manager, updateManager, nil
|
||||
}
|
||||
|
||||
@@ -357,6 +357,10 @@ type testServiceManager struct {
|
||||
store store.Store
|
||||
}
|
||||
|
||||
func (m *testServiceManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -4906,6 +4906,28 @@ func (s *SqlStore) GetServiceByID(ctx context.Context, lockStrength LockingStren
|
||||
return service, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
tx := s.db.Preload("Targets")
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
|
||||
var serviceList []*reverseproxy.Service
|
||||
result := tx.Find(&serviceList, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get services from the store: %s", result.Error)
|
||||
return nil, status.Errorf(status.Internal, "failed to get services from store")
|
||||
}
|
||||
|
||||
for _, service := range serviceList {
|
||||
if err := service.DecryptSensitiveData(s.fieldEncrypt); err != nil {
|
||||
return nil, fmt.Errorf("decrypt service data: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return serviceList, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error) {
|
||||
var service *reverseproxy.Service
|
||||
result := s.db.Preload("Targets").Where("account_id = ? AND domain = ?", accountID, domain).First(&service)
|
||||
|
||||
@@ -256,6 +256,7 @@ type Store interface {
|
||||
UpdateService(ctx context.Context, service *reverseproxy.Service) error
|
||||
DeleteService(ctx context.Context, accountID, serviceID string) error
|
||||
GetServiceByID(ctx context.Context, lockStrength LockingStrength, accountID, serviceID string) (*reverseproxy.Service, error)
|
||||
GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error)
|
||||
GetServiceByDomain(ctx context.Context, accountID, domain string) (*reverseproxy.Service, error)
|
||||
GetServices(ctx context.Context, lockStrength LockingStrength) ([]*reverseproxy.Service, error)
|
||||
GetAccountServices(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error)
|
||||
|
||||
@@ -1109,6 +1109,21 @@ func (mr *MockStoreMockRecorder) GetAccountServices(ctx, lockStrength, accountID
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockStore)(nil).GetAccountServices), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetServicesByAccountID mocks base method.
|
||||
func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID)
|
||||
ret0, _ := ret[0].([]*reverseproxy.Service)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetServicesByAccountID indicates an expected call of GetServicesByAccountID.
|
||||
func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID)
|
||||
}
|
||||
|
||||
// GetAccountSettings mocks base method.
|
||||
func (m *MockStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types2.Settings, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -191,6 +191,10 @@ type storeBackedServiceManager struct {
|
||||
tokenStore *nbgrpc.OneTimeTokenStore
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user