From 00c3b671821ec9be52d412948b3fc48dae4b341c Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Thu, 28 Nov 2024 11:13:01 +0100 Subject: [PATCH] [management] refactor to use account object instead of separate db calls for peer update (#2957) --- management/server/peer.go | 8 ++-- management/server/peer_test.go | 33 ++++++++++------ management/server/posture_checks.go | 61 +++++++++++------------------ 3 files changed, 48 insertions(+), 54 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index dcb47af3b..d45bb1a50 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -617,7 +617,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, err } - postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID) + postureChecks, err := am.getPeerPostureChecks(account, newPeer.ID) if err != nil { return nil, nil, nil, err } @@ -707,7 +707,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac return nil, nil, nil, fmt.Errorf("failed to get validated peers: %w", err) } - postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + postureChecks, err = am.getPeerPostureChecks(account, peer.ID) if err != nil { return nil, nil, nil, err } @@ -885,7 +885,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is return nil, nil, nil, err } - postureChecks, err = am.getPeerPostureChecks(ctx, account.Id, peer.ID) + postureChecks, err = am.getPeerPostureChecks(account, peer.ID) if err != nil { return nil, nil, nil, err } @@ -1042,7 +1042,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account defer wg.Done() defer func() { <-semaphore }() - postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, p.ID) + postureChecks, err := am.getPeerPostureChecks(account, p.ID) if err != nil { log.WithContext(ctx).Errorf("failed to send out updates to peers, failed to get peer: %s posture checks: %v", p.ID, err) return diff --git a/management/server/peer_test.go b/management/server/peer_test.go index e410fa892..e5eaa20d6 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -833,19 +833,20 @@ func BenchmarkGetPeers(b *testing.B) { }) } } - func BenchmarkUpdateAccountPeers(b *testing.B) { benchCases := []struct { - name string - peers int - groups int + name string + peers int + groups int + minMsPerOp float64 + maxMsPerOp float64 }{ - {"Small", 50, 5}, - {"Medium", 500, 10}, - {"Large", 5000, 20}, - {"Small single", 50, 1}, - {"Medium single", 500, 1}, - {"Large 5", 5000, 5}, + {"Small", 50, 5, 90, 120}, + {"Medium", 500, 100, 110, 140}, + {"Large", 5000, 200, 800, 1300}, + {"Small single", 50, 10, 90, 120}, + {"Medium single", 500, 10, 110, 170}, + {"Large 5", 5000, 15, 1300, 1800}, } log.SetOutput(io.Discard) @@ -881,8 +882,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { } duration := time.Since(start) - b.ReportMetric(float64(duration.Nanoseconds())/float64(b.N)/1e6, "ms/op") - b.ReportMetric(0, "ns/op") + msPerOp := float64(duration.Nanoseconds()) / float64(b.N) / 1e6 + b.ReportMetric(msPerOp, "ms/op") + + if msPerOp < bc.minMsPerOp { + b.Fatalf("Benchmark %s failed: too fast (%.2f ms/op, minimum %.2f ms/op)", bc.name, msPerOp, bc.minMsPerOp) + } + + if msPerOp > bc.maxMsPerOp { + b.Fatalf("Benchmark %s failed: too slow (%.2f ms/op, maximum %.2f ms/op)", bc.name, msPerOp, bc.maxMsPerOp) + } }) } } diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 59e726c41..0467efedb 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -2,14 +2,16 @@ package server import ( "context" + "errors" "fmt" "slices" + "github.com/rs/xid" + "golang.org/x/exp/maps" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/server/status" - "github.com/rs/xid" - "golang.org/x/exp/maps" ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { @@ -149,38 +151,21 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI } // getPeerPostureChecks returns the posture checks applied for a given peer. -func (am *DefaultAccountManager) getPeerPostureChecks(ctx context.Context, accountID string, peerID string) ([]*posture.Checks, error) { +func (am *DefaultAccountManager) getPeerPostureChecks(account *Account, peerID string) ([]*posture.Checks, error) { peerPostureChecks := make(map[string]*posture.Checks) - err := am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { - postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) - if err != nil { - return err + if len(account.PostureChecks) == 0 { + return nil, nil + } + + for _, policy := range account.Policies { + if !policy.Enabled || len(policy.SourcePostureChecks) == 0 { + continue } - if len(postureChecks) == 0 { - return nil + if err := addPolicyPostureChecks(account, peerID, policy, peerPostureChecks); err != nil { + return nil, err } - - policies, err := transaction.GetAccountPolicies(ctx, LockingStrengthShare, accountID) - if err != nil { - return err - } - - for _, policy := range policies { - if !policy.Enabled { - continue - } - - if err = addPolicyPostureChecks(ctx, transaction, accountID, peerID, policy, peerPostureChecks); err != nil { - return err - } - } - - return nil - }) - if err != nil { - return nil, err } return maps.Values(peerPostureChecks), nil @@ -241,8 +226,8 @@ func validatePostureChecks(ctx context.Context, transaction Store, accountID str } // addPolicyPostureChecks adds posture checks from a policy to the peer posture checks map if the peer is in the policy's source groups. -func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error { - isInGroup, err := isPeerInPolicySourceGroups(ctx, transaction, accountID, peerID, policy) +func addPolicyPostureChecks(account *Account, peerID string, policy *Policy, peerPostureChecks map[string]*posture.Checks) error { + isInGroup, err := isPeerInPolicySourceGroups(account, peerID, policy) if err != nil { return err } @@ -252,9 +237,9 @@ func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, p } for _, sourcePostureCheckID := range policy.SourcePostureChecks { - postureCheck, err := transaction.GetPostureChecksByID(ctx, LockingStrengthShare, accountID, sourcePostureCheckID) - if err != nil { - return err + postureCheck := account.getPostureChecks(sourcePostureCheckID) + if postureCheck == nil { + return errors.New("failed to add policy posture checks: posture checks not found") } peerPostureChecks[sourcePostureCheckID] = postureCheck } @@ -263,16 +248,16 @@ func addPolicyPostureChecks(ctx context.Context, transaction Store, accountID, p } // isPeerInPolicySourceGroups checks if a peer is present in any of the policy rule source groups. -func isPeerInPolicySourceGroups(ctx context.Context, transaction Store, accountID, peerID string, policy *Policy) (bool, error) { +func isPeerInPolicySourceGroups(account *Account, peerID string, policy *Policy) (bool, error) { for _, rule := range policy.Rules { if !rule.Enabled { continue } for _, sourceGroup := range rule.Sources { - group, err := transaction.GetGroupByID(ctx, LockingStrengthShare, accountID, sourceGroup) - if err != nil { - return false, fmt.Errorf("failed to check peer in policy source group: %w", err) + group := account.GetGroup(sourceGroup) + if group == nil { + return false, fmt.Errorf("failed to check peer in policy source group: group not found") } if slices.Contains(group.Peers, peerID) {