Merge remote-tracking branch 'origin/main' into feature/add-serial-to-proxy

This commit is contained in:
pascal
2026-02-23 16:54:22 +01:00
26 changed files with 2616 additions and 67 deletions

View File

@@ -351,9 +351,13 @@ func (u *upstreamResolverBase) waitUntilResponse() {
return fmt.Errorf("upstream check call error")
}
err := backoff.Retry(operation, exponentialBackOff)
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
if err != nil {
log.Warn(err)
if errors.Is(err, context.Canceled) {
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
} else {
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
}
return
}

View File

@@ -70,6 +70,7 @@ type ServerConfig struct {
DisableGeoliteUpdate bool `yaml:"disableGeoliteUpdate"`
Auth AuthConfig `yaml:"auth"`
Store StoreConfig `yaml:"store"`
ActivityStore StoreConfig `yaml:"activityStore"`
ReverseProxy ReverseProxyConfig `yaml:"reverseProxy"`
}

View File

@@ -141,6 +141,17 @@ func initializeConfig() error {
}
}
if engine := config.Server.ActivityStore.Engine; engine != "" {
engineLower := strings.ToLower(engine)
if engineLower == "postgres" && config.Server.ActivityStore.DSN == "" {
return fmt.Errorf("activityStore.dsn is required when activityStore.engine is postgres")
}
os.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE", engineLower)
if dsn := config.Server.ActivityStore.DSN; dsn != "" {
os.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN", dsn)
}
}
log.Infof("Starting combined NetBird server")
logConfig(config)
logEnvVars()
@@ -668,8 +679,11 @@ func logEnvVars() {
if strings.HasPrefix(env, "NB_") {
key, _, _ := strings.Cut(env, "=")
value := os.Getenv(key)
if strings.Contains(strings.ToLower(key), "secret") || strings.Contains(strings.ToLower(key), "key") || strings.Contains(strings.ToLower(key), "password") {
keyLower := strings.ToLower(key)
if strings.Contains(keyLower, "secret") || strings.Contains(keyLower, "key") || strings.Contains(keyLower, "password") {
value = maskSecret(value)
} else if strings.Contains(keyLower, "dsn") {
value = maskDSNPassword(value)
}
log.Infof(" %s=%s", key, value)
found = true

View File

@@ -104,6 +104,11 @@ server:
dsn: "" # Connection string for postgres or mysql
encryptionKey: ""
# Activity events store configuration (optional, defaults to sqlite in dataDir)
# activityStore:
# engine: "sqlite" # sqlite or postgres
# dsn: "" # Connection string for postgres
# Reverse proxy settings (optional)
# reverseProxy:
# trustedHTTPProxies: []

View File

@@ -63,6 +63,8 @@ type Controller struct {
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
compactedNetworkMap bool
}
type bufferUpdate struct {
@@ -85,6 +87,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
newNetworkMapBuilder = false
}
compactedNetworkMap, err := strconv.ParseBool(os.Getenv(types.EnvNewNetworkMapCompacted))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", types.EnvNewNetworkMapCompacted, err)
compactedNetworkMap = false
}
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
@@ -108,6 +116,8 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
compactedNetworkMap: compactedNetworkMap,
}
}
@@ -230,9 +240,12 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
switch {
case c.experimentalNetworkMap(accountID):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
@@ -355,9 +368,12 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
var remotePeerNetworkMap *types.NetworkMap
if c.experimentalNetworkMap(accountId) {
switch {
case c.experimentalNetworkMap(accountId):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
@@ -479,7 +495,12 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, account.GetActiveGroupUsers())
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
@@ -854,7 +875,12 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]

View File

@@ -14,6 +14,7 @@ type Manager interface {
CreateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
UpdateService(ctx context.Context, accountID, userID string, service *Service) (*Service, error)
DeleteService(ctx context.Context, accountID, userID, serviceID string) error
DeleteAllServices(ctx context.Context, accountID, userID string) error
SetCertificateIssuedAt(ctx context.Context, accountID, serviceID string) error
SetStatus(ctx context.Context, accountID, serviceID string, status Status) error
ReloadAllServicesForAccount(ctx context.Context, accountID string) error

View File

@@ -50,6 +50,20 @@ func (mr *MockManagerMockRecorder) CreateService(ctx, accountID, userID, service
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateService", reflect.TypeOf((*MockManager)(nil).CreateService), ctx, accountID, userID, service)
}
// DeleteAllServices mocks base method.
func (m *MockManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAllServices", ctx, accountID, userID)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAllServices indicates an expected call of DeleteAllServices.
func (mr *MockManagerMockRecorder) DeleteAllServices(ctx, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAllServices", reflect.TypeOf((*MockManager)(nil).DeleteAllServices), ctx, accountID, userID)
}
// DeleteService mocks base method.
func (m *MockManager) DeleteService(ctx context.Context, accountID, userID, serviceID string) error {
m.ctrl.T.Helper()

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -344,6 +345,22 @@ func (m *Manager) sendServiceUpdateNotifications(ctx context.Context, accountID
}
}
func (m *managerImpl) sendServiceUpdate(service *reverseproxy.Service, operation reverseproxy.Operation, cluster, oldService string) {
oidcCfg := m.proxyGRPCServer.GetOIDCValidationConfig()
mapping := service.ToProtoMapping(operation, oldService, oidcCfg)
m.sendMappingsToCluster([]*proto.ProxyMapping{mapping}, cluster)
}
func (m *managerImpl) sendMappingsToCluster(mappings []*proto.ProxyMapping, cluster string) {
if len(mappings) == 0 {
return
}
update := &proto.GetMappingUpdateResponse{
Mapping: mappings,
}
m.proxyGRPCServer.SendServiceUpdateToCluster(update, cluster)
}
// validateTargetReferences checks that all target IDs reference existing peers or resources in the account.
func validateTargetReferences(ctx context.Context, transaction store.Store, accountID string, targets []*service.Target) error {
for _, target := range targets {

View File

@@ -105,7 +105,7 @@ type proxyConnection struct {
proxyID string
address string
stream proto.ProxyService_GetMappingUpdateServer
sendChan chan *proto.ProxyMapping
sendChan chan *proto.GetMappingUpdateResponse
ctx context.Context
cancel context.CancelFunc
}
@@ -114,7 +114,6 @@ type proxyConnection struct {
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager) *ProxyServiceServer {
ctx, cancel := context.WithCancel(context.Background())
s := &ProxyServiceServer{
updatesChan: make(chan *proto.ProxyMapping, 100),
accessLogManager: accessLogMgr,
oidcConfig: oidcConfig,
tokenStore: tokenStore,
@@ -203,7 +202,7 @@ func (s *ProxyServiceServer) GetMappingUpdate(req *proto.GetMappingUpdateRequest
proxyID: proxyID,
address: proxyAddress,
stream: stream,
sendChan: make(chan *proto.ProxyMapping, 100),
sendChan: make(chan *proto.GetMappingUpdateResponse, 100),
ctx: connCtx,
cancel: cancel,
}
@@ -349,7 +348,7 @@ func (s *ProxyServiceServer) sender(conn *proxyConnection, errChan chan<- error)
for {
select {
case msg := <-conn.sendChan:
if err := conn.stream.Send(&proto.GetMappingUpdateResponse{Mapping: []*proto.ProxyMapping{msg}}); err != nil {
if err := conn.stream.Send(msg); err != nil {
errChan <- err
return
}
@@ -400,7 +399,7 @@ func (s *ProxyServiceServer) SendAccessLog(ctx context.Context, req *proto.SendA
// Management should call this when services are created/updated/removed.
// For create/update operations a unique one-time auth token is generated per
// proxy so that every replica can independently authenticate with management.
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) {
func (s *ProxyServiceServer) SendServiceUpdate(update *proto.GetMappingUpdateResponse) {
log.Debugf("Broadcasting service update to all connected proxy servers")
s.connectedProxies.Range(func(key, value interface{}) bool {
conn := value.(*proxyConnection)
@@ -410,7 +409,7 @@ func (s *ProxyServiceServer) SendServiceUpdate(update *proto.ProxyMapping) {
}
select {
case conn.sendChan <- msg:
log.Debugf("Sent service update with id %s to proxy server %s", update.Id, conn.proxyID)
log.Debugf("Sent service update to proxy server %s", conn.proxyID)
default:
log.Warnf("Failed to send service update to proxy server %s (channel full)", conn.proxyID)
}
@@ -494,23 +493,31 @@ func (s *ProxyServiceServer) SendServiceUpdateToCluster(ctx context.Context, upd
}
// perProxyMessage returns a copy of update with a fresh one-time token for
// create/update operations. For delete operations the original message is
// returned unchanged because proxies do not need to authenticate for removal.
// create/update operations. For delete operations the original mapping is
// used unchanged because proxies do not need to authenticate for removal.
// Returns nil if token generation fails (the proxy should be skipped).
func (s *ProxyServiceServer) perProxyMessage(update *proto.ProxyMapping, proxyID string) *proto.ProxyMapping {
if update.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED || update.AccountId == "" {
return update
func (s *ProxyServiceServer) perProxyMessage(update *proto.GetMappingUpdateResponse, proxyID string) *proto.GetMappingUpdateResponse {
resp := make([]*proto.ProxyMapping, 0, len(update.Mapping))
for _, mapping := range update.Mapping {
if mapping.Type == proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED {
resp = append(resp, mapping)
continue
}
token, err := s.tokenStore.GenerateToken(mapping.AccountId, mapping.Id, 5*time.Minute)
if err != nil {
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
return nil
}
msg := shallowCloneMapping(mapping)
msg.AuthToken = token
resp = append(resp, msg)
}
token, err := s.tokenStore.GenerateToken(update.AccountId, update.Id, 5*time.Minute)
if err != nil {
log.Warnf("Failed to generate token for proxy %s: %v", proxyID, err)
return nil
return &proto.GetMappingUpdateResponse{
Mapping: resp,
}
msg := shallowCloneMapping(update)
msg.AuthToken = token
return msg
}
// shallowCloneMapping creates a shallow copy of a ProxyMapping, reusing the

View File

@@ -17,6 +17,10 @@ type mockReverseProxyManager struct {
err error
}
func (m *mockReverseProxyManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
return nil
}
func (m *mockReverseProxyManager) GetAccountServices(ctx context.Context, accountID string) ([]*reverseproxy.Service, error) {
if m.err != nil {
return nil, m.err

View File

@@ -17,8 +17,8 @@ import (
// registerFakeProxy adds a fake proxy connection to the server's internal maps
// and returns the channel where messages will be received.
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.ProxyMapping {
ch := make(chan *proto.ProxyMapping, 10)
func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan *proto.GetMappingUpdateResponse {
ch := make(chan *proto.GetMappingUpdateResponse, 10)
conn := &proxyConnection{
proxyID: proxyID,
address: clusterAddr,
@@ -32,7 +32,7 @@ func registerFakeProxy(s *ProxyServiceServer, proxyID, clusterAddr string) chan
return ch
}
func drainChannel(ch chan *proto.ProxyMapping) *proto.ProxyMapping {
func drainChannel(ch chan *proto.GetMappingUpdateResponse) *proto.GetMappingUpdateResponse {
select {
case msg := <-ch:
return msg
@@ -46,20 +46,19 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
tokenStore: tokenStore,
}
const cluster = "proxy.example.com"
const numProxies = 3
channels := make([]chan *proto.ProxyMapping, numProxies)
channels := make([]chan *proto.GetMappingUpdateResponse, numProxies)
for i := range numProxies {
id := "proxy-" + string(rune('a'+i))
channels[i] = registerFakeProxy(s, id, cluster)
}
update := &proto.ProxyMapping{
mapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1",
AccountId: "account-1",
@@ -73,10 +72,12 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) {
tokens := make([]string, numProxies)
for i, ch := range channels {
msg := drainChannel(ch)
require.NotNil(t, msg, "proxy %d should receive a message", i)
assert.Equal(t, update.Domain, msg.Domain)
assert.Equal(t, update.Id, msg.Id)
resp := drainChannel(ch)
require.NotNil(t, resp, "proxy %d should receive a message", i)
require.Len(t, resp.Mapping, 1, "proxy %d should receive exactly one mapping", i)
msg := resp.Mapping[0]
assert.Equal(t, mapping.Domain, msg.Domain)
assert.Equal(t, mapping.Id, msg.Id)
assert.NotEmpty(t, msg.AuthToken, "proxy %d should have a non-empty token", i)
tokens[i] = msg.AuthToken
}
@@ -101,15 +102,14 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
tokenStore: tokenStore,
}
const cluster = "proxy.example.com"
ch1 := registerFakeProxy(s, "proxy-a", cluster)
ch2 := registerFakeProxy(s, "proxy-b", cluster)
update := &proto.ProxyMapping{
mapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_REMOVED,
Id: "service-1",
AccountId: "account-1",
@@ -118,10 +118,12 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) {
s.SendServiceUpdateToCluster(context.Background(), update, cluster)
msg1 := drainChannel(ch1)
msg2 := drainChannel(ch2)
require.NotNil(t, msg1)
require.NotNil(t, msg2)
resp1 := drainChannel(ch1)
resp2 := drainChannel(ch2)
require.NotNil(t, resp1)
require.NotNil(t, resp2)
require.Len(t, resp1.Mapping, 1)
require.Len(t, resp2.Mapping, 1)
// Delete operations should not generate tokens
assert.Empty(t, msg1.AuthToken)
@@ -133,27 +135,35 @@ func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) {
require.NoError(t, err)
s := &ProxyServiceServer{
tokenStore: tokenStore,
updatesChan: make(chan *proto.ProxyMapping, 100),
tokenStore: tokenStore,
}
// Register proxies in different clusters (SendServiceUpdate broadcasts to all)
ch1 := registerFakeProxy(s, "proxy-a", "cluster-a")
ch2 := registerFakeProxy(s, "proxy-b", "cluster-b")
update := &proto.ProxyMapping{
mapping := &proto.ProxyMapping{
Type: proto.ProxyMappingUpdateType_UPDATE_TYPE_CREATED,
Id: "service-1",
AccountId: "account-1",
Domain: "test.example.com",
}
update := &proto.GetMappingUpdateResponse{
Mapping: []*proto.ProxyMapping{mapping},
}
s.SendServiceUpdate(update)
msg1 := drainChannel(ch1)
msg2 := drainChannel(ch2)
require.NotNil(t, msg1)
require.NotNil(t, msg2)
resp1 := drainChannel(ch1)
resp2 := drainChannel(ch2)
require.NotNil(t, resp1)
require.NotNil(t, resp2)
require.Len(t, resp1.Mapping, 1)
require.Len(t, resp2.Mapping, 1)
msg1 := resp1.Mapping[0]
msg2 := resp2.Mapping[0]
assert.NotEmpty(t, msg1.AuthToken)
assert.NotEmpty(t, msg2.AuthToken)

View File

@@ -714,6 +714,11 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
}
err = am.reverseProxyManager.DeleteAllServices(ctx, accountID, userID)
if err != nil {
return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err)
}
for _, otherUser := range account.Users {
if otherUser.Id == userID {
continue

View File

@@ -31,6 +31,7 @@ import (
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
nbAccount "github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/cache"
@@ -3122,7 +3123,8 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
return nil, nil, err
}
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, nil, nil))
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil)
manager.SetServiceManager(reverseproxymanager.NewManager(store, manager, permissionsManager, proxyGrpcServer, nil))
return manager, updateManager, nil
}

View File

@@ -358,6 +358,10 @@ type testServiceManager struct {
store store.Store
}
func (m *testServiceManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
return nil
}
func (m *testServiceManager) GetAllServices(_ context.Context, _, _ string) ([]*reverseproxy.Service, error) {
return nil, nil
}

View File

@@ -210,6 +210,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
rosenpassEnabled int
localUsers int
idpUsers int
embeddedIdpTypes map[string]int
)
start := time.Now()
metricsProperties := make(properties)
@@ -218,6 +219,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
rulesProtocol = make(map[string]int)
rulesDirection = make(map[string]int)
activeUsersLastDay = make(map[string]struct{})
embeddedIdpTypes = make(map[string]int)
uptime = time.Since(w.startupTime).Seconds()
connections := w.connManager.GetAllConnectedPeers()
version = nbversion.NetbirdVersion()
@@ -277,6 +279,8 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
localUsers++
} else {
idpUsers++
idpType := extractIdpType(idpID)
embeddedIdpTypes[idpType]++
}
}
}
@@ -369,6 +373,11 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["rosenpass_enabled"] = rosenpassEnabled
metricsProperties["local_users_count"] = localUsers
metricsProperties["idp_users_count"] = idpUsers
metricsProperties["embedded_idp_count"] = len(embeddedIdpTypes)
for idpType, count := range embeddedIdpTypes {
metricsProperties["embedded_idp_users_"+idpType] = count
}
for protocol, count := range rulesProtocol {
metricsProperties["rules_protocol_"+protocol] = count
@@ -456,6 +465,17 @@ func createPostRequest(ctx context.Context, endpoint string, payloadStr string)
return req, cancel, nil
}
// extractIdpType extracts the IdP type from a Dex connector ID.
// Connector IDs are formatted as "<type>-<xid>" (e.g., "okta-abc123", "zitadel-xyz").
// Returns the type prefix, or "oidc" if no known prefix is found.
func extractIdpType(connectorID string) string {
idx := strings.LastIndex(connectorID, "-")
if idx <= 0 {
return "oidc"
}
return strings.ToLower(connectorID[:idx])
}
func getMinMaxVersion(inputList []string) (string, string) {
versions := make([]*version.Version, 0)

View File

@@ -27,7 +27,7 @@ func (mockDatasource) GetAllConnectedPeers() map[string]struct{} {
// GetAllAccounts returns a list of *server.Account for use in tests with predefined information
func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
localUserID := dex.EncodeDexUserID("10", "local")
idpUserID := dex.EncodeDexUserID("20", "zitadel")
idpUserID := dex.EncodeDexUserID("20", "zitadel-d5uv82dra0haedlf6kv0")
return []*types.Account{
{
Id: "1",
@@ -341,4 +341,37 @@ func TestGenerateProperties(t *testing.T) {
if properties["idp_users_count"] != 1 {
t.Errorf("expected 1 idp_users_count, got %d", properties["idp_users_count"])
}
if properties["embedded_idp_users_zitadel"] != 1 {
t.Errorf("expected 1 embedded_idp_users_zitadel, got %v", properties["embedded_idp_users_zitadel"])
}
if properties["embedded_idp_count"] != 1 {
t.Errorf("expected 1 embedded_idp_count, got %v", properties["embedded_idp_count"])
}
}
func TestExtractIdpType(t *testing.T) {
tests := []struct {
connectorID string
expected string
}{
{"okta-abc123def", "okta"},
{"zitadel-d5uv82dra0haedlf6kv0", "zitadel"},
{"entra-xyz789", "entra"},
{"google-abc123", "google"},
{"pocketid-abc123", "pocketid"},
{"microsoft-abc123", "microsoft"},
{"authentik-abc123", "authentik"},
{"keycloak-d5uv82dra0haedlf6kv0", "keycloak"},
{"local", "oidc"},
{"", "oidc"},
}
for _, tt := range tests {
t.Run(tt.connectorID, func(t *testing.T) {
result := extractIdpType(tt.connectorID)
if result != tt.expected {
t.Errorf("extractIdpType(%q) = %q, want %q", tt.connectorID, result, tt.expected)
}
})
}
}

View File

@@ -1124,6 +1124,21 @@ func (mr *MockStoreMockRecorder) GetAccountServices(ctx, lockStrength, accountID
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAccountServices", reflect.TypeOf((*MockStore)(nil).GetAccountServices), ctx, lockStrength, accountID)
}
// GetServicesByAccountID mocks base method.
func (m *MockStore) GetServicesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*reverseproxy.Service, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetServicesByAccountID", ctx, lockStrength, accountID)
ret0, _ := ret[0].([]*reverseproxy.Service)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetServicesByAccountID indicates an expected call of GetServicesByAccountID.
func (mr *MockStoreMockRecorder) GetServicesByAccountID(ctx, lockStrength, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetServicesByAccountID", reflect.TypeOf((*MockStore)(nil).GetServicesByAccountID), ctx, lockStrength, accountID)
}
// GetAccountSettings mocks base method.
func (m *MockStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types2.Settings, error) {
m.ctrl.T.Helper()

View File

@@ -0,0 +1,576 @@
package types
import (
"context"
"slices"
"time"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/zones"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/route"
)
func (a *Account) GetPeerNetworkMapFromComponents(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
metrics *telemetry.AccountManagerMetrics,
groupIDToUserIDs map[string][]string,
) *NetworkMap {
start := time.Now()
components := a.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
accountZones,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
if components == nil {
return &NetworkMap{Network: a.Network.Copy()}
}
nm := CalculateNetworkMapFromComponents(ctx, components)
if metrics != nil {
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
metrics.CountNetworkMapObjects(objectCount)
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
if objectCount > 5000 {
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from components, "+
"peers: %d, offline peers: %d, routes: %d, firewall rules: %d, route firewall rules: %d",
a.Id, objectCount, len(nm.Peers), len(nm.OfflinePeers), len(nm.Routes), len(nm.FirewallRules), len(nm.RoutesFirewallRules))
}
}
return nm
}
func (a *Account) GetPeerNetworkMapComponents(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
validatedPeersMap map[string]struct{},
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
groupIDToUserIDs map[string][]string,
) *NetworkMapComponents {
peer := a.Peers[peerID]
if peer == nil {
return nil
}
if _, ok := validatedPeersMap[peerID]; !ok {
return nil
}
components := &NetworkMapComponents{
PeerID: peerID,
Network: a.Network.Copy(),
NameServerGroups: make([]*nbdns.NameServerGroup, 0),
CustomZoneDomain: peersCustomZone.Domain,
ResourcePoliciesMap: make(map[string][]*Policy),
RoutersMap: make(map[string]map[string]*routerTypes.NetworkRouter),
NetworkResources: make([]*resourceTypes.NetworkResource, 0),
PostureFailedPeers: make(map[string]map[string]struct{}, len(a.PostureChecks)),
RouterPeers: make(map[string]*nbpeer.Peer),
}
components.AccountSettings = &AccountSettingsInfo{
PeerLoginExpirationEnabled: a.Settings.PeerLoginExpirationEnabled,
PeerLoginExpiration: a.Settings.PeerLoginExpiration,
PeerInactivityExpirationEnabled: a.Settings.PeerInactivityExpirationEnabled,
PeerInactivityExpiration: a.Settings.PeerInactivityExpiration,
}
components.DNSSettings = &a.DNSSettings
relevantPeers, relevantGroups, relevantPolicies, relevantRoutes, sshReqs := a.getPeersGroupsPoliciesRoutes(ctx, peerID, peer.SSHEnabled, validatedPeersMap, &components.PostureFailedPeers)
if len(sshReqs.neededGroupIDs) > 0 {
components.GroupIDToUserIDs = filterGroupIDToUserIDs(groupIDToUserIDs, sshReqs.neededGroupIDs)
}
if sshReqs.needAllowedUserIDs {
components.AllowedUserIDs = a.getAllowedUserIDs()
}
components.Peers = relevantPeers
components.Groups = relevantGroups
components.Policies = relevantPolicies
components.Routes = relevantRoutes
components.AllDNSRecords = filterDNSRecordsByPeers(peersCustomZone.Records, relevantPeers)
peerGroups := a.GetPeerGroups(peerID)
components.AccountZones = filterPeerAppliedZones(ctx, accountZones, peerGroups)
for _, nsGroup := range a.NameServerGroups {
if nsGroup.Enabled {
for _, gID := range nsGroup.Groups {
if _, found := relevantGroups[gID]; found {
components.NameServerGroups = append(components.NameServerGroups, nsGroup)
break
}
}
}
}
for _, resource := range a.NetworkResources {
if !resource.Enabled {
continue
}
policies, exists := resourcePolicies[resource.ID]
if !exists {
continue
}
addSourcePeers := false
networkRoutingPeers, routerExists := routers[resource.NetworkID]
if routerExists {
if _, ok := networkRoutingPeers[peerID]; ok {
addSourcePeers = true
}
}
for _, policy := range policies {
if addSourcePeers {
var peers []string
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
peers = []string{policy.Rules[0].SourceResource.ID}
} else {
peers = a.getUniquePeerIDsFromGroupsIDs(ctx, policy.SourceGroups())
}
for _, pID := range a.getPostureValidPeersSaveFailed(peers, policy.SourcePostureChecks, validatedPeersMap, &components.PostureFailedPeers) {
if _, exists := components.Peers[pID]; !exists {
components.Peers[pID] = a.GetPeer(pID)
}
}
} else {
peerInSources := false
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
peerInSources = policy.Rules[0].SourceResource.ID == peerID
} else {
for _, groupID := range policy.SourceGroups() {
if group := a.GetGroup(groupID); group != nil && slices.Contains(group.Peers, peerID) {
peerInSources = true
break
}
}
}
if !peerInSources {
continue
}
isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, policy.SourcePostureChecks, peerID)
if !isValid && len(pname) > 0 {
if _, ok := components.PostureFailedPeers[pname]; !ok {
components.PostureFailedPeers[pname] = make(map[string]struct{})
}
components.PostureFailedPeers[pname][peer.ID] = struct{}{}
continue
}
addSourcePeers = true
}
for _, rule := range policy.Rules {
for _, srcGroupID := range rule.Sources {
if g := a.Groups[srcGroupID]; g != nil {
if _, exists := components.Groups[srcGroupID]; !exists {
components.Groups[srcGroupID] = g
}
}
}
for _, dstGroupID := range rule.Destinations {
if g := a.Groups[dstGroupID]; g != nil {
if _, exists := components.Groups[dstGroupID]; !exists {
components.Groups[dstGroupID] = g
}
}
}
}
components.ResourcePoliciesMap[resource.ID] = policies
}
components.RoutersMap[resource.NetworkID] = networkRoutingPeers
for peerIDKey := range networkRoutingPeers {
if p := a.Peers[peerIDKey]; p != nil {
if _, exists := components.RouterPeers[peerIDKey]; !exists {
components.RouterPeers[peerIDKey] = p
}
if _, exists := components.Peers[peerIDKey]; !exists {
if _, validated := validatedPeersMap[peerIDKey]; validated {
components.Peers[peerIDKey] = p
}
}
}
}
if addSourcePeers {
components.NetworkResources = append(components.NetworkResources, resource)
}
}
filterGroupPeers(&components.Groups, components.Peers)
filterPostureFailedPeers(&components.PostureFailedPeers, components.Policies, components.ResourcePoliciesMap, components.Peers)
return components
}
type sshRequirements struct {
neededGroupIDs map[string]struct{}
needAllowedUserIDs bool
}
func (a *Account) getPeersGroupsPoliciesRoutes(
ctx context.Context,
peerID string,
peerSSHEnabled bool,
validatedPeersMap map[string]struct{},
postureFailedPeers *map[string]map[string]struct{},
) (map[string]*nbpeer.Peer, map[string]*Group, []*Policy, []*route.Route, sshRequirements) {
relevantPeerIDs := make(map[string]*nbpeer.Peer, len(a.Peers)/4)
relevantGroupIDs := make(map[string]*Group, len(a.Groups)/4)
relevantPolicies := make([]*Policy, 0, len(a.Policies))
relevantRoutes := make([]*route.Route, 0, len(a.Routes))
sshReqs := sshRequirements{neededGroupIDs: make(map[string]struct{})}
relevantPeerIDs[peerID] = a.GetPeer(peerID)
for groupID, group := range a.Groups {
if slices.Contains(group.Peers, peerID) {
relevantGroupIDs[groupID] = a.GetGroup(groupID)
}
}
routeAccessControlGroups := make(map[string]struct{})
for _, r := range a.Routes {
for _, groupID := range r.Groups {
relevantGroupIDs[groupID] = a.GetGroup(groupID)
}
for _, groupID := range r.PeerGroups {
relevantGroupIDs[groupID] = a.GetGroup(groupID)
}
if r.Enabled {
for _, groupID := range r.AccessControlGroups {
relevantGroupIDs[groupID] = a.GetGroup(groupID)
routeAccessControlGroups[groupID] = struct{}{}
}
}
relevantRoutes = append(relevantRoutes, r)
}
for _, policy := range a.Policies {
if !policy.Enabled {
continue
}
policyRelevant := false
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
if len(routeAccessControlGroups) > 0 {
for _, destGroupID := range rule.Destinations {
if _, needed := routeAccessControlGroups[destGroupID]; needed {
policyRelevant = true
for _, srcGroupID := range rule.Sources {
relevantGroupIDs[srcGroupID] = a.GetGroup(srcGroupID)
}
for _, dstGroupID := range rule.Destinations {
relevantGroupIDs[dstGroupID] = a.GetGroup(dstGroupID)
}
break
}
}
}
var sourcePeers, destinationPeers []string
var peerInSources, peerInDestinations bool
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
sourcePeers = []string{rule.SourceResource.ID}
if rule.SourceResource.ID == peerID {
peerInSources = true
}
} else {
sourcePeers, peerInSources = a.getPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap, postureFailedPeers)
}
if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
destinationPeers = []string{rule.DestinationResource.ID}
if rule.DestinationResource.ID == peerID {
peerInDestinations = true
}
} else {
destinationPeers, peerInDestinations = a.getPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap, postureFailedPeers)
}
if peerInSources {
policyRelevant = true
for _, pid := range destinationPeers {
relevantPeerIDs[pid] = a.GetPeer(pid)
}
for _, dstGroupID := range rule.Destinations {
relevantGroupIDs[dstGroupID] = a.GetGroup(dstGroupID)
}
}
if peerInDestinations {
policyRelevant = true
for _, pid := range sourcePeers {
relevantPeerIDs[pid] = a.GetPeer(pid)
}
for _, srcGroupID := range rule.Sources {
relevantGroupIDs[srcGroupID] = a.GetGroup(srcGroupID)
}
if rule.Protocol == PolicyRuleProtocolNetbirdSSH {
switch {
case len(rule.AuthorizedGroups) > 0:
for groupID := range rule.AuthorizedGroups {
sshReqs.neededGroupIDs[groupID] = struct{}{}
}
case rule.AuthorizedUser != "":
default:
sshReqs.needAllowedUserIDs = true
}
} else if policyRuleImpliesLegacySSH(rule) && peerSSHEnabled {
sshReqs.needAllowedUserIDs = true
}
}
}
if policyRelevant {
relevantPolicies = append(relevantPolicies, policy)
}
}
return relevantPeerIDs, relevantGroupIDs, relevantPolicies, relevantRoutes, sshReqs
}
func (a *Account) getPeersFromGroups(ctx context.Context, groups []string, peerID string, sourcePostureChecksIDs []string,
validatedPeersMap map[string]struct{}, postureFailedPeers *map[string]map[string]struct{}) ([]string, bool) {
peerInGroups := false
filteredPeerIDs := make([]string, 0, len(a.Peers))
seenPeerIds := make(map[string]struct{}, len(groups))
for _, gid := range groups {
group := a.GetGroup(gid)
if group == nil {
continue
}
if group.IsGroupAll() || len(groups) == 1 {
filteredPeerIDs = filteredPeerIDs[:0]
peerInGroups = false
for _, pid := range group.Peers {
peer, ok := a.Peers[pid]
if !ok || peer == nil {
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid && len(pname) > 0 {
if _, ok := (*postureFailedPeers)[pname]; !ok {
(*postureFailedPeers)[pname] = make(map[string]struct{})
}
(*postureFailedPeers)[pname][peer.ID] = struct{}{}
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeerIDs = append(filteredPeerIDs, peer.ID)
}
return filteredPeerIDs, peerInGroups
}
for _, pid := range group.Peers {
if _, seen := seenPeerIds[pid]; seen {
continue
}
seenPeerIds[pid] = struct{}{}
peer, ok := a.Peers[pid]
if !ok || peer == nil {
continue
}
if _, ok := validatedPeersMap[peer.ID]; !ok {
continue
}
isValid, pname := a.validatePostureChecksOnPeerGetFailed(ctx, sourcePostureChecksIDs, peer.ID)
if !isValid && len(pname) > 0 {
if _, ok := (*postureFailedPeers)[pname]; !ok {
(*postureFailedPeers)[pname] = make(map[string]struct{})
}
(*postureFailedPeers)[pname][peer.ID] = struct{}{}
continue
}
if peer.ID == peerID {
peerInGroups = true
continue
}
filteredPeerIDs = append(filteredPeerIDs, peer.ID)
}
}
return filteredPeerIDs, peerInGroups
}
func (a *Account) validatePostureChecksOnPeerGetFailed(ctx context.Context, sourcePostureChecksID []string, peerID string) (bool, string) {
peer, ok := a.Peers[peerID]
if !ok || peer == nil {
return false, ""
}
for _, postureChecksID := range sourcePostureChecksID {
postureChecks := a.GetPostureChecks(postureChecksID)
if postureChecks == nil {
continue
}
for _, check := range postureChecks.GetChecks() {
isValid, _ := check.Check(ctx, *peer)
if !isValid {
return false, postureChecksID
}
}
}
return true, ""
}
func (a *Account) getPostureValidPeersSaveFailed(inputPeers []string, postureChecksIDs []string, validatedPeersMap map[string]struct{}, postureFailedPeers *map[string]map[string]struct{}) []string {
var dest []string
for _, peerID := range inputPeers {
if _, validated := validatedPeersMap[peerID]; !validated {
continue
}
valid, pname := a.validatePostureChecksOnPeerGetFailed(context.Background(), postureChecksIDs, peerID)
if valid {
dest = append(dest, peerID)
continue
}
if _, ok := (*postureFailedPeers)[pname]; !ok {
(*postureFailedPeers)[pname] = make(map[string]struct{})
}
(*postureFailedPeers)[pname][peerID] = struct{}{}
}
return dest
}
func filterGroupPeers(groups *map[string]*Group, peers map[string]*nbpeer.Peer) {
for groupID, groupInfo := range *groups {
filteredPeers := make([]string, 0, len(groupInfo.Peers))
for _, pid := range groupInfo.Peers {
if _, exists := peers[pid]; exists {
filteredPeers = append(filteredPeers, pid)
}
}
if len(filteredPeers) == 0 {
delete(*groups, groupID)
} else if len(filteredPeers) != len(groupInfo.Peers) {
ng := groupInfo.Copy()
ng.Peers = filteredPeers
(*groups)[groupID] = ng
}
}
}
func filterPostureFailedPeers(postureFailedPeers *map[string]map[string]struct{}, policies []*Policy, resourcePoliciesMap map[string][]*Policy, peers map[string]*nbpeer.Peer) {
if len(*postureFailedPeers) == 0 {
return
}
referencedPostureChecks := make(map[string]struct{})
for _, policy := range policies {
for _, checkID := range policy.SourcePostureChecks {
referencedPostureChecks[checkID] = struct{}{}
}
}
for _, resPolicies := range resourcePoliciesMap {
for _, policy := range resPolicies {
for _, checkID := range policy.SourcePostureChecks {
referencedPostureChecks[checkID] = struct{}{}
}
}
}
for checkID, failedPeers := range *postureFailedPeers {
if _, referenced := referencedPostureChecks[checkID]; !referenced {
delete(*postureFailedPeers, checkID)
continue
}
for peerID := range failedPeers {
if _, exists := peers[peerID]; !exists {
delete(failedPeers, peerID)
}
}
if len(failedPeers) == 0 {
delete(*postureFailedPeers, checkID)
}
}
}
func filterDNSRecordsByPeers(records []nbdns.SimpleRecord, peers map[string]*nbpeer.Peer) []nbdns.SimpleRecord {
if len(records) == 0 || len(peers) == 0 {
return nil
}
peerIPs := make(map[string]struct{}, len(peers))
for _, peer := range peers {
if peer != nil {
peerIPs[peer.IP.String()] = struct{}{}
}
}
filteredRecords := make([]nbdns.SimpleRecord, 0, len(records))
for _, record := range records {
if _, exists := peerIPs[record.RData]; exists {
filteredRecords = append(filteredRecords, record)
}
}
return filteredRecords
}
func filterGroupIDToUserIDs(fullMap map[string][]string, neededGroupIDs map[string]struct{}) map[string][]string {
if len(neededGroupIDs) == 0 {
return nil
}
filtered := make(map[string][]string, len(neededGroupIDs))
for groupID := range neededGroupIDs {
if users, ok := fullMap[groupID]; ok {
filtered[groupID] = users
}
}
return filtered
}

View File

@@ -0,0 +1,592 @@
package types
import (
"context"
"encoding/json"
"fmt"
"net"
"net/netip"
"os"
"path/filepath"
"sort"
"testing"
"time"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/route"
)
func TestNetworkMapComponents_CompareWithLegacy(t *testing.T) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid == offlinePeerID {
continue
}
validatedPeersMap[pid] = struct{}{}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
legacyNetworkMap := account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
if components == nil {
t.Fatal("GetPeerNetworkMapComponents returned nil")
}
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
if newNetworkMap == nil {
t.Fatal("CalculateNetworkMapFromComponents returned nil")
}
compareNetworkMaps(t, legacyNetworkMap, newNetworkMap)
}
func TestNetworkMapComponents_GoldenFileComparison(t *testing.T) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid == offlinePeerID {
continue
}
validatedPeersMap[pid] = struct{}{}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
legacyNetworkMap := account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
require.NotNil(t, components, "GetPeerNetworkMapComponents returned nil")
newNetworkMap := CalculateNetworkMapFromComponents(ctx, components)
require.NotNil(t, newNetworkMap, "CalculateNetworkMapFromComponents returned nil")
normalizeAndSortNetworkMap(legacyNetworkMap)
normalizeAndSortNetworkMap(newNetworkMap)
componentsJSON, err := json.MarshalIndent(components, "", " ")
require.NoError(t, err, "error marshaling components to JSON")
legacyJSON, err := json.MarshalIndent(legacyNetworkMap, "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
newJSON, err := json.MarshalIndent(newNetworkMap, "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
goldenDir := filepath.Join("testdata", "comparison")
err = os.MkdirAll(goldenDir, 0755)
require.NoError(t, err)
legacyGoldenPath := filepath.Join(goldenDir, "legacy_networkmap.json")
err = os.WriteFile(legacyGoldenPath, legacyJSON, 0644)
require.NoError(t, err, "error writing legacy golden file")
newGoldenPath := filepath.Join(goldenDir, "components_networkmap.json")
err = os.WriteFile(newGoldenPath, newJSON, 0644)
require.NoError(t, err, "error writing components golden file")
componentsPath := filepath.Join(goldenDir, "components.json")
err = os.WriteFile(componentsPath, componentsJSON, 0644)
require.NoError(t, err, "error writing components golden file")
require.JSONEq(t, string(legacyJSON), string(newJSON),
"NetworkMaps from legacy and components approaches do not match.\n"+
"Legacy JSON saved to: %s\n"+
"Components JSON saved to: %s",
legacyGoldenPath, newGoldenPath)
t.Logf("✅ NetworkMaps are identical")
t.Logf(" Legacy NetworkMap: %s", legacyGoldenPath)
t.Logf(" Components NetworkMap: %s", newGoldenPath)
}
func normalizeAndSortNetworkMap(nm *NetworkMap) {
if nm == nil {
return
}
sort.Slice(nm.Peers, func(i, j int) bool {
return nm.Peers[i].ID < nm.Peers[j].ID
})
sort.Slice(nm.OfflinePeers, func(i, j int) bool {
return nm.OfflinePeers[i].ID < nm.OfflinePeers[j].ID
})
sort.Slice(nm.Routes, func(i, j int) bool {
return string(nm.Routes[i].ID) < string(nm.Routes[j].ID)
})
sort.Slice(nm.FirewallRules, func(i, j int) bool {
if nm.FirewallRules[i].PeerIP != nm.FirewallRules[j].PeerIP {
return nm.FirewallRules[i].PeerIP < nm.FirewallRules[j].PeerIP
}
if nm.FirewallRules[i].Direction != nm.FirewallRules[j].Direction {
return nm.FirewallRules[i].Direction < nm.FirewallRules[j].Direction
}
if nm.FirewallRules[i].Protocol != nm.FirewallRules[j].Protocol {
return nm.FirewallRules[i].Protocol < nm.FirewallRules[j].Protocol
}
if nm.FirewallRules[i].Port != nm.FirewallRules[j].Port {
return nm.FirewallRules[i].Port < nm.FirewallRules[j].Port
}
return nm.FirewallRules[i].PolicyID < nm.FirewallRules[j].PolicyID
})
for i := range nm.RoutesFirewallRules {
sort.Strings(nm.RoutesFirewallRules[i].SourceRanges)
}
sort.Slice(nm.RoutesFirewallRules, func(i, j int) bool {
if nm.RoutesFirewallRules[i].Destination != nm.RoutesFirewallRules[j].Destination {
return nm.RoutesFirewallRules[i].Destination < nm.RoutesFirewallRules[j].Destination
}
minLen := len(nm.RoutesFirewallRules[i].SourceRanges)
if len(nm.RoutesFirewallRules[j].SourceRanges) < minLen {
minLen = len(nm.RoutesFirewallRules[j].SourceRanges)
}
for k := 0; k < minLen; k++ {
if nm.RoutesFirewallRules[i].SourceRanges[k] != nm.RoutesFirewallRules[j].SourceRanges[k] {
return nm.RoutesFirewallRules[i].SourceRanges[k] < nm.RoutesFirewallRules[j].SourceRanges[k]
}
}
if len(nm.RoutesFirewallRules[i].SourceRanges) != len(nm.RoutesFirewallRules[j].SourceRanges) {
return len(nm.RoutesFirewallRules[i].SourceRanges) < len(nm.RoutesFirewallRules[j].SourceRanges)
}
if string(nm.RoutesFirewallRules[i].RouteID) != string(nm.RoutesFirewallRules[j].RouteID) {
return string(nm.RoutesFirewallRules[i].RouteID) < string(nm.RoutesFirewallRules[j].RouteID)
}
if nm.RoutesFirewallRules[i].PolicyID != nm.RoutesFirewallRules[j].PolicyID {
return nm.RoutesFirewallRules[i].PolicyID < nm.RoutesFirewallRules[j].PolicyID
}
if nm.RoutesFirewallRules[i].Port != nm.RoutesFirewallRules[j].Port {
return nm.RoutesFirewallRules[i].Port < nm.RoutesFirewallRules[j].Port
}
return nm.RoutesFirewallRules[i].Protocol < nm.RoutesFirewallRules[j].Protocol
})
if nm.DNSConfig.CustomZones != nil {
for i := range nm.DNSConfig.CustomZones {
sort.Slice(nm.DNSConfig.CustomZones[i].Records, func(a, b int) bool {
return nm.DNSConfig.CustomZones[i].Records[a].Name < nm.DNSConfig.CustomZones[i].Records[b].Name
})
}
}
if len(nm.DNSConfig.NameServerGroups) != 0 {
sort.Slice(nm.DNSConfig.NameServerGroups, func(a, b int) bool {
return nm.DNSConfig.NameServerGroups[a].Name < nm.DNSConfig.NameServerGroups[b].Name
})
}
}
func compareNetworkMaps(t *testing.T, legacy, current *NetworkMap) {
t.Helper()
if legacy.Network.Serial != current.Network.Serial {
t.Errorf("Network Serial mismatch: legacy=%d, current=%d", legacy.Network.Serial, current.Network.Serial)
}
if len(legacy.Peers) != len(current.Peers) {
t.Errorf("Peers count mismatch: legacy=%d, current=%d", len(legacy.Peers), len(current.Peers))
}
legacyPeerIDs := make(map[string]bool)
for _, p := range legacy.Peers {
legacyPeerIDs[p.ID] = true
}
for _, p := range current.Peers {
if !legacyPeerIDs[p.ID] {
t.Errorf("Current NetworkMap contains peer %s not in legacy", p.ID)
}
}
if len(legacy.OfflinePeers) != len(current.OfflinePeers) {
t.Errorf("OfflinePeers count mismatch: legacy=%d, current=%d", len(legacy.OfflinePeers), len(current.OfflinePeers))
}
if len(legacy.FirewallRules) != len(current.FirewallRules) {
t.Logf("FirewallRules count mismatch: legacy=%d, current=%d", len(legacy.FirewallRules), len(current.FirewallRules))
}
if len(legacy.Routes) != len(current.Routes) {
t.Logf("Routes count mismatch: legacy=%d, current=%d", len(legacy.Routes), len(current.Routes))
}
if len(legacy.RoutesFirewallRules) != len(current.RoutesFirewallRules) {
t.Logf("RoutesFirewallRules count mismatch: legacy=%d, current=%d", len(legacy.RoutesFirewallRules), len(current.RoutesFirewallRules))
}
if legacy.DNSConfig.ServiceEnable != current.DNSConfig.ServiceEnable {
t.Errorf("DNSConfig.ServiceEnable mismatch: legacy=%v, current=%v", legacy.DNSConfig.ServiceEnable, current.DNSConfig.ServiceEnable)
}
}
const (
numPeers = 100
devGroupID = "group-dev"
opsGroupID = "group-ops"
allGroupID = "group-all"
routeID = route.ID("route-main")
routeHA1ID = route.ID("route-ha-1")
routeHA2ID = route.ID("route-ha-2")
policyIDDevOps = "policy-dev-ops"
policyIDAll = "policy-all"
policyIDPosture = "policy-posture"
policyIDDrop = "policy-drop"
postureCheckID = "posture-check-ver"
networkResourceID = "res-database"
networkID = "net-database"
networkRouterID = "router-database"
nameserverGroupID = "ns-group-main"
testingPeerID = "peer-60"
expiredPeerID = "peer-98"
offlinePeerID = "peer-99"
routingPeerID = "peer-95"
testAccountID = "account-comparison-test"
)
func createTestAccount() *Account {
peers := make(map[string]*nbpeer.Peer)
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
ip := net.IP{100, 64, 0, byte(i + 1)}
wtVersion := "0.25.0"
if i%2 == 0 {
wtVersion = "0.40.0"
}
p := &nbpeer.Peer{
ID: peerID, IP: ip, Key: fmt.Sprintf("key-%s", peerID), DNSLabel: fmt.Sprintf("peer%d", i+1),
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin", Meta: nbpeer.PeerSystemMeta{WtVersion: wtVersion, GoOS: "linux"},
}
if peerID == expiredPeerID {
p.LoginExpirationEnabled = true
pastTimestamp := time.Now().Add(-2 * time.Hour)
p.LastLogin = &pastTimestamp
}
peers[peerID] = p
allGroupPeers = append(allGroupPeers, peerID)
if i < numPeers/2 {
devGroupPeers = append(devGroupPeers, peerID)
} else {
opsGroupPeers = append(opsGroupPeers, peerID)
}
}
groups := map[string]*Group{
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
}
policies := []*Policy{
{
ID: policyIDAll, Name: "Default-Allow", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDAll, Name: "Allow All", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{allGroupID}, Destinations: []string{allGroupID},
}},
},
{
ID: policyIDDevOps, Name: "Dev to Ops Web Access", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDDevOps, Name: "Dev -> Ops (HTTP Range)", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolTCP, Bidirectional: false,
PortRanges: []RulePortRange{{Start: 8080, End: 8090}},
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDDrop, Name: "Drop DB traffic", Enabled: true,
Rules: []*PolicyRule{{
ID: policyIDDrop, Name: "Drop DB", Enabled: true, Action: PolicyTrafficActionDrop,
Protocol: PolicyRuleProtocolTCP, Ports: []string{"5432"}, Bidirectional: true,
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
}},
},
{
ID: policyIDPosture, Name: "Posture Check for DB Resource", Enabled: true,
SourcePostureChecks: []string{postureCheckID},
Rules: []*PolicyRule{{
ID: policyIDPosture, Name: "Allow DB Access", Enabled: true, Action: PolicyTrafficActionAccept,
Protocol: PolicyRuleProtocolALL, Bidirectional: true,
Sources: []string{opsGroupID}, DestinationResource: Resource{ID: networkResourceID},
}},
},
}
routes := map[route.ID]*route.Route{
routeID: {
ID: routeID, Network: netip.MustParsePrefix("192.168.10.0/24"),
Peer: peers["peer-75"].Key,
PeerID: "peer-75",
Description: "Route to internal resource", Enabled: true,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
},
routeHA1ID: {
ID: routeHA1ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-80"].Key,
PeerID: "peer-80",
Description: "HA Route 1", Enabled: true, Metric: 1000,
PeerGroups: []string{allGroupID},
Groups: []string{allGroupID},
AccessControlGroups: []string{allGroupID},
},
routeHA2ID: {
ID: routeHA2ID, Network: netip.MustParsePrefix("10.10.0.0/16"),
Peer: peers["peer-90"].Key,
PeerID: "peer-90",
Description: "HA Route 2", Enabled: true, Metric: 900,
PeerGroups: []string{devGroupID, opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{allGroupID},
},
}
account := &Account{
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
Network: &Network{
Identifier: "net-comparison-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
},
DNSSettings: DNSSettings{DisabledManagementGroups: []string{opsGroupID}},
NameServerGroups: map[string]*nbdns.NameServerGroup{
nameserverGroupID: {
ID: nameserverGroupID, Name: "Main NS", Enabled: true, Groups: []string{devGroupID},
NameServers: []nbdns.NameServer{{IP: netip.MustParseAddr("8.8.8.8"), NSType: nbdns.UDPNameServerType, Port: 53}},
},
},
PostureChecks: []*posture.Checks{
{ID: postureCheckID, Name: "Check version", Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.26.0"},
}},
},
NetworkResources: []*resourceTypes.NetworkResource{
{ID: networkResourceID, NetworkID: networkID, AccountID: testAccountID, Enabled: true, Address: "db.netbird.cloud"},
},
Networks: []*networkTypes.Network{{ID: networkID, Name: "DB Network", AccountID: testAccountID}},
NetworkRouters: []*routerTypes.NetworkRouter{
{ID: networkRouterID, NetworkID: networkID, Peer: routingPeerID, Enabled: true, AccountID: testAccountID},
},
Settings: &Settings{PeerLoginExpirationEnabled: true, PeerLoginExpiration: 1 * time.Hour},
}
for _, p := range account.Policies {
p.AccountID = account.Id
}
for _, r := range account.Routes {
r.AccountID = account.Id
}
return account
}
func BenchmarkLegacyNetworkMap(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMap(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
nil,
groupIDToUserIDs,
)
}
}
func BenchmarkComponentsNetworkMap(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ResetTimer()
for i := 0; i < b.N; i++ {
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
_ = CalculateNetworkMapFromComponents(ctx, components)
}
}
func BenchmarkComponentsCreation(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
}
}
func BenchmarkCalculationFromComponents(b *testing.B) {
account := createTestAccount()
ctx := context.Background()
peerID := testingPeerID
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
pid := fmt.Sprintf("peer-%d", i)
if pid != offlinePeerID {
validatedPeersMap[pid] = struct{}{}
}
}
peersCustomZone := nbdns.CustomZone{}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
components := account.GetPeerNetworkMapComponents(
ctx,
peerID,
peersCustomZone,
nil,
validatedPeersMap,
resourcePolicies,
routers,
groupIDToUserIDs,
)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = CalculateNetworkMapFromComponents(ctx, components)
}
}

View File

@@ -0,0 +1,938 @@
package types
import (
"context"
"maps"
"net"
"net/netip"
"slices"
"strconv"
"strings"
"time"
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
const EnvNewNetworkMapCompacted = "NB_NETWORK_MAP_COMPACTED"
type NetworkMapComponents struct {
PeerID string
Network *Network
AccountSettings *AccountSettingsInfo
DNSSettings *DNSSettings
CustomZoneDomain string
Peers map[string]*nbpeer.Peer
Groups map[string]*Group
Policies []*Policy
Routes []*route.Route
NameServerGroups []*nbdns.NameServerGroup
AllDNSRecords []nbdns.SimpleRecord
AccountZones []nbdns.CustomZone
ResourcePoliciesMap map[string][]*Policy
RoutersMap map[string]map[string]*routerTypes.NetworkRouter
NetworkResources []*resourceTypes.NetworkResource
GroupIDToUserIDs map[string][]string
AllowedUserIDs map[string]struct{}
PostureFailedPeers map[string]map[string]struct{}
RouterPeers map[string]*nbpeer.Peer
}
type AccountSettingsInfo struct {
PeerLoginExpirationEnabled bool
PeerLoginExpiration time.Duration
PeerInactivityExpirationEnabled bool
PeerInactivityExpiration time.Duration
}
func (c *NetworkMapComponents) GetPeerInfo(peerID string) *nbpeer.Peer {
return c.Peers[peerID]
}
func (c *NetworkMapComponents) GetRouterPeerInfo(peerID string) *nbpeer.Peer {
return c.RouterPeers[peerID]
}
func (c *NetworkMapComponents) GetGroupInfo(groupID string) *Group {
return c.Groups[groupID]
}
func (c *NetworkMapComponents) IsPeerInGroup(peerID, groupID string) bool {
group := c.GetGroupInfo(groupID)
if group == nil {
return false
}
return slices.Contains(group.Peers, peerID)
}
func (c *NetworkMapComponents) GetPeerGroups(peerID string) map[string]struct{} {
groups := make(map[string]struct{})
for groupID, group := range c.Groups {
if slices.Contains(group.Peers, peerID) {
groups[groupID] = struct{}{}
}
}
return groups
}
func (c *NetworkMapComponents) ValidatePostureChecksOnPeer(peerID string, postureCheckIDs []string) bool {
_, exists := c.Peers[peerID]
if !exists {
return false
}
if len(postureCheckIDs) == 0 {
return true
}
for _, checkID := range postureCheckIDs {
if failedPeers, exists := c.PostureFailedPeers[checkID]; exists {
if _, failed := failedPeers[peerID]; failed {
return false
}
}
}
return true
}
func CalculateNetworkMapFromComponents(ctx context.Context, components *NetworkMapComponents) *NetworkMap {
return components.Calculate(ctx)
}
func (c *NetworkMapComponents) Calculate(ctx context.Context) *NetworkMap {
targetPeerID := c.PeerID
peerGroups := c.GetPeerGroups(targetPeerID)
aclPeers, firewallRules, authorizedUsers, sshEnabled := c.getPeerConnectionResources(targetPeerID)
peersToConnect, expiredPeers := c.filterPeersByLoginExpiration(aclPeers)
routesUpdate := c.getRoutesToSync(targetPeerID, peersToConnect, peerGroups)
routesFirewallRules := c.getPeerRoutesFirewallRules(ctx, targetPeerID)
isRouter, networkResourcesRoutes, sourcePeers := c.getNetworkResourcesRoutesToSync(targetPeerID)
var networkResourcesFirewallRules []*RouteFirewallRule
if isRouter {
networkResourcesFirewallRules = c.getPeerNetworkResourceFirewallRules(ctx, targetPeerID, networkResourcesRoutes)
}
peersToConnectIncludingRouters := c.addNetworksRoutingPeers(
networkResourcesRoutes,
targetPeerID,
peersToConnect,
expiredPeers,
isRouter,
sourcePeers,
)
dnsManagementStatus := c.getPeerDNSManagementStatus(targetPeerID)
dnsUpdate := nbdns.Config{
ServiceEnable: dnsManagementStatus,
}
if dnsManagementStatus {
var customZones []nbdns.CustomZone
if c.CustomZoneDomain != "" && len(c.AllDNSRecords) > 0 {
customZones = append(customZones, nbdns.CustomZone{
Domain: c.CustomZoneDomain,
Records: c.AllDNSRecords,
})
}
customZones = append(customZones, c.AccountZones...)
dnsUpdate.CustomZones = customZones
dnsUpdate.NameServerGroups = c.getPeerNSGroups(targetPeerID)
}
return &NetworkMap{
Peers: peersToConnectIncludingRouters,
Network: c.Network.Copy(),
Routes: append(networkResourcesRoutes, routesUpdate...),
DNSConfig: dnsUpdate,
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
RoutesFirewallRules: append(networkResourcesFirewallRules, routesFirewallRules...),
AuthorizedUsers: authorizedUsers,
EnableSSH: sshEnabled,
}
}
func (c *NetworkMapComponents) getPeerConnectionResources(targetPeerID string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) {
targetPeer := c.GetPeerInfo(targetPeerID)
if targetPeer == nil {
return nil, nil, nil, false
}
generateResources, getAccumulatedResources := c.connResourcesGenerator(targetPeer)
authorizedUsers := make(map[string]map[string]struct{})
sshEnabled := false
for _, policy := range c.Policies {
if !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
var sourcePeers, destinationPeers []*nbpeer.Peer
var peerInSources, peerInDestinations bool
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
sourcePeers, peerInSources = c.getPeerFromResource(rule.SourceResource, targetPeerID)
} else {
sourcePeers, peerInSources = c.getAllPeersFromGroups(rule.Sources, targetPeerID, policy.SourcePostureChecks)
}
if rule.DestinationResource.Type == ResourceTypePeer && rule.DestinationResource.ID != "" {
destinationPeers, peerInDestinations = c.getPeerFromResource(rule.DestinationResource, targetPeerID)
} else {
destinationPeers, peerInDestinations = c.getAllPeersFromGroups(rule.Destinations, targetPeerID, nil)
}
if rule.Bidirectional {
if peerInSources {
generateResources(rule, destinationPeers, FirewallRuleDirectionIN)
}
if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionOUT)
}
}
if peerInSources {
generateResources(rule, destinationPeers, FirewallRuleDirectionOUT)
}
if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
}
if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
sshEnabled = true
switch {
case len(rule.AuthorizedGroups) > 0:
for groupID, localUsers := range rule.AuthorizedGroups {
userIDs, ok := c.GroupIDToUserIDs[groupID]
if !ok {
continue
}
if len(localUsers) == 0 {
localUsers = []string{auth.Wildcard}
}
for _, localUser := range localUsers {
if authorizedUsers[localUser] == nil {
authorizedUsers[localUser] = make(map[string]struct{})
}
for _, userID := range userIDs {
authorizedUsers[localUser][userID] = struct{}{}
}
}
}
case rule.AuthorizedUser != "":
if authorizedUsers[auth.Wildcard] == nil {
authorizedUsers[auth.Wildcard] = make(map[string]struct{})
}
authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{}
default:
authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs()
}
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && targetPeer.SSHEnabled {
sshEnabled = true
authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs()
}
}
}
peers, fwRules := getAccumulatedResources()
return peers, fwRules, authorizedUsers, sshEnabled
}
func (c *NetworkMapComponents) getAllowedUserIDs() map[string]struct{} {
if c.AllowedUserIDs != nil {
result := make(map[string]struct{}, len(c.AllowedUserIDs))
maps.Copy(result, c.AllowedUserIDs)
return result
}
return make(map[string]struct{})
}
func (c *NetworkMapComponents) connResourcesGenerator(targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0)
peers := make([]*nbpeer.Peer, 0)
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
for _, peer := range groupPeers {
if peer == nil {
continue
}
if _, ok := peersExists[peer.ID]; !ok {
peers = append(peers, peer)
peersExists[peer.ID] = struct{}{}
}
protocol := rule.Protocol
if protocol == PolicyRuleProtocolNetbirdSSH {
protocol = PolicyRuleProtocolTCP
}
fr := FirewallRule{
PolicyID: rule.ID,
PeerIP: net.IP(peer.IP).String(),
Direction: direction,
Action: string(rule.Action),
Protocol: string(protocol),
}
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
if _, ok := rulesExists[ruleID]; ok {
continue
}
rulesExists[ruleID] = struct{}{}
if len(rule.Ports) == 0 && len(rule.PortRanges) == 0 {
rules = append(rules, &fr)
continue
}
rules = append(rules, expandPortsAndRanges(fr, &PolicyRule{
ID: rule.ID,
Ports: rule.Ports,
PortRanges: rule.PortRanges,
Protocol: rule.Protocol,
Action: rule.Action,
}, targetPeer)...)
}
}, func() ([]*nbpeer.Peer, []*FirewallRule) {
return peers, rules
}
}
func (c *NetworkMapComponents) getAllPeersFromGroups(groups []string, peerID string, sourcePostureChecksIDs []string) ([]*nbpeer.Peer, bool) {
peerInGroups := false
uniquePeerIDs := c.getUniquePeerIDsFromGroupsIDs(groups)
filteredPeers := make([]*nbpeer.Peer, 0, len(uniquePeerIDs))
for _, p := range uniquePeerIDs {
peerInfo := c.GetPeerInfo(p)
if peerInfo == nil {
continue
}
if _, ok := c.Peers[p]; !ok {
continue
}
if !c.ValidatePostureChecksOnPeer(p, sourcePostureChecksIDs) {
continue
}
if p == peerID {
peerInGroups = true
continue
}
filteredPeers = append(filteredPeers, peerInfo)
}
return filteredPeers, peerInGroups
}
func (c *NetworkMapComponents) getUniquePeerIDsFromGroupsIDs(groups []string) []string {
peerIDs := make(map[string]struct{}, len(groups))
for _, groupID := range groups {
group := c.GetGroupInfo(groupID)
if group == nil {
continue
}
if group.IsGroupAll() || len(groups) == 1 {
return group.Peers
}
for _, peerID := range group.Peers {
peerIDs[peerID] = struct{}{}
}
}
ids := make([]string, 0, len(peerIDs))
for peerID := range peerIDs {
ids = append(ids, peerID)
}
return ids
}
func (c *NetworkMapComponents) getPeerFromResource(resource Resource, peerID string) ([]*nbpeer.Peer, bool) {
if resource.ID == peerID {
return []*nbpeer.Peer{}, true
}
peerInfo := c.GetPeerInfo(resource.ID)
if peerInfo == nil {
return []*nbpeer.Peer{}, false
}
return []*nbpeer.Peer{peerInfo}, false
}
func (c *NetworkMapComponents) filterPeersByLoginExpiration(aclPeers []*nbpeer.Peer) ([]*nbpeer.Peer, []*nbpeer.Peer) {
var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer
for _, p := range aclPeers {
expired, _ := p.LoginExpired(c.AccountSettings.PeerLoginExpiration)
if c.AccountSettings.PeerLoginExpirationEnabled && expired {
expiredPeers = append(expiredPeers, p)
continue
}
peersToConnect = append(peersToConnect, p)
}
return peersToConnect, expiredPeers
}
func (c *NetworkMapComponents) getPeerDNSManagementStatus(peerID string) bool {
peerGroups := c.GetPeerGroups(peerID)
enabled := true
for _, groupID := range c.DNSSettings.DisabledManagementGroups {
if _, found := peerGroups[groupID]; found {
enabled = false
break
}
}
return enabled
}
func (c *NetworkMapComponents) getPeerNSGroups(peerID string) []*nbdns.NameServerGroup {
groupList := c.GetPeerGroups(peerID)
var peerNSGroups []*nbdns.NameServerGroup
for _, nsGroup := range c.NameServerGroups {
if !nsGroup.Enabled {
continue
}
for _, gID := range nsGroup.Groups {
_, found := groupList[gID]
if found {
targetPeerInfo := c.GetPeerInfo(peerID)
if targetPeerInfo != nil && !c.peerIsNameserver(targetPeerInfo, nsGroup) {
peerNSGroups = append(peerNSGroups, nsGroup.Copy())
break
}
}
}
}
return peerNSGroups
}
func (c *NetworkMapComponents) peerIsNameserver(peerInfo *nbpeer.Peer, nsGroup *nbdns.NameServerGroup) bool {
for _, ns := range nsGroup.NameServers {
if peerInfo.IP.String() == ns.IP.String() {
return true
}
}
return false
}
func (c *NetworkMapComponents) getRoutesToSync(peerID string, aclPeers []*nbpeer.Peer, peerGroups LookupMap) []*route.Route {
routes, peerDisabledRoutes := c.getRoutingPeerRoutes(peerID)
peerRoutesMembership := make(LookupMap)
for _, r := range append(routes, peerDisabledRoutes...) {
peerRoutesMembership[string(r.GetHAUniqueID())] = struct{}{}
}
for _, peer := range aclPeers {
activeRoutes, _ := c.getRoutingPeerRoutes(peer.ID)
groupFilteredRoutes := c.filterRoutesByGroups(activeRoutes, peerGroups)
filteredRoutes := c.filterRoutesFromPeersOfSameHAGroup(groupFilteredRoutes, peerRoutesMembership)
routes = append(routes, filteredRoutes...)
}
return routes
}
func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoutes []*route.Route, disabledRoutes []*route.Route) {
peerInfo := c.GetPeerInfo(peerID)
if peerInfo == nil {
peerInfo = c.GetRouterPeerInfo(peerID)
}
if peerInfo == nil {
return enabledRoutes, disabledRoutes
}
seenRoute := make(map[route.ID]struct{})
takeRoute := func(r *route.Route) {
if _, ok := seenRoute[r.ID]; ok {
return
}
seenRoute[r.ID] = struct{}{}
routeObj := c.copyRoute(r)
routeObj.Peer = peerInfo.Key
if r.Enabled {
enabledRoutes = append(enabledRoutes, routeObj)
return
}
disabledRoutes = append(disabledRoutes, routeObj)
}
for _, r := range c.Routes {
for _, groupID := range r.PeerGroups {
group := c.GetGroupInfo(groupID)
if group == nil {
continue
}
for _, id := range group.Peers {
if id != peerID {
continue
}
newPeerRoute := c.copyRoute(r)
newPeerRoute.Peer = id
newPeerRoute.PeerGroups = nil
newPeerRoute.ID = route.ID(string(r.ID) + ":" + id)
takeRoute(newPeerRoute)
break
}
}
if r.Peer == peerID {
takeRoute(c.copyRoute(r))
}
}
return enabledRoutes, disabledRoutes
}
func (c *NetworkMapComponents) copyRoute(r *route.Route) *route.Route {
var groups, accessControlGroups, peerGroups []string
var domains domain.List
if r.Groups != nil {
groups = append([]string{}, r.Groups...)
}
if r.AccessControlGroups != nil {
accessControlGroups = append([]string{}, r.AccessControlGroups...)
}
if r.PeerGroups != nil {
peerGroups = append([]string{}, r.PeerGroups...)
}
if r.Domains != nil {
domains = append(domain.List{}, r.Domains...)
}
return &route.Route{
ID: r.ID,
AccountID: r.AccountID,
Network: r.Network,
NetworkType: r.NetworkType,
Description: r.Description,
Peer: r.Peer,
PeerID: r.PeerID,
Metric: r.Metric,
Masquerade: r.Masquerade,
NetID: r.NetID,
Enabled: r.Enabled,
Groups: groups,
AccessControlGroups: accessControlGroups,
PeerGroups: peerGroups,
Domains: domains,
KeepRoute: r.KeepRoute,
SkipAutoApply: r.SkipAutoApply,
}
}
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
for _, groupID := range r.Groups {
_, found := groupListMap[groupID]
if found {
filteredRoutes = append(filteredRoutes, r)
break
}
}
}
return filteredRoutes
}
func (c *NetworkMapComponents) filterRoutesFromPeersOfSameHAGroup(routes []*route.Route, peerMemberships LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {
_, found := peerMemberships[string(r.GetHAUniqueID())]
if !found {
filteredRoutes = append(filteredRoutes, r)
}
}
return filteredRoutes
}
func (c *NetworkMapComponents) getPeerRoutesFirewallRules(ctx context.Context, peerID string) []*RouteFirewallRule {
routesFirewallRules := make([]*RouteFirewallRule, 0)
enabledRoutes, _ := c.getRoutingPeerRoutes(peerID)
for _, r := range enabledRoutes {
if len(r.AccessControlGroups) == 0 {
defaultPermit := c.getDefaultPermit(r)
routesFirewallRules = append(routesFirewallRules, defaultPermit...)
continue
}
distributionPeers := c.getDistributionGroupsPeers(r)
for _, accessGroup := range r.AccessControlGroups {
policies := c.getAllRoutePoliciesFromGroups([]string{accessGroup})
rules := c.getRouteFirewallRules(ctx, peerID, policies, r, distributionPeers)
routesFirewallRules = append(routesFirewallRules, rules...)
}
}
return routesFirewallRules
}
func (c *NetworkMapComponents) getDefaultPermit(r *route.Route) []*RouteFirewallRule {
var rules []*RouteFirewallRule
sources := []string{"0.0.0.0/0"}
if r.Network.Addr().Is6() {
sources = []string{"::/0"}
}
rule := RouteFirewallRule{
SourceRanges: sources,
Action: string(PolicyTrafficActionAccept),
Destination: r.Network.String(),
Protocol: string(PolicyRuleProtocolALL),
Domains: r.Domains,
IsDynamic: r.IsDynamic(),
RouteID: r.ID,
}
rules = append(rules, &rule)
if r.IsDynamic() {
ruleV6 := rule
ruleV6.SourceRanges = []string{"::/0"}
rules = append(rules, &ruleV6)
}
return rules
}
func (c *NetworkMapComponents) getDistributionGroupsPeers(r *route.Route) map[string]struct{} {
distPeers := make(map[string]struct{})
for _, id := range r.Groups {
group := c.GetGroupInfo(id)
if group == nil {
continue
}
for _, pID := range group.Peers {
distPeers[pID] = struct{}{}
}
}
return distPeers
}
func (c *NetworkMapComponents) getAllRoutePoliciesFromGroups(accessControlGroups []string) []*Policy {
routePolicies := make([]*Policy, 0)
for _, groupID := range accessControlGroups {
for _, policy := range c.Policies {
for _, rule := range policy.Rules {
if slices.Contains(rule.Destinations, groupID) {
routePolicies = append(routePolicies, policy)
}
}
}
}
return routePolicies
}
func (c *NetworkMapComponents) getRouteFirewallRules(ctx context.Context, peerID string, policies []*Policy, route *route.Route, distributionPeers map[string]struct{}) []*RouteFirewallRule {
var fwRules []*RouteFirewallRule
for _, policy := range policies {
if !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
rulePeers := c.getRulePeers(rule, policy.SourcePostureChecks, peerID, distributionPeers)
rules := generateRouteFirewallRules(ctx, route, rule, rulePeers, FirewallRuleDirectionIN)
fwRules = append(fwRules, rules...)
}
}
return fwRules
}
func (c *NetworkMapComponents) getRulePeers(rule *PolicyRule, postureChecks []string, peerID string, distributionPeers map[string]struct{}) []*nbpeer.Peer {
distPeersWithPolicy := make(map[string]struct{})
for _, id := range rule.Sources {
group := c.GetGroupInfo(id)
if group == nil {
continue
}
for _, pID := range group.Peers {
if pID == peerID {
continue
}
_, distPeer := distributionPeers[pID]
_, valid := c.Peers[pID]
if distPeer && valid && c.ValidatePostureChecksOnPeer(pID, postureChecks) {
distPeersWithPolicy[pID] = struct{}{}
}
}
}
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
_, distPeer := distributionPeers[rule.SourceResource.ID]
_, valid := c.Peers[rule.SourceResource.ID]
if distPeer && valid && c.ValidatePostureChecksOnPeer(rule.SourceResource.ID, postureChecks) {
distPeersWithPolicy[rule.SourceResource.ID] = struct{}{}
}
}
distributionGroupPeers := make([]*nbpeer.Peer, 0, len(distPeersWithPolicy))
for pID := range distPeersWithPolicy {
peerInfo := c.GetPeerInfo(pID)
if peerInfo == nil {
continue
}
distributionGroupPeers = append(distributionGroupPeers, peerInfo)
}
return distributionGroupPeers
}
func (c *NetworkMapComponents) getNetworkResourcesRoutesToSync(peerID string) (bool, []*route.Route, map[string]struct{}) {
var isRoutingPeer bool
var routes []*route.Route
allSourcePeers := make(map[string]struct{})
for _, resource := range c.NetworkResources {
if !resource.Enabled {
continue
}
var addSourcePeers bool
networkRoutingPeers, exists := c.RoutersMap[resource.NetworkID]
if exists {
if router, ok := networkRoutingPeers[peerID]; ok {
isRoutingPeer, addSourcePeers = true, true
routes = append(routes, c.getNetworkResourcesRoutes(resource, peerID, router)...)
}
}
addedResourceRoute := false
for _, policy := range c.ResourcePoliciesMap[resource.ID] {
var peers []string
if policy.Rules[0].SourceResource.Type == ResourceTypePeer && policy.Rules[0].SourceResource.ID != "" {
peers = []string{policy.Rules[0].SourceResource.ID}
} else {
peers = c.getUniquePeerIDsFromGroupsIDs(policy.SourceGroups())
}
if addSourcePeers {
for _, pID := range c.getPostureValidPeers(peers, policy.SourcePostureChecks) {
allSourcePeers[pID] = struct{}{}
}
} else if slices.Contains(peers, peerID) && c.ValidatePostureChecksOnPeer(peerID, policy.SourcePostureChecks) {
for peerId, router := range networkRoutingPeers {
routes = append(routes, c.getNetworkResourcesRoutes(resource, peerId, router)...)
}
addedResourceRoute = true
}
if addedResourceRoute {
break
}
}
}
return isRoutingPeer, routes, allSourcePeers
}
func (c *NetworkMapComponents) getNetworkResourcesRoutes(resource *resourceTypes.NetworkResource, peerID string, router *routerTypes.NetworkRouter) []*route.Route {
resourceAppliedPolicies := c.ResourcePoliciesMap[resource.ID]
var routes []*route.Route
if len(resourceAppliedPolicies) > 0 {
peerInfo := c.GetPeerInfo(peerID)
if peerInfo != nil {
routes = append(routes, c.networkResourceToRoute(resource, peerInfo, router))
}
}
return routes
}
func (c *NetworkMapComponents) networkResourceToRoute(resource *resourceTypes.NetworkResource, peer *nbpeer.Peer, router *routerTypes.NetworkRouter) *route.Route {
r := &route.Route{
ID: route.ID(resource.ID + ":" + peer.ID),
AccountID: resource.AccountID,
Peer: peer.Key,
PeerID: peer.ID,
Metric: router.Metric,
Masquerade: router.Masquerade,
Enabled: resource.Enabled,
KeepRoute: true,
NetID: route.NetID(resource.Name),
Description: resource.Description,
}
if resource.Type == resourceTypes.Host || resource.Type == resourceTypes.Subnet {
r.Network = resource.Prefix
r.NetworkType = route.IPv4Network
if resource.Prefix.Addr().Is6() {
r.NetworkType = route.IPv6Network
}
}
if resource.Type == resourceTypes.Domain {
domainList, err := domain.FromStringList([]string{resource.Domain})
if err == nil {
r.Domains = domainList
r.NetworkType = route.DomainNetwork
r.Network = netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
}
}
return r
}
func (c *NetworkMapComponents) getPostureValidPeers(inputPeers []string, postureChecksIDs []string) []string {
var dest []string
for _, peerID := range inputPeers {
if c.ValidatePostureChecksOnPeer(peerID, postureChecksIDs) {
dest = append(dest, peerID)
}
}
return dest
}
func (c *NetworkMapComponents) getPeerNetworkResourceFirewallRules(ctx context.Context, peerID string, routes []*route.Route) []*RouteFirewallRule {
routesFirewallRules := make([]*RouteFirewallRule, 0)
peerInfo := c.GetPeerInfo(peerID)
if peerInfo == nil {
return routesFirewallRules
}
for _, r := range routes {
if r.Peer != peerInfo.Key {
continue
}
resourceID := string(r.GetResourceID())
resourcePolicies := c.ResourcePoliciesMap[resourceID]
distributionPeers := c.getPoliciesSourcePeers(resourcePolicies)
rules := c.getRouteFirewallRules(ctx, peerID, resourcePolicies, r, distributionPeers)
for _, rule := range rules {
if len(rule.SourceRanges) > 0 {
routesFirewallRules = append(routesFirewallRules, rule)
}
}
}
return routesFirewallRules
}
func (c *NetworkMapComponents) getPoliciesSourcePeers(policies []*Policy) map[string]struct{} {
sourcePeers := make(map[string]struct{})
for _, policy := range policies {
for _, rule := range policy.Rules {
for _, sourceGroup := range rule.Sources {
group := c.GetGroupInfo(sourceGroup)
if group == nil {
continue
}
for _, peer := range group.Peers {
sourcePeers[peer] = struct{}{}
}
}
if rule.SourceResource.Type == ResourceTypePeer && rule.SourceResource.ID != "" {
sourcePeers[rule.SourceResource.ID] = struct{}{}
}
}
}
return sourcePeers
}
func (c *NetworkMapComponents) addNetworksRoutingPeers(
networkResourcesRoutes []*route.Route,
peerID string,
peersToConnect []*nbpeer.Peer,
expiredPeers []*nbpeer.Peer,
isRouter bool,
sourcePeers map[string]struct{},
) []*nbpeer.Peer {
networkRoutesPeers := make(map[string]struct{}, len(networkResourcesRoutes))
for _, r := range networkResourcesRoutes {
networkRoutesPeers[r.PeerID] = struct{}{}
}
delete(sourcePeers, peerID)
delete(networkRoutesPeers, peerID)
for _, existingPeer := range peersToConnect {
delete(sourcePeers, existingPeer.ID)
delete(networkRoutesPeers, existingPeer.ID)
}
for _, expPeer := range expiredPeers {
delete(sourcePeers, expPeer.ID)
delete(networkRoutesPeers, expPeer.ID)
}
missingPeers := make(map[string]struct{}, len(sourcePeers)+len(networkRoutesPeers))
if isRouter {
for p := range sourcePeers {
missingPeers[p] = struct{}{}
}
}
for p := range networkRoutesPeers {
missingPeers[p] = struct{}{}
}
for p := range missingPeers {
peerInfo := c.GetPeerInfo(p)
if peerInfo == nil {
peerInfo = c.GetRouterPeerInfo(p)
}
if peerInfo != nil {
peersToConnect = append(peersToConnect, peerInfo)
}
}
return peersToConnect
}

