Compare commits

...

12 Commits

Author SHA1 Message Date
Pascal Fischer
84f229f051 fix cancelRefresh 2024-11-06 20:47:25 +01:00
Pascal Fischer
8ac4a0a1e1 fix bool 2024-11-06 20:34:06 +01:00
Pascal Fischer
22126d0484 add session id to update channel 2024-11-06 20:29:59 +01:00
Maycon Santos
b952d8693d Fix cached device flow oauth (#2833)
This change removes the cached device flow oauth info when a down command is called

Removing the need for the agent to be restarted
2024-11-05 14:51:17 +01:00
Maycon Santos
5b46cc8e9c Avoid failing all other matrix tests if one fails (#2839) 2024-11-05 13:28:42 +01:00
Pascal Fischer
a9d06b883f add all group to add peer affected peers network map check (#2830) 2024-11-01 22:09:08 +01:00
Viktor Liu
5f06b202c3 [client] Log windows panics (#2829) 2024-11-01 15:08:22 +01:00
Zoltan Papp
0eb99c266a Fix unused servers cleanup (#2826)
The cleanup loop did not manage those situations well when a connection failed or 
the connection success but the code did not add a peer connection to it yet.

- in the cleanup loop check if a connection failed to a server
- after adding a foreign server connection force to keep it a minimum 5 sec
2024-11-01 12:33:29 +01:00
Pascal Fischer
bac95ace18 [management] Add DB access duration to logs for context cancel (#2781) 2024-11-01 10:58:39 +01:00
Zoltan Papp
9812de853b Allocate new buffer for every package (#2823) 2024-11-01 00:33:25 +01:00
Zoltan Papp
ad4f0a6fdf [client] Nil check on ICE remote conn (#2806) 2024-10-31 23:18:35 +01:00
Pascal Fischer
4c758c6e52 [management] remove network map diff calculations (#2820) 2024-10-31 19:24:15 +01:00
34 changed files with 620 additions and 1055 deletions

View File

@@ -13,6 +13,7 @@ concurrency:
jobs:
test:
strategy:
fail-fast: false
matrix:
arch: [ '386','amd64' ]
store: [ 'sqlite', 'postgres']

View File

@@ -104,8 +104,8 @@ func (p *ProxyBind) proxyToLocal(ctx context.Context) {
}
}()
buf := make([]byte, 1500)
for {
buf := make([]byte, 1500)
n, err := p.remoteConn.Read(buf)
if err != nil {
if ctx.Err() != nil {

View File

@@ -309,6 +309,11 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
return
}
if remoteConnNil(conn.log, iceConnInfo.RemoteConn) {
conn.log.Errorf("remote ICE connection is nil")
return
}
conn.log.Debugf("ICE connection is ready")
if conn.currentConnPriority > priority {

View File

@@ -0,0 +1,21 @@
package peer
import (
"net"
log "github.com/sirupsen/logrus"
)
func remoteConnNil(log *log.Entry, conn net.Conn) bool {
if conn == nil {
log.Errorf("ice conn is nil")
return true
}
if conn.RemoteAddr() == nil {
log.Errorf("ICE remote address is nil")
return true
}
return false
}

View File

@@ -0,0 +1,7 @@
//go:build !windows
package server
func handlePanicLog() error {
return nil
}

View File

@@ -0,0 +1,83 @@
package server
import (
"fmt"
"os"
"path/filepath"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
const (
windowsPanicLogEnvVar = "NB_WINDOWS_PANIC_LOG"
// STD_ERROR_HANDLE ((DWORD)-12) = 4294967284
stdErrorHandle = ^uintptr(11)
)
var (
kernel32 = syscall.NewLazyDLL("kernel32.dll")
// https://learn.microsoft.com/en-us/windows/console/setstdhandle
setStdHandleFn = kernel32.NewProc("SetStdHandle")
)
func handlePanicLog() error {
logPath := os.Getenv(windowsPanicLogEnvVar)
if logPath == "" {
return nil
}
// Ensure the directory exists
logDir := filepath.Dir(logPath)
if err := os.MkdirAll(logDir, 0750); err != nil {
return fmt.Errorf("create panic log directory: %w", err)
}
if err := util.EnforcePermission(logPath); err != nil {
return fmt.Errorf("enforce permission on panic log file: %w", err)
}
// Open log file with append mode
f, err := os.OpenFile(logPath, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644)
if err != nil {
return fmt.Errorf("open panic log file: %w", err)
}
// Redirect stderr to the file
if err = redirectStderr(f); err != nil {
if closeErr := f.Close(); closeErr != nil {
log.Warnf("failed to close file after redirect error: %v", closeErr)
}
return fmt.Errorf("redirect stderr: %w", err)
}
log.Infof("successfully configured panic logging to: %s", logPath)
return nil
}
// redirectStderr redirects stderr to the provided file
func redirectStderr(f *os.File) error {
// Get the current process's stderr handle
if err := setStdHandle(f); err != nil {
return fmt.Errorf("failed to set stderr handle: %w", err)
}
// Also set os.Stderr for Go's standard library
os.Stderr = f
return nil
}
func setStdHandle(f *os.File) error {
handle := f.Fd()
r0, _, e1 := setStdHandleFn.Call(stdErrorHandle, handle)
if r0 == 0 {
if e1 != nil {
return e1
}
return syscall.EINVAL
}
return nil
}

View File

@@ -97,6 +97,10 @@ func (s *Server) Start() error {
defer s.mutex.Unlock()
state := internal.CtxGetState(s.rootCtx)
if err := handlePanicLog(); err != nil {
log.Warnf("failed to redirect stderr: %v", err)
}
if err := restoreResidualState(s.rootCtx); err != nil {
log.Warnf(errRestoreResidualState, err)
}
@@ -622,6 +626,8 @@ func (s *Server) Down(ctx context.Context, _ *proto.DownRequest) (*proto.DownRes
s.mutex.Lock()
defer s.mutex.Unlock()
s.oauthAuthFlow = oauthAuthFlow{}
if s.actCancel == nil {
return nil, fmt.Errorf("service is not up")
}

3
go.mod
View File

@@ -71,7 +71,6 @@ require (
github.com/pion/transport/v3 v3.0.1
github.com/pion/turn/v3 v3.0.1
github.com/prometheus/client_golang v1.19.1
github.com/r3labs/diff/v3 v3.0.1
github.com/rs/xid v1.3.0
github.com/shirou/gopsutil/v3 v3.24.4
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
@@ -211,8 +210,6 @@ require (
github.com/tklauser/go-sysconf v0.3.14 // indirect
github.com/tklauser/numcpus v0.8.0 // indirect
github.com/vishvananda/netns v0.0.4 // indirect
github.com/vmihailenco/msgpack/v5 v5.3.5 // indirect
github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect
github.com/yuin/goldmark v1.7.1 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect
go.opencensus.io v0.24.0 // indirect

6
go.sum
View File

@@ -605,8 +605,6 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a
github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U=
github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek=
github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk=
github.com/r3labs/diff/v3 v3.0.1 h1:CBKqf3XmNRHXKmdU7mZP1w7TV0pDyVCis1AUHtA4Xtg=
github.com/r3labs/diff/v3 v3.0.1/go.mod h1:f1S9bourRbiM66NskseyUdo0fTmEE0qKrikYJX63dgo=
github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
@@ -699,10 +697,6 @@ github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhg
github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0=
github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8=
github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM=
github.com/vmihailenco/msgpack/v5 v5.3.5 h1:5gO0H1iULLWGhs2H5tbAHIZTV8/cYafcFOr9znI5mJU=
github.com/vmihailenco/msgpack/v5 v5.3.5/go.mod h1:7xyJ9e+0+9SaZT0Wt1RGleJXzli6Q/V5KbhBonMG9jc=
github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=

View File

@@ -1147,14 +1147,14 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
require.NoError(t, err)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
message := <-updMsg
message := <-updMsg.channel
networkMap := message.Update.GetNetworkMap()
if len(networkMap.RemotePeers) != 2 {
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
@@ -1174,14 +1174,14 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
manager, account, peer1, _, _ := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
message := <-updMsg
message := <-updMsg.channel
networkMap := message.Update.GetNetworkMap()
if len(networkMap.RemotePeers) != 0 {
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
@@ -1210,7 +1210,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
policy := Policy{
Enabled: true,
@@ -1230,7 +1230,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
go func() {
defer wg.Done()
message := <-updMsg
message := <-updMsg.channel
networkMap := message.Update.GetNetworkMap()
if len(networkMap.RemotePeers) != 2 {
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
@@ -1277,14 +1277,14 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
}
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
message := <-updMsg
message := <-updMsg.channel
networkMap := message.Update.GetNetworkMap()
if len(networkMap.RemotePeers) != 1 {
t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers))
@@ -1303,7 +1303,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
group := group.Group{
ID: "groupA",
@@ -1339,7 +1339,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
go func() {
defer wg.Done()
message := <-updMsg
message := <-updMsg.channel
networkMap := message.Update.GetNetworkMap()
if len(networkMap.RemotePeers) != 0 {
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))

View File

@@ -1,82 +0,0 @@
package differs
import (
"fmt"
"net/netip"
"reflect"
"github.com/r3labs/diff/v3"
)
// NetIPAddr is a custom differ for netip.Addr
type NetIPAddr struct {
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
}
func (differ NetIPAddr) Match(a, b reflect.Value) bool {
return diff.AreType(a, b, reflect.TypeOf(netip.Addr{}))
}
func (differ NetIPAddr) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
if a.Kind() == reflect.Invalid {
cl.Add(diff.CREATE, path, nil, b.Interface())
return nil
}
if b.Kind() == reflect.Invalid {
cl.Add(diff.DELETE, path, a.Interface(), nil)
return nil
}
fromAddr, ok1 := a.Interface().(netip.Addr)
toAddr, ok2 := b.Interface().(netip.Addr)
if !ok1 || !ok2 {
return fmt.Errorf("invalid type for netip.Addr")
}
if fromAddr.String() != toAddr.String() {
cl.Add(diff.UPDATE, path, fromAddr.String(), toAddr.String())
}
return nil
}
func (differ NetIPAddr) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
differ.DiffFunc = dfunc //nolint
}
// NetIPPrefix is a custom differ for netip.Prefix
type NetIPPrefix struct {
DiffFunc func(path []string, a, b reflect.Value, p interface{}) error
}
func (differ NetIPPrefix) Match(a, b reflect.Value) bool {
return diff.AreType(a, b, reflect.TypeOf(netip.Prefix{}))
}
func (differ NetIPPrefix) Diff(_ diff.DiffType, _ diff.DiffFunc, cl *diff.Changelog, path []string, a, b reflect.Value, _ interface{}) error {
if a.Kind() == reflect.Invalid {
cl.Add(diff.CREATE, path, nil, b.Interface())
return nil
}
if b.Kind() == reflect.Invalid {
cl.Add(diff.DELETE, path, a.Interface(), nil)
return nil
}
fromPrefix, ok1 := a.Interface().(netip.Prefix)
toPrefix, ok2 := b.Interface().(netip.Prefix)
if !ok1 || !ok2 {
return fmt.Errorf("invalid type for netip.Addr")
}
if fromPrefix.String() != toPrefix.String() {
cl.Add(diff.UPDATE, path, fromPrefix.String(), toPrefix.String())
}
return nil
}
func (differ NetIPPrefix) InsertParentDiffer(dfunc func(path []string, a, b reflect.Value, p interface{}) error) {
differ.DiffFunc = dfunc //nolint
}

View File

@@ -8,9 +8,10 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -498,14 +499,14 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
})
// Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates
t.Run("saving dns setting with unused groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -521,29 +522,70 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
}
})
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
})
assert.NoError(t, err)
// Creating DNS settings with groups that have no peers should not update account peers or send peer update
t.Run("creating dns setting with unused groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
_, err = manager.CreateNameServerGroup(
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
IP: netip.MustParseAddr(peer1.IP.String()),
NSType: dns.UDPNameServerType,
Port: dns.DefaultDNSPort,
}},
[]string{"groupA"},
true, []string{}, true, userID, false,
)
assert.NoError(t, err)
_, err = manager.CreateNameServerGroup(
context.Background(), account.Id, "ns-group", "ns-group", []dns.NameServer{{
IP: netip.MustParseAddr(peer1.IP.String()),
NSType: dns.UDPNameServerType,
Port: dns.DefaultDNSPort,
}},
[]string{"groupB"},
true, []string{}, true, userID, false,
)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Creating DNS settings with groups that have peers should update account peers and send peer update
t.Run("creating dns setting with used groups", func(t *testing.T) {
err = manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
})
assert.NoError(t, err)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
_, err = manager.CreateNameServerGroup(
context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{
IP: netip.MustParseAddr(peer1.IP.String()),
NSType: dns.UDPNameServerType,
Port: dns.DefaultDNSPort,
}},
[]string{"groupA"},
true, []string{}, true, userID, false,
)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})
// Saving DNS settings with groups that have peers should update account peers and send peer update
t.Run("saving dns setting with used groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -559,32 +601,11 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
}
})
// Saving unchanged DNS settings with used groups should update account peers and not send peer update
// since there is no change in the network map
t.Run("saving unchanged dns setting with used groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{
DisabledManagementGroups: []string{"groupA", "groupB"},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Removing group with no peers from DNS settings should not trigger updates to account peers or send peer updates
t.Run("removing group with no peers from dns settings", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -604,7 +625,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) {
t.Run("removing group with peers from dns settings", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()

View File

@@ -8,12 +8,13 @@ import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
nbgroup "github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/status"
"github.com/netbirdio/netbird/route"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
@@ -417,14 +418,14 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
})
// Saving a group that is not linked to any resource should not update account peers
t.Run("saving unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -447,7 +448,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Run("adding peer to unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -466,7 +467,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Run("removing peer from unliked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -484,7 +485,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Run("deleting group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -518,7 +519,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Run("saving linked group to policy", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -536,34 +537,11 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
}
})
// Saving an unchanged group should trigger account peers update and not send peer update
// since there is no change in the network map
t.Run("saving unchanged group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{
ID: "groupA",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID},
})
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// adding peer to a used group should update account peers and send peer update
t.Run("adding peer to linked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -581,7 +559,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
t.Run("removing peer from linked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -610,7 +588,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -651,7 +629,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -678,7 +656,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()

View File

@@ -194,31 +194,31 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi
}
// handleUpdates sends updates to the connected peer until the updates channel is closed.
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, peerUpdates *PeerUpdateChannel, srv proto.ManagementService_SyncServer) error {
for {
select {
// condition when there are some updates
case update, open := <-updates:
case update, open := <-peerUpdates.channel:
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1)
s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(peerUpdates.channel) + 1)
}
if !open {
log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String())
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, peerUpdates.sessionID)
return nil
}
log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String())
if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil {
if err := s.sendUpdate(ctx, accountID, peerKey, peer, peerUpdates.sessionID, update, srv); err != nil {
return err
}
// condition when client <-> server connection has been terminated
case <-srv.Context().Done():
// happens when connection drops, e.g. client disconnects
log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String())
s.cancelPeerRoutines(ctx, accountID, peer)
log.WithContext(ctx).Debugf("stream of peer %s with session %s has been closed", peerKey.String(), peerUpdates.sessionID)
s.cancelPeerRoutines(ctx, accountID, peer, peerUpdates.sessionID)
return srv.Context().Err()
}
}
@@ -226,10 +226,10 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe
// sendUpdate encrypts the update message using the peer key and the server's wireguard key,
// then sends the encrypted message to the connected peer via the sync server.
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, sessionID string, update *UpdateMessage, srv proto.ManagementService_SyncServer) error {
encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update)
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, sessionID)
return status.Errorf(codes.Internal, "failed processing update message")
}
err = srv.SendMsg(&proto.EncryptedMessage{
@@ -237,18 +237,20 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w
Body: encryptedResp,
})
if err != nil {
s.cancelPeerRoutines(ctx, accountID, peer)
s.cancelPeerRoutines(ctx, accountID, peer, sessionID)
return status.Errorf(codes.Internal, "failed sending update message")
}
log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String())
return nil
}
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) {
s.peersUpdateManager.CloseChannel(ctx, peer.ID)
s.secretsManager.CancelRefresh(peer.ID)
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, sessionID string) {
ok := s.peersUpdateManager.CloseChannel(ctx, peer.ID, sessionID)
if ok {
_ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key)
s.secretsManager.CancelRefresh(peer.ID)
s.ephemeralManager.OnPeerDisconnected(ctx, peer)
}
}
func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) {

View File

@@ -960,7 +960,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
})
// Creating a nameserver group with a distribution group no peers should not update account peers
@@ -968,7 +968,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
t.Run("creating nameserver group with distribution group no peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -995,7 +995,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
t.Run("saving nameserver group with distribution group no peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1013,7 +1013,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
t.Run("creating nameserver group with distribution group has peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1039,7 +1039,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
t.Run("saving nameserver group with distribution group has peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1065,41 +1065,11 @@ func TestNameServerAccountPeersUpdate(t *testing.T) {
}
})
// saving unchanged nameserver group should update account peers and not send peer update
t.Run("saving unchanged nameserver group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
newNameServerGroupB.NameServers = []nbdns.NameServer{
{
IP: netip.MustParseAddr("1.1.1.2"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
},
}
err = manager.SaveNameServerGroup(context.Background(), account.Id, userID, newNameServerGroupB)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Deleting a nameserver group should update account peers and send peer update
t.Run("deleting nameserver group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()

View File

@@ -41,9 +41,9 @@ type Network struct {
Dns string
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
// Used to synchronize state to the client apps.
Serial uint64 `diff:"-"`
Serial uint64
mu sync.Mutex `json:"-" gorm:"-" diff:"-"`
mu sync.Mutex `json:"-" gorm:"-"`
}
// NewNetwork creates a new Network initializing it with a Serial=0

View File

@@ -313,7 +313,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
},
NetworkMap: &NetworkMap{},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
am.peersUpdateManager.CloseChannel(ctx, peer.ID, SessionIdForceOverwrite)
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
}
@@ -589,6 +589,12 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
}
allGroup, err := account.GetGroupAll()
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting all group ID: %w", err)
}
groupsToAdd = append(groupsToAdd, allGroup.ID)
if areGroupChangesAffectPeers(account, groupsToAdd) {
am.updateAccountPeers(ctx, account)
}

