Compare commits

...

16 Commits

Author SHA1 Message Date
Maycon Santos
5f5d597c59 fix GetSetupKeyByID 2025-06-26 13:54:04 +02:00
Maycon Santos
b85aad07d4 add getTXWithLockStrength method 2025-06-26 13:44:00 +02:00
Maycon Santos
7398836c2e Stop using locking share for read calls to avoid deadlocks
Added peer.userID index
2025-06-26 12:24:42 +02:00
Krzysztof Nazarewski (kdn)
34ac4e4b5a [misc] fix: self-hosting: the wrong default for NETBIRD_AUTH_PKCE_LOGIN_FLAG (#4055)
* fix: self-hosting: the wrong default for NETBIRD_AUTH_PKCE_LOGIN_FLAG

fixes https://github.com/netbirdio/netbird/issues/4054

* un-quote the number

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>

---------

Co-authored-by: Maycon Santos <mlsmaycon@gmail.com>
2025-06-26 10:45:00 +02:00
Pascal Fischer
52ff9d9602 [management] remove unused transaction (#4053) 2025-06-26 01:34:22 +02:00
Pascal Fischer
1b73fae46e [management] add breakdown of network map calculation metrics (#4020) 2025-06-25 11:46:35 +02:00
Viktor Liu
d897365abc [client] Don't open cmd.exe during MSI actions (#4041) 2025-06-24 21:32:37 +02:00
Viktor Liu
f37aa2cc9d [misc] Specify netbird binary location in Dockerfiles (#4024) 2025-06-23 10:09:02 +02:00
Maycon Santos
5343bee7b2 [management] check and log on new management version (#4029)
This PR enhances the version checker to send a custom User-Agent header when polling for updates, and configures both the management CLI and client UI to use distinct agents. 

- NewUpdate now takes an `httpAgent` string to set the User-Agent header.
- `fetchVersion` builds a custom HTTP request (instead of `http.Get`) and sets the User-Agent.
- Management CLI and client UI now pass `"nb/management"` and `"nb/client-ui"` respectively to NewUpdate.
- Tests updated to supply an `httpAgent` constant.
- Logs if there is a new version available for management
2025-06-22 16:44:33 +02:00
Maycon Santos
870e29db63 [misc] add additional metrics (#4028)
* add additional metrics

we are collecting active rosenpass, ssh from the client side
we are also collecting active user peers and active users

* remove duplicated
2025-06-22 13:44:25 +02:00
Maycon Santos
08e9b05d51 [client] close windows when process needs to exit (#4027)
This PR fixes a bug by ensuring that the advanced settings and re-authentication windows are closed appropriately when the main GUI process exits.

- Updated runSelfCommand calls throughout the UI to pass a context parameter.
- Modified runSelfCommand’s signature and its internal command invocation to use exec.CommandContext for proper cancellation handling.
2025-06-22 10:33:04 +02:00
hakansa
3581648071 [client] Refactor showLoginURL to improve error handling and connection status checks (#4026)
This PR refactors showLoginURL to improve error handling and connection status checks by delaying the login fetch until user interaction and closing the pop-up if already connected.

- Moved s.login(false) call into the click handler to defer network I/O.
- Added a conn.Status check after opening the URL to skip reconnection if already connected.
- Enhanced error logs for missing verification URLs and service status failures.
2025-06-22 10:03:58 +02:00
Viktor Liu
2a51609436 [client] Handle lazy routing peers that are part of HA groups (#3943)
* Activate new lazy routing peers if the HA group is active
* Prevent lazy peers going to idle if HA group members are active (#3948)
2025-06-20 18:07:19 +02:00
Pascal Fischer
83457f8b99 [management] add transaction for integrated validator groups update and primary account update (#4014) 2025-06-20 12:13:24 +02:00
Pascal Fischer
b45284f086 [management] export ephemeral peer flag on api (#4004) 2025-06-19 16:46:56 +02:00
Bethuel Mmbaga
e9016aecea [management] Add backward compatibility for older clients without firewall rules port range support (#4003)
Adds backward compatibility for clients with versions prior to v0.48.0 that do not support port range firewall rules.

- Skips generation of firewall rules with multi-port ranges for older clients
- Preserves support for single-port ranges by treating them as individual port rules, ensuring compatibility with older clients
2025-06-19 13:07:06 +03:00
54 changed files with 920 additions and 690 deletions

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.0.18" SIGN_PIPE_VER: "v0.0.19"
GORELEASER_VER: "v2.3.2" GORELEASER_VER: "v2.3.2"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"

View File

@@ -1,6 +1,9 @@
FROM alpine:3.21.3 FROM alpine:3.21.3
# iproute2: busybox doesn't display ip rules properly # iproute2: busybox doesn't display ip rules properly
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
ARG NETBIRD_BINARY=netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
ENV NB_FOREGROUND_MODE=true ENV NB_FOREGROUND_MODE=true
ENTRYPOINT [ "/usr/local/bin/netbird","up"] ENTRYPOINT [ "/usr/local/bin/netbird","up"]
COPY netbird /usr/local/bin/netbird

View File

@@ -1,6 +1,7 @@
FROM alpine:3.21.0 FROM alpine:3.21.0
COPY netbird /usr/local/bin/netbird ARG NETBIRD_BINARY=netbird
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
RUN apk add --no-cache ca-certificates \ RUN apk add --no-cache ca-certificates \
&& adduser -D -h /var/lib/netbird netbird && adduser -D -h /var/lib/netbird netbird

View File

@@ -175,7 +175,7 @@ func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Co
PeerConnID: conn.ConnID(), PeerConnID: conn.ConnID(),
Log: conn.Log, Log: conn.Log,
} }
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg) excluded, err := e.lazyConnMgr.AddPeer(e.lazyCtx, lazyPeerCfg)
if err != nil { if err != nil {
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err) conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
if err := conn.Open(ctx); err != nil { if err := conn.Open(ctx); err != nil {

View File

@@ -68,3 +68,8 @@ func (i *Monitor) PauseTimer() {
func (i *Monitor) ResetTimer() { func (i *Monitor) ResetTimer() {
i.timer.Reset(i.inactivityThreshold) i.timer.Reset(i.inactivityThreshold)
} }
func (i *Monitor) ResetMonitor(ctx context.Context, timeoutChan chan peer.ConnID) {
i.Stop()
go i.Start(ctx, timeoutChan)
}

View File

@@ -58,7 +58,7 @@ type Manager struct {
// Route HA group management // Route HA group management
peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to peerToHAGroups map[string][]route.HAUniqueID // peer ID -> HA groups they belong to
haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group haGroupToPeers map[route.HAUniqueID][]string // HA group -> peer IDs in the group
routesMu sync.RWMutex // protects route mappings routesMu sync.RWMutex
onInactive chan peerid.ConnID onInactive chan peerid.ConnID
} }
@@ -146,7 +146,7 @@ func (m *Manager) Start(ctx context.Context) {
case peerConnID := <-m.activityManager.OnActivityChan: case peerConnID := <-m.activityManager.OnActivityChan:
m.onPeerActivity(ctx, peerConnID) m.onPeerActivity(ctx, peerConnID)
case peerConnID := <-m.onInactive: case peerConnID := <-m.onInactive:
m.onPeerInactivityTimedOut(peerConnID) m.onPeerInactivityTimedOut(ctx, peerConnID)
} }
} }
} }
@@ -197,7 +197,7 @@ func (m *Manager) ExcludePeer(ctx context.Context, peerConfigs []lazyconn.PeerCo
return added return added
} }
func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) { func (m *Manager) AddPeer(ctx context.Context, peerCfg lazyconn.PeerConfig) (bool, error) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
@@ -225,6 +225,13 @@ func (m *Manager) AddPeer(peerCfg lazyconn.PeerConfig) (bool, error) {
peerCfg: &peerCfg, peerCfg: &peerCfg,
expectedWatcher: watcherActivity, expectedWatcher: watcherActivity,
} }
// Check if this peer should be activated because its HA group peers are active
if group, ok := m.shouldActivateNewPeer(peerCfg.PublicKey); ok {
peerCfg.Log.Debugf("peer belongs to active HA group %s, will activate immediately", group)
m.activateNewPeerInActiveGroup(ctx, peerCfg)
}
return false, nil return false, nil
} }
@@ -315,36 +322,38 @@ func (m *Manager) activateSinglePeer(ctx context.Context, cfg *lazyconn.PeerConf
// activateHAGroupPeers activates all peers in HA groups that the given peer belongs to // activateHAGroupPeers activates all peers in HA groups that the given peer belongs to
func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) { func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string) {
var peersToActivate []string
m.routesMu.RLock() m.routesMu.RLock()
haGroups := m.peerToHAGroups[triggerPeerID] haGroups := m.peerToHAGroups[triggerPeerID]
m.routesMu.RUnlock()
if len(haGroups) == 0 { if len(haGroups) == 0 {
m.routesMu.RUnlock()
log.Debugf("peer %s is not part of any HA groups", triggerPeerID) log.Debugf("peer %s is not part of any HA groups", triggerPeerID)
return return
} }
activatedCount := 0
for _, haGroup := range haGroups { for _, haGroup := range haGroups {
m.routesMu.RLock()
peers := m.haGroupToPeers[haGroup] peers := m.haGroupToPeers[haGroup]
m.routesMu.RUnlock()
for _, peerID := range peers { for _, peerID := range peers {
if peerID == triggerPeerID { if peerID != triggerPeerID {
continue peersToActivate = append(peersToActivate, peerID)
} }
}
}
m.routesMu.RUnlock()
cfg, mp := m.getPeerForActivation(peerID) activatedCount := 0
if cfg == nil { for _, peerID := range peersToActivate {
continue cfg, mp := m.getPeerForActivation(peerID)
} if cfg == nil {
continue
}
if m.activateSinglePeer(ctx, cfg, mp) { if m.activateSinglePeer(ctx, cfg, mp) {
activatedCount++ activatedCount++
cfg.Log.Infof("activated peer as part of HA group %s (triggered by %s)", haGroup, triggerPeerID) cfg.Log.Infof("activated peer as part of HA group (triggered by %s)", triggerPeerID)
m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey) m.peerStore.PeerConnOpen(m.engineCtx, cfg.PublicKey)
}
} }
} }
@@ -354,6 +363,51 @@ func (m *Manager) activateHAGroupPeers(ctx context.Context, triggerPeerID string
} }
} }
// shouldActivateNewPeer checks if a newly added peer should be activated
// because other peers in its HA groups are already active
func (m *Manager) shouldActivateNewPeer(peerID string) (route.HAUniqueID, bool) {
m.routesMu.RLock()
defer m.routesMu.RUnlock()
haGroups := m.peerToHAGroups[peerID]
if len(haGroups) == 0 {
return "", false
}
for _, haGroup := range haGroups {
peers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range peers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
if mp, ok := m.managedPeersByConnID[cfg.PeerConnID]; ok && mp.expectedWatcher == watcherInactivity {
return haGroup, true
}
}
}
return "", false
}
// activateNewPeerInActiveGroup activates a newly added peer that should be active due to HA group
func (m *Manager) activateNewPeerInActiveGroup(ctx context.Context, peerCfg lazyconn.PeerConfig) {
mp, ok := m.managedPeersByConnID[peerCfg.PeerConnID]
if !ok {
return
}
if !m.activateSinglePeer(ctx, &peerCfg, mp) {
return
}
peerCfg.Log.Infof("activated newly added peer due to active HA group peers")
m.peerStore.PeerConnOpen(m.engineCtx, peerCfg.PublicKey)
}
func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error { func (m *Manager) addActivePeer(ctx context.Context, peerCfg lazyconn.PeerConfig) error {
if _, ok := m.managedPeers[peerCfg.PublicKey]; ok { if _, ok := m.managedPeers[peerCfg.PublicKey]; ok {
peerCfg.Log.Warnf("peer already managed") peerCfg.Log.Warnf("peer already managed")
@@ -415,6 +469,48 @@ func (m *Manager) close() {
log.Infof("lazy connection manager closed") log.Infof("lazy connection manager closed")
} }
// shouldDeferIdleForHA checks if peer should stay connected due to HA group requirements
func (m *Manager) shouldDeferIdleForHA(peerID string) bool {
m.routesMu.RLock()
defer m.routesMu.RUnlock()
haGroups := m.peerToHAGroups[peerID]
if len(haGroups) == 0 {
return false
}
for _, haGroup := range haGroups {
groupPeers := m.haGroupToPeers[haGroup]
for _, groupPeerID := range groupPeers {
if groupPeerID == peerID {
continue
}
cfg, ok := m.managedPeers[groupPeerID]
if !ok {
continue
}
groupMp, ok := m.managedPeersByConnID[cfg.PeerConnID]
if !ok {
continue
}
if groupMp.expectedWatcher != watcherInactivity {
continue
}
// Other member is still connected, defer idle
if peer, ok := m.peerStore.PeerConn(groupPeerID); ok && peer.IsConnected() {
return true
}
}
}
return false
}
func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) { func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
@@ -441,7 +537,7 @@ func (m *Manager) onPeerActivity(ctx context.Context, peerConnID peerid.ConnID)
m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey) m.peerStore.PeerConnOpen(m.engineCtx, mp.peerCfg.PublicKey)
} }
func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) { func (m *Manager) onPeerInactivityTimedOut(ctx context.Context, peerConnID peerid.ConnID) {
m.managedPeersMu.Lock() m.managedPeersMu.Lock()
defer m.managedPeersMu.Unlock() defer m.managedPeersMu.Unlock()
@@ -456,6 +552,17 @@ func (m *Manager) onPeerInactivityTimedOut(peerConnID peerid.ConnID) {
return return
} }
if m.shouldDeferIdleForHA(mp.peerCfg.PublicKey) {
iw, ok := m.inactivityMonitors[peerConnID]
if ok {
mp.peerCfg.Log.Debugf("resetting inactivity timer due to HA group requirements")
iw.ResetMonitor(ctx, m.onInactive)
} else {
mp.peerCfg.Log.Errorf("inactivity monitor not found for HA defer reset")
}
return
}
mp.peerCfg.Log.Infof("connection timed out") mp.peerCfg.Log.Infof("connection timed out")
// this is blocking operation, potentially can be optimized // this is blocking operation, potentially can be optimized
@@ -489,7 +596,7 @@ func (m *Manager) onPeerConnected(peerConnID peerid.ConnID) {
iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID] iw, ok := m.inactivityMonitors[mp.peerCfg.PeerConnID]
if !ok { if !ok {
mp.peerCfg.Log.Errorf("inactivity monitor not found for peer") mp.peerCfg.Log.Warnf("inactivity monitor not found for peer")
return return
} }

View File

@@ -317,12 +317,12 @@ func (conn *Conn) WgConfig() WgConfig {
return conn.config.WgConfig return conn.config.WgConfig
} }
// IsConnected unit tests only // IsConnected returns true if the peer is connected
// refactor unit test to use status recorder use refactor status recorded to manage connection status in peer.Conn
func (conn *Conn) IsConnected() bool { func (conn *Conn) IsConnected() bool {
conn.mu.Lock() conn.mu.Lock()
defer conn.mu.Unlock() defer conn.mu.Unlock()
return conn.currentConnPriority != conntype.None
return conn.evalStatus() == StatusConnected
} }
func (conn *Conn) GetKey() string { func (conn *Conn) GetKey() string {

View File

@@ -1,8 +1,10 @@
<Wix <Wix
xmlns="http://wixtoolset.org/schemas/v4/wxs"> xmlns="http://wixtoolset.org/schemas/v4/wxs"
xmlns:util="http://wixtoolset.org/schemas/v4/wxs/util">
<Package Name="NetBird" Version="$(env.NETBIRD_VERSION)" Manufacturer="NetBird GmbH" Language="1033" UpgradeCode="6456ec4e-3ad6-4b9b-a2be-98e81cb21ccf" <Package Name="NetBird" Version="$(env.NETBIRD_VERSION)" Manufacturer="NetBird GmbH" Language="1033" UpgradeCode="6456ec4e-3ad6-4b9b-a2be-98e81cb21ccf"
InstallerVersion="500" Compressed="yes" Codepage="utf-8" > InstallerVersion="500" Compressed="yes" Codepage="utf-8" >
<MediaTemplate EmbedCab="yes" /> <MediaTemplate EmbedCab="yes" />
<Feature Id="NetbirdFeature" Title="Netbird" Level="1"> <Feature Id="NetbirdFeature" Title="Netbird" Level="1">
@@ -46,29 +48,10 @@
<ComponentRef Id="NetbirdFiles" /> <ComponentRef Id="NetbirdFiles" />
</ComponentGroup> </ComponentGroup>
<Property Id="cmd" Value="cmd.exe"/> <util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
<util:CloseApplication Id="CloseNetBirdUI" CloseMessage="no" Target="netbird-ui.exe" RebootPrompt="no" />
<CustomAction Id="KillDaemon"
ExeCommand='/c "taskkill /im netbird.exe"'
Execute="deferred"
Property="cmd"
Impersonate="no"
Return="ignore"
/>
<CustomAction Id="KillUI"
ExeCommand='/c "taskkill /im netbird-ui.exe"'
Execute="deferred"
Property="cmd"
Impersonate="no"
Return="ignore"
/>
<InstallExecuteSequence>
<!-- For Uninstallation -->
<Custom Action="KillDaemon" Before="RemoveFiles" Condition="Installed"/>
<Custom Action="KillUI" After="KillDaemon" Condition="Installed"/>
</InstallExecuteSequence>
<!-- Icons --> <!-- Icons -->
<Icon Id="NetbirdIcon" SourceFile=".\client\ui\assets\netbird.ico" /> <Icon Id="NetbirdIcon" SourceFile=".\client\ui\assets\netbird.ico" />

View File

@@ -280,7 +280,7 @@ func newServiceClient(addr string, logFile string, a fyne.App, showSettings bool
showAdvancedSettings: showSettings, showAdvancedSettings: showSettings,
showNetworks: showNetworks, showNetworks: showNetworks,
update: version.NewUpdate(), update: version.NewUpdate("nb/client-ui"),
} }
s.eventHandler = newEventHandler(s) s.eventHandler = newEventHandler(s)
@@ -879,7 +879,7 @@ func (s *serviceClient) onUpdateAvailable() {
func (s *serviceClient) onSessionExpire() { func (s *serviceClient) onSessionExpire() {
s.sendNotification = true s.sendNotification = true
if s.sendNotification { if s.sendNotification {
s.eventHandler.runSelfCommand("login-url", "true") s.eventHandler.runSelfCommand(s.ctx, "login-url", "true")
s.sendNotification = false s.sendNotification = false
} }
} }
@@ -992,21 +992,6 @@ func (s *serviceClient) restartClient(loginRequest *proto.LoginRequest) error {
// showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL. // showLoginURL creates a borderless window styled like a pop-up in the top-right corner using s.wLoginURL.
func (s *serviceClient) showLoginURL() { func (s *serviceClient) showLoginURL() {
resp, err := s.login(false)
if err != nil {
log.Errorf("failed to fetch login URL: %v", err)
return
}
verificationURL := resp.VerificationURIComplete
if verificationURL == "" {
verificationURL = resp.VerificationURI
}
if verificationURL == "" {
log.Error("no verification URL provided in the login response")
return
}
resIcon := fyne.NewStaticResource("netbird.png", iconAbout) resIcon := fyne.NewStaticResource("netbird.png", iconAbout)
if s.wLoginURL == nil { if s.wLoginURL == nil {
@@ -1025,6 +1010,21 @@ func (s *serviceClient) showLoginURL() {
return return
} }
resp, err := s.login(false)
if err != nil {
log.Errorf("failed to fetch login URL: %v", err)
return
}
verificationURL := resp.VerificationURIComplete
if verificationURL == "" {
verificationURL = resp.VerificationURI
}
if verificationURL == "" {
log.Error("no verification URL provided in the login response")
return
}
if err := openURL(verificationURL); err != nil { if err := openURL(verificationURL); err != nil {
log.Errorf("failed to open login URL: %v", err) log.Errorf("failed to open login URL: %v", err)
return return
@@ -1038,7 +1038,19 @@ func (s *serviceClient) showLoginURL() {
} }
label.SetText("Re-authentication successful.\nReconnecting") label.SetText("Re-authentication successful.\nReconnecting")
time.Sleep(300 * time.Millisecond) status, err := conn.Status(s.ctx, &proto.StatusRequest{})
if err != nil {
log.Errorf("get service status: %v", err)
return
}
if status.Status == string(internal.StatusConnected) {
label.SetText("Already connected.\nClosing this window.")
time.Sleep(2 * time.Second)
s.wLoginURL.Close()
return
}
_, err = conn.Up(s.ctx, &proto.UpRequest{}) _, err = conn.Up(s.ctx, &proto.UpRequest{})
if err != nil { if err != nil {
label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.") label.SetText("Reconnecting failed, please create \na debug bundle in the settings and contact support.")

View File

@@ -122,7 +122,7 @@ func (h *eventHandler) handleAdvancedSettingsClick() {
go func() { go func() {
defer h.client.mAdvancedSettings.Enable() defer h.client.mAdvancedSettings.Enable()
defer h.client.getSrvConfig() defer h.client.getSrvConfig()
h.runSelfCommand("settings", "true") h.runSelfCommand(h.client.ctx, "settings", "true")
}() }()
} }
@@ -130,7 +130,7 @@ func (h *eventHandler) handleCreateDebugBundleClick() {
h.client.mCreateDebugBundle.Disable() h.client.mCreateDebugBundle.Disable()
go func() { go func() {
defer h.client.mCreateDebugBundle.Enable() defer h.client.mCreateDebugBundle.Enable()
h.runSelfCommand("debug", "true") h.runSelfCommand(h.client.ctx, "debug", "true")
}() }()
} }
@@ -154,7 +154,7 @@ func (h *eventHandler) handleNetworksClick() {
h.client.mNetworks.Disable() h.client.mNetworks.Disable()
go func() { go func() {
defer h.client.mNetworks.Enable() defer h.client.mNetworks.Enable()
h.runSelfCommand("networks", "true") h.runSelfCommand(h.client.ctx, "networks", "true")
}() }()
} }
@@ -172,14 +172,14 @@ func (h *eventHandler) updateConfigWithErr() {
} }
} }
func (h *eventHandler) runSelfCommand(command, arg string) { func (h *eventHandler) runSelfCommand(ctx context.Context, command, arg string) {
proc, err := os.Executable() proc, err := os.Executable()
if err != nil { if err != nil {
log.Errorf("error getting executable path: %v", err) log.Errorf("error getting executable path: %v", err)
return return
} }
cmd := exec.Command(proc, cmd := exec.CommandContext(ctx, proc,
fmt.Sprintf("--%s=%s", command, arg), fmt.Sprintf("--%s=%s", command, arg),
fmt.Sprintf("--daemon-addr=%s", h.client.addr), fmt.Sprintf("--daemon-addr=%s", h.client.addr),
) )

View File

@@ -60,7 +60,7 @@ NETBIRD_TOKEN_SOURCE=${NETBIRD_TOKEN_SOURCE:-accessToken}
NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"} NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS=${NETBIRD_AUTH_PKCE_REDIRECT_URL_PORTS:-"53000"}
NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false} NETBIRD_AUTH_PKCE_USE_ID_TOKEN=${NETBIRD_AUTH_PKCE_USE_ID_TOKEN:-false}
NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=${NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN:-false} NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN=${NETBIRD_AUTH_PKCE_DISABLE_PROMPT_LOGIN:-false}
NETBIRD_AUTH_PKCE_LOGIN_FLAG=${NETBIRD_AUTH_PKCE_LOGIN_FLAG:-1} NETBIRD_AUTH_PKCE_LOGIN_FLAG=${NETBIRD_AUTH_PKCE_LOGIN_FLAG:-0}
NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE NETBIRD_AUTH_PKCE_AUDIENCE=$NETBIRD_AUTH_AUDIENCE
# Dashboard # Dashboard

View File

@@ -357,6 +357,13 @@ var (
log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String()) log.WithContext(ctx).Infof("running HTTP server and gRPC server on the same port: %s", listener.Addr().String())
serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled) serveGRPCWithHTTP(ctx, listener, rootHandler, tlsEnabled)
update := version.NewUpdate("nb/management")
update.SetDaemonVersion(version.NetbirdVersion())
update.SetOnUpdateListener(func() {
log.WithContext(ctx).Infof("your management version, \"%s\", is outdated, a new management version is available. Learn more here: https://github.com/netbirdio/netbird/releases", version.NetbirdVersion())
})
defer update.StopWatch()
SetupCloseHandler() SetupCloseHandler()
<-stopCh <-stopCh

View File

@@ -370,7 +370,7 @@ func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, tra
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
} }
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil { if err != nil {
return err return err
} }
@@ -708,7 +708,7 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
// AccountExists checks if an account exists. // AccountExists checks if an account exists.
func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) { func (am *DefaultAccountManager) AccountExists(ctx context.Context, accountID string) (bool, error) {
return am.Store.AccountExists(ctx, store.LockingStrengthShare, accountID) return am.Store.AccountExists(ctx, store.LockingStrengthNone, accountID)
} }
// GetAccountIDByUserID retrieves the account ID based on the userID provided. // GetAccountIDByUserID retrieves the account ID based on the userID provided.
@@ -720,7 +720,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI
return "", status.Errorf(status.NotFound, "no valid userID provided") return "", status.Errorf(status.NotFound, "no valid userID provided")
} }
accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID) accountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) account, err := am.GetOrCreateAccountByUser(ctx, userID, domain)
@@ -775,7 +775,7 @@ func (am *DefaultAccountManager) loadAccount(ctx context.Context, accountID any)
log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID) log.WithContext(ctx).Debugf("account %s not found in cache, reloading", accountID)
accountIDString := fmt.Sprintf("%v", accountID) accountIDString := fmt.Sprintf("%v", accountID)
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountIDString) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountIDString)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@@ -829,7 +829,7 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil // lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) { func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -859,7 +859,7 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP, // add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta // or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID) log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
return nil, err return nil, err
@@ -1010,7 +1010,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlockAccount() defer unlockAccount()
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, accountID) accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return err return err
@@ -1020,7 +1020,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx
return nil return nil
} }
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting user: %v", err) log.WithContext(ctx).Errorf("error getting user: %v", err)
return err return err
@@ -1185,7 +1185,7 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountMeta(ctx, store.LockingStrengthNone, accountID)
} }
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
@@ -1205,7 +1205,7 @@ func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, u
return "", "", err return "", "", err
} }
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
// this is not really possible because we got an account by user ID // this is not really possible because we got an account by user ID
return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId) return "", "", status.Errorf(status.NotFound, "user %s not found", userAuth.UserId)
@@ -1237,7 +1237,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
return nil return nil
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return err return err
} }
@@ -1263,12 +1263,12 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
var hasChanges bool var hasChanges bool
var user *types.User var user *types.User
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
return fmt.Errorf("error getting user: %w", err) return fmt.Errorf("error getting user: %w", err)
} }
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId) groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return fmt.Errorf("error getting account groups: %w", err) return fmt.Errorf("error getting account groups: %w", err)
} }
@@ -1298,7 +1298,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
// Propagate changes to peers if group propagation is enabled // Propagate changes to peers if group propagation is enabled
if settings.GroupsPropagationEnabled { if settings.GroupsPropagationEnabled {
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, userAuth.AccountId) groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return fmt.Errorf("error getting account groups: %w", err) return fmt.Errorf("error getting account groups: %w", err)
} }
@@ -1308,7 +1308,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
groupsMap[group.ID] = group groupsMap[group.ID] = group
} }
peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, userAuth.AccountId, userAuth.UserId) peers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, userAuth.AccountId, userAuth.UserId)
if err != nil { if err != nil {
return fmt.Errorf("error getting user peers: %w", err) return fmt.Errorf("error getting user peers: %w", err)
} }
@@ -1340,7 +1340,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
} }
for _, g := range addNewGroups { for _, g := range addNewGroups {
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g) group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else { } else {
@@ -1353,7 +1353,7 @@ func (am *DefaultAccountManager) SyncUserJWTGroups(ctx context.Context, userAuth
} }
for _, g := range removeOldGroups { for _, g := range removeOldGroups {
group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthShare, userAuth.AccountId, g) group, err := am.Store.GetGroupByID(ctx, store.LockingStrengthNone, userAuth.AccountId, g)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId) log.WithContext(ctx).Debugf("group %s not found while saving user activity event of account %s", g, userAuth.AccountId)
} else { } else {
@@ -1414,7 +1414,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
} }
if userAuth.IsChild { if userAuth.IsChild {
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, userAuth.AccountId) exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil || !exists { if err != nil || !exists {
return "", err return "", err
} }
@@ -1438,7 +1438,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return "", err return "", err
} }
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err return "", err
@@ -1459,7 +1459,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context
return am.addNewPrivateAccount(ctx, domainAccountID, userAuth) return am.addNewPrivateAccount(ctx, domainAccountID, userAuth)
} }
func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) { func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Context, domain string) (string, context.CancelFunc, error) {
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
@@ -1474,7 +1474,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
cancel := am.Store.AcquireGlobalLock(ctx) cancel := am.Store.AcquireGlobalLock(ctx)
// check again if the domain has a primary account because of simultaneous requests // check again if the domain has a primary account because of simultaneous requests
domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) domainAccountID, err = am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
cancel() cancel()
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
@@ -1485,7 +1485,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont
} }
func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) { func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, userAuth nbcontext.UserAuth) (string, error) {
userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) userAccountID, err := am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err)
return "", err return "", err
@@ -1495,7 +1495,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId) return "", fmt.Errorf("user %s is not part of the account id %s", userAuth.UserId, userAuth.AccountId)
} }
accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, userAuth.AccountId) accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, userAuth.AccountId)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) log.WithContext(ctx).Errorf("error getting account domain and category: %v", err)
return "", err return "", err
@@ -1506,7 +1506,7 @@ func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context
} }
// We checked if the domain has a primary account already // We checked if the domain has a primary account already
domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, userAuth.Domain) domainAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, userAuth.Domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err) log.WithContext(ctx).Errorf(errorGettingDomainAccIDFmt, err)
return "", err return "", err
@@ -1636,7 +1636,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee
} }
func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) { func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction store.Store, peer *nbpeer.Peer, settings *types.Settings) (bool, error) {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -1658,7 +1658,7 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, transaction
} }
func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) { func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, s store.Store, accountID string, peerHostName string) (string, error) {
existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) existingLabels, err := s.GetPeerLabelsInAccount(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to get peer dns labels: %w", err) return "", fmt.Errorf("failed to get peer dns labels: %w", err)
} }
@@ -1684,7 +1684,7 @@ func (am *DefaultAccountManager) GetAccountSettings(ctx context.Context, account
if !allowed { if !allowed {
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
} }
// newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id // newAccountWithId creates a new Account with a default SetupKey (doesn't store in a Store) and provided id
@@ -1770,7 +1770,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
cancel := am.Store.AcquireGlobalLock(ctx) cancel := am.Store.AcquireGlobalLock(ctx)
defer cancel() defer cancel()
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, domain) existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain)
if handleNotFound(err) != nil { if handleNotFound(err) != nil {
return nil, false, err return nil, false, err
} }
@@ -1790,7 +1790,7 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
for range 2 { for range 2 {
accountId := xid.New().String() accountId := xid.New().String()
exists, err := am.Store.AccountExists(ctx, store.LockingStrengthShare, accountId) exists, err := am.Store.AccountExists(ctx, store.LockingStrengthNone, accountId)
if err != nil || exists { if err != nil || exists {
continue continue
} }
@@ -1853,47 +1853,56 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C
} }
func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) { func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) {
account, err := am.Store.GetAccount(ctx, accountId) var account *types.Account
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var err error
account, err = transaction.GetAccount(ctx, accountId)
if err != nil {
return err
}
if account.IsDomainPrimaryAccount {
return nil
}
existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, account.Domain)
// error is not a not found error
if handleNotFound(err) != nil {
return err
}
// a primary account already exists for this private domain
if err == nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
"existingAccountId": existingPrimaryAccountID,
}).Errorf("cannot update account to primary, another account already exists as primary for the same domain")
return status.Errorf(status.Internal, "cannot update account to primary")
}
account.IsDomainPrimaryAccount = true
if err := transaction.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
}).Errorf("failed to update account to primary: %v", err)
return status.Errorf(status.Internal, "failed to update account to primary")
}
return nil
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
if account.IsDomainPrimaryAccount {
return account, nil
}
existingPrimaryAccountID, err := am.Store.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthShare, account.Domain)
// error is not a not found error
if handleNotFound(err) != nil {
return nil, err
}
// a primary account already exists for this private domain
if err == nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
"existingAccountId": existingPrimaryAccountID,
}).Errorf("cannot update account to primary, another account already exists as primary for the same domain")
return nil, status.Errorf(status.Internal, "cannot update account to primary")
}
account.IsDomainPrimaryAccount = true
if err := am.Store.SaveAccount(ctx, account); err != nil {
log.WithContext(ctx).WithFields(log.Fields{
"accountId": accountId,
}).Errorf("failed to update account to primary: %v", err)
return nil, status.Errorf(status.Internal, "failed to update account to primary")
}
return account, nil return account, nil
} }
// propagateUserGroupMemberships propagates all account users' group memberships to their peers. // propagateUserGroupMemberships propagates all account users' group memberships to their peers.
// Returns true if any groups were modified, true if those updates affect peers and an error. // Returns true if any groups were modified, true if those updates affect peers and an error.
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) { func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, false, err return false, false, err
} }
@@ -1903,7 +1912,7 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
groupsMap[group.ID] = group groupsMap[group.ID] = group
} }
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, false, err return false, false, err
} }
@@ -1911,7 +1920,7 @@ func propagateUserGroupMemberships(ctx context.Context, transaction store.Store,
groupsToUpdate := make(map[string]*types.Group) groupsToUpdate := make(map[string]*types.Group)
for _, user := range users { for _, user := range users {
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id) userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, user.Id)
if err != nil { if err != nil {
return false, false, err return false, false, err
} }

View File

@@ -782,7 +782,7 @@ func TestAccountManager_GetAccountByUserID(t *testing.T) {
return return
} }
exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthShare, accountID) exists, err := manager.Store.AccountExists(context.Background(), store.LockingStrengthNone, accountID)
assert.NoError(t, err) assert.NoError(t, err)
assert.True(t, exists, "expected to get existing account after creation using userid") assert.True(t, exists, "expected to get existing account after creation using userid")
@@ -899,11 +899,11 @@ func TestAccountManager_DeleteAccount(t *testing.T) {
t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount)) t.Fatal(fmt.Errorf("expected to get an error when trying to get deleted account, got %v", getAccount))
} }
pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, "service-user-1") pats, err := manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, "service-user-1")
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, pats, 0) assert.Len(t, pats, 0)
pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthShare, userId) pats, err = manager.Store.GetUserPATs(context.Background(), store.LockingStrengthNone, userId)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, pats, 0) assert.Len(t, pats, 0)
} }
@@ -1775,7 +1775,7 @@ func TestDefaultAccountManager_DefaultAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
assert.NotNil(t, settings) assert.NotNil(t, settings)
@@ -1960,7 +1960,7 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
assert.False(t, updatedSettings.PeerLoginExpirationEnabled) assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour) assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthNone, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
assert.False(t, settings.PeerLoginExpirationEnabled) assert.False(t, settings.PeerLoginExpirationEnabled)
@@ -2643,7 +2643,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced") assert.Len(t, user.AutoGroups, 0, "JWT groups should not be synced")
}) })
@@ -2657,7 +2657,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.SyncUserJWTGroups(context.Background(), claims) err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Empty(t, user.AutoGroups, "auto groups must be empty") assert.Empty(t, user.AutoGroups, "auto groups must be empty")
}) })
@@ -2671,11 +2671,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err := manager.SyncUserJWTGroups(context.Background(), claims) err := manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0) assert.Len(t, user.AutoGroups, 0)
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
}) })
@@ -2692,11 +2692,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1) assert.Len(t, user.AutoGroups, 1)
group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthShare, "accountID", "group1") group1, err := manager.Store.GetGroupByID(context.Background(), store.LockingStrengthNone, "accountID", "group1")
assert.NoError(t, err, "unable to get group") assert.NoError(t, err, "unable to get group")
assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued") assert.Equal(t, group1.Issued, types.GroupIssuedAPI, "group should be api issued")
}) })
@@ -2710,7 +2710,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change") assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
}) })
@@ -2724,7 +2724,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 2, "groups count should not be change") assert.Len(t, user.AutoGroups, 2, "groups count should not be change")
}) })
@@ -2738,11 +2738,11 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, "accountID") groups, err := manager.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, "accountID")
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, groups, 3, "new group3 should be added") assert.Len(t, groups, 3, "new group3 should be added")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "new group should be added") assert.Len(t, user.AutoGroups, 1, "new group should be added")
}) })
@@ -2756,7 +2756,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user1") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user1")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain") assert.Len(t, user.AutoGroups, 1, "only non-JWT groups should remain")
assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present") assert.Contains(t, user.AutoGroups, "group1", "group1 should still be present")
@@ -2771,7 +2771,7 @@ func TestAccount_SetJWTGroups(t *testing.T) {
err = manager.SyncUserJWTGroups(context.Background(), claims) err = manager.SyncUserJWTGroups(context.Background(), claims)
assert.NoError(t, err, "unable to sync jwt groups") assert.NoError(t, err, "unable to sync jwt groups")
user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthShare, "user2") user, err := manager.Store.GetUserByUserID(context.Background(), store.LockingStrengthNone, "user2")
assert.NoError(t, err, "unable to get user") assert.NoError(t, err, "unable to get user")
assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed") assert.Len(t, user.AutoGroups, 0, "all JWT groups should be removed")
}) })
@@ -3354,7 +3354,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id} group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group1)) require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group1))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err) require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group1.ID) user.AutoGroups = append(user.AutoGroups, group1.ID)
@@ -3365,7 +3365,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
assert.True(t, groupsUpdated) assert.True(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers) assert.False(t, groupChangesAffectPeers)
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthShare, account.Id, group1.ID) group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, account.Id, group1.ID)
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, group.Peers, 2) assert.Len(t, group.Peers, 2)
assert.Contains(t, group.Peers, "peer1") assert.Contains(t, group.Peers, "peer1")
@@ -3376,7 +3376,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id} group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group2)) require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group2))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err) require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group2.ID) user.AutoGroups = append(user.AutoGroups, group2.ID)
@@ -3403,7 +3403,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
assert.True(t, groupsUpdated) assert.True(t, groupsUpdated)
assert.True(t, groupChangesAffectPeers) assert.True(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"}) groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
require.NoError(t, err) require.NoError(t, err)
for _, group := range groups { for _, group := range groups {
assert.Len(t, group.Peers, 2) assert.Len(t, group.Peers, 2)
@@ -3420,7 +3420,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
}) })
t.Run("should not remove peers when groups are removed from user", func(t *testing.T) { t.Run("should not remove peers when groups are removed from user", func(t *testing.T) {
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorId)
require.NoError(t, err) require.NoError(t, err)
user.AutoGroups = []string{"group1"} user.AutoGroups = []string{"group1"}
@@ -3431,7 +3431,7 @@ func TestPropagateUserGroupMemberships(t *testing.T) {
assert.False(t, groupsUpdated) assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers) assert.False(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"}) groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthNone, account.Id, []string{"group1", "group2"})
require.NoError(t, err) require.NoError(t, err)
for _, group := range groups { for _, group := range groups {
assert.Len(t, group.Peers, 2) assert.Len(t, group.Peers, 2)

View File

@@ -73,7 +73,7 @@ func (m *manager) EnsureUserAccessByJWTGroups(ctx context.Context, userAuth nbco
return userAuth, nil return userAuth, nil
} }
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, userAuth.AccountId) settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, userAuth.AccountId)
if err != nil { if err != nil {
return userAuth, err return userAuth, err
} }
@@ -104,7 +104,7 @@ func (am *manager) GetPATInfo(ctx context.Context, token string) (user *types.Us
return nil, nil, "", "", err return nil, nil, "", "", err
} }
domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthShare, user.AccountID) domain, category, err = am.store.GetAccountDomainAndCategory(ctx, store.LockingStrengthNone, user.AccountID)
if err != nil { if err != nil {
return nil, nil, "", "", err return nil, nil, "", "", err
} }
@@ -142,12 +142,12 @@ func (am *manager) extractPATFromToken(ctx context.Context, token string) (*type
var pat *types.PersonalAccessToken var pat *types.PersonalAccessToken
err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthShare, encodedHashedToken) pat, err = transaction.GetPATByHashedToken(ctx, store.LockingStrengthNone, encodedHashedToken)
if err != nil { if err != nil {
return err return err
} }
user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthShare, pat.ID) user, err = transaction.GetUserByPATID(ctx, store.LockingStrengthNone, pat.ID)
return err return err
}) })
if err != nil { if err != nil {

View File

@@ -72,7 +72,7 @@ func (am *DefaultAccountManager) GetDNSSettings(ctx context.Context, accountID s
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
} }
// SaveDNSSettings validates a user role and updates the account's DNS settings // SaveDNSSettings validates a user role and updates the account's DNS settings
@@ -139,7 +139,7 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
var eventsToStore []func() var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups) modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err) log.WithContext(ctx).Debugf("failed to get groups for dns settings events: %v", err)
return nil return nil
@@ -195,7 +195,7 @@ func validateDNSSettings(ctx context.Context, transaction store.Store, accountID
return nil return nil
} }
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, settings.DisabledManagementGroups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, settings.DisabledManagementGroups)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -122,7 +122,7 @@ func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.
} }
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthShare) peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err) log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
return return