View File

@@ -0,0 +1,230 @@
package types
import (
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/route"
)
type GroupCompact struct {
Name string
PeerIndexes []int
}
type NetworkMapComponentsCompact struct {
PeerID string
Network *Network
AccountSettings *AccountSettingsInfo
DNSSettings *DNSSettings
CustomZoneDomain string
AllPeers []*nbpeer.Peer
PeerIndexes []int
RouterPeerIndexes []int
Groups map[string]*GroupCompact
AllPolicies []*Policy
PolicyIndexes []int
ResourcePoliciesMap map[string][]int
Routes []*route.Route
NameServerGroups []*nbdns.NameServerGroup
AllDNSRecords []nbdns.SimpleRecord
AccountZones []nbdns.CustomZone
RoutersMap map[string]map[string]*routerTypes.NetworkRouter
NetworkResources []*resourceTypes.NetworkResource
GroupIDToUserIDs map[string][]string
AllowedUserIDs map[string]struct{}
PostureFailedPeers map[string]map[string]struct{}
}
func (c *NetworkMapComponents) ToCompact() *NetworkMapComponentsCompact {
peerToIndex := make(map[string]int)
var allPeers []*nbpeer.Peer
for id, peer := range c.Peers {
if _, exists := peerToIndex[id]; !exists {
peerToIndex[id] = len(allPeers)
allPeers = append(allPeers, peer)
}
}
for id, peer := range c.RouterPeers {
if _, exists := peerToIndex[id]; !exists {
peerToIndex[id] = len(allPeers)
allPeers = append(allPeers, peer)
}
}
peerIndexes := make([]int, 0, len(c.Peers))
for id := range c.Peers {
peerIndexes = append(peerIndexes, peerToIndex[id])
}
routerPeerIndexes := make([]int, 0, len(c.RouterPeers))
for id := range c.RouterPeers {
routerPeerIndexes = append(routerPeerIndexes, peerToIndex[id])
}
groups := make(map[string]*GroupCompact, len(c.Groups))
for id, group := range c.Groups {
peerIdxs := make([]int, 0, len(group.Peers))
for _, peerID := range group.Peers {
if idx, ok := peerToIndex[peerID]; ok {
peerIdxs = append(peerIdxs, idx)
}
}
groups[id] = &GroupCompact{
Name: group.Name,
PeerIndexes: peerIdxs,
}
}
policyToIndex := make(map[*Policy]int)
var allPolicies []*Policy
for _, policy := range c.Policies {
if _, exists := policyToIndex[policy]; !exists {
policyToIndex[policy] = len(allPolicies)
allPolicies = append(allPolicies, policy)
}
}
for _, policies := range c.ResourcePoliciesMap {
for _, policy := range policies {
if _, exists := policyToIndex[policy]; !exists {
policyToIndex[policy] = len(allPolicies)
allPolicies = append(allPolicies, policy)
}
}
}
policyIndexes := make([]int, len(c.Policies))
for i, policy := range c.Policies {
policyIndexes[i] = policyToIndex[policy]
}
var resourcePoliciesMap map[string][]int
if len(c.ResourcePoliciesMap) > 0 {
resourcePoliciesMap = make(map[string][]int, len(c.ResourcePoliciesMap))
for resID, policies := range c.ResourcePoliciesMap {
indexes := make([]int, len(policies))
for i, policy := range policies {
indexes[i] = policyToIndex[policy]
}
resourcePoliciesMap[resID] = indexes
}
}
return &NetworkMapComponentsCompact{
PeerID: c.PeerID,
Network: c.Network,
AccountSettings: c.AccountSettings,
DNSSettings: c.DNSSettings,
CustomZoneDomain: c.CustomZoneDomain,
AllPeers: allPeers,
PeerIndexes: peerIndexes,
RouterPeerIndexes: routerPeerIndexes,
Groups: groups,
AllPolicies: allPolicies,
PolicyIndexes: policyIndexes,
ResourcePoliciesMap: resourcePoliciesMap,
Routes: c.Routes,
NameServerGroups: c.NameServerGroups,
AllDNSRecords: c.AllDNSRecords,
AccountZones: c.AccountZones,
RoutersMap: c.RoutersMap,
NetworkResources: c.NetworkResources,
GroupIDToUserIDs: c.GroupIDToUserIDs,
AllowedUserIDs: c.AllowedUserIDs,
PostureFailedPeers: c.PostureFailedPeers,
}
}
func (c *NetworkMapComponentsCompact) ToFull() *NetworkMapComponents {
peers := make(map[string]*nbpeer.Peer, len(c.PeerIndexes))
for _, idx := range c.PeerIndexes {
if idx >= 0 && idx < len(c.AllPeers) {
peer := c.AllPeers[idx]
peers[peer.ID] = peer
}
}
routerPeers := make(map[string]*nbpeer.Peer, len(c.RouterPeerIndexes))
for _, idx := range c.RouterPeerIndexes {
if idx >= 0 && idx < len(c.AllPeers) {
peer := c.AllPeers[idx]
routerPeers[peer.ID] = peer
}
}
groups := make(map[string]*Group, len(c.Groups))
for id, gc := range c.Groups {
peerIDs := make([]string, 0, len(gc.PeerIndexes))
for _, idx := range gc.PeerIndexes {
if idx >= 0 && idx < len(c.AllPeers) {
peerIDs = append(peerIDs, c.AllPeers[idx].ID)
}
}
groups[id] = &Group{
ID: id,
Name: gc.Name,
Peers: peerIDs,
}
}
policies := make([]*Policy, len(c.PolicyIndexes))
for i, idx := range c.PolicyIndexes {
if idx >= 0 && idx < len(c.AllPolicies) {
policies[i] = c.AllPolicies[idx]
}
}
var resourcePoliciesMap map[string][]*Policy
if len(c.ResourcePoliciesMap) > 0 {
resourcePoliciesMap = make(map[string][]*Policy, len(c.ResourcePoliciesMap))
for resID, indexes := range c.ResourcePoliciesMap {
pols := make([]*Policy, 0, len(indexes))
for _, idx := range indexes {
if idx >= 0 && idx < len(c.AllPolicies) {
pols = append(pols, c.AllPolicies[idx])
}
}
resourcePoliciesMap[resID] = pols
}
}
return &NetworkMapComponents{
PeerID: c.PeerID,
Network: c.Network,
AccountSettings: c.AccountSettings,
DNSSettings: c.DNSSettings,
CustomZoneDomain: c.CustomZoneDomain,
Peers: peers,
RouterPeers: routerPeers,
Groups: groups,
Policies: policies,
Routes: c.Routes,
NameServerGroups: c.NameServerGroups,
AllDNSRecords: c.AllDNSRecords,
AccountZones: c.AccountZones,
ResourcePoliciesMap: resourcePoliciesMap,
RoutersMap: c.RoutersMap,
NetworkResources: c.NetworkResources,
GroupIDToUserIDs: c.GroupIDToUserIDs,
AllowedUserIDs: c.AllowedUserIDs,
PostureFailedPeers: c.PostureFailedPeers,
}
}