View File

@@ -20,33 +20,33 @@ type Peer struct {
// IP address of the Peer
IP net.IP `gorm:"serializer:json"`
// Meta is a Peer system meta data
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"`
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
// Name is peer's name (machine name)
Name string
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
DNSLabel string
// Status peer's management connection status
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"`
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
// The user ID that registered the peer
UserID string `diff:"-"`
UserID string
// SSHKey is a public SSH key of the peer
SSHKey string
// SSHEnabled indicates whether SSH server is enabled on the peer
SSHEnabled bool
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
// Works with LastLogin
LoginExpirationEnabled bool `diff:"-"`
LoginExpirationEnabled bool
InactivityExpirationEnabled bool `diff:"-"`
InactivityExpirationEnabled bool
// LastLogin the time when peer performed last login operation
LastLogin time.Time `diff:"-"`
LastLogin time.Time
// CreatedAt records the time the peer was created
CreatedAt time.Time `diff:"-"`
CreatedAt time.Time
// Indicate ephemeral peer attribute
Ephemeral bool `diff:"-"`
Ephemeral bool
// Geo location based on connection IP
Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"`
Location Location `gorm:"embedded;embeddedPrefix:location_"`
}
type PeerStatus struct { //nolint:revive

View File

@@ -864,10 +864,14 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
b.Fatalf("Failed to get account: %v", err)
}
peerChannels := make(map[string]chan *UpdateMessage)
peerChannels := make(map[string]*PeerUpdateChannel)
for peerID := range account.Peers {
peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize)
peerChannels[peerID] = &PeerUpdateChannel{
peerID: peerID,
channel: make(chan *UpdateMessage, channelBufferSize),
sessionID: xid.New().String(),
}
}
manager.peersUpdateManager.peerChannels = peerChannels
@@ -1315,14 +1319,14 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
})
// Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update
t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1340,7 +1344,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("adding peer to unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1365,7 +1369,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("deleting peer with unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1383,7 +1387,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("updating peer label", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1417,7 +1421,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1443,7 +1447,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("deleting peer with linked group to policy", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1481,7 +1485,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1507,7 +1511,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("deleting peer with linked group to route", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1536,7 +1540,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1562,7 +1566,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("deleting peer with linked group to route", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()

View File

@@ -405,7 +405,9 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policy.ID, accountID, activity.PolicyRemoved, policy.EventMeta())
am.updateAccountPeers(ctx, account)
if anyGroupHasPeers(account, policy.ruleGroups()) {
am.updateAccountPeers(ctx, account)
}
return nil
}

View File

@@ -854,14 +854,9 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
})
assert.NoError(t, err)
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
})
updMsg2 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
})
// Saving policy with rule groups with no peers should not update account's peers and not send peer update
@@ -883,7 +878,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg1)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -918,7 +913,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg1)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -953,7 +948,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg2)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -987,7 +982,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg1)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1021,7 +1016,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg1)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1056,7 +1051,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg1)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1090,7 +1085,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg1)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1104,46 +1099,13 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
}
})
// Saving unchanged policy should trigger account peers update but not send peer update
t.Run("saving unchanged policy", func(t *testing.T) {
policy := Policy{
ID: "policy-source-destination-peers",
Enabled: true,
Rules: []*PolicyRule{
{
ID: xid.New().String(),
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupD"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg1)
close(done)
}()
err := manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Deleting policy should trigger account peers update and send peer update
t.Run("deleting policy with source and destination groups with peers", func(t *testing.T) {
policyID := "policy-source-destination-peers"
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg1)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1164,7 +1126,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
policyID := "policy-destination-has-peers-source-none"
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg2)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1180,10 +1142,10 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
// Deleting policy with no peers in groups should not update account's peers and not send peer update
t.Run("deleting policy with no peers in groups", func(t *testing.T) {
policyID := "policy-rule-groups-no-peers" // Deleting the policy created in Case 2
policyID := "policy-rule-groups-no-peers"
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg1)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()

View File

@@ -5,10 +5,11 @@ import (
"testing"
"time"
"github.com/netbirdio/netbird/management/server/group"
"github.com/rs/xid"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/group"
"github.com/netbirdio/netbird/management/server/posture"
)
@@ -146,7 +147,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
})
postureCheck := posture.Checks{
@@ -164,7 +165,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
t.Run("saving unused posture check", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -182,7 +183,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
t.Run("updating unused posture check", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -221,7 +222,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
t.Run("linking posture check to policy with peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -250,7 +251,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -264,30 +265,11 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}
})
// Saving unchanged posture check should not trigger account peers update and not send peer update
// since there is no change in the network map
t.Run("saving unchanged posture check", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Removing posture check from policy should trigger account peers update and send peer update
t.Run("removing posture check from policy", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -307,7 +289,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
t.Run("deleting unused posture check", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -346,7 +328,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -370,7 +352,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) {
updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID)
manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID, updMsg1.sessionID)
})
policy = Policy{
ID: "policyB",
@@ -393,7 +375,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg1)
peerShouldReceiveUpdate(t, updMsg1.channel)
close(done)
}()
@@ -412,8 +394,8 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
}
})
// Updating linked posture check to policy where source has peers but destination does not,
// should not trigger account peers update or send peer update
// Updating linked client posture check to policy where source has peers but destination does not,
// should trigger account peers update and send peer update
t.Run("updating linked posture check to policy where source has peers but destination does not", func(t *testing.T) {
policy = Policy{
ID: "policyB",
@@ -434,48 +416,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
postureCheck.Checks = posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{
MinVersion: "0.29.0",
},
}
err := manager.SavePostureChecks(context.Background(), account.Id, userID, &postureCheck)
assert.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Updating linked client posture check to policy where source has peers but destination does not,
// should trigger account peers update and send peer update
t.Run("updating linked client posture check to policy where source has peers but destination does not", func(t *testing.T) {
policy = Policy{
ID: "policyB",
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"groupA"},
Destinations: []string{"groupB"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
SourcePostureChecks: []string{postureCheck.ID},
}
err = manager.SavePolicy(context.Background(), account.Id, userID, &policy, true)
assert.NoError(t, err)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()

View File

@@ -1807,7 +1807,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID)
manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID, updMsg.sessionID)
})
// Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update
@@ -1827,7 +1827,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1863,7 +1863,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1899,7 +1899,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
t.Run("creating route with a routing peer", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1924,7 +1924,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1938,31 +1938,11 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
}
})
// Updating unchanged route should update account peers and not send peer update
t.Run("updating unchanged route", func(t *testing.T) {
baseRoute.Groups = []string{routeGroup1, routeGroup2}
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
err := manager.SaveRoute(context.Background(), account.Id, userID, &baseRoute)
require.NoError(t, err)
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Deleting the route should update account peers and send peer update
t.Run("deleting route", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1998,7 +1978,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -2038,7 +2018,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()

View File

@@ -408,7 +408,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
})
var setupKey *SetupKey
@@ -417,7 +417,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
t.Run("creating setup key", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -435,7 +435,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) {
t.Run("saving setup key", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()

View File

@@ -292,6 +292,8 @@ func (s *SqlStore) GetInstallationID() string {
}
func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error {
startTime := time.Now()
// To maintain data integrity, we create a copy of the peer's to prevent unintended updates to other fields.
peerCopy := peer.Copy()
peerCopy.AccountID = accountID
@@ -317,6 +319,9 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
})
if err != nil {
if errors.Is(err, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return err
}
@@ -324,6 +329,8 @@ func (s *SqlStore) SavePeer(ctx context.Context, accountID string, peer *nbpeer.
}
func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error {
startTime := time.Now()
accountCopy := Account{
Domain: domain,
DomainCategory: category,
@@ -336,6 +343,9 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
Where(idQueryCondition, accountID).
Updates(&accountCopy)
if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return result.Error
}
@@ -347,6 +357,8 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID
}
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
startTime := time.Now()
var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus
@@ -359,6 +371,9 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
Where(accountAndIDQueryCondition, accountID, peerID).
Updates(&peerCopy)
if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return result.Error
}
@@ -370,6 +385,8 @@ func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.Pe
}
func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.Peer) error {
startTime := time.Now()
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer.Peer
// Since the location field has been migrated to JSON serialization,
@@ -381,6 +398,9 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
Updates(peerCopy)
if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return result.Error
}
@@ -394,6 +414,8 @@ func (s *SqlStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer.P
// SaveUsers saves the given list of users to the database.
// It updates existing users if a conflict occurs.
func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
startTime := time.Now()
usersToSave := make([]User, 0, len(users))
for _, user := range users {
user.AccountID = accountID
@@ -403,15 +425,28 @@ func (s *SqlStore) SaveUsers(accountID string, users map[string]*User) error {
}
usersToSave = append(usersToSave, *user)
}
return s.db.Session(&gorm.Session{FullSaveAssociations: true}).
err := s.db.Session(&gorm.Session{FullSaveAssociations: true}).
Clauses(clause.OnConflict{UpdateAll: true}).
Create(&usersToSave).Error
if err != nil {
if errors.Is(err, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "failed to save users to store: %v", err)
}
return nil
}
// SaveUser saves the given user to the database.
func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, user *User) error {
startTime := time.Now()
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(user)
if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "failed to save user to store: %v", result.Error)
}
return nil
@@ -419,12 +454,17 @@ func (s *SqlStore) SaveUser(ctx context.Context, lockStrength LockingStrength, u
// SaveGroups saves the given list of groups to the database.
func (s *SqlStore) SaveGroups(ctx context.Context, lockStrength LockingStrength, groups []*nbgroup.Group) error {
startTime := time.Now()
if len(groups) == 0 {
return nil
}
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Save(&groups)
if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "failed to save groups to store: %v", result.Error)
}
return nil
@@ -451,6 +491,8 @@ func (s *SqlStore) GetAccountByPrivateDomain(ctx context.Context, domain string)
}
func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength LockingStrength, domain string) (string, error) {
startTime := time.Now()
var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("id").
Where("domain = ? and is_domain_primary_account = ? and domain_category = ?",
@@ -460,6 +502,9 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: provided domain is not registered or is not private")
}
if errors.Is(result.Error, context.Canceled) {
return "", status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting account from the store: %s", result.Error)
return "", status.NewGetAccountFromStoreError(result.Error)
}
@@ -468,12 +513,17 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, lockStrength
}
func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) {
startTime := time.Now()
var key SetupKey
result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, setupKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewSetupKeyNotFoundError(result.Error)
}
@@ -485,12 +535,17 @@ func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*
}
func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken string) (string, error) {
startTime := time.Now()
var token PersonalAccessToken
result := s.db.First(&token, "hashed_token = ?", hashedToken)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return "", status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return "", status.NewGetAccountFromStoreError(result.Error)
}
@@ -499,12 +554,17 @@ func (s *SqlStore) GetTokenIDByHashedToken(ctx context.Context, hashedToken stri
}
func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) {
startTime := time.Now()
var token PersonalAccessToken
result := s.db.First(&token, idQueryCondition, tokenID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting token from the store: %s", result.Error)
return nil, status.NewGetAccountFromStoreError(result.Error)
}
@@ -528,6 +588,8 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
}
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
startTime := time.Now()
var user User
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
Preload(clause.Associations).First(&user, idQueryCondition, userID)
@@ -535,6 +597,9 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewGetUserFromStoreError()
}
@@ -542,12 +607,17 @@ func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStre
}
func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*User, error) {
startTime := time.Now()
var users []*User
result := s.db.Find(&users, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting users from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting users from store")
}
@@ -556,12 +626,17 @@ func (s *SqlStore) GetAccountUsers(ctx context.Context, accountID string) ([]*Us
}
func (s *SqlStore) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) {
startTime := time.Now()
var groups []*nbgroup.Group
result := s.db.Find(&groups, accountIDCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "accountID not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting groups from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting groups from store")
}
@@ -661,12 +736,17 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account,
}
func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) {
startTime := time.Now()
var user User
result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewGetAccountFromStoreError(result.Error)
}
@@ -678,12 +758,17 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun
}
func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) {
startTime := time.Now()
var peer nbpeer.Peer
result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewGetAccountFromStoreError(result.Error)
}
@@ -695,13 +780,17 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco
}
func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) {
var peer nbpeer.Peer
startTime := time.Now()
var peer nbpeer.Peer
result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewGetAccountFromStoreError(result.Error)
}
@@ -713,6 +802,8 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (
}
func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) {
startTime := time.Now()
var peer nbpeer.Peer
var accountID string
result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID)
@@ -720,6 +811,9 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return "", status.NewStoreContextCanceledError(time.Since(startTime))
}
return "", status.NewGetAccountFromStoreError(result.Error)
}
@@ -727,12 +821,17 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string)
}
func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
startTime := time.Now()
var accountID string
result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return "", status.NewStoreContextCanceledError(time.Since(startTime))
}
return "", status.NewGetAccountFromStoreError(result.Error)
}
@@ -740,12 +839,17 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) {
}
func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) {
startTime := time.Now()
var accountID string
result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, setupKey).First(&accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", status.Errorf(status.NotFound, "account not found: index lookup failed")
}
if errors.Is(result.Error, context.Canceled) {
return "", status.NewStoreContextCanceledError(time.Since(startTime))
}
return "", status.NewSetupKeyNotFoundError(result.Error)
}
@@ -757,6 +861,8 @@ func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string)
}
func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) {
startTime := time.Now()
var ipJSONStrings []string
// Fetch the IP addresses as JSON strings
@@ -767,6 +873,9 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.Errorf(status.Internal, "issue getting IPs from store: %s", result.Error)
}
@@ -784,8 +893,9 @@ func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength
}
func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) {
var labels []string
startTime := time.Now()
var labels []string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}).
Where("account_id = ?", accountID).
Pluck("dns_label", &labels)
@@ -794,6 +904,9 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "no peers found for the account")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error)
return nil, status.Errorf(status.Internal, "issue getting dns labels from store: %s", result.Error)
}
@@ -802,24 +915,33 @@ func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength Lock
}
func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) {
var accountNetwork AccountNetwork
startTime := time.Now()
var accountNetwork AccountNetwork
if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.NewAccountNotFoundError(accountID)
}
if errors.Is(err, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.Errorf(status.Internal, "issue getting network from store: %s", err)
}
return accountNetwork.Network, nil
}
func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) {
startTime := time.Now()
var peer nbpeer.Peer
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "peer not found")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.Errorf(status.Internal, "issue getting peer from store: %s", result.Error)
}
@@ -827,11 +949,16 @@ func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength Locking
}
func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) {
startTime := time.Now()
var accountSettings AccountSettings
if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "settings not found")
}
if errors.Is(err, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.Errorf(status.Internal, "issue getting settings from store: %s", err)
}
return accountSettings.Settings, nil
@@ -839,13 +966,17 @@ func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingS
// SaveUserLastLogin stores the last login time for a user in DB.
func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error {
var user User
startTime := time.Now()
var user User
result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.NewUserNotFoundError(userID)
}
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.NewGetUserFromStoreError()
}
user.LastLogin = lastLogin
@@ -854,6 +985,8 @@ func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID stri
}
func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {
startTime := time.Now()
definitionJSON, err := json.Marshal(checks)
if err != nil {
return nil, err
@@ -862,6 +995,9 @@ func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *p
var postureCheck posture.Checks
err = s.db.Where("account_id = ? AND checks = ?", accountID, string(definitionJSON)).First(&postureCheck).Error
if err != nil {
if errors.Is(err, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, err
}
@@ -971,6 +1107,8 @@ func NewPostgresqlStoreFromSqlStore(ctx context.Context, sqliteStore *SqlStore,
}
func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) {
startTime := time.Now()
var setupKey SetupKey
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&setupKey, keyQueryCondition, key)
@@ -978,12 +1116,17 @@ func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength Locking
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "setup key not found")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.NewSetupKeyNotFoundError(result.Error)
}
return &setupKey, nil
}
func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error {
startTime := time.Now()
result := s.db.WithContext(ctx).Model(&SetupKey{}).
Where(idQueryCondition, setupKeyID).
Updates(map[string]interface{}{
@@ -992,6 +1135,9 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
})
if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error)
}
@@ -1003,13 +1149,17 @@ func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string
}
func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error {
var group nbgroup.Group
startTime := time.Now()
var group nbgroup.Group
result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group 'All' not found for account")
}
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue finding group 'All': %s", result.Error)
}
@@ -1022,6 +1172,9 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
group.Peers = append(group.Peers, peerID)
if err := s.db.Save(&group).Error; err != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue updating group 'All': %s", err)
}
@@ -1029,13 +1182,17 @@ func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peer
}
func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error {
var group nbgroup.Group
startTime := time.Now()
var group nbgroup.Group
result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return status.Errorf(status.NotFound, "group not found for account")
}
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue finding group: %s", result.Error)
}
@@ -1048,6 +1205,9 @@ func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId
group.Peers = append(group.Peers, peerId)
if err := s.db.Save(&group).Error; err != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue updating group: %s", err)
}
@@ -1060,7 +1220,12 @@ func (s *SqlStore) GetUserPeers(ctx context.Context, lockStrength LockingStrengt
}
func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error {
startTime := time.Now()
if err := s.db.WithContext(ctx).Create(peer).Error; err != nil {
if errors.Is(err, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue adding peer to account: %s", err)
}
@@ -1068,8 +1233,13 @@ func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) erro
}
func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error {
startTime := time.Now()
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1"))
if result.Error != nil {
if errors.Is(result.Error, context.Canceled) {
return status.NewStoreContextCanceledError(time.Since(startTime))
}
return status.Errorf(status.Internal, "issue incrementing network serial count: %s", result.Error)
}
return nil
@@ -1100,14 +1270,18 @@ func (s *SqlStore) GetDB() *gorm.DB {
}
func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) {
var accountDNSSettings AccountDNSSettings
startTime := time.Now()
var accountDNSSettings AccountDNSSettings
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
First(&accountDNSSettings, idQueryCondition, accountID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "dns settings not found")
}
if errors.Is(result.Error, context.Canceled) {
return nil, status.NewStoreContextCanceledError(time.Since(startTime))
}
return nil, status.Errorf(status.Internal, "failed to get dns settings from store: %v", result.Error)
}
return &accountDNSSettings.DNSSettings, nil
@@ -1115,14 +1289,18 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
// AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
var accountID string
startTime := time.Now()
var accountID string
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Select("id").First(&accountID, idQueryCondition, id)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return false, nil
}
if errors.Is(result.Error, context.Canceled) {
return false, status.NewStoreContextCanceledError(time.Since(startTime))
}
return false, result.Error
}
@@ -1131,14 +1309,18 @@ func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStreng
// GetAccountDomainAndCategory retrieves the Domain and DomainCategory fields for an account based on the given accountID.
func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) {
var account Account
startTime := time.Now()
var account Account
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Select("domain", "domain_category").
Where(idQueryCondition, accountID).First(&account)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return "", "", status.Errorf(status.NotFound, "account not found")
}
if errors.Is(result.Error, context.Canceled) {
return "", "", status.NewStoreContextCanceledError(time.Since(startTime))
}
return "", "", status.Errorf(status.Internal, "failed to get domain category from store: %v", result.Error)
}

