diff --git a/management/server/store/sqlstore_bench_test.go b/management/server/store/sqlstore_bench_test.go index add3fffe5..c538c41aa 100644 --- a/management/server/store/sqlstore_bench_test.go +++ b/management/server/store/sqlstore_bench_test.go @@ -2,6 +2,7 @@ package store import ( "context" + "database/sql" "encoding/json" "errors" "fmt" @@ -20,7 +21,6 @@ import ( "github.com/jackc/pgx/v5/pgxpool" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/dns" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -695,79 +695,6 @@ func normalizeNilSlices(acc *types.Account) { } } -func TestGetAccountEquals(t *testing.T) { - store, accountID := setupBenchmarkDB(t) - ctx := context.Background() - expected, err := store.GetAccountSlow(ctx, accountID) - require.NoError(t, err) - actual, err := store.GetAccount(ctx, accountID) - require.NoError(t, err) - - require.Equal(t, expected.DNSSettings, actual.DNSSettings) - require.Equal(t, expected.Domain, actual.Domain) - require.Equal(t, expected.DomainCategory, actual.DomainCategory) - require.Equal(t, expected.IsDomainPrimaryAccount, actual.IsDomainPrimaryAccount) - require.Equal(t, expected.Id, actual.Id) - require.Equal(t, expected.CreatedBy, actual.CreatedBy) - require.WithinDuration(t, expected.CreatedAt, actual.CreatedAt, time.Second) - - require.Equal(t, len(expected.SetupKeys), len(actual.SetupKeys)) - for k, v := range expected.SetupKeys { - v2, ok := actual.SetupKeys[k] - require.True(t, ok) - require.Equal(t, v, v2) - } - - require.Equal(t, len(expected.Peers), len(actual.Peers)) - for k, v := range expected.Peers { - v2, ok := actual.Peers[k] - require.True(t, ok) - require.Equal(t, v, v2) - } - - require.Equal(t, len(expected.Users), len(actual.Users)) - for k, v := range expected.Users { - v2, ok := actual.Users[k] - require.True(t, ok) - require.Equal(t, v, v2) - require.Equal(t, len(v.PATs), len(v2.PATs)) - for k3, v3 := range v.PATs { - v4, ok := v2.PATs[k3] - require.True(t, ok) - require.Equal(t, v3, v4) - } - } - require.Equal(t, len(expected.Groups), len(actual.Groups)) - for k, v := range expected.Groups { - v2, ok := actual.Groups[k] - require.True(t, ok) - require.Equal(t, v, v2) - } - require.Equal(t, len(expected.Routes), len(actual.Routes)) - for k, v := range expected.Routes { - v2, ok := actual.Routes[k] - require.True(t, ok) - require.Equal(t, v, v2) - } - require.Equal(t, len(expected.NameServerGroups), len(actual.NameServerGroups)) - for k, v := range expected.NameServerGroups { - v2, ok := actual.NameServerGroups[k] - require.True(t, ok) - require.Equal(t, v, v2) - } - - require.Equal(t, expected.Policies, actual.Policies) - - require.Equal(t, expected.PostureChecks, actual.PostureChecks) - - require.Equal(t, expected.Network, actual.Network) - - require.Equal(t, expected.Networks, actual.Networks) - require.Equal(t, expected.NetworkRouters, actual.NetworkRouters) - require.Equal(t, expected.NetworkResources, actual.NetworkResources) - require.Equal(t, expected.Onboarding, actual.Onboarding) -} - func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*types.Account, error) { var account types.Account account.Network = &types.Network{} @@ -837,10 +764,46 @@ func (s *SqlStore) GetAccountPureSQL(ctx context.Context, accountID string) (*ty peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { var p nbpeer.Peer - p.Status = &nbpeer.PeerStatus{} + var lastLogin sql.NullTime + var sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool + var peerStatusLastSeen sql.NullTime + var peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool var ip, extraDNS, netAddr, env, flags, files, connIP []byte - err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &p.SSHEnabled, &p.LoginExpirationEnabled, &p.InactivityExpirationEnabled, &p.LastLogin, &p.CreatedAt, &p.Ephemeral, &extraDNS, &p.AllowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &p.Status.LastSeen, &p.Status.Connected, &p.Status.LoginExpired, &p.Status.RequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) + + err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &p.CreatedAt, &ephemeral, &extraDNS, &allowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) + if err == nil { + if lastLogin.Valid { + p.LastLogin = &lastLogin.Time + } + if sshEnabled.Valid { + p.SSHEnabled = sshEnabled.Bool + } + if loginExpirationEnabled.Valid { + p.LoginExpirationEnabled = loginExpirationEnabled.Bool + } + if inactivityExpirationEnabled.Valid { + p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool + } + if ephemeral.Valid { + p.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if peerStatusLastSeen.Valid { + p.Status.LastSeen = peerStatusLastSeen.Time + } + if peerStatusConnected.Valid { + p.Status.Connected = peerStatusConnected.Bool + } + if peerStatusLoginExpired.Valid { + p.Status.LoginExpired = peerStatusLoginExpired.Bool + } + if peerStatusRequiresApproval.Valid { + p.Status.RequiresApproval = peerStatusRequiresApproval.Bool + } + if ip != nil { _ = json.Unmarshal(ip, &p.IP) }