conflicts resolution

This commit is contained in:
crn4
2026-01-08 13:25:10 +01:00
235 changed files with 24100 additions and 2460 deletions

View File

@@ -16,6 +16,7 @@ import (
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
@@ -45,8 +46,10 @@ const (
// nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections.
nativeSSHPortString = "22022"
nativeSSHPortNumber = 22022
// defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections.
defaultSSHPortString = "22"
defaultSSHPortNumber = 22
)
type supportedFeatures struct {
@@ -275,6 +278,7 @@ func (a *Account) GetPeerNetworkMap(
resourcePolicies map[string][]*Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
metrics *telemetry.AccountManagerMetrics,
groupIDToUserIDs map[string][]string,
) *NetworkMap {
start := time.Now()
peer := a.Peers[peerID]
@@ -290,7 +294,7 @@ func (a *Account) GetPeerNetworkMap(
}
}
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
// exclude expired peers
var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer
@@ -338,6 +342,8 @@ func (a *Account) GetPeerNetworkMap(
OfflinePeers: expiredPeers,
FirewallRules: firewallRules,
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
AuthorizedUsers: authorizedUsers,
EnableSSH: enableSSH,
}
if metrics != nil {
@@ -1122,8 +1128,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
// GetPeerConnectionResources for a given peer
//
// This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
authorizedUsers := make(map[string]map[string]struct{}) // machine user to list of userIDs
sshEnabled := false
for _, policy := range a.Policies {
if !policy.Enabled {
@@ -1166,10 +1174,58 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
if peerInDestinations {
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
}
if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
sshEnabled = true
switch {
case len(rule.AuthorizedGroups) > 0:
for groupID, localUsers := range rule.AuthorizedGroups {
userIDs, ok := groupIDToUserIDs[groupID]
if !ok {
log.WithContext(ctx).Tracef("no user IDs found for group ID %s", groupID)
continue
}
if len(localUsers) == 0 {
localUsers = []string{auth.Wildcard}
}
for _, localUser := range localUsers {
if authorizedUsers[localUser] == nil {
authorizedUsers[localUser] = make(map[string]struct{})
}
for _, userID := range userIDs {
authorizedUsers[localUser][userID] = struct{}{}
}
}
}
case rule.AuthorizedUser != "":
if authorizedUsers[auth.Wildcard] == nil {
authorizedUsers[auth.Wildcard] = make(map[string]struct{})
}
authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{}
default:
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
}
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
sshEnabled = true
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
}
}
}
return getAccumulatedResources()
peers, fwRules := getAccumulatedResources()
return peers, fwRules, authorizedUsers, sshEnabled
}
func (a *Account) getAllowedUserIDs() map[string]struct{} {
users := make(map[string]struct{})
for _, nbUser := range a.Users {
if !nbUser.IsBlocked() && !nbUser.IsServiceUser {
users[nbUser.Id] = struct{}{}
}
}
return users
}
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
@@ -1194,12 +1250,17 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
peersExists[peer.ID] = struct{}{}
}
protocol := rule.Protocol
if protocol == PolicyRuleProtocolNetbirdSSH {
protocol = PolicyRuleProtocolTCP
}
fr := FirewallRule{
PolicyID: rule.ID,
PeerIP: peer.IP.String(),
Direction: direction,
Action: string(rule.Action),
Protocol: string(rule.Protocol),
Protocol: string(protocol),
}
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
@@ -1221,6 +1282,28 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
}
}
func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
}
func portRangeIncludesSSH(portRanges []RulePortRange) bool {
for _, pr := range portRanges {
if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) {
return true
}
}
return false
}
func portsIncludesSSH(ports []string) bool {
for _, port := range ports {
if port == defaultSSHPortString || port == nativeSSHPortString {
return true
}
}
return false
}
// getAllPeersFromGroups for given peer ID and list of groups
//
// Returns a list of peers from specified groups that pass specified posture checks
@@ -1265,7 +1348,11 @@ func (a *Account) getPeerFromResource(resource Resource, peerID string) ([]*nbpe
return []*nbpeer.Peer{}, false
}
return []*nbpeer.Peer{peer}, resource.ID == peerID
if peer.ID == peerID {
return []*nbpeer.Peer{}, true
}
return []*nbpeer.Peer{peer}, false
}
// validatePostureChecksOnPeer validates the posture checks on a peer
@@ -1773,6 +1860,26 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
return nil
}
func (a *Account) GetActiveGroupUsers() map[string][]string {
allGroupID := ""
group, err := a.GetGroupAll()
if err != nil {
log.Errorf("failed to get group all: %v", err)
} else {
allGroupID = group.ID
}
groups := make(map[string][]string, len(a.GroupsG))
for _, user := range a.Users {
if !user.IsBlocked() && !user.IsServiceUser {
for _, groupID := range user.AutoGroups {
groups[groupID] = append(groups[groupID], user.Id)
}
groups[allGroupID] = append(groups[allGroupID], user.Id)
}
}
return groups
}
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
@@ -1804,7 +1911,7 @@ func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer
expanded = append(expanded, &fr)
}
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) {
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH {
expanded = addNativeSSHRule(base, expanded)
}

View File

@@ -1105,6 +1105,193 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
}
}
func Test_GetActiveGroupUsers(t *testing.T) {
tests := []struct {
name string
account *Account
expected map[string][]string
}{
{
name: "all users are active",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1", "group2"},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group2", "group3"},
Blocked: false,
},
"user3": {
Id: "user3",
AutoGroups: []string{"group1"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user1", "user3"},
"group2": {"user1", "user2"},
"group3": {"user2"},
"": {"user1", "user2", "user3"},
},
},
{
name: "some users are blocked",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1", "group2"},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group2", "group3"},
Blocked: true,
},
"user3": {
Id: "user3",
AutoGroups: []string{"group1", "group3"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user1", "user3"},
"group2": {"user1"},
"group3": {"user3"},
"": {"user1", "user3"},
},
},
{
name: "all users are blocked",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1"},
Blocked: true,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group2"},
Blocked: true,
},
},
},
expected: map[string][]string{},
},
{
name: "user with no auto groups",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group1"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user2"},
"": {"user1", "user2"},
},
},
{
name: "empty account",
account: &Account{
Users: map[string]*User{},
},
expected: map[string][]string{},
},
{
name: "multiple users in same group",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1"},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group1"},
Blocked: false,
},
"user3": {
Id: "user3",
AutoGroups: []string{"group1"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user1", "user2", "user3"},
"": {"user1", "user2", "user3"},
},
},
{
name: "user in multiple groups with blocked users",
account: &Account{
Users: map[string]*User{
"user1": {
Id: "user1",
AutoGroups: []string{"group1", "group2", "group3"},
Blocked: false,
},
"user2": {
Id: "user2",
AutoGroups: []string{"group1", "group2"},
Blocked: true,
},
"user3": {
Id: "user3",
AutoGroups: []string{"group3"},
Blocked: false,
},
},
},
expected: map[string][]string{
"group1": {"user1"},
"group2": {"user1"},
"group3": {"user1", "user3"},
"": {"user1", "user3"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetActiveGroupUsers()
// Check that the number of groups matches
assert.Equal(t, len(tt.expected), len(result), "number of groups should match")
// Check each group's users
for groupID, expectedUsers := range tt.expected {
actualUsers, exists := result[groupID]
assert.True(t, exists, "group %s should exist in result", groupID)
assert.ElementsMatch(t, expectedUsers, actualUsers, "users in group %s should match", groupID)
}
// Ensure no extra groups in result
for groupID := range result {
_, exists := tt.expected[groupID]
assert.True(t, exists, "unexpected group %s in result", groupID)
}
})
}
}
func Test_FilterZoneRecordsForPeers(t *testing.T) {
tests := []struct {
name string

View File

@@ -25,16 +25,20 @@ func (h *Holder) GetAccount(id string) *Account {
func (h *Holder) AddAccount(account *Account) {
h.mu.Lock()
defer h.mu.Unlock()
a := h.accounts[account.Id]
if a != nil && a.Network.CurrentSerial() >= account.Network.CurrentSerial() {
return
}
h.accounts[account.Id] = account
}
func (h *Holder) LoadOrStoreFunc(id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
func (h *Holder) LoadOrStoreFunc(ctx context.Context, id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
h.mu.Lock()
defer h.mu.Unlock()
if acc, ok := h.accounts[id]; ok {
return acc, nil
}
account, err := accGetter(context.Background(), id)
account, err := accGetter(ctx, id)
if err != nil {
return nil, err
}

View File

@@ -0,0 +1,122 @@
package types
import (
"errors"
"net/url"
)
// Identity provider validation errors
var (
ErrIdentityProviderNameRequired = errors.New("identity provider name is required")
ErrIdentityProviderTypeRequired = errors.New("identity provider type is required")
ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type")
ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required")
ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL")
ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required")
)
// IdentityProviderType is the type of identity provider
type IdentityProviderType string
const (
// IdentityProviderTypeOIDC is a generic OIDC identity provider
IdentityProviderTypeOIDC IdentityProviderType = "oidc"
// IdentityProviderTypeZitadel is the Zitadel identity provider
IdentityProviderTypeZitadel IdentityProviderType = "zitadel"
// IdentityProviderTypeEntra is the Microsoft Entra (Azure AD) identity provider
IdentityProviderTypeEntra IdentityProviderType = "entra"
// IdentityProviderTypeGoogle is the Google identity provider
IdentityProviderTypeGoogle IdentityProviderType = "google"
// IdentityProviderTypeOkta is the Okta identity provider
IdentityProviderTypeOkta IdentityProviderType = "okta"
// IdentityProviderTypePocketID is the PocketID identity provider
IdentityProviderTypePocketID IdentityProviderType = "pocketid"
// IdentityProviderTypeMicrosoft is the Microsoft identity provider
IdentityProviderTypeMicrosoft IdentityProviderType = "microsoft"
// IdentityProviderTypeAuthentik is the Authentik identity provider
IdentityProviderTypeAuthentik IdentityProviderType = "authentik"
// IdentityProviderTypeKeycloak is the Keycloak identity provider
IdentityProviderTypeKeycloak IdentityProviderType = "keycloak"
)
// IdentityProvider represents an identity provider configuration
type IdentityProvider struct {
// ID is the unique identifier of the identity provider
ID string `gorm:"primaryKey"`
// AccountID is a reference to Account that this object belongs
AccountID string `json:"-" gorm:"index"`
// Type is the type of identity provider
Type IdentityProviderType
// Name is a human-readable name for the identity provider
Name string
// Issuer is the OIDC issuer URL
Issuer string
// ClientID is the OAuth2 client ID
ClientID string
// ClientSecret is the OAuth2 client secret
ClientSecret string
}
// Copy returns a copy of the IdentityProvider
func (idp *IdentityProvider) Copy() *IdentityProvider {
return &IdentityProvider{
ID: idp.ID,
AccountID: idp.AccountID,
Type: idp.Type,
Name: idp.Name,
Issuer: idp.Issuer,
ClientID: idp.ClientID,
ClientSecret: idp.ClientSecret,
}
}
// EventMeta returns a map of metadata for activity events
func (idp *IdentityProvider) EventMeta() map[string]any {
return map[string]any{
"name": idp.Name,
"type": string(idp.Type),
"issuer": idp.Issuer,
}
}
// Validate validates the identity provider configuration
func (idp *IdentityProvider) Validate() error {
if idp.Name == "" {
return ErrIdentityProviderNameRequired
}
if idp.Type == "" {
return ErrIdentityProviderTypeRequired
}
if !idp.Type.IsValid() {
return ErrIdentityProviderTypeUnsupported
}
if !idp.Type.HasBuiltInIssuer() && idp.Issuer == "" {
return ErrIdentityProviderIssuerRequired
}
if idp.Issuer != "" {
parsedURL, err := url.Parse(idp.Issuer)
if err != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
return ErrIdentityProviderIssuerInvalid
}
}
if idp.ClientID == "" {
return ErrIdentityProviderClientIDRequired
}
return nil
}
// IsValid checks if the given type is a supported identity provider type
func (t IdentityProviderType) IsValid() bool {
switch t {
case IdentityProviderTypeOIDC, IdentityProviderTypeZitadel, IdentityProviderTypeEntra,
IdentityProviderTypeGoogle, IdentityProviderTypeOkta, IdentityProviderTypePocketID,
IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak:
return true
}
return false
}
// HasBuiltInIssuer returns true for types that don't require an issuer URL
func (t IdentityProviderType) HasBuiltInIssuer() bool {
return t == IdentityProviderTypeGoogle || t == IdentityProviderTypeMicrosoft
}

View File

@@ -0,0 +1,137 @@
package types
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestIdentityProvider_Validate(t *testing.T) {
tests := []struct {
name string
idp *IdentityProvider
expectedErr error
}{
{
name: "valid OIDC provider",
idp: &IdentityProvider{
Name: "Test Provider",
Type: IdentityProviderTypeOIDC,
Issuer: "https://example.com",
ClientID: "client-id",
},
expectedErr: nil,
},
{
name: "valid OIDC provider with path",
idp: &IdentityProvider{
Name: "Test Provider",
Type: IdentityProviderTypeOIDC,
Issuer: "https://example.com/oauth2/issuer",
ClientID: "client-id",
},
expectedErr: nil,
},
{
name: "missing name",
idp: &IdentityProvider{
Type: IdentityProviderTypeOIDC,
Issuer: "https://example.com",
ClientID: "client-id",
},
expectedErr: ErrIdentityProviderNameRequired,
},
{
name: "missing type",
idp: &IdentityProvider{
Name: "Test Provider",
Issuer: "https://example.com",
ClientID: "client-id",
},
expectedErr: ErrIdentityProviderTypeRequired,
},
{
name: "invalid type",
idp: &IdentityProvider{
Name: "Test Provider",
Type: "invalid",
Issuer: "https://example.com",
ClientID: "client-id",
},
expectedErr: ErrIdentityProviderTypeUnsupported,
},
{
name: "missing issuer for OIDC",
idp: &IdentityProvider{
Name: "Test Provider",
Type: IdentityProviderTypeOIDC,
ClientID: "client-id",
},
expectedErr: ErrIdentityProviderIssuerRequired,
},
{
name: "invalid issuer URL - no scheme",
idp: &IdentityProvider{
Name: "Test Provider",
Type: IdentityProviderTypeOIDC,
Issuer: "example.com",
ClientID: "client-id",
},
expectedErr: ErrIdentityProviderIssuerInvalid,
},
{
name: "invalid issuer URL - no host",
idp: &IdentityProvider{
Name: "Test Provider",
Type: IdentityProviderTypeOIDC,
Issuer: "https://",
ClientID: "client-id",
},
expectedErr: ErrIdentityProviderIssuerInvalid,
},
{
name: "invalid issuer URL - just path",
idp: &IdentityProvider{
Name: "Test Provider",
Type: IdentityProviderTypeOIDC,
Issuer: "/oauth2/issuer",
ClientID: "client-id",
},
expectedErr: ErrIdentityProviderIssuerInvalid,
},
{
name: "missing client ID",
idp: &IdentityProvider{
Name: "Test Provider",
Type: IdentityProviderTypeOIDC,
Issuer: "https://example.com",
},
expectedErr: ErrIdentityProviderClientIDRequired,
},
{
name: "Google provider without issuer is valid",
idp: &IdentityProvider{
Name: "Google SSO",
Type: IdentityProviderTypeGoogle,
ClientID: "client-id",
},
expectedErr: nil,
},
{
name: "Microsoft provider without issuer is valid",
idp: &IdentityProvider{
Name: "Microsoft SSO",
Type: IdentityProviderTypeMicrosoft,
ClientID: "client-id",
},
expectedErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.idp.Validate()
assert.Equal(t, tt.expectedErr, err)
})
}
}

View File

@@ -39,6 +39,8 @@ type NetworkMap struct {
FirewallRules []*FirewallRule
RoutesFirewallRules []*RouteFirewallRule
ForwardingRules []*ForwardingRule
AuthorizedUsers map[string]map[string]struct{}
EnableSSH bool
}
func (nm *NetworkMap) Merge(other *NetworkMap) {

View File

@@ -47,14 +47,21 @@ func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
if a.NetworkMapCache == nil {
return nil
}
return a.NetworkMapCache.OnPeerAddedIncremental(peerId)
return a.NetworkMapCache.OnPeerAddedIncremental(a, peerId)
}
func (a *Account) OnPeersAddedUpdNetworkMapCache(peerIds ...string) {
if a.NetworkMapCache == nil {
return
}
a.NetworkMapCache.EnqueuePeersForIncrementalAdd(a, peerIds...)
}
func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error {
if a.NetworkMapCache == nil {
return nil
}
return a.NetworkMapCache.OnPeerDeleted(peerId)
return a.NetworkMapCache.OnPeerDeleted(a, peerId)
}
func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) {

View File

@@ -25,15 +25,12 @@ import (
"github.com/netbirdio/netbird/route"
)
// update flag is used to update the golden file.
// example: go test ./... -v -update
// var update = flag.Bool("update", false, "update golden files")
const (
numPeers = 100
devGroupID = "group-dev"
opsGroupID = "group-ops"
allGroupID = "group-all"
sshUsersGroupID = "group-ssh-users"
routeID = route.ID("route-main")
routeHA1ID = route.ID("route-ha-1")
routeHA2ID = route.ID("route-ha-2")
@@ -41,6 +38,7 @@ const (
policyIDAll = "policy-all"
policyIDPosture = "policy-posture"
policyIDDrop = "policy-drop"
policyIDSSH = "policy-ssh"
postureCheckID = "posture-check-ver"
networkResourceID = "res-database"
networkID = "net-database"
@@ -51,6 +49,9 @@ const (
offlinePeerID = "peer-99" // This peer will be completely offline.
routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network.
testAccountID = "account-golden-test"
userAdminID = "user-admin"
userDevID = "user-dev"
userOpsID = "user-ops"
)
func TestGetPeerNetworkMap_Golden(t *testing.T) {
@@ -69,61 +70,34 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
normalizeAndSortNetworkMap(networkMap)
jsonData, err := json.MarshalIndent(networkMap, "", " ")
require.NoError(t, err, "error marshaling network map to JSON")
goldenFilePath := filepath.Join("testdata", "networkmap_golden.json")
t.Log("Update golden file...")
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(goldenFilePath, jsonData, 0644)
require.NoError(t, err)
expectedJSON, err := os.ReadFile(goldenFilePath)
require.NoError(t, err, "error reading golden file")
require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from OLD method does not match golden file")
}
func TestGetPeerNetworkMap_Golden_New(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
normalizeAndSortNetworkMap(networkMap)
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new.json")
jsonData, err := json.MarshalIndent(networkMap, "", " ")
require.NoError(t, err, "error marshaling network map to JSON")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new.json")
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
t.Log("Update golden file...")
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(goldenFilePath, jsonData, 0644)
require.NoError(t, err)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
expectedJSON, err := os.ReadFile(goldenFilePath)
require.NoError(t, err, "error reading golden file")
require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from NEW builder does not match golden file")
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps from legacy and new builder do not match")
}
}
func BenchmarkGetPeerNetworkMap(b *testing.B) {
@@ -141,7 +115,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
b.Run("old builder", func(b *testing.B) {
for range b.N {
for _, peerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -169,6 +143,8 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newPeerID := "peer-new-101"
newPeerIP := net.IP{100, 64, 1, 1}
newPeer := &nbpeer.Peer{
@@ -201,92 +177,36 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
normalizeAndSortNetworkMap(networkMap)
jsonData, err := json.MarshalIndent(networkMap, "", " ")
require.NoError(t, err, "error marshaling network map to JSON")
goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json")
t.Log("Update golden file with new peer...")
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(goldenFilePath, jsonData, 0644)
require.NoError(t, err)
expectedJSON, err := os.ReadFile(goldenFilePath)
require.NoError(t, err, "error reading golden file")
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new peer does not match golden file")
}
func TestGetPeerNetworkMap_Golden_New_WithOnPeerAdded(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newPeerID := "peer-new-101"
newPeerIP := net.IP{100, 64, 1, 1}
newPeer := &nbpeer.Peer{
ID: newPeerID,
IP: newPeerIP,
Key: fmt.Sprintf("key-%s", newPeerID),
DNSLabel: "peernew101",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
}
account.Peers[newPeerID] = newPeer
if devGroup, exists := account.Groups[devGroupID]; exists {
devGroup.Peers = append(devGroup.Peers, newPeerID)
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = append(allGroup.Peers, newPeerID)
}
validatedPeersMap[newPeerID] = struct{}{}
if account.Network != nil {
account.Network.Serial++
}
err := builder.OnPeerAddedIncremental(newPeerID)
err = builder.OnPeerAddedIncremental(account, newPeerID)
require.NoError(t, err, "error adding peer to cache")
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
normalizeAndSortNetworkMap(networkMap)
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json")
jsonData, err := json.MarshalIndent(networkMap, "", " ")
require.NoError(t, err, "error marshaling network map to JSON")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json")
t.Log("Update golden file with OnPeerAdded...")
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(goldenFilePath, jsonData, 0644)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
expectedJSON, err := os.ReadFile(goldenFilePath)
require.NoError(t, err, "error reading golden file")
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded does not match golden file")
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new peer from legacy and new builder do not match")
}
}
func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
@@ -320,7 +240,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -328,7 +248,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
b.ResetTimer()
b.Run("new builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerAddedIncremental(newPeerID)
_ = builder.OnPeerAddedIncremental(account, newPeerID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
}
@@ -349,6 +269,8 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newRouterID := "peer-new-router-102"
newRouterIP := net.IP{100, 64, 1, 2}
newRouter := &nbpeer.Peer{
@@ -395,106 +317,36 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
normalizeAndSortNetworkMap(networkMap)
jsonData, err := json.MarshalIndent(networkMap, "", " ")
require.NoError(t, err, "error marshaling network map to JSON")
goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json")
t.Log("Update golden file with new router...")
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(goldenFilePath, jsonData, 0644)
require.NoError(t, err)
expectedJSON, err := os.ReadFile(goldenFilePath)
require.NoError(t, err, "error reading golden file")
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new router does not match golden file")
}
func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newRouterID := "peer-new-router-102"
newRouterIP := net.IP{100, 64, 1, 2}
newRouter := &nbpeer.Peer{
ID: newRouterID,
IP: newRouterIP,
Key: fmt.Sprintf("key-%s", newRouterID),
DNSLabel: "newrouter102",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
}
account.Peers[newRouterID] = newRouter
if opsGroup, exists := account.Groups[opsGroupID]; exists {
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = append(allGroup.Peers, newRouterID)
}
newRoute := &route.Route{
ID: route.ID("route-new-router"),
Network: netip.MustParsePrefix("172.16.0.0/24"),
Peer: newRouter.Key,
PeerID: newRouterID,
Description: "Route from new router",
Enabled: true,
PeerGroups: []string{opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
AccountID: account.Id,
}
account.Routes[newRoute.ID] = newRoute
validatedPeersMap[newRouterID] = struct{}{}
if account.Network != nil {
account.Network.Serial++
}
err := builder.OnPeerAddedIncremental(newRouterID)
err = builder.OnPeerAddedIncremental(account, newRouterID)
require.NoError(t, err, "error adding router to cache")
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
normalizeAndSortNetworkMap(networkMap)
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
jsonData, err := json.MarshalIndent(networkMap, "", " ")
require.NoError(t, err, "error marshaling network map to JSON")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
t.Log("Update golden file with OnPeerAdded router...")
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(goldenFilePath, jsonData, 0644)
require.NoError(t, err)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
expectedJSON, err := os.ReadFile(goldenFilePath)
require.NoError(t, err, "error reading golden file")
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file")
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new router from legacy and new builder do not match")
}
}
func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
@@ -550,7 +402,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
b.Run("old builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -558,7 +410,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
b.ResetTimer()
b.Run("new builder after add", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerAddedIncremental(newRouterID)
_ = builder.OnPeerAddedIncremental(account, newRouterID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
}
@@ -579,7 +431,9 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
validatedPeersMap[peerID] = struct{}{}
}
deletedPeerID := "peer-25" // peer from devs group
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
deletedPeerID := "peer-25"
delete(account.Peers, deletedPeerID)
@@ -604,85 +458,36 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
normalizeAndSortNetworkMap(networkMap)
jsonData, err := json.MarshalIndent(networkMap, "", " ")
require.NoError(t, err, "error marshaling network map to JSON")
goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json")
t.Log("Update golden file with deleted peer...")
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(goldenFilePath, jsonData, 0644)
require.NoError(t, err)
expectedJSON, err := os.ReadFile(goldenFilePath)
require.NoError(t, err, "error reading golden file")
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file")
}
func TestGetPeerNetworkMap_Golden_New_WithOnPeerDeleted(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
deletedPeerID := "peer-25" // devs group peer
delete(account.Peers, deletedPeerID)
if devGroup, exists := account.Groups[devGroupID]; exists {
devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool {
return id == deletedPeerID
})
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool {
return id == deletedPeerID
})
}
delete(validatedPeersMap, deletedPeerID)
if account.Network != nil {
account.Network.Serial++
}
err := builder.OnPeerDeleted(deletedPeerID)
err = builder.OnPeerDeleted(account, deletedPeerID)
require.NoError(t, err, "error deleting peer from cache")
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
normalizeAndSortNetworkMap(networkMap)
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json")
jsonData, err := json.MarshalIndent(networkMap, "", " ")
require.NoError(t, err, "error marshaling network map to JSON")
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json")
t.Log("Update golden file with OnPeerDeleted...")
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(goldenFilePath, jsonData, 0644)
require.NoError(t, err)
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
expectedJSON, err := os.ReadFile(goldenFilePath)
require.NoError(t, err, "error reading golden file")
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerDeleted does not match golden file")
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted peer from legacy and new builder do not match")
}
}
func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
@@ -698,7 +503,9 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
validatedPeersMap[peerID] = struct{}{}
}
deletedRouterID := "peer-75" // router peer
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
deletedRouterID := "peer-75"
var affectedRoute *route.Route
for _, r := range account.Routes {
@@ -730,93 +537,36 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
normalizeAndSortNetworkMap(legacyNetworkMap)
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
require.NoError(t, err, "error marshaling legacy network map to JSON")
normalizeAndSortNetworkMap(networkMap)
jsonData, err := json.MarshalIndent(networkMap, "", " ")
require.NoError(t, err, "error marshaling network map to JSON")
goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json")
t.Log("Update golden file with deleted peer...")
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(goldenFilePath, jsonData, 0644)
require.NoError(t, err)
expectedJSON, err := os.ReadFile(goldenFilePath)
require.NoError(t, err, "error reading golden file")
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file")
}
func TestGetPeerNetworkMap_Golden_New_WithDeletedRouterPeer(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
deletedRouterID := "peer-75" // router peer
var affectedRoute *route.Route
for _, r := range account.Routes {
if r.PeerID == deletedRouterID {
affectedRoute = r
break
}
}
require.NotNil(t, affectedRoute, "Router peer should have a route")
for _, group := range account.Groups {
group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool {
return id == deletedRouterID
})
}
for routeID, r := range account.Routes {
if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID {
delete(account.Routes, routeID)
}
}
delete(account.Peers, deletedRouterID)
delete(validatedPeersMap, deletedRouterID)
if account.Network != nil {
account.Network.Serial++
}
err := builder.OnPeerDeleted(deletedRouterID)
err = builder.OnPeerDeleted(account, deletedRouterID)
require.NoError(t, err, "error deleting routing peer from cache")
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
normalizeAndSortNetworkMap(newNetworkMap)
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
require.NoError(t, err, "error marshaling new network map to JSON")
normalizeAndSortNetworkMap(networkMap)
if string(legacyJSON) != string(newJSON) {
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json")
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json")
jsonData, err := json.MarshalIndent(networkMap, "", " ")
require.NoError(t, err)
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
require.NoError(t, err)
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json")
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
require.NoError(t, err)
t.Logf("Saved legacy network map to %s", legacyFilePath)
t.Log("Update golden file with deleted router...")
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(goldenFilePath, jsonData, 0644)
require.NoError(t, err)
err = os.WriteFile(newFilePath, newJSON, 0644)
require.NoError(t, err)
t.Logf("Saved new network map to %s", newFilePath)
expectedJSON, err := os.ReadFile(goldenFilePath)
require.NoError(t, err)
require.JSONEq(t, string(expectedJSON), string(jsonData),
"network map after deleting router does not match golden file")
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted router from legacy and new builder do not match")
}
}
func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
@@ -847,7 +597,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
b.Run("old builder after delete", func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, testingPeerID := range peerIDs {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
}
}
})
@@ -855,7 +605,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
b.ResetTimer()
b.Run("new builder after delete", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = builder.OnPeerDeleted(deletedPeerID)
_ = builder.OnPeerDeleted(account, deletedPeerID)
for _, testingPeerID := range peerIDs {
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
}
@@ -927,6 +677,54 @@ func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) {
}
}
type networkMapJSON struct {
Peers []*nbpeer.Peer `json:"Peers"`
Network *types.Network `json:"Network"`
Routes []*route.Route `json:"Routes"`
DNSConfig dns.Config `json:"DNSConfig"`
OfflinePeers []*nbpeer.Peer `json:"OfflinePeers"`
FirewallRules []*types.FirewallRule `json:"FirewallRules"`
RoutesFirewallRules []*types.RouteFirewallRule `json:"RoutesFirewallRules"`
ForwardingRules []*types.ForwardingRule `json:"ForwardingRules"`
AuthorizedUsers map[string][]string `json:"AuthorizedUsers,omitempty"`
EnableSSH bool `json:"EnableSSH"`
}
func toNetworkMapJSON(nm *types.NetworkMap) *networkMapJSON {
result := &networkMapJSON{
Peers: nm.Peers,
Network: nm.Network,
Routes: nm.Routes,
DNSConfig: nm.DNSConfig,
OfflinePeers: nm.OfflinePeers,
FirewallRules: nm.FirewallRules,
RoutesFirewallRules: nm.RoutesFirewallRules,
ForwardingRules: nm.ForwardingRules,
EnableSSH: nm.EnableSSH,
}
if len(nm.AuthorizedUsers) > 0 {
result.AuthorizedUsers = make(map[string][]string)
localUsers := make([]string, 0, len(nm.AuthorizedUsers))
for localUser := range nm.AuthorizedUsers {
localUsers = append(localUsers, localUser)
}
sort.Strings(localUsers)
for _, localUser := range localUsers {
userIDs := nm.AuthorizedUsers[localUser]
sortedUserIDs := make([]string, 0, len(userIDs))
for userID := range userIDs {
sortedUserIDs = append(sortedUserIDs, userID)
}
sort.Strings(sortedUserIDs)
result.AuthorizedUsers[localUser] = sortedUserIDs
}
}
return result
}
func createTestAccountWithEntities() *types.Account {
peers := make(map[string]*nbpeer.Peer)
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
@@ -962,9 +760,10 @@ func createTestAccountWithEntities() *types.Account {
}
groups := map[string]*types.Group{
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
sshUsersGroupID: {ID: sshUsersGroupID, Name: "SSH Users", Peers: []string{}},
}
policies := []*types.Policy{
@@ -1002,6 +801,15 @@ func createTestAccountWithEntities() *types.Account {
Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID},
}},
},
{
ID: policyIDSSH, Name: "SSH Access Policy", Enabled: true,
Rules: []*types.PolicyRule{{
ID: policyIDSSH, Name: "Allow SSH to Ops", Enabled: true, Action: types.PolicyTrafficActionAccept,
Protocol: types.PolicyRuleProtocolNetbirdSSH, Bidirectional: false,
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
AuthorizedGroups: map[string][]string{sshUsersGroupID: {"root", "admin"}},
}},
},
}
routes := map[route.ID]*route.Route{
@@ -1034,8 +842,15 @@ func createTestAccountWithEntities() *types.Account {
},
}
users := map[string]*types.User{
userAdminID: {Id: userAdminID, Role: types.UserRoleAdmin, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{allGroupID}},
userDevID: {Id: userDevID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, devGroupID}},
userOpsID: {Id: userOpsID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, opsGroupID}},
}
account := &types.Account{
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
Users: users,
Network: &types.Network{
Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
},
@@ -1071,6 +886,88 @@ func createTestAccountWithEntities() *types.Account {
return account
}
func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter_Batched(t *testing.T) {
account := createTestAccountWithEntities()
ctx := context.Background()
validatedPeersMap := make(map[string]struct{})
for i := range numPeers {
peerID := fmt.Sprintf("peer-%d", i)
if peerID == offlinePeerID {
continue
}
validatedPeersMap[peerID] = struct{}{}
}
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
newRouterID := "peer-new-router-102"
newRouterIP := net.IP{100, 64, 1, 2}
newRouter := &nbpeer.Peer{
ID: newRouterID,
IP: newRouterIP,
Key: fmt.Sprintf("key-%s", newRouterID),
DNSLabel: "newrouter102",
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
UserID: "user-admin",
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
}
account.Peers[newRouterID] = newRouter
if opsGroup, exists := account.Groups[opsGroupID]; exists {
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
}
if allGroup, exists := account.Groups[allGroupID]; exists {
allGroup.Peers = append(allGroup.Peers, newRouterID)
}
newRoute := &route.Route{
ID: route.ID("route-new-router"),
Network: netip.MustParsePrefix("172.16.0.0/24"),
Peer: newRouter.Key,
PeerID: newRouterID,
Description: "Route from new router",
Enabled: true,
PeerGroups: []string{opsGroupID},
Groups: []string{devGroupID, opsGroupID},
AccessControlGroups: []string{devGroupID},
AccountID: account.Id,
}
account.Routes[newRoute.ID] = newRoute
validatedPeersMap[newRouterID] = struct{}{}
if account.Network != nil {
account.Network.Serial++
}
builder.EnqueuePeersForIncrementalAdd(account, newRouterID)
time.Sleep(100 * time.Millisecond)
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
normalizeAndSortNetworkMap(networkMap)
jsonData, err := json.MarshalIndent(networkMap, "", " ")
require.NoError(t, err, "error marshaling network map to JSON")
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
t.Log("Update golden file with OnPeerAdded router...")
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
require.NoError(t, err)
err = os.WriteFile(goldenFilePath, jsonData, 0644)
require.NoError(t, err)
expectedJSON, err := os.ReadFile(goldenFilePath)
require.NoError(t, err, "error reading golden file")
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file")
}
func createAccountFromFile() (*types.Account, error) {
accraw := filepath.Join("testdata", "account_cnlf3j3l0ubs738o5d4g.json")
data, err := os.ReadFile(accraw)
@@ -1343,7 +1240,7 @@ func BenchmarkGetPeerNetworkMapCompactCached(b *testing.B) {
b.Run("Legacy", func(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = account.GetPeerNetworkMap(ctx, testingPeerID, customZone, validatedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
_ = account.GetPeerNetworkMap(ctx, testingPeerID, customZone, validatedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
}
})
b.Run("LegacyCompacted", func(b *testing.B) {

View File

@@ -7,12 +7,12 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
@@ -27,6 +27,9 @@ const (
v6AllWildcard = "::/0"
fw = "fw:"
rfw = "route-fw:"
szAddPeerBatch = 10
maxPeerAddRetries = 20
)
type NetworkMapCache struct {
@@ -47,6 +50,10 @@ type NetworkMapCache struct {
peerDNS map[string]*nbdns.Config
peerFirewallRulesCompact map[string][]*FirewallRule
peerRoutesCompact map[string][]*route.Route
peerSSH map[string]*PeerSSHView
groupIDToUserIDs map[string][]string
allowedUserIDs map[string]struct{}
resourceRouters map[string]map[string]*routerTypes.NetworkRouter
resourcePolicies map[string][]*Policy
@@ -76,41 +83,64 @@ type PeerRoutesView struct {
RouteFirewallRuleIDs []string
}
type PeerSSHView struct {
EnableSSH bool
AuthorizedUsers map[string]map[string]struct{}
}
type NetworkMapBuilder struct {
account atomic.Pointer[Account]
account *Account
cache *NetworkMapCache
validatedPeers map[string]struct{}
apb addPeerBatch
}
type addPeerBatch struct {
mu sync.Mutex
sg *sync.Cond
ids []string
la *Account
retryCount map[string]int
}
func NewNetworkMapBuilder(account *Account, validatedPeers map[string]struct{}) *NetworkMapBuilder {
builder := &NetworkMapBuilder{
cache: &NetworkMapCache{
globalRoutes: make(map[route.ID]*route.Route),
globalRules: make(map[string]*FirewallRule),
globalRouteRules: make(map[string]*RouteFirewallRule),
globalPeers: make(map[string]*nbpeer.Peer),
groupToPeers: make(map[string][]string),
peerToGroups: make(map[string][]string),
policyToRules: make(map[string][]*PolicyRule),
groupToPolicies: make(map[string][]*Policy),
groupToRoutes: make(map[string][]*route.Route),
peerToRoutes: make(map[string][]*route.Route),
globalRoutes: make(map[route.ID]*route.Route),
globalRules: make(map[string]*FirewallRule),
globalRouteRules: make(map[string]*RouteFirewallRule),
globalPeers: make(map[string]*nbpeer.Peer),
groupToPeers: make(map[string][]string),
peerToGroups: make(map[string][]string),
policyToRules: make(map[string][]*PolicyRule),
groupToPolicies: make(map[string][]*Policy),
groupToRoutes: make(map[string][]*route.Route),
peerToRoutes: make(map[string][]*route.Route),
peerACLs: make(map[string]*PeerACLView),
peerRoutes: make(map[string]*PeerRoutesView),
peerDNS: make(map[string]*nbdns.Config),
peerSSH: make(map[string]*PeerSSHView),
groupIDToUserIDs: make(map[string][]string),
allowedUserIDs: make(map[string]struct{}),
peerFirewallRulesCompact: make(map[string][]*FirewallRule),
peerRoutesCompact: make(map[string][]*route.Route),
globalResources: make(map[string]*resourceTypes.NetworkResource),
acgToRoutes: make(map[string]map[route.ID]*RouteOwnerInfo),
noACGRoutes: make(map[route.ID]*RouteOwnerInfo),
acgToRoutes: make(map[string]map[route.ID]*RouteOwnerInfo),
noACGRoutes: make(map[route.ID]*RouteOwnerInfo),
},
validatedPeers: make(map[string]struct{}),
}
builder.account.Store(account)
builder.apb.sg = sync.NewCond(&builder.apb.mu)
builder.apb.ids = make([]string, 0, szAddPeerBatch)
builder.apb.la = account
builder.apb.retryCount = make(map[string]int)
maps.Copy(builder.validatedPeers, validatedPeers)
builder.initialBuild(account)
go builder.incAddPeerLoop()
return builder
}
@@ -118,6 +148,8 @@ func (b *NetworkMapBuilder) initialBuild(account *Account) {
b.cache.mu.Lock()
defer b.cache.mu.Unlock()
b.account = account
start := time.Now()
b.buildGlobalIndexes(account)
@@ -150,9 +182,15 @@ func (b *NetworkMapBuilder) buildGlobalIndexes(account *Account) {
clear(b.cache.peerToRoutes)
clear(b.cache.acgToRoutes)
clear(b.cache.noACGRoutes)
clear(b.cache.groupIDToUserIDs)
clear(b.cache.allowedUserIDs)
clear(b.cache.peerSSH)
maps.Copy(b.cache.globalPeers, account.Peers)
b.cache.groupIDToUserIDs = account.GetActiveGroupUsers()
b.cache.allowedUserIDs = b.buildAllowedUserIDs(account)
for groupID, group := range account.Groups {
peersCopy := make([]string, len(group.Peers))
copy(peersCopy, group.Peers)
@@ -227,7 +265,7 @@ func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) {
return
}
allPotentialPeers, firewallRules := b.getPeerConnectionResources(account, peer, b.validatedPeers)
allPotentialPeers, firewallRules, authorizedUsers, sshEnabled := b.getPeerConnectionResources(account, peer, b.validatedPeers)
isRouter, networkResourcesRoutes, sourcePeers := b.getNetworkResourcesForPeer(account, peer)
@@ -260,12 +298,17 @@ func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) {
compactedRules := compactFirewallRules(firewallRules)
b.cache.peerFirewallRulesCompact[peerID] = compactedRules
b.cache.peerSSH[peerID] = &PeerSSHView{
EnableSSH: sshEnabled,
AuthorizedUsers: authorizedUsers,
}
}
func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *nbpeer.Peer,
validatedPeersMap map[string]struct{},
) ([]*nbpeer.Peer, []*FirewallRule) {
) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) {
peerID := peer.ID
ctx := context.Background()
peerGroups := b.cache.peerToGroups[peerID]
peerGroupsMap := make(map[string]struct{}, len(peerGroups))
@@ -278,9 +321,15 @@ func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *n
fwRules := make([]*FirewallRule, 0)
peers := make([]*nbpeer.Peer, 0)
authorizedUsers := make(map[string]map[string]struct{})
sshEnabled := false
for _, group := range peerGroups {
policies := b.cache.groupToPolicies[group]
for _, policy := range policies {
if isValid := account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID); !isValid {
continue
}
rules := b.cache.policyToRules[policy.ID]
for _, rule := range rules {
var sourcePeers, destinationPeers []*nbpeer.Peer
@@ -323,13 +372,13 @@ func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *n
if rule.Bidirectional {
if peerInSources {
b.generateResourcescached(
account, rule, destinationPeers, FirewallRuleDirectionIN,
rule, destinationPeers, FirewallRuleDirectionIN,
peer, &peers, &fwRules, peersExists, rulesExists,
)
}
if peerInDestinations {
b.generateResourcescached(
account, rule, sourcePeers, FirewallRuleDirectionOUT,
rule, sourcePeers, FirewallRuleDirectionOUT,
peer, &peers, &fwRules, peersExists, rulesExists,
)
}
@@ -337,22 +386,58 @@ func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *n
if peerInSources {
b.generateResourcescached(
account, rule, destinationPeers, FirewallRuleDirectionOUT,
rule, destinationPeers, FirewallRuleDirectionOUT,
peer, &peers, &fwRules, peersExists, rulesExists,
)
}
if peerInDestinations {
b.generateResourcescached(
account, rule, sourcePeers, FirewallRuleDirectionIN,
rule, sourcePeers, FirewallRuleDirectionIN,
peer, &peers, &fwRules, peersExists, rulesExists,
)
if rule.Protocol == PolicyRuleProtocolNetbirdSSH {
sshEnabled = true
switch {
case len(rule.AuthorizedGroups) > 0:
for groupID, localUsers := range rule.AuthorizedGroups {
userIDs, ok := b.cache.groupIDToUserIDs[groupID]
if !ok {
continue
}
if len(localUsers) == 0 {
localUsers = []string{auth.Wildcard}
}
for _, localUser := range localUsers {
if authorizedUsers[localUser] == nil {
authorizedUsers[localUser] = make(map[string]struct{})
}
for _, userID := range userIDs {
authorizedUsers[localUser][userID] = struct{}{}
}
}
}
case rule.AuthorizedUser != "":
if authorizedUsers[auth.Wildcard] == nil {
authorizedUsers[auth.Wildcard] = make(map[string]struct{})
}
authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{}
default:
authorizedUsers[auth.Wildcard] = maps.Clone(b.cache.allowedUserIDs)
}
} else if policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
sshEnabled = true
authorizedUsers[auth.Wildcard] = maps.Clone(b.cache.allowedUserIDs)
}
}
}
}
}
return peers, fwRules
return peers, fwRules, authorizedUsers, sshEnabled
}
func (b *NetworkMapBuilder) isPeerInGroupscached(groupIDs []string, peerGroupsMap map[string]struct{}) bool {
@@ -405,14 +490,9 @@ func (b *NetworkMapBuilder) getPeersFromGroupscached(account *Account, groupIDs
}
func (b *NetworkMapBuilder) generateResourcescached(
account *Account, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, targetPeer *nbpeer.Peer,
rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, targetPeer *nbpeer.Peer,
peers *[]*nbpeer.Peer, rules *[]*FirewallRule, peersExists map[string]struct{}, rulesExists map[string]struct{},
) {
isAll := false
if allGroup, err := account.GetGroupAll(); err == nil {
isAll = (len(allGroup.Peers) - 1) == len(groupPeers)
}
for _, peer := range groupPeers {
if peer == nil {
continue
@@ -427,11 +507,7 @@ func (b *NetworkMapBuilder) generateResourcescached(
PeerIP: peer.IP.String(),
Direction: direction,
Action: string(rule.Action),
Protocol: string(rule.Protocol),
}
if isAll {
fr.PeerIP = allPeers
Protocol: firewallRuleProtocol(rule.Protocol),
}
var s strings.Builder
@@ -942,8 +1018,29 @@ func (b *NetworkMapBuilder) getPeerNSGroups(account *Account, peerID string, che
return peerNSGroups
}
func (b *NetworkMapBuilder) UpdateAccountPointer(account *Account) {
b.account.Store(account)
func (b *NetworkMapBuilder) buildAllowedUserIDs(account *Account) map[string]struct{} {
users := make(map[string]struct{})
for _, nbUser := range account.Users {
if !nbUser.IsBlocked() && !nbUser.IsServiceUser {
users[nbUser.Id] = struct{}{}
}
}
return users
}
func firewallRuleProtocol(protocol PolicyRuleProtocolType) string {
if protocol == PolicyRuleProtocolNetbirdSSH {
return string(PolicyRuleProtocolTCP)
}
return string(protocol)
}
// lock should be held
func (b *NetworkMapBuilder) updateAccountLocked(account *Account) *Account {
if account.Network.CurrentSerial() > b.account.Network.CurrentSerial() {
b.account = account
}
return b.account
}
func (b *NetworkMapBuilder) GetPeerNetworkMap(
@@ -951,25 +1048,27 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap(
validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
start := time.Now()
account := b.account.Load()
b.cache.mu.RLock()
defer b.cache.mu.RUnlock()
account := b.account
peer := account.GetPeer(peerID)
if peer == nil {
return &NetworkMap{Network: account.Network.Copy()}
}
b.cache.mu.RLock()
defer b.cache.mu.RUnlock()
aclView := b.cache.peerACLs[peerID]
routesView := b.cache.peerRoutes[peerID]
dnsConfig := b.cache.peerDNS[peerID]
sshView := b.cache.peerSSH[peerID]
if aclView == nil || routesView == nil || dnsConfig == nil {
return &NetworkMap{Network: account.Network.Copy()}
}
nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, peersCustomZone, validatedPeers)
nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, validatedPeers)
if metrics != nil {
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
@@ -987,7 +1086,7 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap(
func (b *NetworkMapBuilder) assembleNetworkMap(
account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
dnsConfig *nbdns.Config, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
dnsConfig *nbdns.Config, sshView *PeerSSHView, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
) *NetworkMap {
var peersToConnect []*nbpeer.Peer
@@ -1045,7 +1144,7 @@ func (b *NetworkMapBuilder) assembleNetworkMap(
finalDNSConfig.CustomZones = zones
}
return &NetworkMap{
nm := &NetworkMap{
Peers: peersToConnect,
Network: account.Network.Copy(),
Routes: routes,
@@ -1054,6 +1153,13 @@ func (b *NetworkMapBuilder) assembleNetworkMap(
FirewallRules: firewallRules,
RoutesFirewallRules: routesFirewallRules,
}
if sshView != nil {
nm.EnableSSH = sshView.EnableSSH
nm.AuthorizedUsers = sshView.AuthorizedUsers
}
return nm
}
func (b *NetworkMapBuilder) GetPeerNetworkMapCompactCached(
@@ -1061,64 +1167,27 @@ func (b *NetworkMapBuilder) GetPeerNetworkMapCompactCached(
validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
start := time.Now()
account := b.account.Load()
b.cache.mu.RLock()
defer b.cache.mu.RUnlock()
account := b.account
peer := account.GetPeer(peerID)
if peer == nil {
return &NetworkMap{Network: account.Network.Copy()}
}
b.cache.mu.RLock()
defer b.cache.mu.RUnlock()
aclView := b.cache.peerACLs[peerID]
routesView := b.cache.peerRoutes[peerID]
dnsConfig := b.cache.peerDNS[peerID]
sshView := b.cache.peerSSH[peerID]
if aclView == nil || routesView == nil || dnsConfig == nil {
return &NetworkMap{Network: account.Network.Copy()}
}
nm := b.assembleNetworkMapCompactCached(account, peer, aclView, routesView, dnsConfig, peersCustomZone, validatedPeers)
if metrics != nil {
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
metrics.CountNetworkMapObjects(objectCount)
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
if objectCount > 5000 {
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from cache",
account.Id, objectCount)
}
}
return nm
}
func (b *NetworkMapBuilder) GetPeerNetworkMapCompact(
ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone,
validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
start := time.Now()
account := b.account.Load()
peer := account.GetPeer(peerID)
if peer == nil {
return &NetworkMap{Network: account.Network.Copy()}
}
b.cache.mu.RLock()
defer b.cache.mu.RUnlock()
aclView := b.cache.peerACLs[peerID]
routesView := b.cache.peerRoutes[peerID]
dnsConfig := b.cache.peerDNS[peerID]
if aclView == nil || routesView == nil || dnsConfig == nil {
return &NetworkMap{Network: account.Network.Copy()}
}
nm := b.assembleNetworkMapCompact(account, peer, aclView, routesView, dnsConfig, peersCustomZone, validatedPeers)
nm := b.assembleNetworkMapCompactCached(account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, validatedPeers)
if metrics != nil {
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
@@ -1136,7 +1205,7 @@ func (b *NetworkMapBuilder) GetPeerNetworkMapCompact(
func (b *NetworkMapBuilder) assembleNetworkMapCompactCached(
account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
dnsConfig *nbdns.Config, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
dnsConfig *nbdns.Config, sshView *PeerSSHView, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
) *NetworkMap {
var peersToConnect []*nbpeer.Peer
@@ -1182,7 +1251,7 @@ func (b *NetworkMapBuilder) assembleNetworkMapCompactCached(
finalDNSConfig.CustomZones = zones
}
return &NetworkMap{
nm := &NetworkMap{
Peers: peersToConnect,
Network: account.Network.Copy(),
Routes: routes,
@@ -1191,11 +1260,59 @@ func (b *NetworkMapBuilder) assembleNetworkMapCompactCached(
FirewallRules: firewallRules,
RoutesFirewallRules: routesFirewallRules,
}
if sshView != nil {
nm.EnableSSH = sshView.EnableSSH
nm.AuthorizedUsers = sshView.AuthorizedUsers
}
return nm
}
func (b *NetworkMapBuilder) GetPeerNetworkMapCompact(
ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone,
validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
start := time.Now()
b.cache.mu.RLock()
defer b.cache.mu.RUnlock()
account := b.account
peer := account.GetPeer(peerID)
if peer == nil {
return &NetworkMap{Network: account.Network.Copy()}
}
aclView := b.cache.peerACLs[peerID]
routesView := b.cache.peerRoutes[peerID]
dnsConfig := b.cache.peerDNS[peerID]
sshView := b.cache.peerSSH[peerID]
if aclView == nil || routesView == nil || dnsConfig == nil {
return &NetworkMap{Network: account.Network.Copy()}
}
nm := b.assembleNetworkMapCompact(account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, validatedPeers)
if metrics != nil {
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
metrics.CountNetworkMapObjects(objectCount)
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
if objectCount > 5000 {
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from cache",
account.Id, objectCount)
}
}
return nm
}
func (b *NetworkMapBuilder) assembleNetworkMapCompact(
account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
dnsConfig *nbdns.Config, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
dnsConfig *nbdns.Config, sshView *PeerSSHView, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
) *NetworkMap {
var peersToConnect []*nbpeer.Peer
@@ -1279,7 +1396,7 @@ func (b *NetworkMapBuilder) assembleNetworkMapCompact(
finalDNSConfig.CustomZones = zones
}
return &NetworkMap{
nm := &NetworkMap{
Peers: peersToConnect,
Network: account.Network.Copy(),
Routes: routes,
@@ -1288,14 +1405,13 @@ func (b *NetworkMapBuilder) assembleNetworkMapCompact(
FirewallRules: firewallRules,
RoutesFirewallRules: routesFirewallRules,
}
}
func splitRouteAndPeer(r *route.Route) (string, string) {
parts := strings.Split(string(r.ID), ":")
if len(parts) < 2 {
return string(r.ID), ""
if sshView != nil {
nm.EnableSSH = sshView.EnableSSH
nm.AuthorizedUsers = sshView.AuthorizedUsers
}
return parts[0], parts[1]
return nm
}
func (b *NetworkMapBuilder) compactRoutesForPeer(peerID string, ownRouteIDs []route.ID, otherRouteIDs []route.ID) []*route.Route {
@@ -1427,6 +1543,14 @@ func compactFirewallRules(expandedRules []*FirewallRule) []*FirewallRule {
return compactRules
}
func splitRouteAndPeer(r *route.Route) (string, string) {
parts := strings.Split(string(r.ID), ":")
if len(parts) < 2 {
return string(r.ID), ""
}
return parts[0], parts[1]
}
func (b *NetworkMapBuilder) generateFirewallRuleID(rule *FirewallRule) string {
var s strings.Builder
s.WriteString(fw)
@@ -1501,6 +1625,106 @@ func (b *NetworkMapBuilder) isPeerRouter(account *Account, peerID string) bool {
return false
}
func (b *NetworkMapBuilder) incAddPeerLoop() {
for {
b.apb.mu.Lock()
if len(b.apb.ids) == 0 {
b.apb.sg.Wait()
}
b.addPeersIncrementally()
b.apb.mu.Unlock()
}
}
// lock on b.apb level should be held
func (b *NetworkMapBuilder) addPeersIncrementally() {
peers := slices.Clone(b.apb.ids)
clear(b.apb.ids)
b.apb.ids = b.apb.ids[:0]
latestAcc := b.apb.la
b.apb.mu.Unlock()
tt := time.Now()
b.cache.mu.Lock()
defer b.cache.mu.Unlock()
account := b.updateAccountLocked(latestAcc)
log.Debugf("NetworkMapBuilder: Starting incremental add of %d peers", len(peers))
allUpdates := make(map[string]*PeerUpdateDelta)
for _, peerID := range peers {
peer := account.GetPeer(peerID)
if peer == nil {
b.apb.mu.Lock()
retries := b.apb.retryCount[peerID]
b.apb.mu.Unlock()
if retries >= maxPeerAddRetries {
log.Errorf("NetworkMapBuilder: peer %s not found in account %s after %d retries, giving up", peerID, account.Id, retries)
b.apb.mu.Lock()
delete(b.apb.retryCount, peerID)
b.apb.mu.Unlock()
continue
}
log.Warnf("NetworkMapBuilder: peer %s not found in account %s, retry %d/%d", peerID, account.Id, retries+1, maxPeerAddRetries)
b.apb.mu.Lock()
b.apb.retryCount[peerID] = retries + 1
b.apb.mu.Unlock()
b.enqueuePeersForIncrementalAdd(latestAcc, peerID)
continue
}
b.apb.mu.Lock()
delete(b.apb.retryCount, peerID)
b.apb.mu.Unlock()
b.validatedPeers[peerID] = struct{}{}
b.cache.globalPeers[peerID] = peer
peerGroups := b.updateIndexesForNewPeer(account, peerID)
b.buildPeerACLView(account, peerID)
b.buildPeerRoutesView(account, peerID)
b.buildPeerDNSView(account, peerID)
peerDeltas := b.collectDeltasForNewPeer(account, peerID, peerGroups)
for affectedPeerID, delta := range peerDeltas {
if existing, ok := allUpdates[affectedPeerID]; ok {
existing.mergeFrom(delta)
continue
}
allUpdates[affectedPeerID] = delta
}
}
for affectedPeerID, delta := range allUpdates {
b.applyDeltaToPeer(account, affectedPeerID, delta)
}
log.Debugf("NetworkMapBuilder: Added %d peers to cache, affected %d peers, took %s", len(peers), len(allUpdates), time.Since(tt))
b.apb.mu.Lock()
if len(b.apb.ids) > 0 {
b.apb.sg.Signal()
}
}
func (b *NetworkMapBuilder) enqueuePeersForIncrementalAdd(acc *Account, peerIDs ...string) {
b.apb.mu.Lock()
b.apb.ids = append(b.apb.ids, peerIDs...)
if b.apb.la != nil && acc.Network.CurrentSerial() > b.apb.la.Network.CurrentSerial() {
b.apb.la = acc
}
b.apb.sg.Signal()
b.apb.mu.Unlock()
}
func (b *NetworkMapBuilder) EnqueuePeersForIncrementalAdd(acc *Account, peerIDs ...string) {
b.enqueuePeersForIncrementalAdd(acc, peerIDs...)
}
type ViewDelta struct {
AddedPeerIDs []string
RemovedPeerIDs []string
@@ -1508,17 +1732,18 @@ type ViewDelta struct {
RemovedRuleIDs []string
}
func (b *NetworkMapBuilder) OnPeerAddedIncremental(peerID string) error {
func (b *NetworkMapBuilder) OnPeerAddedIncremental(acc *Account, peerID string) error {
tt := time.Now()
account := b.account.Load()
peer := account.GetPeer(peerID)
peer := acc.GetPeer(peerID)
if peer == nil {
return fmt.Errorf("peer %s not found in account", peerID)
return fmt.Errorf("NetworkMapBuilder: peer %s not found in account", peerID)
}
b.cache.mu.Lock()
defer b.cache.mu.Unlock()
account := b.updateAccountLocked(acc)
log.Debugf("NetworkMapBuilder: Adding peer %s (IP: %s) to cache", peerID, peer.IP.String())
b.validatedPeers[peerID] = struct{}{}
@@ -1577,6 +1802,13 @@ func (b *NetworkMapBuilder) updateIndexesForNewPeer(account *Account, peerID str
}
func (b *NetworkMapBuilder) incrementalUpdateAffectedPeers(account *Account, newPeerID string, peerGroups []string) {
updates := b.collectDeltasForNewPeer(account, newPeerID, peerGroups)
for affectedPeerID, delta := range updates {
b.applyDeltaToPeer(account, affectedPeerID, delta)
}
}
func (b *NetworkMapBuilder) collectDeltasForNewPeer(account *Account, newPeerID string, peerGroups []string) map[string]*PeerUpdateDelta {
updates := b.calculateIncrementalUpdates(account, newPeerID, peerGroups)
if b.isPeerRouter(account, newPeerID) {
@@ -1596,9 +1828,7 @@ func (b *NetworkMapBuilder) incrementalUpdateAffectedPeers(account *Account, new
}
}
for affectedPeerID, delta := range updates {
b.applyDeltaToPeer(account, affectedPeerID, delta)
}
return updates
}
func (b *NetworkMapBuilder) findPeersAffectedByNewRouter(account *Account, newRouterID string, routerGroups []string) map[string]struct{} {
@@ -1792,8 +2022,8 @@ func (b *NetworkMapBuilder) calculateNewRouterNetworkResourceUpdates(
updates[peerID] = delta
}
if delta.AddConnectedPeer == "" {
delta.AddConnectedPeer = newPeerID
if !slices.Contains(delta.AddConnectedPeers, newPeerID) {
delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID)
}
delta.RebuildRoutesView = true
@@ -1922,8 +2152,8 @@ func (b *NetworkMapBuilder) calculateNetworkResourceFirewallUpdates(
updates[routerPeerID] = delta
}
if delta.AddConnectedPeer == "" {
delta.AddConnectedPeer = newPeerID
if !slices.Contains(delta.AddConnectedPeers, newPeerID) {
delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID)
}
delta.RebuildRoutesView = true
@@ -1933,13 +2163,63 @@ func (b *NetworkMapBuilder) calculateNetworkResourceFirewallUpdates(
type PeerUpdateDelta struct {
PeerID string
AddConnectedPeer string
AddConnectedPeers []string
AddFirewallRules []*FirewallRuleDelta
AddRoutes []route.ID
UpdateRouteFirewallRules []*RouteFirewallRuleUpdate
UpdateDNS bool
RebuildRoutesView bool
}
func (d *PeerUpdateDelta) mergeFrom(other *PeerUpdateDelta) {
for _, peerID := range other.AddConnectedPeers {
if !slices.Contains(d.AddConnectedPeers, peerID) {
d.AddConnectedPeers = append(d.AddConnectedPeers, peerID)
}
}
existingRuleIDs := make(map[string]struct{}, len(d.AddFirewallRules))
for _, rule := range d.AddFirewallRules {
existingRuleIDs[rule.RuleID] = struct{}{}
}
for _, rule := range other.AddFirewallRules {
if _, exists := existingRuleIDs[rule.RuleID]; !exists {
d.AddFirewallRules = append(d.AddFirewallRules, rule)
existingRuleIDs[rule.RuleID] = struct{}{}
}
}
for _, routeID := range other.AddRoutes {
if !slices.Contains(d.AddRoutes, routeID) {
d.AddRoutes = append(d.AddRoutes, routeID)
}
}
existingRouteUpdates := make(map[string]map[string]struct{})
for _, update := range d.UpdateRouteFirewallRules {
if existingRouteUpdates[update.RuleID] == nil {
existingRouteUpdates[update.RuleID] = make(map[string]struct{})
}
existingRouteUpdates[update.RuleID][update.AddSourceIP] = struct{}{}
}
for _, update := range other.UpdateRouteFirewallRules {
if existingRouteUpdates[update.RuleID] == nil {
existingRouteUpdates[update.RuleID] = make(map[string]struct{})
}
if _, exists := existingRouteUpdates[update.RuleID][update.AddSourceIP]; !exists {
d.UpdateRouteFirewallRules = append(d.UpdateRouteFirewallRules, update)
existingRouteUpdates[update.RuleID][update.AddSourceIP] = struct{}{}
}
}
if other.UpdateDNS {
d.UpdateDNS = true
}
if other.RebuildRoutesView {
d.RebuildRoutesView = true
}
}
type FirewallRuleDelta struct {
Rule *FirewallRule
RuleID string
@@ -1977,7 +2257,7 @@ func (b *NetworkMapBuilder) addUpdateForPeersInGroups(
PeerIP: newPeer.IP.String(),
Direction: direction,
Action: string(rule.Action),
Protocol: string(rule.Protocol),
Protocol: firewallRuleProtocol(rule.Protocol),
}
for _, peerID := range peers {
if peerID == newPeerID {
@@ -2028,7 +2308,7 @@ func (b *NetworkMapBuilder) addUpdateForDirectPeerResource(
PeerIP: newPeer.IP.String(),
Direction: direction,
Action: string(rule.Action),
Protocol: string(rule.Protocol),
Protocol: firewallRuleProtocol(rule.Protocol),
}
b.addOrUpdateFirewallRuleInDelta(updates, targetPeerID, newPeerID, rule, direction, fr, fr.PeerIP, targetPeer)
@@ -2041,11 +2321,13 @@ func (b *NetworkMapBuilder) addOrUpdateFirewallRuleInDelta(
delta := updates[targetPeerID]
if delta == nil {
delta = &PeerUpdateDelta{
PeerID: targetPeerID,
AddConnectedPeer: newPeerID,
AddFirewallRules: make([]*FirewallRuleDelta, 0),
PeerID: targetPeerID,
AddConnectedPeers: []string{newPeerID},
AddFirewallRules: make([]*FirewallRuleDelta, 0),
}
updates[targetPeerID] = delta
} else if !slices.Contains(delta.AddConnectedPeers, newPeerID) {
delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID)
}
baseRule.PeerIP = peerIP
@@ -2071,10 +2353,12 @@ func (b *NetworkMapBuilder) addOrUpdateFirewallRuleInDelta(
}
func (b *NetworkMapBuilder) applyDeltaToPeer(account *Account, peerID string, delta *PeerUpdateDelta) {
if delta.AddConnectedPeer != "" || len(delta.AddFirewallRules) > 0 {
if len(delta.AddConnectedPeers) > 0 || len(delta.AddFirewallRules) > 0 {
if aclView := b.cache.peerACLs[peerID]; aclView != nil {
if delta.AddConnectedPeer != "" && !slices.Contains(aclView.ConnectedPeerIDs, delta.AddConnectedPeer) {
aclView.ConnectedPeerIDs = append(aclView.ConnectedPeerIDs, delta.AddConnectedPeer)
for _, connectedPeerID := range delta.AddConnectedPeers {
if !slices.Contains(aclView.ConnectedPeerIDs, connectedPeerID) {
aclView.ConnectedPeerIDs = append(aclView.ConnectedPeerIDs, connectedPeerID)
}
}
for _, ruleDelta := range delta.AddFirewallRules {
@@ -2130,11 +2414,11 @@ func (b *NetworkMapBuilder) updateRouteFirewallRules(routesView *PeerRoutesView,
}
}
func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error {
func (b *NetworkMapBuilder) OnPeerDeleted(acc *Account, peerID string) error {
b.cache.mu.Lock()
defer b.cache.mu.Unlock()
account := b.account.Load()
account := b.updateAccountLocked(acc)
deletedPeer := b.cache.globalPeers[peerID]
if deletedPeer == nil {
@@ -2190,6 +2474,7 @@ func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error {
delete(b.cache.peerACLs, peerID)
delete(b.cache.peerRoutes, peerID)
delete(b.cache.peerDNS, peerID)
delete(b.cache.peerSSH, peerID)
delete(b.cache.globalPeers, peerID)
@@ -2240,11 +2525,16 @@ func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error {
b.buildPeerRoutesView(account, affectedPeerID)
}
peerDeletionUpdates := b.findPeersAffectedByDeletedPeerACL(peerID, peerIP)
peersToRebuildACL := make(map[string]struct{})
peerDeletionUpdates := b.findPeersAffectedByDeletedPeerACL(peerID, peerIP, peerGroups, peersToRebuildACL)
for affectedPeerID, updates := range peerDeletionUpdates {
b.applyDeletionUpdates(affectedPeerID, updates)
}
for affectedPeerID := range peersToRebuildACL {
b.buildPeerACLView(account, affectedPeerID)
}
b.cleanupUnusedRules()
log.Debugf("NetworkMapBuilder: Deleted peer %s, affected %d other peers", peerID, len(affectedPeers))
@@ -2255,6 +2545,8 @@ func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error {
func (b *NetworkMapBuilder) findPeersAffectedByDeletedPeerACL(
deletedPeerID string,
peerIP string,
peerGroups []string,
peersToRebuildACL map[string]struct{},
) map[string]*PeerDeletionUpdate {
affected := make(map[string]*PeerDeletionUpdate)
@@ -2264,26 +2556,47 @@ func (b *NetworkMapBuilder) findPeersAffectedByDeletedPeerACL(
continue
}
if !slices.Contains(aclView.ConnectedPeerIDs, deletedPeerID) {
continue
}
if affected[peerID] == nil {
affected[peerID] = &PeerDeletionUpdate{
RemovePeerID: deletedPeerID,
PeerIP: peerIP,
if slices.Contains(aclView.ConnectedPeerIDs, deletedPeerID) {
peersToRebuildACL[peerID] = struct{}{}
if affected[peerID] == nil {
affected[peerID] = &PeerDeletionUpdate{
RemovePeerID: deletedPeerID,
PeerIP: peerIP,
}
}
}
}
for _, ruleID := range aclView.FirewallRuleIDs {
if rule := b.cache.globalRules[ruleID]; rule != nil && rule.PeerIP == peerIP {
affected[peerID].RemoveFirewallRuleIDs = append(
affected[peerID].RemoveFirewallRuleIDs,
ruleID,
)
affectedRouteOwners := make(map[string]struct{})
for _, groupID := range peerGroups {
if routeMap, ok := b.cache.acgToRoutes[groupID]; ok {
for _, info := range routeMap {
if info.PeerID != deletedPeerID {
affectedRouteOwners[info.PeerID] = struct{}{}
}
}
}
}
for _, info := range b.cache.noACGRoutes {
if info.PeerID != deletedPeerID {
affectedRouteOwners[info.PeerID] = struct{}{}
}
}
for ownerPeerID := range affectedRouteOwners {
if affected[ownerPeerID] == nil {
affected[ownerPeerID] = &PeerDeletionUpdate{
RemovePeerID: deletedPeerID,
PeerIP: peerIP,
RemoveFromSourceRanges: true,
}
} else {
affected[ownerPeerID].RemoveFromSourceRanges = true
}
}
return affected
}
@@ -2296,18 +2609,6 @@ type PeerDeletionUpdate struct {
}
func (b *NetworkMapBuilder) applyDeletionUpdates(peerID string, updates *PeerDeletionUpdate) {
if aclView := b.cache.peerACLs[peerID]; aclView != nil {
aclView.ConnectedPeerIDs = slices.DeleteFunc(aclView.ConnectedPeerIDs, func(id string) bool {
return id == updates.RemovePeerID
})
if len(updates.RemoveFirewallRuleIDs) > 0 {
aclView.FirewallRuleIDs = slices.DeleteFunc(aclView.FirewallRuleIDs, func(ruleID string) bool {
return slices.Contains(updates.RemoveFirewallRuleIDs, ruleID)
})
}
}
if routesView := b.cache.peerRoutes[peerID]; routesView != nil {
if len(updates.RemoveRouteIDs) > 0 {
routesView.NetworkResourceIDs = slices.DeleteFunc(routesView.NetworkResourceIDs, func(routeID route.ID) bool {

View File

@@ -23,6 +23,8 @@ const (
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
// PolicyRuleProtocolICMP type of traffic
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
// PolicyRuleProtocolNetbirdSSH type of traffic
PolicyRuleProtocolNetbirdSSH = PolicyRuleProtocolType("netbird-ssh")
)
const (
@@ -167,6 +169,8 @@ func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error)
protocol = PolicyRuleProtocolUDP
case "icmp":
return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'")
case "netbird-ssh":
return PolicyRuleProtocolNetbirdSSH, RulePortRange{Start: nativeSSHPortNumber, End: nativeSSHPortNumber}, nil
default:
return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr)
}

View File

@@ -80,6 +80,12 @@ type PolicyRule struct {
// PortRanges a list of port ranges.
PortRanges []RulePortRange `gorm:"serializer:json"`
// AuthorizedGroups is a map of groupIDs and their respective access to local users via ssh
AuthorizedGroups map[string][]string `gorm:"serializer:json"`
// AuthorizedUser is a list of userIDs that are authorized to access local resources via ssh
AuthorizedUser string
}
// Copy returns a copy of a policy rule
@@ -99,10 +105,16 @@ func (pm *PolicyRule) Copy() *PolicyRule {
Protocol: pm.Protocol,
Ports: make([]string, len(pm.Ports)),
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
AuthorizedGroups: make(map[string][]string, len(pm.AuthorizedGroups)),
AuthorizedUser: pm.AuthorizedUser,
}
copy(rule.Destinations, pm.Destinations)
copy(rule.Sources, pm.Sources)
copy(rule.Ports, pm.Ports)
copy(rule.PortRanges, pm.PortRanges)
for k, v := range pm.AuthorizedGroups {
rule.AuthorizedGroups[k] = make([]string, len(v))
copy(rule.AuthorizedGroups[k], v)
}
return rule
}

View File

@@ -52,6 +52,9 @@ type Settings struct {
// LazyConnectionEnabled indicates if the experimental feature is enabled or disabled
LazyConnectionEnabled bool `gorm:"default:false"`
// AutoUpdateVersion client auto-update version
AutoUpdateVersion string `gorm:"default:'disabled'"`
}
// Copy copies the Settings struct
@@ -72,6 +75,7 @@ func (s *Settings) Copy() *Settings {
LazyConnectionEnabled: s.LazyConnectionEnabled,
DNSDomain: s.DNSDomain,
NetworkRange: s.NetworkRange,
AutoUpdateVersion: s.AutoUpdateVersion,
}
if s.Extra != nil {
settings.Extra = s.Extra.Copy()

View File

@@ -7,6 +7,7 @@ import (
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference"
"github.com/netbirdio/netbird/util/crypt"
)
const (
@@ -65,7 +66,11 @@ type UserInfo struct {
LastLogin time.Time `json:"last_login"`
Issued string `json:"issued"`
PendingApproval bool `json:"pending_approval"`
Password string `json:"password"`
IntegrationReference integration_reference.IntegrationReference `json:"-"`
// IdPID is the identity provider ID (connector ID) extracted from the Dex-encoded user ID.
// This field is only populated when the user ID can be decoded from Dex's format.
IdPID string `json:"idp_id,omitempty"`
}
// User represents a user of the system
@@ -96,6 +101,9 @@ type User struct {
Issued string `gorm:"default:api"`
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
Name string `gorm:"default:''"`
Email string `gorm:"default:''"`
}
// IsBlocked returns true if the user is blocked, false otherwise
@@ -143,10 +151,16 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
}
if userData == nil {
name := u.Name
if u.IsServiceUser {
name = u.ServiceUserName
}
return &UserInfo{
ID: u.Id,
Email: "",
Name: u.ServiceUserName,
Email: u.Email,
Name: name,
Role: string(u.Role),
AutoGroups: u.AutoGroups,
Status: string(UserStatusActive),
@@ -178,6 +192,7 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
LastLogin: u.GetLastLogin(),
Issued: u.Issued,
PendingApproval: u.PendingApproval,
Password: userData.Password,
}, nil
}
@@ -204,11 +219,13 @@ func (u *User) Copy() *User {
CreatedAt: u.CreatedAt,
Issued: u.Issued,
IntegrationReference: u.IntegrationReference,
Email: u.Email,
Name: u.Name,
}
}
// NewUser creates a new user
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string, email string, name string) *User {
return &User{
Id: id,
Role: role,
@@ -218,20 +235,70 @@ func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, se
AutoGroups: autoGroups,
Issued: issued,
CreatedAt: time.Now().UTC(),
Name: name,
Email: email,
}
}
// NewRegularUser creates a new user with role UserRoleUser
func NewRegularUser(id string) *User {
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
func NewRegularUser(id, email, name string) *User {
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI, email, name)
}
// NewAdminUser creates a new user with role UserRoleAdmin
func NewAdminUser(id string) *User {
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI, "", "")
}
// NewOwnerUser creates a new user with role UserRoleOwner
func NewOwnerUser(id string) *User {
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
func NewOwnerUser(id string, email string, name string) *User {
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI, email, name)
}
// EncryptSensitiveData encrypts the user's sensitive fields (Email and Name) in place.
func (u *User) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
var err error
if u.Email != "" {
u.Email, err = enc.Encrypt(u.Email)
if err != nil {
return fmt.Errorf("encrypt email: %w", err)
}
}
if u.Name != "" {
u.Name, err = enc.Encrypt(u.Name)
if err != nil {
return fmt.Errorf("encrypt name: %w", err)
}
}
return nil
}
// DecryptSensitiveData decrypts the user's sensitive fields (Email and Name) in place.
func (u *User) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
if enc == nil {
return nil
}
var err error
if u.Email != "" {
u.Email, err = enc.Decrypt(u.Email)
if err != nil {
return fmt.Errorf("decrypt email: %w", err)
}
}
if u.Name != "" {
u.Name, err = enc.Decrypt(u.Name)
if err != nil {
return fmt.Errorf("decrypt name: %w", err)
}
}
return nil
}

View File

@@ -0,0 +1,298 @@
package types
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/util/crypt"
)
func TestUser_EncryptSensitiveData(t *testing.T) {
key, err := crypt.GenerateKey()
require.NoError(t, err)
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
require.NoError(t, err)
t.Run("encrypt email and name", func(t *testing.T) {
user := &User{
Id: "user-1",
Email: "test@example.com",
Name: "Test User",
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.NotEqual(t, "test@example.com", user.Email, "email should be encrypted")
assert.NotEqual(t, "Test User", user.Name, "name should be encrypted")
assert.NotEmpty(t, user.Email, "encrypted email should not be empty")
assert.NotEmpty(t, user.Name, "encrypted name should not be empty")
})
t.Run("encrypt empty email and name", func(t *testing.T) {
user := &User{
Id: "user-2",
Email: "",
Name: "",
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, "", user.Email, "empty email should remain empty")
assert.Equal(t, "", user.Name, "empty name should remain empty")
})
t.Run("encrypt only email", func(t *testing.T) {
user := &User{
Id: "user-3",
Email: "test@example.com",
Name: "",
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.NotEqual(t, "test@example.com", user.Email, "email should be encrypted")
assert.NotEmpty(t, user.Email, "encrypted email should not be empty")
assert.Equal(t, "", user.Name, "empty name should remain empty")
})
t.Run("encrypt only name", func(t *testing.T) {
user := &User{
Id: "user-4",
Email: "",
Name: "Test User",
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, "", user.Email, "empty email should remain empty")
assert.NotEqual(t, "Test User", user.Name, "name should be encrypted")
assert.NotEmpty(t, user.Name, "encrypted name should not be empty")
})
t.Run("nil encryptor returns no error", func(t *testing.T) {
user := &User{
Id: "user-5",
Email: "test@example.com",
Name: "Test User",
}
err := user.EncryptSensitiveData(nil)
require.NoError(t, err)
assert.Equal(t, "test@example.com", user.Email, "email should remain unchanged with nil encryptor")
assert.Equal(t, "Test User", user.Name, "name should remain unchanged with nil encryptor")
})
}
func TestUser_DecryptSensitiveData(t *testing.T) {
key, err := crypt.GenerateKey()
require.NoError(t, err)
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
require.NoError(t, err)
t.Run("decrypt email and name", func(t *testing.T) {
originalEmail := "test@example.com"
originalName := "Test User"
user := &User{
Id: "user-1",
Email: originalEmail,
Name: originalName,
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
err = user.DecryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, originalEmail, user.Email, "decrypted email should match original")
assert.Equal(t, originalName, user.Name, "decrypted name should match original")
})
t.Run("decrypt empty email and name", func(t *testing.T) {
user := &User{
Id: "user-2",
Email: "",
Name: "",
}
err := user.DecryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, "", user.Email, "empty email should remain empty")
assert.Equal(t, "", user.Name, "empty name should remain empty")
})
t.Run("decrypt only email", func(t *testing.T) {
originalEmail := "test@example.com"
user := &User{
Id: "user-3",
Email: originalEmail,
Name: "",
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
err = user.DecryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, originalEmail, user.Email, "decrypted email should match original")
assert.Equal(t, "", user.Name, "empty name should remain empty")
})
t.Run("decrypt only name", func(t *testing.T) {
originalName := "Test User"
user := &User{
Id: "user-4",
Email: "",
Name: originalName,
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
err = user.DecryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, "", user.Email, "empty email should remain empty")
assert.Equal(t, originalName, user.Name, "decrypted name should match original")
})
t.Run("nil encryptor returns no error", func(t *testing.T) {
user := &User{
Id: "user-5",
Email: "test@example.com",
Name: "Test User",
}
err := user.DecryptSensitiveData(nil)
require.NoError(t, err)
assert.Equal(t, "test@example.com", user.Email, "email should remain unchanged with nil encryptor")
assert.Equal(t, "Test User", user.Name, "name should remain unchanged with nil encryptor")
})
t.Run("decrypt with invalid ciphertext returns error", func(t *testing.T) {
user := &User{
Id: "user-6",
Email: "not-valid-base64-ciphertext!!!",
Name: "Test User",
}
err := user.DecryptSensitiveData(fieldEncrypt)
require.Error(t, err)
assert.Contains(t, err.Error(), "decrypt email")
})
t.Run("decrypt with wrong key returns error", func(t *testing.T) {
originalEmail := "test@example.com"
originalName := "Test User"
user := &User{
Id: "user-7",
Email: originalEmail,
Name: originalName,
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
differentKey, err := crypt.GenerateKey()
require.NoError(t, err)
differentEncrypt, err := crypt.NewFieldEncrypt(differentKey)
require.NoError(t, err)
err = user.DecryptSensitiveData(differentEncrypt)
require.Error(t, err)
assert.Contains(t, err.Error(), "decrypt email")
})
}
func TestUser_EncryptDecryptRoundTrip(t *testing.T) {
key, err := crypt.GenerateKey()
require.NoError(t, err)
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
require.NoError(t, err)
testCases := []struct {
name string
email string
uname string
}{
{
name: "standard email and name",
email: "user@example.com",
uname: "John Doe",
},
{
name: "email with special characters",
email: "user+tag@sub.example.com",
uname: "O'Brien, Mary-Jane",
},
{
name: "unicode characters",
email: "user@example.com",
uname: "Jean-Pierre Müller 日本語",
},
{
name: "long values",
email: "very.long.email.address.that.is.quite.extended@subdomain.example.organization.com",
uname: "A Very Long Name That Contains Many Words And Is Quite Extended For Testing Purposes",
},
{
name: "empty email only",
email: "",
uname: "Name Only",
},
{
name: "empty name only",
email: "email@only.com",
uname: "",
},
{
name: "both empty",
email: "",
uname: "",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
user := &User{
Id: "test-user",
Email: tc.email,
Name: tc.uname,
}
err := user.EncryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
if tc.email != "" {
assert.NotEqual(t, tc.email, user.Email, "email should be encrypted")
}
if tc.uname != "" {
assert.NotEqual(t, tc.uname, user.Name, "name should be encrypted")
}
err = user.DecryptSensitiveData(fieldEncrypt)
require.NoError(t, err)
assert.Equal(t, tc.email, user.Email, "decrypted email should match original")
assert.Equal(t, tc.uname, user.Name, "decrypted name should match original")
})
}
}