View File

@@ -103,7 +103,7 @@ func (am *DefaultAccountManager) fillEventsWithUserInfo(ctx context.Context, eve
} }
func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) { func (am *DefaultAccountManager) getEventsUserInfo(ctx context.Context, events []*activity.Event, accountId string, userId string) (map[string]eventUserInfo, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountId) accountUsers, err := am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -154,7 +154,7 @@ func (am *DefaultAccountManager) getEventsExternalUserInfo(ctx context.Context,
continue continue
} }
externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) externalUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
if err != nil { if err != nil {
// @todo consider logging // @todo consider logging
continue continue

View File

@@ -49,7 +49,7 @@ func (am *DefaultAccountManager) GetGroup(ctx context.Context, accountID, groupI
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) return am.Store.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
} }
// GetAllGroups returns all groups in an account // GetAllGroups returns all groups in an account
@@ -57,12 +57,12 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil { if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err return nil, err
} }
return am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
} }
// GetGroupByName filters all groups in an account by name and returns the one with the most peers // GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) { func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
return am.Store.GetGroupByName(ctx, store.LockingStrengthShare, accountID, groupName) return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
} }
// SaveGroup object of the peers // SaveGroup object of the peers
@@ -140,7 +140,7 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
addedPeers := make([]string, 0) addedPeers := make([]string, 0)
removedPeers := make([]string, 0) removedPeers := make([]string, 0)
oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, newGroup.ID) oldGroup, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, newGroup.ID)
if err == nil && oldGroup != nil { if err == nil && oldGroup != nil {
addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers) addedPeers = util.Difference(newGroup.Peers, oldGroup.Peers)
removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers) removedPeers = util.Difference(oldGroup.Peers, newGroup.Peers)
@@ -152,13 +152,13 @@ func (am *DefaultAccountManager) prepareGroupEvents(ctx context.Context, transac
} }
modifiedPeers := slices.Concat(addedPeers, removedPeers) modifiedPeers := slices.Concat(addedPeers, removedPeers)
peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, modifiedPeers) peers, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, modifiedPeers)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err) log.WithContext(ctx).Debugf("failed to get peers for group events: %v", err)
return nil return nil
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err) log.WithContext(ctx).Debugf("failed to get account settings for group events: %v", err)
return nil return nil
@@ -431,7 +431,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
} }
if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI { if newGroup.ID == "" && newGroup.Issued == types.GroupIssuedAPI {
existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthShare, accountID, newGroup.Name) existingGroup, err := transaction.GetGroupByName(ctx, store.LockingStrengthNone, accountID, newGroup.Name)
if err != nil { if err != nil {
if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound { if s, ok := status.FromError(err); !ok || s.Type() != status.NotFound {
return err return err
@@ -448,7 +448,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
} }
for _, peerID := range newGroup.Peers { for _, peerID := range newGroup.Peers {
_, err := transaction.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) _, err := transaction.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID \"%s\" not found", peerID)
} }
@@ -460,7 +460,7 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error { func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string) error {
// disable a deleting integration group if the initiator is not an admin service user // disable a deleting integration group if the initiator is not an admin service user
if group.Issued == types.GroupIssuedIntegration { if group.Issued == types.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, userID) executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to get user") return status.Errorf(status.Internal, "failed to get user")
} }
@@ -506,7 +506,7 @@ func validateDeleteGroup(ctx context.Context, transaction store.Store, group *ty
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. // checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error { func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, group.AccountID) dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to get DNS settings") return status.Errorf(status.Internal, "failed to get DNS settings")
} }
@@ -515,7 +515,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
return &GroupLinkError{"disabled DNS management groups", group.Name} return &GroupLinkError{"disabled DNS management groups", group.Name}
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, group.AccountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil { if err != nil {
return status.Errorf(status.Internal, "failed to get account settings") return status.Errorf(status.Internal, "failed to get account settings")
} }
@@ -529,7 +529,7 @@ func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, gr
// isGroupLinkedToRoute checks if a group is linked to any route in the account. // isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) { func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil return false, nil
@@ -549,7 +549,7 @@ func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountI
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. // isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) { func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil return false, nil
@@ -567,7 +567,7 @@ func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, account
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. // isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil return false, nil
@@ -586,7 +586,7 @@ func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. // isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) { func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil return false, nil
@@ -602,7 +602,7 @@ func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accou
// isGroupLinkedToUser checks if a group is linked to any user in the account. // isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) { func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) {
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil return false, nil
@@ -618,7 +618,7 @@ func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID
// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account. // isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account.
func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) { func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err)
return false, nil return false, nil
@@ -638,7 +638,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
return false, nil return false, nil
} }
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthShare, accountID) dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -666,7 +666,7 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources. // anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupIDs) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@@ -49,7 +49,7 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string
return nil, err return nil, err
} }
groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) groups, err := m.store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err) return nil, fmt.Errorf("error getting account groups: %w", err)
} }
@@ -96,13 +96,13 @@ func (m *managerImpl) AddResourceToGroupInTransaction(ctx context.Context, trans
return nil, fmt.Errorf("error adding resource to group: %w", err) return nil, fmt.Errorf("error adding resource to group: %w", err)
} }
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting group: %w", err) return nil, fmt.Errorf("error getting group: %w", err)
} }
// TODO: at some point, this will need to become a switch statement // TODO: at some point, this will need to become a switch statement
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resource.ID) networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resource.ID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting network resource: %w", err) return nil, fmt.Errorf("error getting network resource: %w", err)
} }
@@ -120,13 +120,13 @@ func (m *managerImpl) RemoveResourceFromGroupInTransaction(ctx context.Context,
return nil, fmt.Errorf("error removing resource from group: %w", err) return nil, fmt.Errorf("error removing resource from group: %w", err)
} }
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthShare, accountID, groupID) group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting group: %w", err) return nil, fmt.Errorf("error getting group: %w", err)
} }
// TODO: at some point, this will need to become a switch statement // TODO: at some point, this will need to become a switch statement
networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) networkResource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting network resource: %w", err) return nil, fmt.Errorf("error getting network resource: %w", err)
} }

