Compare commits

...

13 Commits

Author SHA1 Message Date
crn4
c8a9af2482 Merge branch 'main' into vk/debug/nmap-both 2025-11-11 14:32:04 +01:00
crn4
4bcda3e2ba use old map, new in goroutine 2025-11-11 14:31:47 +01:00
Zoltan Papp
c28275611b Fix agent reference (#4776) 2025-11-11 13:59:32 +01:00
Vlad
56f169eede [management] fix pg db deadlock after app panic (#4772) 2025-11-10 23:43:08 +01:00
crn4
b780f1c09d add both network maps compilation for debug 2025-11-10 21:19:07 +01:00
Viktor Liu
07cf9d5895 [client] Create networkd.conf.d if it doesn't exist (#4764) 2025-11-08 10:54:37 +01:00
Pascal Fischer
7df49e249d [management ] remove timing logs (#4761) 2025-11-07 20:14:52 +01:00
Pascal Fischer
dbfc8a52c9 [management] remove GLOBAL when disabling foreign keys on mysql (#4615) 2025-11-07 16:03:14 +01:00
Vlad
98ddac07bf [management] remove toAll firewall rule (#4725) 2025-11-07 15:50:58 +01:00
Pascal Fischer
48475ddc05 [management] add pat rate limiting (#4741) 2025-11-07 15:50:18 +01:00
Vlad
6aa4ba7af4 [management] incremental network map builder (#4753) 2025-11-07 10:44:46 +01:00
dependabot[bot]
2e16c9914a [management] Bump github.com/containerd/containerd from 1.7.27 to 1.7.29 (#4756)
Bumps [github.com/containerd/containerd](https://github.com/containerd/containerd) from 1.7.27 to 1.7.29.
- [Release notes](https://github.com/containerd/containerd/releases)
- [Changelog](https://github.com/containerd/containerd/blob/main/RELEASES.md)
- [Commits](https://github.com/containerd/containerd/compare/v1.7.27...v1.7.29)

---
updated-dependencies:
- dependency-name: github.com/containerd/containerd
  dependency-version: 1.7.29
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-06 19:01:44 +03:00
Pascal Fischer
5c29d395b2 [management] activity events on group updates (#4750) 2025-11-06 12:51:14 +01:00
43 changed files with 7816 additions and 153 deletions

View File

@@ -259,6 +259,7 @@ func isServiceRunning() (bool, error) {
}
const (
networkdConf = "/etc/systemd/networkd.conf"
networkdConfDir = "/etc/systemd/networkd.conf.d"
networkdConfFile = "/etc/systemd/networkd.conf.d/99-netbird.conf"
networkdConfContent = `# Created by NetBird to prevent systemd-networkd from removing
@@ -273,12 +274,16 @@ ManageForeignRoutingPolicyRules=no
// configureSystemdNetworkd creates a drop-in configuration file to prevent
// systemd-networkd from removing NetBird's routes and policy rules.
func configureSystemdNetworkd() error {
parentDir := filepath.Dir(networkdConfDir)
if _, err := os.Stat(parentDir); os.IsNotExist(err) {
log.Debug("systemd networkd.conf.d parent directory does not exist, skipping configuration")
if _, err := os.Stat(networkdConf); os.IsNotExist(err) {
log.Debug("systemd-networkd not in use, skipping configuration")
return nil
}
// nolint:gosec // standard networkd permissions
if err := os.MkdirAll(networkdConfDir, 0755); err != nil {
return fmt.Errorf("create networkd.conf.d directory: %w", err)
}
// nolint:gosec // standard networkd permissions
if err := os.WriteFile(networkdConfFile, []byte(networkdConfContent), 0644); err != nil {
return fmt.Errorf("write networkd configuration: %w", err)

View File

@@ -411,7 +411,7 @@ func (w *WorkerICE) onConnectionStateChange(agent *icemaker.ThreadSafeAgent, dia
func (w *WorkerICE) turnAgentDial(ctx context.Context, agent *icemaker.ThreadSafeAgent, remoteOfferAnswer *OfferAnswer) (*ice.Conn, error) {
if isController(w.config) {
return w.agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
return agent.Dial(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
} else {
return agent.Accept(ctx, remoteOfferAnswer.IceCredentials.UFrag, remoteOfferAnswer.IceCredentials.Pwd)
}

10
go.mod
View File

@@ -56,6 +56,7 @@ require (
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/go-secure-stdlib/base62 v0.1.2
github.com/hashicorp/go-version v1.6.0
github.com/jackc/pgx/v5 v5.5.5
github.com/libdns/route53 v1.5.0
github.com/libp2p/go-netroute v0.2.1
github.com/mdlayher/socket v0.5.1
@@ -102,11 +103,12 @@ require (
goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
golang.org/x/mod v0.25.0
golang.org/x/mod v0.26.0
golang.org/x/net v0.42.0
golang.org/x/oauth2 v0.28.0
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.16.0
golang.org/x/term v0.33.0
golang.org/x/time v0.12.0
google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7
@@ -146,7 +148,7 @@ require (
github.com/beorn7/perks v1.0.1 // indirect
github.com/caddyserver/zerossl v0.1.3 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/containerd/containerd v1.7.27 // indirect
github.com/containerd/containerd v1.7.29 // indirect
github.com/containerd/log v0.1.0 // indirect
github.com/containerd/platforms v0.2.1 // indirect
github.com/cpuguy83/dockercfg v0.3.2 // indirect
@@ -183,7 +185,6 @@ require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jeandeaual/go-locale v0.0.0-20240223122105-ce5225dcaa49 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
@@ -245,7 +246,6 @@ require (
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect
golang.org/x/text v0.27.0 // indirect
golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.34.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect

16
go.sum
View File

@@ -142,8 +142,8 @@ github.com/cncf/udpa/go v0.0.0-20200629203442-efcf912fb354/go.mod h1:WmhPx2Nbnht
github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk=
github.com/coder/websocket v1.8.13 h1:f3QZdXy7uGVz+4uCJy2nTZyM0yTBj8yANEHhqlXZ9FE=
github.com/coder/websocket v1.8.13/go.mod h1:LNVeNrXQZfe5qhS9ALED3uA+l5pPqvwXg3CKoDBB2gs=
github.com/containerd/containerd v1.7.27 h1:yFyEyojddO3MIGVER2xJLWoCIn+Up4GaHFquP7hsFII=
github.com/containerd/containerd v1.7.27/go.mod h1:xZmPnl75Vc+BLGt4MIfu6bp+fy03gdHAn9bz+FreFR0=
github.com/containerd/containerd v1.7.29 h1:90fWABQsaN9mJhGkoVnuzEY+o1XDPbg9BTC9QTAHnuE=
github.com/containerd/containerd v1.7.29/go.mod h1:azUkWcOvHrWvaiUjSQH0fjzuHIwSPg1WL5PshGP4Szs=
github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I=
github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo=
github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpSBQv6A=
@@ -818,8 +818,8 @@ golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -880,8 +880,8 @@ golang.org/x/oauth2 v0.0.0-20210220000619-9bb904979d93/go.mod h1:KelEdhl1UZF7XfJ
golang.org/x/oauth2 v0.0.0-20210313182246-cd4f82c27b84/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.0.0-20210402161424-2e8d93401602/go.mod h1:KelEdhl1UZF7XfJ4dDtk6s++YSgaE7mD/BuKKDLBl4A=
golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE=
golang.org/x/oauth2 v0.28.0 h1:CrgCKl8PPAVtLnU3c+EDw6x11699EWlsDeWNWKdIOkc=
golang.org/x/oauth2 v0.28.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -993,8 +993,8 @@ golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk=
golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=

View File

@@ -1,11 +1,19 @@
package main
import (
"github.com/netbirdio/netbird/management/cmd"
"log"
"net/http"
// nolint:gosec
_ "net/http/pprof"
"os"
"github.com/netbirdio/netbird/management/cmd"
)
func main() {
go func() {
log.Println(http.ListenAndServe("localhost:6060", nil))
}()
if err := cmd.Execute(); err != nil {
os.Exit(1)
}

View File

@@ -53,6 +53,9 @@ const (
peerSchedulerRetryInterval = 3 * time.Second
emptyUserID = "empty user ID in claims"
errorGettingDomainAccIDFmt = "error getting account ID by private domain: %v"
envNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
envNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
)
type userLoggedInOnce bool
@@ -109,6 +112,11 @@ type DefaultAccountManager struct {
loginFilter *loginFilter
disableDefaultPolicy bool
holder *types.Holder
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
}
func isUniqueConstraintError(err error) bool {
@@ -196,6 +204,18 @@ func BuildManager(
log.WithContext(ctx).Debugf("took %v to instantiate account manager", time.Since(start))
}()
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(envNewNetworkMapBuilder))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", envNewNetworkMapBuilder, err)
newNetworkMapBuilder = false
}
ids := strings.Split(os.Getenv(envNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
expIDs[id] = struct{}{}
}
am := &DefaultAccountManager{
Store: store,
geo: geo,
@@ -217,6 +237,10 @@ func BuildManager(
permissionsManager: permissionsManager,
loginFilter: newLoginFilter(),
disableDefaultPolicy: disableDefaultPolicy,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
}
am.startWarmup(ctx)
@@ -395,6 +419,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
}
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
go am.UpdateAccountPeers(ctx, accountID)
}
@@ -1477,6 +1504,10 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
}
if removedGroupAffectsPeers || newGroupsAffectsPeers {
if err := am.RecalculateNetworkMapCache(ctx, userAuth.AccountId); err != nil {
return err
}
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", userAuth.UserId)
am.BufferUpdateAccountPeers(ctx, userAuth.AccountId)
}
@@ -1641,11 +1672,6 @@ func (am *DefaultAccountManager) AllowSync(wgPubKey string, metahash uint64) boo
}
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("SyncAndMarkPeer: took %v", time.Since(start))
}()
peer, netMap, postureChecks, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error syncing peer: %w", err)
@@ -2129,6 +2155,11 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
}
if updateNetworkMap {
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil {
return err
}
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
am.BufferUpdateAccountPeers(ctx, accountID)
}
return nil

View File

@@ -128,4 +128,5 @@ type Manager interface {
GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error)
SetEphemeralManager(em ephemeral.Manager)
AllowSync(string, uint64) bool
RecalculateNetworkMapCache(ctx context.Context, accountId string) error
}

View File

@@ -1154,7 +1154,16 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
}
func TestAccountManager_NetworkUpdates_SaveGroup_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SaveGroup(t)
}
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
testAccountManager_NetworkUpdates_SaveGroup(t)
}
func testAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
group := types.Group{
@@ -1205,7 +1214,16 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeletePolicy_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePolicy(t)
}
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
testAccountManager_NetworkUpdates_DeletePolicy(t)
}
func testAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
manager, account, peer1, _, _ := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -1239,7 +1257,16 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_SavePolicy_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_SavePolicy(t)
}
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
testAccountManager_NetworkUpdates_SavePolicy(t)
}
func testAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
manager, account, peer1, peer2, _ := setupNetworkMapTest(t)
group := types.Group{
@@ -1288,7 +1315,16 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeletePeer_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeletePeer(t)
}
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
testAccountManager_NetworkUpdates_DeletePeer(t)
}
func testAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
group := types.Group{
@@ -1341,7 +1377,16 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeleteGroup_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
testAccountManager_NetworkUpdates_DeleteGroup(t)
}
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
testAccountManager_NetworkUpdates_DeleteGroup(t)
}
func testAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
@@ -1377,6 +1422,14 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
return
}
for drained := false; !drained; {
select {
case <-updMsg:
default:
drained = true
}
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
@@ -1736,7 +1789,9 @@ func TestAccount_Copy(t *testing.T) {
Address: "172.12.6.1/24",
},
},
NetworkMapCache: &types.NetworkMapBuilder{},
}
account.InitOnce()
err := hasNilField(account)
if err != nil {
t.Fatal(err)

View File

@@ -117,6 +117,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}

View File

@@ -114,6 +114,9 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -138,6 +141,11 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err
}
newGroup.AccountID = accountID
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err != nil {
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
@@ -157,11 +165,6 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
}
newGroup.AccountID = accountID
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
@@ -182,6 +185,9 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -250,6 +256,9 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -318,6 +327,9 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -335,6 +347,16 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
if err == nil && oldGroup != nil {
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
if oldGroup.Name != newGroup.Name {
eventsToStore = append(eventsToStore, func() {
meta := map[string]any{
"old_name": oldGroup.Name,
"new_name": newGroup.Name,
}
am.StoreEvent(ctx, userID, newGroup.ID, accountID, activity.GroupUpdated, meta)
})
}
} else {
addedPeers = append(addedPeers, newGroup.Peers...)
eventsToStore = append(eventsToStore, func() {
@@ -471,6 +493,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -509,6 +534,9 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -537,6 +565,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -575,6 +606,9 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
}
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}

View File

@@ -7,8 +7,10 @@ import (
"net"
"net/netip"
"os"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
pb "github.com/golang/protobuf/proto" // nolint
@@ -44,6 +46,9 @@ import (
const (
envLogBlockedPeers = "NB_LOG_BLOCKED_PEERS"
envBlockPeers = "NB_BLOCK_SAME_PEERS"
envConcurrentSyncs = "NB_MAX_CONCURRENT_SYNCS"
defaultSyncLim = 1000
)
// GRPCServer an instance of a Management gRPC API server
@@ -63,6 +68,9 @@ type GRPCServer struct {
logBlockedPeers bool
blockPeersWithSameConfig bool
integratedPeerValidator integrated_validator.IntegratedValidator
syncSem atomic.Int32
syncLim int32
}
// NewServer creates a new Management server
@@ -96,6 +104,17 @@ func NewServer(
logBlockedPeers := strings.ToLower(os.Getenv(envLogBlockedPeers)) == "true"
blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true"
syncLim := int32(defaultSyncLim)
if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" {
syncLimParsed, err := strconv.Atoi(syncLimStr)
if err != nil {
log.Errorf("invalid value for %s: %v using %d", envConcurrentSyncs, err, defaultSyncLim)
} else {
//nolint:gosec
syncLim = int32(syncLimParsed)
}
}
return &GRPCServer{
wgKey: key,
// peerKey -> event channel
@@ -110,6 +129,8 @@ func NewServer(
logBlockedPeers: logBlockedPeers,
blockPeersWithSameConfig: blockPeersWithSameConfig,
integratedPeerValidator: integratedPeerValidator,
syncLim: syncLim,
}, nil
}
@@ -151,6 +172,11 @@ func getRealIP(ctx context.Context) net.IP {
// Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and
// notifies the connected peer of any updates (e.g. new peers under the same account)
func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error {
if s.syncSem.Load() >= s.syncLim {
return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later")
}
s.syncSem.Add(1)
reqStart := time.Now()
ctx := srv.Context()
@@ -158,6 +184,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
syncReq := &proto.SyncRequest{}
peerKey, err := s.parseRequest(ctx, req, syncReq)
if err != nil {
s.syncSem.Add(-1)
return err
}
realIP := getRealIP(ctx)
@@ -172,6 +199,7 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
log.WithContext(ctx).Warnf("peer %s with meta hash %d is blocked from syncing", peerKey.String(), metahashed)
}
if s.blockPeersWithSameConfig {
s.syncSem.Add(-1)
return mapError(ctx, internalStatus.ErrPeerAlreadyLoggedIn)
}
}
@@ -183,27 +211,34 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.PeerIDKey, peerKey.String())
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
defer func() {
if unlock != nil {
unlock()
}
}()
accountID, err := s.accountManager.GetAccountIDForPeerKey(ctx, peerKey.String())
if err != nil {
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, "UNKNOWN")
log.WithContext(ctx).Tracef("peer %s is not registered", peerKey.String())
if errStatus, ok := internalStatus.FromError(err); ok && errStatus.Type() == internalStatus.NotFound {
s.syncSem.Add(-1)
return status.Errorf(codes.PermissionDenied, "peer is not registered")
}
s.syncSem.Add(-1)
return err
}
log.WithContext(ctx).Debugf("Sync: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
// nolint:staticcheck
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
start := time.Now()
unlock := s.acquirePeerLockByUID(ctx, peerKey.String())
defer func() {
if unlock != nil {
unlock()
}
}()
log.WithContext(ctx).Tracef("acquired peer lock for peer %s took %v", peerKey.String(), time.Since(start))
log.WithContext(ctx).Debugf("Sync: acquirePeerLockByUID since start %v", time.Since(reqStart))
log.WithContext(ctx).Debugf("Sync request from peer [%s] [%s]", req.WgPubKey, sRealIP)
if syncReq.GetMeta() == nil {
@@ -213,21 +248,32 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
peer, netMap, postureChecks, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP)
if err != nil {
log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
return mapError(ctx, err)
}
log.WithContext(ctx).Debugf("Sync: SyncAndMarkPeer since start %v", time.Since(reqStart))
err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv)
if err != nil {
log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err)
s.syncSem.Add(-1)
return err
}
log.WithContext(ctx).Debugf("Sync: sendInitialSync since start %v", time.Since(reqStart))
updates := s.peersUpdateManager.CreateChannel(ctx, peer.ID)
log.WithContext(ctx).Debugf("Sync: CreateChannel since start %v", time.Since(reqStart))
s.ephemeralManager.OnPeerConnected(ctx, peer)
log.WithContext(ctx).Debugf("Sync: OnPeerConnected since start %v", time.Since(reqStart))
s.secretsManager.SetupRefresh(ctx, accountID, peer.ID)
log.WithContext(ctx).Debugf("Sync: SetupRefresh since start %v", time.Since(reqStart))
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestDuration(time.Since(reqStart), accountID)
}
@@ -237,6 +283,8 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
log.WithContext(ctx).Debugf("Sync: took %v", time.Since(reqStart))
s.syncSem.Add(-1)
return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv)
}
@@ -509,10 +557,16 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
//nolint
ctx = context.WithValue(ctx, nbContext.AccountIDKey, accountID)
log.WithContext(ctx).Debugf("Login: GetAccountIDForPeerKey since start %v", time.Since(reqStart))
defer func() {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart), accountID)
}
took := time.Since(reqStart)
if took > 7*time.Second {
log.WithContext(ctx).Debugf("Login: took %v", time.Since(reqStart))
}
}()
if loginReq.GetMeta() == nil {
@@ -546,9 +600,12 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
return nil, mapError(ctx, err)
}
log.WithContext(ctx).Debugf("Login: LoginPeer since start %v", time.Since(reqStart))
// if the login request contains setup key then it is a registration request
if loginReq.GetSetupKey() != "" {
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
log.WithContext(ctx).Debugf("Login: OnPeerDisconnected since start %v", time.Since(reqStart))
}
loginResp, err := s.prepareLoginResponse(ctx, peer, netMap, postureChecks)
@@ -557,6 +614,8 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p
return nil, status.Errorf(codes.Internal, "failed logging in peer")
}
log.WithContext(ctx).Debugf("Login: prepareLoginResponse since start %v", time.Since(reqStart))
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, loginResp)
if err != nil {
log.WithContext(ctx).Warnf("failed encrypting peer %s message", peer.ID)
@@ -822,10 +881,12 @@ func (s *GRPCServer) sendInitialSync(ctx context.Context, peerKey wgtypes.Key, p
return status.Errorf(codes.Internal, "error handling request")
}
sendStart := time.Now()
err = srv.Send(&proto.EncryptedMessage{
WgPubKey: s.wgKey.PublicKey().String(),
Body: encryptedResp,
})
log.WithContext(ctx).Debugf("sendInitialSync: sending response took %s", time.Since(sendStart))
if err != nil {
log.WithContext(ctx).Errorf("failed sending SyncResponse %v", err)

View File

@@ -0,0 +1,39 @@
package server
import (
"github.com/netbirdio/netbird/management/server/types"
)
func (am *DefaultAccountManager) enrichAccountFromHolder(account *types.Account) {
a := am.holder.GetAccount(account.Id)
if a == nil {
am.holder.AddAccount(account)
return
}
account.NetworkMapCache = a.NetworkMapCache
if account.NetworkMapCache == nil {
return
}
account.NetworkMapCache.UpdateAccountPointer(account)
am.holder.AddAccount(account)
}
func (am *DefaultAccountManager) getAccountFromHolder(accountID string) *types.Account {
return am.holder.GetAccount(accountID)
}
func (am *DefaultAccountManager) getAccountFromHolderOrInit(accountID string) *types.Account {
a := am.holder.GetAccount(accountID)
if a != nil {
return a
}
account, err := am.holder.LoadOrStoreFunc(accountID, am.requestBuffer.GetAccountWithBackpressure)
if err != nil {
return nil
}
return account
}
func (am *DefaultAccountManager) updateAccountInHolder(account *types.Account) {
am.holder.AddAccount(account)
}

View File

@@ -4,9 +4,13 @@ import (
"context"
"fmt"
"net/http"
"os"
"strconv"
"time"
"github.com/gorilla/mux"
"github.com/rs/cors"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations"
@@ -38,7 +42,12 @@ import (
"github.com/netbirdio/netbird/management/server/telemetry"
)
const apiPrefix = "/api"
const (
apiPrefix = "/api"
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
func NewAPIHandler(
@@ -58,11 +67,42 @@ func NewAPIHandler(
settingsManager settings.Manager,
) (http.Handler, error) {
var rateLimitingConfig *middleware.RateLimiterConfig
if os.Getenv(rateLimitingEnabledKey) == "true" {
rpm := 6
if v := os.Getenv(rateLimitingRPMKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm)
} else {
rpm = value
}
}
burst := 500
if v := os.Getenv(rateLimitingBurstKey); v != "" {
value, err := strconv.Atoi(v)
if err != nil {
log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst)
} else {
burst = value
}
}
rateLimitingConfig = &middleware.RateLimiterConfig{
RequestsPerMinute: float64(rpm),
Burst: burst,
CleanupInterval: 6 * time.Hour,
LimiterTTL: 24 * time.Hour,
}
}
authMiddleware := middleware.NewAuthMiddleware(
authManager,
accountManager.GetAccountIDFromUserAuth,
accountManager.SyncUserJWTGroups,
accountManager.GetUserFromUserAuth,
rateLimitingConfig,
)
corsMiddleware := cors.AllowAll()

View File

@@ -29,6 +29,7 @@ type AuthMiddleware struct {
ensureAccount EnsureAccountFunc
getUserFromUserAuth GetUserFromUserAuthFunc
syncUserJWTGroups SyncUserJWTGroupsFunc
rateLimiter *APIRateLimiter
}
// NewAuthMiddleware instance constructor
@@ -37,12 +38,19 @@ func NewAuthMiddleware(
ensureAccount EnsureAccountFunc,
syncUserJWTGroups SyncUserJWTGroupsFunc,
getUserFromUserAuth GetUserFromUserAuthFunc,
rateLimiterConfig *RateLimiterConfig,
) *AuthMiddleware {
var rateLimiter *APIRateLimiter
if rateLimiterConfig != nil {
rateLimiter = NewAPIRateLimiter(rateLimiterConfig)
}
return &AuthMiddleware{
authManager: authManager,
ensureAccount: ensureAccount,
syncUserJWTGroups: syncUserJWTGroups,
getUserFromUserAuth: getUserFromUserAuth,
rateLimiter: rateLimiter,
}
}
@@ -76,7 +84,11 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
request, err := m.checkPATFromRequest(r, auth)
if err != nil {
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
// Check if it's a status error, otherwise default to Unauthorized
if _, ok := status.FromError(err); !ok {
err = status.Errorf(status.Unauthorized, "token invalid")
}
util.WriteError(r.Context(), err, w)
return
}
h.ServeHTTP(w, request)
@@ -145,6 +157,12 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, auth []string) (*h
return r, fmt.Errorf("error extracting token: %w", err)
}
if m.rateLimiter != nil {
if !m.rateLimiter.Allow(token) {
return r, status.Errorf(status.TooManyRequests, "too many requests")
}
}
ctx := r.Context()
user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token)
if err != nil {

View File

@@ -27,7 +27,9 @@ const (
domainCategory = "domainCategory"
userID = "userID"
tokenID = "tokenID"
tokenID2 = "tokenID2"
PAT = "nbp_PAT"
PAT2 = "nbp_PAT2"
JWT = "JWT"
wrongToken = "wrongToken"
)
@@ -49,6 +51,15 @@ var testAccount = &types.Account{
CreatedAt: time.Now().UTC(),
LastUsed: util.ToPtr(time.Now().UTC()),
},
tokenID2: {
ID: tokenID2,
Name: "My second token",
HashedToken: "someHash2",
ExpirationDate: util.ToPtr(time.Now().UTC().AddDate(0, 0, 7)),
CreatedBy: userID,
CreatedAt: time.Now().UTC(),
LastUsed: util.ToPtr(time.Now().UTC()),
},
},
},
},
@@ -58,6 +69,9 @@ func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *types.Use
if token == PAT {
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
}
if token == PAT2 {
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID2], testAccount.Domain, testAccount.DomainCategory, nil
}
return nil, nil, "", "", fmt.Errorf("PAT invalid")
}
@@ -81,7 +95,7 @@ func mockValidateAndParseToken(_ context.Context, token string) (nbcontext.UserA
}
func mockMarkPATUsed(_ context.Context, token string) error {
if token == tokenID {
if token == tokenID || token == tokenID2 {
return nil
}
return fmt.Errorf("Should never get reached")
@@ -192,6 +206,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
)
handlerToTest := authMiddleware.Handler(nextHandler)
@@ -221,6 +236,273 @@ func TestAuthMiddleware_Handler(t *testing.T) {
}
}
func TestAuthMiddleware_RateLimiting(t *testing.T) {
mockAuth := &auth.MockManager{
ValidateAndParseTokenFunc: mockValidateAndParseToken,
EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
MarkPATUsedFunc: mockMarkPATUsed,
GetPATInfoFunc: mockGetAccountInfoFromPAT,
}
t.Run("PAT Token Rate Limiting - Burst Works", func(t *testing.T) {
// Configure rate limiter: 10 requests per minute with burst of 5
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 10,
Burst: 5,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make burst requests - all should succeed
successCount := 0
for i := 0; i < 5; i++ {
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code == http.StatusOK {
successCount++
}
}
assert.Equal(t, 5, successCount, "All burst requests should succeed")
// The 6th request should fail (exceeded burst)
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Request beyond burst should be rate limited")
})
t.Run("PAT Token Rate Limiting - Rate Limit Enforced", func(t *testing.T) {
// Configure very low rate limit: 1 request per minute
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First request should succeed
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
// Second request should fail (rate limited)
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
})
t.Run("Bearer Token Not Rate Limited", func(t *testing.T) {
// Configure strict rate limit
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Make multiple requests with Bearer token - all should succeed
successCount := 0
for i := 0; i < 10; i++ {
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Bearer "+JWT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code == http.StatusOK {
successCount++
}
}
assert.Equal(t, 10, successCount, "All Bearer token requests should succeed (not rate limited)")
})
t.Run("PAT Token Rate Limiting Per Token", func(t *testing.T) {
// Configure rate limiter
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 1,
Burst: 1,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// Use first PAT token
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT should succeed")
// Second request with same token should fail
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with same PAT should be rate limited")
// Use second PAT token - should succeed because it has independent rate limit
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT2)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request with PAT2 should succeed (independent rate limit)")
// Second request with PAT2 should also be rate limited
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT2)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request with PAT2 should be rate limited")
// JWT should still work (not rate limited)
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Bearer "+JWT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "JWT request should succeed (not rate limited)")
})
t.Run("Rate Limiter Cleanup", func(t *testing.T) {
// Configure rate limiter with short cleanup interval and TTL for testing
rateLimitConfig := &RateLimiterConfig{
RequestsPerMinute: 60,
Burst: 1,
CleanupInterval: 100 * time.Millisecond,
LimiterTTL: 200 * time.Millisecond,
}
authMiddleware := NewAuthMiddleware(
mockAuth,
func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
return userAuth.AccountId, userAuth.UserId, nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) error {
return nil
},
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
rateLimitConfig,
)
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
// First request - should succeed
req := httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "First request should succeed")
// Second request immediately - should fail (burst exhausted)
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request should be rate limited")
// Wait for limiter to be cleaned up (TTL + cleanup interval + buffer)
time.Sleep(400 * time.Millisecond)
// After cleanup, the limiter should be removed and recreated with full burst capacity
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "Request after cleanup should succeed (new limiter with full burst)")
// Verify it's a fresh limiter by checking burst is reset
req = httptest.NewRequest("GET", "http://testing/test", nil)
req.Header.Set("Authorization", "Token "+PAT)
rec = httptest.NewRecorder()
handler.ServeHTTP(rec, req)
assert.Equal(t, http.StatusTooManyRequests, rec.Code, "Second request after cleanup should be rate limited again")
})
}
func TestAuthMiddleware_Handler_Child(t *testing.T) {
tt := []struct {
name string
@@ -297,6 +579,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) {
return &types.User{}, nil
},
nil,
)
for _, tc := range tt {

View File

@@ -0,0 +1,146 @@
package middleware
import (
"context"
"sync"
"time"
"golang.org/x/time/rate"
)
// RateLimiterConfig holds configuration for the API rate limiter
type RateLimiterConfig struct {
// RequestsPerMinute defines the rate at which tokens are replenished
RequestsPerMinute float64
// Burst defines the maximum number of requests that can be made in a burst
Burst int
// CleanupInterval defines how often to clean up old limiters (how often garbage collection runs)
CleanupInterval time.Duration
// LimiterTTL defines how long a limiter should be kept after last use (age threshold for removal)
LimiterTTL time.Duration
}
// DefaultRateLimiterConfig returns a default configuration
func DefaultRateLimiterConfig() *RateLimiterConfig {
return &RateLimiterConfig{
RequestsPerMinute: 100,
Burst: 120,
CleanupInterval: 5 * time.Minute,
LimiterTTL: 10 * time.Minute,
}
}
// limiterEntry holds a rate limiter and its last access time
type limiterEntry struct {
limiter *rate.Limiter
lastAccess time.Time
}
// APIRateLimiter manages rate limiting for API tokens
type APIRateLimiter struct {
config *RateLimiterConfig
limiters map[string]*limiterEntry
mu sync.RWMutex
stopChan chan struct{}
}
// NewAPIRateLimiter creates a new API rate limiter with the given configuration
func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter {
if config == nil {
config = DefaultRateLimiterConfig()
}
rl := &APIRateLimiter{
config: config,
limiters: make(map[string]*limiterEntry),
stopChan: make(chan struct{}),
}
go rl.cleanupLoop()
return rl
}
// Allow checks if a request for the given key (token) is allowed
func (rl *APIRateLimiter) Allow(key string) bool {
limiter := rl.getLimiter(key)
return limiter.Allow()
}
// Wait blocks until the rate limiter allows another request for the given key
// Returns an error if the context is canceled
func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error {
limiter := rl.getLimiter(key)
return limiter.Wait(ctx)
}
// getLimiter retrieves or creates a rate limiter for the given key
func (rl *APIRateLimiter) getLimiter(key string) *rate.Limiter {
rl.mu.RLock()
entry, exists := rl.limiters[key]
rl.mu.RUnlock()
if exists {
rl.mu.Lock()
entry.lastAccess = time.Now()
rl.mu.Unlock()
return entry.limiter
}
rl.mu.Lock()
defer rl.mu.Unlock()
if entry, exists := rl.limiters[key]; exists {
entry.lastAccess = time.Now()
return entry.limiter
}
requestsPerSecond := rl.config.RequestsPerMinute / 60.0
limiter := rate.NewLimiter(rate.Limit(requestsPerSecond), rl.config.Burst)
rl.limiters[key] = &limiterEntry{
limiter: limiter,
lastAccess: time.Now(),
}
return limiter
}
// cleanupLoop periodically removes old limiters that haven't been used recently
func (rl *APIRateLimiter) cleanupLoop() {
ticker := time.NewTicker(rl.config.CleanupInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
rl.cleanup()
case <-rl.stopChan:
return
}
}
}
// cleanup removes limiters that haven't been used within the TTL period
func (rl *APIRateLimiter) cleanup() {
rl.mu.Lock()
defer rl.mu.Unlock()
now := time.Now()
for key, entry := range rl.limiters {
if now.Sub(entry.lastAccess) > rl.config.LimiterTTL {
delete(rl.limiters, key)
}
}
}
// Stop stops the cleanup goroutine
func (rl *APIRateLimiter) Stop() {
close(rl.stopChan)
}
// Reset removes the rate limiter for a specific key
func (rl *APIRateLimiter) Reset(key string) {
rl.mu.Lock()
defer rl.mu.Unlock()
delete(rl.limiters, key)
}

View File

@@ -125,9 +125,10 @@ type MockAccountManager struct {
UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error)
AllowSyncFunc func(string, uint64) bool
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
AllowSyncFunc func(string, uint64) bool
UpdateAccountPeersFunc func(ctx context.Context, accountID string)
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string)
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
}
func (am *MockAccountManager) CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error {
@@ -986,3 +987,10 @@ func (am *MockAccountManager) AllowSync(key string, hash uint64) bool {
}
return true
}
func (am *MockAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountID string) error {
if am.RecalculateNetworkMapCacheFunc != nil {
return am.RecalculateNetworkMapCacheFunc(ctx, accountID)
}
return nil
}

View File

@@ -83,6 +83,9 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -134,6 +137,9 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -177,6 +183,9 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}

View File

@@ -0,0 +1,137 @@
package server
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
)
func (am *DefaultAccountManager) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
am.enrichAccountFromHolder(account)
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
}
func (am *DefaultAccountManager) getPeerNetworkMapExp(
ctx context.Context,
accountId string,
peerId string,
validatedPeers map[string]struct{},
customZone nbdns.CustomZone,
metrics *telemetry.AccountManagerMetrics,
resourcePolicies map[string][]*types.Policy,
routers map[string]map[string]*routerTypes.NetworkRouter,
) *types.NetworkMap {
account := am.getAccountFromHolderOrInit(accountId)
if account == nil {
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
return &types.NetworkMap{
Network: &types.Network{},
}
}
legacyMap := account.GetPeerNetworkMap(ctx, peerId, customZone, validatedPeers, resourcePolicies, routers, nil)
go func() {
expMap := account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics)
am.compareAndSaveNetworkMaps(ctx, accountId, peerId, expMap, legacyMap)
}()
return legacyMap
}
func (am *DefaultAccountManager) compareAndSaveNetworkMaps(ctx context.Context, accountId, peerId string, expMap, legacyMap *types.NetworkMap) {
expBytes, err := json.Marshal(expMap)
if err != nil {
log.WithContext(ctx).Warnf("failed to marshal experimental network map: %v", err)
return
}
legacyBytes, err := json.Marshal(legacyMap)
if err != nil {
log.WithContext(ctx).Warnf("failed to marshal legacy network map: %v", err)
return
}
if len(expBytes) == len(legacyBytes) {
log.WithContext(ctx).Debugf("network maps are equal for peer %s in account %s (size: %d bytes)", peerId, accountId, len(expBytes))
return
}
timestamp := time.Now().UnixMicro()
baseDir := filepath.Join("debug_networkmaps", accountId, peerId)
if err := os.MkdirAll(baseDir, 0o755); err != nil {
log.WithContext(ctx).Warnf("failed to create debug directory %s: %v", baseDir, err)
return
}
expFile := filepath.Join(baseDir, fmt.Sprintf("exp_networkmap_%d.json", timestamp))
if err := os.WriteFile(expFile, expBytes, 0o644); err != nil {
log.WithContext(ctx).Warnf("failed to write experimental network map to %s: %v", expFile, err)
return
}
legacyFile := filepath.Join(baseDir, fmt.Sprintf("legacy_networkmap_%d.json", timestamp))
if err := os.WriteFile(legacyFile, legacyBytes, 0o644); err != nil {
log.WithContext(ctx).Warnf("failed to write legacy network map to %s: %v", legacyFile, err)
return
}
log.WithContext(ctx).Infof("network maps differ for peer %s in account %s - saved to %s (exp: %d bytes, legacy: %d bytes)", peerId, accountId, baseDir, len(expBytes), len(legacyBytes))
}
func (am *DefaultAccountManager) onPeerAddedUpdNetworkMapCache(account *types.Account, peerId string) error {
am.enrichAccountFromHolder(account)
return account.OnPeerAddedUpdNetworkMapCache(peerId)
}
func (am *DefaultAccountManager) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
am.enrichAccountFromHolder(account)
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
}
func (am *DefaultAccountManager) updatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
account := am.getAccountFromHolder(accountId)
if account == nil {
return
}
account.UpdatePeerInNetworkMapCache(peer)
}
func (am *DefaultAccountManager) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
account.RecalculateNetworkMapCache(validatedPeers)
am.updateAccountInHolder(account)
}
func (am *DefaultAccountManager) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
if am.experimentalNetworkMap(accountId) {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return err
}
validatedPeers, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
return err
}
am.recalculateNetworkMapCache(account, validatedPeers)
}
return nil
}
func (am *DefaultAccountManager) experimentalNetworkMap(accountId string) bool {
_, ok := am.expNewNetworkMapAIDs[accountId]
return am.expNewNetworkMap || ok
}

View File

@@ -177,6 +177,9 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
event()
}
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil

View File

@@ -157,6 +157,9 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
event()
}
if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil {
return nil, err
}
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
return resource, nil
@@ -257,6 +260,9 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
event()
}
if err := m.accountManager.RecalculateNetworkMapCache(ctx, resource.AccountID); err != nil {
return nil, err
}
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
return resource, nil
@@ -331,6 +337,9 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
event()
}
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil

View File

@@ -119,6 +119,9 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network))
if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil {
return nil, err
}
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
return router, nil
@@ -183,6 +186,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network))
if err := m.accountManager.RecalculateNetworkMapCache(ctx, router.AccountID); err != nil {
return nil, err
}
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
return router, nil
@@ -217,6 +223,9 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
event()
if err := m.accountManager.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil

View File

@@ -106,11 +106,6 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("MarkPeerConnected: took %v", time.Since(start))
}()
var peer *nbpeer.Peer
var settings *types.Settings
var expired bool
@@ -145,6 +140,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
if expired {
if am.experimentalNetworkMap(accountID) {
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
am.BufferUpdateAccountPeers(ctx, accountID)
@@ -321,6 +319,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
}
if am.experimentalNetworkMap(accountID) {
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
if peerLabelChanged || requiresPeerUpdates {
am.UpdateAccountPeers(ctx, accountID)
} else if sshChanged {
@@ -381,6 +383,18 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
storeEvent()
}
if am.experimentalNetworkMap(accountID) {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
if err := am.onPeerDeletedUpdNetworkMapCache(account, peerID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", peerID, err)
}
}
if userID != activity.SystemInitiator {
am.BufferUpdateAccountPeers(ctx, accountID)
}
@@ -417,7 +431,16 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin
return nil, err
}
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
var networkMap *types.NetworkMap
if am.experimentalNetworkMap(peer.AccountID) {
networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil, resourcePolicies, routers)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, resourcePolicies, routers, nil)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
@@ -690,6 +713,17 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
if am.experimentalNetworkMap(accountID) {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if err := am.onPeerAddedUpdNetworkMapCache(account, newPeer.ID); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
}
}
am.BufferUpdateAccountPeers(ctx, accountID)
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
@@ -708,11 +742,6 @@ func getPeerIPDNSLabel(ip net.IP, peerHostName string) (string, error) {
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("SyncPeer: took %v", time.Since(start))
}()
var peer *nbpeer.Peer
var peerNotValid bool
var isStatusChanged bool
@@ -776,6 +805,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
}
if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) {
if am.experimentalNetworkMap(accountID) {
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
am.BufferUpdateAccountPeers(ctx, accountID)
}
@@ -831,6 +863,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err
}
startTransaction := time.Now()
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, login.WireGuardPubKey)
if err != nil {
@@ -900,8 +933,15 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, err
}
log.WithContext(ctx).Debugf("LoginPeer: transaction took %v", time.Since(startTransaction))
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
if am.experimentalNetworkMap(accountID) {
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
startBuffer := time.Now()
am.BufferUpdateAccountPeers(ctx, accountID)
log.WithContext(ctx).Debugf("LoginPeer: BufferUpdateAccountPeers took %v", time.Since(startBuffer))
}
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
@@ -997,11 +1037,6 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
}
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("getValidatedPeerWithMap: took %s", time.Since(start))
}()
if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
@@ -1014,9 +1049,17 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return peer, emptyMap, nil, nil
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
var (
account *types.Account
err error
)
if am.experimentalNetworkMap(accountID) {
account = am.getAccountFromHolderOrInit(accountID)
} else {
account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
}
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
@@ -1024,10 +1067,12 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return nil, nil, nil, err
}
startPosture := time.Now()
postureChecks, err := am.getPeerPostureChecks(account, peer.ID)
if err != nil {
return nil, nil, nil, err
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
customZone := account.GetPeersCustomZone(ctx, am.GetDNSDomain(account.Settings))
@@ -1037,7 +1082,16 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
return nil, nil, nil, err
}
networkMap := account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics())
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
var networkMap *types.NetworkMap
if am.experimentalNetworkMap(accountID) {
networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics(), resourcePolicies, routers)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
@@ -1167,11 +1221,18 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string) {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
return
var (
account *types.Account
err error
)
if am.experimentalNetworkMap(accountID) {
account = am.getAccountFromHolderOrInit(accountID)
} else {
account, err = am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
return
}
}
globalStart := time.Now()
@@ -1204,6 +1265,10 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
if am.experimentalNetworkMap(accountID) {
am.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
}
proxyNetworkMaps, err := am.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
@@ -1241,7 +1306,13 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
var remotePeerNetworkMap *types.NetworkMap
if am.experimentalNetworkMap(accountID) {
remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics(), resourcePolicies, routers)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
}
am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
start = time.Now()
@@ -1257,7 +1328,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update})
}(peer)
}
@@ -1351,7 +1422,13 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
return
}
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
var remotePeerNetworkMap *types.NetworkMap
if am.experimentalNetworkMap(accountId) {
remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics(), resourcePolicies, routers)
} else {
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
@@ -1368,7 +1445,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), dnsForwarderPortMinVersion)
update := toSyncResponse(ctx, nil, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
}
// getNextPeerExpiration returns the minimum duration in which the next peer of the account will expire if it was found.
@@ -1580,7 +1657,6 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
},
},
},
NetworkMap: &types.NetworkMap{},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
peerDeletedEvents = append(peerDeletedEvents, func() {

View File

@@ -168,6 +168,15 @@ func TestPeer_SessionExpired(t *testing.T) {
}
func TestAccountManager_GetNetworkMap(t *testing.T) {
testGetNetworkMapGeneral(t)
}
func TestAccountManager_GetNetworkMap_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
testGetNetworkMapGeneral(t)
}
func testGetNetworkMapGeneral(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
@@ -1003,7 +1012,16 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
}
}
func TestUpdateAccountPeers_Experimental(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
testUpdateAccountPeers(t)
}
func TestUpdateAccountPeers(t *testing.T) {
testUpdateAccountPeers(t)
}
func testUpdateAccountPeers(t *testing.T) {
testCases := []struct {
name string
peers int
@@ -1043,8 +1061,8 @@ func TestUpdateAccountPeers(t *testing.T) {
for _, channel := range peerChannels {
update := <-channel
assert.Nil(t, update.Update.NetbirdConfig)
assert.Equal(t, tc.peers, len(update.NetworkMap.Peers))
assert.Equal(t, tc.peers*2, len(update.NetworkMap.FirewallRules))
assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers))
assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules))
}
})
}
@@ -1548,6 +1566,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
}
func Test_LoginPeer(t *testing.T) {
t.Setenv(envNewNetworkMapBuilder, "true")
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")
}

View File

@@ -77,6 +77,9 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -120,6 +123,9 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}

View File

@@ -266,7 +266,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
expectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "0.0.0.0",
PeerIP: "100.65.14.88",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
@@ -274,7 +274,103 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
PolicyID: "RuleDefault",
},
{
PeerIP: "0.0.0.0",
PeerIP: "100.65.14.88",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.62.5",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.62.5",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.32.206",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.32.206",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.250.202",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.250.202",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.13.186",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.13.186",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.29.55",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "all",
Port: "",
PolicyID: "RuleDefault",
},
{
PeerIP: "100.65.29.55",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "all",
@@ -833,10 +929,58 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 1)
assert.Len(t, firewallRules, 7)
expectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "0.0.0.0",
PeerIP: "100.65.80.39",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.14.88",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.62.5",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.32.206",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.13.186",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.29.55",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "80",
PolicyID: "RuleSwarm",
},
{
PeerIP: "100.65.21.56",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",

View File

@@ -80,6 +80,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID)
}

View File

@@ -192,6 +192,9 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return nil, err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -246,6 +249,9 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}
@@ -289,6 +295,9 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
if updateAccountPeers {
if err := am.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return err
}
am.UpdateAccountPeers(ctx, accountID)
}

View File

@@ -5,6 +5,9 @@ package settings
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/integrations/extra_settings"
@@ -45,6 +48,11 @@ func (m *managerImpl) GetExtraSettingsManager() extra_settings.Manager {
}
func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) (*types.Settings, error) {
start := time.Now()
defer func() {
log.WithContext(ctx).Debugf("GetSettings took %s", time.Since(start))
}()
if userID != activity.SystemInitiator {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Read)
if err != nil {

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,951 @@
package store
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"sort"
"sync"
"testing"
"time"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"github.com/jackc/pgx/v5/pgxpool"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/testutil"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
)
func (s *SqlStore) GetAccountSlow(ctx context.Context, accountID string) (*types.Account, error) {
start := time.Now()
defer func() {
elapsed := time.Since(start)
if elapsed > 1*time.Second {
log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed)
}
}()
var account types.Account
result := s.db.Model(&account).
Omit("GroupsG").
Preload("UsersG.PATsG"). // have to be specified as this is nested reference
Preload(clause.Associations).
Take(&account, idQueryCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.NewGetAccountFromStoreError(result.Error)
}
// we have to manually preload policy rules as it seems that gorm preloading doesn't do it for us
for i, policy := range account.Policies {
var rules []*types.PolicyRule
err := s.db.Model(&types.PolicyRule{}).Find(&rules, "policy_id = ?", policy.ID).Error
if err != nil {
return nil, status.Errorf(status.NotFound, "rule not found")
}
account.Policies[i].Rules = rules
}
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
for _, key := range account.SetupKeysG {
account.SetupKeys[key.Key] = key.Copy()
}
account.SetupKeysG = nil
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for _, peer := range account.PeersG {
account.Peers[peer.ID] = peer.Copy()
}
account.PeersG = nil
account.Users = make(map[string]*types.User, len(account.UsersG))
for _, user := range account.UsersG {
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs))
for _, pat := range user.PATsG {
user.PATs[pat.ID] = pat.Copy()
}
account.Users[user.Id] = user.Copy()
}
account.UsersG = nil
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for _, group := range account.GroupsG {
account.Groups[group.ID] = group.Copy()
}
account.GroupsG = nil
var groupPeers []types.GroupPeer
s.db.Model(&types.GroupPeer{}).Where("account_id = ?", accountID).
Find(&groupPeers)
for _, groupPeer := range groupPeers {
if group, ok := account.Groups[groupPeer.GroupID]; ok {
group.Peers = append(group.Peers, groupPeer.PeerID)
} else {
log.WithContext(ctx).Warnf("group %s not found for group peer %s in account %s", groupPeer.GroupID, groupPeer.PeerID, accountID)
}
}
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for _, route := range account.RoutesG {
account.Routes[route.ID] = route.Copy()
}
account.RoutesG = nil
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for _, ns := range account.NameServerGroupsG {
account.NameServerGroups[ns.ID] = ns.Copy()
}
account.NameServerGroupsG = nil
return &account, nil
}
func (s *SqlStore) GetAccountGormOpt(ctx context.Context, accountID string) (*types.Account, error) {
start := time.Now()
defer func() {
elapsed := time.Since(start)
if elapsed > 1*time.Second {
log.WithContext(ctx).Tracef("GetAccount for account %s exceeded 1s, took: %v", accountID, elapsed)
}
}()
var account types.Account
result := s.db.Model(&account).
Preload("UsersG.PATsG"). // have to be specified as this is nested reference
Preload("Policies.Rules").
Preload("SetupKeysG").
Preload("PeersG").
Preload("UsersG").
Preload("GroupsG.GroupPeers").
Preload("RoutesG").
Preload("NameServerGroupsG").
Preload("PostureChecks").
Preload("Networks").
Preload("NetworkRouters").
Preload("NetworkResources").
Preload("Onboarding").
Take(&account, idQueryCondition, accountID)
if result.Error != nil {
log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
return nil, status.NewGetAccountFromStoreError(result.Error)
}
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
for _, key := range account.SetupKeysG {
if key.UpdatedAt.IsZero() {
key.UpdatedAt = key.CreatedAt
}
if key.AutoGroups == nil {
key.AutoGroups = []string{}
}
account.SetupKeys[key.Key] = &key
}
account.SetupKeysG = nil
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for _, peer := range account.PeersG {
account.Peers[peer.ID] = &peer
}
account.PeersG = nil
account.Users = make(map[string]*types.User, len(account.UsersG))
for _, user := range account.UsersG {
user.PATs = make(map[string]*types.PersonalAccessToken, len(user.PATs))
for _, pat := range user.PATsG {
pat.UserID = ""
user.PATs[pat.ID] = &pat
}
if user.AutoGroups == nil {
user.AutoGroups = []string{}
}
account.Users[user.Id] = &user
user.PATsG = nil
}
account.UsersG = nil
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for _, group := range account.GroupsG {
group.Peers = make([]string, len(group.GroupPeers))
for i, gp := range group.GroupPeers {
group.Peers[i] = gp.PeerID
}
if group.Resources == nil {
group.Resources = []types.Resource{}
}
account.Groups[group.ID] = group
}
account.GroupsG = nil
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for _, route := range account.RoutesG {
account.Routes[route.ID] = &route
}
account.RoutesG = nil
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for _, ns := range account.NameServerGroupsG {
ns.AccountID = ""
if ns.NameServers == nil {
ns.NameServers = []nbdns.NameServer{}
}
if ns.Groups == nil {
ns.Groups = []string{}
}
if ns.Domains == nil {
ns.Domains = []string{}
}
account.NameServerGroups[ns.ID] = &ns
}
account.NameServerGroupsG = nil
return &account, nil
}
func connectDBforTest(ctx context.Context, dsn string) (*pgxpool.Pool, error) {
config, err := pgxpool.ParseConfig(dsn)
if err != nil {
return nil, fmt.Errorf("unable to parse database config: %w", err)
}
config.MaxConns = 12
config.MinConns = 2
config.MaxConnLifetime = time.Hour
config.HealthCheckPeriod = time.Minute
pool, err := pgxpool.NewWithConfig(ctx, config)
if err != nil {
return nil, fmt.Errorf("unable to create connection pool: %w", err)
}
if err := pool.Ping(ctx); err != nil {
pool.Close()
return nil, fmt.Errorf("unable to ping database: %w", err)
}
return pool, nil
}
func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
cleanup, dsn, err := testutil.CreatePostgresTestContainer()
if err != nil {
b.Fatalf("failed to create test container: %v", err)
}
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
b.Fatalf("failed to connect database: %v", err)
}
pool, err := connectDBforTest(context.Background(), dsn)
if err != nil {
b.Fatalf("failed to connect database: %v", err)
}
models := []interface{}{
&types.Account{}, &types.SetupKey{}, &nbpeer.Peer{}, &types.User{},
&types.PersonalAccessToken{}, &types.Group{}, &types.GroupPeer{},
&types.Policy{}, &types.PolicyRule{}, &route.Route{},
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
&types.AccountOnboarding{},
}
for i := len(models) - 1; i >= 0; i-- {
err := db.Migrator().DropTable(models[i])
if err != nil {
b.Fatalf("failed to drop table: %v", err)
}
}
err = db.AutoMigrate(models...)
if err != nil {
b.Fatalf("failed to migrate database: %v", err)
}
store := &SqlStore{
db: db,
pool: pool,
}
const (
accountID = "benchmark-account-id"
numUsers = 20
numPatsPerUser = 3
numSetupKeys = 25
numPeers = 200
numGroups = 30
numPolicies = 50
numRulesPerPolicy = 10
numRoutes = 40
numNSGroups = 10
numPostureChecks = 15
numNetworks = 5
numNetworkRouters = 5
numNetworkResources = 10
)
_, ipNet, _ := net.ParseCIDR("100.64.0.0/10")
acc := types.Account{
Id: accountID,
CreatedBy: "benchmark-user",
CreatedAt: time.Now(),
Domain: "benchmark.com",
IsDomainPrimaryAccount: true,
Network: &types.Network{
Identifier: "benchmark-net",
Net: *ipNet,
Serial: 1,
},
DNSSettings: types.DNSSettings{
DisabledManagementGroups: []string{"group-disabled-1"},
},
Settings: &types.Settings{},
}
if err := db.Create(&acc).Error; err != nil {
b.Fatalf("create account: %v", err)
}
var setupKeys []types.SetupKey
for i := 0; i < numSetupKeys; i++ {
setupKeys = append(setupKeys, types.SetupKey{
Id: fmt.Sprintf("keyid-%d", i),
AccountID: accountID,
Key: fmt.Sprintf("key-%d", i),
Name: fmt.Sprintf("Benchmark Key %d", i),
ExpiresAt: &time.Time{},
})
}
if err := db.Create(&setupKeys).Error; err != nil {
b.Fatalf("create setup keys: %v", err)
}
var peers []nbpeer.Peer
for i := 0; i < numPeers; i++ {
peers = append(peers, nbpeer.Peer{
ID: fmt.Sprintf("peer-%d", i),
AccountID: accountID,
Key: fmt.Sprintf("peerkey-%d", i),
IP: net.ParseIP(fmt.Sprintf("100.64.0.%d", i+1)),
Name: fmt.Sprintf("peer-name-%d", i),
Status: &nbpeer.PeerStatus{Connected: i%2 == 0, LastSeen: time.Now()},
})
}
if err := db.Create(&peers).Error; err != nil {
b.Fatalf("create peers: %v", err)
}
for i := 0; i < numUsers; i++ {
userID := fmt.Sprintf("user-%d", i)
user := types.User{Id: userID, AccountID: accountID}
if err := db.Create(&user).Error; err != nil {
b.Fatalf("create user %s: %v", userID, err)
}
var pats []types.PersonalAccessToken
for j := 0; j < numPatsPerUser; j++ {
pats = append(pats, types.PersonalAccessToken{
ID: fmt.Sprintf("pat-%d-%d", i, j),
UserID: userID,
Name: fmt.Sprintf("PAT %d for User %d", j, i),
})
}
if err := db.Create(&pats).Error; err != nil {
b.Fatalf("create pats for user %s: %v", userID, err)
}
}
var groups []*types.Group
for i := 0; i < numGroups; i++ {
groups = append(groups, &types.Group{
ID: fmt.Sprintf("group-%d", i),
AccountID: accountID,
Name: fmt.Sprintf("Group %d", i),
})
}
if err := db.Create(&groups).Error; err != nil {
b.Fatalf("create groups: %v", err)
}
for i := 0; i < numPolicies; i++ {
policyID := fmt.Sprintf("policy-%d", i)
policy := types.Policy{ID: policyID, AccountID: accountID, Name: fmt.Sprintf("Policy %d", i), Enabled: true}
if err := db.Create(&policy).Error; err != nil {
b.Fatalf("create policy %s: %v", policyID, err)
}
var rules []*types.PolicyRule
for j := 0; j < numRulesPerPolicy; j++ {
rules = append(rules, &types.PolicyRule{
ID: fmt.Sprintf("rule-%d-%d", i, j),
PolicyID: policyID,
Name: fmt.Sprintf("Rule %d for Policy %d", j, i),
Enabled: true,
Protocol: "all",
})
}
if err := db.Create(&rules).Error; err != nil {
b.Fatalf("create rules for policy %s: %v", policyID, err)
}
}
var routes []route.Route
for i := 0; i < numRoutes; i++ {
routes = append(routes, route.Route{
ID: route.ID(fmt.Sprintf("route-%d", i)),
AccountID: accountID,
Description: fmt.Sprintf("Route %d", i),
Network: netip.MustParsePrefix(fmt.Sprintf("192.168.%d.0/24", i)),
Enabled: true,
})
}
if err := db.Create(&routes).Error; err != nil {
b.Fatalf("create routes: %v", err)
}
var nsGroups []nbdns.NameServerGroup
for i := 0; i < numNSGroups; i++ {
nsGroups = append(nsGroups, nbdns.NameServerGroup{
ID: fmt.Sprintf("nsg-%d", i),
AccountID: accountID,
Name: fmt.Sprintf("NS Group %d", i),
Description: "Benchmark NS Group",
Enabled: true,
})
}
if err := db.Create(&nsGroups).Error; err != nil {
b.Fatalf("create nsgroups: %v", err)
}
var postureChecks []*posture.Checks
for i := 0; i < numPostureChecks; i++ {
postureChecks = append(postureChecks, &posture.Checks{
ID: fmt.Sprintf("pc-%d", i),
AccountID: accountID,
Name: fmt.Sprintf("Posture Check %d", i),
})
}
if err := db.Create(&postureChecks).Error; err != nil {
b.Fatalf("create posture checks: %v", err)
}
var networks []*networkTypes.Network
for i := 0; i < numNetworks; i++ {
networks = append(networks, &networkTypes.Network{
ID: fmt.Sprintf("nettype-%d", i),
AccountID: accountID,
Name: fmt.Sprintf("Network Type %d", i),
})
}
if err := db.Create(&networks).Error; err != nil {
b.Fatalf("create networks: %v", err)
}
var networkRouters []*routerTypes.NetworkRouter
for i := 0; i < numNetworkRouters; i++ {
networkRouters = append(networkRouters, &routerTypes.NetworkRouter{
ID: fmt.Sprintf("router-%d", i),
AccountID: accountID,
NetworkID: networks[i%numNetworks].ID,
Peer: peers[i%numPeers].ID,
})
}
if err := db.Create(&networkRouters).Error; err != nil {
b.Fatalf("create network routers: %v", err)
}
var networkResources []*resourceTypes.NetworkResource
for i := 0; i < numNetworkResources; i++ {
networkResources = append(networkResources, &resourceTypes.NetworkResource{
ID: fmt.Sprintf("resource-%d", i),
AccountID: accountID,
NetworkID: networks[i%numNetworks].ID,
Name: fmt.Sprintf("Resource %d", i),
})
}
if err := db.Create(&networkResources).Error; err != nil {
b.Fatalf("create network resources: %v", err)
}
onboarding := types.AccountOnboarding{
AccountID: accountID,
OnboardingFlowPending: true,
}
if err := db.Create(&onboarding).Error; err != nil {
b.Fatalf("create onboarding: %v", err)
}
return store, cleanup, accountID
}
func BenchmarkGetAccount(b *testing.B) {
store, cleanup, accountID := setupBenchmarkDB(b)
defer cleanup()
ctx := context.Background()
b.ResetTimer()
b.ReportAllocs()
b.Run("old", func(b *testing.B) {
for range b.N {
_, err := store.GetAccountSlow(ctx, accountID)
if err != nil {
b.Fatalf("GetAccountSlow failed: %v", err)
}
}
})
b.Run("gorm opt", func(b *testing.B) {
for range b.N {
_, err := store.GetAccountGormOpt(ctx, accountID)
if err != nil {
b.Fatalf("GetAccountFast failed: %v", err)
}
}
})
b.Run("raw", func(b *testing.B) {
for range b.N {
_, err := store.GetAccount(ctx, accountID)
if err != nil {
b.Fatalf("GetAccountPureSQL failed: %v", err)
}
}
})
store.pool.Close()
}
func TestAccountEquivalence(t *testing.T) {
store, cleanup, accountID := setupBenchmarkDB(t)
defer cleanup()
ctx := context.Background()
type getAccountFunc func(context.Context, string) (*types.Account, error)
tests := []struct {
name string
expectedF getAccountFunc
actualF getAccountFunc
}{
{"old vs new", store.GetAccountSlow, store.GetAccountGormOpt},
{"old vs raw", store.GetAccountSlow, store.GetAccount},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
expected, errOld := tt.expectedF(ctx, accountID)
assert.NoError(t, errOld, "expected function should not return an error")
assert.NotNil(t, expected, "expected should not be nil")
actual, errNew := tt.actualF(ctx, accountID)
assert.NoError(t, errNew, "actual function should not return an error")
assert.NotNil(t, actual, "actual should not be nil")
testAccountEquivalence(t, expected, actual)
})
}
expected, errOld := store.GetAccountSlow(ctx, accountID)
assert.NoError(t, errOld, "GetAccountSlow should not return an error")
assert.NotNil(t, expected, "expected should not be nil")
actual, errNew := store.GetAccount(ctx, accountID)
assert.NoError(t, errNew, "GetAccount (new) should not return an error")
assert.NotNil(t, actual, "actual should not be nil")
}
func testAccountEquivalence(t *testing.T, expected, actual *types.Account) {
assert.Equal(t, expected.Id, actual.Id, "Account IDs should be equal")
assert.Equal(t, expected.CreatedBy, actual.CreatedBy, "Account CreatedBy fields should be equal")
assert.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second, "Account CreatedAt timestamps should be within a second")
assert.Equal(t, expected.Domain, actual.Domain, "Account Domains should be equal")
assert.Equal(t, expected.DomainCategory, actual.DomainCategory, "Account DomainCategories should be equal")
assert.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount, "Account IsDomainPrimaryAccount flags should be equal")
assert.Equal(t, expected.Network, actual.Network, "Embedded Account Network structs should be equal")
assert.Equal(t, expected.DNSSettings, actual.DNSSettings, "Embedded Account DNSSettings structs should be equal")
assert.Equal(t, expected.Onboarding, actual.Onboarding, "Embedded Account Onboarding structs should be equal")
assert.Len(t, actual.SetupKeys, len(expected.SetupKeys), "SetupKeys maps should have the same number of elements")
for key, oldVal := range expected.SetupKeys {
newVal, ok := actual.SetupKeys[key]
assert.True(t, ok, "SetupKey with key '%s' should exist in new account", key)
assert.Equal(t, *oldVal, *newVal, "SetupKey with key '%s' should be equal", key)
}
assert.Len(t, actual.Peers, len(expected.Peers), "Peers maps should have the same number of elements")
for key, oldVal := range expected.Peers {
newVal, ok := actual.Peers[key]
assert.True(t, ok, "Peer with ID '%s' should exist in new account", key)
assert.Equal(t, *oldVal, *newVal, "Peer with ID '%s' should be equal", key)
}
assert.Len(t, actual.Users, len(expected.Users), "Users maps should have the same number of elements")
for key, oldUser := range expected.Users {
newUser, ok := actual.Users[key]
assert.True(t, ok, "User with ID '%s' should exist in new account", key)
assert.Len(t, newUser.PATs, len(oldUser.PATs), "PATs map for user '%s' should have the same size", key)
for patKey, oldPAT := range oldUser.PATs {
newPAT, patOk := newUser.PATs[patKey]
assert.True(t, patOk, "PAT with ID '%s' for user '%s' should exist in new user object", patKey, key)
assert.Equal(t, *oldPAT, *newPAT, "PAT with ID '%s' for user '%s' should be equal", patKey, key)
}
oldUser.PATs = nil
newUser.PATs = nil
assert.Equal(t, *oldUser, *newUser, "User struct for ID '%s' (without PATs) should be equal", key)
}
assert.Len(t, actual.Groups, len(expected.Groups), "Groups maps should have the same number of elements")
for key, oldVal := range expected.Groups {
newVal, ok := actual.Groups[key]
assert.True(t, ok, "Group with ID '%s' should exist in new account", key)
sort.Strings(oldVal.Peers)
sort.Strings(newVal.Peers)
assert.Equal(t, *oldVal, *newVal, "Group with ID '%s' should be equal", key)
}
assert.Len(t, actual.Routes, len(expected.Routes), "Routes maps should have the same number of elements")
for key, oldVal := range expected.Routes {
newVal, ok := actual.Routes[key]
assert.True(t, ok, "Route with ID '%s' should exist in new account", key)
assert.Equal(t, *oldVal, *newVal, "Route with ID '%s' should be equal", key)
}
assert.Len(t, actual.NameServerGroups, len(expected.NameServerGroups), "NameServerGroups maps should have the same number of elements")
for key, oldVal := range expected.NameServerGroups {
newVal, ok := actual.NameServerGroups[key]
assert.True(t, ok, "NameServerGroup with ID '%s' should exist in new account", key)
assert.Equal(t, *oldVal, *newVal, "NameServerGroup with ID '%s' should be equal", key)
}
assert.Len(t, actual.Policies, len(expected.Policies), "Policies slices should have the same number of elements")
sort.Slice(expected.Policies, func(i, j int) bool { return expected.Policies[i].ID < expected.Policies[j].ID })
sort.Slice(actual.Policies, func(i, j int) bool { return actual.Policies[i].ID < actual.Policies[j].ID })
for i := range expected.Policies {
sort.Slice(expected.Policies[i].Rules, func(j, k int) bool { return expected.Policies[i].Rules[j].ID < expected.Policies[i].Rules[k].ID })
sort.Slice(actual.Policies[i].Rules, func(j, k int) bool { return actual.Policies[i].Rules[j].ID < actual.Policies[i].Rules[k].ID })
assert.Equal(t, *expected.Policies[i], *actual.Policies[i], "Policy with ID '%s' should be equal", expected.Policies[i].ID)
}
assert.Len(t, actual.PostureChecks, len(expected.PostureChecks), "PostureChecks slices should have the same number of elements")
sort.Slice(expected.PostureChecks, func(i, j int) bool { return expected.PostureChecks[i].ID < expected.PostureChecks[j].ID })
sort.Slice(actual.PostureChecks, func(i, j int) bool { return actual.PostureChecks[i].ID < actual.PostureChecks[j].ID })
for i := range expected.PostureChecks {
assert.Equal(t, *expected.PostureChecks[i], *actual.PostureChecks[i], "PostureCheck with ID '%s' should be equal", expected.PostureChecks[i].ID)
}
assert.Len(t, actual.Networks, len(expected.Networks), "Networks slices should have the same number of elements")
sort.Slice(expected.Networks, func(i, j int) bool { return expected.Networks[i].ID < expected.Networks[j].ID })
sort.Slice(actual.Networks, func(i, j int) bool { return actual.Networks[i].ID < actual.Networks[j].ID })
for i := range expected.Networks {
assert.Equal(t, *expected.Networks[i], *actual.Networks[i], "Network with ID '%s' should be equal", expected.Networks[i].ID)
}
assert.Len(t, actual.NetworkRouters, len(expected.NetworkRouters), "NetworkRouters slices should have the same number of elements")
sort.Slice(expected.NetworkRouters, func(i, j int) bool { return expected.NetworkRouters[i].ID < expected.NetworkRouters[j].ID })
sort.Slice(actual.NetworkRouters, func(i, j int) bool { return actual.NetworkRouters[i].ID < actual.NetworkRouters[j].ID })
for i := range expected.NetworkRouters {
assert.Equal(t, *expected.NetworkRouters[i], *actual.NetworkRouters[i], "NetworkRouter with ID '%s' should be equal", expected.NetworkRouters[i].ID)
}
assert.Len(t, actual.NetworkResources, len(expected.NetworkResources), "NetworkResources slices should have the same number of elements")
sort.Slice(expected.NetworkResources, func(i, j int) bool { return expected.NetworkResources[i].ID < expected.NetworkResources[j].ID })
sort.Slice(actual.NetworkResources, func(i, j int) bool { return actual.NetworkResources[i].ID < actual.NetworkResources[j].ID })
for i := range expected.NetworkResources {
assert.Equal(t, *expected.NetworkResources[i], *actual.NetworkResources[i], "NetworkResource with ID '%s' should be equal", expected.NetworkResources[i].ID)
}
}
func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) {
account, err := s.getAccount(ctx, accountID)
if err != nil {
return nil, err
}
var wg sync.WaitGroup
errChan := make(chan error, 12)
wg.Add(1)
go func() {
defer wg.Done()
keys, err := s.getSetupKeys(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.SetupKeysG = keys
}()
wg.Add(1)
go func() {
defer wg.Done()
peers, err := s.getPeers(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.PeersG = peers
}()
wg.Add(1)
go func() {
defer wg.Done()
users, err := s.getUsers(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.UsersG = users
}()
wg.Add(1)
go func() {
defer wg.Done()
groups, err := s.getGroups(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.GroupsG = groups
}()
wg.Add(1)
go func() {
defer wg.Done()
policies, err := s.getPolicies(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.Policies = policies
}()
wg.Add(1)
go func() {
defer wg.Done()
routes, err := s.getRoutes(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.RoutesG = routes
}()
wg.Add(1)
go func() {
defer wg.Done()
nsgs, err := s.getNameServerGroups(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.NameServerGroupsG = nsgs
}()
wg.Add(1)
go func() {
defer wg.Done()
checks, err := s.getPostureChecks(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.PostureChecks = checks
}()
wg.Add(1)
go func() {
defer wg.Done()
networks, err := s.getNetworks(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.Networks = networks
}()
wg.Add(1)
go func() {
defer wg.Done()
routers, err := s.getNetworkRouters(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.NetworkRouters = routers
}()
wg.Add(1)
go func() {
defer wg.Done()
resources, err := s.getNetworkResources(ctx, accountID)
if err != nil {
errChan <- err
return
}
account.NetworkResources = resources
}()
wg.Add(1)
go func() {
defer wg.Done()
err := s.getAccountOnboarding(ctx, accountID, account)
if err != nil {
errChan <- err
return
}
}()
wg.Wait()
close(errChan)
for e := range errChan {
if e != nil {
return nil, e
}
}
var userIDs []string
for _, u := range account.UsersG {
userIDs = append(userIDs, u.Id)
}
var policyIDs []string
for _, p := range account.Policies {
policyIDs = append(policyIDs, p.ID)
}
var groupIDs []string
for _, g := range account.GroupsG {
groupIDs = append(groupIDs, g.ID)
}
wg.Add(3)
errChan = make(chan error, 3)
var pats []types.PersonalAccessToken
go func() {
defer wg.Done()
var err error
pats, err = s.getPersonalAccessTokens(ctx, userIDs)
if err != nil {
errChan <- err
}
}()
var rules []*types.PolicyRule
go func() {
defer wg.Done()
var err error
rules, err = s.getPolicyRules(ctx, policyIDs)
if err != nil {
errChan <- err
}
}()
var groupPeers []types.GroupPeer
go func() {
defer wg.Done()
var err error
groupPeers, err = s.getGroupPeers(ctx, groupIDs)
if err != nil {
errChan <- err
}
}()
wg.Wait()
close(errChan)
for e := range errChan {
if e != nil {
return nil, e
}
}
patsByUserID := make(map[string][]*types.PersonalAccessToken)
for i := range pats {
pat := &pats[i]
patsByUserID[pat.UserID] = append(patsByUserID[pat.UserID], pat)
pat.UserID = ""
}
rulesByPolicyID := make(map[string][]*types.PolicyRule)
for _, rule := range rules {
rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule)
}
peersByGroupID := make(map[string][]string)
for _, gp := range groupPeers {
peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID)
}
account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG))
for i := range account.SetupKeysG {
key := &account.SetupKeysG[i]
account.SetupKeys[key.Key] = key
}
account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG))
for i := range account.PeersG {
peer := &account.PeersG[i]
account.Peers[peer.ID] = peer
}
account.Users = make(map[string]*types.User, len(account.UsersG))
for i := range account.UsersG {
user := &account.UsersG[i]
user.PATs = make(map[string]*types.PersonalAccessToken)
if userPats, ok := patsByUserID[user.Id]; ok {
for j := range userPats {
pat := userPats[j]
user.PATs[pat.ID] = pat
}
}
account.Users[user.Id] = user
}
for i := range account.Policies {
policy := account.Policies[i]
if policyRules, ok := rulesByPolicyID[policy.ID]; ok {
policy.Rules = policyRules
}
}
account.Groups = make(map[string]*types.Group, len(account.GroupsG))
for i := range account.GroupsG {
group := account.GroupsG[i]
if peerIDs, ok := peersByGroupID[group.ID]; ok {
group.Peers = peerIDs
}
account.Groups[group.ID] = group
}
account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG))
for i := range account.RoutesG {
route := &account.RoutesG[i]
account.Routes[route.ID] = route
}
account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG))
for i := range account.NameServerGroupsG {
nsg := &account.NameServerGroupsG[i]
nsg.AccountID = ""
account.NameServerGroups[nsg.ID] = nsg
}
account.SetupKeysG = nil
account.PeersG = nil
account.UsersG = nil
account.GroupsG = nil
account.RoutesG = nil
account.NameServerGroupsG = nil
return account, nil
}

View File

@@ -468,6 +468,9 @@ func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind types.Engine)
closeConnection := func() {
cleanup()
store.Close(ctx)
if store.pool != nil {
store.pool.Close()
}
}
return store, closeConnection, nil
@@ -487,12 +490,18 @@ func newReusedPostgresStore(ctx context.Context, store *SqlStore, kind types.Eng
return nil, nil, fmt.Errorf("%s is not set", postgresDsnEnv)
}
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
db, err := openDBWithRetry(dsn, kind, 5)
if err != nil {
return nil, nil, fmt.Errorf("failed to open postgres connection: %v", err)
}
dsn, cleanup, err := createRandomDB(dsn, db, kind)
sqlDB, _ := db.DB()
if sqlDB != nil {
sqlDB.Close()
}
if err != nil {
return nil, nil, err
}
@@ -519,12 +528,22 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine
return nil, nil, fmt.Errorf("%s is not set", mysqlDsnEnv)
}
db, err := gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
db, err := openDBWithRetry(dsn, kind, 5)
if err != nil {
return nil, nil, fmt.Errorf("failed to open mysql connection: %v", err)
}
sqlDB, err := db.DB()
if err != nil {
return nil, nil, fmt.Errorf("failed to get underlying sql.DB: %v", err)
}
sqlDB.SetMaxOpenConns(1)
sqlDB.SetMaxIdleConns(1)
dsn, cleanup, err := createRandomDB(dsn, db, kind)
sqlDB.Close()
if err != nil {
return nil, nil, err
}
@@ -537,6 +556,31 @@ func newReusedMysqlStore(ctx context.Context, store *SqlStore, kind types.Engine
return store, cleanup, nil
}
func openDBWithRetry(dsn string, engine types.Engine, maxRetries int) (*gorm.DB, error) {
var db *gorm.DB
var err error
for i := range maxRetries {
switch engine {
case types.PostgresStoreEngine:
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
case types.MysqlStoreEngine:
db, err = gorm.Open(mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{})
}
if err == nil {
return db, nil
}
if i < maxRetries-1 {
waitTime := time.Duration(100*(i+1)) * time.Millisecond
time.Sleep(waitTime)
}
}
return nil, err
}
func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(), error) {
dbName := fmt.Sprintf("test_db_%s", strings.ReplaceAll(uuid.New().String(), "-", "_"))
@@ -544,21 +588,63 @@ func createRandomDB(dsn string, db *gorm.DB, engine types.Engine) (string, func(
return "", nil, fmt.Errorf("failed to create database: %v", err)
}
var err error
originalDSN := dsn
cleanup := func() {
var dropDB *gorm.DB
var err error
switch engine {
case types.PostgresStoreEngine:
err = db.Exec(fmt.Sprintf("DROP DATABASE %s WITH (FORCE)", dbName)).Error
dropDB, err = gorm.Open(postgres.Open(originalDSN), &gorm.Config{
SkipDefaultTransaction: true,
PrepareStmt: false,
})
if err != nil {
log.Errorf("failed to connect for dropping database %s: %v", dbName, err)
return
}
defer func() {
if sqlDB, _ := dropDB.DB(); sqlDB != nil {
sqlDB.Close()
}
}()
if sqlDB, _ := dropDB.DB(); sqlDB != nil {
sqlDB.SetMaxOpenConns(1)
sqlDB.SetMaxIdleConns(0)
sqlDB.SetConnMaxLifetime(time.Second)
}
err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE)", dbName)).Error
case types.MysqlStoreEngine:
// err = killMySQLConnections(dsn, dbName)
err = db.Exec(fmt.Sprintf("DROP DATABASE %s", dbName)).Error
dropDB, err = gorm.Open(mysql.Open(originalDSN+"?charset=utf8&parseTime=True&loc=Local"), &gorm.Config{
SkipDefaultTransaction: true,
PrepareStmt: false,
})
if err != nil {
log.Errorf("failed to connect for dropping database %s: %v", dbName, err)
return
}
defer func() {
if sqlDB, _ := dropDB.DB(); sqlDB != nil {
sqlDB.Close()
}
}()
if sqlDB, _ := dropDB.DB(); sqlDB != nil {
sqlDB.SetMaxOpenConns(1)
sqlDB.SetMaxIdleConns(0)
sqlDB.SetConnMaxLifetime(time.Second)
}
err = dropDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", dbName)).Error
}
if err != nil {
log.Errorf("failed to drop database %s: %v", dbName, err)
panic(err)
}
sqlDB, _ := db.DB()
_ = sqlDB.Close()
}
return replaceDBName(dsn, dbName), cleanup, nil

View File

@@ -8,6 +8,7 @@ import (
"slices"
"strconv"
"strings"
"sync"
"time"
"github.com/hashicorp/go-multierror"
@@ -87,6 +88,13 @@ type Account struct {
NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"`
NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"`
Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"`
NetworkMapCache *NetworkMapBuilder `gorm:"-"`
nmapInitOnce *sync.Once `gorm:"-"`
}
func (a *Account) InitOnce() {
a.nmapInitOnce = &sync.Once{}
}
// this class is used by gorm only
@@ -257,7 +265,6 @@ func (a *Account) GetPeerNetworkMap(
metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
start := time.Now()
peer := a.Peers[peerID]
if peer == nil {
return &NetworkMap{
@@ -890,6 +897,8 @@ func (a *Account) Copy() *Account {
NetworkRouters: networkRouters,
NetworkResources: networkResources,
Onboarding: a.Onboarding,
NetworkMapCache: a.NetworkMapCache,
nmapInitOnce: a.nmapInitOnce,
}
}
@@ -1049,14 +1058,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
rules := make([]*FirewallRule, 0)
peers := make([]*nbpeer.Peer, 0)
all, err := a.GetGroupAll()
if err != nil {
log.WithContext(ctx).Errorf("failed to get group all: %v", err)
all = &Group{}
}
return func(rule *PolicyRule, groupPeers []*nbpeer.Peer, direction int) {
isAll := (len(all.Peers) - 1) == len(groupPeers)
for _, peer := range groupPeers {
if peer == nil {
continue
@@ -1075,10 +1077,6 @@ func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer
Protocol: string(rule.Protocol),
}
if isAll {
fr.PeerIP = "0.0.0.0"
}
ruleID := rule.ID + fr.PeerIP + strconv.Itoa(direction) +
fr.Protocol + fr.Action + strings.Join(rule.Ports, ",")
if _, ok := rulesExists[ruleID]; ok {

View File

@@ -0,0 +1,43 @@
package types
import (
"context"
"sync"
)
type Holder struct {
mu sync.RWMutex
accounts map[string]*Account
}
func NewHolder() *Holder {
return &Holder{
accounts: make(map[string]*Account),
}
}
func (h *Holder) GetAccount(id string) *Account {
h.mu.RLock()
defer h.mu.RUnlock()
return h.accounts[id]
}
func (h *Holder) AddAccount(account *Account) {
h.mu.Lock()
defer h.mu.Unlock()
h.accounts[account.Id] = account
}
func (h *Holder) LoadOrStoreFunc(id string, accGetter func(context.Context, string) (*Account, error)) (*Account, error) {
h.mu.Lock()
defer h.mu.Unlock()
if acc, ok := h.accounts[id]; ok {
return acc, nil
}
account, err := accGetter(context.Background(), id)
if err != nil {
return nil, err
}
h.accounts[id] = account
return account, nil
}

View File

@@ -0,0 +1,58 @@
package types
import (
"context"
nbdns "github.com/netbirdio/netbird/dns"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
)
func (a *Account) initNetworkMapBuilder(validatedPeers map[string]struct{}) {
if a.NetworkMapCache != nil {
return
}
a.nmapInitOnce.Do(func() {
a.NetworkMapCache = NewNetworkMapBuilder(a, validatedPeers)
})
}
func (a *Account) InitNetworkMapBuilderIfNeeded(validatedPeers map[string]struct{}) {
a.initNetworkMapBuilder(validatedPeers)
}
func (a *Account) GetPeerNetworkMapExp(
ctx context.Context,
peerID string,
peersCustomZone nbdns.CustomZone,
validatedPeers map[string]struct{},
metrics *telemetry.AccountManagerMetrics,
) *NetworkMap {
a.initNetworkMapBuilder(validatedPeers)
return a.NetworkMapCache.GetPeerNetworkMap(ctx, peerID, peersCustomZone, validatedPeers, metrics)
}
func (a *Account) OnPeerAddedUpdNetworkMapCache(peerId string) error {
if a.NetworkMapCache == nil {
return nil
}
return a.NetworkMapCache.OnPeerAddedIncremental(peerId)
}
func (a *Account) OnPeerDeletedUpdNetworkMapCache(peerId string) error {
if a.NetworkMapCache == nil {
return nil
}
return a.NetworkMapCache.OnPeerDeleted(peerId)
}
func (a *Account) UpdatePeerInNetworkMapCache(peer *nbpeer.Peer) {
if a.NetworkMapCache == nil {
return
}
a.NetworkMapCache.UpdatePeer(peer)
}
func (a *Account) RecalculateNetworkMapCache(validatedPeers map[string]struct{}) {
a.initNetworkMapBuilder(validatedPeers)
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -7,16 +7,14 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/proto"
)
const channelBufferSize = 100
type UpdateMessage struct {
Update *proto.SyncResponse
NetworkMap *types.NetworkMap
Update *proto.SyncResponse
}
type PeersUpdateManager struct {

View File

@@ -595,7 +595,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
// prepareUserUpdateEvents prepares a list user update events based on the changes between the old and new user data.
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool) []func() {
func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, accountID string, initiatorUserID string, oldUser, newUser *types.User, transferredOwnerRole bool, removedGroupIDs, addedGroupIDs []string, tx store.Store) []func() {
var eventsToStore []func()
if oldUser.IsBlocked() != newUser.IsBlocked() {
@@ -621,6 +621,35 @@ func (am *DefaultAccountManager) prepareUserUpdateEvents(ctx context.Context, ac
})
}
addedGroups, err := tx.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, addedGroupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to get added groups for user %s update event: %v", oldUser.Id, err)
}
for _, group := range addedGroups {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName,
}
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupAddedToUser, meta)
})
}
removedGroups, err := tx.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, removedGroupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to get removed groups for user %s update event: %v", oldUser.Id, err)
}
for _, group := range removedGroups {
meta := map[string]any{
"group": group.Name, "group_id": group.ID,
"is_service_user": oldUser.IsServiceUser, "user_name": oldUser.ServiceUserName,
}
eventsToStore = append(eventsToStore, func() {
am.StoreEvent(ctx, oldUser.Id, oldUser.Id, accountID, activity.GroupRemovedFromUser, meta)
})
}
return eventsToStore
}
@@ -667,9 +696,10 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
peersToExpire = userPeers
}
var removedGroups, addedGroups []string
if update.AutoGroups != nil && settings.GroupsPropagationEnabled {
removedGroups := util.Difference(oldUser.AutoGroups, update.AutoGroups)
addedGroups := util.Difference(update.AutoGroups, oldUser.AutoGroups)
removedGroups = util.Difference(oldUser.AutoGroups, update.AutoGroups)
addedGroups = util.Difference(update.AutoGroups, oldUser.AutoGroups)
for _, peer := range userPeers {
for _, groupID := range removedGroups {
if err := transaction.RemovePeerFromGroup(ctx, peer.ID, groupID); err != nil {
@@ -685,7 +715,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
}
updateAccountPeers := len(userPeers) > 0
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole)
userEventsToAdd := am.prepareUserUpdateEvents(ctx, updatedUser.AccountID, initiatorUserId, oldUser, updatedUser, transferredOwnerRole, removedGroups, addedGroups, transaction)
return updateAccountPeers, updatedUser, peersToExpire, userEventsToAdd, nil
}
@@ -961,6 +991,10 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
peer.UserID, peer.ID, accountID,
activity.PeerLoginExpired, peer.EventMeta(dnsDomain),
)
if am.experimentalNetworkMap(accountID) {
am.updatePeerInNetworkMapCache(peer.AccountID, peer)
}
}
if len(peerIDs) != 0 {

View File

@@ -124,6 +124,7 @@ func (r *Route) EventMeta() map[string]any {
func (r *Route) Copy() *Route {
route := &Route{
ID: r.ID,
AccountID: r.AccountID,
Description: r.Description,
NetID: r.NetID,
Network: r.Network,

View File

@@ -106,6 +106,8 @@ func WriteError(ctx context.Context, err error, w http.ResponseWriter) {
httpStatus = http.StatusUnauthorized
case status.BadRequest:
httpStatus = http.StatusBadRequest
case status.TooManyRequests:
httpStatus = http.StatusTooManyRequests
default:
}
msg = strings.ToLower(err.Error())

View File

@@ -37,6 +37,9 @@ const (
// Unauthenticated indicates that user is not authenticated due to absence of valid credentials
Unauthenticated Type = 10
// TooManyRequests indicates that the user has sent too many requests in a given amount of time (rate limiting)
TooManyRequests Type = 11
)
// Type is a type of the Error