diff --git a/.github/workflows/golang-test-linux.yml b/.github/workflows/golang-test-linux.yml index 263623bd1..2d5cf2856 100644 --- a/.github/workflows/golang-test-linux.yml +++ b/.github/workflows/golang-test-linux.yml @@ -49,7 +49,7 @@ jobs: run: git --no-pager diff --exit-code - name: Test - run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./... + run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 6m -p 1 ./... test_client_on_docker: runs-on: ubuntu-20.04 diff --git a/client/internal/config.go b/client/internal/config.go index 725703c43..1df1e0547 100644 --- a/client/internal/config.go +++ b/client/internal/config.go @@ -117,6 +117,11 @@ type Config struct { // ReadConfig read config file and return with Config. If it is not exists create a new with default values func ReadConfig(configPath string) (*Config, error) { if configFileIsExists(configPath) { + err := util.EnforcePermission(configPath) + if err != nil { + log.Errorf("failed to enforce permission on config dir: %v", err) + } + config := &Config{} if _, err := util.ReadJson(configPath, config); err != nil { return nil, err @@ -159,13 +164,17 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) { if err != nil { return nil, err } - err = WriteOutConfig(input.ConfigPath, cfg) + err = util.WriteJsonWithRestrictedPermission(input.ConfigPath, cfg) return cfg, err } if isPreSharedKeyHidden(input.PreSharedKey) { input.PreSharedKey = nil } + err := util.EnforcePermission(input.ConfigPath) + if err != nil { + log.Errorf("failed to enforce permission on config dir: %v", err) + } return update(input) } diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index f4a701f7f..5327e31d2 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -89,8 +89,8 @@ type Conn struct { onConnected func(remoteWireGuardKey string, remoteRosenpassPubKey []byte, wireGuardIP string, remoteRosenpassAddr string) onDisconnected func(remotePeer string, wgIP string) - statusRelay ConnStatus - statusICE ConnStatus + statusRelay *AtomicConnStatus + statusICE *AtomicConnStatus currentConnPriority ConnPriority opened bool // this flag is used to prevent close in case of not opened connection @@ -131,8 +131,8 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu signaler: signaler, relayManager: relayManager, allowedIPsIP: allowedIPsIP.String(), - statusRelay: StatusDisconnected, - statusICE: StatusDisconnected, + statusRelay: NewAtomicConnStatus(), + statusICE: NewAtomicConnStatus(), iCEDisconnected: make(chan bool, 1), relayDisconnected: make(chan bool, 1), } @@ -323,11 +323,11 @@ func (conn *Conn) reconnectLoopWithRetry() { } if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { - if conn.statusRelay == StatusDisconnected || conn.statusICE == StatusDisconnected { + if conn.statusRelay.Get() == StatusDisconnected || conn.statusICE.Get() == StatusDisconnected { conn.log.Tracef("connectivity guard timedout, relay state: %s, ice state: %s", conn.statusRelay, conn.statusICE) } } else { - if conn.statusICE == StatusDisconnected { + if conn.statusICE.Get() == StatusDisconnected { conn.log.Tracef("connectivity guard timedout, ice state: %s", conn.statusICE) } } @@ -419,7 +419,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon conn.log.Debugf("ICE connection is ready") - conn.statusICE = StatusConnected + conn.statusICE.Set(StatusConnected) defer conn.updateIceState(iceConnInfo) @@ -484,16 +484,16 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { // switch back to relay connection if conn.endpointRelay != nil && conn.currentConnPriority != connPriorityRelay { conn.log.Debugf("ICE disconnected, set Relay to active connection") - conn.workerRelay.EnableWgWatcher(conn.ctx) err := conn.configureWGEndpoint(conn.endpointRelay) if err != nil { conn.log.Errorf("failed to switch to relay conn: %v", err) } + conn.workerRelay.EnableWgWatcher(conn.ctx) conn.currentConnPriority = connPriorityRelay } - changed := conn.statusICE != newState && newState != StatusConnecting - conn.statusICE = newState + changed := conn.statusICE.Get() != newState && newState != StatusConnecting + conn.statusICE.Set(newState) select { case conn.iCEDisconnected <- changed: @@ -522,7 +522,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { } conn.log.Debugf("Relay connection is ready to use") - conn.statusRelay = StatusConnected + conn.statusRelay.Set(StatusConnected) wgProxy := conn.wgProxyFactory.GetProxy(conn.ctx) endpoint, err := wgProxy.AddTurnConn(rci.relayedConn) @@ -538,7 +538,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { defer conn.updateRelayStatus(rci.relayedConn.RemoteAddr().String(), rci.rosenpassPubKey) if conn.currentConnPriority > connPriorityRelay { - if conn.statusICE == StatusConnected { + if conn.statusICE.Get() == StatusConnected { log.Debugf("do not switch to relay because current priority is: %v", conn.currentConnPriority) return } @@ -551,7 +551,6 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { } } - conn.workerRelay.EnableWgWatcher(conn.ctx) err = conn.configureWGEndpoint(endpointUdpAddr) if err != nil { if err := wgProxy.CloseConn(); err != nil { @@ -560,6 +559,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { conn.log.Errorf("Failed to update wg peer configuration: %v", err) return } + conn.workerRelay.EnableWgWatcher(conn.ctx) wgConfigWorkaround() if conn.wgProxyRelay != nil { @@ -594,8 +594,8 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { conn.wgProxyRelay = nil } - changed := conn.statusRelay != StatusDisconnected - conn.statusRelay = StatusDisconnected + changed := conn.statusRelay.Get() != StatusDisconnected + conn.statusRelay.Set(StatusDisconnected) select { case conn.relayDisconnected <- changed: @@ -661,8 +661,8 @@ func (conn *Conn) updateIceState(iceConnInfo ICEConnInfo) { } func (conn *Conn) setStatusToDisconnected() { - conn.statusRelay = StatusDisconnected - conn.statusICE = StatusDisconnected + conn.statusRelay.Set(StatusDisconnected) + conn.statusICE.Set(StatusDisconnected) peerState := State{ PubKey: conn.config.Key, @@ -706,7 +706,7 @@ func (conn *Conn) waitInitialRandomSleepTime() { } func (conn *Conn) isRelayed() bool { - if conn.statusRelay == StatusDisconnected && (conn.statusICE == StatusDisconnected || conn.statusICE == StatusConnecting) { + if conn.statusRelay.Get() == StatusDisconnected && (conn.statusICE.Get() == StatusDisconnected || conn.statusICE.Get() == StatusConnecting) { return false } @@ -718,11 +718,11 @@ func (conn *Conn) isRelayed() bool { } func (conn *Conn) evalStatus() ConnStatus { - if conn.statusRelay == StatusConnected || conn.statusICE == StatusConnected { + if conn.statusRelay.Get() == StatusConnected || conn.statusICE.Get() == StatusConnected { return StatusConnected } - if conn.statusRelay == StatusConnecting || conn.statusICE == StatusConnecting { + if conn.statusRelay.Get() == StatusConnecting || conn.statusICE.Get() == StatusConnecting { return StatusConnecting } @@ -733,12 +733,12 @@ func (conn *Conn) isConnected() bool { conn.mu.Lock() defer conn.mu.Unlock() - if conn.statusICE != StatusConnected && conn.statusICE != StatusConnecting { + if conn.statusICE.Get() != StatusConnected && conn.statusICE.Get() != StatusConnecting { return false } if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { - if conn.statusRelay != StatusConnected { + if conn.statusRelay.Get() != StatusConnected { return false } } diff --git a/client/internal/peer/conn_status.go b/client/internal/peer/conn_status.go index 639117c89..3c747864f 100644 --- a/client/internal/peer/conn_status.go +++ b/client/internal/peer/conn_status.go @@ -1,6 +1,10 @@ package peer -import log "github.com/sirupsen/logrus" +import ( + "sync/atomic" + + log "github.com/sirupsen/logrus" +) const ( // StatusConnected indicate the peer is in connected state @@ -12,7 +16,34 @@ const ( ) // ConnStatus describe the status of a peer's connection -type ConnStatus int +type ConnStatus int32 + +// AtomicConnStatus is a thread-safe wrapper for ConnStatus +type AtomicConnStatus struct { + status atomic.Int32 +} + +// NewAtomicConnStatus creates a new AtomicConnStatus with the given initial status +func NewAtomicConnStatus() *AtomicConnStatus { + acs := &AtomicConnStatus{} + acs.Set(StatusDisconnected) + return acs +} + +// Get returns the current connection status +func (acs *AtomicConnStatus) Get() ConnStatus { + return ConnStatus(acs.status.Load()) +} + +// Set updates the connection status +func (acs *AtomicConnStatus) Set(status ConnStatus) { + acs.status.Store(int32(status)) +} + +// String returns the string representation of the current status +func (acs *AtomicConnStatus) String() string { + return acs.Get().String() +} func (s ConnStatus) String() string { switch s { diff --git a/client/internal/peer/conn_test.go b/client/internal/peer/conn_test.go index 59f249b82..80c25f63c 100644 --- a/client/internal/peer/conn_test.go +++ b/client/internal/peer/conn_test.go @@ -158,8 +158,13 @@ func TestConn_Status(t *testing.T) { for _, table := range tables { t.Run(table.name, func(t *testing.T) { - conn.statusICE = table.statusIce - conn.statusRelay = table.statusRelay + si := NewAtomicConnStatus() + si.Set(table.statusIce) + conn.statusICE = si + + sr := NewAtomicConnStatus() + sr.Set(table.statusRelay) + conn.statusRelay = sr got := conn.Status() assert.Equal(t, got, table.want, "they should be equal") diff --git a/client/internal/peer/status.go b/client/internal/peer/status.go index f116f3fef..915fa63f0 100644 --- a/client/internal/peer/status.go +++ b/client/internal/peer/status.go @@ -597,6 +597,8 @@ func (d *Status) DeleteResolvedDomainsStates(domain domain.Domain) { } func (d *Status) GetRosenpassState() RosenpassState { + d.mux.Lock() + defer d.mux.Unlock() return RosenpassState{ d.rosenpassEnabled, d.rosenpassPermissive, @@ -604,6 +606,8 @@ func (d *Status) GetRosenpassState() RosenpassState { } func (d *Status) GetManagementState() ManagementState { + d.mux.Lock() + defer d.mux.Unlock() return ManagementState{ d.mgmAddress, d.managementState, @@ -645,6 +649,8 @@ func (d *Status) IsLoginRequired() bool { } func (d *Status) GetSignalState() SignalState { + d.mux.Lock() + defer d.mux.Unlock() return SignalState{ d.signalAddress, d.signalState, @@ -654,6 +660,8 @@ func (d *Status) GetSignalState() SignalState { // GetRelayStates returns the stun/turn/permanent relay states func (d *Status) GetRelayStates() []relay.ProbeResult { + d.mux.Lock() + defer d.mux.Unlock() if d.relayMgr == nil { return d.relayStates } @@ -684,6 +692,8 @@ func (d *Status) GetRelayStates() []relay.ProbeResult { } func (d *Status) GetDNSStates() []NSGroupState { + d.mux.Lock() + defer d.mux.Unlock() return d.nsGroupStates } @@ -695,18 +705,19 @@ func (d *Status) GetResolvedDomainsStates() map[domain.Domain][]netip.Prefix { // GetFullStatus gets full status func (d *Status) GetFullStatus() FullStatus { - d.mux.Lock() - defer d.mux.Unlock() - fullStatus := FullStatus{ ManagementState: d.GetManagementState(), SignalState: d.GetSignalState(), - LocalPeerState: d.localPeer, Relays: d.GetRelayStates(), RosenpassState: d.GetRosenpassState(), NSGroupStates: d.GetDNSStates(), } + d.mux.Lock() + defer d.mux.Unlock() + + fullStatus.LocalPeerState = d.localPeer + for _, status := range d.peers { fullStatus.Peers = append(fullStatus.Peers, status) } diff --git a/client/internal/peer/worker_relay.go b/client/internal/peer/worker_relay.go index 3457faa46..6bb385d3e 100644 --- a/client/internal/peer/worker_relay.go +++ b/client/internal/peer/worker_relay.go @@ -109,10 +109,10 @@ func (w *WorkerRelay) EnableWgWatcher(ctx context.Context) { } ctx, ctxCancel := context.WithCancel(ctx) - w.wgStateCheck(ctx) w.ctxWgWatch = ctx w.ctxCancelWgWatch = ctxCancel + w.wgStateCheck(ctx, ctxCancel) } func (w *WorkerRelay) DisableWgWatcher() { @@ -158,21 +158,22 @@ func (w *WorkerRelay) CloseConn() { } // wgStateCheck help to check the state of the WireGuard handshake and relay connection -func (w *WorkerRelay) wgStateCheck(ctx context.Context) { +func (w *WorkerRelay) wgStateCheck(ctx context.Context, ctxCancel context.CancelFunc) { + w.log.Debugf("WireGuard watcher started") lastHandshake, err := w.wgState() if err != nil { - w.log.Errorf("failed to read wg stats: %v", err) + w.log.Warnf("failed to read wg stats: %v", err) lastHandshake = time.Time{} } go func(lastHandshake time.Time) { timer := time.NewTimer(wgHandshakeOvertime) defer timer.Stop() + defer ctxCancel() for { select { case <-timer.C: - handshake, err := w.wgState() if err != nil { w.log.Errorf("failed to read wg stats: %v", err) diff --git a/management/cmd/root.go b/management/cmd/root.go index 9b6a96b82..86155a956 100644 --- a/management/cmd/root.go +++ b/management/cmd/root.go @@ -54,7 +54,7 @@ func Execute() error { func init() { stopCh = make(chan int) mgmtCmd.Flags().IntVar(&mgmtPort, "port", 80, "server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") - mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 8081, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics") + mgmtCmd.Flags().IntVar(&mgmtMetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics") mgmtCmd.Flags().StringVar(&mgmtDataDir, "datadir", defaultMgmtDataDir, "server data directory location") mgmtCmd.Flags().StringVar(&mgmtConfig, "config", defaultMgmtConfig, "Netbird config file location. Config params specified via command line (e.g. datadir) have a precedence over configuration from this file") mgmtCmd.Flags().StringVar(&mgmtLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") diff --git a/management/server/account.go b/management/server/account.go index 84bdc629f..5d934d56e 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -264,6 +264,11 @@ type AccountSettings struct { Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` } +// Subclass used in gorm to only load network and not whole account +type AccountNetwork struct { + Network *Network `gorm:"embedded;embeddedPrefix:network_"` +} + type UserPermissions struct { DashboardView string `json:"dashboard_view"` } @@ -701,14 +706,6 @@ func (a *Account) GetPeerGroupsList(peerID string) []string { return grps } -func (a *Account) getUserGroups(userID string) ([]string, error) { - user, err := a.FindUser(userID) - if err != nil { - return nil, err - } - return user.AutoGroups, nil -} - func (a *Account) getPeerDNSManagementStatus(peerID string) bool { peerGroups := a.getPeerGroups(peerID) enabled := true @@ -735,14 +732,6 @@ func (a *Account) getPeerGroups(peerID string) lookupMap { return groupList } -func (a *Account) getSetupKeyGroups(setupKey string) ([]string, error) { - key, err := a.FindSetupKey(setupKey) - if err != nil { - return nil, err - } - return key.AutoGroups, nil -} - func (a *Account) getTakenIPs() []net.IP { var takenIps []net.IP for _, existingPeer := range a.Peers { @@ -2083,7 +2072,7 @@ func (am *DefaultAccountManager) GetAccountIDForPeerKey(ctx context.Context, pee } func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpeer.Peer, settings *Settings) (bool, error) { - user, err := am.Store.GetUserByUserID(ctx, peer.UserID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, peer.UserID) if err != nil { return false, err } @@ -2104,6 +2093,25 @@ func (am *DefaultAccountManager) handleUserPeer(ctx context.Context, peer *nbpee return false, nil } +func (am *DefaultAccountManager) getFreeDNSLabel(ctx context.Context, store Store, accountID string, peerHostName string) (string, error) { + existingLabels, err := store.GetPeerLabelsInAccount(ctx, LockingStrengthShare, accountID) + if err != nil { + return "", fmt.Errorf("failed to get peer dns labels: %w", err) + } + + labelMap := ConvertSliceToMap(existingLabels) + newLabel, err := getPeerHostLabel(peerHostName, labelMap) + if err != nil { + return "", fmt.Errorf("failed to get new host label: %w", err) + } + + if newLabel == "" { + return "", fmt.Errorf("failed to get new host label: %w", err) + } + + return newLabel, nil +} + // addAllGroup to account object if it doesn't exist func addAllGroup(account *Account) error { if len(account.Groups) == 0 { diff --git a/management/server/activity/sqlite/crypt.go b/management/server/activity/sqlite/crypt.go index 852d9bc4a..096f49ea3 100644 --- a/management/server/activity/sqlite/crypt.go +++ b/management/server/activity/sqlite/crypt.go @@ -7,7 +7,6 @@ import ( "crypto/rand" "encoding/base64" "errors" - "fmt" ) var iv = []byte{10, 22, 13, 79, 05, 8, 52, 91, 87, 98, 88, 98, 35, 25, 13, 05} @@ -115,12 +114,23 @@ func pkcs5Padding(ciphertext []byte) []byte { padText := bytes.Repeat([]byte{byte(padding)}, padding) return append(ciphertext, padText...) } - func pkcs5UnPadding(src []byte) ([]byte, error) { srcLen := len(src) - paddingLen := int(src[srcLen-1]) - if paddingLen >= srcLen || paddingLen > aes.BlockSize { - return nil, fmt.Errorf("padding size error") + if srcLen == 0 { + return nil, errors.New("input data is empty") } + + paddingLen := int(src[srcLen-1]) + if paddingLen == 0 || paddingLen > aes.BlockSize || paddingLen > srcLen { + return nil, errors.New("invalid padding size") + } + + // Verify that all padding bytes are the same + for i := 0; i < paddingLen; i++ { + if src[srcLen-1-i] != byte(paddingLen) { + return nil, errors.New("invalid padding") + } + } + return src[:srcLen-paddingLen], nil } diff --git a/management/server/activity/sqlite/crypt_test.go b/management/server/activity/sqlite/crypt_test.go index 1033ab6ed..aff3a08b1 100644 --- a/management/server/activity/sqlite/crypt_test.go +++ b/management/server/activity/sqlite/crypt_test.go @@ -1,6 +1,7 @@ package sqlite import ( + "bytes" "testing" ) @@ -95,3 +96,215 @@ func TestCorruptKey(t *testing.T) { t.Fatalf("incorrect decryption, the result is: %s", res) } } + +func TestEncryptDecrypt(t *testing.T) { + // Generate a key for encryption/decryption + key, err := GenerateKey() + if err != nil { + t.Fatalf("Failed to generate key: %v", err) + } + + // Initialize the FieldEncrypt with the generated key + ec, err := NewFieldEncrypt(key) + if err != nil { + t.Fatalf("Failed to create FieldEncrypt: %v", err) + } + + // Test cases + testCases := []struct { + name string + input string + }{ + { + name: "Empty String", + input: "", + }, + { + name: "Short String", + input: "Hello", + }, + { + name: "String with Spaces", + input: "Hello, World!", + }, + { + name: "Long String", + input: "The quick brown fox jumps over the lazy dog.", + }, + { + name: "Unicode Characters", + input: "こんにちは世界", + }, + { + name: "Special Characters", + input: "!@#$%^&*()_+-=[]{}|;':\",./<>?", + }, + { + name: "Numeric String", + input: "1234567890", + }, + { + name: "Repeated Characters", + input: "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + }, + { + name: "Multi-block String", + input: "This is a longer string that will span multiple blocks in the encryption algorithm.", + }, + { + name: "Non-ASCII and ASCII Mix", + input: "Hello 世界 123", + }, + } + + for _, tc := range testCases { + t.Run(tc.name+" - Legacy", func(t *testing.T) { + // Legacy Encryption + encryptedLegacy := ec.LegacyEncrypt(tc.input) + if encryptedLegacy == "" { + t.Errorf("LegacyEncrypt returned empty string for input '%s'", tc.input) + } + + // Legacy Decryption + decryptedLegacy, err := ec.LegacyDecrypt(encryptedLegacy) + if err != nil { + t.Errorf("LegacyDecrypt failed for input '%s': %v", tc.input, err) + } + + // Verify that the decrypted value matches the original input + if decryptedLegacy != tc.input { + t.Errorf("LegacyDecrypt output '%s' does not match original input '%s'", decryptedLegacy, tc.input) + } + }) + + t.Run(tc.name+" - New", func(t *testing.T) { + // New Encryption + encryptedNew, err := ec.Encrypt(tc.input) + if err != nil { + t.Errorf("Encrypt failed for input '%s': %v", tc.input, err) + } + if encryptedNew == "" { + t.Errorf("Encrypt returned empty string for input '%s'", tc.input) + } + + // New Decryption + decryptedNew, err := ec.Decrypt(encryptedNew) + if err != nil { + t.Errorf("Decrypt failed for input '%s': %v", tc.input, err) + } + + // Verify that the decrypted value matches the original input + if decryptedNew != tc.input { + t.Errorf("Decrypt output '%s' does not match original input '%s'", decryptedNew, tc.input) + } + }) + } +} + +func TestPKCS5UnPadding(t *testing.T) { + tests := []struct { + name string + input []byte + expected []byte + expectError bool + }{ + { + name: "Valid Padding", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{4}, 4)...), + expected: []byte("Hello, World!"), + }, + { + name: "Empty Input", + input: []byte{}, + expectError: true, + }, + { + name: "Padding Length Zero", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{0}, 4)...), + expectError: true, + }, + { + name: "Padding Length Exceeds Block Size", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{17}, 17)...), + expectError: true, + }, + { + name: "Padding Length Exceeds Input Length", + input: []byte{5, 5, 5}, + expectError: true, + }, + { + name: "Invalid Padding Bytes", + input: append([]byte("Hello, World!"), []byte{2, 3, 4, 5}...), + expectError: true, + }, + { + name: "Valid Single Byte Padding", + input: append([]byte("Hello, World!"), byte(1)), + expected: []byte("Hello, World!"), + }, + { + name: "Invalid Mixed Padding Bytes", + input: append([]byte("Hello, World!"), []byte{3, 3, 2}...), + expectError: true, + }, + { + name: "Valid Full Block Padding", + input: append([]byte("Hello, World!"), bytes.Repeat([]byte{16}, 16)...), + expected: []byte("Hello, World!"), + }, + { + name: "Non-Padding Byte at End", + input: append([]byte("Hello, World!"), []byte{4, 4, 4, 5}...), + expectError: true, + }, + { + name: "Valid Padding with Different Text Length", + input: append([]byte("Test"), bytes.Repeat([]byte{12}, 12)...), + expected: []byte("Test"), + }, + { + name: "Padding Length Equal to Input Length", + input: bytes.Repeat([]byte{8}, 8), + expected: []byte{}, + }, + { + name: "Invalid Padding Length Zero (Again)", + input: append([]byte("Test"), byte(0)), + expectError: true, + }, + { + name: "Padding Length Greater Than Input", + input: []byte{10}, + expectError: true, + }, + { + name: "Input Length Not Multiple of Block Size", + input: append([]byte("Invalid Length"), byte(1)), + expected: []byte("Invalid Length"), + }, + { + name: "Valid Padding with Non-ASCII Characters", + input: append([]byte("こんにちは"), bytes.Repeat([]byte{2}, 2)...), + expected: []byte("こんにちは"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := pkcs5UnPadding(tt.input) + if tt.expectError { + if err == nil { + t.Errorf("Expected error but got nil") + } + } else { + if err != nil { + t.Errorf("Did not expect error but got: %v", err) + } + if !bytes.Equal(result, tt.expected) { + t.Errorf("Expected output %v, got %v", tt.expected, result) + } + } + }) + } +} diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 36c88f1d1..1390352a5 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -7,6 +7,7 @@ import ( "time" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/status" ) type MockStore struct { @@ -24,7 +25,7 @@ func (s *MockStore) GetAccountByPeerID(_ context.Context, peerId string) (*Accou return s.account, nil } - return nil, fmt.Errorf("account not found") + return nil, status.NewPeerNotFoundError(peerId) } type MocAccountManager struct { diff --git a/management/server/file_store.go b/management/server/file_store.go index ed1bc3d09..434a7710d 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -2,6 +2,8 @@ package server import ( "context" + "errors" + "net" "os" "path/filepath" "strings" @@ -46,6 +48,158 @@ type FileStore struct { metrics telemetry.AppMetrics `json:"-"` } +func (s *FileStore) ExecuteInTransaction(ctx context.Context, f func(store Store) error) error { + return f(s) +} + +func (s *FileStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { + s.mux.Lock() + defer s.mux.Unlock() + + accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKeyID)] + if !ok { + return status.NewSetupKeyNotFoundError() + } + + account, err := s.getAccount(accountID) + if err != nil { + return err + } + + account.SetupKeys[setupKeyID].UsedTimes++ + + return s.SaveAccount(ctx, account) +} + +func (s *FileStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return err + } + + allGroup, err := account.GetGroupAll() + if err != nil || allGroup == nil { + return errors.New("all group not found") + } + + allGroup.Peers = append(allGroup.Peers, peerID) + + return nil +} + +func (s *FileStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountId) + if err != nil { + return err + } + + account.Groups[groupID].Peers = append(account.Groups[groupID].Peers, peerId) + + return nil +} + +func (s *FileStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { + s.mux.Lock() + defer s.mux.Unlock() + + account, ok := s.Accounts[peer.AccountID] + if !ok { + return status.NewAccountNotFoundError(peer.AccountID) + } + + account.Peers[peer.ID] = peer + return s.SaveAccount(ctx, account) +} + +func (s *FileStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { + s.mux.Lock() + defer s.mux.Unlock() + + account, ok := s.Accounts[accountId] + if !ok { + return status.NewAccountNotFoundError(accountId) + } + + account.Network.Serial++ + + return s.SaveAccount(ctx, account) +} + +func (s *FileStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { + s.mux.Lock() + defer s.mux.Unlock() + + accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(key)] + if !ok { + return nil, status.NewSetupKeyNotFoundError() + } + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + setupKey, ok := account.SetupKeys[key] + if !ok { + return nil, status.Errorf(status.NotFound, "setup key not found") + } + + return setupKey, nil +} + +func (s *FileStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + var takenIps []net.IP + for _, existingPeer := range account.Peers { + takenIps = append(takenIps, existingPeer.IP) + } + + return takenIps, nil +} + +func (s *FileStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + existingLabels := []string{} + for _, peer := range account.Peers { + if peer.DNSLabel != "" { + existingLabels = append(existingLabels, peer.DNSLabel) + } + } + return existingLabels, nil +} + +func (s *FileStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { + s.mux.Lock() + defer s.mux.Unlock() + + account, err := s.getAccount(accountID) + if err != nil { + return nil, err + } + + return account.Network, nil +} + type StoredAccount struct{} // NewFileStore restores a store from the file located in the datadir @@ -422,7 +576,7 @@ func (s *FileStore) GetAccountBySetupKey(_ context.Context, setupKey string) (*A accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] if !ok { - return nil, status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists") + return nil, status.NewSetupKeyNotFoundError() } account, err := s.getAccount(accountID) @@ -469,7 +623,7 @@ func (s *FileStore) GetUserByTokenID(_ context.Context, tokenID string) (*User, return account.Users[userID].Copy(), nil } -func (s *FileStore) GetUserByUserID(_ context.Context, userID string) (*User, error) { +func (s *FileStore) GetUserByUserID(_ context.Context, _ LockingStrength, userID string) (*User, error) { accountID, ok := s.UserID2AccountID[userID] if !ok { return nil, status.Errorf(status.NotFound, "accountID not found: provided userID doesn't exists") @@ -513,7 +667,7 @@ func (s *FileStore) GetAllAccounts(_ context.Context) (all []*Account) { func (s *FileStore) getAccount(accountID string) (*Account, error) { account, ok := s.Accounts[accountID] if !ok { - return nil, status.Errorf(status.NotFound, "account not found") + return nil, status.NewAccountNotFoundError(accountID) } return account, nil @@ -639,13 +793,13 @@ func (s *FileStore) GetAccountIDBySetupKey(_ context.Context, setupKey string) ( accountID, ok := s.SetupKeyID2AccountID[strings.ToUpper(setupKey)] if !ok { - return "", status.Errorf(status.NotFound, "account not found: provided setup key doesn't exists") + return "", status.NewSetupKeyNotFoundError() } return accountID, nil } -func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbpeer.Peer, error) { +func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, _ LockingStrength, peerKey string) (*nbpeer.Peer, error) { s.mux.Lock() defer s.mux.Unlock() @@ -668,7 +822,7 @@ func (s *FileStore) GetPeerByPeerPubKey(_ context.Context, peerKey string) (*nbp return nil, status.NewPeerNotFoundError(peerKey) } -func (s *FileStore) GetAccountSettings(_ context.Context, accountID string) (*Settings, error) { +func (s *FileStore) GetAccountSettings(_ context.Context, _ LockingStrength, accountID string) (*Settings, error) { s.mux.Lock() defer s.mux.Unlock() @@ -758,7 +912,7 @@ func (s *FileStore) SavePeerLocation(accountID string, peerWithLocation *nbpeer. } // SaveUserLastLogin stores the last login time for a user in memory. It doesn't attempt to persist data to speed up things. -func (s *FileStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { +func (s *FileStore) SaveUserLastLogin(_ context.Context, accountID, userID string, lastLogin time.Time) error { s.mux.Lock() defer s.mux.Unlock() diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 00ee4bda2..ff09129bd 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -627,7 +627,7 @@ func testSyncStatusRace(t *testing.T) { } time.Sleep(10 * time.Millisecond) - peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), peerWithInvalidStatus.PublicKey().String()) + peer, err := am.Store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, peerWithInvalidStatus.PublicKey().String()) if err != nil { t.Fatal(err) return @@ -638,8 +638,8 @@ func testSyncStatusRace(t *testing.T) { } func Test_LoginPerformance(t *testing.T) { - if os.Getenv("CI") == "true" { - t.Skip("Skipping on CI") + if os.Getenv("CI") == "true" || runtime.GOOS == "windows" { + t.Skip("Skipping test on CI or Windows") } t.Setenv("NETBIRD_STORE_ENGINE", "sqlite") @@ -655,7 +655,7 @@ func Test_LoginPerformance(t *testing.T) { // {"M", 250, 1}, // {"L", 500, 1}, // {"XL", 750, 1}, - {"XXL", 2000, 1}, + {"XXL", 5000, 1}, } log.SetOutput(io.Discard) @@ -700,15 +700,18 @@ func Test_LoginPerformance(t *testing.T) { } defer mgmtServer.GracefulStop() + t.Logf("management setup complete, start registering peers") + var counter int32 var counterStart int32 - var wg sync.WaitGroup + var wgAccount sync.WaitGroup var mu sync.Mutex messageCalls := []func() error{} for j := 0; j < bc.accounts; j++ { - wg.Add(1) + wgAccount.Add(1) + var wgPeer sync.WaitGroup go func(j int, counter *int32, counterStart *int32) { - defer wg.Done() + defer wgAccount.Done() account, err := createAccount(am, fmt.Sprintf("account-%d", j), fmt.Sprintf("user-%d", j), fmt.Sprintf("domain-%d", j)) if err != nil { @@ -722,7 +725,9 @@ func Test_LoginPerformance(t *testing.T) { return } + startTime := time.Now() for i := 0; i < bc.peers; i++ { + wgPeer.Add(1) key, err := wgtypes.GeneratePrivateKey() if err != nil { t.Logf("failed to generate key: %v", err) @@ -763,21 +768,29 @@ func Test_LoginPerformance(t *testing.T) { mu.Lock() messageCalls = append(messageCalls, login) mu.Unlock() - _, _, _, err = am.LoginPeer(context.Background(), peerLogin) - if err != nil { - t.Logf("failed to login peer: %v", err) - return - } - atomic.AddInt32(counterStart, 1) - if *counterStart%100 == 0 { - t.Logf("registered %d peers", *counterStart) - } + go func(peerLogin PeerLogin, counterStart *int32) { + defer wgPeer.Done() + _, _, _, err = am.LoginPeer(context.Background(), peerLogin) + if err != nil { + t.Logf("failed to login peer: %v", err) + return + } + + atomic.AddInt32(counterStart, 1) + if *counterStart%100 == 0 { + t.Logf("registered %d peers", *counterStart) + } + }(peerLogin, counterStart) + } + wgPeer.Wait() + + t.Logf("Time for registration: %s", time.Since(startTime)) }(j, &counter, &counterStart) } - wg.Wait() + wgAccount.Wait() t.Logf("prepared %d login calls", len(messageCalls)) testLoginPerformance(t, messageCalls) diff --git a/management/server/peer.go b/management/server/peer.go index 26e27617d..da9586734 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -11,6 +11,7 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/posture" "github.com/netbirdio/netbird/management/proto" @@ -371,164 +372,175 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s } }() - var account *Account - // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) - account, err = am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - - if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" { - if am.idpManager != nil { - userdata, err := am.lookupUserInCache(ctx, userID, account) - if err == nil && userdata != nil { - peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0]) - } - } - } - // This is a handling for the case when the same machine (with the same WireGuard pub key) tries to register twice. // Such case is possible when AddPeer function takes long time to finish after AcquireWriteLockByUID (e.g., database is slow) // and the peer disconnects with a timeout and tries to register again. // We just check if this machine has been registered before and reject the second registration. // The connecting peer should be able to recover with a retry. - _, err = account.FindPeerByPubKey(peer.Key) + _, err = am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, peer.Key) if err == nil { return nil, nil, nil, status.Errorf(status.PreconditionFailed, "peer has been already registered") } opEvent := &activity.Event{ Timestamp: time.Now().UTC(), - AccountID: account.Id, + AccountID: accountID, } - var ephemeral bool - setupKeyName := "" - if !addedByUser { - // validate the setup key if adding with a key - sk, err := account.FindSetupKey(upperKey) - if err != nil { - return nil, nil, nil, err - } + var newPeer *nbpeer.Peer - if !sk.IsValid() { - return nil, nil, nil, status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") - } - - account.SetupKeys[sk.Key] = sk.IncrementUsage() - opEvent.InitiatorID = sk.Id - opEvent.Activity = activity.PeerAddedWithSetupKey - ephemeral = sk.Ephemeral - setupKeyName = sk.Name - } else { - opEvent.InitiatorID = userID - opEvent.Activity = activity.PeerAddedByUser - } - - takenIps := account.getTakenIPs() - existingLabels := account.getPeerDNSLabels() - - newLabel, err := getPeerHostLabel(peer.Meta.Hostname, existingLabels) - if err != nil { - return nil, nil, nil, err - } - - peer.DNSLabel = newLabel - network := account.Network - nextIp, err := AllocatePeerIP(network.Net, takenIps) - if err != nil { - return nil, nil, nil, err - } - - registrationTime := time.Now().UTC() - - newPeer := &nbpeer.Peer{ - ID: xid.New().String(), - Key: peer.Key, - SetupKey: upperKey, - IP: nextIp, - Meta: peer.Meta, - Name: peer.Meta.Hostname, - DNSLabel: newLabel, - UserID: userID, - Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, - SSHEnabled: false, - SSHKey: peer.SSHKey, - LastLogin: registrationTime, - CreatedAt: registrationTime, - LoginExpirationEnabled: addedByUser, - Ephemeral: ephemeral, - Location: peer.Location, - } - - if am.geo != nil && newPeer.Location.ConnectionIP != nil { - location, err := am.geo.Lookup(newPeer.Location.ConnectionIP) - if err != nil { - log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err) + err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error { + var groupsToAdd []string + var setupKeyID string + var setupKeyName string + var ephemeral bool + if addedByUser { + user, err := transaction.GetUserByUserID(ctx, LockingStrengthUpdate, userID) + if err != nil { + return fmt.Errorf("failed to get user groups: %w", err) + } + groupsToAdd = user.AutoGroups + opEvent.InitiatorID = userID + opEvent.Activity = activity.PeerAddedByUser } else { - newPeer.Location.CountryCode = location.Country.ISOCode - newPeer.Location.CityName = location.City.Names.En - newPeer.Location.GeoNameID = location.City.GeonameID - } - } + // Validate the setup key + sk, err := transaction.GetSetupKeyBySecret(ctx, LockingStrengthUpdate, upperKey) + if err != nil { + return fmt.Errorf("failed to get setup key: %w", err) + } - // add peer to 'All' group - group, err := account.GetGroupAll() - if err != nil { - return nil, nil, nil, err - } - group.Peers = append(group.Peers, newPeer.ID) + if !sk.IsValid() { + return status.Errorf(status.PreconditionFailed, "couldn't add peer: setup key is invalid") + } - var groupsToAdd []string - if addedByUser { - groupsToAdd, err = account.getUserGroups(userID) - if err != nil { - return nil, nil, nil, err + opEvent.InitiatorID = sk.Id + opEvent.Activity = activity.PeerAddedWithSetupKey + groupsToAdd = sk.AutoGroups + ephemeral = sk.Ephemeral + setupKeyID = sk.Id + setupKeyName = sk.Name } - } else { - groupsToAdd, err = account.getSetupKeyGroups(upperKey) - if err != nil { - return nil, nil, nil, err - } - } - if len(groupsToAdd) > 0 { - for _, s := range groupsToAdd { - if g, ok := account.Groups[s]; ok && g.Name != "All" { - g.Peers = append(g.Peers, newPeer.ID) + if strings.ToLower(peer.Meta.Hostname) == "iphone" || strings.ToLower(peer.Meta.Hostname) == "ipad" && userID != "" { + if am.idpManager != nil { + userdata, err := am.idpManager.GetUserDataByID(ctx, userID, idp.AppMetadata{WTAccountID: accountID}) + if err == nil && userdata != nil { + peer.Meta.Hostname = fmt.Sprintf("%s-%s", peer.Meta.Hostname, strings.Split(userdata.Email, "@")[0]) + } } } - } - newPeer = am.integratedPeerValidator.PreparePeer(ctx, account.Id, newPeer, account.GetPeerGroupsList(newPeer.ID), account.Settings.Extra) - - if addedByUser { - user, err := account.FindUser(userID) + freeLabel, err := am.getFreeDNSLabel(ctx, transaction, accountID, peer.Meta.Hostname) if err != nil { - return nil, nil, nil, status.Errorf(status.Internal, "couldn't find user") + return fmt.Errorf("failed to get free DNS label: %w", err) } - user.updateLastLogin(newPeer.LastLogin) - } - account.Peers[newPeer.ID] = newPeer - account.Network.IncSerial() - err = am.Store.SaveAccount(ctx, account) + freeIP, err := am.getFreeIP(ctx, transaction, accountID) + if err != nil { + return fmt.Errorf("failed to get free IP: %w", err) + } + + registrationTime := time.Now().UTC() + newPeer = &nbpeer.Peer{ + ID: xid.New().String(), + AccountID: accountID, + Key: peer.Key, + SetupKey: upperKey, + IP: freeIP, + Meta: peer.Meta, + Name: peer.Meta.Hostname, + DNSLabel: freeLabel, + UserID: userID, + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: registrationTime}, + SSHEnabled: false, + SSHKey: peer.SSHKey, + LastLogin: registrationTime, + CreatedAt: registrationTime, + LoginExpirationEnabled: addedByUser, + Ephemeral: ephemeral, + Location: peer.Location, + } + opEvent.TargetID = newPeer.ID + opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) + if !addedByUser { + opEvent.Meta["setup_key_name"] = setupKeyName + } + + if am.geo != nil && newPeer.Location.ConnectionIP != nil { + location, err := am.geo.Lookup(newPeer.Location.ConnectionIP) + if err != nil { + log.WithContext(ctx).Warnf("failed to get location for new peer realip: [%s]: %v", newPeer.Location.ConnectionIP.String(), err) + } else { + newPeer.Location.CountryCode = location.Country.ISOCode + newPeer.Location.CityName = location.City.Names.En + newPeer.Location.GeoNameID = location.City.GeonameID + } + } + + settings, err := transaction.GetAccountSettings(ctx, LockingStrengthShare, accountID) + if err != nil { + return fmt.Errorf("failed to get account settings: %w", err) + } + newPeer = am.integratedPeerValidator.PreparePeer(ctx, accountID, newPeer, groupsToAdd, settings.Extra) + + err = transaction.AddPeerToAllGroup(ctx, accountID, newPeer.ID) + if err != nil { + return fmt.Errorf("failed adding peer to All group: %w", err) + } + + if len(groupsToAdd) > 0 { + for _, g := range groupsToAdd { + err = transaction.AddPeerToGroup(ctx, accountID, newPeer.ID, g) + if err != nil { + return err + } + } + } + + err = transaction.AddPeerToAccount(ctx, newPeer) + if err != nil { + return fmt.Errorf("failed to add peer to account: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + if addedByUser { + err := transaction.SaveUserLastLogin(ctx, accountID, userID, newPeer.LastLogin) + if err != nil { + return fmt.Errorf("failed to update user last login: %w", err) + } + } else { + err = transaction.IncrementSetupKeyUsage(ctx, setupKeyID) + if err != nil { + return fmt.Errorf("failed to increment setup key usage: %w", err) + } + } + + log.WithContext(ctx).Debugf("Peer %s added to account %s", newPeer.ID, accountID) + return nil + }) + if err != nil { - return nil, nil, nil, err + return nil, nil, nil, fmt.Errorf("failed to add peer to database: %w", err) } - // Account is saved, we can release the lock - unlock() - unlock = nil - - opEvent.TargetID = newPeer.ID - opEvent.Meta = newPeer.EventMeta(am.GetDNSDomain()) - if !addedByUser { - opEvent.Meta["setup_key_name"] = setupKeyName + if newPeer == nil { + return nil, nil, nil, fmt.Errorf("new peer is nil") } am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) + unlock() + unlock = nil + + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, fmt.Errorf("error getting account: %w", err) + } + am.updateAccountPeers(ctx, account) approvedPeersMap, err := am.GetValidatedPeers(account) @@ -536,12 +548,31 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, err } - postureChecks := am.getPeerPostureChecks(account, peer) + postureChecks := am.getPeerPostureChecks(account, newPeer) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) return newPeer, networkMap, postureChecks, nil } +func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) { + takenIps, err := store.GetTakenIPs(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, fmt.Errorf("failed to get taken IPs: %w", err) + } + + network, err := store.GetAccountNetwork(ctx, LockingStrengthUpdate, accountID) + if err != nil { + return nil, fmt.Errorf("failed getting network: %w", err) + } + + nextIp, err := AllocatePeerIP(network.Net, takenIps) + if err != nil { + return nil, fmt.Errorf("failed to allocate new peer ip: %w", err) + } + + return nextIp, nil +} + // SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, account *Account) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { peer, err := account.FindPeerByPubKey(sync.WireGuardPubKey) @@ -647,12 +678,12 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) } }() - peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthUpdate, login.WireGuardPubKey) if err != nil { return nil, nil, nil, err } - settings, err := am.Store.GetAccountSettings(ctx, accountID) + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return nil, nil, nil, err } @@ -730,7 +761,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) // with no JWT token and usually no setup-key. As the client can send up to two login request to check if it is expired // and before starting the engine, we do the checks without an account lock to avoid piling up requests. func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Context, accountID string, login PeerLogin) error { - peer, err := am.Store.GetPeerByPeerPubKey(ctx, login.WireGuardPubKey) + peer, err := am.Store.GetPeerByPeerPubKey(ctx, LockingStrengthShare, login.WireGuardPubKey) if err != nil { return err } @@ -741,7 +772,7 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } - settings, err := am.Store.GetAccountSettings(ctx, accountID) + settings, err := am.Store.GetAccountSettings(ctx, LockingStrengthShare, accountID) if err != nil { return err } @@ -786,7 +817,7 @@ func (am *DefaultAccountManager) handleExpiredPeer(ctx context.Context, user *Us return err } - err = am.Store.SaveUserLastLogin(user.AccountID, user.Id, peer.LastLogin) + err = am.Store.SaveUserLastLogin(ctx, user.AccountID, user.Id, peer.LastLogin) if err != nil { return err } @@ -969,3 +1000,11 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account wg.Wait() } + +func ConvertSliceToMap(existingLabels []string) map[string]struct{} { + labelMap := make(map[string]struct{}, len(existingLabels)) + for _, label := range existingLabels { + labelMap[label] = struct{}{} + } + return labelMap +} diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 448e83a08..4b2ec66c6 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -7,20 +7,24 @@ import ( "net" "net/netip" "os" + "runtime" "testing" "time" "github.com/rs/xid" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/domain" "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/activity" nbgroup "github.com/netbirdio/netbird/management/server/group" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/management/server/telemetry" nbroute "github.com/netbirdio/netbird/route" ) @@ -995,3 +999,184 @@ func TestToSyncResponse(t *testing.T) { assert.Equal(t, 1, len(response.Checks)) assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0]) } + +func Test_RegisterPeerByUser(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + + eventStore := &activity.InMemoryEventStore{} + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + assert.NoError(t, err) + + am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + assert.NoError(t, err) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + existingUserID := "edafee4e-63fb-11ec-90d6-0242ac120003" + + _, err = store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + newPeer := &nbpeer.Peer{ + ID: xid.New().String(), + AccountID: existingAccountID, + Key: "newPeerKey", + SetupKey: "", + IP: net.IP{123, 123, 123, 123}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "newPeer", + GoOS: "linux", + }, + Name: "newPeerName", + DNSLabel: "newPeer.test", + UserID: existingUserID, + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, + SSHEnabled: false, + LastLogin: time.Now(), + } + + addedPeer, _, _, err := am.AddPeer(context.Background(), "", existingUserID, newPeer) + require.NoError(t, err) + + peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, addedPeer.Key) + require.NoError(t, err) + assert.Equal(t, peer.AccountID, existingAccountID) + assert.Equal(t, peer.UserID, existingUserID) + + account, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + assert.Contains(t, account.Peers, addedPeer.ID) + assert.Equal(t, peer.Meta.Hostname, newPeer.Meta.Hostname) + assert.Contains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, addedPeer.ID) + assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID) + + assert.Equal(t, uint64(1), account.Network.Serial) + + lastLogin, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") + assert.NoError(t, err) + assert.NotEqual(t, lastLogin, account.Users[existingUserID].LastLogin) +} + +func Test_RegisterPeerBySetupKey(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + + eventStore := &activity.InMemoryEventStore{} + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + assert.NoError(t, err) + + am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + assert.NoError(t, err) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + existingSetupKeyID := "A2C8E62B-38F5-4553-B31E-DD66C696CEBB" + + _, err = store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + newPeer := &nbpeer.Peer{ + ID: xid.New().String(), + AccountID: existingAccountID, + Key: "newPeerKey", + SetupKey: "existingSetupKey", + UserID: "", + IP: net.IP{123, 123, 123, 123}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "newPeer", + GoOS: "linux", + }, + Name: "newPeerName", + DNSLabel: "newPeer.test", + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, + SSHEnabled: false, + } + + addedPeer, _, _, err := am.AddPeer(context.Background(), existingSetupKeyID, "", newPeer) + + require.NoError(t, err) + + peer, err := store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) + require.NoError(t, err) + assert.Equal(t, peer.AccountID, existingAccountID) + assert.Equal(t, peer.SetupKey, existingSetupKeyID) + + account, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + assert.Contains(t, account.Peers, addedPeer.ID) + assert.Contains(t, account.Groups["cfefqs706sqkneg59g2g"].Peers, addedPeer.ID) + assert.Contains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, addedPeer.ID) + + assert.Equal(t, uint64(1), account.Network.Serial) + + lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") + assert.NoError(t, err) + assert.NotEqual(t, lastUsed, account.SetupKeys[existingSetupKeyID].LastUsed) + assert.Equal(t, 1, account.SetupKeys[existingSetupKeyID].UsedTimes) + +} + +func Test_RegisterPeerRollbackOnFailure(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + + eventStore := &activity.InMemoryEventStore{} + + metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + assert.NoError(t, err) + + am, err := BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.cloud", eventStore, nil, false, MocIntegratedValidator{}, metrics) + assert.NoError(t, err) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + faultyKey := "A2C8E62B-38F5-4553-B31E-DD66C696CEBC" + + _, err = store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + newPeer := &nbpeer.Peer{ + ID: xid.New().String(), + AccountID: existingAccountID, + Key: "newPeerKey", + SetupKey: "existingSetupKey", + UserID: "", + IP: net.IP{123, 123, 123, 123}, + Meta: nbpeer.PeerSystemMeta{ + Hostname: "newPeer", + GoOS: "linux", + }, + Name: "newPeerName", + DNSLabel: "newPeer.test", + Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now()}, + SSHEnabled: false, + } + + _, _, _, err = am.AddPeer(context.Background(), faultyKey, "", newPeer) + require.Error(t, err) + + _, err = store.GetPeerByPeerPubKey(context.Background(), LockingStrengthShare, newPeer.Key) + require.Error(t, err) + + account, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + assert.NotContains(t, account.Peers, newPeer.ID) + assert.NotContains(t, account.Groups["cfefqs706sqkneg59g3g"].Peers, newPeer.ID) + assert.NotContains(t, account.Groups["cfefqs706sqkneg59g4g"].Peers, newPeer.ID) + + assert.Equal(t, uint64(0), account.Network.Serial) + + lastUsed, err := time.Parse("2006-01-02T15:04:05Z", "0001-01-01T00:00:00Z") + assert.NoError(t, err) + assert.Equal(t, lastUsed, account.SetupKeys[faultyKey].LastUsed) + assert.Equal(t, 0, account.SetupKeys[faultyKey].UsedTimes) +} diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 49a5ddeb4..b644b6db1 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "net" "os" "path/filepath" "runtime" @@ -33,6 +34,7 @@ import ( const ( storeSqliteFileName = "store.db" idQueryCondition = "id = ?" + keyQueryCondition = "key = ?" accountAndIDQueryCondition = "account_id = ? and id = ?" peerNotFoundFMT = "peer %s not found" ) @@ -428,13 +430,12 @@ func (s *SqlStore) GetAccountIDByPrivateDomain(ctx context.Context, domain strin func (s *SqlStore) GetAccountBySetupKey(ctx context.Context, setupKey string) (*Account, error) { var key SetupKey - result := s.db.Select("account_id").First(&key, "key = ?", strings.ToUpper(setupKey)) + result := s.db.WithContext(ctx).Select("account_id").First(&key, keyQueryCondition, strings.ToUpper(setupKey)) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting setup key from store") + return nil, status.NewSetupKeyNotFoundError() } if key.AccountID == "" { @@ -487,15 +488,15 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, return &user, nil } -func (s *SqlStore) GetUserByUserID(ctx context.Context, userID string) (*User, error) { +func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { var user User - result := s.db.First(&user, idQueryCondition, userID) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "user not found: index lookup failed") + return nil, status.NewUserNotFoundError(userID) } - log.WithContext(ctx).Errorf("error when getting user from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "issue getting user from store") + return nil, status.NewGetUserFromStoreError() } return &user, nil @@ -548,7 +549,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, if result.Error != nil { log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return nil, status.Errorf(status.NotFound, "account not found") + return nil, status.NewAccountNotFoundError(accountID) } return nil, status.Errorf(status.Internal, "issue getting account from store") } @@ -608,7 +609,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Account, error) { var user User - result := s.db.Select("account_id").First(&user, idQueryCondition, userID) + result := s.db.WithContext(ctx).Select("account_id").First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") @@ -625,12 +626,11 @@ func (s *SqlStore) GetAccountByUser(ctx context.Context, userID string) (*Accoun func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) { var peer nbpeer.Peer - result := s.db.Select("account_id").First(&peer, idQueryCondition, peerID) + result := s.db.WithContext(ctx).Select("account_id").First(&peer, idQueryCondition, peerID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting account from store") } @@ -644,12 +644,11 @@ func (s *SqlStore) GetAccountByPeerID(ctx context.Context, peerID string) (*Acco func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) { var peer nbpeer.Peer - result := s.db.Select("account_id").First(&peer, "key = ?", peerKey) + result := s.db.WithContext(ctx).Select("account_id").First(&peer, keyQueryCondition, peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting account from store") } @@ -663,12 +662,11 @@ func (s *SqlStore) GetAccountByPeerPubKey(ctx context.Context, peerKey string) ( func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) { var peer nbpeer.Peer var accountID string - result := s.db.Model(&peer).Select("account_id").Where("key = ?", peerKey).First(&accountID) + result := s.db.WithContext(ctx).Model(&peer).Select("account_id").Where(keyQueryCondition, peerKey).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return "", status.Errorf(status.Internal, "issue getting account from store") } @@ -690,61 +688,117 @@ func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { } func (s *SqlStore) GetAccountIDBySetupKey(ctx context.Context, setupKey string) (string, error) { - var key SetupKey var accountID string - result := s.db.Model(&key).Select("account_id").Where("key = ?", strings.ToUpper(setupKey)).First(&accountID) + result := s.db.WithContext(ctx).Model(&SetupKey{}).Select("account_id").Where(keyQueryCondition, strings.ToUpper(setupKey)).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } - log.WithContext(ctx).Errorf("error when getting setup key from the store: %s", result.Error) - return "", status.Errorf(status.Internal, "issue getting setup key from store") + return "", status.NewSetupKeyNotFoundError() + } + + if accountID == "" { + return "", status.Errorf(status.NotFound, "account not found: index lookup failed") } return accountID, nil } -func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) { +func (s *SqlStore) GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountID string) ([]net.IP, error) { + var ipJSONStrings []string + + // Fetch the IP addresses as JSON strings + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + Where("account_id = ?", accountID). + Pluck("ip", &ipJSONStrings) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "no peers found for the account") + } + return nil, status.Errorf(status.Internal, "issue getting IPs from store") + } + + // Convert the JSON strings to net.IP objects + ips := make([]net.IP, len(ipJSONStrings)) + for i, ipJSON := range ipJSONStrings { + var ip net.IP + if err := json.Unmarshal([]byte(ipJSON), &ip); err != nil { + return nil, status.Errorf(status.Internal, "issue parsing IP JSON from store") + } + ips[i] = ip + } + + return ips, nil +} + +func (s *SqlStore) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountID string) ([]string, error) { + var labels []string + + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbpeer.Peer{}). + Where("account_id = ?", accountID). + Pluck("dns_label", &labels) + + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "no peers found for the account") + } + log.WithContext(ctx).Errorf("error when getting dns labels from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "issue getting dns labels from store") + } + + return labels, nil +} + +func (s *SqlStore) GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountID string) (*Network, error) { + var accountNetwork AccountNetwork + + if err := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountNetwork).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.Errorf(status.Internal, "issue getting network from store") + } + return accountNetwork.Network, nil +} + +func (s *SqlStore) GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) { var peer nbpeer.Peer - result := s.db.First(&peer, "key = ?", peerKey) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).First(&peer, keyQueryCondition, peerKey) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "peer not found") } - log.WithContext(ctx).Errorf("error when getting peer from the store: %s", result.Error) return nil, status.Errorf(status.Internal, "issue getting peer from store") } return &peer, nil } -func (s *SqlStore) GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) { +func (s *SqlStore) GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) { var accountSettings AccountSettings - if err := s.db.Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { + if err := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Where(idQueryCondition, accountID).First(&accountSettings).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "settings not found") } - log.WithContext(ctx).Errorf("error when getting settings from the store: %s", err) return nil, status.Errorf(status.Internal, "issue getting settings from store") } return accountSettings.Settings, nil } // SaveUserLastLogin stores the last login time for a user in DB. -func (s *SqlStore) SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error { +func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { var user User - result := s.db.First(&user, accountAndIDQueryCondition, accountID, userID) + result := s.db.WithContext(ctx).First(&user, accountAndIDQueryCondition, accountID, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { - return status.Errorf(status.NotFound, "user %s not found", userID) + return status.NewUserNotFoundError(userID) } - return status.Errorf(status.Internal, "issue getting user from store") + return status.NewGetUserFromStoreError() } - user.LastLogin = lastLogin - return s.db.Save(user).Error + return s.db.Save(&user).Error } func (s *SqlStore) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { @@ -863,3 +917,123 @@ func NewPostgresqlStoreFromFileStore(ctx context.Context, fileStore *FileStore, return store, nil } + +func (s *SqlStore) GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) { + var setupKey SetupKey + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + First(&setupKey, keyQueryCondition, strings.ToUpper(key)) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "setup key not found") + } + return nil, status.NewSetupKeyNotFoundError() + } + return &setupKey, nil +} + +func (s *SqlStore) IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error { + result := s.db.WithContext(ctx).Model(&SetupKey{}). + Where(idQueryCondition, setupKeyID). + Updates(map[string]interface{}{ + "used_times": gorm.Expr("used_times + 1"), + "last_used": time.Now(), + }) + + if result.Error != nil { + return status.Errorf(status.Internal, "issue incrementing setup key usage count: %s", result.Error) + } + + if result.RowsAffected == 0 { + return status.Errorf(status.NotFound, "setup key not found") + } + + return nil +} + +func (s *SqlStore) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error { + var group nbgroup.Group + + result := s.db.WithContext(ctx).Where("account_id = ? AND name = ?", accountID, "All").First(&group) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.Errorf(status.NotFound, "group 'All' not found for account") + } + return status.Errorf(status.Internal, "issue finding group 'All'") + } + + for _, existingPeerID := range group.Peers { + if existingPeerID == peerID { + return nil + } + } + + group.Peers = append(group.Peers, peerID) + + if err := s.db.Save(&group).Error; err != nil { + return status.Errorf(status.Internal, "issue updating group 'All'") + } + + return nil +} + +func (s *SqlStore) AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error { + var group nbgroup.Group + + result := s.db.WithContext(ctx).Where(accountAndIDQueryCondition, accountId, groupID).First(&group) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return status.Errorf(status.NotFound, "group not found for account") + } + return status.Errorf(status.Internal, "issue finding group") + } + + for _, existingPeerID := range group.Peers { + if existingPeerID == peerId { + return nil + } + } + + group.Peers = append(group.Peers, peerId) + + if err := s.db.Save(&group).Error; err != nil { + return status.Errorf(status.Internal, "issue updating group") + } + + return nil +} + +func (s *SqlStore) AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error { + if err := s.db.WithContext(ctx).Create(peer).Error; err != nil { + return status.Errorf(status.Internal, "issue adding peer to account") + } + + return nil +} + +func (s *SqlStore) IncrementNetworkSerial(ctx context.Context, accountId string) error { + result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, accountId).Update("network_serial", gorm.Expr("network_serial + 1")) + if result.Error != nil { + return status.Errorf(status.Internal, "issue incrementing network serial count") + } + return nil +} + +func (s *SqlStore) ExecuteInTransaction(ctx context.Context, operation func(store Store) error) error { + tx := s.db.WithContext(ctx).Begin() + if tx.Error != nil { + return tx.Error + } + repo := s.withTx(tx) + err := operation(repo) + if err != nil { + tx.Rollback() + return err + } + return tx.Commit().Error +} + +func (s *SqlStore) withTx(tx *gorm.DB) Store { + return &SqlStore{ + db: tx, + } +} diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index ce4ee531a..64ef36831 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -1003,3 +1003,163 @@ func TestPostgresql_GetUserByTokenID(t *testing.T) { require.NoError(t, err) require.Equal(t, id, user.PATs[id].ID) } + +func TestSqlite_GetTakenIPs(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + defer store.Close(context.Background()) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + takenIPs, err := store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + assert.Equal(t, []net.IP{}, takenIPs) + + peer1 := &nbpeer.Peer{ + ID: "peer1", + AccountID: existingAccountID, + IP: net.IP{1, 1, 1, 1}, + } + err = store.AddPeerToAccount(context.Background(), peer1) + require.NoError(t, err) + + takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + ip1 := net.IP{1, 1, 1, 1}.To16() + assert.Equal(t, []net.IP{ip1}, takenIPs) + + peer2 := &nbpeer.Peer{ + ID: "peer2", + AccountID: existingAccountID, + IP: net.IP{2, 2, 2, 2}, + } + err = store.AddPeerToAccount(context.Background(), peer2) + require.NoError(t, err) + + takenIPs, err = store.GetTakenIPs(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + ip2 := net.IP{2, 2, 2, 2}.To16() + assert.Equal(t, []net.IP{ip1, ip2}, takenIPs) + +} + +func TestSqlite_GetPeerLabelsInAccount(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + defer store.Close(context.Background()) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + labels, err := store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + assert.Equal(t, []string{}, labels) + + peer1 := &nbpeer.Peer{ + ID: "peer1", + AccountID: existingAccountID, + DNSLabel: "peer1.domain.test", + } + err = store.AddPeerToAccount(context.Background(), peer1) + require.NoError(t, err) + + labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + assert.Equal(t, []string{"peer1.domain.test"}, labels) + + peer2 := &nbpeer.Peer{ + ID: "peer2", + AccountID: existingAccountID, + DNSLabel: "peer2.domain.test", + } + err = store.AddPeerToAccount(context.Background(), peer2) + require.NoError(t, err) + + labels, err = store.GetPeerLabelsInAccount(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + assert.Equal(t, []string{"peer1.domain.test", "peer2.domain.test"}, labels) +} + +func TestSqlite_GetAccountNetwork(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + defer store.Close(context.Background()) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + network, err := store.GetAccountNetwork(context.Background(), LockingStrengthShare, existingAccountID) + require.NoError(t, err) + ip := net.IP{100, 64, 0, 0}.To16() + assert.Equal(t, ip, network.Net.IP) + assert.Equal(t, net.IPMask{255, 255, 0, 0}, network.Net.Mask) + assert.Equal(t, "", network.Dns) + assert.Equal(t, "af1c8024-ha40-4ce2-9418-34653101fc3c", network.Identifier) + assert.Equal(t, uint64(0), network.Serial) +} + +func TestSqlite_GetSetupKeyBySecret(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + defer store.Close(context.Background()) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + require.NoError(t, err) + assert.Equal(t, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", setupKey.Key) + assert.Equal(t, "bf1c8084-ba50-4ce7-9439-34653001fc3b", setupKey.AccountID) + assert.Equal(t, "Default key", setupKey.Name) +} + +func TestSqlite_incrementSetupKeyUsage(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("The SQLite store is not properly supported by Windows yet") + } + store := newSqliteStoreFromFile(t, "testdata/extended-store.json") + defer store.Close(context.Background()) + + existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + _, err := store.GetAccount(context.Background(), existingAccountID) + require.NoError(t, err) + + setupKey, err := store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + require.NoError(t, err) + assert.Equal(t, 0, setupKey.UsedTimes) + + err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id) + require.NoError(t, err) + + setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + require.NoError(t, err) + assert.Equal(t, 1, setupKey.UsedTimes) + + err = store.IncrementSetupKeyUsage(context.Background(), setupKey.Id) + require.NoError(t, err) + + setupKey, err = store.GetSetupKeyBySecret(context.Background(), LockingStrengthShare, "A2C8E62B-38F5-4553-B31E-DD66C696CEBB") + require.NoError(t, err) + assert.Equal(t, 2, setupKey.UsedTimes) +} diff --git a/management/server/status/error.go b/management/server/status/error.go index 58b9a84a0..d7fde35b9 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -100,3 +100,13 @@ func NewPeerNotRegisteredError() error { func NewPeerLoginExpiredError() error { return Errorf(PermissionDenied, "peer login has expired, please log in once more") } + +// NewSetupKeyNotFoundError creates a new Error with NotFound type for a missing setup key +func NewSetupKeyNotFoundError() error { + return Errorf(NotFound, "setup key not found") +} + +// NewGetUserFromStoreError creates a new Error with Internal type for an issue getting user from store +func NewGetUserFromStoreError() error { + return Errorf(Internal, "issue getting user from store") +} diff --git a/management/server/store.go b/management/server/store.go index dcff80ee5..3bf15c1b5 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -27,6 +27,15 @@ import ( "github.com/netbirdio/netbird/route" ) +type LockingStrength string + +const ( + LockingStrengthUpdate LockingStrength = "UPDATE" // Strongest lock, preventing any changes by other transactions until your transaction completes. + LockingStrengthShare LockingStrength = "SHARE" // Allows reading but prevents changes by other transactions. + LockingStrengthNoKeyUpdate LockingStrength = "NO KEY UPDATE" // Similar to UPDATE but allows changes to related rows. + LockingStrengthKeyShare LockingStrength = "KEY SHARE" // Protects against changes to primary/unique keys but allows other updates. +) + type Store interface { GetAllAccounts(ctx context.Context) []*Account GetAccount(ctx context.Context, accountID string) (*Account, error) @@ -42,7 +51,7 @@ type Store interface { GetAccountIDByPrivateDomain(ctx context.Context, domain string) (string, error) GetTokenIDByHashedToken(ctx context.Context, secret string) (string, error) GetUserByTokenID(ctx context.Context, tokenID string) (*User, error) - GetUserByUserID(ctx context.Context, userID string) (*User, error) + GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) SaveAccount(ctx context.Context, account *Account) error @@ -61,14 +70,24 @@ type Store interface { SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(accountID, peerID string, status nbpeer.PeerStatus) error SavePeerLocation(accountID string, peer *nbpeer.Peer) error - SaveUserLastLogin(accountID, userID string, lastLogin time.Time) error + SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error // Close should close the store persisting all unsaved data. Close(ctx context.Context) error // GetStoreEngine should return StoreEngine of the current store implementation. // This is also a method of metrics.DataSource interface. GetStoreEngine() StoreEngine - GetPeerByPeerPubKey(ctx context.Context, peerKey string) (*nbpeer.Peer, error) - GetAccountSettings(ctx context.Context, accountID string) (*Settings, error) + GetPeerByPeerPubKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (*nbpeer.Peer, error) + GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) + GetSetupKeyBySecret(ctx context.Context, lockStrength LockingStrength, key string) (*SetupKey, error) + GetTakenIPs(ctx context.Context, lockStrength LockingStrength, accountId string) ([]net.IP, error) + IncrementSetupKeyUsage(ctx context.Context, setupKeyID string) error + AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error + GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) + AddPeerToGroup(ctx context.Context, accountId string, peerId string, groupID string) error + AddPeerToAccount(ctx context.Context, peer *nbpeer.Peer) error + IncrementNetworkSerial(ctx context.Context, accountId string) error + GetAccountNetwork(ctx context.Context, lockStrength LockingStrength, accountId string) (*Network, error) + ExecuteInTransaction(ctx context.Context, f func(store Store) error) error } type StoreEngine string diff --git a/management/server/testdata/extended-store.json b/management/server/testdata/extended-store.json new file mode 100644 index 000000000..7f96e57a8 --- /dev/null +++ b/management/server/testdata/extended-store.json @@ -0,0 +1,120 @@ +{ + "Accounts": { + "bf1c8084-ba50-4ce7-9439-34653001fc3b": { + "Id": "bf1c8084-ba50-4ce7-9439-34653001fc3b", + "CreatedBy": "", + "Domain": "test.com", + "DomainCategory": "private", + "IsDomainPrimaryAccount": true, + "SetupKeys": { + "A2C8E62B-38F5-4553-B31E-DD66C696CEBB": { + "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", + "AccountID": "", + "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBB", + "Name": "Default key", + "Type": "reusable", + "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", + "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", + "UpdatedAt": "0001-01-01T00:00:00Z", + "Revoked": false, + "UsedTimes": 0, + "LastUsed": "0001-01-01T00:00:00Z", + "AutoGroups": ["cfefqs706sqkneg59g2g"], + "UsageLimit": 0, + "Ephemeral": false + }, + "A2C8E62B-38F5-4553-B31E-DD66C696CEBC": { + "Id": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC", + "AccountID": "", + "Key": "A2C8E62B-38F5-4553-B31E-DD66C696CEBC", + "Name": "Faulty key with non existing group", + "Type": "reusable", + "CreatedAt": "2021-08-19T20:46:20.005936822+02:00", + "ExpiresAt": "2321-09-18T20:46:20.005936822+02:00", + "UpdatedAt": "0001-01-01T00:00:00Z", + "Revoked": false, + "UsedTimes": 0, + "LastUsed": "0001-01-01T00:00:00Z", + "AutoGroups": ["abcd"], + "UsageLimit": 0, + "Ephemeral": false + } + }, + "Network": { + "id": "af1c8024-ha40-4ce2-9418-34653101fc3c", + "Net": { + "IP": "100.64.0.0", + "Mask": "//8AAA==" + }, + "Dns": "", + "Serial": 0 + }, + "Peers": {}, + "Users": { + "edafee4e-63fb-11ec-90d6-0242ac120003": { + "Id": "edafee4e-63fb-11ec-90d6-0242ac120003", + "AccountID": "", + "Role": "admin", + "IsServiceUser": false, + "ServiceUserName": "", + "AutoGroups": ["cfefqs706sqkneg59g3g"], + "PATs": {}, + "Blocked": false, + "LastLogin": "0001-01-01T00:00:00Z" + }, + "f4f6d672-63fb-11ec-90d6-0242ac120003": { + "Id": "f4f6d672-63fb-11ec-90d6-0242ac120003", + "AccountID": "", + "Role": "user", + "IsServiceUser": false, + "ServiceUserName": "", + "AutoGroups": null, + "PATs": { + "9dj38s35-63fb-11ec-90d6-0242ac120003": { + "ID": "9dj38s35-63fb-11ec-90d6-0242ac120003", + "UserID": "", + "Name": "", + "HashedToken": "SoMeHaShEdToKeN", + "ExpirationDate": "2023-02-27T00:00:00Z", + "CreatedBy": "user", + "CreatedAt": "2023-01-01T00:00:00Z", + "LastUsed": "2023-02-01T00:00:00Z" + } + }, + "Blocked": false, + "LastLogin": "0001-01-01T00:00:00Z" + } + }, + "Groups": { + "cfefqs706sqkneg59g4g": { + "ID": "cfefqs706sqkneg59g4g", + "Name": "All", + "Peers": [] + }, + "cfefqs706sqkneg59g3g": { + "ID": "cfefqs706sqkneg59g3g", + "Name": "AwesomeGroup1", + "Peers": [] + }, + "cfefqs706sqkneg59g2g": { + "ID": "cfefqs706sqkneg59g2g", + "Name": "AwesomeGroup2", + "Peers": [] + } + }, + "Rules": null, + "Policies": [], + "Routes": null, + "NameServerGroups": null, + "DNSSettings": null, + "Settings": { + "PeerLoginExpirationEnabled": false, + "PeerLoginExpiration": 86400000000000, + "GroupsPropagationEnabled": false, + "JWTGroupsEnabled": false, + "JWTGroupsClaimName": "" + } + } + }, + "InstallationID": "" +} diff --git a/management/server/user.go b/management/server/user.go index 85840f408..5dce744a6 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -89,10 +89,6 @@ func (u *User) LastDashboardLoginChanged(LastLogin time.Time) bool { return LastLogin.After(u.LastLogin) && !u.LastLogin.IsZero() } -func (u *User) updateLastLogin(login time.Time) { - u.LastLogin = login -} - // HasAdminPower returns true if the user has admin or owner roles, false otherwise func (u *User) HasAdminPower() bool { return u.Role == UserRoleAdmin || u.Role == UserRoleOwner @@ -393,7 +389,7 @@ func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.A // server when user authenticates a device. And we need to separate the Dashboard login event from the Device login event. newLogin := user.LastDashboardLoginChanged(claims.LastLogin) - err = am.Store.SaveUserLastLogin(account.Id, claims.UserId, claims.LastLogin) + err = am.Store.SaveUserLastLogin(ctx, account.Id, claims.UserId, claims.LastLogin) if err != nil { log.WithContext(ctx).Errorf("failed saving user last login: %v", err) } diff --git a/relay/client/client.go b/relay/client/client.go index 6560c81e1..e431c029d 100644 --- a/relay/client/client.go +++ b/relay/client/client.go @@ -63,32 +63,53 @@ type connContainer struct { messages chan Msg msgChanLock sync.Mutex closed bool // flag to check if channel is closed + ctx context.Context + cancel context.CancelFunc } func newConnContainer(conn *Conn, messages chan Msg) *connContainer { + ctx, cancel := context.WithCancel(context.Background()) + return &connContainer{ conn: conn, messages: messages, + ctx: ctx, + cancel: cancel, } } func (cc *connContainer) writeMsg(msg Msg) { cc.msgChanLock.Lock() defer cc.msgChanLock.Unlock() + if cc.closed { + msg.Free() return } - cc.messages <- msg + + select { + case cc.messages <- msg: + case <-cc.ctx.Done(): + msg.Free() + } } func (cc *connContainer) close() { + cc.cancel() + cc.msgChanLock.Lock() defer cc.msgChanLock.Unlock() + if cc.closed { return } - close(cc.messages) + cc.closed = true + close(cc.messages) + + for msg := range cc.messages { + msg.Free() + } } // Client is a client for the relay server. It is responsible for establishing a connection to the relay server and @@ -121,7 +142,7 @@ type Client struct { func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.TokenStore, peerID string) *Client { hashedID, hashedStringId := messages.HashID(peerID) return &Client{ - log: log.WithField("client_id", hashedStringId), + log: log.WithFields(log.Fields{"client_id": hashedStringId, "relay": serverURL}), parentCtx: ctx, connectionURL: serverURL, authTokenStore: authTokenStore, @@ -138,7 +159,7 @@ func NewClient(ctx context.Context, serverURL string, authTokenStore *auth.Token // Connect establishes a connection to the relay server. It blocks until the connection is established or an error occurs. func (c *Client) Connect() error { - c.log.Infof("connecting to relay server: %s", c.connectionURL) + c.log.Infof("connecting to relay server") c.readLoopMutex.Lock() defer c.readLoopMutex.Unlock() @@ -159,7 +180,7 @@ func (c *Client) Connect() error { c.wgReadLoop.Add(1) go c.readLoop(c.relayConn) - c.log.Infof("relay connection established with: %s", c.connectionURL) + c.log.Infof("relay connection established") return nil } @@ -181,7 +202,7 @@ func (c *Client) OpenConn(dstPeerID string) (net.Conn, error) { return nil, ErrConnAlreadyExists } - log.Infof("open connection to peer: %s", hashedStringID) + c.log.Infof("open connection to peer: %s", hashedStringID) msgChannel := make(chan Msg, 2) conn := NewConn(c, hashedID, hashedStringID, msgChannel, c.instanceURL) @@ -229,7 +250,7 @@ func (c *Client) connect() error { if err != nil { cErr := conn.Close() if cErr != nil { - log.Errorf("failed to close connection: %s", cErr) + c.log.Errorf("failed to close connection: %s", cErr) } return err } @@ -240,19 +261,19 @@ func (c *Client) connect() error { func (c *Client) handShake() error { msg, err := messages.MarshalAuthMsg(c.hashedID, c.authTokenStore.TokenBinary()) if err != nil { - log.Errorf("failed to marshal auth message: %s", err) + c.log.Errorf("failed to marshal auth message: %s", err) return err } _, err = c.relayConn.Write(msg) if err != nil { - log.Errorf("failed to send auth message: %s", err) + c.log.Errorf("failed to send auth message: %s", err) return err } buf := make([]byte, messages.MaxHandshakeRespSize) n, err := c.readWithTimeout(buf) if err != nil { - log.Errorf("failed to read auth response: %s", err) + c.log.Errorf("failed to read auth response: %s", err) return err } @@ -263,12 +284,12 @@ func (c *Client) handShake() error { msgType, err := messages.DetermineServerMessageType(buf[messages.SizeOfVersionByte:n]) if err != nil { - log.Errorf("failed to determine message type: %s", err) + c.log.Errorf("failed to determine message type: %s", err) return err } if msgType != messages.MsgTypeAuthResponse { - log.Errorf("unexpected message type: %s", msgType) + c.log.Errorf("unexpected message type: %s", msgType) return fmt.Errorf("unexpected message type") } @@ -285,7 +306,7 @@ func (c *Client) handShake() error { func (c *Client) readLoop(relayConn net.Conn) { internallyStoppedFlag := newInternalStopFlag() - hc := healthcheck.NewReceiver() + hc := healthcheck.NewReceiver(c.log) go c.listenForStopEvents(hc, relayConn, internallyStoppedFlag) var ( @@ -297,6 +318,7 @@ func (c *Client) readLoop(relayConn net.Conn) { buf := *bufPtr n, errExit = relayConn.Read(buf) if errExit != nil { + c.log.Infof("start to Relay read loop exit") c.mu.Lock() if c.serviceIsRunning && !internallyStoppedFlag.isSet() { c.log.Debugf("failed to read message from relay server: %s", errExit) @@ -343,7 +365,7 @@ func (c *Client) handleMsg(msgType messages.MsgType, buf []byte, bufPtr *[]byte, case messages.MsgTypeTransport: return c.handleTransportMsg(buf, bufPtr, internallyStoppedFlag) case messages.MsgTypeClose: - log.Debugf("relay connection close by server") + c.log.Debugf("relay connection close by server") c.bufPool.Put(bufPtr) return false } @@ -412,14 +434,14 @@ func (c *Client) writeTo(connReference *Conn, id string, dstID []byte, payload [ // todo: use buffer pool instead of create new transport msg. msg, err := messages.MarshalTransportMsg(dstID, payload) if err != nil { - log.Errorf("failed to marshal transport message: %s", err) + c.log.Errorf("failed to marshal transport message: %s", err) return 0, err } // the write always return with 0 length because the underling does not support the size feedback. _, err = c.relayConn.Write(msg) if err != nil { - log.Errorf("failed to write transport message: %s", err) + c.log.Errorf("failed to write transport message: %s", err) } return len(payload), err } @@ -438,7 +460,7 @@ func (c *Client) listenForStopEvents(hc *healthcheck.Receiver, conn net.Conn, in case <-c.parentCtx.Done(): err := c.close(true) if err != nil { - log.Errorf("failed to teardown connection: %s", err) + c.log.Errorf("failed to teardown connection: %s", err) } return } @@ -464,8 +486,8 @@ func (c *Client) closeConn(connReference *Conn, id string) error { if container.conn != connReference { return fmt.Errorf("conn reference mismatch") } - container.close() delete(c.conns, id) + container.close() return nil } @@ -478,10 +500,12 @@ func (c *Client) close(gracefullyExit bool) error { var err error if !c.serviceIsRunning { c.mu.Unlock() + c.log.Warn("relay connection was already marked as not running") return nil } c.serviceIsRunning = false + c.log.Infof("closing all peer connections") c.closeAllConns() if gracefullyExit { c.writeCloseMsg() @@ -489,8 +513,9 @@ func (c *Client) close(gracefullyExit bool) error { err = c.relayConn.Close() c.mu.Unlock() + c.log.Infof("waiting for read loop to close") c.wgReadLoop.Wait() - c.log.Infof("relay connection closed with: %s", c.connectionURL) + c.log.Infof("relay connection closed") return err } diff --git a/relay/client/client_test.go b/relay/client/client_test.go index b7f1a63ca..ef28203e9 100644 --- a/relay/client/client_test.go +++ b/relay/client/client_test.go @@ -618,6 +618,87 @@ func TestCloseByClient(t *testing.T) { } } +func TestCloseNotDrainedChannel(t *testing.T) { + ctx := context.Background() + idAlice := "alice" + idBob := "bob" + srvCfg := server.ListenerConfig{Address: serverListenAddr} + srv, err := server.NewServer(otel.Meter(""), serverURL, false, av) + if err != nil { + t.Fatalf("failed to create server: %s", err) + } + errChan := make(chan error, 1) + go func() { + err := srv.Listen(srvCfg) + if err != nil { + errChan <- err + } + }() + + defer func() { + err := srv.Shutdown(ctx) + if err != nil { + t.Errorf("failed to close server: %s", err) + } + }() + + // wait for servers to start + if err := waitForServerToStart(errChan); err != nil { + t.Fatalf("failed to start server: %s", err) + } + + clientAlice := NewClient(ctx, serverURL, hmacTokenStore, idAlice) + err = clientAlice.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer func() { + err := clientAlice.Close() + if err != nil { + t.Errorf("failed to close Alice client: %s", err) + } + }() + + clientBob := NewClient(ctx, serverURL, hmacTokenStore, idBob) + err = clientBob.Connect() + if err != nil { + t.Fatalf("failed to connect to server: %s", err) + } + defer func() { + err := clientBob.Close() + if err != nil { + t.Errorf("failed to close Bob client: %s", err) + } + }() + + connAliceToBob, err := clientAlice.OpenConn(idBob) + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + connBobToAlice, err := clientBob.OpenConn(idAlice) + if err != nil { + t.Fatalf("failed to bind channel: %s", err) + } + + payload := "hello bob, I am alice" + // the internal channel buffer size is 2. So we should overflow it + for i := 0; i < 5; i++ { + _, err = connAliceToBob.Write([]byte(payload)) + if err != nil { + t.Fatalf("failed to write to channel: %s", err) + } + + } + + // wait for delivery + time.Sleep(1 * time.Second) + err = connBobToAlice.Close() + if err != nil { + t.Errorf("failed to close channel: %s", err) + } +} + func waitForServerToStart(errChan chan error) error { select { case err := <-errChan: diff --git a/relay/client/manager.go b/relay/client/manager.go index a9d294160..4554c7c0f 100644 --- a/relay/client/manager.go +++ b/relay/client/manager.go @@ -3,7 +3,6 @@ package client import ( "container/list" "context" - "errors" "fmt" "net" "reflect" @@ -17,8 +16,6 @@ import ( var ( relayCleanupInterval = 60 * time.Second - connectionTimeout = 30 * time.Second - maxConcurrentServers = 7 ErrRelayClientNotConnected = fmt.Errorf("relay client not connected") ) @@ -92,67 +89,23 @@ func (m *Manager) Serve() error { } log.Debugf("starting relay client manager with %v relay servers", m.serverURLs) - totalServers := len(m.serverURLs) - - successChan := make(chan *Client, 1) - errChan := make(chan error, len(m.serverURLs)) - - ctx, cancel := context.WithTimeout(m.ctx, connectionTimeout) - defer cancel() - - sem := make(chan struct{}, maxConcurrentServers) - - for _, url := range m.serverURLs { - sem <- struct{}{} - go func(url string) { - defer func() { <-sem }() - m.connect(m.ctx, url, successChan, errChan) - }(url) + sp := ServerPicker{ + TokenStore: m.tokenStore, + PeerID: m.peerID, } - var errCount int - - for { - select { - case client := <-successChan: - log.Infof("Successfully connected to relay server: %s", client.connectionURL) - - m.relayClient = client - - m.reconnectGuard = NewGuard(m.ctx, m.relayClient) - m.relayClient.SetOnDisconnectListener(func() { - m.onServerDisconnected(client.connectionURL) - }) - m.startCleanupLoop() - return nil - case err := <-errChan: - errCount++ - log.Warnf("Connection attempt failed: %v", err) - if errCount == totalServers { - return errors.New("failed to connect to any relay server: all attempts failed") - } - case <-ctx.Done(): - return fmt.Errorf("failed to connect to any relay server: %w", ctx.Err()) - } + client, err := sp.PickServer(m.ctx, m.serverURLs) + if err != nil { + return err } -} + m.relayClient = client -func (m *Manager) connect(ctx context.Context, serverURL string, successChan chan<- *Client, errChan chan<- error) { - // TODO: abort the connection if another connection was successful - relayClient := NewClient(ctx, serverURL, m.tokenStore, m.peerID) - if err := relayClient.Connect(); err != nil { - errChan <- fmt.Errorf("failed to connect to %s: %w", serverURL, err) - return - } - - select { - case successChan <- relayClient: - // This client was the first to connect successfully - default: - if err := relayClient.Close(); err != nil { - log.Debugf("failed to close relay client: %s", err) - } - } + m.reconnectGuard = NewGuard(m.ctx, m.relayClient) + m.relayClient.SetOnDisconnectListener(func() { + m.onServerDisconnected(client.connectionURL) + }) + m.startCleanupLoop() + return nil } // OpenConn opens a connection to the given peer key. If the peer is on the same relay server, the connection will be diff --git a/relay/client/picker.go b/relay/client/picker.go new file mode 100644 index 000000000..b0888a4a0 --- /dev/null +++ b/relay/client/picker.go @@ -0,0 +1,94 @@ +package client + +import ( + "context" + "errors" + "fmt" + "time" + + log "github.com/sirupsen/logrus" + + auth "github.com/netbirdio/netbird/relay/auth/hmac" +) + +const ( + connectionTimeout = 30 * time.Second + maxConcurrentServers = 7 +) + +type connResult struct { + RelayClient *Client + Url string + Err error +} + +type ServerPicker struct { + TokenStore *auth.TokenStore + PeerID string +} + +func (sp *ServerPicker) PickServer(parentCtx context.Context, urls []string) (*Client, error) { + ctx, cancel := context.WithTimeout(parentCtx, connectionTimeout) + defer cancel() + + totalServers := len(urls) + + connResultChan := make(chan connResult, totalServers) + successChan := make(chan connResult, 1) + + concurrentLimiter := make(chan struct{}, maxConcurrentServers) + for _, url := range urls { + concurrentLimiter <- struct{}{} + go func(url string) { + defer func() { <-concurrentLimiter }() + sp.startConnection(parentCtx, connResultChan, url) + }(url) + } + + go sp.processConnResults(connResultChan, successChan) + + select { + case cr, ok := <-successChan: + if !ok { + return nil, errors.New("failed to connect to any relay server: all attempts failed") + } + log.Infof("chosen home Relay server: %s", cr.Url) + return cr.RelayClient, nil + case <-ctx.Done(): + return nil, fmt.Errorf("failed to connect to any relay server: %w", ctx.Err()) + } +} + +func (sp *ServerPicker) startConnection(ctx context.Context, resultChan chan connResult, url string) { + log.Infof("try to connecting to relay server: %s", url) + relayClient := NewClient(ctx, url, sp.TokenStore, sp.PeerID) + err := relayClient.Connect() + resultChan <- connResult{ + RelayClient: relayClient, + Url: url, + Err: err, + } +} + +func (sp *ServerPicker) processConnResults(resultChan chan connResult, successChan chan connResult) { + var hasSuccess bool + for cr := range resultChan { + if cr.Err != nil { + log.Debugf("failed to connect to Relay server: %s: %v", cr.Url, cr.Err) + continue + } + log.Infof("connected to Relay server: %s", cr.Url) + + if hasSuccess { + log.Infof("closing unnecessary Relay connection to: %s", cr.Url) + if err := cr.RelayClient.Close(); err != nil { + log.Errorf("failed to close connection to %s: %v", cr.Url, err) + } + continue + } + + hasSuccess = true + successChan <- cr + } + close(successChan) +} diff --git a/relay/cmd/root.go b/relay/cmd/root.go index dcc1465d0..d603ff73b 100644 --- a/relay/cmd/root.go +++ b/relay/cmd/root.go @@ -23,15 +23,12 @@ import ( "github.com/netbirdio/netbird/util" ) -const ( - metricsPort = 9090 -) - type Config struct { ListenAddress string // in HA every peer connect to a common domain, the instance domain has been distributed during the p2p connection // it is a domain:port or ip:port ExposedAddress string + MetricsPort int LetsencryptEmail string LetsencryptDataDir string LetsencryptDomains []string @@ -80,6 +77,7 @@ func init() { cobraConfig = &Config{} rootCmd.PersistentFlags().StringVarP(&cobraConfig.ListenAddress, "listen-address", "l", ":443", "listen address") rootCmd.PersistentFlags().StringVarP(&cobraConfig.ExposedAddress, "exposed-address", "e", "", "instance domain address (or ip) and port, it will be distributes between peers") + rootCmd.PersistentFlags().IntVar(&cobraConfig.MetricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics") rootCmd.PersistentFlags().StringVarP(&cobraConfig.LetsencryptDataDir, "letsencrypt-data-dir", "d", "", "a directory to store Let's Encrypt data. Required if Let's Encrypt is enabled.") rootCmd.PersistentFlags().StringSliceVarP(&cobraConfig.LetsencryptDomains, "letsencrypt-domains", "a", nil, "list of domains to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") rootCmd.PersistentFlags().StringVar(&cobraConfig.LetsencryptEmail, "letsencrypt-email", "", "email address to use for Let's Encrypt certificate registration") @@ -116,7 +114,7 @@ func execute(cmd *cobra.Command, args []string) error { return fmt.Errorf("failed to initialize log: %s", err) } - metricsServer, err := metrics.NewServer(metricsPort, "") + metricsServer, err := metrics.NewServer(cobraConfig.MetricsPort, "") if err != nil { log.Debugf("setup metrics: %v", err) return fmt.Errorf("setup metrics: %v", err) diff --git a/relay/healthcheck/receiver.go b/relay/healthcheck/receiver.go index 2b9c9e2e0..b3503d5db 100644 --- a/relay/healthcheck/receiver.go +++ b/relay/healthcheck/receiver.go @@ -3,10 +3,12 @@ package healthcheck import ( "context" "time" + + log "github.com/sirupsen/logrus" ) var ( - heartbeatTimeout = healthCheckInterval + 3*time.Second + heartbeatTimeout = healthCheckInterval + 10*time.Second ) // Receiver is a healthcheck receiver @@ -14,23 +16,26 @@ var ( // If the heartbeat is not received in a certain time, it will send a timeout signal and stop to work // The heartbeat timeout is a bit longer than the sender's healthcheck interval type Receiver struct { - OnTimeout chan struct{} - - ctx context.Context - ctxCancel context.CancelFunc - heartbeat chan struct{} - alive bool + OnTimeout chan struct{} + log *log.Entry + ctx context.Context + ctxCancel context.CancelFunc + heartbeat chan struct{} + alive bool + attemptThreshold int } // NewReceiver creates a new healthcheck receiver and start the timer in the background -func NewReceiver() *Receiver { +func NewReceiver(log *log.Entry) *Receiver { ctx, ctxCancel := context.WithCancel(context.Background()) r := &Receiver{ - OnTimeout: make(chan struct{}, 1), - ctx: ctx, - ctxCancel: ctxCancel, - heartbeat: make(chan struct{}, 1), + OnTimeout: make(chan struct{}, 1), + log: log, + ctx: ctx, + ctxCancel: ctxCancel, + heartbeat: make(chan struct{}, 1), + attemptThreshold: getAttemptThresholdFromEnv(), } go r.waitForHealthcheck() @@ -56,16 +61,23 @@ func (r *Receiver) waitForHealthcheck() { defer r.ctxCancel() defer close(r.OnTimeout) + failureCounter := 0 for { select { case <-r.heartbeat: r.alive = true + failureCounter = 0 case <-ticker.C: if r.alive { r.alive = false continue } + failureCounter++ + if failureCounter < r.attemptThreshold { + r.log.Warnf("healthcheck failed, attempt %d", failureCounter) + continue + } r.notifyTimeout() return case <-r.ctx.Done(): diff --git a/relay/healthcheck/receiver_test.go b/relay/healthcheck/receiver_test.go index 4b4123416..3b3e32fe6 100644 --- a/relay/healthcheck/receiver_test.go +++ b/relay/healthcheck/receiver_test.go @@ -1,13 +1,18 @@ package healthcheck import ( + "context" + "fmt" + "os" "testing" "time" + + log "github.com/sirupsen/logrus" ) func TestNewReceiver(t *testing.T) { heartbeatTimeout = 5 * time.Second - r := NewReceiver() + r := NewReceiver(log.WithContext(context.Background())) select { case <-r.OnTimeout: @@ -19,7 +24,7 @@ func TestNewReceiver(t *testing.T) { func TestNewReceiverNotReceive(t *testing.T) { heartbeatTimeout = 1 * time.Second - r := NewReceiver() + r := NewReceiver(log.WithContext(context.Background())) select { case <-r.OnTimeout: @@ -30,7 +35,7 @@ func TestNewReceiverNotReceive(t *testing.T) { func TestNewReceiverAck(t *testing.T) { heartbeatTimeout = 2 * time.Second - r := NewReceiver() + r := NewReceiver(log.WithContext(context.Background())) r.Heartbeat() @@ -40,3 +45,53 @@ func TestNewReceiverAck(t *testing.T) { case <-time.After(3 * time.Second): } } + +func TestReceiverHealthCheckAttemptThreshold(t *testing.T) { + testsCases := []struct { + name string + threshold int + resetCounterOnce bool + }{ + {"Default attempt threshold", defaultAttemptThreshold, false}, + {"Custom attempt threshold", 3, false}, + {"Should reset threshold once", 2, true}, + } + + for _, tc := range testsCases { + t.Run(tc.name, func(t *testing.T) { + originalInterval := healthCheckInterval + originalTimeout := heartbeatTimeout + healthCheckInterval = 1 * time.Second + heartbeatTimeout = healthCheckInterval + 500*time.Millisecond + defer func() { + healthCheckInterval = originalInterval + heartbeatTimeout = originalTimeout + }() + //nolint:tenv + os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) + defer os.Unsetenv(defaultAttemptThresholdEnv) + + receiver := NewReceiver(log.WithField("test_name", tc.name)) + + testTimeout := heartbeatTimeout*time.Duration(tc.threshold) + healthCheckInterval + + if tc.resetCounterOnce { + receiver.Heartbeat() + t.Logf("reset counter once") + } + + select { + case <-receiver.OnTimeout: + if tc.resetCounterOnce { + t.Fatalf("should not have timed out before %s", testTimeout) + } + case <-time.After(testTimeout): + if tc.resetCounterOnce { + return + } + t.Fatalf("should have timed out before %s", testTimeout) + } + + }) + } +} diff --git a/relay/healthcheck/sender.go b/relay/healthcheck/sender.go index ec0560ef2..57b3015ec 100644 --- a/relay/healthcheck/sender.go +++ b/relay/healthcheck/sender.go @@ -2,12 +2,21 @@ package healthcheck import ( "context" + "os" + "strconv" "time" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultAttemptThreshold = 1 + defaultAttemptThresholdEnv = "NB_RELAY_HC_ATTEMPT_THRESHOLD" ) var ( healthCheckInterval = 25 * time.Second - healthCheckTimeout = 5 * time.Second + healthCheckTimeout = 20 * time.Second ) // Sender is a healthcheck sender @@ -15,20 +24,25 @@ var ( // If the receiver does not receive the signal in a certain time, it will send a timeout signal and stop to work // It will also stop if the context is canceled type Sender struct { + log *log.Entry // HealthCheck is a channel to send health check signal to the peer HealthCheck chan struct{} // Timeout is a channel to the health check signal is not received in a certain time Timeout chan struct{} - ack chan struct{} + ack chan struct{} + alive bool + attemptThreshold int } // NewSender creates a new healthcheck sender -func NewSender() *Sender { +func NewSender(log *log.Entry) *Sender { hc := &Sender{ - HealthCheck: make(chan struct{}, 1), - Timeout: make(chan struct{}, 1), - ack: make(chan struct{}, 1), + log: log, + HealthCheck: make(chan struct{}, 1), + Timeout: make(chan struct{}, 1), + ack: make(chan struct{}, 1), + attemptThreshold: getAttemptThresholdFromEnv(), } return hc @@ -46,23 +60,51 @@ func (hc *Sender) StartHealthCheck(ctx context.Context) { ticker := time.NewTicker(healthCheckInterval) defer ticker.Stop() - timeoutTimer := time.NewTimer(healthCheckInterval + healthCheckTimeout) - defer timeoutTimer.Stop() + timeoutTicker := time.NewTicker(hc.getTimeoutTime()) + defer timeoutTicker.Stop() defer close(hc.HealthCheck) defer close(hc.Timeout) + failureCounter := 0 for { select { case <-ticker.C: hc.HealthCheck <- struct{}{} - case <-timeoutTimer.C: + case <-timeoutTicker.C: + if hc.alive { + hc.alive = false + continue + } + + failureCounter++ + if failureCounter < hc.attemptThreshold { + hc.log.Warnf("Health check failed attempt %d.", failureCounter) + continue + } hc.Timeout <- struct{}{} return case <-hc.ack: - timeoutTimer.Reset(healthCheckInterval + healthCheckTimeout) + failureCounter = 0 + hc.alive = true case <-ctx.Done(): return } } } + +func (hc *Sender) getTimeoutTime() time.Duration { + return healthCheckInterval + healthCheckTimeout +} + +func getAttemptThresholdFromEnv() int { + if attemptThreshold := os.Getenv(defaultAttemptThresholdEnv); attemptThreshold != "" { + threshold, err := strconv.ParseInt(attemptThreshold, 10, 64) + if err != nil { + log.Errorf("Failed to parse attempt threshold from environment variable \"%s\" should be an integer. Using default value", attemptThreshold) + return defaultAttemptThreshold + } + return int(threshold) + } + return defaultAttemptThreshold +} diff --git a/relay/healthcheck/sender_test.go b/relay/healthcheck/sender_test.go index 7a105c308..f21167025 100644 --- a/relay/healthcheck/sender_test.go +++ b/relay/healthcheck/sender_test.go @@ -2,9 +2,12 @@ package healthcheck import ( "context" + "fmt" "os" "testing" "time" + + log "github.com/sirupsen/logrus" ) func TestMain(m *testing.M) { @@ -18,7 +21,7 @@ func TestMain(m *testing.M) { func TestNewHealthPeriod(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender() + hc := NewSender(log.WithContext(ctx)) go hc.StartHealthCheck(ctx) iterations := 0 @@ -38,7 +41,7 @@ func TestNewHealthPeriod(t *testing.T) { func TestNewHealthFailed(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender() + hc := NewSender(log.WithContext(ctx)) go hc.StartHealthCheck(ctx) select { @@ -50,7 +53,7 @@ func TestNewHealthFailed(t *testing.T) { func TestNewHealthcheckStop(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - hc := NewSender() + hc := NewSender(log.WithContext(ctx)) go hc.StartHealthCheck(ctx) time.Sleep(100 * time.Millisecond) @@ -75,7 +78,7 @@ func TestNewHealthcheckStop(t *testing.T) { func TestTimeoutReset(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := NewSender() + hc := NewSender(log.WithContext(ctx)) go hc.StartHealthCheck(ctx) iterations := 0 @@ -101,3 +104,102 @@ func TestTimeoutReset(t *testing.T) { t.Fatalf("is not exited") } } + +func TestSenderHealthCheckAttemptThreshold(t *testing.T) { + testsCases := []struct { + name string + threshold int + resetCounterOnce bool + }{ + {"Default attempt threshold", defaultAttemptThreshold, false}, + {"Custom attempt threshold", 3, false}, + {"Should reset threshold once", 2, true}, + } + + for _, tc := range testsCases { + t.Run(tc.name, func(t *testing.T) { + originalInterval := healthCheckInterval + originalTimeout := healthCheckTimeout + healthCheckInterval = 1 * time.Second + healthCheckTimeout = 500 * time.Millisecond + defer func() { + healthCheckInterval = originalInterval + healthCheckTimeout = originalTimeout + }() + + //nolint:tenv + os.Setenv(defaultAttemptThresholdEnv, fmt.Sprintf("%d", tc.threshold)) + defer os.Unsetenv(defaultAttemptThresholdEnv) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + sender := NewSender(log.WithField("test_name", tc.name)) + go sender.StartHealthCheck(ctx) + + go func() { + responded := false + for { + select { + case <-ctx.Done(): + return + case _, ok := <-sender.HealthCheck: + if !ok { + return + } + if tc.resetCounterOnce && !responded { + responded = true + sender.OnHCResponse() + } + } + } + }() + + testTimeout := sender.getTimeoutTime()*time.Duration(tc.threshold) + healthCheckInterval + + select { + case <-sender.Timeout: + if tc.resetCounterOnce { + t.Fatalf("should not have timed out before %s", testTimeout) + } + case <-time.After(testTimeout): + if tc.resetCounterOnce { + return + } + t.Fatalf("should have timed out before %s", testTimeout) + } + + }) + } + +} + +//nolint:tenv +func TestGetAttemptThresholdFromEnv(t *testing.T) { + tests := []struct { + name string + envValue string + expected int + }{ + {"Default attempt threshold when env is not set", "", defaultAttemptThreshold}, + {"Custom attempt threshold when env is set to a valid integer", "3", 3}, + {"Default attempt threshold when env is set to an invalid value", "invalid", defaultAttemptThreshold}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.envValue == "" { + os.Unsetenv(defaultAttemptThresholdEnv) + } else { + os.Setenv(defaultAttemptThresholdEnv, tt.envValue) + } + + result := getAttemptThresholdFromEnv() + if result != tt.expected { + t.Fatalf("Expected %d, got %d", tt.expected, result) + } + + os.Unsetenv(defaultAttemptThresholdEnv) + }) + } +} diff --git a/relay/server/peer.go b/relay/server/peer.go index 0de601996..a9c542f84 100644 --- a/relay/server/peer.go +++ b/relay/server/peer.go @@ -49,7 +49,7 @@ func (p *Peer) Work() { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - hc := healthcheck.NewSender() + hc := healthcheck.NewSender(p.log) go hc.StartHealthCheck(ctx) go p.handleHealthcheckEvents(ctx, hc) @@ -115,6 +115,7 @@ func (p *Peer) Write(b []byte) (int, error) { // connection. func (p *Peer) CloseGracefully(ctx context.Context) { p.connMu.Lock() + defer p.connMu.Unlock() err := p.writeWithTimeout(ctx, messages.MarshalCloseMsg()) if err != nil { p.log.Errorf("failed to send close message to peer: %s", p.String()) @@ -124,8 +125,15 @@ func (p *Peer) CloseGracefully(ctx context.Context) { if err != nil { p.log.Errorf("failed to close connection to peer: %s", err) } +} +func (p *Peer) Close() { + p.connMu.Lock() defer p.connMu.Unlock() + + if err := p.conn.Close(); err != nil { + p.log.Errorf("failed to close connection to peer: %s", err) + } } // String returns the peer ID @@ -167,6 +175,7 @@ func (p *Peer) handleHealthcheckEvents(ctx context.Context, hc *healthcheck.Send if err != nil { p.log.Errorf("failed to close connection to peer: %s", err) } + p.log.Info("peer connection closed due healthcheck timeout") return case <-ctx.Done(): return diff --git a/relay/server/store.go b/relay/server/store.go index 96879dae1..4288e62c5 100644 --- a/relay/server/store.go +++ b/relay/server/store.go @@ -19,10 +19,14 @@ func NewStore() *Store { } // AddPeer adds a peer to the store -// todo: consider to close peer conn if the peer already exists func (s *Store) AddPeer(peer *Peer) { s.peersLock.Lock() defer s.peersLock.Unlock() + odlPeer, ok := s.peers[peer.String()] + if ok { + odlPeer.Close() + } + s.peers[peer.String()] = peer } diff --git a/relay/server/store_test.go b/relay/server/store_test.go index 4a30bc131..41c7baa92 100644 --- a/relay/server/store_test.go +++ b/relay/server/store_test.go @@ -2,13 +2,57 @@ package server import ( "context" + "net" "testing" + "time" "go.opentelemetry.io/otel" "github.com/netbirdio/netbird/relay/metrics" ) +type mockConn struct { +} + +func (m mockConn) Read(b []byte) (n int, err error) { + //TODO implement me + panic("implement me") +} + +func (m mockConn) Write(b []byte) (n int, err error) { + //TODO implement me + panic("implement me") +} + +func (m mockConn) Close() error { + return nil +} + +func (m mockConn) LocalAddr() net.Addr { + //TODO implement me + panic("implement me") +} + +func (m mockConn) RemoteAddr() net.Addr { + //TODO implement me + panic("implement me") +} + +func (m mockConn) SetDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + +func (m mockConn) SetReadDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + +func (m mockConn) SetWriteDeadline(t time.Time) error { + //TODO implement me + panic("implement me") +} + func TestStore_DeletePeer(t *testing.T) { s := NewStore() @@ -27,8 +71,9 @@ func TestStore_DeleteDeprecatedPeer(t *testing.T) { m, _ := metrics.NewMetrics(context.Background(), otel.Meter("")) - p1 := NewPeer(m, []byte("peer_id"), nil, nil) - p2 := NewPeer(m, []byte("peer_id"), nil, nil) + conn := &mockConn{} + p1 := NewPeer(m, []byte("peer_id"), conn, nil) + p2 := NewPeer(m, []byte("peer_id"), conn, nil) s.AddPeer(p1) s.AddPeer(p2) diff --git a/signal/cmd/run.go b/signal/cmd/run.go index 61f7a32a7..0bdc62ead 100644 --- a/signal/cmd/run.go +++ b/signal/cmd/run.go @@ -29,12 +29,9 @@ import ( "google.golang.org/grpc/keepalive" ) -const ( - metricsPort = 9090 -) - var ( signalPort int + metricsPort int signalLetsencryptDomain string signalSSLDir string defaultSignalSSLDir string @@ -288,6 +285,7 @@ func loadTLSConfig(certFile string, certKey string) (*tls.Config, error) { func init() { runCmd.PersistentFlags().IntVar(&signalPort, "port", 80, "Server port to listen on (defaults to 443 if TLS is enabled, 80 otherwise") + runCmd.Flags().IntVar(&metricsPort, "metrics-port", 9090, "metrics endpoint http port. Metrics are accessible under host:metrics-port/metrics") runCmd.Flags().StringVar(&signalSSLDir, "ssl-dir", defaultSignalSSLDir, "server ssl directory location. *Required only for Let's Encrypt certificates.") runCmd.Flags().StringVar(&signalLetsencryptDomain, "letsencrypt-domain", "", "a domain to issue Let's Encrypt certificate for. Enables TLS using Let's Encrypt. Will fetch and renew certificate, and run the server with TLS") runCmd.Flags().StringVar(&signalCertFile, "cert-file", "", "Location of your SSL certificate. Can be used when you have an existing certificate and don't want a new certificate be generated automatically. If letsencrypt-domain is specified this property has no effect") diff --git a/signal/peer/peer.go b/signal/peer/peer.go index 85de91581..ed2360d67 100644 --- a/signal/peer/peer.go +++ b/signal/peer/peer.go @@ -82,8 +82,11 @@ func (registry *Registry) Register(peer *Peer) { log.Warnf("peer [%s] is already registered [new streamID %d, previous StreamID %d]. Will override stream.", peer.Id, peer.StreamID, pp.StreamID) registry.Peers.Store(peer.Id, peer) + return } + log.Debugf("peer registered [%s]", peer.Id) + registry.metrics.ActivePeers.Add(context.Background(), 1) // record time as milliseconds registry.metrics.RegistrationDelay.Record(context.Background(), float64(time.Since(start).Nanoseconds())/1e6) @@ -105,8 +108,8 @@ func (registry *Registry) Deregister(peer *Peer) { peer.Id, pp.StreamID, peer.StreamID) return } + registry.metrics.ActivePeers.Add(context.Background(), -1) + log.Debugf("peer deregistered [%s]", peer.Id) + registry.metrics.Deregistrations.Add(context.Background(), 1) } - log.Debugf("peer deregistered [%s]", peer.Id) - - registry.metrics.Deregistrations.Add(context.Background(), 1) } diff --git a/signal/server/signal.go b/signal/server/signal.go index 69387cc69..b268aa3fc 100644 --- a/signal/server/signal.go +++ b/signal/server/signal.go @@ -133,8 +133,6 @@ func (s *Server) RegisterPeer(stream proto.SignalExchange_ConnectStreamServer) ( s.registry.Register(p) s.dispatcher.ListenForMessages(stream.Context(), p.Id, s.forwardMessageToPeer) - s.metrics.ActivePeers.Add(stream.Context(), 1) - return p, nil } else { s.metrics.RegistrationFailures.Add(stream.Context(), 1, metric.WithAttributes(attribute.String(labelError, labelErrorMissingId))) @@ -151,7 +149,6 @@ func (s *Server) DeregisterPeer(p *peer.Peer) { s.registry.Deregister(p) s.metrics.PeerConnectionDuration.Record(p.Stream.Context(), int64(time.Since(p.RegisteredAt).Seconds())) - s.metrics.ActivePeers.Add(context.Background(), -1) } func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedMessage) { diff --git a/util/file.go b/util/file.go index 2a6182556..8355488c9 100644 --- a/util/file.go +++ b/util/file.go @@ -10,51 +10,30 @@ import ( log "github.com/sirupsen/logrus" ) -// WriteJson writes JSON config object to a file creating parent directories if required -// The output JSON is pretty-formatted -func WriteJson(file string, obj interface{}) error { - +// WriteJsonWithRestrictedPermission writes JSON config object to a file. Enforces permission on the parent directory +func WriteJsonWithRestrictedPermission(file string, obj interface{}) error { configDir, configFileName, err := prepareConfigFileDir(file) if err != nil { return err } - // make it pretty - bs, err := json.MarshalIndent(obj, "", " ") + err = EnforcePermission(file) if err != nil { return err } - tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) + return writeJson(file, obj, configDir, configFileName) +} + +// WriteJson writes JSON config object to a file creating parent directories if required +// The output JSON is pretty-formatted +func WriteJson(file string, obj interface{}) error { + configDir, configFileName, err := prepareConfigFileDir(file) if err != nil { return err } - tempFileName := tempFile.Name() - // closing file ops as windows doesn't allow to move it - err = tempFile.Close() - if err != nil { - return err - } - - defer func() { - _, err = os.Stat(tempFileName) - if err == nil { - os.Remove(tempFileName) - } - }() - - err = os.WriteFile(tempFileName, bs, 0600) - if err != nil { - return err - } - - err = os.Rename(tempFileName, file) - if err != nil { - return err - } - - return nil + return writeJson(file, obj, configDir, configFileName) } // DirectWriteJson writes JSON config object to a file creating parent directories if required without creating a temporary file @@ -96,6 +75,46 @@ func DirectWriteJson(ctx context.Context, file string, obj interface{}) error { return nil } +func writeJson(file string, obj interface{}, configDir string, configFileName string) error { + + // make it pretty + bs, err := json.MarshalIndent(obj, "", " ") + if err != nil { + return err + } + + tempFile, err := os.CreateTemp(configDir, ".*"+configFileName) + if err != nil { + return err + } + + tempFileName := tempFile.Name() + // closing file ops as windows doesn't allow to move it + err = tempFile.Close() + if err != nil { + return err + } + + defer func() { + _, err = os.Stat(tempFileName) + if err == nil { + os.Remove(tempFileName) + } + }() + + err = os.WriteFile(tempFileName, bs, 0600) + if err != nil { + return err + } + + err = os.Rename(tempFileName, file) + if err != nil { + return err + } + + return nil +} + func openOrCreateFile(file string) (*os.File, error) { s, err := os.Stat(file) if err == nil { @@ -172,5 +191,9 @@ func prepareConfigFileDir(file string) (string, string, error) { } err := os.MkdirAll(configDir, 0750) + if err != nil { + return "", "", err + } + return configDir, configFileName, err } diff --git a/util/permission.go b/util/permission.go new file mode 100644 index 000000000..666998cff --- /dev/null +++ b/util/permission.go @@ -0,0 +1,7 @@ +//go:build !windows + +package util + +func EnforcePermission(dirPath string) error { + return nil +} diff --git a/util/permission_windows.go b/util/permission_windows.go new file mode 100644 index 000000000..548fef824 --- /dev/null +++ b/util/permission_windows.go @@ -0,0 +1,86 @@ +package util + +import ( + "path/filepath" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/windows" +) + +const ( + securityFlags = windows.OWNER_SECURITY_INFORMATION | + windows.GROUP_SECURITY_INFORMATION | + windows.DACL_SECURITY_INFORMATION | + windows.PROTECTED_DACL_SECURITY_INFORMATION +) + +func EnforcePermission(file string) error { + dirPath := filepath.Dir(file) + + user, group, err := sids() + if err != nil { + return err + } + + adminGroupSid, err := windows.CreateWellKnownSid(windows.WinBuiltinAdministratorsSid) + if err != nil { + return err + } + + explicitAccess := []windows.EXPLICIT_ACCESS{ + { + AccessPermissions: windows.GENERIC_ALL, + AccessMode: windows.SET_ACCESS, + Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT, + Trustee: windows.TRUSTEE{ + MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE, + TrusteeForm: windows.TRUSTEE_IS_SID, + TrusteeType: windows.TRUSTEE_IS_USER, + TrusteeValue: windows.TrusteeValueFromSID(user), + }, + }, + { + AccessPermissions: windows.GENERIC_ALL, + AccessMode: windows.SET_ACCESS, + Inheritance: windows.SUB_CONTAINERS_AND_OBJECTS_INHERIT, + Trustee: windows.TRUSTEE{ + MultipleTrusteeOperation: windows.NO_MULTIPLE_TRUSTEE, + TrusteeForm: windows.TRUSTEE_IS_SID, + TrusteeType: windows.TRUSTEE_IS_WELL_KNOWN_GROUP, + TrusteeValue: windows.TrusteeValueFromSID(adminGroupSid), + }, + }, + } + + dacl, err := windows.ACLFromEntries(explicitAccess, nil) + if err != nil { + return err + } + + return windows.SetNamedSecurityInfo(dirPath, windows.SE_FILE_OBJECT, securityFlags, user, group, dacl, nil) +} + +func sids() (*windows.SID, *windows.SID, error) { + var token windows.Token + err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_QUERY, &token) + if err != nil { + return nil, nil, err + } + defer func() { + if err := token.Close(); err != nil { + log.Errorf("failed to close process token: %v", err) + } + }() + + tu, err := token.GetTokenUser() + if err != nil { + return nil, nil, err + } + + pg, err := token.GetTokenPrimaryGroup() + if err != nil { + return nil, nil, err + } + + return tu.User.Sid, pg.PrimaryGroup, nil +}