View File

@@ -426,6 +426,10 @@ components:
items: items:
type: string type: string
example: "stage-host-1" example: "stage-host-1"
ephemeral:
description: Indicates whether the peer is ephemeral or not
type: boolean
example: false
required: required:
- city_name - city_name
- connected - connected
@@ -450,6 +454,7 @@ components:
- approval_required - approval_required
- serial_number - serial_number
- extra_dns_labels - extra_dns_labels
- ephemeral
AccessiblePeer: AccessiblePeer:
allOf: allOf:
- $ref: '#/components/schemas/PeerMinimum' - $ref: '#/components/schemas/PeerMinimum'

View File

@@ -1016,6 +1016,9 @@ type Peer struct {
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"` DnsLabel string `json:"dns_label"`
// Ephemeral Indicates whether the peer is ephemeral or not
Ephemeral bool `json:"ephemeral"`
// ExtraDnsLabels Extra DNS labels added to the peer // ExtraDnsLabels Extra DNS labels added to the peer
ExtraDnsLabels []string `json:"extra_dns_labels"` ExtraDnsLabels []string `json:"extra_dns_labels"`
@@ -1097,6 +1100,9 @@ type PeerBatch struct {
// DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud // DnsLabel Peer's DNS label is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's domain to the peer label. e.g. peer-dns-label.netbird.cloud
DnsLabel string `json:"dns_label"` DnsLabel string `json:"dns_label"`
// Ephemeral Indicates whether the peer is ephemeral or not
Ephemeral bool `json:"ephemeral"`
// ExtraDnsLabels Extra DNS labels added to the peer // ExtraDnsLabels Extra DNS labels added to the peer
ExtraDnsLabels []string `json:"extra_dns_labels"` ExtraDnsLabels []string `json:"extra_dns_labels"`

View File

@@ -365,6 +365,7 @@ func toSinglePeerResponse(peer *nbpeer.Peer, groupsInfo []api.GroupMinimum, dnsD
CityName: peer.Location.CityName, CityName: peer.Location.CityName,
SerialNumber: peer.Meta.SystemSerialNumber, SerialNumber: peer.Meta.SystemSerialNumber,
InactivityExpirationEnabled: peer.InactivityExpirationEnabled, InactivityExpirationEnabled: peer.InactivityExpirationEnabled,
Ephemeral: peer.Ephemeral,
} }
} }

View File

@@ -37,21 +37,23 @@ func (am *DefaultAccountManager) UpdateIntegratedValidatorGroups(ctx context.Con
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
a, err := am.Store.GetAccountByUser(ctx, userID) return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err != nil { a, err := transaction.GetAccountByUser(ctx, userID)
return err if err != nil {
} return err
}
var extra *types.ExtraSettings var extra *types.ExtraSettings
if a.Settings.Extra != nil { if a.Settings.Extra != nil {
extra = a.Settings.Extra extra = a.Settings.Extra
} else { } else {
extra = &types.ExtraSettings{} extra = &types.ExtraSettings{}
a.Settings.Extra = extra a.Settings.Extra = extra
} }
extra.IntegratedValidatorGroups = groups extra.IntegratedValidatorGroups = groups
return am.Store.SaveAccount(ctx, a) return transaction.SaveAccount(ctx, a)
})
} }
func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) { func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID string, groupIDs []string) (bool, error) {
@@ -61,7 +63,7 @@ func (am *DefaultAccountManager) GroupValidation(ctx context.Context, accountID
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range groupIDs { for _, groupID := range groupIDs {
_, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthShare, accountID, groupID) _, err := transaction.GetGroupByID(context.Background(), store.LockingStrengthNone, accountID, groupID)
if err != nil { if err != nil {
return err return err
} }
@@ -81,20 +83,17 @@ func (am *DefaultAccountManager) GetValidatedPeers(ctx context.Context, accountI
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
var settings *types.Settings var settings *types.Settings
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { groups, err = am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
groups, err = transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return err
}
peers, err = transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
return err
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) peers, err = am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
if err != nil {
return nil, err
}
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -641,7 +641,7 @@ func testSyncStatusRace(t *testing.T) {
} }
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, peerWithInvalidStatus.PublicKey().String()) peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerWithInvalidStatus.PublicKey().String())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return

View File