View File

@@ -3,6 +3,7 @@ package status
import (
"errors"
"fmt"
"time"
)
const (
@@ -115,6 +116,11 @@ func NewGetUserFromStoreError() error {
return Errorf(Internal, "issue getting user from store")
}
// NewStoreContextCanceledError creates a new Error with Internal type for a canceled store context
func NewStoreContextCanceledError(duration time.Duration) error {
return Errorf(Internal, "store access: context canceled after %v", duration)
}
// NewInvalidKeyIDError creates a new Error with InvalidArgument type for an issue getting a setup key
func NewInvalidKeyIDError() error {
return Errorf(InvalidArgument, "invalid key ID")

View File

@@ -18,7 +18,6 @@ type UpdateChannelMetrics struct {
getAllConnectedPeersDurationMicro metric.Int64Histogram
getAllConnectedPeers metric.Int64Histogram
hasChannelDurationMicro metric.Int64Histogram
networkMapDiffDurationMicro metric.Int64Histogram
ctx context.Context
}
@@ -64,11 +63,6 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
return nil, err
}
networkMapDiffDurationMicro, err := meter.Int64Histogram("management.updatechannel.networkmap.diff.duration.micro")
if err != nil {
return nil, err
}
return &UpdateChannelMetrics{
createChannelDurationMicro: createChannelDurationMicro,
closeChannelDurationMicro: closeChannelDurationMicro,
@@ -78,7 +72,6 @@ func NewUpdateChannelMetrics(ctx context.Context, meter metric.Meter) (*UpdateCh
getAllConnectedPeersDurationMicro: getAllConnectedPeersDurationMicro,
getAllConnectedPeers: getAllConnectedPeers,
hasChannelDurationMicro: hasChannelDurationMicro,
networkMapDiffDurationMicro: networkMapDiffDurationMicro,
ctx: ctx,
}, nil
}
@@ -118,8 +111,3 @@ func (metrics *UpdateChannelMetrics) CountGetAllConnectedPeersDuration(duration
func (metrics *UpdateChannelMetrics) CountHasChannelDuration(duration time.Duration) {
metrics.hasChannelDurationMicro.Record(metrics.ctx, duration.Microseconds())
}
// CountNetworkMapDiffDurationMicro counts the duration of the NetworkMapDiff method
func (metrics *UpdateChannelMetrics) CountNetworkMapDiffDurationMicro(duration time.Duration) {
metrics.networkMapDiffDurationMicro.Record(metrics.ctx, duration.Microseconds())
}

View File

@@ -104,7 +104,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) {
loop:
for timeout := time.After(5 * time.Second); ; {
select {
case update := <-updateChannel:
case update := <-updateChannel.channel:
updates = append(updates, update)
case <-timeout:
break loop

View File

@@ -2,31 +2,33 @@ package server
import (
"context"
"fmt"
"runtime/debug"
"sync"
"time"
"github.com/r3labs/diff/v3"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/differs"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const channelBufferSize = 100
const SessionIdForceOverwrite = "FORCE"
type UpdateMessage struct {
Update *proto.SyncResponse
NetworkMap *NetworkMap
}
type PeerUpdateChannel struct {
peerID string
sessionID string
channel chan *UpdateMessage
}
type PeersUpdateManager struct {
// peerChannels is an update channel indexed by Peer.ID
peerChannels map[string]chan *UpdateMessage
// peerNetworkMaps is the UpdateMessage indexed by Peer.ID.
peerUpdateMessage map[string]*UpdateMessage
// peerChannels is a map of peerID to the channel used to deliver updates relevant to the peer
peerChannels map[string]*PeerUpdateChannel
// channelsMux keeps the mutex to access peerChannels
channelsMux *sync.RWMutex
// metrics provides method to collect application metrics
@@ -36,10 +38,9 @@ type PeersUpdateManager struct {
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
return &PeersUpdateManager{
peerChannels: make(map[string]chan *UpdateMessage),
peerUpdateMessage: make(map[string]*UpdateMessage),
channelsMux: &sync.RWMutex{},
metrics: metrics,
peerChannels: make(map[string]*PeerUpdateChannel),
channelsMux: &sync.RWMutex{},
metrics: metrics,
}
}
@@ -48,15 +49,6 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
start := time.Now()
var found, dropped bool
// skip sending sync update to the peer if there is no change in update message,
// it will not check on turn credential refresh as we do not send network map or client posture checks
if update.NetworkMap != nil {
updated := p.handlePeerMessageUpdate(ctx, peerID, update)
if !updated {
return
}
}
p.channelsMux.Lock()
defer func() {
@@ -66,24 +58,14 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
}
}()
if update.NetworkMap != nil {
lastSentUpdate := p.peerUpdateMessage[peerID]
if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() > update.Update.NetworkMap.GetSerial() {
log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update",
peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial())
return
}
p.peerUpdateMessage[peerID] = update
}
if channel, ok := p.peerChannels[peerID]; ok {
if peerUpdates, ok := p.peerChannels[peerID]; ok {
found = true
select {
case channel <- update:
case peerUpdates.channel <- update:
log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID)
default:
dropped = true
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel))
log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(peerUpdates.channel))
}
} else {
log.WithContext(ctx).Debugf("peer %s has no channel", peerID)
@@ -91,7 +73,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
}
// CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer.
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage {
func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) *PeerUpdateChannel {
start := time.Now()
closed := false
@@ -107,26 +89,39 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c
if channel, ok := p.peerChannels[peerID]; ok {
closed = true
delete(p.peerChannels, peerID)
close(channel)
delete(p.peerUpdateMessage, peerID)
close(channel.channel)
log.WithContext(ctx).Debugf("overwriting existing channel for peer %s", peerID)
}
// mbragin: todo shouldn't it be more? or configurable?
channel := make(chan *UpdateMessage, channelBufferSize)
p.peerChannels[peerID] = channel
log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID)
peerUpdateChannel := &PeerUpdateChannel{
peerID: peerID,
sessionID: uuid.New().String(),
// mbragin: todo shouldn't it be more? or configurable?
channel: make(chan *UpdateMessage, channelBufferSize),
}
return channel
p.peerChannels[peerID] = peerUpdateChannel
log.WithContext(ctx).Debugf("opened updates channel for a peer %s and session %s", peerID, peerUpdateChannel.sessionID)
return peerUpdateChannel
}
func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
if channel, ok := p.peerChannels[peerID]; ok {
delete(p.peerChannels, peerID)
close(channel)
delete(p.peerUpdateMessage, peerID)
func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string, sessionID string) bool {
if peerUpdates, ok := p.peerChannels[peerID]; ok {
if peerUpdates.sessionID == sessionID || sessionID == SessionIdForceOverwrite {
delete(p.peerChannels, peerID)
close(peerUpdates.channel)
log.WithContext(ctx).Debugf("closed updates channel of a peer %s and session %s", peerID, sessionID)
return true
}
log.WithContext(ctx).Warnf("tried to close updates channel of a peer %s with session %s, but current session is %s", peerID, sessionID, peerUpdates.sessionID)
return false
}
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
log.WithContext(ctx).Warnf("tried to close updates channel of a peer %s with session %s, but no channel found", peerID, sessionID)
return true
}
// CloseChannels closes updates channel for each given peer
@@ -142,12 +137,12 @@ func (p *PeersUpdateManager) CloseChannels(ctx context.Context, peerIDs []string
}()
for _, id := range peerIDs {
p.closeChannel(ctx, id)
p.closeChannel(ctx, id, SessionIdForceOverwrite)
}
}
// CloseChannel closes updates channel of a given peer
func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) {
func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string, sessionID string) bool {
start := time.Now()
p.channelsMux.Lock()
@@ -158,7 +153,7 @@ func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) {
}
}()
p.closeChannel(ctx, peerID)
return p.closeChannel(ctx, peerID, sessionID)
}
// GetAllConnectedPeers returns a copy of the connected peers map
@@ -200,79 +195,3 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool {
return ok
}
// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent.
func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool {
p.channelsMux.RLock()
lastSentUpdate := p.peerUpdateMessage[peerID]
p.channelsMux.RUnlock()
if lastSentUpdate != nil {
updated, err := isNewPeerUpdateMessage(ctx, lastSentUpdate, update, p.metrics)
if err != nil {
log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err)
return true
}
if !updated {
log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID)
return false
}
}
return true
}
// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent.
func isNewPeerUpdateMessage(ctx context.Context, lastSentUpdate, currUpdateToSend *UpdateMessage, metric telemetry.AppMetrics) (isNew bool, err error) {
startTime := time.Now()
defer func() {
if r := recover(); r != nil {
log.WithContext(ctx).Panicf("comparing peer update messages. Trace: %s", debug.Stack())
isNew, err = true, nil
}
}()
if lastSentUpdate.Update.NetworkMap.GetSerial() > currUpdateToSend.Update.NetworkMap.GetSerial() {
return false, nil
}
differ, err := diff.NewDiffer(
diff.CustomValueDiffers(&differs.NetIPAddr{}),
diff.CustomValueDiffers(&differs.NetIPPrefix{}),
)
if err != nil {
return false, fmt.Errorf("failed to create differ: %v", err)
}
lastSentFiles := getChecksFiles(lastSentUpdate.Update.Checks)
currFiles := getChecksFiles(currUpdateToSend.Update.Checks)
changelog, err := differ.Diff(lastSentFiles, currFiles)
if err != nil {
return false, fmt.Errorf("failed to diff checks: %v", err)
}
if len(changelog) > 0 {
return true, nil
}
changelog, err = differ.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap)
if err != nil {
return false, fmt.Errorf("failed to diff network map: %v", err)
}
if metric != nil {
metric.UpdateChannelMetrics().CountNetworkMapDiffDurationMicro(time.Since(startTime))
}
return len(changelog) > 0, nil
}
// getChecksFiles returns a list of files from the given checks.
func getChecksFiles(checks []*proto.Checks) []string {
files := make([]string, 0, len(checks))
for _, check := range checks {
files = append(files, check.GetFiles()...)
}
return files
}

