mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 16:26:38 +00:00
Merge remote-tracking branch 'origin/main' into feature/add-serial-to-proxy
This commit is contained in:
@@ -351,9 +351,13 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
|||||||
return fmt.Errorf("upstream check call error")
|
return fmt.Errorf("upstream check call error")
|
||||||
}
|
}
|
||||||
|
|
||||||
err := backoff.Retry(operation, exponentialBackOff)
|
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(err)
|
if errors.Is(err, context.Canceled) {
|
||||||
|
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
|
||||||
|
} else {
|
||||||
|
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -70,6 +70,7 @@ type ServerConfig struct {
|
|||||||
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
|
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
|
||||||
Auth AuthConfig `yaml:"auth"`
|
Auth AuthConfig `yaml:"auth"`
|
||||||
Store StoreConfig `yaml:"store"`
|
Store StoreConfig `yaml:"store"`
|
||||||
|
ActivityStore StoreConfig `yaml:"activityStore"`
|
||||||
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
|
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -141,6 +141,17 @@ func initializeConfig() error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if engine := config.Server.ActivityStore.Engine; engine != "" {
|
||||||
|
engineLower := strings.ToLower(engine)
|
||||||
|
if engineLower == "postgres" && config.Server.ActivityStore.DSN == "" {
|
||||||
|
return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres")
|
||||||
|
}
|
||||||
|
os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower)
|
||||||
|
if dsn := config.Server.ActivityStore.DSN; dsn != "" {
|
||||||
|
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
log.Infof("Starting combined NetBird server")
|
log.Infof("Starting combined NetBird server")
|
||||||
logConfig(config)
|
logConfig(config)
|
||||||
logEnvVars()
|
logEnvVars()
|
||||||
@@ -668,8 +679,11 @@ func logEnvVars() {
|
|||||||
if strings.HasPrefix(env, "NB_") {
|
if strings.HasPrefix(env, "NB_") {
|
||||||
key, _, _ := strings.Cut(env, "=")
|
key, _, _ := strings.Cut(env, "=")
|
||||||
value := os.Getenv(key)
|
value := os.Getenv(key)
|
||||||
if strings.Contains(strings.ToLower(key), "secret") || strings.Contains(strings.ToLower(key), "key") || strings.Contains(strings.ToLower(key), "password") {
|
keyLower := strings.ToLower(key)
|
||||||
|
if strings.Contains(keyLower, "secret") || strings.Contains(keyLower, "key") || strings.Contains(keyLower, "password") {
|
||||||
value = maskSecret(value)
|
value = maskSecret(value)
|
||||||
|
} else if strings.Contains(keyLower, "dsn") {
|
||||||
|
value = maskDSNPassword(value)
|
||||||
}
|
}
|
||||||
log.Infof(" %s=%s", key, value)
|
log.Infof(" %s=%s", key, value)
|
||||||
found = true
|
found = true
|
||||||
|
|||||||
@@ -104,6 +104,11 @@ server:
|
|||||||
dsn: "" # Connection string for postgres or mysql
|
dsn: "" # Connection string for postgres or mysql
|
||||||
encryptionKey: ""
|
encryptionKey: ""
|
||||||
|
|
||||||
|
# Activity events store configuration (optional, defaults to sqlite in dataDir)
|
||||||
|
# activityStore:
|
||||||
|
# engine: "sqlite" # sqlite or postgres
|
||||||
|
# dsn: "" # Connection string for postgres
|
||||||
|
|
||||||
# Reverse proxy settings (optional)
|
# Reverse proxy settings (optional)
|
||||||
# reverseProxy:
|
# reverseProxy:
|
||||||
# trustedHTTPProxies: []
|
# trustedHTTPProxies: []
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ type Controller struct {
|
|||||||
|
|
||||||
expNewNetworkMap bool
|
expNewNetworkMap bool
|
||||||
expNewNetworkMapAIDs map[string]struct{}
|
expNewNetworkMapAIDs map[string]struct{}
|
||||||
|
|
||||||
|
compactedNetworkMap bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type bufferUpdate struct {
|
type bufferUpdate struct {
|
||||||
@@ -85,6 +87,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
|||||||
newNetworkMapBuilder = false
|
newNetworkMapBuilder = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
compactedNetworkMap, err := strconv.ParseBool(os.Getenv(types.EnvNewNetworkMapCompacted))
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", types.EnvNewNetworkMapCompacted, err)
|
||||||
|
compactedNetworkMap = false
|
||||||
|
}
|
||||||
|
|
||||||
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
|
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
|
||||||
expIDs := make(map[string]struct{}, len(ids))
|
expIDs := make(map[string]struct{}, len(ids))
|
||||||
for _, id := range ids {
|
for _, id := range ids {
|
||||||
@@ -108,6 +116,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
|||||||
holder: types.NewHolder(),
|
holder: types.NewHolder(),
|
||||||
expNewNetworkMap: newNetworkMapBuilder,
|
expNewNetworkMap: newNetworkMapBuilder,
|
||||||
expNewNetworkMapAIDs: expIDs,
|
expNewNetworkMapAIDs: expIDs,
|
||||||
|
|
||||||
|
compactedNetworkMap: compactedNetworkMap,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -230,9 +240,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
|
|
||||||
var remotePeerNetworkMap *types.NetworkMap
|
var remotePeerNetworkMap *types.NetworkMap
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountID) {
|
switch {
|
||||||
|
case c.experimentalNetworkMap(accountID):
|
||||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||||
} else {
|
case c.compactedNetworkMap:
|
||||||
|
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
default:
|
||||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,9 +368,12 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
|
|
||||||
var remotePeerNetworkMap *types.NetworkMap
|
var remotePeerNetworkMap *types.NetworkMap
|
||||||
|
|
||||||
if c.experimentalNetworkMap(accountId) {
|
switch {
|
||||||
|
case c.experimentalNetworkMap(accountId):
|
||||||
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||||
} else {
|
case c.compactedNetworkMap:
|
||||||
|
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
default:
|
||||||
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -479,7 +495,12 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
} else {
|
} else {
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
if c.compactedNetworkMap {
|
||||||
|
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
} else {
|
||||||
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
@@ -854,7 +875,12 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
|||||||
account.InjectProxyPolicies(ctx)
|
account.InjectProxyPolicies(ctx)
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
if c.compactedNetworkMap {
|
||||||
|
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||||
|
} else {
|
||||||
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ type Manager interface {
|
|||||||
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||||
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
|
||||||
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
|
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
|
||||||
|
DeleteAllServices(ctx context.Context, accountID, userID string) error
|
||||||
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
|
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
|
||||||
SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
|
SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
|
||||||
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
|
ReloadAllServicesForAccount(ctx context.Context, accountID string) error
|
||||||
|
|||||||
@@ -50,6 +50,20 @@ func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteAllServices mocks base method.
|
||||||
|
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
|
||||||
|
ret0, _ := ret[0].(error)
|
||||||
|
return ret0
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteAllServices indicates an expected call of DeleteAllServices.
|
||||||
|
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
|
||||||
|
}
|
||||||
|
|
||||||
// DeleteService mocks base method.
|
// DeleteService mocks base method.
|
||||||
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -344,6 +345,22 @@ func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) sendServiceUpdate(service *reverseproxy.Service, operation reverseproxy.Operation, cluster, oldService string) {
|
||||||
|
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
|
||||||
|
mapping := service.ToProtoMapping(operation, oldService, oidcCfg)
|
||||||
|
m.sendMappingsToCluster([]*proto.ProxyMapping{mapping}, cluster)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *managerImpl) sendMappingsToCluster(mappings []*proto.ProxyMapping, cluster string) {
|
||||||
|
if len(mappings) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
update := &proto.GetMappingUpdateResponse{
|
||||||
|
Mapping: mappings,
|
||||||
|
}
|
||||||
|
m.proxyGRPCServer.SendServiceUpdateToCluster(update, cluster)
|
||||||
|
}
|
||||||
|
|
||||||
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
|
||||||
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error {
|
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error {
|
||||||
for _, target := range targets {
|
for _, target := range targets {
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ type proxyConnection struct {
|
|||||||
proxyID string
|
proxyID string
|
||||||
address string
|
address string
|
||||||
stream proto.ProxyService_GetMappingUpdateServer
|
stream proto.ProxyService_GetMappingUpdateServer
|
||||||
sendChan chan *proto.ProxyMapping
|
sendChan chan *proto.GetMappingUpdateResponse
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
@@ -114,7 +114,6 @@ type proxyConnection struct {
|
|||||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
|
||||||
accessLogManager: accessLogMgr,
|
accessLogManager: accessLogMgr,
|
||||||
oidcConfig: oidcConfig,
|
oidcConfig: oidcConfig,
|
||||||
tokenStore: tokenStore,
|
tokenStore: tokenStore,
|
||||||
@@ -203,7 +202,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
|
|||||||
proxyID: proxyID,
|
proxyID: proxyID,
|
||||||
address: proxyAddress,
|
address: proxyAddress,
|
||||||
stream: stream,
|
stream: stream,
|
||||||
sendChan: make(chan *proto.ProxyMapping, 100),
|
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
|
||||||
ctx: connCtx,
|
ctx: connCtx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
}
|
}
|
||||||
@@ -349,7 +348,7 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case msg := <-conn.sendChan:
|
case msg := <-conn.sendChan:
|
||||||
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{Mapping: []*proto.ProxyMapping{msg}}); err != nil {
|
if err := conn.stream.Send(msg); err != nil {
|
||||||
errChan <- err
|
errChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -400,7 +399,7 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
|
|||||||
// Management should call this when services are created/updated/removed.
|
// Management should call this when services are created/updated/removed.
|
||||||
// For create/update operations a unique one-time auth token is generated per
|
// For create/update operations a unique one-time auth token is generated per
|
||||||
// proxy so that every replica can independently authenticate with management.
|
// proxy so that every replica can independently authenticate with management.
|
||||||
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) {
|
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
|
||||||
log.Debugf("Broadcasting service update to all connected proxy servers")
|
log.Debugf("Broadcasting service update to all connected proxy servers")
|
||||||
s.connectedProxies.Range(func(key, value interface{}) bool {
|
s.connectedProxies.Range(func(key, value interface{}) bool {
|
||||||
conn := value.(*proxyConnection)
|
conn := value.(*proxyConnection)
|
||||||
@@ -410,7 +409,7 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) {
|
|||||||
}
|
}
|
||||||
select {
|
select {
|
||||||
case conn.sendChan <- msg:
|
case conn.sendChan <- msg:
|
||||||
log.Debugf("Sent service update with id %s to proxy server %s", update.Id, conn.proxyID)
|
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
|
||||||
default:
|
default:
|
||||||
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
|
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
|
||||||
}
|
}
|
||||||
@@ -494,23 +493,31 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
|
|||||||
}
|
}
|
||||||
|
|
||||||
// perProxyMessage returns a copy of update with a fresh one-time token for
|
// perProxyMessage returns a copy of update with a fresh one-time token for
|
||||||
// create/update operations. For delete operations the original message is
|
// create/update operations. For delete operations the original mapping is
|
||||||
// returned unchanged because proxies do not need to authenticate for removal.
|
// used unchanged because proxies do not need to authenticate for removal.
|
||||||
// Returns nil if token generation fails (the proxy should be skipped).
|
// Returns nil if token generation fails (the proxy should be skipped).
|
||||||
func (s *ProxyServiceServer) perProxyMessage(update *proto.ProxyMapping, proxyID string) *proto.ProxyMapping {
|
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
|
||||||
if update.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED || update.AccountId == "" {
|
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
|
||||||
return update
|
for _, mapping := range update.Mapping {
|
||||||
|
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
|
||||||
|
resp = append(resp, mapping)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := shallowCloneMapping(mapping)
|
||||||
|
msg.AuthToken = token
|
||||||
|
resp = append(resp, msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := s.tokenStore.GenerateToken(update.AccountId, update.Id, 5*time.Minute)
|
return &proto.GetMappingUpdateResponse{
|
||||||
if err != nil {
|
Mapping: resp,
|
||||||
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
msg := shallowCloneMapping(update)
|
|
||||||
msg.AuthToken = token
|
|
||||||
return msg
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// shallowCloneMapping creates a shallow copy of a ProxyMapping, reusing the
|
// shallowCloneMapping creates a shallow copy of a ProxyMapping, reusing the
|
||||||
|
|||||||
@@ -17,6 +17,10 @@ type mockReverseProxyManager struct {
|
|||||||
err error
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
|
||||||
if m.err != nil {
|
if m.err != nil {
|
||||||
return nil, m.err
|
return nil, m.err
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ import (
|
|||||||
|
|
||||||
// registerFakeProxy adds a fake proxy connection to the server's internal maps
|
// registerFakeProxy adds a fake proxy connection to the server's internal maps
|
||||||
// and returns the channel where messages will be received.
|
// and returns the channel where messages will be received.
|
||||||
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
|
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse {
|
||||||
ch := make(chan *proto.ProxyMapping, 10)
|
ch := make(chan *proto.GetMappingUpdateResponse, 10)
|
||||||
conn := &proxyConnection{
|
conn := &proxyConnection{
|
||||||
proxyID: proxyID,
|
proxyID: proxyID,
|
||||||
address: clusterAddr,
|
address: clusterAddr,
|
||||||
@@ -32,7 +32,7 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
|
|||||||
return ch
|
return ch
|
||||||
}
|
}
|
||||||
|
|
||||||
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
|
func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
|
||||||
select {
|
select {
|
||||||
case msg := <-ch:
|
case msg := <-ch:
|
||||||
return msg
|
return msg
|
||||||
@@ -46,20 +46,19 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
tokenStore: tokenStore,
|
tokenStore: tokenStore,
|
||||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const cluster = "proxy.example.com"
|
const cluster = "proxy.example.com"
|
||||||
const numProxies = 3
|
const numProxies = 3
|
||||||
|
|
||||||
channels := make([]chan *proto.ProxyMapping, numProxies)
|
channels := make([]chan *proto.GetMappingUpdateResponse, numProxies)
|
||||||
for i := range numProxies {
|
for i := range numProxies {
|
||||||
id := "proxy-" + string(rune('a'+i))
|
id := "proxy-" + string(rune('a'+i))
|
||||||
channels[i] = registerFakeProxy(s, id, cluster)
|
channels[i] = registerFakeProxy(s, id, cluster)
|
||||||
}
|
}
|
||||||
|
|
||||||
update := &proto.ProxyMapping{
|
mapping := &proto.ProxyMapping{
|
||||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||||
Id: "service-1",
|
Id: "service-1",
|
||||||
AccountId: "account-1",
|
AccountId: "account-1",
|
||||||
@@ -73,10 +72,12 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
|
|||||||
|
|
||||||
tokens := make([]string, numProxies)
|
tokens := make([]string, numProxies)
|
||||||
for i, ch := range channels {
|
for i, ch := range channels {
|
||||||
msg := drainChannel(ch)
|
resp := drainChannel(ch)
|
||||||
require.NotNil(t, msg, "proxy %d should receive a message", i)
|
require.NotNil(t, resp, "proxy %d should receive a message", i)
|
||||||
assert.Equal(t, update.Domain, msg.Domain)
|
require.Len(t, resp.Mapping, 1, "proxy %d should receive exactly one mapping", i)
|
||||||
assert.Equal(t, update.Id, msg.Id)
|
msg := resp.Mapping[0]
|
||||||
|
assert.Equal(t, mapping.Domain, msg.Domain)
|
||||||
|
assert.Equal(t, mapping.Id, msg.Id)
|
||||||
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
|
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
|
||||||
tokens[i] = msg.AuthToken
|
tokens[i] = msg.AuthToken
|
||||||
}
|
}
|
||||||
@@ -101,15 +102,14 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
tokenStore: tokenStore,
|
tokenStore: tokenStore,
|
||||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const cluster = "proxy.example.com"
|
const cluster = "proxy.example.com"
|
||||||
ch1 := registerFakeProxy(s, "proxy-a", cluster)
|
ch1 := registerFakeProxy(s, "proxy-a", cluster)
|
||||||
ch2 := registerFakeProxy(s, "proxy-b", cluster)
|
ch2 := registerFakeProxy(s, "proxy-b", cluster)
|
||||||
|
|
||||||
update := &proto.ProxyMapping{
|
mapping := &proto.ProxyMapping{
|
||||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
|
||||||
Id: "service-1",
|
Id: "service-1",
|
||||||
AccountId: "account-1",
|
AccountId: "account-1",
|
||||||
@@ -118,10 +118,12 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
|
|||||||
|
|
||||||
s.SendServiceUpdateToCluster(context.Background(), update, cluster)
|
s.SendServiceUpdateToCluster(context.Background(), update, cluster)
|
||||||
|
|
||||||
msg1 := drainChannel(ch1)
|
resp1 := drainChannel(ch1)
|
||||||
msg2 := drainChannel(ch2)
|
resp2 := drainChannel(ch2)
|
||||||
require.NotNil(t, msg1)
|
require.NotNil(t, resp1)
|
||||||
require.NotNil(t, msg2)
|
require.NotNil(t, resp2)
|
||||||
|
require.Len(t, resp1.Mapping, 1)
|
||||||
|
require.Len(t, resp2.Mapping, 1)
|
||||||
|
|
||||||
// Delete operations should not generate tokens
|
// Delete operations should not generate tokens
|
||||||
assert.Empty(t, msg1.AuthToken)
|
assert.Empty(t, msg1.AuthToken)
|
||||||
@@ -133,27 +135,35 @@ func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
s := &ProxyServiceServer{
|
s := &ProxyServiceServer{
|
||||||
tokenStore: tokenStore,
|
tokenStore: tokenStore,
|
||||||
updatesChan: make(chan *proto.ProxyMapping, 100),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
|
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
|
||||||
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
|
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
|
||||||
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
|
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
|
||||||
|
|
||||||
update := &proto.ProxyMapping{
|
mapping := &proto.ProxyMapping{
|
||||||
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
|
||||||
Id: "service-1",
|
Id: "service-1",
|
||||||
AccountId: "account-1",
|
AccountId: "account-1",
|
||||||
Domain: "test.example.com",
|
Domain: "test.example.com",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
update := &proto.GetMappingUpdateResponse{
|
||||||
|
Mapping: []*proto.ProxyMapping{mapping},
|
||||||
|
}
|
||||||
|
|
||||||
s.SendServiceUpdate(update)
|
s.SendServiceUpdate(update)
|
||||||
|
|
||||||
msg1 := drainChannel(ch1)
|
resp1 := drainChannel(ch1)
|
||||||
msg2 := drainChannel(ch2)
|
resp2 := drainChannel(ch2)
|
||||||
require.NotNil(t, msg1)
|
require.NotNil(t, resp1)
|
||||||
require.NotNil(t, msg2)
|
require.NotNil(t, resp2)
|
||||||
|
require.Len(t, resp1.Mapping, 1)
|
||||||
|
require.Len(t, resp2.Mapping, 1)
|
||||||
|
|
||||||
|
msg1 := resp1.Mapping[0]
|
||||||
|
msg2 := resp2.Mapping[0]
|
||||||
|
|
||||||
assert.NotEmpty(t, msg1.AuthToken)
|
assert.NotEmpty(t, msg1.AuthToken)
|
||||||
assert.NotEmpty(t, msg2.AuthToken)
|
assert.NotEmpty(t, msg2.AuthToken)
|
||||||
|
|||||||
@@ -714,6 +714,11 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
|||||||
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
|
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
err = am.reverseProxyManager.DeleteAllServices(ctx, accountID, userID)
|
||||||
|
if err != nil {
|
||||||
|
return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err)
|
||||||
|
}
|
||||||
|
|
||||||
for _, otherUser := range account.Users {
|
for _, otherUser := range account.Users {
|
||||||
if otherUser.Id == userID {
|
if otherUser.Id == userID {
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ import (
|
|||||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
nbAccount "github.com/netbirdio/netbird/management/server/account"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/cache"
|
"github.com/netbirdio/netbird/management/server/cache"
|
||||||
@@ -3122,7 +3123,8 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
|||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil))
|
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil)
|
||||||
|
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyGrpcServer, nil))
|
||||||
|
|
||||||
return manager, updateManager, nil
|
return manager, updateManager, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -358,6 +358,10 @@ type testServiceManager struct {
|
|||||||
store store.Store
|
store store.Store
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *testServiceManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
|
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -210,6 +210,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
|||||||
rosenpassEnabled int
|
rosenpassEnabled int
|
||||||
localUsers int
|
localUsers int
|
||||||
idpUsers int
|
idpUsers int
|
||||||
|
embeddedIdpTypes map[string]int
|
||||||
)
|
)
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
metricsProperties := make(properties)
|
metricsProperties := make(properties)
|
||||||
@@ -218,6 +219,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
|||||||
rulesProtocol = make(map[string]int)
|
rulesProtocol = make(map[string]int)
|
||||||
rulesDirection = make(map[string]int)
|
rulesDirection = make(map[string]int)
|
||||||
activeUsersLastDay = make(map[string]struct{})
|
activeUsersLastDay = make(map[string]struct{})
|
||||||
|
embeddedIdpTypes = make(map[string]int)
|
||||||
uptime = time.Since(w.startupTime).Seconds()
|
uptime = time.Since(w.startupTime).Seconds()
|
||||||
connections := w.connManager.GetAllConnectedPeers()
|
connections := w.connManager.GetAllConnectedPeers()
|
||||||
version = nbversion.NetbirdVersion()
|
version = nbversion.NetbirdVersion()
|
||||||
@@ -277,6 +279,8 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
|||||||
localUsers++
|
localUsers++
|
||||||
} else {
|
} else {
|
||||||
idpUsers++
|
idpUsers++
|
||||||
|
idpType := extractIdpType(idpID)
|
||||||
|
embeddedIdpTypes[idpType]++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -369,6 +373,11 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
|
|||||||
metricsProperties["rosenpass_enabled"] = rosenpassEnabled
|
metricsProperties["rosenpass_enabled"] = rosenpassEnabled
|
||||||
metricsProperties["local_users_count"] = localUsers
|
metricsProperties["local_users_count"] = localUsers
|
||||||
metricsProperties["idp_users_count"] = idpUsers
|
metricsProperties["idp_users_count"] = idpUsers
|
||||||
|
metricsProperties["embedded_idp_count"] = len(embeddedIdpTypes)
|
||||||
|
|
||||||
|
for idpType, count := range embeddedIdpTypes {
|
||||||
|
metricsProperties["embedded_idp_users_"+idpType] = count
|
||||||
|
}
|
||||||
|
|
||||||
for protocol, count := range rulesProtocol {
|
for protocol, count := range rulesProtocol {
|
||||||
metricsProperties["rules_protocol_"+protocol] = count
|
metricsProperties["rules_protocol_"+protocol] = count
|
||||||
@@ -456,6 +465,17 @@ func createPostRequest(ctx context.Context, endpoint string, payloadStr string)
|
|||||||
return req, cancel, nil
|
return req, cancel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractIdpType extracts the IdP type from a Dex connector ID.
|
||||||
|
// Connector IDs are formatted as "<type>-<xid>" (e.g., "okta-abc123", "zitadel-xyz").
|
||||||
|
// Returns the type prefix, or "oidc" if no known prefix is found.
|
||||||
|
func extractIdpType(connectorID string) string {
|
||||||
|
idx := strings.LastIndex(connectorID, "-")
|
||||||
|
if idx <= 0 {
|
||||||
|
return "oidc"
|
||||||
|
}
|
||||||
|
return strings.ToLower(connectorID[:idx])
|
||||||
|
}
|
||||||
|
|
||||||
func getMinMaxVersion(inputList []string) (string, string) {
|
func getMinMaxVersion(inputList []string) (string, string) {
|
||||||
versions := make([]*version.Version, 0)
|
versions := make([]*version.Version, 0)
|
||||||
|
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} {
|
|||||||
// GetAllAccounts returns a list of *server.Account for use in tests with predefined information
|
// GetAllAccounts returns a list of *server.Account for use in tests with predefined information
|
||||||
func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
|
||||||
localUserID := dex.EncodeDexUserID("10", "local")
|
localUserID := dex.EncodeDexUserID("10", "local")
|
||||||
idpUserID := dex.EncodeDexUserID("20", "zitadel")
|
idpUserID := dex.EncodeDexUserID("20", "zitadel-d5uv82dra0haedlf6kv0")
|
||||||
return []*types.Account{
|
return []*types.Account{
|
||||||
{
|
{
|
||||||
Id: "1",
|
Id: "1",
|
||||||
@@ -341,4 +341,37 @@ func TestGenerateProperties(t *testing.T) {
|
|||||||
if properties["idp_users_count"] != 1 {
|
if properties["idp_users_count"] != 1 {
|
||||||
t.Errorf("expected 1 idp_users_count, got %d", properties["idp_users_count"])
|
t.Errorf("expected 1 idp_users_count, got %d", properties["idp_users_count"])
|
||||||
}
|
}
|
||||||
|
if properties["embedded_idp_users_zitadel"] != 1 {
|
||||||
|
t.Errorf("expected 1 embedded_idp_users_zitadel, got %v", properties["embedded_idp_users_zitadel"])
|
||||||
|
}
|
||||||
|
if properties["embedded_idp_count"] != 1 {
|
||||||
|
t.Errorf("expected 1 embedded_idp_count, got %v", properties["embedded_idp_count"])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractIdpType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
connectorID string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"okta-abc123def", "okta"},
|
||||||
|
{"zitadel-d5uv82dra0haedlf6kv0", "zitadel"},
|
||||||
|
{"entra-xyz789", "entra"},
|
||||||
|
{"google-abc123", "google"},
|
||||||
|
{"pocketid-abc123", "pocketid"},
|
||||||
|
{"microsoft-abc123", "microsoft"},
|
||||||
|
{"authentik-abc123", "authentik"},
|
||||||
|
{"keycloak-d5uv82dra0haedlf6kv0", "keycloak"},
|
||||||
|
{"local", "oidc"},
|
||||||
|
{"", "oidc"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.connectorID, func(t *testing.T) {
|
||||||
|
result := extractIdpType(tt.connectorID)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("extractIdpType(%q) = %q, want %q", tt.connectorID, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1124,6 +1124,21 @@ func (mr *MockStoreMockRecorder) GetAccountServices(ctx, lockStrength, accountID
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockStore)(nil).GetAccountServices), ctx, lockStrength, accountID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockStore)(nil).GetAccountServices), ctx, lockStrength, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetServicesByAccountID mocks base method.
|
||||||
|
func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID)
|
||||||
|
ret0, _ := ret[0].([]*reverseproxy.Service)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServicesByAccountID indicates an expected call of GetServicesByAccountID.
|
||||||
|
func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID)
|
||||||
|
}
|
||||||
|
|
||||||
// GetAccountSettings mocks base method.
|
// GetAccountSettings mocks base method.
|
||||||
func (m *MockStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types2.Settings, error) {
|
func (m *MockStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types2.Settings, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|||||||
576
management/server/types/account_components.go
Normal file
576
management/server/types/account_components.go
Normal file
@@ -0,0 +1,576 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"slices"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (a *Account) GetPeerNetworkMapFromComponents(
|
||||||
|
ctx context.Context,
|
||||||
|
peerID string,
|
||||||
|
peersCustomZone nbdns.CustomZone,
|
||||||
|
accountZones []*zones.Zone,
|
||||||
|
validatedPeersMap map[string]struct{},
|
||||||
|
resourcePolicies map[string][]*Policy,
|
||||||
|
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||||
|
metrics *telemetry.AccountManagerMetrics,
|
||||||
|
groupIDToUserIDs map[string][]string,
|
||||||
|
) *NetworkMap {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
components := a.GetPeerNetworkMapComponents(
|
||||||
|
ctx,
|
||||||
|
peerID,
|
||||||
|
peersCustomZone,
|
||||||
|
accountZones,
|
||||||
|
validatedPeersMap,
|
||||||
|
resourcePolicies,
|
||||||
|
routers,
|
||||||
|
groupIDToUserIDs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if components == nil {
|
||||||
|
return &NetworkMap{Network: a.Network.Copy()}
|
||||||
|
}
|
||||||
|
|
||||||
|
nm := CalculateNetworkMapFromComponents(ctx, components)
|
||||||
|
|
||||||
|
if metrics != nil {
|
||||||
|
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
|
||||||
|
metrics.CountNetworkMapObjects(objectCount)
|
||||||
|
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
|
||||||
|
|
||||||
|
if objectCount > 5000 {
|
||||||
|
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from components, "+
|
||||||
|
"peers: %d, offline peers: %d, routes: %d, firewall rules: %d, route firewall rules: %d",
|
||||||
|
a.Id, objectCount, len(nm.Peers), len(nm.OfflinePeers), len(nm.Routes), len(nm.FirewallRules), len(nm.RoutesFirewallRules))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nm
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) GetPeerNetworkMapComponents(
|
||||||
|
ctx context.Context,
|
||||||
|
peerID string,
|
||||||
|
peersCustomZone nbdns.CustomZone,
|
||||||
|
accountZones []*zones.Zone,
|
||||||
|
validatedPeersMap map[string]struct{},
|
||||||
|
resourcePolicies map[string][]*Policy,
|
||||||
|
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||||
|
groupIDToUserIDs map[string][]string,
|
||||||
|
) *NetworkMapComponents {
|
||||||
|
|
||||||
|
peer := a.Peers[peerID]
|
||||||
|
if peer == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := validatedPeersMap[peerID]; !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
components := &NetworkMapComponents{
|
||||||
|
PeerID: peerID,
|
||||||
|
Network: a.Network.Copy(),
|
||||||
|
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
|
||||||
|
CustomZoneDomain: peersCustomZone.Domain,
|
||||||
|
ResourcePoliciesMap: make(map[string][]*Policy),
|
||||||
|
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
|
||||||
|
NetworkResources: make([]*resourceTypes.NetworkResource, 0),
|
||||||
|
PostureFailedPeers: make(map[string]map[string]struct{}, len(a.PostureChecks)),
|
||||||
|
RouterPeers: make(map[string]*nbpeer.Peer),
|
||||||
|
}
|
||||||
|
|
||||||
|
components.AccountSettings = &AccountSettingsInfo{
|
||||||
|
PeerLoginExpirationEnabled: a.Settings.PeerLoginExpirationEnabled,
|
||||||
|
PeerLoginExpiration: a.Settings.PeerLoginExpiration,
|
||||||
|
PeerInactivityExpirationEnabled: a.Settings.PeerInactivityExpirationEnabled,
|
||||||
|
PeerInactivityExpiration: a.Settings.PeerInactivityExpiration,
|
||||||
|
}
|
||||||
|
|
||||||
|
components.DNSSettings = &a.DNSSettings
|
||||||
|
|
||||||
|
relevantPeers, relevantGroups, relevantPolicies, relevantRoutes, sshReqs := a.getPeersGroupsPoliciesRoutes(ctx, peerID, peer.SSHEnabled, validatedPeersMap, &components.PostureFailedPeers)
|
||||||
|
|
||||||
|
if len(sshReqs.neededGroupIDs) > 0 {
|
||||||
|
components.GroupIDToUserIDs = filterGroupIDToUserIDs(groupIDToUserIDs, sshReqs.neededGroupIDs)
|
||||||
|
}
|
||||||
|
if sshReqs.needAllowedUserIDs {
|
||||||
|
components.AllowedUserIDs = a.getAllowedUserIDs()
|
||||||
|
}
|
||||||
|
|
||||||
|
components.Peers = relevantPeers
|
||||||
|
components.Groups = relevantGroups
|
||||||
|
components.Policies = relevantPolicies
|
||||||
|
components.Routes = relevantRoutes
|
||||||
|
components.AllDNSRecords = filterDNSRecordsByPeers(peersCustomZone.Records, relevantPeers)
|
||||||
|
|
||||||
|
peerGroups := a.GetPeerGroups(peerID)
|
||||||
|
components.AccountZones = filterPeerAppliedZones(ctx, accountZones, peerGroups)
|
||||||
|
|
||||||
|
for _, nsGroup := range a.NameServerGroups {
|
||||||
|
if nsGroup.Enabled {
|
||||||
|
for _, gID := range nsGroup.Groups {
|
||||||
|
if _, found := relevantGroups[gID]; found {
|
||||||
|
components.NameServerGroups = append(components.NameServerGroups, nsGroup)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, resource := range a.NetworkResources {
|
||||||
|
if !resource.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
policies, exists := resourcePolicies[resource.ID]
|
||||||
|
if !exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
addSourcePeers := false
|
||||||
|
|
||||||
|
networkRoutingPeers, routerExists := routers[resource.NetworkID]
|
||||||
|
if routerExists {
|
||||||
|
if _, ok := networkRoutingPeers[peerID]; ok {
|
||||||
|
addSourcePeers = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, policy := range policies {
|
||||||
|
if addSourcePeers {
|
||||||
|
var peers []string
|
||||||
|
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
|
||||||
|
peers = []string{policy.Rules[0].SourceResource.ID}
|
||||||
|
} else {
|
||||||
|
peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
|
||||||
|
}
|
||||||
|
for _, pID := range a.getPostureValidPeersSaveFailed(peers, policy.SourcePostureChecks, validatedPeersMap, &components.PostureFailedPeers) {
|
||||||
|
if _, exists := components.Peers[pID]; !exists {
|
||||||
|
components.Peers[pID] = a.GetPeer(pID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
peerInSources := false
|
||||||
|
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
|
||||||
|
peerInSources = policy.Rules[0].SourceResource.ID == peerID
|
||||||
|
} else {
|
||||||
|
for _, groupID := range policy.SourceGroups() {
|
||||||
|
if group := a.GetGroup(groupID); group != nil && slices.Contains(group.Peers, peerID) {
|
||||||
|
peerInSources = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !peerInSources {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, policy.SourcePostureChecks, peerID)
|
||||||
|
if !isValid && len(pname) > 0 {
|
||||||
|
if _, ok := components.PostureFailedPeers[pname]; !ok {
|
||||||
|
components.PostureFailedPeers[pname] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
components.PostureFailedPeers[pname][peer.ID] = struct{}{}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
addSourcePeers = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range policy.Rules {
|
||||||
|
for _, srcGroupID := range rule.Sources {
|
||||||
|
if g := a.Groups[srcGroupID]; g != nil {
|
||||||
|
if _, exists := components.Groups[srcGroupID]; !exists {
|
||||||
|
components.Groups[srcGroupID] = g
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, dstGroupID := range rule.Destinations {
|
||||||
|
if g := a.Groups[dstGroupID]; g != nil {
|
||||||
|
if _, exists := components.Groups[dstGroupID]; !exists {
|
||||||
|
components.Groups[dstGroupID] = g
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
components.ResourcePoliciesMap[resource.ID] = policies
|
||||||
|
}
|
||||||
|
|
||||||
|
components.RoutersMap[resource.NetworkID] = networkRoutingPeers
|
||||||
|
for peerIDKey := range networkRoutingPeers {
|
||||||
|
if p := a.Peers[peerIDKey]; p != nil {
|
||||||
|
if _, exists := components.RouterPeers[peerIDKey]; !exists {
|
||||||
|
components.RouterPeers[peerIDKey] = p
|
||||||
|
}
|
||||||
|
if _, exists := components.Peers[peerIDKey]; !exists {
|
||||||
|
if _, validated := validatedPeersMap[peerIDKey]; validated {
|
||||||
|
components.Peers[peerIDKey] = p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if addSourcePeers {
|
||||||
|
components.NetworkResources = append(components.NetworkResources, resource)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filterGroupPeers(&components.Groups, components.Peers)
|
||||||
|
filterPostureFailedPeers(&components.PostureFailedPeers, components.Policies, components.ResourcePoliciesMap, components.Peers)
|
||||||
|
|
||||||
|
return components
|
||||||
|
}
|
||||||
|
|
||||||
|
type sshRequirements struct {
|
||||||
|
neededGroupIDs map[string]struct{}
|
||||||
|
needAllowedUserIDs bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) getPeersGroupsPoliciesRoutes(
|
||||||
|
ctx context.Context,
|
||||||
|
peerID string,
|
||||||
|
peerSSHEnabled bool,
|
||||||
|
validatedPeersMap map[string]struct{},
|
||||||
|
postureFailedPeers *map[string]map[string]struct{},
|
||||||
|
) (map[string]*nbpeer.Peer, map[string]*Group, []*Policy, []*route.Route, sshRequirements) {
|
||||||
|
relevantPeerIDs := make(map[string]*nbpeer.Peer, len(a.Peers)/4)
|
||||||
|
relevantGroupIDs := make(map[string]*Group, len(a.Groups)/4)
|
||||||
|
relevantPolicies := make([]*Policy, 0, len(a.Policies))
|
||||||
|
relevantRoutes := make([]*route.Route, 0, len(a.Routes))
|
||||||
|
sshReqs := sshRequirements{neededGroupIDs: make(map[string]struct{})}
|
||||||
|
|
||||||
|
relevantPeerIDs[peerID] = a.GetPeer(peerID)
|
||||||
|
|
||||||
|
for groupID, group := range a.Groups {
|
||||||
|
if slices.Contains(group.Peers, peerID) {
|
||||||
|
relevantGroupIDs[groupID] = a.GetGroup(groupID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
routeAccessControlGroups := make(map[string]struct{})
|
||||||
|
for _, r := range a.Routes {
|
||||||
|
for _, groupID := range r.Groups {
|
||||||
|
relevantGroupIDs[groupID] = a.GetGroup(groupID)
|
||||||
|
}
|
||||||
|
for _, groupID := range r.PeerGroups {
|
||||||
|
relevantGroupIDs[groupID] = a.GetGroup(groupID)
|
||||||
|
}
|
||||||
|
if r.Enabled {
|
||||||
|
for _, groupID := range r.AccessControlGroups {
|
||||||
|
relevantGroupIDs[groupID] = a.GetGroup(groupID)
|
||||||
|
routeAccessControlGroups[groupID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
relevantRoutes = append(relevantRoutes, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, policy := range a.Policies {
|
||||||
|
if !policy.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
policyRelevant := false
|
||||||
|
for _, rule := range policy.Rules {
|
||||||
|
if !rule.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(routeAccessControlGroups) > 0 {
|
||||||
|
for _, destGroupID := range rule.Destinations {
|
||||||
|
if _, needed := routeAccessControlGroups[destGroupID]; needed {
|
||||||
|
policyRelevant = true
|
||||||
|
for _, srcGroupID := range rule.Sources {
|
||||||
|
relevantGroupIDs[srcGroupID] = a.GetGroup(srcGroupID)
|
||||||
|
}
|
||||||
|
for _, dstGroupID := range rule.Destinations {
|
||||||
|
relevantGroupIDs[dstGroupID] = a.GetGroup(dstGroupID)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var sourcePeers, destinationPeers []string
|
||||||
|
var peerInSources, peerInDestinations bool
|
||||||
|
|
||||||
|
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||||
|
sourcePeers = []string{rule.SourceResource.ID}
|
||||||
|
if rule.SourceResource.ID == peerID {
|
||||||
|
peerInSources = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
sourcePeers, peerInSources = a.getPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap, postureFailedPeers)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||||
|
destinationPeers = []string{rule.DestinationResource.ID}
|
||||||
|
if rule.DestinationResource.ID == peerID {
|
||||||
|
peerInDestinations = true
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
destinationPeers, peerInDestinations = a.getPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap, postureFailedPeers)
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerInSources {
|
||||||
|
policyRelevant = true
|
||||||
|
for _, pid := range destinationPeers {
|
||||||
|
relevantPeerIDs[pid] = a.GetPeer(pid)
|
||||||
|
}
|
||||||
|
for _, dstGroupID := range rule.Destinations {
|
||||||
|
relevantGroupIDs[dstGroupID] = a.GetGroup(dstGroupID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerInDestinations {
|
||||||
|
policyRelevant = true
|
||||||
|
for _, pid := range sourcePeers {
|
||||||
|
relevantPeerIDs[pid] = a.GetPeer(pid)
|
||||||
|
}
|
||||||
|
for _, srcGroupID := range rule.Sources {
|
||||||
|
relevantGroupIDs[srcGroupID] = a.GetGroup(srcGroupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||||
|
switch {
|
||||||
|
case len(rule.AuthorizedGroups) > 0:
|
||||||
|
for groupID := range rule.AuthorizedGroups {
|
||||||
|
sshReqs.neededGroupIDs[groupID] = struct{}{}
|
||||||
|
}
|
||||||
|
case rule.AuthorizedUser != "":
|
||||||
|
default:
|
||||||
|
sshReqs.needAllowedUserIDs = true
|
||||||
|
}
|
||||||
|
} else if policyRuleImpliesLegacySSH(rule) && peerSSHEnabled {
|
||||||
|
sshReqs.needAllowedUserIDs = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if policyRelevant {
|
||||||
|
relevantPolicies = append(relevantPolicies, policy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return relevantPeerIDs, relevantGroupIDs, relevantPolicies, relevantRoutes, sshReqs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string,
|
||||||
|
validatedPeersMap map[string]struct{}, postureFailedPeers *map[string]map[string]struct{}) ([]string, bool) {
|
||||||
|
peerInGroups := false
|
||||||
|
filteredPeerIDs := make([]string, 0, len(a.Peers))
|
||||||
|
seenPeerIds := make(map[string]struct{}, len(groups))
|
||||||
|
|
||||||
|
for _, gid := range groups {
|
||||||
|
group := a.GetGroup(gid)
|
||||||
|
if group == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if group.IsGroupAll() || len(groups) == 1 {
|
||||||
|
filteredPeerIDs = filteredPeerIDs[:0]
|
||||||
|
peerInGroups = false
|
||||||
|
for _, pid := range group.Peers {
|
||||||
|
peer, ok := a.Peers[pid]
|
||||||
|
if !ok || peer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := validatedPeersMap[peer.ID]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, sourcePostureChecksIDs, peer.ID)
|
||||||
|
if !isValid && len(pname) > 0 {
|
||||||
|
if _, ok := (*postureFailedPeers)[pname]; !ok {
|
||||||
|
(*postureFailedPeers)[pname] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
(*postureFailedPeers)[pname][peer.ID] = struct{}{}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.ID == peerID {
|
||||||
|
peerInGroups = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
filteredPeerIDs = append(filteredPeerIDs, peer.ID)
|
||||||
|
}
|
||||||
|
return filteredPeerIDs, peerInGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pid := range group.Peers {
|
||||||
|
if _, seen := seenPeerIds[pid]; seen {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seenPeerIds[pid] = struct{}{}
|
||||||
|
peer, ok := a.Peers[pid]
|
||||||
|
if !ok || peer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := validatedPeersMap[peer.ID]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, sourcePostureChecksIDs, peer.ID)
|
||||||
|
if !isValid && len(pname) > 0 {
|
||||||
|
if _, ok := (*postureFailedPeers)[pname]; !ok {
|
||||||
|
(*postureFailedPeers)[pname] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
(*postureFailedPeers)[pname][peer.ID] = struct{}{}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.ID == peerID {
|
||||||
|
peerInGroups = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
filteredPeerIDs = append(filteredPeerIDs, peer.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredPeerIDs, peerInGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) validatePostureChecksOnPeerGetFailed(ctx context.Context, sourcePostureChecksID []string, peerID string) (bool, string) {
|
||||||
|
peer, ok := a.Peers[peerID]
|
||||||
|
if !ok || peer == nil {
|
||||||
|
return false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, postureChecksID := range sourcePostureChecksID {
|
||||||
|
postureChecks := a.GetPostureChecks(postureChecksID)
|
||||||
|
if postureChecks == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, check := range postureChecks.GetChecks() {
|
||||||
|
isValid, _ := check.Check(ctx, *peer)
|
||||||
|
if !isValid {
|
||||||
|
return false, postureChecksID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *Account) getPostureValidPeersSaveFailed(inputPeers []string, postureChecksIDs []string, validatedPeersMap map[string]struct{}, postureFailedPeers *map[string]map[string]struct{}) []string {
|
||||||
|
var dest []string
|
||||||
|
for _, peerID := range inputPeers {
|
||||||
|
if _, validated := validatedPeersMap[peerID]; !validated {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
valid, pname := a.validatePostureChecksOnPeerGetFailed(context.Background(), postureChecksIDs, peerID)
|
||||||
|
if valid {
|
||||||
|
dest = append(dest, peerID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := (*postureFailedPeers)[pname]; !ok {
|
||||||
|
(*postureFailedPeers)[pname] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
(*postureFailedPeers)[pname][peerID] = struct{}{}
|
||||||
|
}
|
||||||
|
return dest
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer) {
|
||||||
|
for groupID, groupInfo := range *groups {
|
||||||
|
filteredPeers := make([]string, 0, len(groupInfo.Peers))
|
||||||
|
for _, pid := range groupInfo.Peers {
|
||||||
|
if _, exists := peers[pid]; exists {
|
||||||
|
filteredPeers = append(filteredPeers, pid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(filteredPeers) == 0 {
|
||||||
|
delete(*groups, groupID)
|
||||||
|
} else if len(filteredPeers) != len(groupInfo.Peers) {
|
||||||
|
ng := groupInfo.Copy()
|
||||||
|
ng.Peers = filteredPeers
|
||||||
|
(*groups)[groupID] = ng
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterPostureFailedPeers(postureFailedPeers *map[string]map[string]struct{}, policies []*Policy, resourcePoliciesMap map[string][]*Policy, peers map[string]*nbpeer.Peer) {
|
||||||
|
if len(*postureFailedPeers) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
referencedPostureChecks := make(map[string]struct{})
|
||||||
|
for _, policy := range policies {
|
||||||
|
for _, checkID := range policy.SourcePostureChecks {
|
||||||
|
referencedPostureChecks[checkID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, resPolicies := range resourcePoliciesMap {
|
||||||
|
for _, policy := range resPolicies {
|
||||||
|
for _, checkID := range policy.SourcePostureChecks {
|
||||||
|
referencedPostureChecks[checkID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for checkID, failedPeers := range *postureFailedPeers {
|
||||||
|
if _, referenced := referencedPostureChecks[checkID]; !referenced {
|
||||||
|
delete(*postureFailedPeers, checkID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for peerID := range failedPeers {
|
||||||
|
if _, exists := peers[peerID]; !exists {
|
||||||
|
delete(failedPeers, peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(failedPeers) == 0 {
|
||||||
|
delete(*postureFailedPeers, checkID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterDNSRecordsByPeers(records []nbdns.SimpleRecord, peers map[string]*nbpeer.Peer) []nbdns.SimpleRecord {
|
||||||
|
if len(records) == 0 || len(peers) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
peerIPs := make(map[string]struct{}, len(peers))
|
||||||
|
for _, peer := range peers {
|
||||||
|
if peer != nil {
|
||||||
|
peerIPs[peer.IP.String()] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filteredRecords := make([]nbdns.SimpleRecord, 0, len(records))
|
||||||
|
for _, record := range records {
|
||||||
|
if _, exists := peerIPs[record.RData]; exists {
|
||||||
|
filteredRecords = append(filteredRecords, record)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredRecords
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterGroupIDToUserIDs(fullMap map[string][]string, neededGroupIDs map[string]struct{}) map[string][]string {
|
||||||
|
if len(neededGroupIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered := make(map[string][]string, len(neededGroupIDs))
|
||||||
|
for groupID := range neededGroupIDs {
|
||||||
|
if users, ok := fullMap[groupID]; ok {
|
||||||
|
filtered[groupID] = users
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
592
management/server/types/networkmap_comparison_test.go
Normal file
592
management/server/types/networkmap_comparison_test.go
Normal file
@@ -0,0 +1,592 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
|
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/management/server/posture"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) {
|
||||||
|
account := createTestAccount()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
peerID := testingPeerID
|
||||||
|
validatedPeersMap := make(map[string]struct{})
|
||||||
|
for i := range numPeers {
|
||||||
|
pid := fmt.Sprintf("peer-%d", i)
|
||||||
|
if pid == offlinePeerID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
validatedPeersMap[pid] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
peersCustomZone := nbdns.CustomZone{}
|
||||||
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
|
routers := account.GetResourceRoutersMap()
|
||||||
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
|
||||||
|
legacyNetworkMap := account.GetPeerNetworkMap(
|
||||||
|
ctx,
|
||||||
|
peerID,
|
||||||
|
peersCustomZone,
|
||||||
|
nil,
|
||||||
|
validatedPeersMap,
|
||||||
|
resourcePolicies,
|
||||||
|
routers,
|
||||||
|
nil,
|
||||||
|
groupIDToUserIDs,
|
||||||
|
)
|
||||||
|
|
||||||
|
components := account.GetPeerNetworkMapComponents(
|
||||||
|
ctx,
|
||||||
|
peerID,
|
||||||
|
peersCustomZone,
|
||||||
|
nil,
|
||||||
|
validatedPeersMap,
|
||||||
|
resourcePolicies,
|
||||||
|
routers,
|
||||||
|
groupIDToUserIDs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if components == nil {
|
||||||
|
t.Fatal("GetPeerNetworkMapComponents returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
|
||||||
|
|
||||||
|
if newNetworkMap == nil {
|
||||||
|
t.Fatal("CalculateNetworkMapFromComponents returned nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
compareNetworkMaps(t, legacyNetworkMap, newNetworkMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) {
|
||||||
|
account := createTestAccount()
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
peerID := testingPeerID
|
||||||
|
validatedPeersMap := make(map[string]struct{})
|
||||||
|
for i := range numPeers {
|
||||||
|
pid := fmt.Sprintf("peer-%d", i)
|
||||||
|
if pid == offlinePeerID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
validatedPeersMap[pid] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
peersCustomZone := nbdns.CustomZone{}
|
||||||
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
|
routers := account.GetResourceRoutersMap()
|
||||||
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
|
||||||
|
legacyNetworkMap := account.GetPeerNetworkMap(
|
||||||
|
ctx,
|
||||||
|
peerID,
|
||||||
|
peersCustomZone,
|
||||||
|
nil,
|
||||||
|
validatedPeersMap,
|
||||||
|
resourcePolicies,
|
||||||
|
routers,
|
||||||
|
nil,
|
||||||
|
groupIDToUserIDs,
|
||||||
|
)
|
||||||
|
|
||||||
|
components := account.GetPeerNetworkMapComponents(
|
||||||
|
ctx,
|
||||||
|
peerID,
|
||||||
|
peersCustomZone,
|
||||||
|
nil,
|
||||||
|
validatedPeersMap,
|
||||||
|
resourcePolicies,
|
||||||
|
routers,
|
||||||
|
groupIDToUserIDs,
|
||||||
|
)
|
||||||
|
|
||||||
|
require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil")
|
||||||
|
|
||||||
|
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
|
||||||
|
require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil")
|
||||||
|
|
||||||
|
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||||
|
normalizeAndSortNetworkMap(newNetworkMap)
|
||||||
|
|
||||||
|
componentsJSON, err := json.MarshalIndent(components, "", " ")
|
||||||
|
require.NoError(t, err, "error marshaling components to JSON")
|
||||||
|
|
||||||
|
legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ")
|
||||||
|
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||||
|
|
||||||
|
newJSON, err := json.MarshalIndent(newNetworkMap, "", " ")
|
||||||
|
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||||
|
|
||||||
|
goldenDir := filepath.Join("testdata", "comparison")
|
||||||
|
err = os.MkdirAll(goldenDir, 0755)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json")
|
||||||
|
err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644)
|
||||||
|
require.NoError(t, err, "error writing legacy golden file")
|
||||||
|
|
||||||
|
newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json")
|
||||||
|
err = os.WriteFile(newGoldenPath, newJSON, 0644)
|
||||||
|
require.NoError(t, err, "error writing components golden file")
|
||||||
|
|
||||||
|
componentsPath := filepath.Join(goldenDir, "components.json")
|
||||||
|
err = os.WriteFile(componentsPath, componentsJSON, 0644)
|
||||||
|
require.NoError(t, err, "error writing components golden file")
|
||||||
|
|
||||||
|
require.JSONEq(t, string(legacyJSON), string(newJSON),
|
||||||
|
"NetworkMaps from legacy and components approaches do not match.\n"+
|
||||||
|
"Legacy JSON saved to: %s\n"+
|
||||||
|
"Components JSON saved to: %s",
|
||||||
|
legacyGoldenPath, newGoldenPath)
|
||||||
|
|
||||||
|
t.Logf("✅ NetworkMaps are identical")
|
||||||
|
t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath)
|
||||||
|
t.Logf(" Components NetworkMap: %s", newGoldenPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeAndSortNetworkMap(nm *NetworkMap) {
|
||||||
|
if nm == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(nm.Peers, func(i, j int) bool {
|
||||||
|
return nm.Peers[i].ID < nm.Peers[j].ID
|
||||||
|
})
|
||||||
|
|
||||||
|
sort.Slice(nm.OfflinePeers, func(i, j int) bool {
|
||||||
|
return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID
|
||||||
|
})
|
||||||
|
|
||||||
|
sort.Slice(nm.Routes, func(i, j int) bool {
|
||||||
|
return string(nm.Routes[i].ID) < string(nm.Routes[j].ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
sort.Slice(nm.FirewallRules, func(i, j int) bool {
|
||||||
|
if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP {
|
||||||
|
return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP
|
||||||
|
}
|
||||||
|
if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction {
|
||||||
|
return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction
|
||||||
|
}
|
||||||
|
if nm.FirewallRules[i].Protocol != nm.FirewallRules[j].Protocol {
|
||||||
|
return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol
|
||||||
|
}
|
||||||
|
if nm.FirewallRules[i].Port != nm.FirewallRules[j].Port {
|
||||||
|
return nm.FirewallRules[i].Port < nm.FirewallRules[j].Port
|
||||||
|
}
|
||||||
|
return nm.FirewallRules[i].PolicyID < nm.FirewallRules[j].PolicyID
|
||||||
|
})
|
||||||
|
|
||||||
|
for i := range nm.RoutesFirewallRules {
|
||||||
|
sort.Strings(nm.RoutesFirewallRules[i].SourceRanges)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool {
|
||||||
|
if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination {
|
||||||
|
return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination
|
||||||
|
}
|
||||||
|
|
||||||
|
minLen := len(nm.RoutesFirewallRules[i].SourceRanges)
|
||||||
|
if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen {
|
||||||
|
minLen = len(nm.RoutesFirewallRules[j].SourceRanges)
|
||||||
|
}
|
||||||
|
for k := 0; k < minLen; k++ {
|
||||||
|
if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] {
|
||||||
|
return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) {
|
||||||
|
return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(nm.RoutesFirewallRules[i].RouteID) != string(nm.RoutesFirewallRules[j].RouteID) {
|
||||||
|
return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nm.RoutesFirewallRules[i].PolicyID != nm.RoutesFirewallRules[j].PolicyID {
|
||||||
|
return nm.RoutesFirewallRules[i].PolicyID < nm.RoutesFirewallRules[j].PolicyID
|
||||||
|
}
|
||||||
|
|
||||||
|
if nm.RoutesFirewallRules[i].Port != nm.RoutesFirewallRules[j].Port {
|
||||||
|
return nm.RoutesFirewallRules[i].Port < nm.RoutesFirewallRules[j].Port
|
||||||
|
}
|
||||||
|
|
||||||
|
return nm.RoutesFirewallRules[i].Protocol < nm.RoutesFirewallRules[j].Protocol
|
||||||
|
})
|
||||||
|
|
||||||
|
if nm.DNSConfig.CustomZones != nil {
|
||||||
|
for i := range nm.DNSConfig.CustomZones {
|
||||||
|
sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool {
|
||||||
|
return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(nm.DNSConfig.NameServerGroups) != 0 {
|
||||||
|
sort.Slice(nm.DNSConfig.NameServerGroups, func(a, b int) bool {
|
||||||
|
return nm.DNSConfig.NameServerGroups[a].Name < nm.DNSConfig.NameServerGroups[b].Name
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareNetworkMaps(t *testing.T, legacy, current *NetworkMap) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
if legacy.Network.Serial != current.Network.Serial {
|
||||||
|
t.Errorf("Network Serial mismatch: legacy=%d, current=%d", legacy.Network.Serial, current.Network.Serial)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(legacy.Peers) != len(current.Peers) {
|
||||||
|
t.Errorf("Peers count mismatch: legacy=%d, current=%d", len(legacy.Peers), len(current.Peers))
|
||||||
|
}
|
||||||
|
|
||||||
|
legacyPeerIDs := make(map[string]bool)
|
||||||
|
for _, p := range legacy.Peers {
|
||||||
|
legacyPeerIDs[p.ID] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range current.Peers {
|
||||||
|
if !legacyPeerIDs[p.ID] {
|
||||||
|
t.Errorf("Current NetworkMap contains peer %s not in legacy", p.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(legacy.OfflinePeers) != len(current.OfflinePeers) {
|
||||||
|
t.Errorf("OfflinePeers count mismatch: legacy=%d, current=%d", len(legacy.OfflinePeers), len(current.OfflinePeers))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(legacy.FirewallRules) != len(current.FirewallRules) {
|
||||||
|
t.Logf("FirewallRules count mismatch: legacy=%d, current=%d", len(legacy.FirewallRules), len(current.FirewallRules))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(legacy.Routes) != len(current.Routes) {
|
||||||
|
t.Logf("Routes count mismatch: legacy=%d, current=%d", len(legacy.Routes), len(current.Routes))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(legacy.RoutesFirewallRules) != len(current.RoutesFirewallRules) {
|
||||||
|
t.Logf("RoutesFirewallRules count mismatch: legacy=%d, current=%d", len(legacy.RoutesFirewallRules), len(current.RoutesFirewallRules))
|
||||||
|
}
|
||||||
|
|
||||||
|
if legacy.DNSConfig.ServiceEnable != current.DNSConfig.ServiceEnable {
|
||||||
|
t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, current=%v", legacy.DNSConfig.ServiceEnable, current.DNSConfig.ServiceEnable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
numPeers = 100
|
||||||
|
devGroupID = "group-dev"
|
||||||
|
opsGroupID = "group-ops"
|
||||||
|
allGroupID = "group-all"
|
||||||
|
routeID = route.ID("route-main")
|
||||||
|
routeHA1ID = route.ID("route-ha-1")
|
||||||
|
routeHA2ID = route.ID("route-ha-2")
|
||||||
|
policyIDDevOps = "policy-dev-ops"
|
||||||
|
policyIDAll = "policy-all"
|
||||||
|
policyIDPosture = "policy-posture"
|
||||||
|
policyIDDrop = "policy-drop"
|
||||||
|
postureCheckID = "posture-check-ver"
|
||||||
|
networkResourceID = "res-database"
|
||||||
|
networkID = "net-database"
|
||||||
|
networkRouterID = "router-database"
|
||||||
|
nameserverGroupID = "ns-group-main"
|
||||||
|
testingPeerID = "peer-60"
|
||||||
|
expiredPeerID = "peer-98"
|
||||||
|
offlinePeerID = "peer-99"
|
||||||
|
routingPeerID = "peer-95"
|
||||||
|
testAccountID = "account-comparison-test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func createTestAccount() *Account {
|
||||||
|
peers := make(map[string]*nbpeer.Peer)
|
||||||
|
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
|
||||||
|
|
||||||
|
for i := range numPeers {
|
||||||
|
peerID := fmt.Sprintf("peer-%d", i)
|
||||||
|
ip := net.IP{100, 64, 0, byte(i + 1)}
|
||||||
|
wtVersion := "0.25.0"
|
||||||
|
if i%2 == 0 {
|
||||||
|
wtVersion = "0.40.0"
|
||||||
|
}
|
||||||
|
|
||||||
|
p := &nbpeer.Peer{
|
||||||
|
ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1),
|
||||||
|
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||||
|
UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerID == expiredPeerID {
|
||||||
|
p.LoginExpirationEnabled = true
|
||||||
|
pastTimestamp := time.Now().Add(-2 * time.Hour)
|
||||||
|
p.LastLogin = &pastTimestamp
|
||||||
|
}
|
||||||
|
|
||||||
|
peers[peerID] = p
|
||||||
|
allGroupPeers = append(allGroupPeers, peerID)
|
||||||
|
if i < numPeers/2 {
|
||||||
|
devGroupPeers = append(devGroupPeers, peerID)
|
||||||
|
} else {
|
||||||
|
opsGroupPeers = append(opsGroupPeers, peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
groups := map[string]*Group{
|
||||||
|
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
|
||||||
|
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
|
||||||
|
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
|
||||||
|
}
|
||||||
|
|
||||||
|
policies := []*Policy{
|
||||||
|
{
|
||||||
|
ID: policyIDAll, Name: "Default-Allow", Enabled: true,
|
||||||
|
Rules: []*PolicyRule{{
|
||||||
|
ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept,
|
||||||
|
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
|
||||||
|
Sources: []string{allGroupID}, Destinations: []string{allGroupID},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true,
|
||||||
|
Rules: []*PolicyRule{{
|
||||||
|
ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept,
|
||||||
|
Protocol: PolicyRuleProtocolTCP, Bidirectional: false,
|
||||||
|
PortRanges: []RulePortRange{{Start: 8080, End: 8090}},
|
||||||
|
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true,
|
||||||
|
Rules: []*PolicyRule{{
|
||||||
|
ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop,
|
||||||
|
Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
|
||||||
|
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true,
|
||||||
|
SourcePostureChecks: []string{postureCheckID},
|
||||||
|
Rules: []*PolicyRule{{
|
||||||
|
ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept,
|
||||||
|
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
|
||||||
|
Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
routes := map[route.ID]*route.Route{
|
||||||
|
routeID: {
|
||||||
|
ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"),
|
||||||
|
Peer: peers["peer-75"].Key,
|
||||||
|
PeerID: "peer-75",
|
||||||
|
Description: "Route to internal resource", Enabled: true,
|
||||||
|
PeerGroups: []string{devGroupID, opsGroupID},
|
||||||
|
Groups: []string{devGroupID, opsGroupID},
|
||||||
|
AccessControlGroups: []string{devGroupID},
|
||||||
|
},
|
||||||
|
routeHA1ID: {
|
||||||
|
ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
|
||||||
|
Peer: peers["peer-80"].Key,
|
||||||
|
PeerID: "peer-80",
|
||||||
|
Description: "HA Route 1", Enabled: true, Metric: 1000,
|
||||||
|
PeerGroups: []string{allGroupID},
|
||||||
|
Groups: []string{allGroupID},
|
||||||
|
AccessControlGroups: []string{allGroupID},
|
||||||
|
},
|
||||||
|
routeHA2ID: {
|
||||||
|
ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
|
||||||
|
Peer: peers["peer-90"].Key,
|
||||||
|
PeerID: "peer-90",
|
||||||
|
Description: "HA Route 2", Enabled: true, Metric: 900,
|
||||||
|
PeerGroups: []string{devGroupID, opsGroupID},
|
||||||
|
Groups: []string{devGroupID, opsGroupID},
|
||||||
|
AccessControlGroups: []string{allGroupID},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
account := &Account{
|
||||||
|
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
|
||||||
|
Network: &Network{
|
||||||
|
Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
|
||||||
|
},
|
||||||
|
DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}},
|
||||||
|
NameServerGroups: map[string]*nbdns.NameServerGroup{
|
||||||
|
nameserverGroupID: {
|
||||||
|
ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID},
|
||||||
|
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
PostureChecks: []*posture.Checks{
|
||||||
|
{ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{
|
||||||
|
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
NetworkResources: []*resourceTypes.NetworkResource{
|
||||||
|
{ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"},
|
||||||
|
},
|
||||||
|
Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}},
|
||||||
|
NetworkRouters: []*routerTypes.NetworkRouter{
|
||||||
|
{ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID},
|
||||||
|
},
|
||||||
|
Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range account.Policies {
|
||||||
|
p.AccountID = account.Id
|
||||||
|
}
|
||||||
|
for _, r := range account.Routes {
|
||||||
|
r.AccountID = account.Id
|
||||||
|
}
|
||||||
|
|
||||||
|
return account
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLegacyNetworkMap(b *testing.B) {
|
||||||
|
account := createTestAccount()
|
||||||
|
ctx := context.Background()
|
||||||
|
peerID := testingPeerID
|
||||||
|
validatedPeersMap := make(map[string]struct{})
|
||||||
|
for i := range numPeers {
|
||||||
|
pid := fmt.Sprintf("peer-%d", i)
|
||||||
|
if pid != offlinePeerID {
|
||||||
|
validatedPeersMap[pid] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
peersCustomZone := nbdns.CustomZone{}
|
||||||
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
|
routers := account.GetResourceRoutersMap()
|
||||||
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = account.GetPeerNetworkMap(
|
||||||
|
ctx,
|
||||||
|
peerID,
|
||||||
|
peersCustomZone,
|
||||||
|
nil,
|
||||||
|
validatedPeersMap,
|
||||||
|
resourcePolicies,
|
||||||
|
routers,
|
||||||
|
nil,
|
||||||
|
groupIDToUserIDs,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkComponentsNetworkMap(b *testing.B) {
|
||||||
|
account := createTestAccount()
|
||||||
|
ctx := context.Background()
|
||||||
|
peerID := testingPeerID
|
||||||
|
validatedPeersMap := make(map[string]struct{})
|
||||||
|
for i := range numPeers {
|
||||||
|
pid := fmt.Sprintf("peer-%d", i)
|
||||||
|
if pid != offlinePeerID {
|
||||||
|
validatedPeersMap[pid] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
peersCustomZone := nbdns.CustomZone{}
|
||||||
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
|
routers := account.GetResourceRoutersMap()
|
||||||
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
components := account.GetPeerNetworkMapComponents(
|
||||||
|
ctx,
|
||||||
|
peerID,
|
||||||
|
peersCustomZone,
|
||||||
|
nil,
|
||||||
|
validatedPeersMap,
|
||||||
|
resourcePolicies,
|
||||||
|
routers,
|
||||||
|
groupIDToUserIDs,
|
||||||
|
)
|
||||||
|
_ = CalculateNetworkMapFromComponents(ctx, components)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkComponentsCreation(b *testing.B) {
|
||||||
|
account := createTestAccount()
|
||||||
|
ctx := context.Background()
|
||||||
|
peerID := testingPeerID
|
||||||
|
validatedPeersMap := make(map[string]struct{})
|
||||||
|
for i := range numPeers {
|
||||||
|
pid := fmt.Sprintf("peer-%d", i)
|
||||||
|
if pid != offlinePeerID {
|
||||||
|
validatedPeersMap[pid] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
peersCustomZone := nbdns.CustomZone{}
|
||||||
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
|
routers := account.GetResourceRoutersMap()
|
||||||
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = account.GetPeerNetworkMapComponents(
|
||||||
|
ctx,
|
||||||
|
peerID,
|
||||||
|
peersCustomZone,
|
||||||
|
nil,
|
||||||
|
validatedPeersMap,
|
||||||
|
resourcePolicies,
|
||||||
|
routers,
|
||||||
|
groupIDToUserIDs,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkCalculationFromComponents(b *testing.B) {
|
||||||
|
account := createTestAccount()
|
||||||
|
ctx := context.Background()
|
||||||
|
peerID := testingPeerID
|
||||||
|
validatedPeersMap := make(map[string]struct{})
|
||||||
|
for i := range numPeers {
|
||||||
|
pid := fmt.Sprintf("peer-%d", i)
|
||||||
|
if pid != offlinePeerID {
|
||||||
|
validatedPeersMap[pid] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
peersCustomZone := nbdns.CustomZone{}
|
||||||
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
|
routers := account.GetResourceRoutersMap()
|
||||||
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
|
||||||
|
components := account.GetPeerNetworkMapComponents(
|
||||||
|
ctx,
|
||||||
|
peerID,
|
||||||
|
peersCustomZone,
|
||||||
|
nil,
|
||||||
|
validatedPeersMap,
|
||||||
|
resourcePolicies,
|
||||||
|
routers,
|
||||||
|
groupIDToUserIDs,
|
||||||
|
)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = CalculateNetworkMapFromComponents(ctx, components)
|
||||||
|
}
|
||||||
|
}
|
||||||
938
management/server/types/networkmap_components.go
Normal file
938
management/server/types/networkmap_components.go
Normal file
@@ -0,0 +1,938 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"maps"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
const EnvNewNetworkMapCompacted = "NB_NETWORK_MAP_COMPACTED"
|
||||||
|
|
||||||
|
type NetworkMapComponents struct {
|
||||||
|
PeerID string
|
||||||
|
|
||||||
|
Network *Network
|
||||||
|
AccountSettings *AccountSettingsInfo
|
||||||
|
DNSSettings *DNSSettings
|
||||||
|
CustomZoneDomain string
|
||||||
|
|
||||||
|
Peers map[string]*nbpeer.Peer
|
||||||
|
Groups map[string]*Group
|
||||||
|
Policies []*Policy
|
||||||
|
Routes []*route.Route
|
||||||
|
NameServerGroups []*nbdns.NameServerGroup
|
||||||
|
AllDNSRecords []nbdns.SimpleRecord
|
||||||
|
AccountZones []nbdns.CustomZone
|
||||||
|
ResourcePoliciesMap map[string][]*Policy
|
||||||
|
RoutersMap map[string]map[string]*routerTypes.NetworkRouter
|
||||||
|
NetworkResources []*resourceTypes.NetworkResource
|
||||||
|
|
||||||
|
GroupIDToUserIDs map[string][]string
|
||||||
|
AllowedUserIDs map[string]struct{}
|
||||||
|
PostureFailedPeers map[string]map[string]struct{}
|
||||||
|
|
||||||
|
RouterPeers map[string]*nbpeer.Peer
|
||||||
|
}
|
||||||
|
|
||||||
|
type AccountSettingsInfo struct {
|
||||||
|
PeerLoginExpirationEnabled bool
|
||||||
|
PeerLoginExpiration time.Duration
|
||||||
|
PeerInactivityExpirationEnabled bool
|
||||||
|
PeerInactivityExpiration time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) GetPeerInfo(peerID string) *nbpeer.Peer {
|
||||||
|
return c.Peers[peerID]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) GetRouterPeerInfo(peerID string) *nbpeer.Peer {
|
||||||
|
return c.RouterPeers[peerID]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) GetGroupInfo(groupID string) *Group {
|
||||||
|
return c.Groups[groupID]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) IsPeerInGroup(peerID, groupID string) bool {
|
||||||
|
group := c.GetGroupInfo(groupID)
|
||||||
|
if group == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return slices.Contains(group.Peers, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) GetPeerGroups(peerID string) map[string]struct{} {
|
||||||
|
groups := make(map[string]struct{})
|
||||||
|
for groupID, group := range c.Groups {
|
||||||
|
if slices.Contains(group.Peers, peerID) {
|
||||||
|
groups[groupID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return groups
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) ValidatePostureChecksOnPeer(peerID string, postureCheckIDs []string) bool {
|
||||||
|
_, exists := c.Peers[peerID]
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(postureCheckIDs) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
for _, checkID := range postureCheckIDs {
|
||||||
|
if failedPeers, exists := c.PostureFailedPeers[checkID]; exists {
|
||||||
|
if _, failed := failedPeers[peerID]; failed {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func CalculateNetworkMapFromComponents(ctx context.Context, components *NetworkMapComponents) *NetworkMap {
|
||||||
|
return components.Calculate(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap {
|
||||||
|
targetPeerID := c.PeerID
|
||||||
|
|
||||||
|
peerGroups := c.GetPeerGroups(targetPeerID)
|
||||||
|
|
||||||
|
aclPeers, firewallRules, authorizedUsers, sshEnabled := c.getPeerConnectionResources(targetPeerID)
|
||||||
|
|
||||||
|
peersToConnect, expiredPeers := c.filterPeersByLoginExpiration(aclPeers)
|
||||||
|
|
||||||
|
routesUpdate := c.getRoutesToSync(targetPeerID, peersToConnect, peerGroups)
|
||||||
|
routesFirewallRules := c.getPeerRoutesFirewallRules(ctx, targetPeerID)
|
||||||
|
|
||||||
|
isRouter, networkResourcesRoutes, sourcePeers := c.getNetworkResourcesRoutesToSync(targetPeerID)
|
||||||
|
var networkResourcesFirewallRules []*RouteFirewallRule
|
||||||
|
if isRouter {
|
||||||
|
networkResourcesFirewallRules = c.getPeerNetworkResourceFirewallRules(ctx, targetPeerID, networkResourcesRoutes)
|
||||||
|
}
|
||||||
|
|
||||||
|
peersToConnectIncludingRouters := c.addNetworksRoutingPeers(
|
||||||
|
networkResourcesRoutes,
|
||||||
|
targetPeerID,
|
||||||
|
peersToConnect,
|
||||||
|
expiredPeers,
|
||||||
|
isRouter,
|
||||||
|
sourcePeers,
|
||||||
|
)
|
||||||
|
|
||||||
|
dnsManagementStatus := c.getPeerDNSManagementStatus(targetPeerID)
|
||||||
|
dnsUpdate := nbdns.Config{
|
||||||
|
ServiceEnable: dnsManagementStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
if dnsManagementStatus {
|
||||||
|
var customZones []nbdns.CustomZone
|
||||||
|
|
||||||
|
if c.CustomZoneDomain != "" && len(c.AllDNSRecords) > 0 {
|
||||||
|
customZones = append(customZones, nbdns.CustomZone{
|
||||||
|
Domain: c.CustomZoneDomain,
|
||||||
|
Records: c.AllDNSRecords,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
customZones = append(customZones, c.AccountZones...)
|
||||||
|
|
||||||
|
dnsUpdate.CustomZones = customZones
|
||||||
|
dnsUpdate.NameServerGroups = c.getPeerNSGroups(targetPeerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &NetworkMap{
|
||||||
|
Peers: peersToConnectIncludingRouters,
|
||||||
|
Network: c.Network.Copy(),
|
||||||
|
Routes: append(networkResourcesRoutes, routesUpdate...),
|
||||||
|
DNSConfig: dnsUpdate,
|
||||||
|
OfflinePeers: expiredPeers,
|
||||||
|
FirewallRules: firewallRules,
|
||||||
|
RoutesFirewallRules: append(networkResourcesFirewallRules, routesFirewallRules...),
|
||||||
|
AuthorizedUsers: authorizedUsers,
|
||||||
|
EnableSSH: sshEnabled,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getPeerConnectionResources(targetPeerID string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) {
|
||||||
|
targetPeer := c.GetPeerInfo(targetPeerID)
|
||||||
|
if targetPeer == nil {
|
||||||
|
return nil, nil, nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
generateResources, getAccumulatedResources := c.connResourcesGenerator(targetPeer)
|
||||||
|
authorizedUsers := make(map[string]map[string]struct{})
|
||||||
|
sshEnabled := false
|
||||||
|
|
||||||
|
for _, policy := range c.Policies {
|
||||||
|
if !policy.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range policy.Rules {
|
||||||
|
if !rule.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var sourcePeers, destinationPeers []*nbpeer.Peer
|
||||||
|
var peerInSources, peerInDestinations bool
|
||||||
|
|
||||||
|
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||||
|
sourcePeers, peerInSources = c.getPeerFromResource(rule.SourceResource, targetPeerID)
|
||||||
|
} else {
|
||||||
|
sourcePeers, peerInSources = c.getAllPeersFromGroups(rule.Sources, targetPeerID, policy.SourcePostureChecks)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
|
||||||
|
destinationPeers, peerInDestinations = c.getPeerFromResource(rule.DestinationResource, targetPeerID)
|
||||||
|
} else {
|
||||||
|
destinationPeers, peerInDestinations = c.getAllPeersFromGroups(rule.Destinations, targetPeerID, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.Bidirectional {
|
||||||
|
if peerInSources {
|
||||||
|
generateResources(rule, destinationPeers, FirewallRuleDirectionIN)
|
||||||
|
}
|
||||||
|
if peerInDestinations {
|
||||||
|
generateResources(rule, sourcePeers, FirewallRuleDirectionOUT)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerInSources {
|
||||||
|
generateResources(rule, destinationPeers, FirewallRuleDirectionOUT)
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerInDestinations {
|
||||||
|
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
|
||||||
|
}
|
||||||
|
|
||||||
|
if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||||
|
sshEnabled = true
|
||||||
|
switch {
|
||||||
|
case len(rule.AuthorizedGroups) > 0:
|
||||||
|
for groupID, localUsers := range rule.AuthorizedGroups {
|
||||||
|
userIDs, ok := c.GroupIDToUserIDs[groupID]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(localUsers) == 0 {
|
||||||
|
localUsers = []string{auth.Wildcard}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, localUser := range localUsers {
|
||||||
|
if authorizedUsers[localUser] == nil {
|
||||||
|
authorizedUsers[localUser] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
for _, userID := range userIDs {
|
||||||
|
authorizedUsers[localUser][userID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case rule.AuthorizedUser != "":
|
||||||
|
if authorizedUsers[auth.Wildcard] == nil {
|
||||||
|
authorizedUsers[auth.Wildcard] = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{}
|
||||||
|
default:
|
||||||
|
authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs()
|
||||||
|
}
|
||||||
|
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && targetPeer.SSHEnabled {
|
||||||
|
sshEnabled = true
|
||||||
|
authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
peers, fwRules := getAccumulatedResources()
|
||||||
|
return peers, fwRules, authorizedUsers, sshEnabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getAllowedUserIDs() map[string]struct{} {
|
||||||
|
if c.AllowedUserIDs != nil {
|
||||||
|
result := make(map[string]struct{}, len(c.AllowedUserIDs))
|
||||||
|
maps.Copy(result, c.AllowedUserIDs)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
return make(map[string]struct{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
|
||||||
|
rulesExists := make(map[string]struct{})
|
||||||
|
peersExists := make(map[string]struct{})
|
||||||
|
rules := make([]*FirewallRule, 0)
|
||||||
|
peers := make([]*nbpeer.Peer, 0)
|
||||||
|
|
||||||
|
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
|
||||||
|
for _, peer := range groupPeers {
|
||||||
|
if peer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := peersExists[peer.ID]; !ok {
|
||||||
|
peers = append(peers, peer)
|
||||||
|
peersExists[peer.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
protocol := rule.Protocol
|
||||||
|
if protocol == PolicyRuleProtocolNetbirdSSH {
|
||||||
|
protocol = PolicyRuleProtocolTCP
|
||||||
|
}
|
||||||
|
|
||||||
|
fr := FirewallRule{
|
||||||
|
PolicyID: rule.ID,
|
||||||
|
PeerIP: net.IP(peer.IP).String(),
|
||||||
|
Direction: direction,
|
||||||
|
Action: string(rule.Action),
|
||||||
|
Protocol: string(protocol),
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||||
|
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
|
||||||
|
if _, ok := rulesExists[ruleID]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rulesExists[ruleID] = struct{}{}
|
||||||
|
|
||||||
|
if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 {
|
||||||
|
rules = append(rules, &fr)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rules = append(rules, expandPortsAndRanges(fr, &PolicyRule{
|
||||||
|
ID: rule.ID,
|
||||||
|
Ports: rule.Ports,
|
||||||
|
PortRanges: rule.PortRanges,
|
||||||
|
Protocol: rule.Protocol,
|
||||||
|
Action: rule.Action,
|
||||||
|
}, targetPeer)...)
|
||||||
|
}
|
||||||
|
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
|
||||||
|
return peers, rules
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getAllPeersFromGroups(groups []string, peerID string, sourcePostureChecksIDs []string) ([]*nbpeer.Peer, bool) {
|
||||||
|
peerInGroups := false
|
||||||
|
uniquePeerIDs := c.getUniquePeerIDsFromGroupsIDs(groups)
|
||||||
|
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
|
||||||
|
|
||||||
|
for _, p := range uniquePeerIDs {
|
||||||
|
peerInfo := c.GetPeerInfo(p)
|
||||||
|
if peerInfo == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := c.Peers[p]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !c.ValidatePostureChecksOnPeer(p, sourcePostureChecksIDs) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if p == peerID {
|
||||||
|
peerInGroups = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
filteredPeers = append(filteredPeers, peerInfo)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredPeers, peerInGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getUniquePeerIDsFromGroupsIDs(groups []string) []string {
|
||||||
|
peerIDs := make(map[string]struct{}, len(groups))
|
||||||
|
for _, groupID := range groups {
|
||||||
|
group := c.GetGroupInfo(groupID)
|
||||||
|
if group == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if group.IsGroupAll() || len(groups) == 1 {
|
||||||
|
return group.Peers
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peerID := range group.Peers {
|
||||||
|
peerIDs[peerID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := make([]string, 0, len(peerIDs))
|
||||||
|
for peerID := range peerIDs {
|
||||||
|
ids = append(ids, peerID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ids
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getPeerFromResource(resource Resource, peerID string) ([]*nbpeer.Peer, bool) {
|
||||||
|
if resource.ID == peerID {
|
||||||
|
return []*nbpeer.Peer{}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
peerInfo := c.GetPeerInfo(resource.ID)
|
||||||
|
if peerInfo == nil {
|
||||||
|
return []*nbpeer.Peer{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return []*nbpeer.Peer{peerInfo}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) filterPeersByLoginExpiration(aclPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*nbpeer.Peer) {
|
||||||
|
var peersToConnect []*nbpeer.Peer
|
||||||
|
var expiredPeers []*nbpeer.Peer
|
||||||
|
|
||||||
|
for _, p := range aclPeers {
|
||||||
|
expired, _ := p.LoginExpired(c.AccountSettings.PeerLoginExpiration)
|
||||||
|
if c.AccountSettings.PeerLoginExpirationEnabled && expired {
|
||||||
|
expiredPeers = append(expiredPeers, p)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
peersToConnect = append(peersToConnect, p)
|
||||||
|
}
|
||||||
|
|
||||||
|
return peersToConnect, expiredPeers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getPeerDNSManagementStatus(peerID string) bool {
|
||||||
|
peerGroups := c.GetPeerGroups(peerID)
|
||||||
|
enabled := true
|
||||||
|
for _, groupID := range c.DNSSettings.DisabledManagementGroups {
|
||||||
|
if _, found := peerGroups[groupID]; found {
|
||||||
|
enabled = false
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getPeerNSGroups(peerID string) []*nbdns.NameServerGroup {
|
||||||
|
groupList := c.GetPeerGroups(peerID)
|
||||||
|
|
||||||
|
var peerNSGroups []*nbdns.NameServerGroup
|
||||||
|
|
||||||
|
for _, nsGroup := range c.NameServerGroups {
|
||||||
|
if !nsGroup.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, gID := range nsGroup.Groups {
|
||||||
|
_, found := groupList[gID]
|
||||||
|
if found {
|
||||||
|
targetPeerInfo := c.GetPeerInfo(peerID)
|
||||||
|
if targetPeerInfo != nil && !c.peerIsNameserver(targetPeerInfo, nsGroup) {
|
||||||
|
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return peerNSGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) peerIsNameserver(peerInfo *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
|
||||||
|
for _, ns := range nsGroup.NameServers {
|
||||||
|
if peerInfo.IP.String() == ns.IP.String() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route {
|
||||||
|
routes, peerDisabledRoutes := c.getRoutingPeerRoutes(peerID)
|
||||||
|
peerRoutesMembership := make(LookupMap)
|
||||||
|
for _, r := range append(routes, peerDisabledRoutes...) {
|
||||||
|
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range aclPeers {
|
||||||
|
activeRoutes, _ := c.getRoutingPeerRoutes(peer.ID)
|
||||||
|
groupFilteredRoutes := c.filterRoutesByGroups(activeRoutes, peerGroups)
|
||||||
|
filteredRoutes := c.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
|
||||||
|
routes = append(routes, filteredRoutes...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
|
||||||
|
peerInfo := c.GetPeerInfo(peerID)
|
||||||
|
if peerInfo == nil {
|
||||||
|
peerInfo = c.GetRouterPeerInfo(peerID)
|
||||||
|
}
|
||||||
|
if peerInfo == nil {
|
||||||
|
return enabledRoutes, disabledRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
seenRoute := make(map[route.ID]struct{})
|
||||||
|
|
||||||
|
takeRoute := func(r *route.Route) {
|
||||||
|
if _, ok := seenRoute[r.ID]; ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seenRoute[r.ID] = struct{}{}
|
||||||
|
|
||||||
|
routeObj := c.copyRoute(r)
|
||||||
|
routeObj.Peer = peerInfo.Key
|
||||||
|
|
||||||
|
if r.Enabled {
|
||||||
|
enabledRoutes = append(enabledRoutes, routeObj)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
disabledRoutes = append(disabledRoutes, routeObj)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range c.Routes {
|
||||||
|
for _, groupID := range r.PeerGroups {
|
||||||
|
group := c.GetGroupInfo(groupID)
|
||||||
|
if group == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, id := range group.Peers {
|
||||||
|
if id != peerID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
newPeerRoute := c.copyRoute(r)
|
||||||
|
newPeerRoute.Peer = id
|
||||||
|
newPeerRoute.PeerGroups = nil
|
||||||
|
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id)
|
||||||
|
takeRoute(newPeerRoute)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if r.Peer == peerID {
|
||||||
|
takeRoute(c.copyRoute(r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return enabledRoutes, disabledRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) copyRoute(r *route.Route) *route.Route {
|
||||||
|
var groups, accessControlGroups, peerGroups []string
|
||||||
|
var domains domain.List
|
||||||
|
|
||||||
|
if r.Groups != nil {
|
||||||
|
groups = append([]string{}, r.Groups...)
|
||||||
|
}
|
||||||
|
if r.AccessControlGroups != nil {
|
||||||
|
accessControlGroups = append([]string{}, r.AccessControlGroups...)
|
||||||
|
}
|
||||||
|
if r.PeerGroups != nil {
|
||||||
|
peerGroups = append([]string{}, r.PeerGroups...)
|
||||||
|
}
|
||||||
|
if r.Domains != nil {
|
||||||
|
domains = append(domain.List{}, r.Domains...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &route.Route{
|
||||||
|
ID: r.ID,
|
||||||
|
AccountID: r.AccountID,
|
||||||
|
Network: r.Network,
|
||||||
|
NetworkType: r.NetworkType,
|
||||||
|
Description: r.Description,
|
||||||
|
Peer: r.Peer,
|
||||||
|
PeerID: r.PeerID,
|
||||||
|
Metric: r.Metric,
|
||||||
|
Masquerade: r.Masquerade,
|
||||||
|
NetID: r.NetID,
|
||||||
|
Enabled: r.Enabled,
|
||||||
|
Groups: groups,
|
||||||
|
AccessControlGroups: accessControlGroups,
|
||||||
|
PeerGroups: peerGroups,
|
||||||
|
Domains: domains,
|
||||||
|
KeepRoute: r.KeepRoute,
|
||||||
|
SkipAutoApply: r.SkipAutoApply,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
|
||||||
|
var filteredRoutes []*route.Route
|
||||||
|
for _, r := range routes {
|
||||||
|
for _, groupID := range r.Groups {
|
||||||
|
_, found := groupListMap[groupID]
|
||||||
|
if found {
|
||||||
|
filteredRoutes = append(filteredRoutes, r)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filteredRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route {
|
||||||
|
var filteredRoutes []*route.Route
|
||||||
|
for _, r := range routes {
|
||||||
|
_, found := peerMemberships[string(r.GetHAUniqueID())]
|
||||||
|
if !found {
|
||||||
|
filteredRoutes = append(filteredRoutes, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return filteredRoutes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getPeerRoutesFirewallRules(ctx context.Context, peerID string) []*RouteFirewallRule {
|
||||||
|
routesFirewallRules := make([]*RouteFirewallRule, 0)
|
||||||
|
|
||||||
|
enabledRoutes, _ := c.getRoutingPeerRoutes(peerID)
|
||||||
|
for _, r := range enabledRoutes {
|
||||||
|
if len(r.AccessControlGroups) == 0 {
|
||||||
|
defaultPermit := c.getDefaultPermit(r)
|
||||||
|
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
distributionPeers := c.getDistributionGroupsPeers(r)
|
||||||
|
|
||||||
|
for _, accessGroup := range r.AccessControlGroups {
|
||||||
|
policies := c.getAllRoutePoliciesFromGroups([]string{accessGroup})
|
||||||
|
rules := c.getRouteFirewallRules(ctx, peerID, policies, r, distributionPeers)
|
||||||
|
routesFirewallRules = append(routesFirewallRules, rules...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routesFirewallRules
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getDefaultPermit(r *route.Route) []*RouteFirewallRule {
|
||||||
|
var rules []*RouteFirewallRule
|
||||||
|
|
||||||
|
sources := []string{"0.0.0.0/0"}
|
||||||
|
if r.Network.Addr().Is6() {
|
||||||
|
sources = []string{"::/0"}
|
||||||
|
}
|
||||||
|
|
||||||
|
rule := RouteFirewallRule{
|
||||||
|
SourceRanges: sources,
|
||||||
|
Action: string(PolicyTrafficActionAccept),
|
||||||
|
Destination: r.Network.String(),
|
||||||
|
Protocol: string(PolicyRuleProtocolALL),
|
||||||
|
Domains: r.Domains,
|
||||||
|
IsDynamic: r.IsDynamic(),
|
||||||
|
RouteID: r.ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
rules = append(rules, &rule)
|
||||||
|
|
||||||
|
if r.IsDynamic() {
|
||||||
|
ruleV6 := rule
|
||||||
|
ruleV6.SourceRanges = []string{"::/0"}
|
||||||
|
rules = append(rules, &ruleV6)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rules
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getDistributionGroupsPeers(r *route.Route) map[string]struct{} {
|
||||||
|
distPeers := make(map[string]struct{})
|
||||||
|
for _, id := range r.Groups {
|
||||||
|
group := c.GetGroupInfo(id)
|
||||||
|
if group == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pID := range group.Peers {
|
||||||
|
distPeers[pID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return distPeers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy {
|
||||||
|
routePolicies := make([]*Policy, 0)
|
||||||
|
for _, groupID := range accessControlGroups {
|
||||||
|
for _, policy := range c.Policies {
|
||||||
|
for _, rule := range policy.Rules {
|
||||||
|
if slices.Contains(rule.Destinations, groupID) {
|
||||||
|
routePolicies = append(routePolicies, policy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routePolicies
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, distributionPeers map[string]struct{}) []*RouteFirewallRule {
|
||||||
|
var fwRules []*RouteFirewallRule
|
||||||
|
for _, policy := range policies {
|
||||||
|
if !policy.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, rule := range policy.Rules {
|
||||||
|
if !rule.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
rulePeers := c.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers)
|
||||||
|
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN)
|
||||||
|
fwRules = append(fwRules, rules...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fwRules
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}) []*nbpeer.Peer {
|
||||||
|
distPeersWithPolicy := make(map[string]struct{})
|
||||||
|
for _, id := range rule.Sources {
|
||||||
|
group := c.GetGroupInfo(id)
|
||||||
|
if group == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, pID := range group.Peers {
|
||||||
|
if pID == peerID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, distPeer := distributionPeers[pID]
|
||||||
|
_, valid := c.Peers[pID]
|
||||||
|
if distPeer && valid && c.ValidatePostureChecksOnPeer(pID, postureChecks) {
|
||||||
|
distPeersWithPolicy[pID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||||
|
_, distPeer := distributionPeers[rule.SourceResource.ID]
|
||||||
|
_, valid := c.Peers[rule.SourceResource.ID]
|
||||||
|
if distPeer && valid && c.ValidatePostureChecksOnPeer(rule.SourceResource.ID, postureChecks) {
|
||||||
|
distPeersWithPolicy[rule.SourceResource.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy))
|
||||||
|
for pID := range distPeersWithPolicy {
|
||||||
|
peerInfo := c.GetPeerInfo(pID)
|
||||||
|
if peerInfo == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
distributionGroupPeers = append(distributionGroupPeers, peerInfo)
|
||||||
|
}
|
||||||
|
return distributionGroupPeers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getNetworkResourcesRoutesToSync(peerID string) (bool, []*route.Route, map[string]struct{}) {
|
||||||
|
var isRoutingPeer bool
|
||||||
|
var routes []*route.Route
|
||||||
|
allSourcePeers := make(map[string]struct{})
|
||||||
|
|
||||||
|
for _, resource := range c.NetworkResources {
|
||||||
|
if !resource.Enabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var addSourcePeers bool
|
||||||
|
|
||||||
|
networkRoutingPeers, exists := c.RoutersMap[resource.NetworkID]
|
||||||
|
if exists {
|
||||||
|
if router, ok := networkRoutingPeers[peerID]; ok {
|
||||||
|
isRoutingPeer, addSourcePeers = true, true
|
||||||
|
routes = append(routes, c.getNetworkResourcesRoutes(resource, peerID, router)...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
addedResourceRoute := false
|
||||||
|
for _, policy := range c.ResourcePoliciesMap[resource.ID] {
|
||||||
|
var peers []string
|
||||||
|
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
|
||||||
|
peers = []string{policy.Rules[0].SourceResource.ID}
|
||||||
|
} else {
|
||||||
|
peers = c.getUniquePeerIDsFromGroupsIDs(policy.SourceGroups())
|
||||||
|
}
|
||||||
|
if addSourcePeers {
|
||||||
|
for _, pID := range c.getPostureValidPeers(peers, policy.SourcePostureChecks) {
|
||||||
|
allSourcePeers[pID] = struct{}{}
|
||||||
|
}
|
||||||
|
} else if slices.Contains(peers, peerID) && c.ValidatePostureChecksOnPeer(peerID, policy.SourcePostureChecks) {
|
||||||
|
for peerId, router := range networkRoutingPeers {
|
||||||
|
routes = append(routes, c.getNetworkResourcesRoutes(resource, peerId, router)...)
|
||||||
|
}
|
||||||
|
addedResourceRoute = true
|
||||||
|
}
|
||||||
|
if addedResourceRoute {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return isRoutingPeer, routes, allSourcePeers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerID string, router *routerTypes.NetworkRouter) []*route.Route {
|
||||||
|
resourceAppliedPolicies := c.ResourcePoliciesMap[resource.ID]
|
||||||
|
|
||||||
|
var routes []*route.Route
|
||||||
|
if len(resourceAppliedPolicies) > 0 {
|
||||||
|
peerInfo := c.GetPeerInfo(peerID)
|
||||||
|
if peerInfo != nil {
|
||||||
|
routes = append(routes, c.networkResourceToRoute(resource, peerInfo, router))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routes
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) networkResourceToRoute(resource *resourceTypes.NetworkResource, peer *nbpeer.Peer, router *routerTypes.NetworkRouter) *route.Route {
|
||||||
|
r := &route.Route{
|
||||||
|
ID: route.ID(resource.ID + ":" + peer.ID),
|
||||||
|
AccountID: resource.AccountID,
|
||||||
|
Peer: peer.Key,
|
||||||
|
PeerID: peer.ID,
|
||||||
|
Metric: router.Metric,
|
||||||
|
Masquerade: router.Masquerade,
|
||||||
|
Enabled: resource.Enabled,
|
||||||
|
KeepRoute: true,
|
||||||
|
NetID: route.NetID(resource.Name),
|
||||||
|
Description: resource.Description,
|
||||||
|
}
|
||||||
|
|
||||||
|
if resource.Type == resourceTypes.Host || resource.Type == resourceTypes.Subnet {
|
||||||
|
r.Network = resource.Prefix
|
||||||
|
|
||||||
|
r.NetworkType = route.IPv4Network
|
||||||
|
if resource.Prefix.Addr().Is6() {
|
||||||
|
r.NetworkType = route.IPv6Network
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if resource.Type == resourceTypes.Domain {
|
||||||
|
domainList, err := domain.FromStringList([]string{resource.Domain})
|
||||||
|
if err == nil {
|
||||||
|
r.Domains = domainList
|
||||||
|
r.NetworkType = route.DomainNetwork
|
||||||
|
r.Network = netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
|
||||||
|
var dest []string
|
||||||
|
for _, peerID := range inputPeers {
|
||||||
|
if c.ValidatePostureChecksOnPeer(peerID, postureChecksIDs) {
|
||||||
|
dest = append(dest, peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dest
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getPeerNetworkResourceFirewallRules(ctx context.Context, peerID string, routes []*route.Route) []*RouteFirewallRule {
|
||||||
|
routesFirewallRules := make([]*RouteFirewallRule, 0)
|
||||||
|
|
||||||
|
peerInfo := c.GetPeerInfo(peerID)
|
||||||
|
if peerInfo == nil {
|
||||||
|
return routesFirewallRules
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, r := range routes {
|
||||||
|
if r.Peer != peerInfo.Key {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
resourceID := string(r.GetResourceID())
|
||||||
|
resourcePolicies := c.ResourcePoliciesMap[resourceID]
|
||||||
|
distributionPeers := c.getPoliciesSourcePeers(resourcePolicies)
|
||||||
|
|
||||||
|
rules := c.getRouteFirewallRules(ctx, peerID, resourcePolicies, r, distributionPeers)
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.SourceRanges) > 0 {
|
||||||
|
routesFirewallRules = append(routesFirewallRules, rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routesFirewallRules
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) getPoliciesSourcePeers(policies []*Policy) map[string]struct{} {
|
||||||
|
sourcePeers := make(map[string]struct{})
|
||||||
|
|
||||||
|
for _, policy := range policies {
|
||||||
|
for _, rule := range policy.Rules {
|
||||||
|
for _, sourceGroup := range rule.Sources {
|
||||||
|
group := c.GetGroupInfo(sourceGroup)
|
||||||
|
if group == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range group.Peers {
|
||||||
|
sourcePeers[peer] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||||
|
sourcePeers[rule.SourceResource.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sourcePeers
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) addNetworksRoutingPeers(
|
||||||
|
networkResourcesRoutes []*route.Route,
|
||||||
|
peerID string,
|
||||||
|
peersToConnect []*nbpeer.Peer,
|
||||||
|
expiredPeers []*nbpeer.Peer,
|
||||||
|
isRouter bool,
|
||||||
|
sourcePeers map[string]struct{},
|
||||||
|
) []*nbpeer.Peer {
|
||||||
|
|
||||||
|
networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes))
|
||||||
|
for _, r := range networkResourcesRoutes {
|
||||||
|
networkRoutesPeers[r.PeerID] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(sourcePeers, peerID)
|
||||||
|
delete(networkRoutesPeers, peerID)
|
||||||
|
|
||||||
|
for _, existingPeer := range peersToConnect {
|
||||||
|
delete(sourcePeers, existingPeer.ID)
|
||||||
|
delete(networkRoutesPeers, existingPeer.ID)
|
||||||
|
}
|
||||||
|
for _, expPeer := range expiredPeers {
|
||||||
|
delete(sourcePeers, expPeer.ID)
|
||||||
|
delete(networkRoutesPeers, expPeer.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers))
|
||||||
|
if isRouter {
|
||||||
|
for p := range sourcePeers {
|
||||||
|
missingPeers[p] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for p := range networkRoutesPeers {
|
||||||
|
missingPeers[p] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for p := range missingPeers {
|
||||||
|
peerInfo := c.GetPeerInfo(p)
|
||||||
|
if peerInfo == nil {
|
||||||
|
peerInfo = c.GetRouterPeerInfo(p)
|
||||||
|
}
|
||||||
|
if peerInfo != nil {
|
||||||
|
peersToConnect = append(peersToConnect, peerInfo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return peersToConnect
|
||||||
|
}
|
||||||
230
management/server/types/networkmap_components_compact.go
Normal file
230
management/server/types/networkmap_components_compact.go
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
import (
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||||
|
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||||
|
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GroupCompact struct {
|
||||||
|
Name string
|
||||||
|
PeerIndexes []int
|
||||||
|
}
|
||||||
|
|
||||||
|
type NetworkMapComponentsCompact struct {
|
||||||
|
PeerID string
|
||||||
|
|
||||||
|
Network *Network
|
||||||
|
AccountSettings *AccountSettingsInfo
|
||||||
|
DNSSettings *DNSSettings
|
||||||
|
CustomZoneDomain string
|
||||||
|
|
||||||
|
AllPeers []*nbpeer.Peer
|
||||||
|
PeerIndexes []int
|
||||||
|
RouterPeerIndexes []int
|
||||||
|
|
||||||
|
Groups map[string]*GroupCompact
|
||||||
|
AllPolicies []*Policy
|
||||||
|
PolicyIndexes []int
|
||||||
|
ResourcePoliciesMap map[string][]int
|
||||||
|
Routes []*route.Route
|
||||||
|
NameServerGroups []*nbdns.NameServerGroup
|
||||||
|
AllDNSRecords []nbdns.SimpleRecord
|
||||||
|
AccountZones []nbdns.CustomZone
|
||||||
|
|
||||||
|
RoutersMap map[string]map[string]*routerTypes.NetworkRouter
|
||||||
|
NetworkResources []*resourceTypes.NetworkResource
|
||||||
|
|
||||||
|
GroupIDToUserIDs map[string][]string
|
||||||
|
AllowedUserIDs map[string]struct{}
|
||||||
|
PostureFailedPeers map[string]map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponents) ToCompact() *NetworkMapComponentsCompact {
|
||||||
|
peerToIndex := make(map[string]int)
|
||||||
|
var allPeers []*nbpeer.Peer
|
||||||
|
|
||||||
|
for id, peer := range c.Peers {
|
||||||
|
if _, exists := peerToIndex[id]; !exists {
|
||||||
|
peerToIndex[id] = len(allPeers)
|
||||||
|
allPeers = append(allPeers, peer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, peer := range c.RouterPeers {
|
||||||
|
if _, exists := peerToIndex[id]; !exists {
|
||||||
|
peerToIndex[id] = len(allPeers)
|
||||||
|
allPeers = append(allPeers, peer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
peerIndexes := make([]int, 0, len(c.Peers))
|
||||||
|
for id := range c.Peers {
|
||||||
|
peerIndexes = append(peerIndexes, peerToIndex[id])
|
||||||
|
}
|
||||||
|
|
||||||
|
routerPeerIndexes := make([]int, 0, len(c.RouterPeers))
|
||||||
|
for id := range c.RouterPeers {
|
||||||
|
routerPeerIndexes = append(routerPeerIndexes, peerToIndex[id])
|
||||||
|
}
|
||||||
|
|
||||||
|
groups := make(map[string]*GroupCompact, len(c.Groups))
|
||||||
|
for id, group := range c.Groups {
|
||||||
|
peerIdxs := make([]int, 0, len(group.Peers))
|
||||||
|
for _, peerID := range group.Peers {
|
||||||
|
if idx, ok := peerToIndex[peerID]; ok {
|
||||||
|
peerIdxs = append(peerIdxs, idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
groups[id] = &GroupCompact{
|
||||||
|
Name: group.Name,
|
||||||
|
PeerIndexes: peerIdxs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
policyToIndex := make(map[*Policy]int)
|
||||||
|
var allPolicies []*Policy
|
||||||
|
|
||||||
|
for _, policy := range c.Policies {
|
||||||
|
if _, exists := policyToIndex[policy]; !exists {
|
||||||
|
policyToIndex[policy] = len(allPolicies)
|
||||||
|
allPolicies = append(allPolicies, policy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, policies := range c.ResourcePoliciesMap {
|
||||||
|
for _, policy := range policies {
|
||||||
|
if _, exists := policyToIndex[policy]; !exists {
|
||||||
|
policyToIndex[policy] = len(allPolicies)
|
||||||
|
allPolicies = append(allPolicies, policy)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
policyIndexes := make([]int, len(c.Policies))
|
||||||
|
for i, policy := range c.Policies {
|
||||||
|
policyIndexes[i] = policyToIndex[policy]
|
||||||
|
}
|
||||||
|
|
||||||
|
var resourcePoliciesMap map[string][]int
|
||||||
|
if len(c.ResourcePoliciesMap) > 0 {
|
||||||
|
resourcePoliciesMap = make(map[string][]int, len(c.ResourcePoliciesMap))
|
||||||
|
for resID, policies := range c.ResourcePoliciesMap {
|
||||||
|
indexes := make([]int, len(policies))
|
||||||
|
for i, policy := range policies {
|
||||||
|
indexes[i] = policyToIndex[policy]
|
||||||
|
}
|
||||||
|
resourcePoliciesMap[resID] = indexes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &NetworkMapComponentsCompact{
|
||||||
|
PeerID: c.PeerID,
|
||||||
|
Network: c.Network,
|
||||||
|
AccountSettings: c.AccountSettings,
|
||||||
|
DNSSettings: c.DNSSettings,
|
||||||
|
CustomZoneDomain: c.CustomZoneDomain,
|
||||||
|
|
||||||
|
AllPeers: allPeers,
|
||||||
|
PeerIndexes: peerIndexes,
|
||||||
|
RouterPeerIndexes: routerPeerIndexes,
|
||||||
|
|
||||||
|
Groups: groups,
|
||||||
|
AllPolicies: allPolicies,
|
||||||
|
PolicyIndexes: policyIndexes,
|
||||||
|
ResourcePoliciesMap: resourcePoliciesMap,
|
||||||
|
Routes: c.Routes,
|
||||||
|
NameServerGroups: c.NameServerGroups,
|
||||||
|
AllDNSRecords: c.AllDNSRecords,
|
||||||
|
AccountZones: c.AccountZones,
|
||||||
|
|
||||||
|
RoutersMap: c.RoutersMap,
|
||||||
|
NetworkResources: c.NetworkResources,
|
||||||
|
|
||||||
|
GroupIDToUserIDs: c.GroupIDToUserIDs,
|
||||||
|
AllowedUserIDs: c.AllowedUserIDs,
|
||||||
|
PostureFailedPeers: c.PostureFailedPeers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *NetworkMapComponentsCompact) ToFull() *NetworkMapComponents {
|
||||||
|
peers := make(map[string]*nbpeer.Peer, len(c.PeerIndexes))
|
||||||
|
for _, idx := range c.PeerIndexes {
|
||||||
|
if idx >= 0 && idx < len(c.AllPeers) {
|
||||||
|
peer := c.AllPeers[idx]
|
||||||
|
peers[peer.ID] = peer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
routerPeers := make(map[string]*nbpeer.Peer, len(c.RouterPeerIndexes))
|
||||||
|
for _, idx := range c.RouterPeerIndexes {
|
||||||
|
if idx >= 0 && idx < len(c.AllPeers) {
|
||||||
|
peer := c.AllPeers[idx]
|
||||||
|
routerPeers[peer.ID] = peer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
groups := make(map[string]*Group, len(c.Groups))
|
||||||
|
for id, gc := range c.Groups {
|
||||||
|
peerIDs := make([]string, 0, len(gc.PeerIndexes))
|
||||||
|
for _, idx := range gc.PeerIndexes {
|
||||||
|
if idx >= 0 && idx < len(c.AllPeers) {
|
||||||
|
peerIDs = append(peerIDs, c.AllPeers[idx].ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
groups[id] = &Group{
|
||||||
|
ID: id,
|
||||||
|
Name: gc.Name,
|
||||||
|
Peers: peerIDs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
policies := make([]*Policy, len(c.PolicyIndexes))
|
||||||
|
for i, idx := range c.PolicyIndexes {
|
||||||
|
if idx >= 0 && idx < len(c.AllPolicies) {
|
||||||
|
policies[i] = c.AllPolicies[idx]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var resourcePoliciesMap map[string][]*Policy
|
||||||
|
if len(c.ResourcePoliciesMap) > 0 {
|
||||||
|
resourcePoliciesMap = make(map[string][]*Policy, len(c.ResourcePoliciesMap))
|
||||||
|
for resID, indexes := range c.ResourcePoliciesMap {
|
||||||
|
pols := make([]*Policy, 0, len(indexes))
|
||||||
|
for _, idx := range indexes {
|
||||||
|
if idx >= 0 && idx < len(c.AllPolicies) {
|
||||||
|
pols = append(pols, c.AllPolicies[idx])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resourcePoliciesMap[resID] = pols
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &NetworkMapComponents{
|
||||||
|
PeerID: c.PeerID,
|
||||||
|
Network: c.Network,
|
||||||
|
AccountSettings: c.AccountSettings,
|
||||||
|
DNSSettings: c.DNSSettings,
|
||||||
|
CustomZoneDomain: c.CustomZoneDomain,
|
||||||
|
|
||||||
|
Peers: peers,
|
||||||
|
RouterPeers: routerPeers,
|
||||||
|
|
||||||
|
Groups: groups,
|
||||||
|
Policies: policies,
|
||||||
|
Routes: c.Routes,
|
||||||
|
NameServerGroups: c.NameServerGroups,
|
||||||
|
AllDNSRecords: c.AllDNSRecords,
|
||||||
|
AccountZones: c.AccountZones,
|
||||||
|
|
||||||
|
ResourcePoliciesMap: resourcePoliciesMap,
|
||||||
|
RoutersMap: c.RoutersMap,
|
||||||
|
NetworkResources: c.NetworkResources,
|
||||||
|
|
||||||
|
GroupIDToUserIDs: c.GroupIDToUserIDs,
|
||||||
|
AllowedUserIDs: c.AllowedUserIDs,
|
||||||
|
PostureFailedPeers: c.PostureFailedPeers,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -53,6 +53,7 @@ var (
|
|||||||
certLockMethod string
|
certLockMethod string
|
||||||
wgPort int
|
wgPort int
|
||||||
proxyProtocol bool
|
proxyProtocol bool
|
||||||
|
preSharedKey string
|
||||||
)
|
)
|
||||||
|
|
||||||
var rootCmd = &cobra.Command{
|
var rootCmd = &cobra.Command{
|
||||||
@@ -84,6 +85,7 @@ func init() {
|
|||||||
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
|
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
|
||||||
rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
|
rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
|
||||||
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
|
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
|
||||||
|
rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute runs the root command.
|
// Execute runs the root command.
|
||||||
@@ -156,6 +158,7 @@ func runServer(cmd *cobra.Command, args []string) error {
|
|||||||
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
|
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
|
||||||
WireguardPort: wgPort,
|
WireguardPort: wgPort,
|
||||||
ProxyProtocol: proxyProtocol,
|
ProxyProtocol: proxyProtocol,
|
||||||
|
PreSharedKey: preSharedKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)
|
||||||
|
|||||||
@@ -86,6 +86,13 @@ func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClientConfig holds configuration for the embedded NetBird client.
|
||||||
|
type ClientConfig struct {
|
||||||
|
MgmtAddr string
|
||||||
|
WGPort int
|
||||||
|
PreSharedKey string
|
||||||
|
}
|
||||||
|
|
||||||
type statusNotifier interface {
|
type statusNotifier interface {
|
||||||
NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error
|
NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error
|
||||||
}
|
}
|
||||||
@@ -98,10 +105,9 @@ type managementClient interface {
|
|||||||
// backed by underlying NetBird connections.
|
// backed by underlying NetBird connections.
|
||||||
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
|
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
|
||||||
type NetBird struct {
|
type NetBird struct {
|
||||||
mgmtAddr string
|
|
||||||
proxyID string
|
proxyID string
|
||||||
proxyAddr string
|
proxyAddr string
|
||||||
wgPort int
|
clientCfg ClientConfig
|
||||||
logger *log.Logger
|
logger *log.Logger
|
||||||
mgmtClient managementClient
|
mgmtClient managementClient
|
||||||
transportCfg transportConfig
|
transportCfg transportConfig
|
||||||
@@ -229,11 +235,12 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
|||||||
// The peer has already been created via CreateProxyPeer RPC with the public key.
|
// The peer has already been created via CreateProxyPeer RPC with the public key.
|
||||||
client, err := embed.New(embed.Options{
|
client, err := embed.New(embed.Options{
|
||||||
DeviceName: deviceNamePrefix + n.proxyID,
|
DeviceName: deviceNamePrefix + n.proxyID,
|
||||||
ManagementURL: n.mgmtAddr,
|
ManagementURL: n.clientCfg.MgmtAddr,
|
||||||
PrivateKey: privateKey.String(),
|
PrivateKey: privateKey.String(),
|
||||||
LogLevel: log.WarnLevel.String(),
|
LogLevel: log.WarnLevel.String(),
|
||||||
BlockInbound: true,
|
BlockInbound: true,
|
||||||
WireguardPort: &n.wgPort,
|
WireguardPort: &n.clientCfg.WGPort,
|
||||||
|
PreSharedKey: n.clientCfg.PreSharedKey,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create netbird client: %w", err)
|
return nil, fmt.Errorf("create netbird client: %w", err)
|
||||||
@@ -536,18 +543,17 @@ func (n *NetBird) ListClientsForStartup() map[types.AccountID]*embed.Client {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewNetBird creates a new NetBird transport. Set wgPort to 0 for a random
|
// NewNetBird creates a new NetBird transport. Set clientCfg.WGPort to 0 for a random
|
||||||
// OS-assigned port. A fixed port only works with single-account deployments;
|
// OS-assigned port. A fixed port only works with single-account deployments;
|
||||||
// multiple accounts will fail to bind the same port.
|
// multiple accounts will fail to bind the same port.
|
||||||
func NewNetBird(mgmtAddr, proxyID, proxyAddr string, wgPort int, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird {
|
func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird {
|
||||||
if logger == nil {
|
if logger == nil {
|
||||||
logger = log.StandardLogger()
|
logger = log.StandardLogger()
|
||||||
}
|
}
|
||||||
return &NetBird{
|
return &NetBird{
|
||||||
mgmtAddr: mgmtAddr,
|
|
||||||
proxyID: proxyID,
|
proxyID: proxyID,
|
||||||
proxyAddr: proxyAddr,
|
proxyAddr: proxyAddr,
|
||||||
wgPort: wgPort,
|
clientCfg: clientCfg,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
clients: make(map[types.AccountID]*clientEntry),
|
clients: make(map[types.AccountID]*clientEntry),
|
||||||
statusNotifier: notifier,
|
statusNotifier: notifier,
|
||||||
|
|||||||
@@ -49,7 +49,11 @@ func (m *mockStatusNotifier) calls() []statusCall {
|
|||||||
// mockNetBird creates a NetBird instance for testing without actually connecting.
|
// mockNetBird creates a NetBird instance for testing without actually connecting.
|
||||||
// It uses an invalid management URL to prevent real connections.
|
// It uses an invalid management URL to prevent real connections.
|
||||||
func mockNetBird() *NetBird {
|
func mockNetBird() *NetBird {
|
||||||
return NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, nil, &mockMgmtClient{})
|
return NewNetBird("test-proxy", "invalid.test", ClientConfig{
|
||||||
|
MgmtAddr: "http://invalid.test:9999",
|
||||||
|
WGPort: 0,
|
||||||
|
PreSharedKey: "",
|
||||||
|
}, nil, nil, &mockMgmtClient{})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
|
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
|
||||||
@@ -282,7 +286,11 @@ func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
|
|||||||
|
|
||||||
func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
||||||
notifier := &mockStatusNotifier{}
|
notifier := &mockStatusNotifier{}
|
||||||
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
|
nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{
|
||||||
|
MgmtAddr: "http://invalid.test:9999",
|
||||||
|
WGPort: 0,
|
||||||
|
PreSharedKey: "",
|
||||||
|
}, nil, notifier, &mockMgmtClient{})
|
||||||
accountID := types.AccountID("account-1")
|
accountID := types.AccountID("account-1")
|
||||||
|
|
||||||
// Add first domain — creates a new client entry.
|
// Add first domain — creates a new client entry.
|
||||||
@@ -308,7 +316,11 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
|
|||||||
|
|
||||||
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
|
||||||
notifier := &mockStatusNotifier{}
|
notifier := &mockStatusNotifier{}
|
||||||
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
|
nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{
|
||||||
|
MgmtAddr: "http://invalid.test:9999",
|
||||||
|
WGPort: 0,
|
||||||
|
PreSharedKey: "",
|
||||||
|
}, nil, notifier, &mockMgmtClient{})
|
||||||
accountID := types.AccountID("account-1")
|
accountID := types.AccountID("account-1")
|
||||||
|
|
||||||
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
|
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")
|
||||||
|
|||||||
@@ -192,6 +192,10 @@ type storeBackedServiceManager struct {
|
|||||||
tokenStore *nbgrpc.OneTimeTokenStore
|
tokenStore *nbgrpc.OneTimeTokenStore
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *storeBackedServiceManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
|
||||||
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -114,6 +114,8 @@ type Server struct {
|
|||||||
// When enabled, the real client IP is extracted from the PROXY header
|
// When enabled, the real client IP is extracted from the PROXY header
|
||||||
// sent by upstream L4 proxies that support PROXY protocol.
|
// sent by upstream L4 proxies that support PROXY protocol.
|
||||||
ProxyProtocol bool
|
ProxyProtocol bool
|
||||||
|
// PreSharedKey used for tunnel between proxy and peers (set globally not per account)
|
||||||
|
PreSharedKey string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NotifyStatus sends a status update to management about tunnel connectivity
|
// NotifyStatus sends a status update to management about tunnel connectivity
|
||||||
@@ -163,7 +165,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
|
|||||||
|
|
||||||
// Initialize the netbird client, this is required to build peer connections
|
// Initialize the netbird client, this is required to build peer connections
|
||||||
// to proxy over.
|
// to proxy over.
|
||||||
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.ProxyURL, s.WireguardPort, s.Logger, s, s.mgmtClient)
|
s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{
|
||||||
|
MgmtAddr: s.ManagementAddress,
|
||||||
|
WGPort: s.WireguardPort,
|
||||||
|
PreSharedKey: s.PreSharedKey,
|
||||||
|
}, s.Logger, s, s.mgmtClient)
|
||||||
|
|
||||||
tlsConfig, err := s.configureTLS(ctx)
|
tlsConfig, err := s.configureTLS(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user