@@ -184,7 +184,9 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
ephemeralPeersSKs int ephemeralPeersSKs int
ephemeralPeersSKUsage int ephemeralPeersSKUsage int
activePeersLastDay int activePeersLastDay int
activeUserPeersLastDay int
osPeers map[string]int osPeers map[string]int
activeUsersLastDay map[string]struct{}
userPeers int userPeers int
rules int rules int
rulesProtocol map[string]int rulesProtocol map[string]int
@@ -203,6 +205,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
version string version string
peerActiveVersions []string peerActiveVersions []string
osUIClients map[string]int osUIClients map[string]int
rosenpassEnabled int
) )
start := time.Now() start := time.Now()
metricsProperties := make(properties) metricsProperties := make(properties)
@@ -210,6 +213,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
osUIClients = make(map[string]int) osUIClients = make(map[string]int)
rulesProtocol = make(map[string]int) rulesProtocol = make(map[string]int)
rulesDirection = make(map[string]int) rulesDirection = make(map[string]int)
activeUsersLastDay = make(map[string]struct{})
uptime = time.Since(w.startupTime).Seconds() uptime = time.Since(w.startupTime).Seconds()
connections := w.connManager.GetAllConnectedPeers() connections := w.connManager.GetAllConnectedPeers()
version = nbversion.NetbirdVersion() version = nbversion.NetbirdVersion()
@@ -277,10 +281,14 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
for _, peer := range account.Peers { for _, peer := range account.Peers {
peers++ peers++
if peer.SSHEnabled { if peer.SSHEnabled || peer.Meta.Flags.ServerSSHAllowed {
peersSSHEnabled++ peersSSHEnabled++
} }
if peer.Meta.Flags.RosenpassEnabled {
rosenpassEnabled++
}
if peer.UserID != "" { if peer.UserID != "" {
userPeers++ userPeers++
} }
@@ -299,6 +307,10 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
_, connected := connections[peer.ID] _, connected := connections[peer.ID]
if connected || peer.Status.LastSeen.After(w.lastRun) { if connected || peer.Status.LastSeen.After(w.lastRun) {
activePeersLastDay++ activePeersLastDay++
if peer.UserID != "" {
activeUserPeersLastDay++
activeUsersLastDay[peer.UserID] = struct{}{}
}
osActiveKey := osKey + "_active" osActiveKey := osKey + "_active"
osActiveCount := osPeers[osActiveKey] osActiveCount := osPeers[osActiveKey]
osPeers[osActiveKey] = osActiveCount + 1 osPeers[osActiveKey] = osActiveCount + 1
@@ -320,6 +332,8 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["ephemeral_peers_setup_keys"] = ephemeralPeersSKs metricsProperties["ephemeral_peers_setup_keys"] = ephemeralPeersSKs
metricsProperties["ephemeral_peers_setup_keys_usage"] = ephemeralPeersSKUsage metricsProperties["ephemeral_peers_setup_keys_usage"] = ephemeralPeersSKUsage
metricsProperties["active_peers_last_day"] = activePeersLastDay metricsProperties["active_peers_last_day"] = activePeersLastDay
metricsProperties["active_user_peers_last_day"] = activeUserPeersLastDay
metricsProperties["active_users_last_day"] = len(activeUsersLastDay)
metricsProperties["user_peers"] = userPeers metricsProperties["user_peers"] = userPeers
metricsProperties["rules"] = rules metricsProperties["rules"] = rules
metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks metricsProperties["rules_with_src_posture_checks"] = rulesWithSrcPostureChecks
@@ -338,6 +352,7 @@ func (w *Worker) generateProperties(ctx context.Context) properties {
metricsProperties["ui_clients"] = uiClient metricsProperties["ui_clients"] = uiClient
metricsProperties["idp_manager"] = w.idpManager metricsProperties["idp_manager"] = w.idpManager
metricsProperties["store_engine"] = w.dataSource.GetStoreEngine() metricsProperties["store_engine"] = w.dataSource.GetStoreEngine()
metricsProperties["rosenpass_enabled"] = rosenpassEnabled
for protocol, count := range rulesProtocol { for protocol, count := range rulesProtocol {
metricsProperties["rules_protocol_"+protocol] = count metricsProperties["rules_protocol_"+protocol] = count

View File

@@ -47,8 +47,8 @@ func (mockDatasource) GetAllAccounts(_ context.Context) []*types.Account {
"1": { "1": {
ID: "1", ID: "1",
UserID: "test", UserID: "test",
SSHEnabled: true, SSHEnabled: false,
Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1"}, Meta: nbpeer.PeerSystemMeta{GoOS: "linux", WtVersion: "0.0.1", Flags: nbpeer.Flags{ServerSSHAllowed: true, RosenpassEnabled: true}},
}, },
}, },
Policies: []*types.Policy{ Policies: []*types.Policy{
@@ -312,7 +312,19 @@ func TestGenerateProperties(t *testing.T) {
} }
if properties["posture_checks"] != 2 { if properties["posture_checks"] != 2 {
t.Errorf("expected 1 posture_checks, got %d", properties["posture_checks"]) t.Errorf("expected 2 posture_checks, got %d", properties["posture_checks"])
}
if properties["rosenpass_enabled"] != 1 {
t.Errorf("expected 1 rosenpass_enabled, got %d", properties["rosenpass_enabled"])
}
if properties["active_user_peers_last_day"] != 2 {
t.Errorf("expected 2 active_user_peers_last_day, got %d", properties["active_user_peers_last_day"])
}
if properties["active_users_last_day"] != 1 {
t.Errorf("expected 1 active_users_last_day, got %d", properties["active_users_last_day"])
} }
} }

View File

@@ -32,7 +32,7 @@ func (am *DefaultAccountManager) GetNameServerGroup(ctx context.Context, account
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupID) return am.Store.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupID)
} }
// CreateNameServerGroup creates and saves a new nameserver group // CreateNameServerGroup creates and saves a new nameserver group
@@ -112,7 +112,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
var updateAccountPeers bool var updateAccountPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthShare, accountID, nsGroupToSave.ID) oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -202,7 +202,7 @@ func (am *DefaultAccountManager) ListNameServerGroups(ctx context.Context, accou
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
} }
func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error { func validateNameServerGroup(ctx context.Context, transaction store.Store, accountID string, nameserverGroup *nbdns.NameServerGroup) error {
@@ -216,7 +216,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return err return err
} }
nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthShare, accountID) nsServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -226,7 +226,7 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return err return err
} }
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, nameserverGroup.Groups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, nameserverGroup.Groups)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -56,7 +56,7 @@ func (m *managerImpl) GetAllNetworks(ctx context.Context, accountID, userID stri
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetAccountNetworks(ctx, store.LockingStrengthShare, accountID) return m.store.GetAccountNetworks(ctx, store.LockingStrengthNone, accountID)
} }
func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { func (m *managerImpl) CreateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {
@@ -92,7 +92,7 @@ func (m *managerImpl) GetNetwork(ctx context.Context, accountID, userID, network
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) return m.store.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
} }
func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) { func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network *types.Network) (*types.Network, error) {

View File

@@ -57,7 +57,7 @@ func (m *managerImpl) GetAllResourcesInNetwork(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthShare, accountID, networkID) return m.store.GetNetworkResourcesByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
} }
func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) { func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, userID string) ([]*types.NetworkResource, error) {
@@ -69,7 +69,7 @@ func (m *managerImpl) GetAllResourcesInAccount(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) return m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
} }
func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) { func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID, userID string) (map[string][]string, error) {
@@ -81,7 +81,7 @@ func (m *managerImpl) GetAllResourceIDsInAccount(ctx context.Context, accountID,
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthShare, accountID) resources, err := m.store.GetNetworkResourcesByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network resources: %w", err) return nil, fmt.Errorf("failed to get network resources: %w", err)
} }
@@ -113,7 +113,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
var eventsToStore []func() var eventsToStore []func()
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) _, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil { if err == nil {
return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name) return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
} }
@@ -174,7 +174,7 @@ func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networ
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthShare, accountID, resourceID) resource, err := m.store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, accountID, resourceID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network resource: %w", err) return nil, fmt.Errorf("failed to get network resource: %w", err)
} }
@@ -218,17 +218,17 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID) return status.NewResourceNotPartOfNetworkError(resource.ID, resource.NetworkID)
} }
_, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) _, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network resource: %w", err) return fmt.Errorf("failed to get network resource: %w", err)
} }
oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthShare, resource.AccountID, resource.Name) oldResource, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil && oldResource.ID != resource.ID { if err == nil && oldResource.ID != resource.ID {
return status.Errorf(status.InvalidArgument, "new resource name already exists") return status.Errorf(status.InvalidArgument, "new resource name already exists")
} }
oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthShare, resource.AccountID, resource.ID) oldResource, err = transaction.GetNetworkResourceByID(ctx, store.LockingStrengthNone, resource.AccountID, resource.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network resource: %w", err) return fmt.Errorf("failed to get network resource: %w", err)
} }

View File

@@ -54,7 +54,7 @@ func (m *managerImpl) GetAllRoutersInNetwork(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthShare, accountID, networkID) return m.store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
} }
func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) { func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, userID string) (map[string][]*types.NetworkRouter, error) {
@@ -66,7 +66,7 @@ func (m *managerImpl) GetAllRoutersInAccount(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) routers, err := m.store.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network routers: %w", err) return nil, fmt.Errorf("failed to get network routers: %w", err)
} }
@@ -93,7 +93,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
var network *networkTypes.Network var network *networkTypes.Network
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network: %w", err) return fmt.Errorf("failed to get network: %w", err)
} }
@@ -136,7 +136,7 @@ func (m *managerImpl) GetRouter(ctx context.Context, accountID, userID, networkI
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthShare, accountID, routerID) router, err := m.store.GetNetworkRouterByID(ctx, store.LockingStrengthNone, accountID, routerID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network router: %w", err) return nil, fmt.Errorf("failed to get network router: %w", err)
} }
@@ -162,7 +162,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
var network *networkTypes.Network var network *networkTypes.Network
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthShare, router.AccountID, router.NetworkID) network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get network: %w", err) return fmt.Errorf("failed to get network: %w", err)
} }
@@ -232,7 +232,7 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
} }
func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) { func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction store.Store, accountID, userID, networkID, routerID string) (func(), error) {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthShare, accountID, networkID) network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, accountID, networkID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get network: %w", err) return nil, fmt.Errorf("failed to get network: %w", err)
} }

View File

@@ -35,7 +35,7 @@ import (
// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if // GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if
// the current user is not an admin. // the current user is not an admin.
func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -45,7 +45,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return nil, status.NewPermissionValidationError(err) return nil, status.NewPermissionValidationError(err)
} }
accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, nameFilter, ipFilter) accountPeers, err := am.Store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, nameFilter, ipFilter)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -55,7 +55,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return accountPeers, nil return accountPeers, nil
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get account settings: %w", err) return nil, fmt.Errorf("failed to get account settings: %w", err)
} }
@@ -92,7 +92,7 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc
// fetch all the peers that have access to the user's peers // fetch all the peers that have access to the user's peers
for _, peer := range peers { for _, peer := range peers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, peer.ID, approvedPeersMap) aclPeers, _ := account.GetPeerConnectionResources(ctx, peer, approvedPeersMap)
for _, p := range aclPeers { for _, p := range aclPeers {
peersMap[p.ID] = p peersMap[p.ID] = p
} }
@@ -127,7 +127,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
} }
if peer.AddedWithSSOLogin() { if peer.AddedWithSSOLogin() {
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -216,7 +216,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
return err return err
} }
settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -335,7 +335,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID) peerAccountID, err := am.Store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
if err != nil { if err != nil {
return err return err
} }
@@ -468,7 +468,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
addedByUser := false addedByUser := false
if len(userID) > 0 { if len(userID) > 0 {
addedByUser = true addedByUser = true
accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthShare, userID) accountID, err = am.Store.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
} else { } else {
accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey) accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey)
} }
@@ -488,7 +488,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
// and the peer disconnects with a timeout and tries to register again. // and the peer disconnects with a timeout and tries to register again.
// We just check if this machine has been registered before and reject the second registration. // We just check if this machine has been registered before and reject the second registration.
// The connecting peer should be able to recover with a retry. // The connecting peer should be able to recover with a retry.
_, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, peer.Key) _, err = am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peer.Key)
if err == nil { if err == nil {
return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered")
} }
@@ -584,7 +584,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
ExtraDNSLabels: peer.ExtraDNSLabels, ExtraDNSLabels: peer.ExtraDNSLabels,
AllowExtraDNSLabels: allowExtraDNSLabels, AllowExtraDNSLabels: allowExtraDNSLabels,
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get account settings: %w", err) return fmt.Errorf("failed to get account settings: %w", err)
} }
@@ -674,7 +674,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
} }
func getFreeIP(ctx context.Context, transaction store.Store, accountID string) (net.IP, error) { func getFreeIP(ctx context.Context, transaction store.Store, accountID string) (net.IP, error) {
takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthShare, accountID) takenIps, err := transaction.GetTakenIPs(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get taken IPs: %w", err) return nil, fmt.Errorf("failed to get taken IPs: %w", err)
} }
@@ -706,7 +706,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
var err error var err error
var postureChecks []*posture.Checks var postureChecks []*posture.Checks
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -718,7 +718,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
} }
if peer.UserID != "" { if peer.UserID != "" {
user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, peer.UserID) user, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, peer.UserID)
if err != nil { if err != nil {
return err return err
} }
@@ -821,7 +821,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
var isPeerUpdated bool var isPeerUpdated bool
var postureChecks []*posture.Checks var postureChecks []*posture.Checks
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -906,7 +906,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
// getPeerPostureChecks returns the posture checks for the peer. // getPeerPostureChecks returns the posture checks for the peer.
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) { func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -930,7 +930,7 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...) peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
} }
peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, peerPostureChecksIDs) peerPostureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, peerPostureChecksIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -945,7 +945,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli
continue continue
} }
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, rule.Sources) sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -970,7 +970,7 @@ func processPeerPostureChecks(ctx context.Context, transaction store.Store, poli
// with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired // with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired
// and before starting the engine, we do the checks without an account lock to avoid piling up requests. // and before starting the engine, we do the checks without an account lock to avoid piling up requests.
func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login types.PeerLogin) error { func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login types.PeerLogin) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthShare, login.WireGuardPubKey) peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, login.WireGuardPubKey)
if err != nil { if err != nil {
return err return err
} }
@@ -981,7 +981,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil return nil
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -1000,7 +1000,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is
}() }()
if isRequiresApproval { if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID) network, err := am.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, nil, nil, err return nil, nil, nil, err
} }
@@ -1062,7 +1062,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, transact
log.WithContext(ctx).Debugf("failed to update user last login: %v", err) log.WithContext(ctx).Debugf("failed to update user last login: %v", err)
} }
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, peer.AccountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, peer.AccountID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get account settings: %w", err) return fmt.Errorf("failed to get account settings: %w", err)
} }
@@ -1104,7 +1104,7 @@ func peerLoginExpired(ctx context.Context, peer *nbpeer.Peer, settings *types.Se
// GetPeer for a given accountID, peerID and userID error if not found. // GetPeer for a given accountID, peerID and userID error if not found.
func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) {
peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) peer, err := am.Store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1117,7 +1117,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
return peer, nil return peer, nil
} }
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1143,13 +1143,13 @@ func (am *DefaultAccountManager) checkIfUserOwnsPeer(ctx context.Context, accoun
// it is also possible that user doesn't own the peer but some of his peers have access to it, // it is also possible that user doesn't own the peer but some of his peers have access to it,
// this is a valid case, show the peer as well. // this is a valid case, show the peer as well.
userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID) userPeers, err := am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, p := range userPeers { for _, p := range userPeers {
aclPeers, _ := account.GetPeerConnectionResources(ctx, p.ID, approvedPeersMap) aclPeers, _ := account.GetPeerConnectionResources(ctx, p, approvedPeersMap)
for _, aclPeer := range aclPeers { for _, aclPeer := range aclPeers {
if aclPeer.ID == peer.ID { if aclPeer.ID == peer.ID {
return peer, nil return peer, nil
@@ -1169,7 +1169,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
return return
} }
start := time.Now() globalStart := time.Now()
approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra) approvedPeersMap, err := am.integratedPeerValidator.GetValidatedPeers(account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil { if err != nil {
@@ -1204,18 +1204,27 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
defer wg.Done() defer wg.Done()
defer func() { <-semaphore }() defer func() { <-semaphore }()
start := time.Now()
postureChecks, err := am.getPeerPostureChecks(account, p.ID) postureChecks, err := am.getPeerPostureChecks(account, p.ID)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err) log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", peer.ID, err)
return return
} }
am.metrics.UpdateChannelMetrics().CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics())
am.metrics.UpdateChannelMetrics().CountCalcPeerNetworkMapDuration(time.Since(start))
start = time.Now()
proxyNetworkMap, ok := proxyNetworkMaps[p.ID] proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok { if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap) remotePeerNetworkMap.Merge(proxyNetworkMap)
} }
am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start))
extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID) extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil { if err != nil {
@@ -1223,7 +1232,10 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
return return
} }
start = time.Now()
update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting) update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting)
am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start))
am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap}) am.peersUpdateManager.SendUpdate(ctx, p.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap})
}(peer) }(peer)
} }
@@ -1232,7 +1244,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
wg.Wait() wg.Wait()
if am.metrics != nil { if am.metrics != nil {
am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(start)) am.metrics.AccountManagerMetrics().CountUpdateAccountPeersDuration(time.Since(globalStart))
} }
} }
@@ -1316,7 +1328,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI
// If there is no peer that expires this function returns false and a duration of 0. // If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are connected. // This function only considers peers that haven't been expired yet and that are connected.
func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err) log.WithContext(ctx).Errorf("failed to get peers with expiration: %v", err)
return peerSchedulerRetryInterval, true return peerSchedulerRetryInterval, true
@@ -1326,7 +1338,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
return 0, false return 0, false
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err) log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return peerSchedulerRetryInterval, true return peerSchedulerRetryInterval, true
@@ -1360,7 +1372,7 @@ func (am *DefaultAccountManager) getNextPeerExpiration(ctx context.Context, acco
// If there is no peer that expires this function returns false and a duration of 0. // If there is no peer that expires this function returns false and a duration of 0.
// This function only considers peers that haven't been expired yet and that are not connected. // This function only considers peers that haven't been expired yet and that are not connected.
func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) { func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Context, accountID string) (time.Duration, bool) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err) log.WithContext(ctx).Errorf("failed to get peers with inactivity: %v", err)
return peerSchedulerRetryInterval, true return peerSchedulerRetryInterval, true
@@ -1370,7 +1382,7 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
return 0, false return 0, false
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get account settings: %v", err) log.WithContext(ctx).Errorf("failed to get account settings: %v", err)
return peerSchedulerRetryInterval, true return peerSchedulerRetryInterval, true
@@ -1401,12 +1413,12 @@ func (am *DefaultAccountManager) getNextInactivePeerExpiration(ctx context.Conte
// getExpiredPeers returns peers that have been expired. // getExpiredPeers returns peers that have been expired.
func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthShare, accountID) peersWithExpiry, err := am.Store.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1424,12 +1436,12 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID
// getInactivePeers returns peers that have been expired by inactivity // getInactivePeers returns peers that have been expired by inactivity
func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) { func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID string) ([]*nbpeer.Peer, error) {
peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthShare, accountID) peersWithInactivity, err := am.Store.GetAccountPeersWithInactivity(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1447,12 +1459,12 @@ func (am *DefaultAccountManager) getInactivePeers(ctx context.Context, accountID
// GetPeerGroups returns groups that the peer is part of. // GetPeerGroups returns groups that the peer is part of.
func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) { func (am *DefaultAccountManager) GetPeerGroups(ctx context.Context, accountID, peerID string) ([]*types.Group, error) {
return am.Store.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID) return am.Store.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
} }
// getPeerGroupIDs returns the IDs of the groups that the peer is part of. // getPeerGroupIDs returns the IDs of the groups that the peer is part of.
func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) { func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID string, peerID string) ([]string, error) {
groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthShare, accountID, peerID) groups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1466,7 +1478,7 @@ func getPeerGroupIDs(ctx context.Context, transaction store.Store, accountID str
} }
func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) { func getPeerDNSLabels(ctx context.Context, transaction store.Store, accountID string) (types.LookupMap, error) {
dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthShare, accountID) dnsLabels, err := transaction.GetPeerLabelsInAccount(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1493,7 +1505,7 @@ func isPeerInActiveGroup(ctx context.Context, transaction store.Store, accountID
func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) { func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction store.Store, accountID, userID string, peers []*nbpeer.Peer) ([]func(), error) {
var peerDeletedEvents []func() var peerDeletedEvents []func()
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1504,7 +1516,7 @@ func deletePeers(ctx context.Context, am *DefaultAccountManager, transaction sto
return nil, err return nil, err
} }
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthShare, accountID) network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1565,7 +1577,7 @@ func (am *DefaultAccountManager) validatePeerDelete(ctx context.Context, transac
// isPeerLinkedToNetworkRouter checks if a peer is linked to any network router in the account. // isPeerLinkedToNetworkRouter checks if a peer is linked to any network router in the account.
func isPeerLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, peerID string) (bool, *routerTypes.NetworkRouter) { func isPeerLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, peerID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthShare, accountID) routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking peer linkage: %v", err) log.WithContext(ctx).Errorf("error retrieving network routers while checking peer linkage: %v", err)
return false, nil return false, nil

View File

@@ -31,7 +31,7 @@ type Peer struct {
// Status peer's management connection status // Status peer's management connection status
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"` Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
// The user ID that registered the peer // The user ID that registered the peer
UserID string UserID string `gorm:"index"`
// SSHKey is a public SSH key of the peer // SSHKey is a public SSH key of the peer
SSHKey string SSHKey string
// SSHEnabled indicates whether SSH server is enabled on the peer // SSHEnabled indicates whether SSH server is enabled on the peer

View File

@@ -1301,7 +1301,7 @@ func Test_RegisterPeerByUser(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels) assert.Equal(t, newPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels)
peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, addedPeer.Key) peer, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, addedPeer.Key)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, peer.AccountID, existingAccountID) assert.Equal(t, peer.AccountID, existingAccountID)
assert.Equal(t, peer.UserID, existingUserID) assert.Equal(t, peer.UserID, existingUserID)
@@ -1423,7 +1423,7 @@ func Test_RegisterPeerBySetupKey(t *testing.T) {
assert.NotNil(t, addedPeer, "addedPeer should not be nil on success") assert.NotNil(t, addedPeer, "addedPeer should not be nil on success")
assert.Equal(t, currentPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch") assert.Equal(t, currentPeer.ExtraDNSLabels, addedPeer.ExtraDNSLabels, "ExtraDNSLabels mismatch")
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, currentPeer.Key) peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, currentPeer.Key)
require.NoError(t, err, "Failed to get peer by pub key: %s", currentPeer.Key) require.NoError(t, err, "Failed to get peer by pub key: %s", currentPeer.Key)
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store") assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
assert.Equal(t, currentPeer.ExtraDNSLabels, peerFromStore.ExtraDNSLabels, "ExtraDNSLabels mismatch for peer from store") assert.Equal(t, currentPeer.ExtraDNSLabels, peerFromStore.ExtraDNSLabels, "ExtraDNSLabels mismatch for peer from store")
@@ -1505,7 +1505,7 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) {
_, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer) _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer)
require.Error(t, err) require.Error(t, err)
_, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, newPeer.Key) _, err = s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, newPeer.Key)
require.Error(t, err) require.Error(t, err)
account, err := s.GetAccount(context.Background(), existingAccountID) account, err := s.GetAccount(context.Background(), existingAccountID)
@@ -1671,7 +1671,7 @@ func Test_LoginPeer(t *testing.T) {
assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer") assert.Equal(t, existingAccountID, loggedinPeer.AccountID, "AccountID mismatch for logged peer")
peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthShare, loginInput.WireGuardPubKey) peerFromStore, err := s.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, loginInput.WireGuardPubKey)
require.NoError(t, err, "Failed to get peer by pub key: %s", loginInput.WireGuardPubKey) require.NoError(t, err, "Failed to get peer by pub key: %s", loginInput.WireGuardPubKey)
assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store") assert.Equal(t, existingAccountID, peerFromStore.AccountID, "AccountID mismatch for peer from store")
assert.Equal(t, loggedinPeer.ID, peerFromStore.ID, "Peer ID mismatch between loggedinPeer and peerFromStore") assert.Equal(t, loggedinPeer.ID, peerFromStore.ID, "Peer ID mismatch between loggedinPeer and peerFromStore")