View File

@@ -2,21 +2,12 @@ package server
import (
"context"
"net"
"net/netip"
"testing"
"time"
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/management/proto"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/telemetry"
nbroute "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/util"
)
// var peersUpdater *PeersUpdateManager
@@ -24,7 +15,7 @@ import (
func TestCreateChannel(t *testing.T) {
peer := "test-create"
peersUpdater := NewPeersUpdateManager(nil)
defer peersUpdater.CloseChannel(context.Background(), peer)
defer peersUpdater.CloseChannel(context.Background(), peer, "sessionID")
_ = peersUpdater.CreateChannel(context.Background(), peer)
if _, ok := peersUpdater.peerChannels[peer]; !ok {
@@ -46,7 +37,7 @@ func TestSendUpdate(t *testing.T) {
}
peersUpdater.SendUpdate(context.Background(), peer, update1)
select {
case <-peersUpdater.peerChannels[peer]:
case <-peersUpdater.peerChannels[peer].channel:
default:
t.Error("Update wasn't send")
}
@@ -67,7 +58,7 @@ func TestSendUpdate(t *testing.T) {
select {
case <-timeout:
t.Error("timed out reading previously sent updates")
case updateReader := <-peersUpdater.peerChannels[peer]:
case updateReader := <-peersUpdater.peerChannels[peer].channel:
if updateReader.Update.NetworkMap.Serial == update2.Update.NetworkMap.Serial {
t.Error("got the update that shouldn't have been sent")
}
@@ -76,486 +67,50 @@ func TestSendUpdate(t *testing.T) {
}
func TestCloseChannel(t *testing.T) {
func TestCloseChannel_WithCorrectSessionID(t *testing.T) {
peer := "test-close"
peersUpdater := NewPeersUpdateManager(nil)
_ = peersUpdater.CreateChannel(context.Background(), peer)
peerUpdates := peersUpdater.CreateChannel(context.Background(), peer)
if _, ok := peersUpdater.peerChannels[peer]; !ok {
t.Error("Error creating the channel")
}
peersUpdater.CloseChannel(context.Background(), peer)
updateDB := peersUpdater.CloseChannel(context.Background(), peer, peerUpdates.sessionID)
if _, ok := peersUpdater.peerChannels[peer]; ok {
t.Error("Error closing the channel")
}
assert.Equal(t, true, updateDB)
}
func TestHandlePeerMessageUpdate(t *testing.T) {
tests := []struct {
name string
peerID string
existingUpdate *UpdateMessage
newUpdate *UpdateMessage
expectedResult bool
}{
{
name: "update message with turn credentials update",
peerID: "peer",
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
WiretrusteeConfig: &proto.WiretrusteeConfig{},
},
},
expectedResult: true,
},
{
name: "update message for peer without existing update",
peerID: "peer1",
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
},
expectedResult: true,
},
{
name: "update message with no changes in update",
peerID: "peer2",
existingUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
},
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
},
expectedResult: false,
},
{
name: "update message with changes in checks",
peerID: "peer3",
existingUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
},
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 2},
Checks: []*proto.Checks{
{
Files: []string{"/usr/bin/netbird"},
},
},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
},
expectedResult: true,
},
{
name: "update message with lower serial number",
peerID: "peer4",
existingUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 2},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
},
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
},
expectedResult: false,
},
func TestCloseChannel_WithWrongSessionID(t *testing.T) {
peer := "test-close"
peersUpdater := NewPeersUpdateManager(nil)
peersUpdater.CreateChannel(context.Background(), peer)
if _, ok := peersUpdater.peerChannels[peer]; !ok {
t.Error("Error creating the channel")
}
for _, tt := range tests {
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
if err != nil {
t.Fatal(err)
}
t.Run(tt.name, func(t *testing.T) {
p := NewPeersUpdateManager(metrics)
ctx := context.Background()
if tt.existingUpdate != nil {
p.peerUpdateMessage[tt.peerID] = tt.existingUpdate
}
result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate)
assert.Equal(t, tt.expectedResult, result)
})
updateDB := peersUpdater.CloseChannel(context.Background(), peer, "wrongSessionID")
if _, ok := peersUpdater.peerChannels[peer]; !ok {
t.Error("Should not close channel with wrong session id")
}
assert.Equal(t, false, updateDB)
}
func TestIsNewPeerUpdateMessage(t *testing.T) {
t.Run("Unchanged value", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
func TestCloseChannel_WithForceOverwrite(t *testing.T) {
peer := "test-close"
peersUpdater := NewPeersUpdateManager(nil)
peersUpdater.CreateChannel(context.Background(), peer)
if _, ok := peersUpdater.peerChannels[peer]; !ok {
t.Error("Error creating the channel")
}
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
assert.NoError(t, err)
assert.False(t, message)
})
t.Run("Unchanged value with serial incremented", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
assert.NoError(t, err)
assert.False(t, message)
})
t.Run("Updating routes network", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.NetworkMap.Routes[0].Network = netip.MustParsePrefix("1.1.1.1/32")
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Updating routes groups", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.NetworkMap.Routes[0].Groups = []string{"randomGroup1"}
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Updating network map peers", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newPeer := &nbpeer.Peer{
IP: net.ParseIP("192.168.1.4"),
SSHEnabled: true,
Key: "peer4-key",
DNSLabel: "peer4",
SSHKey: "peer4-ssh-key",
}
newUpdateMessage2.NetworkMap.Peers = append(newUpdateMessage2.NetworkMap.Peers, newPeer)
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Updating process check", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
assert.NoError(t, err)
assert.False(t, message)
newUpdateMessage3 := createMockUpdateMessage(t)
newUpdateMessage3.Update.Checks = []*proto.Checks{}
newUpdateMessage3.Update.NetworkMap.Serial++
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage3, nil)
assert.NoError(t, err)
assert.True(t, message)
newUpdateMessage4 := createMockUpdateMessage(t)
check := &posture.Checks{
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{
LinuxPath: "/usr/local/netbird",
MacPath: "/usr/bin/netbird",
},
},
},
},
}
newUpdateMessage4.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
newUpdateMessage4.Update.NetworkMap.Serial++
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage4, nil)
assert.NoError(t, err)
assert.True(t, message)
newUpdateMessage5 := createMockUpdateMessage(t)
check = &posture.Checks{
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{
LinuxPath: "/usr/bin/netbird",
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
MacPath: "/usr/local/netbird",
},
},
},
},
}
newUpdateMessage5.Update.Checks = []*proto.Checks{toProtocolCheck(check)}
newUpdateMessage5.Update.NetworkMap.Serial++
message, err = isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage5, nil)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Updating DNS configuration", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newDomain := "newexample.com"
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains = append(
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].Domains,
newDomain,
)
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Updating peer IP", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.NetworkMap.Peers[0].IP = net.ParseIP("192.168.1.10")
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Updating firewall rule", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.NetworkMap.FirewallRules[0].Port = "443"
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Add new firewall rule", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newRule := &FirewallRule{
PeerIP: "192.168.1.3",
Direction: firewallRuleDirectionOUT,
Action: string(PolicyTrafficActionDrop),
Protocol: string(PolicyRuleProtocolUDP),
Port: "53",
}
newUpdateMessage2.NetworkMap.FirewallRules = append(newUpdateMessage2.NetworkMap.FirewallRules, newRule)
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Removing nameserver", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers = make([]nbdns.NameServer, 0)
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Updating name server IP", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.NetworkMap.DNSConfig.NameServerGroups[0].NameServers[0].IP = netip.MustParseAddr("8.8.4.4")
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
assert.NoError(t, err)
assert.True(t, message)
})
t.Run("Updating custom DNS zone", func(t *testing.T) {
newUpdateMessage1 := createMockUpdateMessage(t)
newUpdateMessage2 := createMockUpdateMessage(t)
newUpdateMessage2.NetworkMap.DNSConfig.CustomZones[0].Records[0].RData = "100.64.0.2"
newUpdateMessage2.Update.NetworkMap.Serial++
message, err := isNewPeerUpdateMessage(context.Background(), newUpdateMessage1, newUpdateMessage2, nil)
assert.NoError(t, err)
assert.True(t, message)
})
updateDB := peersUpdater.CloseChannel(context.Background(), peer, SessionIdForceOverwrite)
if _, ok := peersUpdater.peerChannels[peer]; ok {
t.Error("Should close channel if forced")
}
}
func createMockUpdateMessage(t *testing.T) *UpdateMessage {
t.Helper()
_, ipNet, err := net.ParseCIDR("192.168.1.0/24")
if err != nil {
t.Fatal(err)
}
domainList, err := domain.FromStringList([]string{"example.com"})
if err != nil {
t.Fatal(err)
}
config := &Config{
Signal: &Host{
Proto: "https",
URI: "signal.uri",
Username: "",
Password: "",
},
Stuns: []*Host{{URI: "stun.uri", Proto: UDP}},
TURNConfig: &TURNConfig{
Turns: []*Host{{URI: "turn.uri", Proto: UDP, Username: "turn-user", Password: "turn-pass"}},
},
}
peer := &nbpeer.Peer{
IP: net.ParseIP("192.168.1.1"),
SSHEnabled: true,
Key: "peer-key",
DNSLabel: "peer1",
SSHKey: "peer1-ssh-key",
}
secretManager := NewTimeBasedAuthSecretsManager(
NewPeersUpdateManager(nil),
&TURNConfig{
TimeBasedCredentials: false,
CredentialsTTL: util.Duration{
Duration: defaultDuration,
},
Secret: "secret",
Turns: []*Host{TurnTestHost},
},
&Relay{
Addresses: []string{"localhost:0"},
CredentialsTTL: util.Duration{Duration: time.Hour},
Secret: "secret",
},
)
networkMap := &NetworkMap{
Network: &Network{Net: *ipNet, Serial: 1000},
Peers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.2"), Key: "peer2-key", DNSLabel: "peer2", SSHEnabled: true, SSHKey: "peer2-ssh-key"}},
OfflinePeers: []*nbpeer.Peer{{IP: net.ParseIP("192.168.1.3"), Key: "peer3-key", DNSLabel: "peer3", SSHEnabled: true, SSHKey: "peer3-ssh-key"}},
Routes: []*nbroute.Route{
{
ID: "route1",
Network: netip.MustParsePrefix("10.0.0.0/24"),
KeepRoute: true,
NetID: "route1",
Peer: "peer1",
NetworkType: 1,
Masquerade: true,
Metric: 9999,
Enabled: true,
Groups: []string{"test1", "test2"},
},
{
ID: "route2",
Domains: domainList,
KeepRoute: true,
NetID: "route2",
Peer: "peer1",
NetworkType: 1,
Masquerade: true,
Metric: 9999,
Enabled: true,
Groups: []string{"test1", "test2"},
},
},
DNSConfig: nbdns.Config{
ServiceEnable: true,
NameServerGroups: []*nbdns.NameServerGroup{
{
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("8.8.8.8"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
}},
Primary: true,
Domains: []string{"example.com"},
Enabled: true,
SearchDomainsEnabled: true,
},
{
ID: "ns1",
NameServers: []nbdns.NameServer{{
IP: netip.MustParseAddr("1.1.1.1"),
NSType: nbdns.UDPNameServerType,
Port: nbdns.DefaultDNSPort,
}},
Groups: []string{"group1"},
Primary: true,
Domains: []string{"example.com"},
Enabled: true,
SearchDomainsEnabled: true,
},
},
CustomZones: []nbdns.CustomZone{{Domain: "example.com", Records: []nbdns.SimpleRecord{{Name: "example.com", Type: 1, Class: "IN", TTL: 60, RData: "100.64.0.1"}}}},
},
FirewallRules: []*FirewallRule{
{PeerIP: "192.168.1.2", Direction: firewallRuleDirectionIN, Action: string(PolicyTrafficActionAccept), Protocol: string(PolicyRuleProtocolTCP), Port: "80"},
},
}
dnsName := "example.com"
checks := []*posture.Checks{
{
Checks: posture.ChecksDefinition{
ProcessCheck: &posture.ProcessCheck{
Processes: []posture.Process{
{
LinuxPath: "/usr/bin/netbird",
WindowsPath: "C:\\Program Files\\netbird\\netbird.exe",
MacPath: "/usr/bin/netbird",
},
},
},
},
},
}
dnsCache := &DNSConfigCache{}
turnToken, err := secretManager.GenerateTurnToken()
if err != nil {
t.Fatal(err)
}
relayToken, err := secretManager.GenerateRelayToken()
if err != nil {
t.Fatal(err)
}
return &UpdateMessage{
Update: toSyncResponse(context.Background(), config, peer, turnToken, relayToken, networkMap, dnsName, checks, dnsCache),
NetworkMap: networkMap,
}
assert.Equal(t, true, updateDB)
}