View File

@@ -53,6 +53,7 @@ var (
certLockMethod string
wgPort int
proxyProtocol bool
preSharedKey string
)
var rootCmd = &cobra.Command{
@@ -84,6 +85,7 @@ func init() {
rootCmd.Flags().StringVar(&certLockMethod, "cert-lock-method", envStringOrDefault("NB_PROXY_CERT_LOCK_METHOD", "auto"), "Certificate lock method for cross-replica coordination: auto, flock, or k8s-lease")
rootCmd.Flags().IntVar(&wgPort, "wg-port", envIntOrDefault("NB_PROXY_WG_PORT", 0), "WireGuard listen port (0 = random). Fixed port only works with single-account deployments")
rootCmd.Flags().BoolVar(&proxyProtocol, "proxy-protocol", envBoolOrDefault("NB_PROXY_PROXY_PROTOCOL", false), "Enable PROXY protocol on TCP listeners to preserve client IPs behind L4 proxies")
rootCmd.Flags().StringVar(&preSharedKey, "preshared-key", envStringOrDefault("NB_PROXY_PRESHARED_KEY", ""), "Define a pre-shared key for the tunnel between proxy and peers")
}
// Execute runs the root command.
@@ -156,6 +158,7 @@ func runServer(cmd *cobra.Command, args []string) error {
CertLockMethod: nbacme.CertLockMethod(certLockMethod),
WireguardPort: wgPort,
ProxyProtocol: proxyProtocol,
PreSharedKey: preSharedKey,
}
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGTERM, syscall.SIGINT)

