mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 17:26:40 +00:00
conflicts resolution
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
@@ -45,8 +46,10 @@ const (
|
||||
|
||||
// nativeSSHPortString defines the default port number as a string used for native SSH connections; this port is used by clients when hijacking ssh connections.
|
||||
nativeSSHPortString = "22022"
|
||||
nativeSSHPortNumber = 22022
|
||||
// defaultSSHPortString defines the standard SSH port number as a string, commonly used for default SSH connections.
|
||||
defaultSSHPortString = "22"
|
||||
defaultSSHPortNumber = 22
|
||||
)
|
||||
|
||||
type supportedFeatures struct {
|
||||
@@ -275,6 +278,7 @@ func (a *Account) GetPeerNetworkMap(
|
||||
resourcePolicies map[string][]*Policy,
|
||||
routers map[string]map[string]*routerTypes.NetworkRouter,
|
||||
metrics *telemetry.AccountManagerMetrics,
|
||||
groupIDToUserIDs map[string][]string,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
peer := a.Peers[peerID]
|
||||
@@ -290,7 +294,7 @@ func (a *Account) GetPeerNetworkMap(
|
||||
}
|
||||
}
|
||||
|
||||
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
|
||||
aclPeers, firewallRules, authorizedUsers, enableSSH := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap, groupIDToUserIDs)
|
||||
// exclude expired peers
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
var expiredPeers []*nbpeer.Peer
|
||||
@@ -338,6 +342,8 @@ func (a *Account) GetPeerNetworkMap(
|
||||
OfflinePeers: expiredPeers,
|
||||
FirewallRules: firewallRules,
|
||||
RoutesFirewallRules: slices.Concat(networkResourcesFirewallRules, routesFirewallRules),
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
EnableSSH: enableSSH,
|
||||
}
|
||||
|
||||
if metrics != nil {
|
||||
@@ -1122,8 +1128,10 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
|
||||
// GetPeerConnectionResources for a given peer
|
||||
//
|
||||
// This function returns the list of peers and firewall rules that are applicable to a given peer.
|
||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}, groupIDToUserIDs map[string][]string) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) {
|
||||
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
|
||||
authorizedUsers := make(map[string]map[string]struct{}) // machine user to list of userIDs
|
||||
sshEnabled := false
|
||||
|
||||
for _, policy := range a.Policies {
|
||||
if !policy.Enabled {
|
||||
@@ -1166,10 +1174,58 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
|
||||
if peerInDestinations {
|
||||
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
|
||||
}
|
||||
|
||||
if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
sshEnabled = true
|
||||
switch {
|
||||
case len(rule.AuthorizedGroups) > 0:
|
||||
for groupID, localUsers := range rule.AuthorizedGroups {
|
||||
userIDs, ok := groupIDToUserIDs[groupID]
|
||||
if !ok {
|
||||
log.WithContext(ctx).Tracef("no user IDs found for group ID %s", groupID)
|
||||
continue
|
||||
}
|
||||
|
||||
if len(localUsers) == 0 {
|
||||
localUsers = []string{auth.Wildcard}
|
||||
}
|
||||
|
||||
for _, localUser := range localUsers {
|
||||
if authorizedUsers[localUser] == nil {
|
||||
authorizedUsers[localUser] = make(map[string]struct{})
|
||||
}
|
||||
for _, userID := range userIDs {
|
||||
authorizedUsers[localUser][userID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
case rule.AuthorizedUser != "":
|
||||
if authorizedUsers[auth.Wildcard] == nil {
|
||||
authorizedUsers[auth.Wildcard] = make(map[string]struct{})
|
||||
}
|
||||
authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{}
|
||||
default:
|
||||
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
|
||||
}
|
||||
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
|
||||
sshEnabled = true
|
||||
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return getAccumulatedResources()
|
||||
peers, fwRules := getAccumulatedResources()
|
||||
return peers, fwRules, authorizedUsers, sshEnabled
|
||||
}
|
||||
|
||||
func (a *Account) getAllowedUserIDs() map[string]struct{} {
|
||||
users := make(map[string]struct{})
|
||||
for _, nbUser := range a.Users {
|
||||
if !nbUser.IsBlocked() && !nbUser.IsServiceUser {
|
||||
users[nbUser.Id] = struct{}{}
|
||||
}
|
||||
}
|
||||
return users
|
||||
}
|
||||
|
||||
// connResourcesGenerator returns generator and accumulator function which returns the result of generator calls
|
||||
@@ -1194,12 +1250,17 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
|
||||
peersExists[peer.ID] = struct{}{}
|
||||
}
|
||||
|
||||
protocol := rule.Protocol
|
||||
if protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
protocol = PolicyRuleProtocolTCP
|
||||
}
|
||||
|
||||
fr := FirewallRule{
|
||||
PolicyID: rule.ID,
|
||||
PeerIP: peer.IP.String(),
|
||||
Direction: direction,
|
||||
Action: string(rule.Action),
|
||||
Protocol: string(rule.Protocol),
|
||||
Protocol: string(protocol),
|
||||
}
|
||||
|
||||
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
|
||||
@@ -1221,6 +1282,28 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
|
||||
}
|
||||
}
|
||||
|
||||
func policyRuleImpliesLegacySSH(rule *PolicyRule) bool {
|
||||
return rule.Protocol == PolicyRuleProtocolALL || (rule.Protocol == PolicyRuleProtocolTCP && (portsIncludesSSH(rule.Ports) || portRangeIncludesSSH(rule.PortRanges)))
|
||||
}
|
||||
|
||||
func portRangeIncludesSSH(portRanges []RulePortRange) bool {
|
||||
for _, pr := range portRanges {
|
||||
if (pr.Start <= defaultSSHPortNumber && pr.End >= defaultSSHPortNumber) || (pr.Start <= nativeSSHPortNumber && pr.End >= nativeSSHPortNumber) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func portsIncludesSSH(ports []string) bool {
|
||||
for _, port := range ports {
|
||||
if port == defaultSSHPortString || port == nativeSSHPortString {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getAllPeersFromGroups for given peer ID and list of groups
|
||||
//
|
||||
// Returns a list of peers from specified groups that pass specified posture checks
|
||||
@@ -1265,7 +1348,11 @@ func (a *Account) getPeerFromResource(resource Resource, peerID string) ([]*nbpe
|
||||
return []*nbpeer.Peer{}, false
|
||||
}
|
||||
|
||||
return []*nbpeer.Peer{peer}, resource.ID == peerID
|
||||
if peer.ID == peerID {
|
||||
return []*nbpeer.Peer{}, true
|
||||
}
|
||||
|
||||
return []*nbpeer.Peer{peer}, false
|
||||
}
|
||||
|
||||
// validatePostureChecksOnPeer validates the posture checks on a peer
|
||||
@@ -1773,6 +1860,26 @@ func (a *Account) AddAllGroup(disableDefaultPolicy bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Account) GetActiveGroupUsers() map[string][]string {
|
||||
allGroupID := ""
|
||||
group, err := a.GetGroupAll()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get group all: %v", err)
|
||||
} else {
|
||||
allGroupID = group.ID
|
||||
}
|
||||
groups := make(map[string][]string, len(a.GroupsG))
|
||||
for _, user := range a.Users {
|
||||
if !user.IsBlocked() && !user.IsServiceUser {
|
||||
for _, groupID := range user.AutoGroups {
|
||||
groups[groupID] = append(groups[groupID], user.Id)
|
||||
}
|
||||
groups[allGroupID] = append(groups[allGroupID], user.Id)
|
||||
}
|
||||
}
|
||||
return groups
|
||||
}
|
||||
|
||||
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
|
||||
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
|
||||
features := peerSupportedFirewallFeatures(peer.Meta.WtVersion)
|
||||
@@ -1804,7 +1911,7 @@ func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer
|
||||
expanded = append(expanded, &fr)
|
||||
}
|
||||
|
||||
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) {
|
||||
if shouldCheckRulesForNativeSSH(features.nativeSSH, rule, peer) || rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
expanded = addNativeSSHRule(base, expanded)
|
||||
}
|
||||
|
||||
|
||||
@@ -1105,6 +1105,193 @@ func Test_ExpandPortsAndRanges_SSHRuleExpansion(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func Test_GetActiveGroupUsers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
expected map[string][]string
|
||||
}{
|
||||
{
|
||||
name: "all users are active",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group2", "group3"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user3": {
|
||||
Id: "user3",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{
|
||||
"group1": {"user1", "user3"},
|
||||
"group2": {"user1", "user2"},
|
||||
"group3": {"user2"},
|
||||
"": {"user1", "user2", "user3"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "some users are blocked",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group2", "group3"},
|
||||
Blocked: true,
|
||||
},
|
||||
"user3": {
|
||||
Id: "user3",
|
||||
AutoGroups: []string{"group1", "group3"},
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{
|
||||
"group1": {"user1", "user3"},
|
||||
"group2": {"user1"},
|
||||
"group3": {"user3"},
|
||||
"": {"user1", "user3"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "all users are blocked",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: true,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group2"},
|
||||
Blocked: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{},
|
||||
},
|
||||
{
|
||||
name: "user with no auto groups",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{},
|
||||
Blocked: false,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{
|
||||
"group1": {"user2"},
|
||||
"": {"user1", "user2"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty account",
|
||||
account: &Account{
|
||||
Users: map[string]*User{},
|
||||
},
|
||||
expected: map[string][]string{},
|
||||
},
|
||||
{
|
||||
name: "multiple users in same group",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user3": {
|
||||
Id: "user3",
|
||||
AutoGroups: []string{"group1"},
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{
|
||||
"group1": {"user1", "user2", "user3"},
|
||||
"": {"user1", "user2", "user3"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user in multiple groups with blocked users",
|
||||
account: &Account{
|
||||
Users: map[string]*User{
|
||||
"user1": {
|
||||
Id: "user1",
|
||||
AutoGroups: []string{"group1", "group2", "group3"},
|
||||
Blocked: false,
|
||||
},
|
||||
"user2": {
|
||||
Id: "user2",
|
||||
AutoGroups: []string{"group1", "group2"},
|
||||
Blocked: true,
|
||||
},
|
||||
"user3": {
|
||||
Id: "user3",
|
||||
AutoGroups: []string{"group3"},
|
||||
Blocked: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: map[string][]string{
|
||||
"group1": {"user1"},
|
||||
"group2": {"user1"},
|
||||
"group3": {"user1", "user3"},
|
||||
"": {"user1", "user3"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.account.GetActiveGroupUsers()
|
||||
|
||||
// Check that the number of groups matches
|
||||
assert.Equal(t, len(tt.expected), len(result), "number of groups should match")
|
||||
|
||||
// Check each group's users
|
||||
for groupID, expectedUsers := range tt.expected {
|
||||
actualUsers, exists := result[groupID]
|
||||
assert.True(t, exists, "group %s should exist in result", groupID)
|
||||
assert.ElementsMatch(t, expectedUsers, actualUsers, "users in group %s should match", groupID)
|
||||
}
|
||||
|
||||
// Ensure no extra groups in result
|
||||
for groupID := range result {
|
||||
_, exists := tt.expected[groupID]
|
||||
assert.True(t, exists, "unexpected group %s in result", groupID)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_FilterZoneRecordsForPeers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -25,16 +25,20 @@ func (h *Holder) GetAccount(id string) *Account {
|
||||
func (h *Holder) AddAccount(account *Account) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
a := h.accounts[account.Id]
|
||||
if a != nil && a.Network.CurrentSerial() >= account.Network.CurrentSerial() {
|
||||
return
|
||||
}
|
||||
h.accounts[account.Id] = account
|
||||
}
|
||||
|
||||
func (h *Holder) LoadOrStoreFunc(id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
|
||||
func (h *Holder) LoadOrStoreFunc(ctx context.Context, id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
|
||||
h.mu.Lock()
|
||||
defer h.mu.Unlock()
|
||||
if acc, ok := h.accounts[id]; ok {
|
||||
return acc, nil
|
||||
}
|
||||
account, err := accGetter(context.Background(), id)
|
||||
account, err := accGetter(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
122
management/server/types/identity_provider.go
Normal file
122
management/server/types/identity_provider.go
Normal file
@@ -0,0 +1,122 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// Identity provider validation errors
|
||||
var (
|
||||
ErrIdentityProviderNameRequired = errors.New("identity provider name is required")
|
||||
ErrIdentityProviderTypeRequired = errors.New("identity provider type is required")
|
||||
ErrIdentityProviderTypeUnsupported = errors.New("unsupported identity provider type")
|
||||
ErrIdentityProviderIssuerRequired = errors.New("identity provider issuer is required")
|
||||
ErrIdentityProviderIssuerInvalid = errors.New("identity provider issuer must be a valid URL")
|
||||
ErrIdentityProviderClientIDRequired = errors.New("identity provider client ID is required")
|
||||
)
|
||||
|
||||
// IdentityProviderType is the type of identity provider
|
||||
type IdentityProviderType string
|
||||
|
||||
const (
|
||||
// IdentityProviderTypeOIDC is a generic OIDC identity provider
|
||||
IdentityProviderTypeOIDC IdentityProviderType = "oidc"
|
||||
// IdentityProviderTypeZitadel is the Zitadel identity provider
|
||||
IdentityProviderTypeZitadel IdentityProviderType = "zitadel"
|
||||
// IdentityProviderTypeEntra is the Microsoft Entra (Azure AD) identity provider
|
||||
IdentityProviderTypeEntra IdentityProviderType = "entra"
|
||||
// IdentityProviderTypeGoogle is the Google identity provider
|
||||
IdentityProviderTypeGoogle IdentityProviderType = "google"
|
||||
// IdentityProviderTypeOkta is the Okta identity provider
|
||||
IdentityProviderTypeOkta IdentityProviderType = "okta"
|
||||
// IdentityProviderTypePocketID is the PocketID identity provider
|
||||
IdentityProviderTypePocketID IdentityProviderType = "pocketid"
|
||||
// IdentityProviderTypeMicrosoft is the Microsoft identity provider
|
||||
IdentityProviderTypeMicrosoft IdentityProviderType = "microsoft"
|
||||
// IdentityProviderTypeAuthentik is the Authentik identity provider
|
||||
IdentityProviderTypeAuthentik IdentityProviderType = "authentik"
|
||||
// IdentityProviderTypeKeycloak is the Keycloak identity provider
|
||||
IdentityProviderTypeKeycloak IdentityProviderType = "keycloak"
|
||||
)
|
||||
|
||||
// IdentityProvider represents an identity provider configuration
|
||||
type IdentityProvider struct {
|
||||
// ID is the unique identifier of the identity provider
|
||||
ID string `gorm:"primaryKey"`
|
||||
// AccountID is a reference to Account that this object belongs
|
||||
AccountID string `json:"-" gorm:"index"`
|
||||
// Type is the type of identity provider
|
||||
Type IdentityProviderType
|
||||
// Name is a human-readable name for the identity provider
|
||||
Name string
|
||||
// Issuer is the OIDC issuer URL
|
||||
Issuer string
|
||||
// ClientID is the OAuth2 client ID
|
||||
ClientID string
|
||||
// ClientSecret is the OAuth2 client secret
|
||||
ClientSecret string
|
||||
}
|
||||
|
||||
// Copy returns a copy of the IdentityProvider
|
||||
func (idp *IdentityProvider) Copy() *IdentityProvider {
|
||||
return &IdentityProvider{
|
||||
ID: idp.ID,
|
||||
AccountID: idp.AccountID,
|
||||
Type: idp.Type,
|
||||
Name: idp.Name,
|
||||
Issuer: idp.Issuer,
|
||||
ClientID: idp.ClientID,
|
||||
ClientSecret: idp.ClientSecret,
|
||||
}
|
||||
}
|
||||
|
||||
// EventMeta returns a map of metadata for activity events
|
||||
func (idp *IdentityProvider) EventMeta() map[string]any {
|
||||
return map[string]any{
|
||||
"name": idp.Name,
|
||||
"type": string(idp.Type),
|
||||
"issuer": idp.Issuer,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate validates the identity provider configuration
|
||||
func (idp *IdentityProvider) Validate() error {
|
||||
if idp.Name == "" {
|
||||
return ErrIdentityProviderNameRequired
|
||||
}
|
||||
if idp.Type == "" {
|
||||
return ErrIdentityProviderTypeRequired
|
||||
}
|
||||
if !idp.Type.IsValid() {
|
||||
return ErrIdentityProviderTypeUnsupported
|
||||
}
|
||||
if !idp.Type.HasBuiltInIssuer() && idp.Issuer == "" {
|
||||
return ErrIdentityProviderIssuerRequired
|
||||
}
|
||||
if idp.Issuer != "" {
|
||||
parsedURL, err := url.Parse(idp.Issuer)
|
||||
if err != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
|
||||
return ErrIdentityProviderIssuerInvalid
|
||||
}
|
||||
}
|
||||
if idp.ClientID == "" {
|
||||
return ErrIdentityProviderClientIDRequired
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsValid checks if the given type is a supported identity provider type
|
||||
func (t IdentityProviderType) IsValid() bool {
|
||||
switch t {
|
||||
case IdentityProviderTypeOIDC, IdentityProviderTypeZitadel, IdentityProviderTypeEntra,
|
||||
IdentityProviderTypeGoogle, IdentityProviderTypeOkta, IdentityProviderTypePocketID,
|
||||
IdentityProviderTypeMicrosoft, IdentityProviderTypeAuthentik, IdentityProviderTypeKeycloak:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HasBuiltInIssuer returns true for types that don't require an issuer URL
|
||||
func (t IdentityProviderType) HasBuiltInIssuer() bool {
|
||||
return t == IdentityProviderTypeGoogle || t == IdentityProviderTypeMicrosoft
|
||||
}
|
||||
137
management/server/types/identity_provider_test.go
Normal file
137
management/server/types/identity_provider_test.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestIdentityProvider_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
idp *IdentityProvider
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "valid OIDC provider",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "https://example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "valid OIDC provider with path",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "https://example.com/oauth2/issuer",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "missing name",
|
||||
idp: &IdentityProvider{
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "https://example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderNameRequired,
|
||||
},
|
||||
{
|
||||
name: "missing type",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Issuer: "https://example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderTypeRequired,
|
||||
},
|
||||
{
|
||||
name: "invalid type",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: "invalid",
|
||||
Issuer: "https://example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderTypeUnsupported,
|
||||
},
|
||||
{
|
||||
name: "missing issuer for OIDC",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderIssuerRequired,
|
||||
},
|
||||
{
|
||||
name: "invalid issuer URL - no scheme",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "example.com",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderIssuerInvalid,
|
||||
},
|
||||
{
|
||||
name: "invalid issuer URL - no host",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "https://",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderIssuerInvalid,
|
||||
},
|
||||
{
|
||||
name: "invalid issuer URL - just path",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "/oauth2/issuer",
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderIssuerInvalid,
|
||||
},
|
||||
{
|
||||
name: "missing client ID",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Test Provider",
|
||||
Type: IdentityProviderTypeOIDC,
|
||||
Issuer: "https://example.com",
|
||||
},
|
||||
expectedErr: ErrIdentityProviderClientIDRequired,
|
||||
},
|
||||
{
|
||||
name: "Google provider without issuer is valid",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Google SSO",
|
||||
Type: IdentityProviderTypeGoogle,
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "Microsoft provider without issuer is valid",
|
||||
idp: &IdentityProvider{
|
||||
Name: "Microsoft SSO",
|
||||
Type: IdentityProviderTypeMicrosoft,
|
||||
ClientID: "client-id",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.idp.Validate()
|
||||
assert.Equal(t, tt.expectedErr, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -39,6 +39,8 @@ type NetworkMap struct {
|
||||
FirewallRules []*FirewallRule
|
||||
RoutesFirewallRules []*RouteFirewallRule
|
||||
ForwardingRules []*ForwardingRule
|
||||
AuthorizedUsers map[string]map[string]struct{}
|
||||
EnableSSH bool
|
||||
}
|
||||
|
||||
func (nm *NetworkMap) Merge(other *NetworkMap) {
|
||||
|
||||
@@ -47,14 +47,21 @@ func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
|
||||
if a.NetworkMapCache == nil {
|
||||
return nil
|
||||
}
|
||||
return a.NetworkMapCache.OnPeerAddedIncremental(peerId)
|
||||
return a.NetworkMapCache.OnPeerAddedIncremental(a, peerId)
|
||||
}
|
||||
|
||||
func (a *Account) OnPeersAddedUpdNetworkMapCache(peerIds ...string) {
|
||||
if a.NetworkMapCache == nil {
|
||||
return
|
||||
}
|
||||
a.NetworkMapCache.EnqueuePeersForIncrementalAdd(a, peerIds...)
|
||||
}
|
||||
|
||||
func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error {
|
||||
if a.NetworkMapCache == nil {
|
||||
return nil
|
||||
}
|
||||
return a.NetworkMapCache.OnPeerDeleted(peerId)
|
||||
return a.NetworkMapCache.OnPeerDeleted(a, peerId)
|
||||
}
|
||||
|
||||
func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) {
|
||||
|
||||
@@ -25,15 +25,12 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// update flag is used to update the golden file.
|
||||
// example: go test ./... -v -update
|
||||
// var update = flag.Bool("update", false, "update golden files")
|
||||
|
||||
const (
|
||||
numPeers = 100
|
||||
devGroupID = "group-dev"
|
||||
opsGroupID = "group-ops"
|
||||
allGroupID = "group-all"
|
||||
sshUsersGroupID = "group-ssh-users"
|
||||
routeID = route.ID("route-main")
|
||||
routeHA1ID = route.ID("route-ha-1")
|
||||
routeHA2ID = route.ID("route-ha-2")
|
||||
@@ -41,6 +38,7 @@ const (
|
||||
policyIDAll = "policy-all"
|
||||
policyIDPosture = "policy-posture"
|
||||
policyIDDrop = "policy-drop"
|
||||
policyIDSSH = "policy-ssh"
|
||||
postureCheckID = "posture-check-ver"
|
||||
networkResourceID = "res-database"
|
||||
networkID = "net-database"
|
||||
@@ -51,6 +49,9 @@ const (
|
||||
offlinePeerID = "peer-99" // This peer will be completely offline.
|
||||
routingPeerID = "peer-95" // This peer is used for routing, it has a route to the network.
|
||||
testAccountID = "account-golden-test"
|
||||
userAdminID = "user-admin"
|
||||
userDevID = "user-dev"
|
||||
userOpsID = "user-ops"
|
||||
)
|
||||
|
||||
func TestGetPeerNetworkMap_Golden(t *testing.T) {
|
||||
@@ -69,61 +70,34 @@ func TestGetPeerNetworkMap_Golden(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
jsonData, err := json.MarshalIndent(networkMap, "", " ")
|
||||
require.NoError(t, err, "error marshaling network map to JSON")
|
||||
|
||||
goldenFilePath := filepath.Join("testdata", "networkmap_golden.json")
|
||||
|
||||
t.Log("Update golden file...")
|
||||
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(goldenFilePath, jsonData, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedJSON, err := os.ReadFile(goldenFilePath)
|
||||
require.NoError(t, err, "error reading golden file")
|
||||
|
||||
require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from OLD method does not match golden file")
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMap_Golden_New(t *testing.T) {
|
||||
account := createTestAccountWithEntities()
|
||||
|
||||
ctx := context.Background()
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
|
||||
if peerID == offlinePeerID {
|
||||
continue
|
||||
}
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
}
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
if string(legacyJSON) != string(newJSON) {
|
||||
legacyFilePath := filepath.Join("testdata", "networkmap_golden.json")
|
||||
newFilePath := filepath.Join("testdata", "networkmap_golden_new.json")
|
||||
|
||||
jsonData, err := json.MarshalIndent(networkMap, "", " ")
|
||||
require.NoError(t, err, "error marshaling network map to JSON")
|
||||
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new.json")
|
||||
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved legacy network map to %s", legacyFilePath)
|
||||
|
||||
t.Log("Update golden file...")
|
||||
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(goldenFilePath, jsonData, 0644)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(newFilePath, newJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved new network map to %s", newFilePath)
|
||||
|
||||
expectedJSON, err := os.ReadFile(goldenFilePath)
|
||||
require.NoError(t, err, "error reading golden file")
|
||||
|
||||
require.JSONEq(t, string(expectedJSON), string(jsonData), "resulted network map from NEW builder does not match golden file")
|
||||
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps from legacy and new builder do not match")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetPeerNetworkMap(b *testing.B) {
|
||||
@@ -141,7 +115,7 @@ func BenchmarkGetPeerNetworkMap(b *testing.B) {
|
||||
b.Run("old builder", func(b *testing.B) {
|
||||
for range b.N {
|
||||
for _, peerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
||||
_ = account.GetPeerNetworkMap(ctx, peerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -169,6 +143,8 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
}
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
|
||||
newPeerID := "peer-new-101"
|
||||
newPeerIP := net.IP{100, 64, 1, 1}
|
||||
newPeer := &nbpeer.Peer{
|
||||
@@ -201,92 +177,36 @@ func TestGetPeerNetworkMap_Golden_WithNewPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
jsonData, err := json.MarshalIndent(networkMap, "", " ")
|
||||
require.NoError(t, err, "error marshaling network map to JSON")
|
||||
|
||||
goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json")
|
||||
|
||||
t.Log("Update golden file with new peer...")
|
||||
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(goldenFilePath, jsonData, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedJSON, err := os.ReadFile(goldenFilePath)
|
||||
require.NoError(t, err, "error reading golden file")
|
||||
|
||||
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new peer does not match golden file")
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMap_Golden_New_WithOnPeerAdded(t *testing.T) {
|
||||
account := createTestAccountWithEntities()
|
||||
|
||||
ctx := context.Background()
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
if peerID == offlinePeerID {
|
||||
continue
|
||||
}
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
}
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
|
||||
newPeerID := "peer-new-101"
|
||||
newPeerIP := net.IP{100, 64, 1, 1}
|
||||
newPeer := &nbpeer.Peer{
|
||||
ID: newPeerID,
|
||||
IP: newPeerIP,
|
||||
Key: fmt.Sprintf("key-%s", newPeerID),
|
||||
DNSLabel: "peernew101",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
UserID: "user-admin",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
|
||||
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
|
||||
}
|
||||
|
||||
account.Peers[newPeerID] = newPeer
|
||||
|
||||
if devGroup, exists := account.Groups[devGroupID]; exists {
|
||||
devGroup.Peers = append(devGroup.Peers, newPeerID)
|
||||
}
|
||||
|
||||
if allGroup, exists := account.Groups[allGroupID]; exists {
|
||||
allGroup.Peers = append(allGroup.Peers, newPeerID)
|
||||
}
|
||||
|
||||
validatedPeersMap[newPeerID] = struct{}{}
|
||||
|
||||
if account.Network != nil {
|
||||
account.Network.Serial++
|
||||
}
|
||||
|
||||
err := builder.OnPeerAddedIncremental(newPeerID)
|
||||
err = builder.OnPeerAddedIncremental(account, newPeerID)
|
||||
require.NoError(t, err, "error adding peer to cache")
|
||||
|
||||
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
if string(legacyJSON) != string(newJSON) {
|
||||
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_peer.json")
|
||||
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json")
|
||||
|
||||
jsonData, err := json.MarshalIndent(networkMap, "", " ")
|
||||
require.NoError(t, err, "error marshaling network map to JSON")
|
||||
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded.json")
|
||||
t.Log("Update golden file with OnPeerAdded...")
|
||||
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(goldenFilePath, jsonData, 0644)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved legacy network map to %s", legacyFilePath)
|
||||
|
||||
expectedJSON, err := os.ReadFile(goldenFilePath)
|
||||
require.NoError(t, err, "error reading golden file")
|
||||
err = os.WriteFile(newFilePath, newJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved new network map to %s", newFilePath)
|
||||
|
||||
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded does not match golden file")
|
||||
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new peer from legacy and new builder do not match")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
||||
@@ -320,7 +240,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
||||
b.Run("old builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -328,7 +248,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerAdded(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
b.Run("new builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = builder.OnPeerAddedIncremental(newPeerID)
|
||||
_ = builder.OnPeerAddedIncremental(account, newPeerID)
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
}
|
||||
@@ -349,6 +269,8 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
}
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
|
||||
newRouterID := "peer-new-router-102"
|
||||
newRouterIP := net.IP{100, 64, 1, 2}
|
||||
newRouter := &nbpeer.Peer{
|
||||
@@ -395,106 +317,36 @@ func TestGetPeerNetworkMap_Golden_WithNewRoutingPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
jsonData, err := json.MarshalIndent(networkMap, "", " ")
|
||||
require.NoError(t, err, "error marshaling network map to JSON")
|
||||
|
||||
goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json")
|
||||
|
||||
t.Log("Update golden file with new router...")
|
||||
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(goldenFilePath, jsonData, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedJSON, err := os.ReadFile(goldenFilePath)
|
||||
require.NoError(t, err, "error reading golden file")
|
||||
|
||||
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with new router does not match golden file")
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter(t *testing.T) {
|
||||
account := createTestAccountWithEntities()
|
||||
|
||||
ctx := context.Background()
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
if peerID == offlinePeerID {
|
||||
continue
|
||||
}
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
}
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
|
||||
newRouterID := "peer-new-router-102"
|
||||
newRouterIP := net.IP{100, 64, 1, 2}
|
||||
newRouter := &nbpeer.Peer{
|
||||
ID: newRouterID,
|
||||
IP: newRouterIP,
|
||||
Key: fmt.Sprintf("key-%s", newRouterID),
|
||||
DNSLabel: "newrouter102",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
UserID: "user-admin",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
|
||||
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
|
||||
}
|
||||
|
||||
account.Peers[newRouterID] = newRouter
|
||||
|
||||
if opsGroup, exists := account.Groups[opsGroupID]; exists {
|
||||
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
|
||||
}
|
||||
if allGroup, exists := account.Groups[allGroupID]; exists {
|
||||
allGroup.Peers = append(allGroup.Peers, newRouterID)
|
||||
}
|
||||
|
||||
newRoute := &route.Route{
|
||||
ID: route.ID("route-new-router"),
|
||||
Network: netip.MustParsePrefix("172.16.0.0/24"),
|
||||
Peer: newRouter.Key,
|
||||
PeerID: newRouterID,
|
||||
Description: "Route from new router",
|
||||
Enabled: true,
|
||||
PeerGroups: []string{opsGroupID},
|
||||
Groups: []string{devGroupID, opsGroupID},
|
||||
AccessControlGroups: []string{devGroupID},
|
||||
AccountID: account.Id,
|
||||
}
|
||||
account.Routes[newRoute.ID] = newRoute
|
||||
|
||||
validatedPeersMap[newRouterID] = struct{}{}
|
||||
|
||||
if account.Network != nil {
|
||||
account.Network.Serial++
|
||||
}
|
||||
|
||||
err := builder.OnPeerAddedIncremental(newRouterID)
|
||||
err = builder.OnPeerAddedIncremental(account, newRouterID)
|
||||
require.NoError(t, err, "error adding router to cache")
|
||||
|
||||
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
if string(legacyJSON) != string(newJSON) {
|
||||
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_new_router.json")
|
||||
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
|
||||
|
||||
jsonData, err := json.MarshalIndent(networkMap, "", " ")
|
||||
require.NoError(t, err, "error marshaling network map to JSON")
|
||||
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
|
||||
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved legacy network map to %s", legacyFilePath)
|
||||
|
||||
t.Log("Update golden file with OnPeerAdded router...")
|
||||
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(goldenFilePath, jsonData, 0644)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(newFilePath, newJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved new network map to %s", newFilePath)
|
||||
|
||||
expectedJSON, err := os.ReadFile(goldenFilePath)
|
||||
require.NoError(t, err, "error reading golden file")
|
||||
|
||||
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file")
|
||||
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with new router from legacy and new builder do not match")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
||||
@@ -550,7 +402,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
||||
b.Run("old builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -558,7 +410,7 @@ func BenchmarkGetPeerNetworkMap_AfterRouterPeerAdded(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
b.Run("new builder after add", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = builder.OnPeerAddedIncremental(newRouterID)
|
||||
_ = builder.OnPeerAddedIncremental(account, newRouterID)
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
}
|
||||
@@ -579,7 +431,9 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
}
|
||||
|
||||
deletedPeerID := "peer-25" // peer from devs group
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
|
||||
deletedPeerID := "peer-25"
|
||||
|
||||
delete(account.Peers, deletedPeerID)
|
||||
|
||||
@@ -604,85 +458,36 @@ func TestGetPeerNetworkMap_Golden_WithDeletedPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
jsonData, err := json.MarshalIndent(networkMap, "", " ")
|
||||
require.NoError(t, err, "error marshaling network map to JSON")
|
||||
|
||||
goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json")
|
||||
|
||||
t.Log("Update golden file with deleted peer...")
|
||||
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(goldenFilePath, jsonData, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedJSON, err := os.ReadFile(goldenFilePath)
|
||||
require.NoError(t, err, "error reading golden file")
|
||||
|
||||
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file")
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMap_Golden_New_WithOnPeerDeleted(t *testing.T) {
|
||||
account := createTestAccountWithEntities()
|
||||
|
||||
ctx := context.Background()
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
if peerID == offlinePeerID {
|
||||
continue
|
||||
}
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
}
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
|
||||
deletedPeerID := "peer-25" // devs group peer
|
||||
|
||||
delete(account.Peers, deletedPeerID)
|
||||
|
||||
if devGroup, exists := account.Groups[devGroupID]; exists {
|
||||
devGroup.Peers = slices.DeleteFunc(devGroup.Peers, func(id string) bool {
|
||||
return id == deletedPeerID
|
||||
})
|
||||
}
|
||||
|
||||
if allGroup, exists := account.Groups[allGroupID]; exists {
|
||||
allGroup.Peers = slices.DeleteFunc(allGroup.Peers, func(id string) bool {
|
||||
return id == deletedPeerID
|
||||
})
|
||||
}
|
||||
|
||||
delete(validatedPeersMap, deletedPeerID)
|
||||
|
||||
if account.Network != nil {
|
||||
account.Network.Serial++
|
||||
}
|
||||
|
||||
err := builder.OnPeerDeleted(deletedPeerID)
|
||||
err = builder.OnPeerDeleted(account, deletedPeerID)
|
||||
require.NoError(t, err, "error deleting peer from cache")
|
||||
|
||||
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
if string(legacyJSON) != string(newJSON) {
|
||||
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_peer.json")
|
||||
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json")
|
||||
|
||||
jsonData, err := json.MarshalIndent(networkMap, "", " ")
|
||||
require.NoError(t, err, "error marshaling network map to JSON")
|
||||
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeerdeleted.json")
|
||||
t.Log("Update golden file with OnPeerDeleted...")
|
||||
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(goldenFilePath, jsonData, 0644)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved legacy network map to %s", legacyFilePath)
|
||||
|
||||
expectedJSON, err := os.ReadFile(goldenFilePath)
|
||||
require.NoError(t, err, "error reading golden file")
|
||||
err = os.WriteFile(newFilePath, newJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved new network map to %s", newFilePath)
|
||||
|
||||
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerDeleted does not match golden file")
|
||||
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted peer from legacy and new builder do not match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
||||
@@ -698,7 +503,9 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
}
|
||||
|
||||
deletedRouterID := "peer-75" // router peer
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
|
||||
deletedRouterID := "peer-75"
|
||||
|
||||
var affectedRoute *route.Route
|
||||
for _, r := range account.Routes {
|
||||
@@ -730,93 +537,36 @@ func TestGetPeerNetworkMap_Golden_WithDeletedRouterPeer(t *testing.T) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
|
||||
networkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil)
|
||||
legacyNetworkMap := account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, resourcePolicies, routers, nil, account.GetActiveGroupUsers())
|
||||
normalizeAndSortNetworkMap(legacyNetworkMap)
|
||||
legacyJSON, err := json.MarshalIndent(toNetworkMapJSON(legacyNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling legacy network map to JSON")
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
jsonData, err := json.MarshalIndent(networkMap, "", " ")
|
||||
require.NoError(t, err, "error marshaling network map to JSON")
|
||||
|
||||
goldenFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json")
|
||||
|
||||
t.Log("Update golden file with deleted peer...")
|
||||
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(goldenFilePath, jsonData, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedJSON, err := os.ReadFile(goldenFilePath)
|
||||
require.NoError(t, err, "error reading golden file")
|
||||
|
||||
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from OLD method with deleted peer does not match golden file")
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMap_Golden_New_WithDeletedRouterPeer(t *testing.T) {
|
||||
account := createTestAccountWithEntities()
|
||||
|
||||
ctx := context.Background()
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
if peerID == offlinePeerID {
|
||||
continue
|
||||
}
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
}
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
|
||||
deletedRouterID := "peer-75" // router peer
|
||||
|
||||
var affectedRoute *route.Route
|
||||
for _, r := range account.Routes {
|
||||
if r.PeerID == deletedRouterID {
|
||||
affectedRoute = r
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotNil(t, affectedRoute, "Router peer should have a route")
|
||||
|
||||
for _, group := range account.Groups {
|
||||
group.Peers = slices.DeleteFunc(group.Peers, func(id string) bool {
|
||||
return id == deletedRouterID
|
||||
})
|
||||
}
|
||||
for routeID, r := range account.Routes {
|
||||
if r.Peer == account.Peers[deletedRouterID].Key || r.PeerID == deletedRouterID {
|
||||
delete(account.Routes, routeID)
|
||||
}
|
||||
}
|
||||
delete(account.Peers, deletedRouterID)
|
||||
delete(validatedPeersMap, deletedRouterID)
|
||||
|
||||
if account.Network != nil {
|
||||
account.Network.Serial++
|
||||
}
|
||||
|
||||
err := builder.OnPeerDeleted(deletedRouterID)
|
||||
err = builder.OnPeerDeleted(account, deletedRouterID)
|
||||
require.NoError(t, err, "error deleting routing peer from cache")
|
||||
|
||||
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
newNetworkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
normalizeAndSortNetworkMap(newNetworkMap)
|
||||
newJSON, err := json.MarshalIndent(toNetworkMapJSON(newNetworkMap), "", " ")
|
||||
require.NoError(t, err, "error marshaling new network map to JSON")
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
if string(legacyJSON) != string(newJSON) {
|
||||
legacyFilePath := filepath.Join("testdata", "networkmap_golden_with_deleted_router_peer.json")
|
||||
newFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json")
|
||||
|
||||
jsonData, err := json.MarshalIndent(networkMap, "", " ")
|
||||
require.NoError(t, err)
|
||||
err = os.MkdirAll(filepath.Dir(legacyFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
|
||||
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_deleted_router.json")
|
||||
err = os.WriteFile(legacyFilePath, legacyJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved legacy network map to %s", legacyFilePath)
|
||||
|
||||
t.Log("Update golden file with deleted router...")
|
||||
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(goldenFilePath, jsonData, 0644)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(newFilePath, newJSON, 0644)
|
||||
require.NoError(t, err)
|
||||
t.Logf("Saved new network map to %s", newFilePath)
|
||||
|
||||
expectedJSON, err := os.ReadFile(goldenFilePath)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.JSONEq(t, string(expectedJSON), string(jsonData),
|
||||
"network map after deleting router does not match golden file")
|
||||
require.JSONEq(t, string(legacyJSON), string(newJSON), "network maps with deleted router from legacy and new builder do not match")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
||||
@@ -847,7 +597,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
||||
b.Run("old builder after delete", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil)
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil, nil, nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -855,7 +605,7 @@ func BenchmarkGetPeerNetworkMap_AfterPeerDeleted(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
b.Run("new builder after delete", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = builder.OnPeerDeleted(deletedPeerID)
|
||||
_ = builder.OnPeerDeleted(account, deletedPeerID)
|
||||
for _, testingPeerID := range peerIDs {
|
||||
_ = builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
}
|
||||
@@ -927,6 +677,54 @@ func normalizeAndSortNetworkMap(networkMap *types.NetworkMap) {
|
||||
}
|
||||
}
|
||||
|
||||
type networkMapJSON struct {
|
||||
Peers []*nbpeer.Peer `json:"Peers"`
|
||||
Network *types.Network `json:"Network"`
|
||||
Routes []*route.Route `json:"Routes"`
|
||||
DNSConfig dns.Config `json:"DNSConfig"`
|
||||
OfflinePeers []*nbpeer.Peer `json:"OfflinePeers"`
|
||||
FirewallRules []*types.FirewallRule `json:"FirewallRules"`
|
||||
RoutesFirewallRules []*types.RouteFirewallRule `json:"RoutesFirewallRules"`
|
||||
ForwardingRules []*types.ForwardingRule `json:"ForwardingRules"`
|
||||
AuthorizedUsers map[string][]string `json:"AuthorizedUsers,omitempty"`
|
||||
EnableSSH bool `json:"EnableSSH"`
|
||||
}
|
||||
|
||||
func toNetworkMapJSON(nm *types.NetworkMap) *networkMapJSON {
|
||||
result := &networkMapJSON{
|
||||
Peers: nm.Peers,
|
||||
Network: nm.Network,
|
||||
Routes: nm.Routes,
|
||||
DNSConfig: nm.DNSConfig,
|
||||
OfflinePeers: nm.OfflinePeers,
|
||||
FirewallRules: nm.FirewallRules,
|
||||
RoutesFirewallRules: nm.RoutesFirewallRules,
|
||||
ForwardingRules: nm.ForwardingRules,
|
||||
EnableSSH: nm.EnableSSH,
|
||||
}
|
||||
|
||||
if len(nm.AuthorizedUsers) > 0 {
|
||||
result.AuthorizedUsers = make(map[string][]string)
|
||||
localUsers := make([]string, 0, len(nm.AuthorizedUsers))
|
||||
for localUser := range nm.AuthorizedUsers {
|
||||
localUsers = append(localUsers, localUser)
|
||||
}
|
||||
sort.Strings(localUsers)
|
||||
|
||||
for _, localUser := range localUsers {
|
||||
userIDs := nm.AuthorizedUsers[localUser]
|
||||
sortedUserIDs := make([]string, 0, len(userIDs))
|
||||
for userID := range userIDs {
|
||||
sortedUserIDs = append(sortedUserIDs, userID)
|
||||
}
|
||||
sort.Strings(sortedUserIDs)
|
||||
result.AuthorizedUsers[localUser] = sortedUserIDs
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func createTestAccountWithEntities() *types.Account {
|
||||
peers := make(map[string]*nbpeer.Peer)
|
||||
devGroupPeers, opsGroupPeers, allGroupPeers := []string{}, []string{}, []string{}
|
||||
@@ -962,9 +760,10 @@ func createTestAccountWithEntities() *types.Account {
|
||||
}
|
||||
|
||||
groups := map[string]*types.Group{
|
||||
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
|
||||
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
|
||||
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
|
||||
allGroupID: {ID: allGroupID, Name: "All", Peers: allGroupPeers},
|
||||
devGroupID: {ID: devGroupID, Name: "Developers", Peers: devGroupPeers},
|
||||
opsGroupID: {ID: opsGroupID, Name: "Operations", Peers: opsGroupPeers},
|
||||
sshUsersGroupID: {ID: sshUsersGroupID, Name: "SSH Users", Peers: []string{}},
|
||||
}
|
||||
|
||||
policies := []*types.Policy{
|
||||
@@ -1002,6 +801,15 @@ func createTestAccountWithEntities() *types.Account {
|
||||
Sources: []string{opsGroupID}, DestinationResource: types.Resource{ID: networkResourceID},
|
||||
}},
|
||||
},
|
||||
{
|
||||
ID: policyIDSSH, Name: "SSH Access Policy", Enabled: true,
|
||||
Rules: []*types.PolicyRule{{
|
||||
ID: policyIDSSH, Name: "Allow SSH to Ops", Enabled: true, Action: types.PolicyTrafficActionAccept,
|
||||
Protocol: types.PolicyRuleProtocolNetbirdSSH, Bidirectional: false,
|
||||
Sources: []string{devGroupID}, Destinations: []string{opsGroupID},
|
||||
AuthorizedGroups: map[string][]string{sshUsersGroupID: {"root", "admin"}},
|
||||
}},
|
||||
},
|
||||
}
|
||||
|
||||
routes := map[route.ID]*route.Route{
|
||||
@@ -1034,8 +842,15 @@ func createTestAccountWithEntities() *types.Account {
|
||||
},
|
||||
}
|
||||
|
||||
users := map[string]*types.User{
|
||||
userAdminID: {Id: userAdminID, Role: types.UserRoleAdmin, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{allGroupID}},
|
||||
userDevID: {Id: userDevID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, devGroupID}},
|
||||
userOpsID: {Id: userOpsID, Role: types.UserRoleUser, IsServiceUser: false, AccountID: testAccountID, AutoGroups: []string{sshUsersGroupID, opsGroupID}},
|
||||
}
|
||||
|
||||
account := &types.Account{
|
||||
Id: testAccountID, Peers: peers, Groups: groups, Policies: policies, Routes: routes,
|
||||
Users: users,
|
||||
Network: &types.Network{
|
||||
Identifier: "net-golden-test", Net: net.IPNet{IP: net.IP{100, 64, 0, 0}, Mask: net.CIDRMask(16, 32)}, Serial: 1,
|
||||
},
|
||||
@@ -1071,6 +886,88 @@ func createTestAccountWithEntities() *types.Account {
|
||||
return account
|
||||
}
|
||||
|
||||
func TestGetPeerNetworkMap_Golden_New_WithOnPeerAddedRouter_Batched(t *testing.T) {
|
||||
account := createTestAccountWithEntities()
|
||||
|
||||
ctx := context.Background()
|
||||
validatedPeersMap := make(map[string]struct{})
|
||||
for i := range numPeers {
|
||||
peerID := fmt.Sprintf("peer-%d", i)
|
||||
if peerID == offlinePeerID {
|
||||
continue
|
||||
}
|
||||
validatedPeersMap[peerID] = struct{}{}
|
||||
}
|
||||
|
||||
builder := types.NewNetworkMapBuilder(account, validatedPeersMap)
|
||||
|
||||
newRouterID := "peer-new-router-102"
|
||||
newRouterIP := net.IP{100, 64, 1, 2}
|
||||
newRouter := &nbpeer.Peer{
|
||||
ID: newRouterID,
|
||||
IP: newRouterIP,
|
||||
Key: fmt.Sprintf("key-%s", newRouterID),
|
||||
DNSLabel: "newrouter102",
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now()},
|
||||
UserID: "user-admin",
|
||||
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.26.0", GoOS: "linux"},
|
||||
LastLogin: func() *time.Time { t := time.Now(); return &t }(),
|
||||
}
|
||||
|
||||
account.Peers[newRouterID] = newRouter
|
||||
|
||||
if opsGroup, exists := account.Groups[opsGroupID]; exists {
|
||||
opsGroup.Peers = append(opsGroup.Peers, newRouterID)
|
||||
}
|
||||
if allGroup, exists := account.Groups[allGroupID]; exists {
|
||||
allGroup.Peers = append(allGroup.Peers, newRouterID)
|
||||
}
|
||||
|
||||
newRoute := &route.Route{
|
||||
ID: route.ID("route-new-router"),
|
||||
Network: netip.MustParsePrefix("172.16.0.0/24"),
|
||||
Peer: newRouter.Key,
|
||||
PeerID: newRouterID,
|
||||
Description: "Route from new router",
|
||||
Enabled: true,
|
||||
PeerGroups: []string{opsGroupID},
|
||||
Groups: []string{devGroupID, opsGroupID},
|
||||
AccessControlGroups: []string{devGroupID},
|
||||
AccountID: account.Id,
|
||||
}
|
||||
account.Routes[newRoute.ID] = newRoute
|
||||
|
||||
validatedPeersMap[newRouterID] = struct{}{}
|
||||
|
||||
if account.Network != nil {
|
||||
account.Network.Serial++
|
||||
}
|
||||
|
||||
builder.EnqueuePeersForIncrementalAdd(account, newRouterID)
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
networkMap := builder.GetPeerNetworkMap(ctx, testingPeerID, dns.CustomZone{}, validatedPeersMap, nil)
|
||||
|
||||
normalizeAndSortNetworkMap(networkMap)
|
||||
|
||||
jsonData, err := json.MarshalIndent(networkMap, "", " ")
|
||||
require.NoError(t, err, "error marshaling network map to JSON")
|
||||
|
||||
goldenFilePath := filepath.Join("testdata", "networkmap_golden_new_with_onpeeradded_router.json")
|
||||
|
||||
t.Log("Update golden file with OnPeerAdded router...")
|
||||
err = os.MkdirAll(filepath.Dir(goldenFilePath), 0755)
|
||||
require.NoError(t, err)
|
||||
err = os.WriteFile(goldenFilePath, jsonData, 0644)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedJSON, err := os.ReadFile(goldenFilePath)
|
||||
require.NoError(t, err, "error reading golden file")
|
||||
|
||||
require.JSONEq(t, string(expectedJSON), string(jsonData), "network map from NEW builder with OnPeerAdded router does not match golden file")
|
||||
}
|
||||
|
||||
func createAccountFromFile() (*types.Account, error) {
|
||||
accraw := filepath.Join("testdata", "account_cnlf3j3l0ubs738o5d4g.json")
|
||||
data, err := os.ReadFile(accraw)
|
||||
@@ -1343,7 +1240,7 @@ func BenchmarkGetPeerNetworkMapCompactCached(b *testing.B) {
|
||||
b.Run("Legacy", func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, customZone, validatedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
|
||||
_ = account.GetPeerNetworkMap(ctx, testingPeerID, customZone, validatedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil, account.GetActiveGroupUsers())
|
||||
}
|
||||
})
|
||||
b.Run("LegacyCompacted", func(b *testing.B) {
|
||||
|
||||
@@ -7,12 +7,12 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
@@ -27,6 +27,9 @@ const (
|
||||
v6AllWildcard = "::/0"
|
||||
fw = "fw:"
|
||||
rfw = "route-fw:"
|
||||
|
||||
szAddPeerBatch = 10
|
||||
maxPeerAddRetries = 20
|
||||
)
|
||||
|
||||
type NetworkMapCache struct {
|
||||
@@ -47,6 +50,10 @@ type NetworkMapCache struct {
|
||||
peerDNS map[string]*nbdns.Config
|
||||
peerFirewallRulesCompact map[string][]*FirewallRule
|
||||
peerRoutesCompact map[string][]*route.Route
|
||||
peerSSH map[string]*PeerSSHView
|
||||
|
||||
groupIDToUserIDs map[string][]string
|
||||
allowedUserIDs map[string]struct{}
|
||||
|
||||
resourceRouters map[string]map[string]*routerTypes.NetworkRouter
|
||||
resourcePolicies map[string][]*Policy
|
||||
@@ -76,41 +83,64 @@ type PeerRoutesView struct {
|
||||
RouteFirewallRuleIDs []string
|
||||
}
|
||||
|
||||
type PeerSSHView struct {
|
||||
EnableSSH bool
|
||||
AuthorizedUsers map[string]map[string]struct{}
|
||||
}
|
||||
|
||||
type NetworkMapBuilder struct {
|
||||
account atomic.Pointer[Account]
|
||||
account *Account
|
||||
cache *NetworkMapCache
|
||||
validatedPeers map[string]struct{}
|
||||
|
||||
apb addPeerBatch
|
||||
}
|
||||
|
||||
type addPeerBatch struct {
|
||||
mu sync.Mutex
|
||||
sg *sync.Cond
|
||||
ids []string
|
||||
la *Account
|
||||
retryCount map[string]int
|
||||
}
|
||||
|
||||
func NewNetworkMapBuilder(account *Account, validatedPeers map[string]struct{}) *NetworkMapBuilder {
|
||||
builder := &NetworkMapBuilder{
|
||||
cache: &NetworkMapCache{
|
||||
globalRoutes: make(map[route.ID]*route.Route),
|
||||
globalRules: make(map[string]*FirewallRule),
|
||||
globalRouteRules: make(map[string]*RouteFirewallRule),
|
||||
globalPeers: make(map[string]*nbpeer.Peer),
|
||||
groupToPeers: make(map[string][]string),
|
||||
peerToGroups: make(map[string][]string),
|
||||
policyToRules: make(map[string][]*PolicyRule),
|
||||
groupToPolicies: make(map[string][]*Policy),
|
||||
groupToRoutes: make(map[string][]*route.Route),
|
||||
peerToRoutes: make(map[string][]*route.Route),
|
||||
globalRoutes: make(map[route.ID]*route.Route),
|
||||
globalRules: make(map[string]*FirewallRule),
|
||||
globalRouteRules: make(map[string]*RouteFirewallRule),
|
||||
globalPeers: make(map[string]*nbpeer.Peer),
|
||||
groupToPeers: make(map[string][]string),
|
||||
peerToGroups: make(map[string][]string),
|
||||
policyToRules: make(map[string][]*PolicyRule),
|
||||
groupToPolicies: make(map[string][]*Policy),
|
||||
groupToRoutes: make(map[string][]*route.Route),
|
||||
peerToRoutes: make(map[string][]*route.Route),
|
||||
peerACLs: make(map[string]*PeerACLView),
|
||||
peerRoutes: make(map[string]*PeerRoutesView),
|
||||
peerDNS: make(map[string]*nbdns.Config),
|
||||
peerSSH: make(map[string]*PeerSSHView),
|
||||
groupIDToUserIDs: make(map[string][]string),
|
||||
allowedUserIDs: make(map[string]struct{}),
|
||||
peerFirewallRulesCompact: make(map[string][]*FirewallRule),
|
||||
peerRoutesCompact: make(map[string][]*route.Route),
|
||||
globalResources: make(map[string]*resourceTypes.NetworkResource),
|
||||
acgToRoutes: make(map[string]map[route.ID]*RouteOwnerInfo),
|
||||
noACGRoutes: make(map[route.ID]*RouteOwnerInfo),
|
||||
acgToRoutes: make(map[string]map[route.ID]*RouteOwnerInfo),
|
||||
noACGRoutes: make(map[route.ID]*RouteOwnerInfo),
|
||||
},
|
||||
validatedPeers: make(map[string]struct{}),
|
||||
}
|
||||
builder.account.Store(account)
|
||||
builder.apb.sg = sync.NewCond(&builder.apb.mu)
|
||||
builder.apb.ids = make([]string, 0, szAddPeerBatch)
|
||||
builder.apb.la = account
|
||||
builder.apb.retryCount = make(map[string]int)
|
||||
|
||||
maps.Copy(builder.validatedPeers, validatedPeers)
|
||||
|
||||
builder.initialBuild(account)
|
||||
|
||||
go builder.incAddPeerLoop()
|
||||
return builder
|
||||
}
|
||||
|
||||
@@ -118,6 +148,8 @@ func (b *NetworkMapBuilder) initialBuild(account *Account) {
|
||||
b.cache.mu.Lock()
|
||||
defer b.cache.mu.Unlock()
|
||||
|
||||
b.account = account
|
||||
|
||||
start := time.Now()
|
||||
|
||||
b.buildGlobalIndexes(account)
|
||||
@@ -150,9 +182,15 @@ func (b *NetworkMapBuilder) buildGlobalIndexes(account *Account) {
|
||||
clear(b.cache.peerToRoutes)
|
||||
clear(b.cache.acgToRoutes)
|
||||
clear(b.cache.noACGRoutes)
|
||||
clear(b.cache.groupIDToUserIDs)
|
||||
clear(b.cache.allowedUserIDs)
|
||||
clear(b.cache.peerSSH)
|
||||
|
||||
maps.Copy(b.cache.globalPeers, account.Peers)
|
||||
|
||||
b.cache.groupIDToUserIDs = account.GetActiveGroupUsers()
|
||||
b.cache.allowedUserIDs = b.buildAllowedUserIDs(account)
|
||||
|
||||
for groupID, group := range account.Groups {
|
||||
peersCopy := make([]string, len(group.Peers))
|
||||
copy(peersCopy, group.Peers)
|
||||
@@ -227,7 +265,7 @@ func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) {
|
||||
return
|
||||
}
|
||||
|
||||
allPotentialPeers, firewallRules := b.getPeerConnectionResources(account, peer, b.validatedPeers)
|
||||
allPotentialPeers, firewallRules, authorizedUsers, sshEnabled := b.getPeerConnectionResources(account, peer, b.validatedPeers)
|
||||
|
||||
isRouter, networkResourcesRoutes, sourcePeers := b.getNetworkResourcesForPeer(account, peer)
|
||||
|
||||
@@ -260,12 +298,17 @@ func (b *NetworkMapBuilder) buildPeerACLView(account *Account, peerID string) {
|
||||
|
||||
compactedRules := compactFirewallRules(firewallRules)
|
||||
b.cache.peerFirewallRulesCompact[peerID] = compactedRules
|
||||
b.cache.peerSSH[peerID] = &PeerSSHView{
|
||||
EnableSSH: sshEnabled,
|
||||
AuthorizedUsers: authorizedUsers,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *nbpeer.Peer,
|
||||
validatedPeersMap map[string]struct{},
|
||||
) ([]*nbpeer.Peer, []*FirewallRule) {
|
||||
) ([]*nbpeer.Peer, []*FirewallRule, map[string]map[string]struct{}, bool) {
|
||||
peerID := peer.ID
|
||||
ctx := context.Background()
|
||||
|
||||
peerGroups := b.cache.peerToGroups[peerID]
|
||||
peerGroupsMap := make(map[string]struct{}, len(peerGroups))
|
||||
@@ -278,9 +321,15 @@ func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *n
|
||||
fwRules := make([]*FirewallRule, 0)
|
||||
peers := make([]*nbpeer.Peer, 0)
|
||||
|
||||
authorizedUsers := make(map[string]map[string]struct{})
|
||||
sshEnabled := false
|
||||
|
||||
for _, group := range peerGroups {
|
||||
policies := b.cache.groupToPolicies[group]
|
||||
for _, policy := range policies {
|
||||
if isValid := account.validatePostureChecksOnPeer(ctx, policy.SourcePostureChecks, peerID); !isValid {
|
||||
continue
|
||||
}
|
||||
rules := b.cache.policyToRules[policy.ID]
|
||||
for _, rule := range rules {
|
||||
var sourcePeers, destinationPeers []*nbpeer.Peer
|
||||
@@ -323,13 +372,13 @@ func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *n
|
||||
if rule.Bidirectional {
|
||||
if peerInSources {
|
||||
b.generateResourcescached(
|
||||
account, rule, destinationPeers, FirewallRuleDirectionIN,
|
||||
rule, destinationPeers, FirewallRuleDirectionIN,
|
||||
peer, &peers, &fwRules, peersExists, rulesExists,
|
||||
)
|
||||
}
|
||||
if peerInDestinations {
|
||||
b.generateResourcescached(
|
||||
account, rule, sourcePeers, FirewallRuleDirectionOUT,
|
||||
rule, sourcePeers, FirewallRuleDirectionOUT,
|
||||
peer, &peers, &fwRules, peersExists, rulesExists,
|
||||
)
|
||||
}
|
||||
@@ -337,22 +386,58 @@ func (b *NetworkMapBuilder) getPeerConnectionResources(account *Account, peer *n
|
||||
|
||||
if peerInSources {
|
||||
b.generateResourcescached(
|
||||
account, rule, destinationPeers, FirewallRuleDirectionOUT,
|
||||
rule, destinationPeers, FirewallRuleDirectionOUT,
|
||||
peer, &peers, &fwRules, peersExists, rulesExists,
|
||||
)
|
||||
}
|
||||
|
||||
if peerInDestinations {
|
||||
b.generateResourcescached(
|
||||
account, rule, sourcePeers, FirewallRuleDirectionIN,
|
||||
rule, sourcePeers, FirewallRuleDirectionIN,
|
||||
peer, &peers, &fwRules, peersExists, rulesExists,
|
||||
)
|
||||
|
||||
if rule.Protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
sshEnabled = true
|
||||
switch {
|
||||
case len(rule.AuthorizedGroups) > 0:
|
||||
for groupID, localUsers := range rule.AuthorizedGroups {
|
||||
userIDs, ok := b.cache.groupIDToUserIDs[groupID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(localUsers) == 0 {
|
||||
localUsers = []string{auth.Wildcard}
|
||||
}
|
||||
|
||||
for _, localUser := range localUsers {
|
||||
if authorizedUsers[localUser] == nil {
|
||||
authorizedUsers[localUser] = make(map[string]struct{})
|
||||
}
|
||||
for _, userID := range userIDs {
|
||||
authorizedUsers[localUser][userID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
case rule.AuthorizedUser != "":
|
||||
if authorizedUsers[auth.Wildcard] == nil {
|
||||
authorizedUsers[auth.Wildcard] = make(map[string]struct{})
|
||||
}
|
||||
authorizedUsers[auth.Wildcard][rule.AuthorizedUser] = struct{}{}
|
||||
default:
|
||||
authorizedUsers[auth.Wildcard] = maps.Clone(b.cache.allowedUserIDs)
|
||||
}
|
||||
} else if policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
|
||||
sshEnabled = true
|
||||
authorizedUsers[auth.Wildcard] = maps.Clone(b.cache.allowedUserIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return peers, fwRules
|
||||
return peers, fwRules, authorizedUsers, sshEnabled
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) isPeerInGroupscached(groupIDs []string, peerGroupsMap map[string]struct{}) bool {
|
||||
@@ -405,14 +490,9 @@ func (b *NetworkMapBuilder) getPeersFromGroupscached(account *Account, groupIDs
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) generateResourcescached(
|
||||
account *Account, rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, targetPeer *nbpeer.Peer,
|
||||
rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int, targetPeer *nbpeer.Peer,
|
||||
peers *[]*nbpeer.Peer, rules *[]*FirewallRule, peersExists map[string]struct{}, rulesExists map[string]struct{},
|
||||
) {
|
||||
isAll := false
|
||||
if allGroup, err := account.GetGroupAll(); err == nil {
|
||||
isAll = (len(allGroup.Peers) - 1) == len(groupPeers)
|
||||
}
|
||||
|
||||
for _, peer := range groupPeers {
|
||||
if peer == nil {
|
||||
continue
|
||||
@@ -427,11 +507,7 @@ func (b *NetworkMapBuilder) generateResourcescached(
|
||||
PeerIP: peer.IP.String(),
|
||||
Direction: direction,
|
||||
Action: string(rule.Action),
|
||||
Protocol: string(rule.Protocol),
|
||||
}
|
||||
|
||||
if isAll {
|
||||
fr.PeerIP = allPeers
|
||||
Protocol: firewallRuleProtocol(rule.Protocol),
|
||||
}
|
||||
|
||||
var s strings.Builder
|
||||
@@ -942,8 +1018,29 @@ func (b *NetworkMapBuilder) getPeerNSGroups(account *Account, peerID string, che
|
||||
return peerNSGroups
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) UpdateAccountPointer(account *Account) {
|
||||
b.account.Store(account)
|
||||
func (b *NetworkMapBuilder) buildAllowedUserIDs(account *Account) map[string]struct{} {
|
||||
users := make(map[string]struct{})
|
||||
for _, nbUser := range account.Users {
|
||||
if !nbUser.IsBlocked() && !nbUser.IsServiceUser {
|
||||
users[nbUser.Id] = struct{}{}
|
||||
}
|
||||
}
|
||||
return users
|
||||
}
|
||||
|
||||
func firewallRuleProtocol(protocol PolicyRuleProtocolType) string {
|
||||
if protocol == PolicyRuleProtocolNetbirdSSH {
|
||||
return string(PolicyRuleProtocolTCP)
|
||||
}
|
||||
return string(protocol)
|
||||
}
|
||||
|
||||
// lock should be held
|
||||
func (b *NetworkMapBuilder) updateAccountLocked(account *Account) *Account {
|
||||
if account.Network.CurrentSerial() > b.account.Network.CurrentSerial() {
|
||||
b.account = account
|
||||
}
|
||||
return b.account
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) GetPeerNetworkMap(
|
||||
@@ -951,25 +1048,27 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap(
|
||||
validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
account := b.account.Load()
|
||||
|
||||
b.cache.mu.RLock()
|
||||
defer b.cache.mu.RUnlock()
|
||||
|
||||
account := b.account
|
||||
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
return &NetworkMap{Network: account.Network.Copy()}
|
||||
}
|
||||
|
||||
b.cache.mu.RLock()
|
||||
defer b.cache.mu.RUnlock()
|
||||
|
||||
aclView := b.cache.peerACLs[peerID]
|
||||
routesView := b.cache.peerRoutes[peerID]
|
||||
dnsConfig := b.cache.peerDNS[peerID]
|
||||
sshView := b.cache.peerSSH[peerID]
|
||||
|
||||
if aclView == nil || routesView == nil || dnsConfig == nil {
|
||||
return &NetworkMap{Network: account.Network.Copy()}
|
||||
}
|
||||
|
||||
nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, peersCustomZone, validatedPeers)
|
||||
nm := b.assembleNetworkMap(account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, validatedPeers)
|
||||
|
||||
if metrics != nil {
|
||||
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
|
||||
@@ -987,7 +1086,7 @@ func (b *NetworkMapBuilder) GetPeerNetworkMap(
|
||||
|
||||
func (b *NetworkMapBuilder) assembleNetworkMap(
|
||||
account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
|
||||
dnsConfig *nbdns.Config, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
|
||||
dnsConfig *nbdns.Config, sshView *PeerSSHView, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
|
||||
) *NetworkMap {
|
||||
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
@@ -1045,7 +1144,7 @@ func (b *NetworkMapBuilder) assembleNetworkMap(
|
||||
finalDNSConfig.CustomZones = zones
|
||||
}
|
||||
|
||||
return &NetworkMap{
|
||||
nm := &NetworkMap{
|
||||
Peers: peersToConnect,
|
||||
Network: account.Network.Copy(),
|
||||
Routes: routes,
|
||||
@@ -1054,6 +1153,13 @@ func (b *NetworkMapBuilder) assembleNetworkMap(
|
||||
FirewallRules: firewallRules,
|
||||
RoutesFirewallRules: routesFirewallRules,
|
||||
}
|
||||
|
||||
if sshView != nil {
|
||||
nm.EnableSSH = sshView.EnableSSH
|
||||
nm.AuthorizedUsers = sshView.AuthorizedUsers
|
||||
}
|
||||
|
||||
return nm
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) GetPeerNetworkMapCompactCached(
|
||||
@@ -1061,64 +1167,27 @@ func (b *NetworkMapBuilder) GetPeerNetworkMapCompactCached(
|
||||
validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
account := b.account.Load()
|
||||
|
||||
b.cache.mu.RLock()
|
||||
defer b.cache.mu.RUnlock()
|
||||
|
||||
account := b.account
|
||||
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
return &NetworkMap{Network: account.Network.Copy()}
|
||||
}
|
||||
|
||||
b.cache.mu.RLock()
|
||||
defer b.cache.mu.RUnlock()
|
||||
|
||||
aclView := b.cache.peerACLs[peerID]
|
||||
routesView := b.cache.peerRoutes[peerID]
|
||||
dnsConfig := b.cache.peerDNS[peerID]
|
||||
sshView := b.cache.peerSSH[peerID]
|
||||
|
||||
if aclView == nil || routesView == nil || dnsConfig == nil {
|
||||
return &NetworkMap{Network: account.Network.Copy()}
|
||||
}
|
||||
|
||||
nm := b.assembleNetworkMapCompactCached(account, peer, aclView, routesView, dnsConfig, peersCustomZone, validatedPeers)
|
||||
|
||||
if metrics != nil {
|
||||
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
|
||||
metrics.CountNetworkMapObjects(objectCount)
|
||||
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
if objectCount > 5000 {
|
||||
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from cache",
|
||||
account.Id, objectCount)
|
||||
}
|
||||
}
|
||||
|
||||
return nm
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) GetPeerNetworkMapCompact(
|
||||
ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone,
|
||||
validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
account := b.account.Load()
|
||||
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
return &NetworkMap{Network: account.Network.Copy()}
|
||||
}
|
||||
|
||||
b.cache.mu.RLock()
|
||||
defer b.cache.mu.RUnlock()
|
||||
|
||||
aclView := b.cache.peerACLs[peerID]
|
||||
routesView := b.cache.peerRoutes[peerID]
|
||||
dnsConfig := b.cache.peerDNS[peerID]
|
||||
|
||||
if aclView == nil || routesView == nil || dnsConfig == nil {
|
||||
return &NetworkMap{Network: account.Network.Copy()}
|
||||
}
|
||||
|
||||
nm := b.assembleNetworkMapCompact(account, peer, aclView, routesView, dnsConfig, peersCustomZone, validatedPeers)
|
||||
nm := b.assembleNetworkMapCompactCached(account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, validatedPeers)
|
||||
|
||||
if metrics != nil {
|
||||
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
|
||||
@@ -1136,7 +1205,7 @@ func (b *NetworkMapBuilder) GetPeerNetworkMapCompact(
|
||||
|
||||
func (b *NetworkMapBuilder) assembleNetworkMapCompactCached(
|
||||
account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
|
||||
dnsConfig *nbdns.Config, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
|
||||
dnsConfig *nbdns.Config, sshView *PeerSSHView, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
|
||||
) *NetworkMap {
|
||||
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
@@ -1182,7 +1251,7 @@ func (b *NetworkMapBuilder) assembleNetworkMapCompactCached(
|
||||
finalDNSConfig.CustomZones = zones
|
||||
}
|
||||
|
||||
return &NetworkMap{
|
||||
nm := &NetworkMap{
|
||||
Peers: peersToConnect,
|
||||
Network: account.Network.Copy(),
|
||||
Routes: routes,
|
||||
@@ -1191,11 +1260,59 @@ func (b *NetworkMapBuilder) assembleNetworkMapCompactCached(
|
||||
FirewallRules: firewallRules,
|
||||
RoutesFirewallRules: routesFirewallRules,
|
||||
}
|
||||
|
||||
if sshView != nil {
|
||||
nm.EnableSSH = sshView.EnableSSH
|
||||
nm.AuthorizedUsers = sshView.AuthorizedUsers
|
||||
}
|
||||
|
||||
return nm
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) GetPeerNetworkMapCompact(
|
||||
ctx context.Context, peerID string, peersCustomZone nbdns.CustomZone,
|
||||
validatedPeers map[string]struct{}, metrics *telemetry.AccountManagerMetrics,
|
||||
) *NetworkMap {
|
||||
start := time.Now()
|
||||
|
||||
b.cache.mu.RLock()
|
||||
defer b.cache.mu.RUnlock()
|
||||
|
||||
account := b.account
|
||||
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
return &NetworkMap{Network: account.Network.Copy()}
|
||||
}
|
||||
|
||||
aclView := b.cache.peerACLs[peerID]
|
||||
routesView := b.cache.peerRoutes[peerID]
|
||||
dnsConfig := b.cache.peerDNS[peerID]
|
||||
sshView := b.cache.peerSSH[peerID]
|
||||
|
||||
if aclView == nil || routesView == nil || dnsConfig == nil {
|
||||
return &NetworkMap{Network: account.Network.Copy()}
|
||||
}
|
||||
|
||||
nm := b.assembleNetworkMapCompact(account, peer, aclView, routesView, dnsConfig, sshView, peersCustomZone, validatedPeers)
|
||||
|
||||
if metrics != nil {
|
||||
objectCount := int64(len(nm.Peers) + len(nm.OfflinePeers) + len(nm.Routes) + len(nm.FirewallRules) + len(nm.RoutesFirewallRules))
|
||||
metrics.CountNetworkMapObjects(objectCount)
|
||||
metrics.CountGetPeerNetworkMapDuration(time.Since(start))
|
||||
|
||||
if objectCount > 5000 {
|
||||
log.WithContext(ctx).Tracef("account: %s has a total resource count of %d objects from cache",
|
||||
account.Id, objectCount)
|
||||
}
|
||||
}
|
||||
|
||||
return nm
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) assembleNetworkMapCompact(
|
||||
account *Account, peer *nbpeer.Peer, aclView *PeerACLView, routesView *PeerRoutesView,
|
||||
dnsConfig *nbdns.Config, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
|
||||
dnsConfig *nbdns.Config, sshView *PeerSSHView, customZone nbdns.CustomZone, validatedPeers map[string]struct{},
|
||||
) *NetworkMap {
|
||||
|
||||
var peersToConnect []*nbpeer.Peer
|
||||
@@ -1279,7 +1396,7 @@ func (b *NetworkMapBuilder) assembleNetworkMapCompact(
|
||||
finalDNSConfig.CustomZones = zones
|
||||
}
|
||||
|
||||
return &NetworkMap{
|
||||
nm := &NetworkMap{
|
||||
Peers: peersToConnect,
|
||||
Network: account.Network.Copy(),
|
||||
Routes: routes,
|
||||
@@ -1288,14 +1405,13 @@ func (b *NetworkMapBuilder) assembleNetworkMapCompact(
|
||||
FirewallRules: firewallRules,
|
||||
RoutesFirewallRules: routesFirewallRules,
|
||||
}
|
||||
}
|
||||
|
||||
func splitRouteAndPeer(r *route.Route) (string, string) {
|
||||
parts := strings.Split(string(r.ID), ":")
|
||||
if len(parts) < 2 {
|
||||
return string(r.ID), ""
|
||||
if sshView != nil {
|
||||
nm.EnableSSH = sshView.EnableSSH
|
||||
nm.AuthorizedUsers = sshView.AuthorizedUsers
|
||||
}
|
||||
return parts[0], parts[1]
|
||||
|
||||
return nm
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) compactRoutesForPeer(peerID string, ownRouteIDs []route.ID, otherRouteIDs []route.ID) []*route.Route {
|
||||
@@ -1427,6 +1543,14 @@ func compactFirewallRules(expandedRules []*FirewallRule) []*FirewallRule {
|
||||
return compactRules
|
||||
}
|
||||
|
||||
func splitRouteAndPeer(r *route.Route) (string, string) {
|
||||
parts := strings.Split(string(r.ID), ":")
|
||||
if len(parts) < 2 {
|
||||
return string(r.ID), ""
|
||||
}
|
||||
return parts[0], parts[1]
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) generateFirewallRuleID(rule *FirewallRule) string {
|
||||
var s strings.Builder
|
||||
s.WriteString(fw)
|
||||
@@ -1501,6 +1625,106 @@ func (b *NetworkMapBuilder) isPeerRouter(account *Account, peerID string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) incAddPeerLoop() {
|
||||
for {
|
||||
b.apb.mu.Lock()
|
||||
if len(b.apb.ids) == 0 {
|
||||
b.apb.sg.Wait()
|
||||
}
|
||||
b.addPeersIncrementally()
|
||||
b.apb.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// lock on b.apb level should be held
|
||||
func (b *NetworkMapBuilder) addPeersIncrementally() {
|
||||
peers := slices.Clone(b.apb.ids)
|
||||
clear(b.apb.ids)
|
||||
b.apb.ids = b.apb.ids[:0]
|
||||
latestAcc := b.apb.la
|
||||
b.apb.mu.Unlock()
|
||||
|
||||
tt := time.Now()
|
||||
b.cache.mu.Lock()
|
||||
defer b.cache.mu.Unlock()
|
||||
|
||||
account := b.updateAccountLocked(latestAcc)
|
||||
|
||||
log.Debugf("NetworkMapBuilder: Starting incremental add of %d peers", len(peers))
|
||||
|
||||
allUpdates := make(map[string]*PeerUpdateDelta)
|
||||
|
||||
for _, peerID := range peers {
|
||||
peer := account.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
b.apb.mu.Lock()
|
||||
retries := b.apb.retryCount[peerID]
|
||||
b.apb.mu.Unlock()
|
||||
|
||||
if retries >= maxPeerAddRetries {
|
||||
log.Errorf("NetworkMapBuilder: peer %s not found in account %s after %d retries, giving up", peerID, account.Id, retries)
|
||||
b.apb.mu.Lock()
|
||||
delete(b.apb.retryCount, peerID)
|
||||
b.apb.mu.Unlock()
|
||||
continue
|
||||
}
|
||||
|
||||
log.Warnf("NetworkMapBuilder: peer %s not found in account %s, retry %d/%d", peerID, account.Id, retries+1, maxPeerAddRetries)
|
||||
b.apb.mu.Lock()
|
||||
b.apb.retryCount[peerID] = retries + 1
|
||||
b.apb.mu.Unlock()
|
||||
b.enqueuePeersForIncrementalAdd(latestAcc, peerID)
|
||||
continue
|
||||
}
|
||||
|
||||
b.apb.mu.Lock()
|
||||
delete(b.apb.retryCount, peerID)
|
||||
b.apb.mu.Unlock()
|
||||
|
||||
b.validatedPeers[peerID] = struct{}{}
|
||||
b.cache.globalPeers[peerID] = peer
|
||||
|
||||
peerGroups := b.updateIndexesForNewPeer(account, peerID)
|
||||
b.buildPeerACLView(account, peerID)
|
||||
b.buildPeerRoutesView(account, peerID)
|
||||
b.buildPeerDNSView(account, peerID)
|
||||
|
||||
peerDeltas := b.collectDeltasForNewPeer(account, peerID, peerGroups)
|
||||
for affectedPeerID, delta := range peerDeltas {
|
||||
if existing, ok := allUpdates[affectedPeerID]; ok {
|
||||
existing.mergeFrom(delta)
|
||||
continue
|
||||
}
|
||||
allUpdates[affectedPeerID] = delta
|
||||
}
|
||||
}
|
||||
|
||||
for affectedPeerID, delta := range allUpdates {
|
||||
b.applyDeltaToPeer(account, affectedPeerID, delta)
|
||||
}
|
||||
|
||||
log.Debugf("NetworkMapBuilder: Added %d peers to cache, affected %d peers, took %s", len(peers), len(allUpdates), time.Since(tt))
|
||||
|
||||
b.apb.mu.Lock()
|
||||
if len(b.apb.ids) > 0 {
|
||||
b.apb.sg.Signal()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) enqueuePeersForIncrementalAdd(acc *Account, peerIDs ...string) {
|
||||
b.apb.mu.Lock()
|
||||
b.apb.ids = append(b.apb.ids, peerIDs...)
|
||||
if b.apb.la != nil && acc.Network.CurrentSerial() > b.apb.la.Network.CurrentSerial() {
|
||||
b.apb.la = acc
|
||||
}
|
||||
b.apb.sg.Signal()
|
||||
b.apb.mu.Unlock()
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) EnqueuePeersForIncrementalAdd(acc *Account, peerIDs ...string) {
|
||||
b.enqueuePeersForIncrementalAdd(acc, peerIDs...)
|
||||
}
|
||||
|
||||
type ViewDelta struct {
|
||||
AddedPeerIDs []string
|
||||
RemovedPeerIDs []string
|
||||
@@ -1508,17 +1732,18 @@ type ViewDelta struct {
|
||||
RemovedRuleIDs []string
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) OnPeerAddedIncremental(peerID string) error {
|
||||
func (b *NetworkMapBuilder) OnPeerAddedIncremental(acc *Account, peerID string) error {
|
||||
tt := time.Now()
|
||||
account := b.account.Load()
|
||||
peer := account.GetPeer(peerID)
|
||||
peer := acc.GetPeer(peerID)
|
||||
if peer == nil {
|
||||
return fmt.Errorf("peer %s not found in account", peerID)
|
||||
return fmt.Errorf("NetworkMapBuilder: peer %s not found in account", peerID)
|
||||
}
|
||||
|
||||
b.cache.mu.Lock()
|
||||
defer b.cache.mu.Unlock()
|
||||
|
||||
account := b.updateAccountLocked(acc)
|
||||
|
||||
log.Debugf("NetworkMapBuilder: Adding peer %s (IP: %s) to cache", peerID, peer.IP.String())
|
||||
|
||||
b.validatedPeers[peerID] = struct{}{}
|
||||
@@ -1577,6 +1802,13 @@ func (b *NetworkMapBuilder) updateIndexesForNewPeer(account *Account, peerID str
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) incrementalUpdateAffectedPeers(account *Account, newPeerID string, peerGroups []string) {
|
||||
updates := b.collectDeltasForNewPeer(account, newPeerID, peerGroups)
|
||||
for affectedPeerID, delta := range updates {
|
||||
b.applyDeltaToPeer(account, affectedPeerID, delta)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) collectDeltasForNewPeer(account *Account, newPeerID string, peerGroups []string) map[string]*PeerUpdateDelta {
|
||||
updates := b.calculateIncrementalUpdates(account, newPeerID, peerGroups)
|
||||
|
||||
if b.isPeerRouter(account, newPeerID) {
|
||||
@@ -1596,9 +1828,7 @@ func (b *NetworkMapBuilder) incrementalUpdateAffectedPeers(account *Account, new
|
||||
}
|
||||
}
|
||||
|
||||
for affectedPeerID, delta := range updates {
|
||||
b.applyDeltaToPeer(account, affectedPeerID, delta)
|
||||
}
|
||||
return updates
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) findPeersAffectedByNewRouter(account *Account, newRouterID string, routerGroups []string) map[string]struct{} {
|
||||
@@ -1792,8 +2022,8 @@ func (b *NetworkMapBuilder) calculateNewRouterNetworkResourceUpdates(
|
||||
updates[peerID] = delta
|
||||
}
|
||||
|
||||
if delta.AddConnectedPeer == "" {
|
||||
delta.AddConnectedPeer = newPeerID
|
||||
if !slices.Contains(delta.AddConnectedPeers, newPeerID) {
|
||||
delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID)
|
||||
}
|
||||
|
||||
delta.RebuildRoutesView = true
|
||||
@@ -1922,8 +2152,8 @@ func (b *NetworkMapBuilder) calculateNetworkResourceFirewallUpdates(
|
||||
updates[routerPeerID] = delta
|
||||
}
|
||||
|
||||
if delta.AddConnectedPeer == "" {
|
||||
delta.AddConnectedPeer = newPeerID
|
||||
if !slices.Contains(delta.AddConnectedPeers, newPeerID) {
|
||||
delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID)
|
||||
}
|
||||
|
||||
delta.RebuildRoutesView = true
|
||||
@@ -1933,13 +2163,63 @@ func (b *NetworkMapBuilder) calculateNetworkResourceFirewallUpdates(
|
||||
|
||||
type PeerUpdateDelta struct {
|
||||
PeerID string
|
||||
AddConnectedPeer string
|
||||
AddConnectedPeers []string
|
||||
AddFirewallRules []*FirewallRuleDelta
|
||||
AddRoutes []route.ID
|
||||
UpdateRouteFirewallRules []*RouteFirewallRuleUpdate
|
||||
UpdateDNS bool
|
||||
RebuildRoutesView bool
|
||||
}
|
||||
|
||||
func (d *PeerUpdateDelta) mergeFrom(other *PeerUpdateDelta) {
|
||||
for _, peerID := range other.AddConnectedPeers {
|
||||
if !slices.Contains(d.AddConnectedPeers, peerID) {
|
||||
d.AddConnectedPeers = append(d.AddConnectedPeers, peerID)
|
||||
}
|
||||
}
|
||||
|
||||
existingRuleIDs := make(map[string]struct{}, len(d.AddFirewallRules))
|
||||
for _, rule := range d.AddFirewallRules {
|
||||
existingRuleIDs[rule.RuleID] = struct{}{}
|
||||
}
|
||||
for _, rule := range other.AddFirewallRules {
|
||||
if _, exists := existingRuleIDs[rule.RuleID]; !exists {
|
||||
d.AddFirewallRules = append(d.AddFirewallRules, rule)
|
||||
existingRuleIDs[rule.RuleID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
for _, routeID := range other.AddRoutes {
|
||||
if !slices.Contains(d.AddRoutes, routeID) {
|
||||
d.AddRoutes = append(d.AddRoutes, routeID)
|
||||
}
|
||||
}
|
||||
|
||||
existingRouteUpdates := make(map[string]map[string]struct{})
|
||||
for _, update := range d.UpdateRouteFirewallRules {
|
||||
if existingRouteUpdates[update.RuleID] == nil {
|
||||
existingRouteUpdates[update.RuleID] = make(map[string]struct{})
|
||||
}
|
||||
existingRouteUpdates[update.RuleID][update.AddSourceIP] = struct{}{}
|
||||
}
|
||||
for _, update := range other.UpdateRouteFirewallRules {
|
||||
if existingRouteUpdates[update.RuleID] == nil {
|
||||
existingRouteUpdates[update.RuleID] = make(map[string]struct{})
|
||||
}
|
||||
if _, exists := existingRouteUpdates[update.RuleID][update.AddSourceIP]; !exists {
|
||||
d.UpdateRouteFirewallRules = append(d.UpdateRouteFirewallRules, update)
|
||||
existingRouteUpdates[update.RuleID][update.AddSourceIP] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if other.UpdateDNS {
|
||||
d.UpdateDNS = true
|
||||
}
|
||||
if other.RebuildRoutesView {
|
||||
d.RebuildRoutesView = true
|
||||
}
|
||||
}
|
||||
|
||||
type FirewallRuleDelta struct {
|
||||
Rule *FirewallRule
|
||||
RuleID string
|
||||
@@ -1977,7 +2257,7 @@ func (b *NetworkMapBuilder) addUpdateForPeersInGroups(
|
||||
PeerIP: newPeer.IP.String(),
|
||||
Direction: direction,
|
||||
Action: string(rule.Action),
|
||||
Protocol: string(rule.Protocol),
|
||||
Protocol: firewallRuleProtocol(rule.Protocol),
|
||||
}
|
||||
for _, peerID := range peers {
|
||||
if peerID == newPeerID {
|
||||
@@ -2028,7 +2308,7 @@ func (b *NetworkMapBuilder) addUpdateForDirectPeerResource(
|
||||
PeerIP: newPeer.IP.String(),
|
||||
Direction: direction,
|
||||
Action: string(rule.Action),
|
||||
Protocol: string(rule.Protocol),
|
||||
Protocol: firewallRuleProtocol(rule.Protocol),
|
||||
}
|
||||
|
||||
b.addOrUpdateFirewallRuleInDelta(updates, targetPeerID, newPeerID, rule, direction, fr, fr.PeerIP, targetPeer)
|
||||
@@ -2041,11 +2321,13 @@ func (b *NetworkMapBuilder) addOrUpdateFirewallRuleInDelta(
|
||||
delta := updates[targetPeerID]
|
||||
if delta == nil {
|
||||
delta = &PeerUpdateDelta{
|
||||
PeerID: targetPeerID,
|
||||
AddConnectedPeer: newPeerID,
|
||||
AddFirewallRules: make([]*FirewallRuleDelta, 0),
|
||||
PeerID: targetPeerID,
|
||||
AddConnectedPeers: []string{newPeerID},
|
||||
AddFirewallRules: make([]*FirewallRuleDelta, 0),
|
||||
}
|
||||
updates[targetPeerID] = delta
|
||||
} else if !slices.Contains(delta.AddConnectedPeers, newPeerID) {
|
||||
delta.AddConnectedPeers = append(delta.AddConnectedPeers, newPeerID)
|
||||
}
|
||||
|
||||
baseRule.PeerIP = peerIP
|
||||
@@ -2071,10 +2353,12 @@ func (b *NetworkMapBuilder) addOrUpdateFirewallRuleInDelta(
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) applyDeltaToPeer(account *Account, peerID string, delta *PeerUpdateDelta) {
|
||||
if delta.AddConnectedPeer != "" || len(delta.AddFirewallRules) > 0 {
|
||||
if len(delta.AddConnectedPeers) > 0 || len(delta.AddFirewallRules) > 0 {
|
||||
if aclView := b.cache.peerACLs[peerID]; aclView != nil {
|
||||
if delta.AddConnectedPeer != "" && !slices.Contains(aclView.ConnectedPeerIDs, delta.AddConnectedPeer) {
|
||||
aclView.ConnectedPeerIDs = append(aclView.ConnectedPeerIDs, delta.AddConnectedPeer)
|
||||
for _, connectedPeerID := range delta.AddConnectedPeers {
|
||||
if !slices.Contains(aclView.ConnectedPeerIDs, connectedPeerID) {
|
||||
aclView.ConnectedPeerIDs = append(aclView.ConnectedPeerIDs, connectedPeerID)
|
||||
}
|
||||
}
|
||||
|
||||
for _, ruleDelta := range delta.AddFirewallRules {
|
||||
@@ -2130,11 +2414,11 @@ func (b *NetworkMapBuilder) updateRouteFirewallRules(routesView *PeerRoutesView,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error {
|
||||
func (b *NetworkMapBuilder) OnPeerDeleted(acc *Account, peerID string) error {
|
||||
b.cache.mu.Lock()
|
||||
defer b.cache.mu.Unlock()
|
||||
|
||||
account := b.account.Load()
|
||||
account := b.updateAccountLocked(acc)
|
||||
|
||||
deletedPeer := b.cache.globalPeers[peerID]
|
||||
if deletedPeer == nil {
|
||||
@@ -2190,6 +2474,7 @@ func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error {
|
||||
delete(b.cache.peerACLs, peerID)
|
||||
delete(b.cache.peerRoutes, peerID)
|
||||
delete(b.cache.peerDNS, peerID)
|
||||
delete(b.cache.peerSSH, peerID)
|
||||
|
||||
delete(b.cache.globalPeers, peerID)
|
||||
|
||||
@@ -2240,11 +2525,16 @@ func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error {
|
||||
b.buildPeerRoutesView(account, affectedPeerID)
|
||||
}
|
||||
|
||||
peerDeletionUpdates := b.findPeersAffectedByDeletedPeerACL(peerID, peerIP)
|
||||
peersToRebuildACL := make(map[string]struct{})
|
||||
peerDeletionUpdates := b.findPeersAffectedByDeletedPeerACL(peerID, peerIP, peerGroups, peersToRebuildACL)
|
||||
for affectedPeerID, updates := range peerDeletionUpdates {
|
||||
b.applyDeletionUpdates(affectedPeerID, updates)
|
||||
}
|
||||
|
||||
for affectedPeerID := range peersToRebuildACL {
|
||||
b.buildPeerACLView(account, affectedPeerID)
|
||||
}
|
||||
|
||||
b.cleanupUnusedRules()
|
||||
|
||||
log.Debugf("NetworkMapBuilder: Deleted peer %s, affected %d other peers", peerID, len(affectedPeers))
|
||||
@@ -2255,6 +2545,8 @@ func (b *NetworkMapBuilder) OnPeerDeleted(peerID string) error {
|
||||
func (b *NetworkMapBuilder) findPeersAffectedByDeletedPeerACL(
|
||||
deletedPeerID string,
|
||||
peerIP string,
|
||||
peerGroups []string,
|
||||
peersToRebuildACL map[string]struct{},
|
||||
) map[string]*PeerDeletionUpdate {
|
||||
|
||||
affected := make(map[string]*PeerDeletionUpdate)
|
||||
@@ -2264,26 +2556,47 @@ func (b *NetworkMapBuilder) findPeersAffectedByDeletedPeerACL(
|
||||
continue
|
||||
}
|
||||
|
||||
if !slices.Contains(aclView.ConnectedPeerIDs, deletedPeerID) {
|
||||
continue
|
||||
}
|
||||
if affected[peerID] == nil {
|
||||
affected[peerID] = &PeerDeletionUpdate{
|
||||
RemovePeerID: deletedPeerID,
|
||||
PeerIP: peerIP,
|
||||
if slices.Contains(aclView.ConnectedPeerIDs, deletedPeerID) {
|
||||
peersToRebuildACL[peerID] = struct{}{}
|
||||
if affected[peerID] == nil {
|
||||
affected[peerID] = &PeerDeletionUpdate{
|
||||
RemovePeerID: deletedPeerID,
|
||||
PeerIP: peerIP,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, ruleID := range aclView.FirewallRuleIDs {
|
||||
if rule := b.cache.globalRules[ruleID]; rule != nil && rule.PeerIP == peerIP {
|
||||
affected[peerID].RemoveFirewallRuleIDs = append(
|
||||
affected[peerID].RemoveFirewallRuleIDs,
|
||||
ruleID,
|
||||
)
|
||||
affectedRouteOwners := make(map[string]struct{})
|
||||
|
||||
for _, groupID := range peerGroups {
|
||||
if routeMap, ok := b.cache.acgToRoutes[groupID]; ok {
|
||||
for _, info := range routeMap {
|
||||
if info.PeerID != deletedPeerID {
|
||||
affectedRouteOwners[info.PeerID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, info := range b.cache.noACGRoutes {
|
||||
if info.PeerID != deletedPeerID {
|
||||
affectedRouteOwners[info.PeerID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
for ownerPeerID := range affectedRouteOwners {
|
||||
if affected[ownerPeerID] == nil {
|
||||
affected[ownerPeerID] = &PeerDeletionUpdate{
|
||||
RemovePeerID: deletedPeerID,
|
||||
PeerIP: peerIP,
|
||||
RemoveFromSourceRanges: true,
|
||||
}
|
||||
} else {
|
||||
affected[ownerPeerID].RemoveFromSourceRanges = true
|
||||
}
|
||||
}
|
||||
|
||||
return affected
|
||||
}
|
||||
|
||||
@@ -2296,18 +2609,6 @@ type PeerDeletionUpdate struct {
|
||||
}
|
||||
|
||||
func (b *NetworkMapBuilder) applyDeletionUpdates(peerID string, updates *PeerDeletionUpdate) {
|
||||
if aclView := b.cache.peerACLs[peerID]; aclView != nil {
|
||||
aclView.ConnectedPeerIDs = slices.DeleteFunc(aclView.ConnectedPeerIDs, func(id string) bool {
|
||||
return id == updates.RemovePeerID
|
||||
})
|
||||
|
||||
if len(updates.RemoveFirewallRuleIDs) > 0 {
|
||||
aclView.FirewallRuleIDs = slices.DeleteFunc(aclView.FirewallRuleIDs, func(ruleID string) bool {
|
||||
return slices.Contains(updates.RemoveFirewallRuleIDs, ruleID)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
if routesView := b.cache.peerRoutes[peerID]; routesView != nil {
|
||||
if len(updates.RemoveRouteIDs) > 0 {
|
||||
routesView.NetworkResourceIDs = slices.DeleteFunc(routesView.NetworkResourceIDs, func(routeID route.ID) bool {
|
||||
|
||||
@@ -23,6 +23,8 @@ const (
|
||||
PolicyRuleProtocolUDP = PolicyRuleProtocolType("udp")
|
||||
// PolicyRuleProtocolICMP type of traffic
|
||||
PolicyRuleProtocolICMP = PolicyRuleProtocolType("icmp")
|
||||
// PolicyRuleProtocolNetbirdSSH type of traffic
|
||||
PolicyRuleProtocolNetbirdSSH = PolicyRuleProtocolType("netbird-ssh")
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -167,6 +169,8 @@ func ParseRuleString(rule string) (PolicyRuleProtocolType, RulePortRange, error)
|
||||
protocol = PolicyRuleProtocolUDP
|
||||
case "icmp":
|
||||
return "", RulePortRange{}, errors.New("icmp does not accept ports; use 'icmp' without '/…'")
|
||||
case "netbird-ssh":
|
||||
return PolicyRuleProtocolNetbirdSSH, RulePortRange{Start: nativeSSHPortNumber, End: nativeSSHPortNumber}, nil
|
||||
default:
|
||||
return "", RulePortRange{}, fmt.Errorf("invalid protocol: %q", protoStr)
|
||||
}
|
||||
|
||||
@@ -80,6 +80,12 @@ type PolicyRule struct {
|
||||
|
||||
// PortRanges a list of port ranges.
|
||||
PortRanges []RulePortRange `gorm:"serializer:json"`
|
||||
|
||||
// AuthorizedGroups is a map of groupIDs and their respective access to local users via ssh
|
||||
AuthorizedGroups map[string][]string `gorm:"serializer:json"`
|
||||
|
||||
// AuthorizedUser is a list of userIDs that are authorized to access local resources via ssh
|
||||
AuthorizedUser string
|
||||
}
|
||||
|
||||
// Copy returns a copy of a policy rule
|
||||
@@ -99,10 +105,16 @@ func (pm *PolicyRule) Copy() *PolicyRule {
|
||||
Protocol: pm.Protocol,
|
||||
Ports: make([]string, len(pm.Ports)),
|
||||
PortRanges: make([]RulePortRange, len(pm.PortRanges)),
|
||||
AuthorizedGroups: make(map[string][]string, len(pm.AuthorizedGroups)),
|
||||
AuthorizedUser: pm.AuthorizedUser,
|
||||
}
|
||||
copy(rule.Destinations, pm.Destinations)
|
||||
copy(rule.Sources, pm.Sources)
|
||||
copy(rule.Ports, pm.Ports)
|
||||
copy(rule.PortRanges, pm.PortRanges)
|
||||
for k, v := range pm.AuthorizedGroups {
|
||||
rule.AuthorizedGroups[k] = make([]string, len(v))
|
||||
copy(rule.AuthorizedGroups[k], v)
|
||||
}
|
||||
return rule
|
||||
}
|
||||
|
||||
@@ -52,6 +52,9 @@ type Settings struct {
|
||||
|
||||
// LazyConnectionEnabled indicates if the experimental feature is enabled or disabled
|
||||
LazyConnectionEnabled bool `gorm:"default:false"`
|
||||
|
||||
// AutoUpdateVersion client auto-update version
|
||||
AutoUpdateVersion string `gorm:"default:'disabled'"`
|
||||
}
|
||||
|
||||
// Copy copies the Settings struct
|
||||
@@ -72,6 +75,7 @@ func (s *Settings) Copy() *Settings {
|
||||
LazyConnectionEnabled: s.LazyConnectionEnabled,
|
||||
DNSDomain: s.DNSDomain,
|
||||
NetworkRange: s.NetworkRange,
|
||||
AutoUpdateVersion: s.AutoUpdateVersion,
|
||||
}
|
||||
if s.Extra != nil {
|
||||
settings.Extra = s.Extra.Copy()
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/integration_reference"
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -65,7 +66,11 @@ type UserInfo struct {
|
||||
LastLogin time.Time `json:"last_login"`
|
||||
Issued string `json:"issued"`
|
||||
PendingApproval bool `json:"pending_approval"`
|
||||
Password string `json:"password"`
|
||||
IntegrationReference integration_reference.IntegrationReference `json:"-"`
|
||||
// IdPID is the identity provider ID (connector ID) extracted from the Dex-encoded user ID.
|
||||
// This field is only populated when the user ID can be decoded from Dex's format.
|
||||
IdPID string `json:"idp_id,omitempty"`
|
||||
}
|
||||
|
||||
// User represents a user of the system
|
||||
@@ -96,6 +101,9 @@ type User struct {
|
||||
Issued string `gorm:"default:api"`
|
||||
|
||||
IntegrationReference integration_reference.IntegrationReference `gorm:"embedded;embeddedPrefix:integration_ref_"`
|
||||
|
||||
Name string `gorm:"default:''"`
|
||||
Email string `gorm:"default:''"`
|
||||
}
|
||||
|
||||
// IsBlocked returns true if the user is blocked, false otherwise
|
||||
@@ -143,10 +151,16 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
}
|
||||
|
||||
if userData == nil {
|
||||
|
||||
name := u.Name
|
||||
if u.IsServiceUser {
|
||||
name = u.ServiceUserName
|
||||
}
|
||||
|
||||
return &UserInfo{
|
||||
ID: u.Id,
|
||||
Email: "",
|
||||
Name: u.ServiceUserName,
|
||||
Email: u.Email,
|
||||
Name: name,
|
||||
Role: string(u.Role),
|
||||
AutoGroups: u.AutoGroups,
|
||||
Status: string(UserStatusActive),
|
||||
@@ -178,6 +192,7 @@ func (u *User) ToUserInfo(userData *idp.UserData) (*UserInfo, error) {
|
||||
LastLogin: u.GetLastLogin(),
|
||||
Issued: u.Issued,
|
||||
PendingApproval: u.PendingApproval,
|
||||
Password: userData.Password,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -204,11 +219,13 @@ func (u *User) Copy() *User {
|
||||
CreatedAt: u.CreatedAt,
|
||||
Issued: u.Issued,
|
||||
IntegrationReference: u.IntegrationReference,
|
||||
Email: u.Email,
|
||||
Name: u.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// NewUser creates a new user
|
||||
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string) *User {
|
||||
func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, serviceUserName string, autoGroups []string, issued string, email string, name string) *User {
|
||||
return &User{
|
||||
Id: id,
|
||||
Role: role,
|
||||
@@ -218,20 +235,70 @@ func NewUser(id string, role UserRole, isServiceUser bool, nonDeletable bool, se
|
||||
AutoGroups: autoGroups,
|
||||
Issued: issued,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Name: name,
|
||||
Email: email,
|
||||
}
|
||||
}
|
||||
|
||||
// NewRegularUser creates a new user with role UserRoleUser
|
||||
func NewRegularUser(id string) *User {
|
||||
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI)
|
||||
func NewRegularUser(id, email, name string) *User {
|
||||
return NewUser(id, UserRoleUser, false, false, "", []string{}, UserIssuedAPI, email, name)
|
||||
}
|
||||
|
||||
// NewAdminUser creates a new user with role UserRoleAdmin
|
||||
func NewAdminUser(id string) *User {
|
||||
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI)
|
||||
return NewUser(id, UserRoleAdmin, false, false, "", []string{}, UserIssuedAPI, "", "")
|
||||
}
|
||||
|
||||
// NewOwnerUser creates a new user with role UserRoleOwner
|
||||
func NewOwnerUser(id string) *User {
|
||||
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI)
|
||||
func NewOwnerUser(id string, email string, name string) *User {
|
||||
return NewUser(id, UserRoleOwner, false, false, "", []string{}, UserIssuedAPI, email, name)
|
||||
}
|
||||
|
||||
// EncryptSensitiveData encrypts the user's sensitive fields (Email and Name) in place.
|
||||
func (u *User) EncryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if u.Email != "" {
|
||||
u.Email, err = enc.Encrypt(u.Email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if u.Name != "" {
|
||||
u.Name, err = enc.Encrypt(u.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("encrypt name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DecryptSensitiveData decrypts the user's sensitive fields (Email and Name) in place.
|
||||
func (u *User) DecryptSensitiveData(enc *crypt.FieldEncrypt) error {
|
||||
if enc == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
if u.Email != "" {
|
||||
u.Email, err = enc.Decrypt(u.Email)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decrypt email: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if u.Name != "" {
|
||||
u.Name, err = enc.Decrypt(u.Name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decrypt name: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
298
management/server/types/user_test.go
Normal file
298
management/server/types/user_test.go
Normal file
@@ -0,0 +1,298 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/util/crypt"
|
||||
)
|
||||
|
||||
func TestUser_EncryptSensitiveData(t *testing.T) {
|
||||
key, err := crypt.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("encrypt email and name", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-1",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, "test@example.com", user.Email, "email should be encrypted")
|
||||
assert.NotEqual(t, "Test User", user.Name, "name should be encrypted")
|
||||
assert.NotEmpty(t, user.Email, "encrypted email should not be empty")
|
||||
assert.NotEmpty(t, user.Name, "encrypted name should not be empty")
|
||||
})
|
||||
|
||||
t.Run("encrypt empty email and name", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-2",
|
||||
Email: "",
|
||||
Name: "",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "", user.Email, "empty email should remain empty")
|
||||
assert.Equal(t, "", user.Name, "empty name should remain empty")
|
||||
})
|
||||
|
||||
t.Run("encrypt only email", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-3",
|
||||
Email: "test@example.com",
|
||||
Name: "",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, "test@example.com", user.Email, "email should be encrypted")
|
||||
assert.NotEmpty(t, user.Email, "encrypted email should not be empty")
|
||||
assert.Equal(t, "", user.Name, "empty name should remain empty")
|
||||
})
|
||||
|
||||
t.Run("encrypt only name", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-4",
|
||||
Email: "",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "", user.Email, "empty email should remain empty")
|
||||
assert.NotEqual(t, "Test User", user.Name, "name should be encrypted")
|
||||
assert.NotEmpty(t, user.Name, "encrypted name should not be empty")
|
||||
})
|
||||
|
||||
t.Run("nil encryptor returns no error", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-5",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test@example.com", user.Email, "email should remain unchanged with nil encryptor")
|
||||
assert.Equal(t, "Test User", user.Name, "name should remain unchanged with nil encryptor")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_DecryptSensitiveData(t *testing.T) {
|
||||
key, err := crypt.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("decrypt email and name", func(t *testing.T) {
|
||||
originalEmail := "test@example.com"
|
||||
originalName := "Test User"
|
||||
|
||||
user := &User{
|
||||
Id: "user-1",
|
||||
Email: originalEmail,
|
||||
Name: originalName,
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, originalEmail, user.Email, "decrypted email should match original")
|
||||
assert.Equal(t, originalName, user.Name, "decrypted name should match original")
|
||||
})
|
||||
|
||||
t.Run("decrypt empty email and name", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-2",
|
||||
Email: "",
|
||||
Name: "",
|
||||
}
|
||||
|
||||
err := user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "", user.Email, "empty email should remain empty")
|
||||
assert.Equal(t, "", user.Name, "empty name should remain empty")
|
||||
})
|
||||
|
||||
t.Run("decrypt only email", func(t *testing.T) {
|
||||
originalEmail := "test@example.com"
|
||||
|
||||
user := &User{
|
||||
Id: "user-3",
|
||||
Email: originalEmail,
|
||||
Name: "",
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, originalEmail, user.Email, "decrypted email should match original")
|
||||
assert.Equal(t, "", user.Name, "empty name should remain empty")
|
||||
})
|
||||
|
||||
t.Run("decrypt only name", func(t *testing.T) {
|
||||
originalName := "Test User"
|
||||
|
||||
user := &User{
|
||||
Id: "user-4",
|
||||
Email: "",
|
||||
Name: originalName,
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "", user.Email, "empty email should remain empty")
|
||||
assert.Equal(t, originalName, user.Name, "decrypted name should match original")
|
||||
})
|
||||
|
||||
t.Run("nil encryptor returns no error", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-5",
|
||||
Email: "test@example.com",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := user.DecryptSensitiveData(nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "test@example.com", user.Email, "email should remain unchanged with nil encryptor")
|
||||
assert.Equal(t, "Test User", user.Name, "name should remain unchanged with nil encryptor")
|
||||
})
|
||||
|
||||
t.Run("decrypt with invalid ciphertext returns error", func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "user-6",
|
||||
Email: "not-valid-base64-ciphertext!!!",
|
||||
Name: "Test User",
|
||||
}
|
||||
|
||||
err := user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decrypt email")
|
||||
})
|
||||
|
||||
t.Run("decrypt with wrong key returns error", func(t *testing.T) {
|
||||
originalEmail := "test@example.com"
|
||||
originalName := "Test User"
|
||||
|
||||
user := &User{
|
||||
Id: "user-7",
|
||||
Email: originalEmail,
|
||||
Name: originalName,
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
differentKey, err := crypt.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
differentEncrypt, err := crypt.NewFieldEncrypt(differentKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = user.DecryptSensitiveData(differentEncrypt)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "decrypt email")
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_EncryptDecryptRoundTrip(t *testing.T) {
|
||||
key, err := crypt.GenerateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
fieldEncrypt, err := crypt.NewFieldEncrypt(key)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
email string
|
||||
uname string
|
||||
}{
|
||||
{
|
||||
name: "standard email and name",
|
||||
email: "user@example.com",
|
||||
uname: "John Doe",
|
||||
},
|
||||
{
|
||||
name: "email with special characters",
|
||||
email: "user+tag@sub.example.com",
|
||||
uname: "O'Brien, Mary-Jane",
|
||||
},
|
||||
{
|
||||
name: "unicode characters",
|
||||
email: "user@example.com",
|
||||
uname: "Jean-Pierre Müller 日本語",
|
||||
},
|
||||
{
|
||||
name: "long values",
|
||||
email: "very.long.email.address.that.is.quite.extended@subdomain.example.organization.com",
|
||||
uname: "A Very Long Name That Contains Many Words And Is Quite Extended For Testing Purposes",
|
||||
},
|
||||
{
|
||||
name: "empty email only",
|
||||
email: "",
|
||||
uname: "Name Only",
|
||||
},
|
||||
{
|
||||
name: "empty name only",
|
||||
email: "email@only.com",
|
||||
uname: "",
|
||||
},
|
||||
{
|
||||
name: "both empty",
|
||||
email: "",
|
||||
uname: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
user := &User{
|
||||
Id: "test-user",
|
||||
Email: tc.email,
|
||||
Name: tc.uname,
|
||||
}
|
||||
|
||||
err := user.EncryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tc.email != "" {
|
||||
assert.NotEqual(t, tc.email, user.Email, "email should be encrypted")
|
||||
}
|
||||
if tc.uname != "" {
|
||||
assert.NotEqual(t, tc.uname, user.Name, "name should be encrypted")
|
||||
}
|
||||
|
||||
err = user.DecryptSensitiveData(fieldEncrypt)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tc.email, user.Email, "decrypted email should match original")
|
||||
assert.Equal(t, tc.uname, user.Name, "decrypted name should match original")
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user