View File

@@ -10,13 +10,14 @@ import (
"github.com/eko/gocache/v3/cache"
cacheStore "github.com/eko/gocache/v3/store"
"github.com/google/go-cmp/cmp"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
gocache "github.com/patrickmn/go-cache"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
nbgroup "github.com/netbirdio/netbird/management/server/group"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/integration_reference"
@@ -1297,14 +1298,14 @@ func TestUserAccountPeersUpdate(t *testing.T) {
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID)
})
// Creating a new regular user should not update account peers and not send peer update
t.Run("creating new regular user with no groups", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1327,7 +1328,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
t.Run("updating user with no linked peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1350,7 +1351,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
t.Run("deleting user with no linked peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1387,7 +1388,7 @@ func TestUserAccountPeersUpdate(t *testing.T) {
t.Run("updating user with linked peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg.channel)
close(done)
}()
@@ -1408,14 +1409,14 @@ func TestUserAccountPeersUpdate(t *testing.T) {
peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID)
t.Cleanup(func() {
manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID)
manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID, peer4UpdMsg.sessionID)
})
// deleting user with linked peers should update account peers and send peer update
t.Run("deleting user with linked peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, peer4UpdMsg)
peerShouldReceiveUpdate(t, peer4UpdMsg.channel)
close(done)
}()

View File

@@ -16,6 +16,7 @@ import (
var (
relayCleanupInterval = 60 * time.Second
keepUnusedServerTime = 5 * time.Second
ErrRelayClientNotConnected = fmt.Errorf("relay client not connected")
)
@@ -27,10 +28,13 @@ type RelayTrack struct {
sync.RWMutex
relayClient *Client
err error
created time.Time
}
func NewRelayTrack() *RelayTrack {
return &RelayTrack{}
return &RelayTrack{
created: time.Now(),
}
}
type OnServerCloseListener func()
@@ -302,6 +306,18 @@ func (m *Manager) cleanUpUnusedRelays() {
for addr, rt := range m.relayClients {
rt.Lock()
// if the connection failed to the server the relay client will be nil
// but the instance will be kept in the relayClients until the next locking
if rt.err != nil {
rt.Unlock()
continue
}
if time.Since(rt.created) <= keepUnusedServerTime {
rt.Unlock()
continue
}
if rt.relayClient.HasConns() {
rt.Unlock()
continue

View File

@@ -288,8 +288,9 @@ func TestForeginAutoClose(t *testing.T) {
t.Fatalf("failed to close connection: %s", err)
}
t.Logf("waiting for relay cleanup: %s", relayCleanupInterval+1*time.Second)
time.Sleep(relayCleanupInterval + 1*time.Second)
timeout := relayCleanupInterval + keepUnusedServerTime + 1*time.Second
t.Logf("waiting for relay cleanup: %s", timeout)
time.Sleep(timeout)
if len(mgr.relayClients) != 0 {
t.Errorf("expected 0, got %d", len(mgr.relayClients))
}

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"testing"
"time"
)
func TestServerPicker_UnavailableServers(t *testing.T) {
@@ -13,7 +12,7 @@ func TestServerPicker_UnavailableServers(t *testing.T) {
PeerID: "test",
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), connectionTimeout+1)
defer cancel()
go func() {