View File

@@ -86,6 +86,13 @@ func (e *clientEntry) acquireInflight(backend backendKey) (release func(), ok bo
}
}
// ClientConfig holds configuration for the embedded NetBird client.
type ClientConfig struct {
MgmtAddr string
WGPort int
PreSharedKey string
}
type statusNotifier interface {
NotifyStatus(ctx context.Context, accountID, serviceID, domain string, connected bool) error
}
@@ -98,10 +105,9 @@ type managementClient interface {
// backed by underlying NetBird connections.
// Clients are keyed by AccountID, allowing multiple domains to share the same connection.
type NetBird struct {
mgmtAddr string
proxyID string
proxyAddr string
wgPort int
clientCfg ClientConfig
logger *log.Logger
mgmtClient managementClient
transportCfg transportConfig
@@ -229,11 +235,12 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
// The peer has already been created via CreateProxyPeer RPC with the public key.
client, err := embed.New(embed.Options{
DeviceName: deviceNamePrefix + n.proxyID,
ManagementURL: n.mgmtAddr,
ManagementURL: n.clientCfg.MgmtAddr,
PrivateKey: privateKey.String(),
LogLevel: log.WarnLevel.String(),
BlockInbound: true,
WireguardPort: &n.wgPort,
WireguardPort: &n.clientCfg.WGPort,
PreSharedKey: n.clientCfg.PreSharedKey,
})
if err != nil {
return nil, fmt.Errorf("create netbird client: %w", err)
@@ -536,18 +543,17 @@ func (n *NetBird) ListClientsForStartup() map[types.AccountID]*embed.Client {
return result
}
// NewNetBird creates a new NetBird transport. Set wgPort to 0 for a random
// NewNetBird creates a new NetBird transport. Set clientCfg.WGPort to 0 for a random
// OS-assigned port. A fixed port only works with single-account deployments;
// multiple accounts will fail to bind the same port.
func NewNetBird(mgmtAddr, proxyID, proxyAddr string, wgPort int, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird {
func NewNetBird(proxyID, proxyAddr string, clientCfg ClientConfig, logger *log.Logger, notifier statusNotifier, mgmtClient managementClient) *NetBird {
if logger == nil {
logger = log.StandardLogger()
}
return &NetBird{
mgmtAddr: mgmtAddr,
proxyID: proxyID,
proxyAddr: proxyAddr,
wgPort: wgPort,
clientCfg: clientCfg,
logger: logger,
clients: make(map[types.AccountID]*clientEntry),
statusNotifier: notifier,

View File

@@ -49,7 +49,11 @@ func (m *mockStatusNotifier) calls() []statusCall {
// mockNetBird creates a NetBird instance for testing without actually connecting.
// It uses an invalid management URL to prevent real connections.
func mockNetBird() *NetBird {
return NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, nil, &mockMgmtClient{})
return NewNetBird("test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
WGPort: 0,
PreSharedKey: "",
}, nil, nil, &mockMgmtClient{})
}
func TestNetBird_AddPeer_CreatesClientForNewAccount(t *testing.T) {
@@ -282,7 +286,11 @@ func TestNetBird_RoundTrip_RequiresExistingClient(t *testing.T) {
func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
WGPort: 0,
PreSharedKey: "",
}, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1")
// Add first domain — creates a new client entry.
@@ -308,7 +316,11 @@ func TestNetBird_AddPeer_ExistingStartedClient_NotifiesStatus(t *testing.T) {
func TestNetBird_RemovePeer_NotifiesDisconnection(t *testing.T) {
notifier := &mockStatusNotifier{}
nb := NewNetBird("http://invalid.test:9999", "test-proxy", "invalid.test", 0, nil, notifier, &mockMgmtClient{})
nb := NewNetBird("test-proxy", "invalid.test", ClientConfig{
MgmtAddr: "http://invalid.test:9999",
WGPort: 0,
PreSharedKey: "",
}, nil, notifier, &mockMgmtClient{})
accountID := types.AccountID("account-1")
err := nb.AddPeer(context.Background(), accountID, domain.Domain("domain1.test"), "key-1", "svc-1")

View File

@@ -192,6 +192,10 @@ type storeBackedServiceManager struct {
tokenStore *nbgrpc.OneTimeTokenStore
}
func (m *storeBackedServiceManager) DeleteAllServices(ctx context.Context, accountID, userID string) error {
return nil
}
func (m *storeBackedServiceManager) GetAllServices(ctx context.Context, accountID, userID string) ([]*reverseproxy.Service, error) {
return m.store.GetAccountServices(ctx, store.LockingStrengthNone, accountID)
}

View File

@@ -114,6 +114,8 @@ type Server struct {
// When enabled, the real client IP is extracted from the PROXY header
// sent by upstream L4 proxies that support PROXY protocol.
ProxyProtocol bool
// PreSharedKey used for tunnel between proxy and peers (set globally not per account)
PreSharedKey string
}
// NotifyStatus sends a status update to management about tunnel connectivity
@@ -163,7 +165,11 @@ func (s *Server) ListenAndServe(ctx context.Context, addr string) (err error) {
// Initialize the netbird client, this is required to build peer connections
// to proxy over.
s.netbird = roundtrip.NewNetBird(s.ManagementAddress, s.ID, s.ProxyURL, s.WireguardPort, s.Logger, s, s.mgmtClient)
s.netbird = roundtrip.NewNetBird(s.ID, s.ProxyURL, roundtrip.ClientConfig{
MgmtAddr: s.ManagementAddress,
WGPort: s.WireguardPort,
PreSharedKey: s.PreSharedKey,
}, s.Logger, s, s.mgmtClient)
tlsConfig, err := s.configureTLS(ctx)
if err != nil {