diff --git a/management/server/account.go b/management/server/account.go index fe86b4481..b470e8036 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -76,7 +76,7 @@ type AccountManager interface { SaveOrAddUsers(ctx context.Context, accountID, initiatorUserID string, updates []*User, addIfNotExists bool) ([]*UserInfo, error) GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*SetupKey, error) GetAccountByUserOrAccountID(ctx context.Context, userID, accountID, domain string) (*Account, error) - GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) + GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error) DeleteAccount(ctx context.Context, accountID, userID string) error @@ -1738,10 +1738,10 @@ func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token st return account, user, pat, nil } -// GetAccountFromToken returns an account associated with this token -func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*Account, *User, error) { +// GetAccountFromToken returns an account associated with this token. +func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { if claims.UserId == "" { - return nil, nil, fmt.Errorf("user ID is empty") + return "", "", fmt.Errorf("user ID is empty") } if am.singleAccountMode && am.singleAccountModeDomain != "" { // This section is mostly related to self-hosted installations. @@ -1753,7 +1753,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims newAcc, err := am.getAccountWithAuthorizationClaims(ctx, claims) if err != nil { - return nil, nil, err + return "", "", err } unlock := am.Store.AcquireWriteLockByUID(ctx, newAcc.Id) alreadyUnlocked := false @@ -1765,26 +1765,27 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims account, err := am.Store.GetAccount(ctx, newAcc.Id) if err != nil { - return nil, nil, err + return "", "", err } user := account.Users[claims.UserId] if user == nil { // this is not really possible because we got an account by user ID - return nil, nil, status.Errorf(status.NotFound, "user %s not found", claims.UserId) + return "", "", status.Errorf(status.NotFound, "user %s not found", claims.UserId) } if !user.IsServiceUser && claims.Invited { err = am.redeemInvite(ctx, account, claims.UserId) if err != nil { - return nil, nil, err + return "", "", err } } if account.Settings.JWTGroupsEnabled { if account.Settings.JWTGroupsClaimName == "" { log.WithContext(ctx).Errorf("JWT groups are enabled but no claim name is set") - return account, user, nil + + return account.Id, user.Id, nil } jwtGroupsNames := extractJWTGroups(ctx, account.Settings.JWTGroupsClaimName, claims) @@ -1837,7 +1838,7 @@ func (am *DefaultAccountManager) GetAccountFromToken(ctx context.Context, claims } } - return account, user, nil + return account.Id, user.Id, nil } // getAccountWithAuthorizationClaims retrievs an account using JWT Claims. @@ -1863,6 +1864,7 @@ func (am *DefaultAccountManager) getAccountWithAuthorizationClaims(ctx context.C if claims.UserId == "" { return nil, fmt.Errorf("user ID is empty") } + // if Account ID is part of the claims // it means that we've already classified the domain and user has an account if claims.DomainCategory != PrivateCategory || !isDomainValid(claims.Domain) { @@ -2020,13 +2022,9 @@ func (am *DefaultAccountManager) GetDNSDomain() string { // CheckUserAccessByJWTGroups checks if the user has access, particularly in cases where the admin enabled JWT // group propagation and set the list of groups with access permissions. func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { - accountID := claims.AccountId - if accountID == "" { - user, err := am.GetUserByID(ctx, claims.UserId) - if err != nil { - return fmt.Errorf("failed to retrieve account for user %s: %v", claims.UserId, err) - } - accountID = user.AccountID + accountID, _, err := am.GetAccountFromToken(ctx, claims) + if err != nil { + return err } settings, err := am.Store.GetAccountSettings(ctx, accountID) diff --git a/management/server/http/accounts_handler.go b/management/server/http/accounts_handler.go index ffa5b9a28..fdeca8509 100644 --- a/management/server/http/accounts_handler.go +++ b/management/server/http/accounts_handler.go @@ -2,6 +2,7 @@ package http import ( "encoding/json" + "fmt" "net/http" "time" @@ -35,12 +36,18 @@ func NewAccountsHandler(accountManager server.AccountManager, authCfg AuthCfg) * // GetAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := h.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, claims.AccountId, "") if err != nil { util.WriteError(r.Context(), err, w) return } + user, ok := account.Users[claims.UserId] + if !ok { + util.WriteError(r.Context(), fmt.Errorf("user %s not found in account", claims.UserId), w) + return + } + if !(user.HasAdminPower() || user.IsServiceUser) { util.WriteError(r.Context(), status.Errorf(status.PermissionDenied, "the user has no permission to access account data"), w) return @@ -53,7 +60,7 @@ func (h *AccountsHandler) GetAllAccounts(w http.ResponseWriter, r *http.Request) // UpdateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - _, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + _, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -96,7 +103,7 @@ func (h *AccountsHandler) UpdateAccount(w http.ResponseWriter, r *http.Request) settings.JWTAllowGroups = *req.Settings.JwtAllowGroups } - updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, user.Id, settings) + updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/accounts_handler_test.go b/management/server/http/accounts_handler_test.go index 45c7679e5..b4e79ceb7 100644 --- a/management/server/http/accounts_handler_test.go +++ b/management/server/http/accounts_handler_test.go @@ -23,8 +23,11 @@ import ( func initAccountsTestData(account *server.Account, admin *server.User) *AccountsHandler { return &AccountsHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return account, admin, nil + GetAccountFromTokenFunc: func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return account.Id, admin.Id, nil + }, + GetAccountByUserOrAccountIdFunc: func(ctx context.Context, userId, accountId, domain string) (*server.Account, error) { + return account, nil }, UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *server.Settings) (*server.Account, error) { halfYearLimit := 180 * 24 * time.Hour diff --git a/management/server/http/dns_settings_handler.go b/management/server/http/dns_settings_handler.go index 74b0e1a55..f247489c5 100644 --- a/management/server/http/dns_settings_handler.go +++ b/management/server/http/dns_settings_handler.go @@ -32,14 +32,14 @@ func NewDNSSettingsHandler(accountManager server.AccountManager, authCfg AuthCfg // GetDNSSettings returns the DNS settings for the account func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), account.Id, user.Id) + dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -55,7 +55,7 @@ func (h *DNSSettingsHandler) GetDNSSettings(w http.ResponseWriter, r *http.Reque // UpdateDNSSettings handles update to DNS settings of an account func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -72,7 +72,7 @@ func (h *DNSSettingsHandler) UpdateDNSSettings(w http.ResponseWriter, r *http.Re DisabledManagementGroups: req.DisabledManagementGroups, } - err = h.accountManager.SaveDNSSettings(r.Context(), account.Id, user.Id, updateDNSSettings) + err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/dns_settings_handler_test.go b/management/server/http/dns_settings_handler_test.go index 897ae63dc..da1522bdc 100644 --- a/management/server/http/dns_settings_handler_test.go +++ b/management/server/http/dns_settings_handler_test.go @@ -52,8 +52,8 @@ func initDNSSettingsTestData() *DNSSettingsHandler { } return status.Errorf(status.InvalidArgument, "the dns settings provided are nil") }, - GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingDNSSettingsAccount, testingDNSSettingsAccount.Users[testDNSSettingsUserID], nil + GetAccountFromTokenFunc: func(ctx context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { + return testingDNSSettingsAccount.Id, testDNSSettingsUserID, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/events_handler.go b/management/server/http/events_handler.go index 428b4c164..d525a6aad 100644 --- a/management/server/http/events_handler.go +++ b/management/server/http/events_handler.go @@ -34,14 +34,14 @@ func NewEventsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ev // GetAllEvents list of the given account func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - accountEvents, err := h.accountManager.GetEvents(r.Context(), account.Id, user.Id) + accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -51,7 +51,7 @@ func (h *EventsHandler) GetAllEvents(w http.ResponseWriter, r *http.Request) { events[i] = toEventResponse(e) } - err = h.fillEventsWithUserInfo(r.Context(), events, account.Id, user.Id) + err = h.fillEventsWithUserInfo(r.Context(), events, accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/events_handler_test.go b/management/server/http/events_handler_test.go index 8bdd508bf..2f2c8f67d 100644 --- a/management/server/http/events_handler_test.go +++ b/management/server/http/events_handler_test.go @@ -29,14 +29,8 @@ func initEventsTestData(account string, user *server.User, events ...*activity.E } return []*activity.Event{}, nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return &server.Account{ - Id: claims.AccountId, - Domain: "hotmail.com", - Users: map[string]*server.User{ - user.Id: user, - }, - }, user, nil + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, user.Id, nil }, GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { return make([]*server.UserInfo, 0), nil diff --git a/management/server/http/geolocation_handler_test.go b/management/server/http/geolocation_handler_test.go index 7f4d6dc7c..8233d8872 100644 --- a/management/server/http/geolocation_handler_test.go +++ b/management/server/http/geolocation_handler_test.go @@ -43,14 +43,9 @@ func initGeolocationTestData(t *testing.T) *GeolocationsHandler { return &GeolocationsHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") - return &server.Account{ - Id: claims.AccountId, - Users: map[string]*server.User{ - "test_user": user, - }, - }, user, nil + GetUserByIDFunc: func(ctx context.Context, userID string) (*server.User, error) { + user := server.NewAdminUser(userID) + return user, nil }, }, geolocationManager: geo, diff --git a/management/server/http/geolocations_handler.go b/management/server/http/geolocations_handler.go index af4d3116f..8891e3522 100644 --- a/management/server/http/geolocations_handler.go +++ b/management/server/http/geolocations_handler.go @@ -98,7 +98,7 @@ func (l *GeolocationsHandler) GetCitiesByCountry(w http.ResponseWriter, r *http. func (l *GeolocationsHandler) authenticateUser(r *http.Request) error { claims := l.claimsExtractor.FromRequestContext(r) - _, user, err := l.accountManager.GetAccountFromToken(r.Context(), claims) + user, err := l.accountManager.GetUserByID(r.Context(), claims.UserId) if err != nil { return err } diff --git a/management/server/http/groups_handler.go b/management/server/http/groups_handler.go index c622d873a..0841a2c58 100644 --- a/management/server/http/groups_handler.go +++ b/management/server/http/groups_handler.go @@ -35,21 +35,15 @@ func NewGroupsHandler(accountManager server.AccountManager, authCfg AuthCfg) *Gr // GetAllGroups list for the account func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := h.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, claims.AccountId, "") if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - groups, err := h.accountManager.GetAllGroups(r.Context(), account.Id, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - groupsResponse := make([]*api.Group, 0, len(groups)) - for _, group := range groups { + groupsResponse := make([]*api.Group, 0, len(account.Groups)) + for _, group := range account.Groups { groupsResponse = append(groupsResponse, toGroupResponse(account, group)) } @@ -59,7 +53,7 @@ func (h *GroupsHandler) GetAllGroups(w http.ResponseWriter, r *http.Request) { // UpdateGroup handles update to a group identified by a given ID func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := h.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, claims.AccountId, "") if err != nil { util.WriteError(r.Context(), err, w) return @@ -118,7 +112,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { IntegrationReference: eg.IntegrationReference, } - if err := h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group); err != nil { + if err := h.accountManager.SaveGroup(r.Context(), account.Id, claims.UserId, &group); err != nil { log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, account.Id, err) util.WriteError(r.Context(), err, w) return @@ -130,7 +124,7 @@ func (h *GroupsHandler) UpdateGroup(w http.ResponseWriter, r *http.Request) { // CreateGroup handles group creation request func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := h.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, claims.AccountId, "") if err != nil { util.WriteError(r.Context(), err, w) return @@ -160,7 +154,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { Issued: nbgroup.GroupIssuedAPI, } - err = h.accountManager.SaveGroup(r.Context(), account.Id, user.Id, &group) + err = h.accountManager.SaveGroup(r.Context(), account.Id, claims.UserId, &group) if err != nil { util.WriteError(r.Context(), err, w) return @@ -172,7 +166,7 @@ func (h *GroupsHandler) CreateGroup(w http.ResponseWriter, r *http.Request) { // DeleteGroup handles group deletion request func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := h.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, claims.AccountId, "") if err != nil { util.WriteError(r.Context(), err, w) return @@ -196,7 +190,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteGroup(r.Context(), aID, user.Id, groupID) + err = h.accountManager.DeleteGroup(r.Context(), aID, claims.UserId, groupID) if err != nil { _, ok := err.(*server.GroupLinkError) if ok { @@ -213,7 +207,7 @@ func (h *GroupsHandler) DeleteGroup(w http.ResponseWriter, r *http.Request) { // GetGroup returns a group func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := h.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, claims.AccountId, "") if err != nil { util.WriteError(r.Context(), err, w) return @@ -227,7 +221,7 @@ func (h *GroupsHandler) GetGroup(w http.ResponseWriter, r *http.Request) { return } - group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, user.Id) + group, err := h.accountManager.GetGroup(r.Context(), account.Id, groupID, claims.UserId) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/groups_handler_test.go b/management/server/http/groups_handler_test.go index d5ed07c9e..0a8699a17 100644 --- a/management/server/http/groups_handler_test.go +++ b/management/server/http/groups_handler_test.go @@ -56,20 +56,23 @@ func initGroupTestData(user *server.User, _ ...*nbgroup.Group) *GroupsHandler { Issued: nbgroup.GroupIssuedAPI, }, nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, user.Id, nil + }, + GetAccountByUserOrAccountIdFunc: func(ctx context.Context, userId, accountId, domain string) (*server.Account, error) { return &server.Account{ - Id: claims.AccountId, + Id: accountId, Domain: "hotmail.com", Peers: TestPeers, Users: map[string]*server.User{ - user.Id: user, + userId: user, }, Groups: map[string]*nbgroup.Group{ "id-jwt-group": {ID: "id-jwt-group", Name: "From JWT", Issued: nbgroup.GroupIssuedJWT}, "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}, Issued: nbgroup.GroupIssuedAPI}, "id-all": {ID: "id-all", Name: "All", Issued: nbgroup.GroupIssuedAPI}, }, - }, user, nil + }, nil }, DeleteGroupFunc: func(_ context.Context, accountID, userId, groupID string) error { if groupID == "linked-grp" { diff --git a/management/server/http/nameservers_handler.go b/management/server/http/nameservers_handler.go index c6e00bb2d..9ade31e7f 100644 --- a/management/server/http/nameservers_handler.go +++ b/management/server/http/nameservers_handler.go @@ -36,14 +36,14 @@ func NewNameserversHandler(accountManager server.AccountManager, authCfg AuthCfg // GetAllNameservers returns the list of nameserver groups for the account func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), account.Id, user.Id) + nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -60,7 +60,7 @@ func (h *NameserversHandler) GetAllNameservers(w http.ResponseWriter, r *http.Re // CreateNameserverGroup handles nameserver group creation request func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -79,7 +79,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt return } - nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), account.Id, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, user.Id, req.SearchDomainsEnabled) + nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled) if err != nil { util.WriteError(r.Context(), err, w) return @@ -93,7 +93,7 @@ func (h *NameserversHandler) CreateNameserverGroup(w http.ResponseWriter, r *htt // UpdateNameserverGroup handles update to a nameserver group identified by a given ID func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -130,7 +130,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt SearchDomainsEnabled: req.SearchDomainsEnabled, } - err = h.accountManager.SaveNameServerGroup(r.Context(), account.Id, user.Id, updatedNSGroup) + err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup) if err != nil { util.WriteError(r.Context(), err, w) return @@ -144,7 +144,7 @@ func (h *NameserversHandler) UpdateNameserverGroup(w http.ResponseWriter, r *htt // DeleteNameserverGroup handles nameserver group deletion request func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -156,7 +156,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt return } - err = h.accountManager.DeleteNameServerGroup(r.Context(), account.Id, nsGroupID, user.Id) + err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -168,7 +168,7 @@ func (h *NameserversHandler) DeleteNameserverGroup(w http.ResponseWriter, r *htt // GetNameserverGroup handles a nameserver group Get request identified by ID func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { log.WithContext(r.Context()).Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -181,7 +181,7 @@ func (h *NameserversHandler) GetNameserverGroup(w http.ResponseWriter, r *http.R return } - nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), account.Id, user.Id, nsGroupID) + nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/nameservers_handler_test.go b/management/server/http/nameservers_handler_test.go index 28b080571..645e8585c 100644 --- a/management/server/http/nameservers_handler_test.go +++ b/management/server/http/nameservers_handler_test.go @@ -90,8 +90,8 @@ func initNameserversTestData() *NameserversHandler { } return status.Errorf(status.NotFound, "nameserver group with ID %s was not found", nsGroupToSave.ID) }, - GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingNSAccount, testingAccount.Users["test_user"], nil + GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { + return testingNSAccount.Id, "test_user", nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/pat_handler.go b/management/server/http/pat_handler.go index 9d8448d3d..d03c214ab 100644 --- a/management/server/http/pat_handler.go +++ b/management/server/http/pat_handler.go @@ -34,20 +34,20 @@ func NewPATsHandler(accountManager server.AccountManager, authCfg AuthCfg) *PATH // GetAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } vars := mux.Vars(r) - userID := vars["userId"] + targetUserID := vars["userId"] if len(userID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - pats, err := h.accountManager.GetAllPATs(r.Context(), account.Id, user.Id, userID) + pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -64,7 +64,7 @@ func (h *PATHandler) GetAllTokens(w http.ResponseWriter, r *http.Request) { // GetToken is HTTP GET handler that returns a personal access token for the given user func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -83,7 +83,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { return } - pat, err := h.accountManager.GetPAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) + pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -95,7 +95,7 @@ func (h *PATHandler) GetToken(w http.ResponseWriter, r *http.Request) { // CreateToken is HTTP POST handler that creates a personal access token for the given user func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -115,7 +115,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { return } - pat, err := h.accountManager.CreatePAT(r.Context(), account.Id, user.Id, targetUserID, req.Name, req.ExpiresIn) + pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn) if err != nil { util.WriteError(r.Context(), err, w) return @@ -127,7 +127,7 @@ func (h *PATHandler) CreateToken(w http.ResponseWriter, r *http.Request) { // DeleteToken is HTTP DELETE handler that deletes a personal access token for the given user func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -146,7 +146,7 @@ func (h *PATHandler) DeleteToken(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeletePAT(r.Context(), account.Id, user.Id, targetUserID, tokenID) + err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/pat_handler_test.go b/management/server/http/pat_handler_test.go index b72f71468..e0361fd5e 100644 --- a/management/server/http/pat_handler_test.go +++ b/management/server/http/pat_handler_test.go @@ -77,8 +77,8 @@ func initPATTestData() *PATHandler { }, nil }, - GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testAccount, testAccount.Users[existingUserID], nil + GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { + return testAccount.Id, existingUserID, nil }, DeletePATFunc: func(_ context.Context, accountID string, initiatorUserID string, targetUserID string, tokenID string) error { if accountID != existingAccountID { diff --git a/management/server/http/peers_handler.go b/management/server/http/peers_handler.go index 1487bbc39..83d3cd3b4 100644 --- a/management/server/http/peers_handler.go +++ b/management/server/http/peers_handler.go @@ -75,7 +75,7 @@ func (h *PeersHandler) getPeer(ctx context.Context, account *server.Account, pee util.WriteJSONObject(ctx, w, toSinglePeerResponse(peerToReturn, groupsInfo, dnsDomain, valid)) } -func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, user *server.User, peerID string, w http.ResponseWriter, r *http.Request) { +func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, userID, peerID string, w http.ResponseWriter, r *http.Request) { req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -97,7 +97,7 @@ func (h *PeersHandler) updatePeer(ctx context.Context, account *server.Account, } } - peer, err := h.accountManager.UpdatePeer(ctx, account.Id, user.Id, update) + peer, err := h.accountManager.UpdatePeer(ctx, account.Id, userID, update) if err != nil { util.WriteError(ctx, err, w) return @@ -131,7 +131,7 @@ func (h *PeersHandler) deletePeer(ctx context.Context, accountID, userID string, // HandlePeer handles all peer requests for GET, PUT and DELETE operations func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -145,13 +145,20 @@ func (h *PeersHandler) HandlePeer(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodDelete: - h.deletePeer(r.Context(), account.Id, user.Id, peerID, w) + h.deletePeer(r.Context(), accountID, userID, peerID, w) return - case http.MethodPut: - h.updatePeer(r.Context(), account, user, peerID, w, r) - return - case http.MethodGet: - h.getPeer(r.Context(), account, peerID, user.Id, w) + case http.MethodGet, http.MethodPut: + account, err := h.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, accountID, "") + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + if r.Method == http.MethodGet { + h.getPeer(r.Context(), account, peerID, userID, w) + } else { + h.updatePeer(r.Context(), account, userID, peerID, w, r) + } return default: util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) @@ -166,13 +173,13 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := h.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, claims.AccountId, "") if err != nil { util.WriteError(r.Context(), err, w) return } - peers, err := h.accountManager.GetPeers(r.Context(), account.Id, user.Id) + peers, err := h.accountManager.GetPeers(r.Context(), account.Id, claims.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -194,7 +201,7 @@ func (h *PeersHandler) GetAllPeers(w http.ResponseWriter, r *http.Request) { validPeersMap, err := h.accountManager.GetValidatedPeers(account) if err != nil { - log.WithContext(r.Context()).Errorf("failed to list appreoved peers: %v", err) + log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } @@ -215,7 +222,7 @@ func (h *PeersHandler) setApprovalRequiredFlag(respBody []*api.PeerBatch, approv // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. func (h *PeersHandler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, _, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := h.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, claims.AccountId, "") if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/peers_handler_test.go b/management/server/http/peers_handler_test.go index 153c8f03a..e271992a6 100644 --- a/management/server/http/peers_handler_test.go +++ b/management/server/http/peers_handler_test.go @@ -12,6 +12,7 @@ import ( "time" "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -20,7 +21,6 @@ import ( "github.com/magiconair/properties/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/mock_server" ) @@ -59,17 +59,20 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { GetDNSDomainFunc: func() string { return "netbird.selfhosted" }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, "test_user", nil + }, + GetAccountByUserOrAccountIdFunc: func(ctx context.Context, userId, accountId, domain string) (*server.Account, error) { user := server.NewAdminUser("test_user") return &server.Account{ - Id: claims.AccountId, + Id: accountId, Domain: "hotmail.com", Peers: map[string]*nbpeer.Peer{ peers[0].ID: peers[0], peers[1].ID: peers[1], }, Users: map[string]*server.User{ - "test_user": user, + userId: user, }, Settings: &server.Settings{ PeerLoginExpirationEnabled: true, @@ -83,7 +86,7 @@ func initTestMetaData(peers ...*nbpeer.Peer) *PeersHandler { }, Serial: 51, }, - }, user, nil + }, nil }, HasConnectedChannelFunc: func(peerID string) bool { statuses := make(map[string]struct{}) diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 9622668f4..aa5bce47a 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -35,13 +35,13 @@ func NewPoliciesHandler(accountManager server.AccountManager, authCfg AuthCfg) * // GetAllPolicies list for the account func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := h.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, claims.AccountId, "") if err != nil { util.WriteError(r.Context(), err, w) return } - accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, user.Id) + accountPolicies, err := h.accountManager.ListPolicies(r.Context(), account.Id, claims.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -63,7 +63,7 @@ func (h *Policies) GetAllPolicies(w http.ResponseWriter, r *http.Request) { // UpdatePolicy handles update to a policy identified by a given ID func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := h.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, claims.AccountId, "") if err != nil { util.WriteError(r.Context(), err, w) return @@ -88,19 +88,19 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { return } - h.savePolicy(w, r, account, user, policyID) + h.savePolicy(w, r, account, claims.UserId, policyID) } // CreatePolicy handles policy creation request func (h *Policies) CreatePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := h.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, claims.AccountId, "") if err != nil { util.WriteError(r.Context(), err, w) return } - h.savePolicy(w, r, account, user, "") + h.savePolicy(w, r, account, claims.UserId, "") } // savePolicy handles policy creation and update @@ -108,7 +108,7 @@ func (h *Policies) savePolicy( w http.ResponseWriter, r *http.Request, account *server.Account, - user *server.User, + userID string, policyID string, ) { var req api.PutApiPoliciesPolicyIdJSONRequestBody @@ -210,7 +210,7 @@ func (h *Policies) savePolicy( policy.SourcePostureChecks = sourcePostureChecksToStrings(account, *req.SourcePostureChecks) } - if err := h.accountManager.SavePolicy(r.Context(), account.Id, user.Id, &policy); err != nil { + if err := h.accountManager.SavePolicy(r.Context(), account.Id, userID, &policy); err != nil { util.WriteError(r.Context(), err, w) return } @@ -227,12 +227,11 @@ func (h *Policies) savePolicy( // DeletePolicy handles policy deletion request func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - aID := account.Id vars := mux.Vars(r) policyID := vars["policyId"] @@ -241,7 +240,7 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { return } - if err = h.accountManager.DeletePolicy(r.Context(), aID, policyID, user.Id); err != nil { + if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil { util.WriteError(r.Context(), err, w) return } @@ -252,37 +251,33 @@ func (h *Policies) DeletePolicy(w http.ResponseWriter, r *http.Request) { // GetPolicy handles a group Get request identified by ID func (h *Policies) GetPolicy(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := h.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, claims.AccountId, "") if err != nil { util.WriteError(r.Context(), err, w) return } - switch r.Method { - case http.MethodGet: - vars := mux.Vars(r) - policyID := vars["policyId"] - if len(policyID) == 0 { - util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) - return - } - - policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, user.Id) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - resp := toPolicyResponse(account, policy) - if len(resp.Rules) == 0 { - util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) - return - } - - util.WriteJSONObject(r.Context(), w, resp) - default: - util.WriteError(r.Context(), status.Errorf(status.NotFound, "method not found"), w) + vars := mux.Vars(r) + policyID := vars["policyId"] + if len(policyID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid policy ID"), w) + return } + + policy, err := h.accountManager.GetPolicy(r.Context(), account.Id, policyID, claims.UserId) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resp := toPolicyResponse(account, policy) + if len(resp.Rules) == 0 { + util.WriteError(r.Context(), status.Errorf(status.Internal, "no rules in the policy"), w) + return + } + + util.WriteJSONObject(r.Context(), w, resp) + } func toPolicyResponse(account *server.Account, policy *server.Policy) *api.Policy { diff --git a/management/server/http/policies_handler_test.go b/management/server/http/policies_handler_test.go index 06274fb07..13801aa3b 100644 --- a/management/server/http/policies_handler_test.go +++ b/management/server/http/policies_handler_test.go @@ -45,10 +45,13 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { } return nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, "test_user", nil + }, + GetAccountByUserOrAccountIdFunc: func(ctx context.Context, userId, accountId, domain string) (*server.Account, error) { + user := server.NewAdminUser(userId) return &server.Account{ - Id: claims.AccountId, + Id: accountId, Domain: "hotmail.com", Policies: []*server.Policy{ {ID: "id-existed"}, @@ -58,9 +61,9 @@ func initPoliciesTestData(policies ...*server.Policy) *Policies { "G": {ID: "G"}, }, Users: map[string]*server.User{ - "test_user": user, + userId: user, }, - }, user, nil + }, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/posture_checks_handler.go b/management/server/http/posture_checks_handler.go index 059cb3b80..042d423e4 100644 --- a/management/server/http/posture_checks_handler.go +++ b/management/server/http/posture_checks_handler.go @@ -37,13 +37,13 @@ func NewPostureChecksHandler(accountManager server.AccountManager, geolocationMa // GetAllPostureChecks list for the account func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), account.Id, user.Id) + accountPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -60,7 +60,7 @@ func (p *PostureChecksHandler) GetAllPostureChecks(w http.ResponseWriter, r *htt // UpdatePostureCheck handles update to a posture check identified by a given ID func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := p.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, claims.AccountId, "") if err != nil { util.WriteError(r.Context(), err, w) return @@ -85,7 +85,7 @@ func (p *PostureChecksHandler) UpdatePostureCheck(w http.ResponseWriter, r *http return } - p.savePostureChecks(w, r, account, user, postureChecksID) + p.savePostureChecks(w, r, account.Id, claims.UserId, postureChecksID) } // CreatePostureCheck handles posture check creation request @@ -103,7 +103,7 @@ func (p *PostureChecksHandler) CreatePostureCheck(w http.ResponseWriter, r *http // GetPostureCheck handles a posture check Get request identified by ID func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -116,7 +116,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re return } - postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), account.Id, postureChecksID, user.Id) + postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -128,7 +128,7 @@ func (p *PostureChecksHandler) GetPostureCheck(w http.ResponseWriter, r *http.Re // DeletePostureCheck handles posture check deletion request func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http.Request) { claims := p.claimsExtractor.FromRequestContext(r) - account, user, err := p.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := p.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -141,7 +141,7 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http return } - if err = p.accountManager.DeletePostureChecks(r.Context(), account.Id, postureChecksID, user.Id); err != nil { + if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil { util.WriteError(r.Context(), err, w) return } @@ -153,8 +153,8 @@ func (p *PostureChecksHandler) DeletePostureCheck(w http.ResponseWriter, r *http func (p *PostureChecksHandler) savePostureChecks( w http.ResponseWriter, r *http.Request, - account *server.Account, - user *server.User, + accountID string, + userID string, postureChecksID string, ) { var ( @@ -181,7 +181,7 @@ func (p *PostureChecksHandler) savePostureChecks( return } - if err := p.accountManager.SavePostureChecks(r.Context(), account.Id, user.Id, postureChecks); err != nil { + if err := p.accountManager.SavePostureChecks(r.Context(), accountID, userID, postureChecks); err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/http/posture_checks_handler_test.go b/management/server/http/posture_checks_handler_test.go index 974edafde..71bca7de8 100644 --- a/management/server/http/posture_checks_handler_test.go +++ b/management/server/http/posture_checks_handler_test.go @@ -12,9 +12,9 @@ import ( "testing" "github.com/gorilla/mux" + "github.com/netbirdio/netbird/management/server" "github.com/stretchr/testify/assert" - "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" @@ -67,15 +67,18 @@ func initPostureChecksTestData(postureChecks ...*posture.Checks) *PostureChecksH } return accountPostureChecks, nil }, - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - user := server.NewAdminUser("test_user") + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return claims.AccountId, "test_user", nil + }, + GetAccountByUserOrAccountIdFunc: func(ctx context.Context, userId, accountId, domain string) (*server.Account, error) { + user := server.NewAdminUser(userId) return &server.Account{ - Id: claims.AccountId, + Id: accountId, Users: map[string]*server.User{ - "test_user": user, + userId: user, }, PostureChecks: postureChecks, - }, user, nil + }, nil }, }, geolocationManager: &geolocation.Geolocation{}, diff --git a/management/server/http/routes_handler.go b/management/server/http/routes_handler.go index 18c347334..cf72108cd 100644 --- a/management/server/http/routes_handler.go +++ b/management/server/http/routes_handler.go @@ -43,13 +43,13 @@ func NewRoutesHandler(accountManager server.AccountManager, authCfg AuthCfg) *Ro // GetAllRoutes returns the list of routes for the account func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - routes, err := h.accountManager.ListRoutes(r.Context(), account.Id, user.Id) + routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -70,7 +70,7 @@ func (h *RoutesHandler) GetAllRoutes(w http.ResponseWriter, r *http.Request) { // CreateRoute handles route creation request func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := h.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, claims.AccountId, "") if err != nil { util.WriteError(r.Context(), err, w) return @@ -125,7 +125,7 @@ func (h *RoutesHandler) CreateRoute(w http.ResponseWriter, r *http.Request) { } } - newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, user.Id, req.KeepRoute) + newRoute, err := h.accountManager.CreateRoute(r.Context(), account.Id, newPrefix, networkType, domains, peerId, peerGroupIds, req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, req.Enabled, claims.UserId, req.KeepRoute) if err != nil { util.WriteError(r.Context(), err, w) return @@ -168,7 +168,7 @@ func (h *RoutesHandler) validateRoute(req api.PostApiRoutesJSONRequestBody) erro // UpdateRoute handles update to a route identified by a given ID func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := h.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, claims.AccountId, "") if err != nil { util.WriteError(r.Context(), err, w) return @@ -181,7 +181,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { return } - _, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) + _, err = h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), claims.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -247,7 +247,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { newRoute.PeerGroups = *req.PeerGroups } - err = h.accountManager.SaveRoute(r.Context(), account.Id, user.Id, newRoute) + err = h.accountManager.SaveRoute(r.Context(), account.Id, claims.UserId, newRoute) if err != nil { util.WriteError(r.Context(), err, w) return @@ -265,7 +265,7 @@ func (h *RoutesHandler) UpdateRoute(w http.ResponseWriter, r *http.Request) { // DeleteRoute handles route deletion request func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -277,7 +277,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteRoute(r.Context(), account.Id, route.ID(routeID), user.Id) + err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -289,7 +289,7 @@ func (h *RoutesHandler) DeleteRoute(w http.ResponseWriter, r *http.Request) { // GetRoute handles a route Get request identified by ID func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -301,7 +301,7 @@ func (h *RoutesHandler) GetRoute(w http.ResponseWriter, r *http.Request) { return } - foundRoute, err := h.accountManager.GetRoute(r.Context(), account.Id, route.ID(routeID), user.Id) + foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID) if err != nil { util.WriteError(r.Context(), status.Errorf(status.NotFound, "route not found"), w) return diff --git a/management/server/http/routes_handler_test.go b/management/server/http/routes_handler_test.go index 40075eb9d..0a0fd8109 100644 --- a/management/server/http/routes_handler_test.go +++ b/management/server/http/routes_handler_test.go @@ -139,8 +139,11 @@ func initRoutesTestData() *RoutesHandler { } return nil }, - GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return testingAccount, testingAccount.Users["test_user"], nil + GetAccountFromTokenFunc: func(_ context.Context, _ jwtclaims.AuthorizationClaims) (string, string, error) { + return testingAccount.Id, "test_user", nil + }, + GetAccountByUserOrAccountIdFunc: func(ctx context.Context, userId, accountId, domain string) (*server.Account, error) { + return testingAccount, nil }, }, claimsExtractor: jwtclaims.NewClaimsExtractor( diff --git a/management/server/http/setupkeys_handler.go b/management/server/http/setupkeys_handler.go index 8ee7dfaba..b0e54f254 100644 --- a/management/server/http/setupkeys_handler.go +++ b/management/server/http/setupkeys_handler.go @@ -35,7 +35,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authCfg AuthCfg) // CreateSetupKey is a POST requests that creates a new SetupKey func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -76,8 +76,8 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request if req.Ephemeral != nil { ephemeral = *req.Ephemeral } - setupKey, err := h.accountManager.CreateSetupKey(r.Context(), account.Id, req.Name, server.SetupKeyType(req.Type), expiresIn, - req.AutoGroups, req.UsageLimit, user.Id, ephemeral) + setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, server.SetupKeyType(req.Type), expiresIn, + req.AutoGroups, req.UsageLimit, userID, ephemeral) if err != nil { util.WriteError(r.Context(), err, w) return @@ -89,7 +89,7 @@ func (h *SetupKeysHandler) CreateSetupKey(w http.ResponseWriter, r *http.Request // GetSetupKey is a GET request to get a SetupKey by ID func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -102,7 +102,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { return } - key, err := h.accountManager.GetSetupKey(r.Context(), account.Id, user.Id, keyID) + key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -114,7 +114,7 @@ func (h *SetupKeysHandler) GetSetupKey(w http.ResponseWriter, r *http.Request) { // UpdateSetupKey is a PUT request to update server.SetupKey func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -150,7 +150,7 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request newKey.Name = req.Name newKey.Id = keyID - newKey, err = h.accountManager.SaveSetupKey(r.Context(), account.Id, newKey, user.Id) + newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -161,13 +161,13 @@ func (h *SetupKeysHandler) UpdateSetupKey(w http.ResponseWriter, r *http.Request // GetAllSetupKeys is a GET request that returns a list of SetupKey func (h *SetupKeysHandler) GetAllSetupKeys(w http.ResponseWriter, r *http.Request) { claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), account.Id, user.Id) + setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/setupkeys_handler_test.go b/management/server/http/setupkeys_handler_test.go index bfa0ec008..daff5d6e2 100644 --- a/management/server/http/setupkeys_handler_test.go +++ b/management/server/http/setupkeys_handler_test.go @@ -15,7 +15,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server" - nbgroup "github.com/netbirdio/netbird/management/server/group" "github.com/netbirdio/netbird/management/server/http/api" "github.com/netbirdio/netbird/management/server/jwtclaims" "github.com/netbirdio/netbird/management/server/mock_server" @@ -34,21 +33,8 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup ) *SetupKeysHandler { return &SetupKeysHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return &server.Account{ - Id: testAccountID, - Domain: "hotmail.com", - Users: map[string]*server.User{ - user.Id: user, - }, - SetupKeys: map[string]*server.SetupKey{ - defaultKey.Key: defaultKey, - }, - Groups: map[string]*nbgroup.Group{ - "group-1": {ID: "group-1", Peers: []string{"A", "B"}}, - "id-all": {ID: "id-all", Name: "All"}, - }, - }, user, nil + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return testAccountID, user.Id, nil }, CreateSetupKeyFunc: func(_ context.Context, _ string, keyName string, typ server.SetupKeyType, _ time.Duration, _ []string, _ int, _ string, ephemeral bool, diff --git a/management/server/http/users_handler.go b/management/server/http/users_handler.go index 2c2aed842..cb78b1801 100644 --- a/management/server/http/users_handler.go +++ b/management/server/http/users_handler.go @@ -41,7 +41,7 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + account, err := h.accountManager.GetAccountByUserOrAccountID(r.Context(), claims.UserId, claims.AccountId, "") if err != nil { util.WriteError(r.Context(), err, w) return @@ -78,7 +78,7 @@ func (h *UsersHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { return } - newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, user.Id, &server.User{ + newUser, err := h.accountManager.SaveUser(r.Context(), account.Id, claims.UserId, &server.User{ Id: userID, Role: userRole, AutoGroups: req.AutoGroups, @@ -102,7 +102,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -115,7 +115,7 @@ func (h *UsersHandler) DeleteUser(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteUser(r.Context(), account.Id, user.Id, targetUserID) + err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -132,7 +132,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -160,7 +160,7 @@ func (h *UsersHandler) CreateUser(w http.ResponseWriter, r *http.Request) { name = *req.Name } - newUser, err := h.accountManager.CreateUser(r.Context(), account.Id, user.Id, &server.UserInfo{ + newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &server.UserInfo{ Email: email, Name: name, Role: req.Role, @@ -184,13 +184,13 @@ func (h *UsersHandler) GetAllUsers(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return } - data, err := h.accountManager.GetUsersFromAccount(r.Context(), account.Id, user.Id) + data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -231,7 +231,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { } claims := h.claimsExtractor.FromRequestContext(r) - account, user, err := h.accountManager.GetAccountFromToken(r.Context(), claims) + accountID, userID, err := h.accountManager.GetAccountFromToken(r.Context(), claims) if err != nil { util.WriteError(r.Context(), err, w) return @@ -244,7 +244,7 @@ func (h *UsersHandler) InviteUser(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.InviteUser(r.Context(), account.Id, user.Id, targetUserID) + err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/users_handler_test.go b/management/server/http/users_handler_test.go index a78ac3a4e..5388c1159 100644 --- a/management/server/http/users_handler_test.go +++ b/management/server/http/users_handler_test.go @@ -64,8 +64,11 @@ var usersTestAccount = &server.Account{ func initUsersTestData() *UsersHandler { return &UsersHandler{ accountManager: &mock_server.MockAccountManager{ - GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) { - return usersTestAccount, usersTestAccount.Users[claims.UserId], nil + GetAccountFromTokenFunc: func(_ context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) { + return usersTestAccount.Id, claims.UserId, nil + }, + GetAccountByUserOrAccountIdFunc: func(_ context.Context, _, _, _ string) (*server.Account, error) { + return usersTestAccount, nil }, GetUsersFromAccountFunc: func(_ context.Context, accountID, userID string) ([]*server.UserInfo, error) { users := make([]*server.UserInfo, 0) diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 476edf19f..7be6f2dca 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -80,7 +80,7 @@ type MockAccountManager struct { DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) CreateUserFunc func(ctx context.Context, accountID, userID string, key *server.UserInfo) (*server.UserInfo, error) - GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, error) + GetAccountFromTokenFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error) CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.AuthorizationClaims) error DeleteAccountFunc func(ctx context.Context, accountID, userID string) error GetDNSDomainFunc func() string @@ -611,13 +611,13 @@ func (am *MockAccountManager) CreateUser(ctx context.Context, accountID, userID } // GetAccountFromToken mocks GetAccountFromToken of the AccountManager interface -func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*server.Account, *server.User, +func (am *MockAccountManager) GetAccountFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error, ) { if am.GetAccountFromTokenFunc != nil { return am.GetAccountFromTokenFunc(ctx, claims) } - return nil, nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") + return "", "", status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") } func (am *MockAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error { diff --git a/management/server/user.go b/management/server/user.go index 25916ae98..85840f408 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -371,15 +371,15 @@ func (am *DefaultAccountManager) GetUserByID(ctx context.Context, id string) (*U // GetUser looks up a user by provided authorization claims. // It will also create an account if didn't exist for this user before. func (am *DefaultAccountManager) GetUser(ctx context.Context, claims jwtclaims.AuthorizationClaims) (*User, error) { - account, _, err := am.GetAccountFromToken(ctx, claims) + accountID, _, err := am.GetAccountFromToken(ctx, claims) if err != nil { return nil, fmt.Errorf("failed to get account with token claims %v", err) } - unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err = am.Store.GetAccount(ctx, account.Id) + account, err := am.Store.GetAccount(ctx, accountID) if err != nil { return nil, fmt.Errorf("failed to get an account from store %v", err) }