View File

@@ -42,7 +42,7 @@ func (m *managerImpl) GetPeer(ctx context.Context, accountID, userID, peerID str
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return m.store.GetPeerByID(ctx, store.LockingStrengthShare, accountID, peerID) return m.store.GetPeerByID(ctx, store.LockingStrengthNone, accountID, peerID)
} }
func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) { func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string) ([]*peer.Peer, error) {
@@ -52,12 +52,12 @@ func (m *managerImpl) GetAllPeers(ctx context.Context, accountID, userID string)
} }
if !allowed { if !allowed {
return m.store.GetUserPeers(ctx, store.LockingStrengthShare, accountID, userID) return m.store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
} }
return m.store.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") return m.store.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "")
} }
func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) { func (m *managerImpl) GetPeerAccountID(ctx context.Context, peerID string) (string, error) {
return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthShare, peerID) return m.store.GetAccountIDByPeerID(ctx, store.LockingStrengthNone, peerID)
} }

View File

@@ -45,7 +45,7 @@ func (m *managerImpl) ValidateUserPermissions(
return true, nil return true, nil
} }
user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return false, err return false, err
} }

View File

@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, polic
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policyID) return am.Store.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policyID)
} }
// SavePolicy in the store // SavePolicy in the store
@@ -142,13 +142,13 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
} }
// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. // arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers.
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) {
if isUpdate { if isUpdate {
existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -173,7 +173,7 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a
// validatePolicy validates the policy and its rules. // validatePolicy validates the policy and its rules.
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error {
if policy.ID != "" { if policy.ID != "" {
_, err := transaction.GetPolicyByID(ctx, store.LockingStrengthShare, accountID, policy.ID) _, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID)
if err != nil { if err != nil {
return err return err
} }
@@ -182,12 +182,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri
policy.AccountID = accountID policy.AccountID = accountID
} }
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, policy.RuleGroups()) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups())
if err != nil { if err != nil {
return err return err
} }
postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthShare, accountID, policy.SourcePostureChecks) postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -27,6 +27,7 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
ID: "peerB", ID: "peerB",
IP: net.ParseIP("100.65.80.39"), IP: net.ParseIP("100.65.80.39"),
Status: &nbpeer.PeerStatus{}, Status: &nbpeer.PeerStatus{},
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.48.0"},
}, },
"peerC": { "peerC": {
ID: "peerC", ID: "peerC",
@@ -63,6 +64,12 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
IP: net.ParseIP("100.65.31.2"), IP: net.ParseIP("100.65.31.2"),
Status: &nbpeer.PeerStatus{}, Status: &nbpeer.PeerStatus{},
}, },
"peerK": {
ID: "peerK",
IP: net.ParseIP("100.32.80.1"),
Status: &nbpeer.PeerStatus{},
Meta: nbpeer.PeerSystemMeta{WtVersion: "0.30.0"},
},
}, },
Groups: map[string]*types.Group{ Groups: map[string]*types.Group{
"GroupAll": { "GroupAll": {
@@ -111,6 +118,13 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
"peerI", "peerI",
}, },
}, },
"GroupWorkflow": {
ID: "GroupWorkflow",
Name: "workflow",
Peers: []string{
"peerK",
},
},
}, },
Policies: []*types.Policy{ Policies: []*types.Policy{
{ {
@@ -189,6 +203,39 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
}, },
}, },
}, },
{
ID: "RuleWorkflow",
Name: "Workflow",
Description: "No description",
Enabled: true,
Rules: []*types.PolicyRule{
{
ID: "RuleWorkflow",
Name: "Workflow",
Description: "No description",
Bidirectional: true,
Enabled: true,
Protocol: types.PolicyRuleProtocolTCP,
Action: types.PolicyTrafficActionAccept,
PortRanges: []types.RulePortRange{
{
Start: 8088,
End: 8088,
},
{
Start: 9090,
End: 9095,
},
},
Sources: []string{
"GroupWorkflow",
},
Destinations: []string{
"GroupDMZ",
},
},
},
},
}, },
} }
@@ -199,14 +246,14 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
t.Run("check that all peers get map", func(t *testing.T) { t.Run("check that all peers get map", func(t *testing.T) {
for _, p := range account.Peers { for _, p := range account.Peers {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p.ID, validatedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), p, validatedPeers)
assert.GreaterOrEqual(t, len(peers), 2, "minimum number peers should present") assert.GreaterOrEqual(t, len(peers), 1, "minimum number peers should present")
assert.GreaterOrEqual(t, len(firewallRules), 2, "minimum number of firewall rules should present") assert.GreaterOrEqual(t, len(firewallRules), 1, "minimum number of firewall rules should present")
} }
}) })
t.Run("check first peer map details", func(t *testing.T) { t.Run("check first peer map details", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", validatedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], validatedPeers)
assert.Len(t, peers, 8) assert.Len(t, peers, 8)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
@@ -364,6 +411,32 @@ func TestAccount_getPeersByPolicy(t *testing.T) {
assert.True(t, contains, "rule not found in expected rules %#v", rule) assert.True(t, contains, "rule not found in expected rules %#v", rule)
} }
}) })
t.Run("check port ranges support for older peers", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerK"], validatedPeers)
assert.Len(t, peers, 1)
assert.Contains(t, peers, account.Peers["peerI"])
expectedFirewallRules := []*types.FirewallRule{
{
PeerIP: "100.65.31.2",
Direction: types.FirewallRuleDirectionIN,
Action: "accept",
Protocol: "tcp",
Port: "8088",
PolicyID: "RuleWorkflow",
},
{
PeerIP: "100.65.31.2",
Direction: types.FirewallRuleDirectionOUT,
Action: "accept",
Protocol: "tcp",
Port: "8088",
PolicyID: "RuleWorkflow",
},
}
assert.ElementsMatch(t, firewallRules, expectedFirewallRules)
})
} }
func TestAccount_getPeersByPolicyDirect(t *testing.T) { func TestAccount_getPeersByPolicyDirect(t *testing.T) {
@@ -466,10 +539,10 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
} }
t.Run("check first peer map", func(t *testing.T) { t.Run("check first peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
{ {
PeerIP: "100.65.254.139", PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionIN, Direction: types.FirewallRuleDirectionIN,
@@ -487,19 +560,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
PolicyID: "RuleSwarm", PolicyID: "RuleSwarm",
}, },
} }
assert.Len(t, firewallRules, len(epectedFirewallRules)) assert.Len(t, firewallRules, len(expectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc()) slices.SortFunc(expectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc()) slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules { for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
} }
}) })
t.Run("check second peer map", func(t *testing.T) { t.Run("check second peer map", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Contains(t, peers, account.Peers["peerB"]) assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
{ {
PeerIP: "100.65.80.39", PeerIP: "100.65.80.39",
Direction: types.FirewallRuleDirectionIN, Direction: types.FirewallRuleDirectionIN,
@@ -517,21 +590,21 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
PolicyID: "RuleSwarm", PolicyID: "RuleSwarm",
}, },
} }
assert.Len(t, firewallRules, len(epectedFirewallRules)) assert.Len(t, firewallRules, len(expectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc()) slices.SortFunc(expectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc()) slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules { for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
} }
}) })
account.Policies[1].Rules[0].Bidirectional = false account.Policies[1].Rules[0].Bidirectional = false
t.Run("check first peer map directional only", func(t *testing.T) { t.Run("check first peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
epectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
{ {
PeerIP: "100.65.254.139", PeerIP: "100.65.254.139",
Direction: types.FirewallRuleDirectionOUT, Direction: types.FirewallRuleDirectionOUT,
@@ -541,19 +614,19 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
PolicyID: "RuleSwarm", PolicyID: "RuleSwarm",
}, },
} }
assert.Len(t, firewallRules, len(epectedFirewallRules)) assert.Len(t, firewallRules, len(expectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc()) slices.SortFunc(expectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc()) slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules { for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
} }
}) })
t.Run("check second peer map directional only", func(t *testing.T) { t.Run("check second peer map directional only", func(t *testing.T) {
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Contains(t, peers, account.Peers["peerB"]) assert.Contains(t, peers, account.Peers["peerB"])
epectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
{ {
PeerIP: "100.65.80.39", PeerIP: "100.65.80.39",
Direction: types.FirewallRuleDirectionIN, Direction: types.FirewallRuleDirectionIN,
@@ -563,11 +636,11 @@ func TestAccount_getPeersByPolicyDirect(t *testing.T) {
PolicyID: "RuleSwarm", PolicyID: "RuleSwarm",
}, },
} }
assert.Len(t, firewallRules, len(epectedFirewallRules)) assert.Len(t, firewallRules, len(expectedFirewallRules))
slices.SortFunc(epectedFirewallRules, sortFunc()) slices.SortFunc(expectedFirewallRules, sortFunc())
slices.SortFunc(firewallRules, sortFunc()) slices.SortFunc(firewallRules, sortFunc())
for i := range firewallRules { for i := range firewallRules {
assert.Equal(t, epectedFirewallRules[i], firewallRules[i]) assert.Equal(t, expectedFirewallRules[i], firewallRules[i])
} }
}) })
} }
@@ -748,7 +821,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
t.Run("verify peer's network map with default group peer list", func(t *testing.T) { t.Run("verify peer's network map with default group peer list", func(t *testing.T) {
// peerB doesn't fulfill the NB posture check but is included in the destination group Swarm, // peerB doesn't fulfill the NB posture check but is included in the destination group Swarm,
// will establish a connection with all source peers satisfying the NB posture check. // will establish a connection with all source peers satisfying the NB posture check.
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -758,7 +831,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerC satisfy the NB posture check, should establish connection to all destination group peer's // peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections // We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, 1) assert.Len(t, firewallRules, 1)
expectedFirewallRules := []*types.FirewallRule{ expectedFirewallRules := []*types.FirewallRule{
@@ -775,7 +848,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -785,7 +858,7 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm, // peerI doesn't fulfill the OS version posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
assert.Len(t, peers, 4) assert.Len(t, peers, 4)
assert.Len(t, firewallRules, 4) assert.Len(t, firewallRules, 4)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
@@ -800,19 +873,19 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's // peerB doesn't satisfy the NB posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group // no connection should be established to any peer of destination group
peers, firewallRules := account.GetPeerConnectionResources(context.Background(), "peerB", approvedPeers) peers, firewallRules := account.GetPeerConnectionResources(context.Background(), account.Peers["peerB"], approvedPeers)
assert.Len(t, peers, 0) assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0) assert.Len(t, firewallRules, 0)
// peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's // peerI doesn't satisfy the OS version posture check, and doesn't exist in destination group peer's
// no connection should be established to any peer of destination group // no connection should be established to any peer of destination group
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerI", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerI"], approvedPeers)
assert.Len(t, peers, 0) assert.Len(t, peers, 0)
assert.Len(t, firewallRules, 0) assert.Len(t, firewallRules, 0)
// peerC satisfy the NB posture check, should establish connection to all destination group peer's // peerC satisfy the NB posture check, should establish connection to all destination group peer's
// We expect a single permissive firewall rule which all outgoing connections // We expect a single permissive firewall rule which all outgoing connections
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerC", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerC"], approvedPeers)
assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, peers, len(account.Groups["GroupSwarm"].Peers))
assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers)) assert.Len(t, firewallRules, len(account.Groups["GroupSwarm"].Peers))
@@ -827,14 +900,14 @@ func TestAccount_getPeersByPolicyPostureChecks(t *testing.T) {
// peerE doesn't fulfill the NB posture check and exists in only destination group Swarm, // peerE doesn't fulfill the NB posture check and exists in only destination group Swarm,
// all source group peers satisfying the NB posture check should establish connection // all source group peers satisfying the NB posture check should establish connection
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerE", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerE"], approvedPeers)
assert.Len(t, peers, 3) assert.Len(t, peers, 3)
assert.Len(t, firewallRules, 3) assert.Len(t, firewallRules, 3)
assert.Contains(t, peers, account.Peers["peerA"]) assert.Contains(t, peers, account.Peers["peerA"])
assert.Contains(t, peers, account.Peers["peerC"]) assert.Contains(t, peers, account.Peers["peerC"])
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])
peers, firewallRules = account.GetPeerConnectionResources(context.Background(), "peerA", approvedPeers) peers, firewallRules = account.GetPeerConnectionResources(context.Background(), account.Peers["peerA"], approvedPeers)
assert.Len(t, peers, 5) assert.Len(t, peers, 5)
// assert peers from Group Swarm // assert peers from Group Swarm
assert.Contains(t, peers, account.Peers["peerD"]) assert.Contains(t, peers, account.Peers["peerD"])

View File

