diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 44d3a031c..7124486ad 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -949,8 +949,6 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc } if autoGroups != nil { _ = json.Unmarshal(autoGroups, &sk.AutoGroups) - } else { - sk.AutoGroups = []string{} } } return sk, err @@ -975,9 +973,46 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc peers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbpeer.Peer, error) { var p nbpeer.Peer p.Status = &nbpeer.PeerStatus{} + var lastLogin sql.NullTime + var sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool + var peerStatusLastSeen sql.NullTime + var peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval sql.NullBool var ip, extraDNS, netAddr, env, flags, files, connIP []byte - err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &p.SSHEnabled, &p.LoginExpirationEnabled, &p.InactivityExpirationEnabled, &p.LastLogin, &p.CreatedAt, &p.Ephemeral, &extraDNS, &p.AllowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &p.Status.LastSeen, &p.Status.Connected, &p.Status.LoginExpired, &p.Status.RequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) + + err := row.Scan(&p.ID, &p.AccountID, &p.Key, &ip, &p.Name, &p.DNSLabel, &p.UserID, &p.SSHKey, &sshEnabled, &loginExpirationEnabled, &inactivityExpirationEnabled, &lastLogin, &p.CreatedAt, &ephemeral, &extraDNS, &allowExtraDNSLabels, &p.Meta.Hostname, &p.Meta.GoOS, &p.Meta.Kernel, &p.Meta.Core, &p.Meta.Platform, &p.Meta.OS, &p.Meta.OSVersion, &p.Meta.WtVersion, &p.Meta.UIVersion, &p.Meta.KernelVersion, &netAddr, &p.Meta.SystemSerialNumber, &p.Meta.SystemProductName, &p.Meta.SystemManufacturer, &env, &flags, &files, &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, &p.Location.CountryCode, &p.Location.CityName, &p.Location.GeoNameID) + if err == nil { + if lastLogin.Valid { + p.LastLogin = &lastLogin.Time + } + if sshEnabled.Valid { + p.SSHEnabled = sshEnabled.Bool + } + if loginExpirationEnabled.Valid { + p.LoginExpirationEnabled = loginExpirationEnabled.Bool + } + if inactivityExpirationEnabled.Valid { + p.InactivityExpirationEnabled = inactivityExpirationEnabled.Bool + } + if ephemeral.Valid { + p.Ephemeral = ephemeral.Bool + } + if allowExtraDNSLabels.Valid { + p.AllowExtraDNSLabels = allowExtraDNSLabels.Bool + } + if peerStatusLastSeen.Valid { + p.Status.LastSeen = peerStatusLastSeen.Time + } + if peerStatusConnected.Valid { + p.Status.Connected = peerStatusConnected.Bool + } + if peerStatusLoginExpired.Valid { + p.Status.LoginExpired = peerStatusLoginExpired.Bool + } + if peerStatusRequiresApproval.Valid { + p.Status.RequiresApproval = peerStatusRequiresApproval.Bool + } + if ip != nil { _ = json.Unmarshal(ip, &p.IP) } @@ -1042,8 +1077,6 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc } if autoGroups != nil { _ = json.Unmarshal(autoGroups, &u.AutoGroups) - } else { - u.AutoGroups = []string{} } } return u, err @@ -1079,11 +1112,7 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc } if resources != nil { _ = json.Unmarshal(resources, &g.Resources) - } else { - g.Resources = []types.Resource{} } - g.GroupPeers = []types.GroupPeer{} - g.Peers = []string{} } return &g, err }) @@ -1106,9 +1135,15 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc policies, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.Policy, error) { var p types.Policy var checks []byte - err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &p.Enabled, &checks) - if err == nil && checks != nil { - _ = json.Unmarshal(checks, &p.SourcePostureChecks) + var enabled sql.NullBool + err := row.Scan(&p.ID, &p.AccountID, &p.Name, &p.Description, &enabled, &checks) + if err == nil { + if enabled.Valid { + p.Enabled = enabled.Bool + } + if checks != nil { + _ = json.Unmarshal(checks, &p.SourcePostureChecks) + } } return &p, err }) @@ -1131,8 +1166,25 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc routes, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (route.Route, error) { var r route.Route var network, domains, peerGroups, groups, accessGroups []byte - err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &r.KeepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &r.Masquerade, &r.Metric, &r.Enabled, &groups, &accessGroups, &r.SkipAutoApply) + var keepRoute, masquerade, enabled, skipAutoApply sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.AccountID, &network, &domains, &keepRoute, &r.NetID, &r.Description, &r.Peer, &peerGroups, &r.NetworkType, &masquerade, &metric, &enabled, &groups, &accessGroups, &skipAutoApply) if err == nil { + if keepRoute.Valid { + r.KeepRoute = keepRoute.Bool + } + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if skipAutoApply.Valid { + r.SkipAutoApply = skipAutoApply.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } if network != nil { _ = json.Unmarshal(network, &r.Network) } @@ -1170,8 +1222,18 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc nsgs, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (nbdns.NameServerGroup, error) { var n nbdns.NameServerGroup var ns, groups, domains []byte - err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &n.Primary, &domains, &n.Enabled, &n.SearchDomainsEnabled) + var primary, enabled, searchDomainsEnabled sql.NullBool + err := row.Scan(&n.ID, &n.AccountID, &n.Name, &n.Description, &ns, &groups, &primary, &domains, &enabled, &searchDomainsEnabled) if err == nil { + if primary.Valid { + n.Primary = primary.Bool + } + if enabled.Valid { + n.Enabled = enabled.Bool + } + if searchDomainsEnabled.Valid { + n.SearchDomainsEnabled = searchDomainsEnabled.Bool + } if ns != nil { _ = json.Unmarshal(ns, &n.NameServers) } else { @@ -1251,20 +1313,36 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc errChan <- err return } - routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*routerTypes.NetworkRouter, error) { + routers, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (routerTypes.NetworkRouter, error) { var r routerTypes.NetworkRouter var peerGroups []byte - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &r.Masquerade, &r.Metric, &r.Enabled) - if err == nil && peerGroups != nil { - _ = json.Unmarshal(peerGroups, &r.PeerGroups) + var masquerade, enabled sql.NullBool + var metric sql.NullInt64 + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Peer, &peerGroups, &masquerade, &metric, &enabled) + if err == nil { + if masquerade.Valid { + r.Masquerade = masquerade.Bool + } + if enabled.Valid { + r.Enabled = enabled.Bool + } + if metric.Valid { + r.Metric = int(metric.Int64) + } + if peerGroups != nil { + _ = json.Unmarshal(peerGroups, &r.PeerGroups) + } } - return &r, err + return r, err }) if err != nil { errChan <- err return } - account.NetworkRouters = routers + account.NetworkRouters = make([]*routerTypes.NetworkRouter, len(routers)) + for i := range routers { + account.NetworkRouters[i] = &routers[i] + } }() wg.Add(1) @@ -1276,35 +1354,52 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc errChan <- err return } - resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (*resourceTypes.NetworkResource, error) { + resources, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (resourceTypes.NetworkResource, error) { var r resourceTypes.NetworkResource var prefix []byte - err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &r.Enabled) - if err == nil && prefix != nil { - _ = json.Unmarshal(prefix, &r.Prefix) + var enabled sql.NullBool + err := row.Scan(&r.ID, &r.NetworkID, &r.AccountID, &r.Name, &r.Description, &r.Type, &r.Domain, &prefix, &enabled) + if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if prefix != nil { + _ = json.Unmarshal(prefix, &r.Prefix) + } } - return &r, err + return r, err }) if err != nil { errChan <- err return } - account.NetworkResources = resources + account.NetworkResources = make([]*resourceTypes.NetworkResource, len(resources)) + for i := range resources { + account.NetworkResources[i] = &resources[i] + } }() wg.Add(1) go func() { defer wg.Done() const query = `SELECT account_id, onboarding_flow_pending, signup_form_pending, created_at, updated_at FROM account_onboardings WHERE account_id = $1` + var onboardingFlowPending, signupFormPending sql.NullBool err := s.pool.QueryRow(ctx, query, accountID).Scan( &account.Onboarding.AccountID, - &account.Onboarding.OnboardingFlowPending, - &account.Onboarding.SignupFormPending, + &onboardingFlowPending, + &signupFormPending, &account.Onboarding.CreatedAt, &account.Onboarding.UpdatedAt, ) if err != nil && !errors.Is(err, pgx.ErrNoRows) { errChan <- err + return + } + if onboardingFlowPending.Valid { + account.Onboarding.OnboardingFlowPending = onboardingFlowPending.Bool + } + if signupFormPending.Valid { + account.Onboarding.SignupFormPending = signupFormPending.Bool } }() @@ -1344,7 +1439,20 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc errChan <- err return } - pats, err = pgx.CollectRows(rows, pgx.RowToStructByName[types.PersonalAccessToken]) + pats, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (types.PersonalAccessToken, error) { + var pat types.PersonalAccessToken + var expirationDate, lastUsed sql.NullTime + err := row.Scan(&pat.ID, &pat.UserID, &pat.Name, &pat.HashedToken, &expirationDate, &pat.CreatedBy, &pat.CreatedAt, &lastUsed) + if err == nil { + if expirationDate.Valid { + pat.ExpirationDate = &expirationDate.Time + } + if lastUsed.Valid { + pat.LastUsed = &lastUsed.Time + } + } + return pat, err + }) if err != nil { errChan <- err } @@ -1365,8 +1473,15 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc rules, err = pgx.CollectRows(rows, func(row pgx.CollectableRow) (*types.PolicyRule, error) { var r types.PolicyRule var dest, destRes, sources, sourceRes, ports, portRanges []byte - err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &r.Enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &r.Bidirectional, &r.Protocol, &ports, &portRanges) + var enabled, bidirectional sql.NullBool + err := row.Scan(&r.ID, &r.PolicyID, &r.Name, &r.Description, &enabled, &r.Action, &dest, &destRes, &sources, &sourceRes, &bidirectional, &r.Protocol, &ports, &portRanges) if err == nil { + if enabled.Valid { + r.Enabled = enabled.Bool + } + if bidirectional.Valid { + r.Bidirectional = bidirectional.Bool + } if dest != nil { _ = json.Unmarshal(dest, &r.Destinations) }