@@ -24,20 +24,12 @@ func sanitizeVersion(version string) string {
} }
func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) { func (n *NBVersionCheck) Check(ctx context.Context, peer nbpeer.Peer) (bool, error) {
peerVersion := sanitizeVersion(peer.Meta.WtVersion) meetsMin, err := MeetsMinVersion(n.MinVersion, peer.Meta.WtVersion)
minVersion := sanitizeVersion(n.MinVersion)
peerNBVersion, err := version.NewVersion(peerVersion)
if err != nil { if err != nil {
return false, err return false, err
} }
constraints, err := version.NewConstraint(">= " + minVersion) if meetsMin {
if err != nil {
return false, err
}
if constraints.Check(peerNBVersion) {
return true, nil return true, nil
} }
@@ -60,3 +52,21 @@ func (n *NBVersionCheck) Validate() error {
} }
return nil return nil
} }
// MeetsMinVersion checks if the peer's version meets or exceeds the minimum required version
func MeetsMinVersion(minVer, peerVer string) (bool, error) {
peerVer = sanitizeVersion(peerVer)
minVer = sanitizeVersion(minVer)
peerNBVer, err := version.NewVersion(peerVer)
if err != nil {
return false, err
}
constraints, err := version.NewConstraint(">= " + minVer)
if err != nil {
return false, err
}
return constraints.Check(peerNBVer), nil
}

View File

@@ -139,3 +139,68 @@ func TestNBVersionCheck_Validate(t *testing.T) {
}) })
} }
} }
func TestMeetsMinVersion(t *testing.T) {
tests := []struct {
name string
minVer string
peerVer string
want bool
wantErr bool
}{
{
name: "Peer version greater than min version",
minVer: "0.26.0",
peerVer: "0.60.1",
want: true,
wantErr: false,
},
{
name: "Peer version equals min version",
minVer: "1.0.0",
peerVer: "1.0.0",
want: true,
wantErr: false,
},
{
name: "Peer version less than min version",
minVer: "1.0.0",
peerVer: "0.9.9",
want: false,
wantErr: false,
},
{
name: "Peer version with pre-release tag greater than min version",
minVer: "1.0.0",
peerVer: "1.0.1-alpha",
want: true,
wantErr: false,
},
{
name: "Invalid peer version format",
minVer: "1.0.0",
peerVer: "dev",
want: false,
wantErr: true,
},
{
name: "Invalid min version format",
minVer: "invalid.version",
peerVer: "1.0.0",
want: false,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := MeetsMinVersion(tt.minVer, tt.peerVer)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -27,7 +27,7 @@ func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) return am.Store.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
} }
// SavePostureChecks saves a posture check. // SavePostureChecks saves a posture check.
@@ -101,7 +101,7 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun
var postureChecks *posture.Checks var postureChecks *posture.Checks
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecksID) postureChecks, err = transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecksID)
if err != nil { if err != nil {
return err return err
} }
@@ -135,7 +135,7 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
} }
// getPeerPostureChecks returns the posture checks applied for a given peer. // getPeerPostureChecks returns the posture checks applied for a given peer.
@@ -161,7 +161,7 @@ func (am *DefaultAccountManager) getPeerPostureChecks(account *types.Account, pe
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers. // arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) { func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return false, err return false, err
} }
@@ -190,14 +190,14 @@ func validatePostureChecks(ctx context.Context, transaction store.Store, account
// If the posture check already has an ID, verify its existence in the store. // If the posture check already has an ID, verify its existence in the store.
if postureChecks.ID != "" { if postureChecks.ID != "" {
if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthShare, accountID, postureChecks.ID); err != nil { if _, err := transaction.GetPostureChecksByID(ctx, store.LockingStrengthNone, accountID, postureChecks.ID); err != nil {
return err return err
} }
return nil return nil
} }
// For new posture checks, ensure no duplicates by name. // For new posture checks, ensure no duplicates by name.
checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthShare, accountID) checks, err := transaction.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -259,7 +259,7 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
// isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy. // isPostureCheckLinkedToPolicy checks whether the posture check is linked to any account policy.
func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error { func isPostureCheckLinkedToPolicy(ctx context.Context, transaction store.Store, postureChecksID, accountID string) error {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthShare, accountID) policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -30,7 +30,7 @@ func (am *DefaultAccountManager) GetRoute(ctx context.Context, accountID string,
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetRouteByID(ctx, store.LockingStrengthShare, accountID, string(routeID)) return am.Store.GetRouteByID(ctx, store.LockingStrengthNone, accountID, string(routeID))
} }
// checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups. // checkRoutePrefixOrDomainsExistForPeers checks if a route with a given prefix exists for a single peer or multiple peer groups.
@@ -59,7 +59,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
seenPeers[string(prefixRoute.ID)] = true seenPeers[string(prefixRoute.ID)] = true
} }
peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, prefixRoute.PeerGroups) peerGroupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, prefixRoute.PeerGroups)
if err != nil { if err != nil {
return err return err
} }
@@ -83,7 +83,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
if peerID := checkRoute.Peer; peerID != "" { if peerID := checkRoute.Peer; peerID != "" {
// check that peerID exists and is not in any route as single peer or part of the group // check that peerID exists and is not in any route as single peer or part of the group
_, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthShare, accountID, peerID) _, err = transaction.GetPeerByID(context.Background(), store.LockingStrengthNone, accountID, peerID)
if err != nil { if err != nil {
return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID) return status.Errorf(status.InvalidArgument, "peer with ID %s not found", peerID)
} }
@@ -104,7 +104,7 @@ func checkRoutePrefixOrDomainsExistForPeers(ctx context.Context, transaction sto
} }
// check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix // check that the peers from peerGroupIDs groups are not the same peers we saw in routesWithPrefix
peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthShare, accountID, group.Peers) peersMap, err := transaction.GetPeersByIDs(ctx, store.LockingStrengthNone, accountID, group.Peers)
if err != nil { if err != nil {
return err return err
} }
@@ -310,7 +310,7 @@ func (am *DefaultAccountManager) ListRoutes(ctx context.Context, accountID, user
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
} }
func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error { func validateRoute(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) error {
@@ -353,7 +353,7 @@ func validateRoute(ctx context.Context, transaction store.Store, accountID strin
// validateRouteGroups validates the route groups and returns the validated groups map. // validateRouteGroups validates the route groups and returns the validated groups map.
func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) { func validateRouteGroups(ctx context.Context, transaction store.Store, accountID string, routeToSave *route.Route) (map[string]*types.Group, error) {
groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups) groupsToValidate := slices.Concat(routeToSave.Groups, routeToSave.PeerGroups, routeToSave.AccessControlGroups)
groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, groupsToValidate) groupsMap, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupsToValidate)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -494,7 +494,7 @@ func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, ro
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix // GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) { func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthShare, accountID) accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1100,7 +1100,7 @@ func TestGetNetworkMap_RouteSyncPeerGroups(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route") assert.Len(t, peer4Routes.Routes, 1, "HA route should have 1 server route")
groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthShare, account.Id) groups, err := am.Store.GetAccountGroups(context.Background(), store.LockingStrengthNone, account.Id)
require.NoError(t, err) require.NoError(t, err)
var groupHA1, groupHA2 *types.Group var groupHA1, groupHA2 *types.Group
for _, group := range groups { for _, group := range groups {

View File

@@ -60,7 +60,7 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string)
return nil, fmt.Errorf("get extra settings: %w", err) return nil, fmt.Errorf("get extra settings: %w", err)
} }
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account settings: %w", err) return nil, fmt.Errorf("get account settings: %w", err)
} }
@@ -82,7 +82,7 @@ func (m *managerImpl) GetExtraSettings(ctx context.Context, accountID string) (*
return nil, fmt.Errorf("get extra settings: %w", err) return nil, fmt.Errorf("get extra settings: %w", err)
} }
settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := m.store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("get account settings: %w", err) return nil, fmt.Errorf("get account settings: %w", err)
} }

View File

@@ -127,7 +127,7 @@ func (am *DefaultAccountManager) SaveSetupKey(ctx context.Context, accountID str
return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err) return status.Errorf(status.InvalidArgument, "invalid auto groups: %v", err)
} }
oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyToSave.Id) oldKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyToSave.Id)
if err != nil { if err != nil {
return err return err
} }
@@ -175,7 +175,7 @@ func (am *DefaultAccountManager) ListSetupKeys(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
} }
// GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found.
@@ -188,7 +188,7 @@ func (am *DefaultAccountManager) GetSetupKey(ctx context.Context, accountID, use
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) setupKey, err := am.Store.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -214,7 +214,7 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
var deletedSetupKey *types.SetupKey var deletedSetupKey *types.SetupKey
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthShare, accountID, keyID) deletedSetupKey, err = transaction.GetSetupKeyByID(ctx, store.LockingStrengthNone, accountID, keyID)
if err != nil { if err != nil {
return err return err
} }
@@ -231,7 +231,7 @@ func (am *DefaultAccountManager) DeleteSetupKey(ctx context.Context, accountID,
} }
func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error { func validateSetupKeyAutoGroups(ctx context.Context, transaction store.Store, accountID string, autoGroupIDs []string) error {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, autoGroupIDs) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, autoGroupIDs)
if err != nil { if err != nil {
return err return err
} }
@@ -255,7 +255,7 @@ func (am *DefaultAccountManager) prepareSetupKeyEvents(ctx context.Context, tran
var eventsToStore []func() var eventsToStore []func()
modifiedGroups := slices.Concat(addedGroups, removedGroups) modifiedGroups := slices.Concat(addedGroups, removedGroups)
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthShare, accountID, modifiedGroups) groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, modifiedGroups)
if err != nil { if err != nil {
log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err) log.WithContext(ctx).Debugf("failed to get groups for setup key events: %v", err)
return nil return nil

View File

@@ -478,7 +478,7 @@ func (s *SqlStore) DeleteTokenID2UserIDIndex(tokenID string) error {
} }
func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) { func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string) (*types.Account, error) {
accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthShare, domain) accountID, err := s.GetAccountIDByPrivateDomain(ctx, LockingStrengthNone, domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -488,11 +488,7 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
} }
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) { func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountID string var accountID string
result := tx.Model(&types.Account{}).Select("id"). result := tx.Model(&types.Account{}).Select("id").
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?", Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
@@ -542,11 +538,7 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
} }
func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) { func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var user types.User var user types.User
result := tx. result := tx.
Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id"). Joins("JOIN personal_access_tokens ON personal_access_tokens.user_id = users.id").
@@ -563,11 +555,7 @@ func (s *SqlStore) GetUserByPATID(ctx context.Context, lockStrength LockingStren
} }
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) { func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var user types.User var user types.User
result := tx.First(&user, idQueryCondition, userID) result := tx.First(&user, idQueryCondition, userID)
if result.Error != nil { if result.Error != nil {
@@ -600,11 +588,7 @@ func (s *SqlStore) DeleteUser(ctx context.Context, lockStrength LockingStrength,
} }
func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) { func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.User, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var users []*types.User var users []*types.User
result := tx.Find(&users, accountIDCondition, accountID) result := tx.Find(&users, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
@@ -619,11 +603,7 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, lockStrength LockingStre
} }
func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error) { func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.User, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var user types.User var user types.User
result := tx.First(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner) result := tx.First(&user, "account_id = ? AND role = ?", accountID, types.UserRoleOwner)
if result.Error != nil { if result.Error != nil {
@@ -637,11 +617,7 @@ func (s *SqlStore) GetAccountOwner(ctx context.Context, lockStrength LockingStre
} }
func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) { func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var groups []*types.Group var groups []*types.Group
result := tx.Find(&groups, accountIDCondition, accountID) result := tx.Find(&groups, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
@@ -656,11 +632,7 @@ func (s *SqlStore) GetAccountGroups(ctx context.Context, lockStrength LockingStr
} }
func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) { func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var groups []*types.Group var groups []*types.Group
likePattern := `%"ID":"` + resourceID + `"%` likePattern := `%"ID":"` + resourceID + `"%`
@@ -706,11 +678,7 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
} }
func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) { func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountMeta types.AccountMeta var accountMeta types.AccountMeta
result := tx.Model(&types.Account{}). result := tx.Model(&types.Account{}).
First(&accountMeta, idQueryCondition, accountID) First(&accountMeta, idQueryCondition, accountID)
@@ -880,11 +848,7 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
} }
func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) { func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountID string var accountID string
result := tx.Model(&types.User{}). result := tx.Model(&types.User{}).
Select("account_id").Where(idQueryCondition, userID).First(&accountID) Select("account_id").Where(idQueryCondition, userID).First(&accountID)
@@ -899,11 +863,7 @@ func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength Lockin
} }
func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) { func (s *SqlStore) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountID string var accountID string
result := tx.Model(&nbpeer.Peer{}). result := tx.Model(&nbpeer.Peer{}).
Select("account_id").Where(idQueryCondition, peerID).First(&accountID) Select("account_id").Where(idQueryCondition, peerID).First(&accountID)
@@ -936,11 +896,7 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
} }
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var ipJSONStrings []string var ipJSONStrings []string
// Fetch the IP addresses as JSON strings // Fetch the IP addresses as JSON strings
@@ -968,11 +924,7 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
} }
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var labels []string var labels []string
result := tx.Model(&nbpeer.Peer{}). result := tx.Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID). Where("account_id = ?", accountID).
@@ -990,11 +942,7 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
} }
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) { func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Network, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountNetwork types.AccountNetwork var accountNetwork types.AccountNetwork
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -1006,11 +954,7 @@ func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingSt
} }
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peer nbpeer.Peer var peer nbpeer.Peer
result := tx.First(&peer, GetKeyQueryCondition(s), peerKey) result := tx.First(&peer, GetKeyQueryCondition(s), peerKey)
@@ -1025,11 +969,7 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
} }
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) { func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.Settings, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountSettings types.AccountSettings var accountSettings types.AccountSettings
if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if err := tx.Model(&types.Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
@@ -1041,11 +981,7 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
} }
func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) { func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var createdBy string var createdBy string
result := tx.Model(&types.Account{}). result := tx.Model(&types.Account{}).
Select("created_by").First(&createdBy, idQueryCondition, accountID) Select("created_by").First(&createdBy, idQueryCondition, accountID)
@@ -1243,11 +1179,7 @@ func NewMysqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore, dsn s
} }
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) { func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*types.SetupKey, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var setupKey types.SetupKey var setupKey types.SetupKey
result := tx. result := tx.
First(&setupKey, GetKeyQueryCondition(s), key) First(&setupKey, GetKeyQueryCondition(s), key)
@@ -1391,11 +1323,7 @@ func (s *SqlStore) RemoveResourceFromGroup(ctx context.Context, accountId string
// GetPeerGroups retrieves all groups assigned to a specific peer in a given account. // GetPeerGroups retrieves all groups assigned to a specific peer in a given account.
func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) { func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStrength, accountId string, peerId string) ([]*types.Group, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var groups []*types.Group var groups []*types.Group
query := tx. query := tx.
Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId)) Find(&groups, "account_id = ? AND peers LIKE ?", accountId, fmt.Sprintf(`%%"%s"%%`, peerId))
@@ -1409,8 +1337,9 @@ func (s *SqlStore) GetPeerGroups(ctx context.Context, lockStrength LockingStreng
// GetAccountPeers retrieves peers for an account. // GetAccountPeers retrieves peers for an account.
func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStrength, accountID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) {
tx := s.getTXWithLockStrength(lockStrength)
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
query := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Where(accountIDCondition, accountID) query := tx.Where(accountIDCondition, accountID)
if nameFilter != "" { if nameFilter != "" {
query = query.Where("name LIKE ?", "%"+nameFilter+"%") query = query.Where("name LIKE ?", "%"+nameFilter+"%")
@@ -1429,11 +1358,7 @@ func (s *SqlStore) GetAccountPeers(ctx context.Context, lockStrength LockingStre
// GetUserPeers retrieves peers for a user. // GetUserPeers retrieves peers for a user.
func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrength, accountID, userID string) ([]*nbpeer.Peer, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
// Exclude peers added via setup keys, as they are not user-specific and have an empty user_id. // Exclude peers added via setup keys, as they are not user-specific and have an empty user_id.
@@ -1461,11 +1386,7 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, lockStrength LockingStr
// GetPeerByID retrieves a peer by its ID and account ID. // GetPeerByID retrieves a peer by its ID and account ID.
func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) { func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID, peerID string) (*nbpeer.Peer, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peer *nbpeer.Peer var peer *nbpeer.Peer
result := tx. result := tx.
First(&peer, accountAndIDQueryCondition, accountID, peerID) First(&peer, accountAndIDQueryCondition, accountID, peerID)
@@ -1481,11 +1402,7 @@ func (s *SqlStore) GetPeerByID(ctx context.Context, lockStrength LockingStrength
// GetPeersByIDs retrieves peers by their IDs and account ID. // GetPeersByIDs retrieves peers by their IDs and account ID.
func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) { func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
result := tx.Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs) result := tx.Find(&peers, accountAndIDsQueryCondition, accountID, peerIDs)
if result.Error != nil { if result.Error != nil {
@@ -1503,11 +1420,7 @@ func (s *SqlStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStreng
// GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user. // GetAccountPeersWithExpiration retrieves a list of peers that have login expiration enabled and added by a user.
func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
result := tx. result := tx.
Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
@@ -1522,11 +1435,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng
// GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user. // GetAccountPeersWithInactivity retrieves a list of peers that have login expiration enabled and added by a user.
func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var peers []*nbpeer.Peer var peers []*nbpeer.Peer
result := tx. result := tx.
Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). Where("inactivity_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true).
@@ -1541,11 +1450,7 @@ func (s *SqlStore) GetAccountPeersWithInactivity(ctx context.Context, lockStreng
// GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing. // GetAllEphemeralPeers retrieves all peers with Ephemeral set to true across all accounts, optimized for batch processing.
func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) { func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var allEphemeralPeers, batchPeers []*nbpeer.Peer var allEphemeralPeers, batchPeers []*nbpeer.Peer
result := tx. result := tx.
Where("ephemeral = ?", true). Where("ephemeral = ?", true).
@@ -1623,11 +1528,7 @@ func (s *SqlStore) GetDB() *gorm.DB {
} }
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) { func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.DNSSettings, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountDNSSettings types.AccountDNSSettings var accountDNSSettings types.AccountDNSSettings
result := tx.Model(&types.Account{}). result := tx.Model(&types.Account{}).
First(&accountDNSSettings, idQueryCondition, accountID) First(&accountDNSSettings, idQueryCondition, accountID)
@@ -1643,11 +1544,7 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
// AccountExists checks whether an account exists by the given ID. // AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) { func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var accountID string var accountID string
result := tx.Model(&types.Account{}). result := tx.Model(&types.Account{}).
Select("id").First(&accountID, idQueryCondition, id) Select("id").First(&accountID, idQueryCondition, id)
@@ -1663,11 +1560,7 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID. // GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) { func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var account types.Account var account types.Account
result := tx.Model(&types.Account{}).Select("domain", "domain_category"). result := tx.Model(&types.Account{}).Select("domain", "domain_category").
Where(idQueryCondition, accountID).First(&account) Where(idQueryCondition, accountID).First(&account)
@@ -1683,11 +1576,7 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
// GetGroupByID retrieves a group by ID and account ID. // GetGroupByID retrieves a group by ID and account ID.
func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) { func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var group *types.Group var group *types.Group
result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID) result := tx.First(&group, accountAndIDQueryCondition, accountID, groupID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
@@ -1703,11 +1592,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, lockStrength LockingStrengt
// GetGroupByName retrieves a group by name and account ID. // GetGroupByName retrieves a group by name and account ID.
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) { func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var group types.Group var group types.Group
// TODO: This fix is accepted for now, but if we need to handle this more frequently // TODO: This fix is accepted for now, but if we need to handle this more frequently
@@ -1736,11 +1621,7 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
// GetGroupsByIDs retrieves groups by their IDs and account ID. // GetGroupsByIDs retrieves groups by their IDs and account ID.
func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) { func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var groups []*types.Group var groups []*types.Group
result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) result := tx.Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs)
if result.Error != nil { if result.Error != nil {
@@ -1796,11 +1677,7 @@ func (s *SqlStore) DeleteGroups(ctx context.Context, strength LockingStrength, a
// GetAccountPolicies retrieves policies for an account. // GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) { func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Policy, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var policies []*types.Policy var policies []*types.Policy
result := tx. result := tx.
Preload(clause.Associations).Find(&policies, accountIDCondition, accountID) Preload(clause.Associations).Find(&policies, accountIDCondition, accountID)
@@ -1814,11 +1691,7 @@ func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingS
// GetPolicyByID retrieves a policy by its ID and account ID. // GetPolicyByID retrieves a policy by its ID and account ID.
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) { func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, accountID, policyID string) (*types.Policy, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var policy *types.Policy var policy *types.Policy
result := tx.Preload(clause.Associations). result := tx.Preload(clause.Associations).
@@ -1880,11 +1753,7 @@ func (s *SqlStore) DeletePolicy(ctx context.Context, lockStrength LockingStrengt
// GetAccountPostureChecks retrieves posture checks for an account. // GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) { func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var postureChecks []*posture.Checks var postureChecks []*posture.Checks
result := tx.Find(&postureChecks, accountIDCondition, accountID) result := tx.Find(&postureChecks, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
@@ -1897,10 +1766,7 @@ func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength Loc
// GetPostureChecksByID retrieves posture checks by their ID and account ID. // GetPostureChecksByID retrieves posture checks by their ID and account ID.
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) { func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) (*posture.Checks, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var postureCheck *posture.Checks var postureCheck *posture.Checks
result := tx. result := tx.
@@ -1918,11 +1784,7 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin
// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID. // GetPostureChecksByIDs retrieves posture checks by their IDs and account ID.
func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) { func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var postureChecks []*posture.Checks var postureChecks []*posture.Checks
result := tx.Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs) result := tx.Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs)
if result.Error != nil { if result.Error != nil {
@@ -1967,9 +1829,9 @@ func (s *SqlStore) DeletePostureChecks(ctx context.Context, lockStrength Locking
// GetAccountRoutes retrieves network routes for an account. // GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) { func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
tx := s.getTXWithLockStrength(lockStrength)
var routes []*route.Route var routes []*route.Route
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := tx.Find(&routes, accountIDCondition, accountID)
Find(&routes, accountIDCondition, accountID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err) log.WithContext(ctx).Errorf("failed to get routes from the store: %s", err)
return nil, status.Errorf(status.Internal, "failed to get routes from store") return nil, status.Errorf(status.Internal, "failed to get routes from store")
@@ -1980,9 +1842,9 @@ func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStr
// GetRouteByID retrieves a route by its ID and account ID. // GetRouteByID retrieves a route by its ID and account ID.
func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) { func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrength, accountID string, routeID string) (*route.Route, error) {
tx := s.getTXWithLockStrength(lockStrength)
var route *route.Route var route *route.Route
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}). result := tx.First(&route, accountAndIDQueryCondition, accountID, routeID)
First(&route, accountAndIDQueryCondition, accountID, routeID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewRouteNotFoundError(routeID) return nil, status.NewRouteNotFoundError(routeID)
@@ -2023,11 +1885,7 @@ func (s *SqlStore) DeleteRoute(ctx context.Context, lockStrength LockingStrength
// GetAccountSetupKeys retrieves setup keys for an account. // GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) { func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.SetupKey, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var setupKeys []*types.SetupKey var setupKeys []*types.SetupKey
result := tx. result := tx.
Find(&setupKeys, accountIDCondition, accountID) Find(&setupKeys, accountIDCondition, accountID)
@@ -2041,14 +1899,9 @@ func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength Locking
// GetSetupKeyByID retrieves a setup key by its ID and account ID. // GetSetupKeyByID retrieves a setup key by its ID and account ID.
func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) { func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStrength, accountID, setupKeyID string) (*types.SetupKey, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var setupKey *types.SetupKey var setupKey *types.SetupKey
result := tx.Clauses(clause.Locking{Strength: string(lockStrength)}). result := tx.First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
First(&setupKey, accountAndIDQueryCondition, accountID, setupKeyID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewSetupKeyNotFoundError(setupKeyID) return nil, status.NewSetupKeyNotFoundError(setupKeyID)
@@ -2088,11 +1941,7 @@ func (s *SqlStore) DeleteSetupKey(ctx context.Context, lockStrength LockingStren
// GetAccountNameServerGroups retrieves name server groups for an account. // GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) { func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var nsGroups []*nbdns.NameServerGroup var nsGroups []*nbdns.NameServerGroup
result := tx.Find(&nsGroups, accountIDCondition, accountID) result := tx.Find(&nsGroups, accountIDCondition, accountID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
@@ -2105,11 +1954,7 @@ func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength
// GetNameServerGroupByID retrieves a name server group by its ID and account ID. // GetNameServerGroupByID retrieves a name server group by its ID and account ID.
func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) { func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, nsGroupID string) (*nbdns.NameServerGroup, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var nsGroup *nbdns.NameServerGroup var nsGroup *nbdns.NameServerGroup
result := tx. result := tx.
First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID) First(&nsGroup, accountAndIDQueryCondition, accountID, nsGroupID)
@@ -2182,11 +2027,7 @@ func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength Locking
} }
func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) { func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var networks []*networkTypes.Network var networks []*networkTypes.Network
result := tx.Find(&networks, accountIDCondition, accountID) result := tx.Find(&networks, accountIDCondition, accountID)
if result.Error != nil { if result.Error != nil {
@@ -2198,11 +2039,7 @@ func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingS
} }
func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) { func (s *SqlStore) GetNetworkByID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) (*networkTypes.Network, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var network *networkTypes.Network var network *networkTypes.Network
result := tx. result := tx.
First(&network, accountAndIDQueryCondition, accountID, networkID) First(&network, accountAndIDQueryCondition, accountID, networkID)
@@ -2244,11 +2081,7 @@ func (s *SqlStore) DeleteNetwork(ctx context.Context, lockStrength LockingStreng
} }
func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) { func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength LockingStrength, accountID, netID string) ([]*routerTypes.NetworkRouter, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var netRouters []*routerTypes.NetworkRouter var netRouters []*routerTypes.NetworkRouter
result := tx. result := tx.
Find(&netRouters, "account_id = ? AND network_id = ?", accountID, netID) Find(&netRouters, "account_id = ? AND network_id = ?", accountID, netID)
@@ -2261,11 +2094,7 @@ func (s *SqlStore) GetNetworkRoutersByNetID(ctx context.Context, lockStrength Lo
} }
func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) { func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var netRouters []*routerTypes.NetworkRouter var netRouters []*routerTypes.NetworkRouter
result := tx. result := tx.
Find(&netRouters, accountIDCondition, accountID) Find(&netRouters, accountIDCondition, accountID)
@@ -2278,11 +2107,7 @@ func (s *SqlStore) GetNetworkRoutersByAccountID(ctx context.Context, lockStrengt
} }
func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) { func (s *SqlStore) GetNetworkRouterByID(ctx context.Context, lockStrength LockingStrength, accountID, routerID string) (*routerTypes.NetworkRouter, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var netRouter *routerTypes.NetworkRouter var netRouter *routerTypes.NetworkRouter
result := tx. result := tx.
First(&netRouter, accountAndIDQueryCondition, accountID, routerID) First(&netRouter, accountAndIDQueryCondition, accountID, routerID)
@@ -2323,11 +2148,7 @@ func (s *SqlStore) DeleteNetworkRouter(ctx context.Context, lockStrength Locking
} }
func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*resourceTypes.NetworkResource, error) { func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength LockingStrength, accountID, networkID string) ([]*resourceTypes.NetworkResource, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var netResources []*resourceTypes.NetworkResource var netResources []*resourceTypes.NetworkResource
result := tx. result := tx.
Find(&netResources, "account_id = ? AND network_id = ?", accountID, networkID) Find(&netResources, "account_id = ? AND network_id = ?", accountID, networkID)
@@ -2340,11 +2161,7 @@ func (s *SqlStore) GetNetworkResourcesByNetID(ctx context.Context, lockStrength
} }
func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) { func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var netResources []*resourceTypes.NetworkResource var netResources []*resourceTypes.NetworkResource
result := tx. result := tx.
Find(&netResources, accountIDCondition, accountID) Find(&netResources, accountIDCondition, accountID)
@@ -2357,11 +2174,7 @@ func (s *SqlStore) GetNetworkResourcesByAccountID(ctx context.Context, lockStren
} }
func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) { func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) (*resourceTypes.NetworkResource, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var netResources *resourceTypes.NetworkResource var netResources *resourceTypes.NetworkResource
result := tx. result := tx.
First(&netResources, accountAndIDQueryCondition, accountID, resourceID) First(&netResources, accountAndIDQueryCondition, accountID, resourceID)
@@ -2377,11 +2190,7 @@ func (s *SqlStore) GetNetworkResourceByID(ctx context.Context, lockStrength Lock
} }
func (s *SqlStore) GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) { func (s *SqlStore) GetNetworkResourceByName(ctx context.Context, lockStrength LockingStrength, accountID, resourceName string) (*resourceTypes.NetworkResource, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var netResources *resourceTypes.NetworkResource var netResources *resourceTypes.NetworkResource
result := tx. result := tx.
First(&netResources, "account_id = ? AND name = ?", accountID, resourceName) First(&netResources, "account_id = ? AND name = ?", accountID, resourceName)
@@ -2423,10 +2232,7 @@ func (s *SqlStore) DeleteNetworkResource(ctx context.Context, lockStrength Locki
// GetPATByHashedToken returns a PersonalAccessToken by its hashed token. // GetPATByHashedToken returns a PersonalAccessToken by its hashed token.
func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) { func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken string) (*types.PersonalAccessToken, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var pat types.PersonalAccessToken var pat types.PersonalAccessToken
result := tx.First(&pat, "hashed_token = ?", hashedToken) result := tx.First(&pat, "hashed_token = ?", hashedToken)
@@ -2443,11 +2249,7 @@ func (s *SqlStore) GetPATByHashedToken(ctx context.Context, lockStrength Locking
// GetPATByID retrieves a personal access token by its ID and user ID. // GetPATByID retrieves a personal access token by its ID and user ID.
func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) { func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength, userID string, patID string) (*types.PersonalAccessToken, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var pat types.PersonalAccessToken var pat types.PersonalAccessToken
result := tx. result := tx.
First(&pat, "id = ? AND user_id = ?", patID, userID) First(&pat, "id = ? AND user_id = ?", patID, userID)
@@ -2464,11 +2266,7 @@ func (s *SqlStore) GetPATByID(ctx context.Context, lockStrength LockingStrength,
// GetUserPATs retrieves personal access tokens for a user. // GetUserPATs retrieves personal access tokens for a user.
func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) { func (s *SqlStore) GetUserPATs(ctx context.Context, lockStrength LockingStrength, userID string) ([]*types.PersonalAccessToken, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
var pats []*types.PersonalAccessToken var pats []*types.PersonalAccessToken
result := tx.Find(&pats, "user_id = ?", userID) result := tx.Find(&pats, "user_id = ?", userID)
if err := result.Error; err != nil { if err := result.Error; err != nil {
@@ -2528,11 +2326,7 @@ func (s *SqlStore) DeletePAT(ctx context.Context, lockStrength LockingStrength,
} }
func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) { func (s *SqlStore) GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) {
tx := s.db tx := s.getTXWithLockStrength(lockStrength)
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
jsonValue := fmt.Sprintf(`"%s"`, ip.String()) jsonValue := fmt.Sprintf(`"%s"`, ip.String())
var peer nbpeer.Peer var peer nbpeer.Peer
@@ -2559,3 +2353,11 @@ func (s *SqlStore) CountAccountsByPrivateDomain(ctx context.Context, domain stri
return count, nil return count, nil
} }
func (s *SqlStore) getTXWithLockStrength(lockStrength LockingStrength) *gorm.DB {
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
return tx
}

View File

@@ -18,6 +18,10 @@ type UpdateChannelMetrics struct {
getAllConnectedPeersDurationMicro metric.Int64Histogram getAllConnectedPeersDurationMicro metric.Int64Histogram
getAllConnectedPeers metric.Int64Histogram getAllConnectedPeers metric.Int64Histogram
hasChannelDurationMicro metric.Int64Histogram hasChannelDurationMicro metric.Int64Histogram
calcPostureChecksDurationMicro metric.Int64Histogram
calcPeerNetworkMapDurationMicro metric.Int64Histogram
mergeNetworkMapDurationMicro metric.Int64Histogram
toSyncResponseDurationMicro metric.Int64Histogram
ctx context.Context ctx context.Context
} }
@@ -89,6 +93,38 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
return nil, err return nil, err
} }
calcPostureChecksDurationMicro, err := meter.Int64Histogram("management.updatechannel.calc.posturechecks.duration.micro",
metric.WithUnit("microseconds"),
metric.WithDescription("Duration of how long it takes to get the posture checks for a peer"),
)
if err != nil {
return nil, err
}
calcPeerNetworkMapDurationMicro, err := meter.Int64Histogram("management.updatechannel.calc.networkmap.duration.micro",
metric.WithUnit("microseconds"),
metric.WithDescription("Duration of how long it takes to calculate the network map for a peer"),
)
if err != nil {
return nil, err
}
mergeNetworkMapDurationMicro, err := meter.Int64Histogram("management.updatechannel.merge.networkmap.duration.micro",
metric.WithUnit("microseconds"),
metric.WithDescription("Duration of how long it takes to merge the network maps for a peer"),
)
if err != nil {
return nil, err
}
toSyncResponseDurationMicro, err := meter.Int64Histogram("management.updatechannel.tosyncresponse.duration.micro",
metric.WithUnit("microseconds"),
metric.WithDescription("Duration of how long it takes to convert the network map to sync response"),
)
if err != nil {
return nil, err
}
return &UpdateChannelMetrics{ return &UpdateChannelMetrics{
createChannelDurationMicro: createChannelDurationMicro, createChannelDurationMicro: createChannelDurationMicro,
closeChannelDurationMicro: closeChannelDurationMicro, closeChannelDurationMicro: closeChannelDurationMicro,
@@ -98,6 +134,10 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
getAllConnectedPeersDurationMicro: getAllConnectedPeersDurationMicro, getAllConnectedPeersDurationMicro: getAllConnectedPeersDurationMicro,
getAllConnectedPeers: getAllConnectedPeers, getAllConnectedPeers: getAllConnectedPeers,
hasChannelDurationMicro: hasChannelDurationMicro, hasChannelDurationMicro: hasChannelDurationMicro,
calcPostureChecksDurationMicro: calcPostureChecksDurationMicro,
calcPeerNetworkMapDurationMicro: calcPeerNetworkMapDurationMicro,
mergeNetworkMapDurationMicro: mergeNetworkMapDurationMicro,
toSyncResponseDurationMicro: toSyncResponseDurationMicro,
ctx: ctx, ctx: ctx,
}, nil }, nil
} }
@@ -137,3 +177,19 @@ func (metrics *UpdateChannelMetrics) CountGetAllConnectedPeersDuration(duration
func (metrics *UpdateChannelMetrics) CountHasChannelDuration(duration time.Duration) { func (metrics *UpdateChannelMetrics) CountHasChannelDuration(duration time.Duration) {
metrics.hasChannelDurationMicro.Record(metrics.ctx, duration.Microseconds()) metrics.hasChannelDurationMicro.Record(metrics.ctx, duration.Microseconds())
} }
func (metrics *UpdateChannelMetrics) CountCalcPostureChecksDuration(duration time.Duration) {
metrics.calcPostureChecksDurationMicro.Record(metrics.ctx, duration.Microseconds())
}
func (metrics *UpdateChannelMetrics) CountCalcPeerNetworkMapDuration(duration time.Duration) {
metrics.calcPeerNetworkMapDurationMicro.Record(metrics.ctx, duration.Microseconds())
}
func (metrics *UpdateChannelMetrics) CountMergeNetworkMapDuration(duration time.Duration) {
metrics.mergeNetworkMapDurationMicro.Record(metrics.ctx, duration.Microseconds())
}
func (metrics *UpdateChannelMetrics) CountToSyncResponseDuration(duration time.Duration) {
metrics.toSyncResponseDurationMicro.Record(metrics.ctx, duration.Microseconds())
}

View File

@@ -36,6 +36,9 @@ const (
PublicCategory = "public" PublicCategory = "public"
PrivateCategory = "private" PrivateCategory = "private"
UnknownCategory = "unknown" UnknownCategory = "unknown"
// firewallRuleMinPortRangesVer defines the minimum peer version that supports port range rules.
firewallRuleMinPortRangesVer = "0.48.0"
) )
type LookupMap map[string]struct{} type LookupMap map[string]struct{}
@@ -248,7 +251,7 @@ func (a *Account) GetPeerNetworkMap(
} }
} }
aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peerID, validatedPeersMap) aclPeers, firewallRules := a.GetPeerConnectionResources(ctx, peer, validatedPeersMap)
// exclude expired peers // exclude expired peers
var peersToConnect []*nbpeer.Peer var peersToConnect []*nbpeer.Peer
var expiredPeers []*nbpeer.Peer var expiredPeers []*nbpeer.Peer
@@ -961,8 +964,9 @@ func (a *Account) UserGroupsRemoveFromPeers(userID string, groups ...string) map
// GetPeerConnectionResources for a given peer // GetPeerConnectionResources for a given peer
// //
// This function returns the list of peers and firewall rules that are applicable to a given peer. // This function returns the list of peers and firewall rules that are applicable to a given peer.
func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) { func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.Peer, validatedPeersMap map[string]struct{}) ([]*nbpeer.Peer, []*FirewallRule) {
generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx) generateResources, getAccumulatedResources := a.connResourcesGenerator(ctx, peer)
for _, policy := range a.Policies { for _, policy := range a.Policies {
if !policy.Enabled { if !policy.Enabled {
continue continue
@@ -973,8 +977,8 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string,
continue continue
} }
sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peerID, policy.SourcePostureChecks, validatedPeersMap) sourcePeers, peerInSources := a.getAllPeersFromGroups(ctx, rule.Sources, peer.ID, policy.SourcePostureChecks, validatedPeersMap)
destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peerID, nil, validatedPeersMap) destinationPeers, peerInDestinations := a.getAllPeersFromGroups(ctx, rule.Destinations, peer.ID, nil, validatedPeersMap)
if rule.Bidirectional { if rule.Bidirectional {
if peerInSources { if peerInSources {
@@ -1003,7 +1007,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peerID string,
// The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer. // The generator function is used to generate the list of peers and firewall rules that are applicable to a given peer.
// It safe to call the generator function multiple times for same peer and different rules no duplicates will be // It safe to call the generator function multiple times for same peer and different rules no duplicates will be
// generated. The accumulator function returns the result of all the generator calls. // generated. The accumulator function returns the result of all the generator calls.
func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) { func (a *Account) connResourcesGenerator(ctx context.Context, targetPeer *nbpeer.Peer) (func(*PolicyRule, []*nbpeer.Peer, int), func() ([]*nbpeer.Peer, []*FirewallRule)) {
rulesExists := make(map[string]struct{}) rulesExists := make(map[string]struct{})
peersExists := make(map[string]struct{}) peersExists := make(map[string]struct{})
rules := make([]*FirewallRule, 0) rules := make([]*FirewallRule, 0)
@@ -1051,17 +1055,7 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule,
continue continue
} }
for _, port := range rule.Ports { rules = append(rules, expandPortsAndRanges(fr, rule, targetPeer)...)
pr := fr // clone rule and add set new port
pr.Port = port
rules = append(rules, &pr)
}
for _, portRange := range rule.PortRanges {
pr := fr
pr.PortRange = portRange
rules = append(rules, &pr)
}
} }
}, func() ([]*nbpeer.Peer, []*FirewallRule) { }, func() ([]*nbpeer.Peer, []*FirewallRule) {
return peers, rules return peers, rules
@@ -1590,3 +1584,45 @@ func (a *Account) AddAllGroup() error {
} }
return nil return nil
} }
// expandPortsAndRanges expands Ports and PortRanges of a rule into individual firewall rules
func expandPortsAndRanges(base FirewallRule, rule *PolicyRule, peer *nbpeer.Peer) []*FirewallRule {
var expanded []*FirewallRule
if len(rule.Ports) > 0 {
for _, port := range rule.Ports {
fr := base
fr.Port = port
expanded = append(expanded, &fr)
}
return expanded
}
supportPortRanges := peerSupportsPortRanges(peer.Meta.WtVersion)
for _, portRange := range rule.PortRanges {
fr := base
if supportPortRanges {
fr.PortRange = portRange
} else {
// Peer doesn't support port ranges, only allow single-port ranges
if portRange.Start != portRange.End {
continue
}
fr.Port = strconv.FormatUint(uint64(portRange.Start), 10)
}
expanded = append(expanded, &fr)
}
return expanded
}
// peerSupportsPortRanges checks if the peer version supports port ranges.
func peerSupportsPortRanges(peerVer string) bool {
if strings.Contains(peerVer, "dev") {
return true
}
meetMinVer, err := posture.MeetsMinVersion(firewallRuleMinPortRangesVer, peerVer)
return err == nil && meetMinVer
}

View File

@@ -76,7 +76,6 @@ func generateRouteFirewallRules(ctx context.Context, route *nbroute.Route, rule
rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...) rules = append(rules, generateRulesWithPortRanges(baseRule, rule, rulesExists)...)
} else { } else {
rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...) rules = append(rules, generateRulesWithPorts(ctx, baseRule, rule, rulesExists)...)
} }
// TODO: generate IPv6 rules for dynamic routes // TODO: generate IPv6 rules for dynamic routes

View File

@@ -95,14 +95,14 @@ func (am *DefaultAccountManager) inviteNewUser(ctx context.Context, accountID, u
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
inviterID := userID inviterID := userID
if initiatorUser.IsServiceUser { if initiatorUser.IsServiceUser {
createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthShare, accountID) createdBy, err := am.Store.GetAccountCreatedBy(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -178,13 +178,13 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID
} }
func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) {
return am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, id) return am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, id)
} }
// GetUser looks up a user by provided nbContext.UserAuths. // GetUser looks up a user by provided nbContext.UserAuths.
// Expects account to have been created already. // Expects account to have been created already.
func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) { func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAuth nbContext.UserAuth) (*types.User, error) {
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userAuth.UserId) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userAuth.UserId)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -209,7 +209,7 @@ func (am *DefaultAccountManager) GetUserFromUserAuth(ctx context.Context, userAu
// ListUsers returns lists of all users under the account. // ListUsers returns lists of all users under the account.
// It doesn't populate user information such as email or name. // It doesn't populate user information such as email or name.
func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) { func (am *DefaultAccountManager) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) {
return am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) return am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
} }
func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *types.User) error { func (am *DefaultAccountManager) deleteServiceUser(ctx context.Context, accountID string, initiatorUserID string, targetUser *types.User) error {
@@ -230,7 +230,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return err return err
} }
@@ -243,7 +243,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return err return err
} }
@@ -347,12 +347,12 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -390,12 +390,12 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return err return err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return err return err
} }
@@ -404,7 +404,7 @@ func (am *DefaultAccountManager) DeletePAT(ctx context.Context, accountID string
return status.NewAdminPermissionError() return status.NewAdminPermissionError()
} }
pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID) pat, err := am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID)
if err != nil { if err != nil {
return err return err
} }
@@ -429,12 +429,12 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -443,7 +443,7 @@ func (am *DefaultAccountManager) GetPAT(ctx context.Context, accountID string, i
return nil, status.NewAdminPermissionError() return nil, status.NewAdminPermissionError()
} }
return am.Store.GetPATByID(ctx, store.LockingStrengthShare, targetUserID, tokenID) return am.Store.GetPATByID(ctx, store.LockingStrengthNone, targetUserID, tokenID)
} }
// GetAllPATs returns all PATs for a user // GetAllPATs returns all PATs for a user
@@ -456,12 +456,12 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -470,7 +470,7 @@ func (am *DefaultAccountManager) GetAllPATs(ctx context.Context, accountID strin
return nil, status.NewAdminPermissionError() return nil, status.NewAdminPermissionError()
} }
return am.Store.GetUserPATs(ctx, store.LockingStrengthShare, targetUserID) return am.Store.GetUserPATs(ctx, store.LockingStrengthNone, targetUserID)
} }
// SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error. // SaveUser saves updates to the given user. If the user doesn't exist, it will throw status.NotFound error.
@@ -511,7 +511,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
if !allowed { if !allowed {
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -521,7 +521,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
var addUserEvents []func() var addUserEvents []func()
var usersToSave = make([]*types.User, 0, len(updates)) var usersToSave = make([]*types.User, 0, len(updates))
groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) groups, err := am.Store.GetAccountGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, fmt.Errorf("error getting account groups: %w", err) return nil, fmt.Errorf("error getting account groups: %w", err)
} }
@@ -533,7 +533,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
var initiatorUser *types.User var initiatorUser *types.User
if initiatorUserID != activity.SystemInitiator { if initiatorUserID != activity.SystemInitiator {
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -695,7 +695,7 @@ func (am *DefaultAccountManager) processUserUpdate(ctx context.Context, transact
// getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist. // getUserOrCreateIfNotExists retrieves the existing user or creates a new one if it doesn't exist.
func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) { func getUserOrCreateIfNotExists(ctx context.Context, transaction store.Store, accountID string, update *types.User, addIfNotExists bool) (*types.User, error) {
existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthShare, update.Id) existingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, update.Id)
if err != nil { if err != nil {
if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound { if sErr, ok := status.FromError(err); ok && sErr.Type() == status.NotFound {
if !addIfNotExists { if !addIfNotExists {
@@ -830,7 +830,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
var user *types.User var user *types.User
if initiatorUserID != activity.SystemInitiator { if initiatorUserID != activity.SystemInitiator {
result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) result, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get user: %w", err) return nil, fmt.Errorf("failed to get user: %w", err)
} }
@@ -840,7 +840,7 @@ func (am *DefaultAccountManager) GetUsersFromAccount(ctx context.Context, accoun
accountUsers := []*types.User{} accountUsers := []*types.User{}
switch { switch {
case allowed: case allowed:
accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) accountUsers, err = am.Store.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -933,7 +933,7 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
// expireAndUpdatePeers expires all peers of the given user and updates them in the account // expireAndUpdatePeers expires all peers of the given user and updates them in the account
func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error { func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accountID string, peers []*nbpeer.Peer) error {
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return err return err
} }
@@ -1003,7 +1003,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
return status.NewPermissionDeniedError() return status.NewPermissionDeniedError()
} }
initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorUserID) initiatorUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, initiatorUserID)
if err != nil { if err != nil {
return err return err
} }
@@ -1017,7 +1017,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
continue continue
} }
targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserID) targetUser, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserID)
if err != nil { if err != nil {
allErrors = errors.Join(allErrors, err) allErrors = errors.Join(allErrors, err)
continue continue
@@ -1081,12 +1081,12 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
var err error var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthShare, targetUserInfo.ID) targetUser, err = transaction.GetUserByUserID(ctx, store.LockingStrengthNone, targetUserInfo.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user to delete: %w", err) return fmt.Errorf("failed to get user to delete: %w", err)
} }
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, targetUserInfo.ID) userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthNone, accountID, targetUserInfo.ID)
if err != nil { if err != nil {
return fmt.Errorf("failed to get user peers: %w", err) return fmt.Errorf("failed to get user peers: %w", err)
} }
@@ -1120,7 +1120,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
// GetOwnerInfo retrieves the owner information for a given account ID. // GetOwnerInfo retrieves the owner information for a given account ID.
func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) { func (am *DefaultAccountManager) GetOwnerInfo(ctx context.Context, accountID string) (*types.UserInfo, error) {
owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthShare, accountID) owner, err := am.Store.GetAccountOwner(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1257,7 +1257,7 @@ func validateUserInvite(invite *types.UserInfo) error {
func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) { func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) {
accountID, userID := userAuth.AccountId, userAuth.UserId accountID, userID := userAuth.AccountId, userAuth.UserId
user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) user, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1274,7 +1274,7 @@ func (am *DefaultAccountManager) GetCurrentUserInfo(ctx context.Context, userAut
return nil, err return nil, err
} }
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthShare, accountID) settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -88,7 +88,7 @@ func TestUser_CreatePAT_ForSameUser(t *testing.T) {
assert.Equal(t, pat.ID, tokenID) assert.Equal(t, pat.ID, tokenID)
user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthShare, tokenID) user, err := am.Store.GetUserByPATID(context.Background(), store.LockingStrengthNone, tokenID)
if err != nil { if err != nil {
t.Fatalf("Error when getting user by token ID: %s", err) t.Fatalf("Error when getting user by token ID: %s", err)
} }
@@ -1521,7 +1521,7 @@ func TestSaveOrAddUser_PreventAccountSwitch(t *testing.T) {
_, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true) _, err = am.SaveOrAddUser(context.Background(), "account2", "ownerAccount2", account1.Users[targetId], true)
assert.Error(t, err, "update user to another account should fail") assert.Error(t, err, "update user to another account should fail")
user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthShare, targetId) user, err := s.GetUserByUserID(context.Background(), store.LockingStrengthNone, targetId)
require.NoError(t, err) require.NoError(t, err)
assert.Equal(t, account1.Users[targetId].Id, user.Id) assert.Equal(t, account1.Users[targetId].Id, user.Id)
assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID) assert.Equal(t, account1.Users[targetId].AccountID, user.AccountID)

View File

@@ -26,7 +26,7 @@ func NewManager(store store.Store) Manager {
} }
func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) { func (m *managerImpl) GetUser(ctx context.Context, userID string) (*types.User, error) {
return m.store.GetUserByUserID(ctx, store.LockingStrengthShare, userID) return m.store.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
} }
func NewManagerMock() Manager { func NewManagerMock() Manager {

View File

@@ -21,6 +21,7 @@ var (
// Update fetch the version info periodically and notify the onUpdateListener in case the UI version or the // Update fetch the version info periodically and notify the onUpdateListener in case the UI version or the
// daemon version are deprecated // daemon version are deprecated
type Update struct { type Update struct {
httpAgent string
uiVersion *goversion.Version uiVersion *goversion.Version
daemonVersion *goversion.Version daemonVersion *goversion.Version
latestAvailable *goversion.Version latestAvailable *goversion.Version
@@ -34,7 +35,7 @@ type Update struct {
} }
// NewUpdate instantiate Update and start to fetch the new version information // NewUpdate instantiate Update and start to fetch the new version information
func NewUpdate() *Update { func NewUpdate(httpAgent string) *Update {
currentVersion, err := goversion.NewVersion(version) currentVersion, err := goversion.NewVersion(version)
if err != nil { if err != nil {
currentVersion, _ = goversion.NewVersion("0.0.0") currentVersion, _ = goversion.NewVersion("0.0.0")
@@ -43,6 +44,7 @@ func NewUpdate() *Update {
latestAvailable, _ := goversion.NewVersion("0.0.0") latestAvailable, _ := goversion.NewVersion("0.0.0")
u := &Update{ u := &Update{
httpAgent: httpAgent,
latestAvailable: latestAvailable, latestAvailable: latestAvailable,
uiVersion: currentVersion, uiVersion: currentVersion,
fetchTicker: time.NewTicker(fetchPeriod), fetchTicker: time.NewTicker(fetchPeriod),
@@ -112,7 +114,15 @@ func (u *Update) startFetcher() {
func (u *Update) fetchVersion() bool { func (u *Update) fetchVersion() bool {
log.Debugf("fetching version info from %s", versionURL) log.Debugf("fetching version info from %s", versionURL)
resp, err := http.Get(versionURL) req, err := http.NewRequest("GET", versionURL, nil)
if err != nil {
log.Errorf("failed to create request for version info: %s", err)
return false
}
req.Header.Set("User-Agent", u.httpAgent)
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
log.Errorf("failed to fetch version info: %s", err) log.Errorf("failed to fetch version info: %s", err)
return false return false

View File

@@ -9,6 +9,8 @@ import (
"time" "time"
) )
const httpAgent = "pkg/test"
func TestNewUpdate(t *testing.T) { func TestNewUpdate(t *testing.T) {
version = "1.0.0" version = "1.0.0"
svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -21,7 +23,7 @@ func TestNewUpdate(t *testing.T) {
wg.Add(1) wg.Add(1)
onUpdate := false onUpdate := false
u := NewUpdate() u := NewUpdate(httpAgent)
defer u.StopWatch() defer u.StopWatch()
u.SetOnUpdateListener(func() { u.SetOnUpdateListener(func() {
onUpdate = true onUpdate = true
@@ -46,7 +48,7 @@ func TestDoNotUpdate(t *testing.T) {
wg.Add(1) wg.Add(1)
onUpdate := false onUpdate := false
u := NewUpdate() u := NewUpdate(httpAgent)
defer u.StopWatch() defer u.StopWatch()
u.SetOnUpdateListener(func() { u.SetOnUpdateListener(func() {
onUpdate = true onUpdate = true
@@ -71,7 +73,7 @@ func TestDaemonUpdate(t *testing.T) {
wg.Add(1) wg.Add(1)
onUpdate := false onUpdate := false
u := NewUpdate() u := NewUpdate(httpAgent)
defer u.StopWatch() defer u.StopWatch()
u.SetOnUpdateListener(func() { u.SetOnUpdateListener(func() {
onUpdate = true onUpdate = true