Hide content based on user role (#541)

This commit is contained in:
Misha Bragin
2022-11-05 10:24:50 +01:00
committed by GitHub
parent e8d82c1bd3
commit 4321b71984
27 changed files with 305 additions and 142 deletions

View File

@@ -33,7 +33,7 @@ func NewGroups(accountManager server.AccountManager, authAudience string) *Group
// GetAllGroupsHandler list for the account
func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -50,7 +50,7 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) {
// UpdateGroupHandler handles update to a group identified by a given ID
func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -111,7 +111,7 @@ func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) {
// PatchGroupHandler handles patch updates to a group identified by a given ID
func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -226,7 +226,7 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) {
// CreateGroupHandler handles group creation request
func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -260,7 +260,7 @@ func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) {
// DeleteGroupHandler handles group deletion request
func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -295,7 +295,7 @@ func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) {
// GetGroupHandler returns a group
func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return

View File

@@ -25,7 +25,7 @@ var TestPeers = map[string]*server.Peer{
"B": &server.Peer{Key: "B", IP: net.ParseIP("200.200.200.200")},
}
func initGroupTestData(groups ...*server.Group) *Groups {
func initGroupTestData(user *server.User, groups ...*server.Group) *Groups {
return &Groups{
accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(accountID string, group *server.Group) error {
@@ -72,6 +72,9 @@ func initGroupTestData(groups ...*server.Group) *Groups {
Id: claims.AccountId,
Domain: "hotmail.com",
Peers: TestPeers,
Users: map[string]*server.User{
user.Id: user,
},
Groups: map[string]*server.Group{
"id-existed": {ID: "id-existed", Peers: []string{"A", "B"}},
"id-all": {ID: "id-all", Name: "All"}},
@@ -120,7 +123,8 @@ func TestGetGroup(t *testing.T) {
Name: "Group",
}
p := initGroupTestData(group)
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser, group)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -270,7 +274,8 @@ func TestWriteGroup(t *testing.T) {
},
}
p := initGroupTestData()
adminUser := server.NewAdminUser("test_user")
p := initGroupTestData(adminUser)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {

View File

@@ -1,17 +1,12 @@
package middleware
import (
"context"
"fmt"
"net/http"
"github.com/netbirdio/netbird/management/server/jwtclaims"
)
const (
IsUserAdminProperty = "isAdminUser"
)
type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error)
// AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only
@@ -50,10 +45,6 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler {
}
}
newRequest := r.Clone(context.WithValue(r.Context(), IsUserAdminProperty, ok)) //nolint
// Update the current request with the new context information.
*r = *newRequest
h.ServeHTTP(w, r)
})
}

View File

@@ -30,7 +30,7 @@ func NewNameservers(accountManager server.AccountManager, authAudience string) *
// GetAllNameserversHandler returns the list of nameserver groups for the account
func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -53,7 +53,7 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re
// CreateNameserverGroupHandler handles nameserver group creation request
func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -84,7 +84,7 @@ func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *htt
// UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID
func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -131,7 +131,7 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt
// PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID
func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -202,7 +202,7 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http
// DeleteNameserverGroupHandler handles nameserver group deletion request
func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -226,7 +226,7 @@ func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *htt
// GetNameserverGroupHandler handles a nameserver group Get request identified by ID
func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)

View File

@@ -29,6 +29,9 @@ const (
var testingNSAccount = &server.Account{
Id: testNSGroupAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}
var baseExistingNSGroup = &nbdns.NameServerGroup{

View File

@@ -56,7 +56,7 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW
}
func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -95,15 +95,20 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) {
func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet:
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
peers, err := h.accountManager.GetPeers(account.Id, user.Id)
if err != nil {
return
}
respBody := []*api.Peer{}
for _, peer := range account.Peers {
for _, peer := range peers {
respBody = append(respBody, toPeerResponse(peer, account))
}
writeJSONObject(w, respBody)

View File

@@ -16,15 +16,21 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
)
func initTestMetaData(peer ...*server.Peer) *Peers {
func initTestMetaData(peers ...*server.Peer) *Peers {
return &Peers{
accountManager: &mock_server.MockAccountManager{
GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) {
return peers, nil
},
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{
Id: claims.AccountId,
Domain: "hotmail.com",
Peers: map[string]*server.Peer{
"test_peer": peer[0],
"test_peer": peers[0],
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}, nil
},

View File

@@ -33,16 +33,24 @@ func NewRoutes(accountManager server.AccountManager, authAudience string) *Route
// GetAllRoutesHandler returns the list of routes for the account
func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
routes, err := h.accountManager.ListRoutes(account.Id)
routes, err := h.accountManager.ListRoutes(account.Id, user.Id)
if err != nil {
log.Error(err)
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.PermissionDenied:
http.Error(w, e.Error(), http.StatusForbidden)
return
default:
}
}
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
@@ -56,7 +64,7 @@ func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) {
// CreateRouteHandler handles route creation request
func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -104,7 +112,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) {
// UpdateRouteHandler handles update to a route identified by a given ID
func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -117,7 +125,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
return
}
_, err = h.accountManager.GetRoute(account.Id, routeID)
_, err = h.accountManager.GetRoute(account.Id, routeID, "")
if err != nil {
http.Error(w, fmt.Sprintf("couldn't find route for ID %s", routeID), http.StatusNotFound)
return
@@ -177,7 +185,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) {
// PatchRouteHandler handles patch updates to a route identified by a given ID
func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -190,7 +198,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
return
}
_, err = h.accountManager.GetRoute(account.Id, routeID)
_, err = h.accountManager.GetRoute(account.Id, routeID, "")
if err != nil {
log.Error(err)
http.Error(w, fmt.Sprintf("couldn't find route ID %s", routeID), http.StatusNotFound)
@@ -334,7 +342,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) {
// DeleteRouteHandler handles route deletion request
func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -363,7 +371,7 @@ func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) {
// GetRouteHandler handles a route Get request identified by ID
func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -375,7 +383,7 @@ func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) {
return
}
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID)
foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id)
if err != nil {
http.Error(w, "route not found", http.StatusNotFound)
return

View File

@@ -51,12 +51,15 @@ var testingAccount = &server.Account{
IP: netip.MustParseAddr(existingPeerID).AsSlice(),
},
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}
func initRoutesTestData() *Routes {
return &Routes{
accountManager: &mock_server.MockAccountManager{
GetRouteFunc: func(_, routeID string) (*route.Route, error) {
GetRouteFunc: func(_, routeID, _ string) (*route.Route, error) {
if routeID == existingRouteID {
return baseExistingRoute, nil
}

View File

@@ -31,15 +31,29 @@ func NewRules(accountManager server.AccountManager, authAudience string) *Rules
// GetAllRulesHandler list for the account
func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountRules, err := h.accountManager.ListRules(account.Id, user.Id)
if err != nil {
log.Error(err)
if e, ok := server.FromError(err); ok {
switch e.Type() {
case server.PermissionDenied:
http.Error(w, e.Error(), http.StatusForbidden)
return
default:
}
}
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
rules := []*api.Rule{}
for _, r := range account.Rules {
for _, r := range accountRules {
rules = append(rules, toRuleResponse(account, r))
}
@@ -48,7 +62,7 @@ func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) {
// UpdateRuleHandler handles update to a rule identified by a given ID
func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -118,7 +132,7 @@ func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) {
// PatchRuleHandler handles patch updates to a rule identified by a given ID
func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -275,7 +289,7 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) {
// CreateRuleHandler handles rule creation request
func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -332,7 +346,7 @@ func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) {
// DeleteRuleHandler handles rule deletion request
func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -356,7 +370,7 @@ func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) {
// GetRuleHandler handles a group Get request identified by ID
func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
@@ -370,7 +384,7 @@ func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) {
return
}
rule, err := h.accountManager.GetRule(account.Id, ruleID)
rule, err := h.accountManager.GetRule(account.Id, ruleID, user.Id)
if err != nil {
http.Error(w, "rule not found", http.StatusNotFound)
return

View File

@@ -28,7 +28,7 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
}
return nil
},
GetRuleFunc: func(_, ruleID string) (*server.Rule, error) {
GetRuleFunc: func(_, ruleID, _ string) (*server.Rule, error) {
if ruleID != "idoftherule" {
return nil, fmt.Errorf("not found")
}
@@ -75,6 +75,9 @@ func initRulesTestData(rules ...*server.Rule) *Rules {
"F": {ID: "F"},
"G": {ID: "G"},
},
Users: map[string]*server.User{
"test_user": server.NewAdminUser("test_user"),
},
}, nil
},
},

View File

@@ -6,15 +6,12 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/http/api"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/jwtclaims"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"net/http"
"strings"
"time"
"unicode/utf8"
)
// SetupKeys is a handler that returns a list of setup keys of the account
@@ -34,7 +31,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authAudience stri
// CreateSetupKeyHandler is a POST requests that creates a new SetupKey
func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -84,7 +81,7 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetSetupKeyHandler is a GET request to get a SetupKey by ID
func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -98,7 +95,7 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
return
}
key, err := h.accountManager.GetSetupKey(account.Id, keyID)
key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID)
if err != nil {
errStatus, ok := status.FromError(err)
if ok && errStatus.Code() == codes.NotFound {
@@ -110,17 +107,12 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
return
}
isUserAdmin, ok := r.Context().Value(middleware.IsUserAdminProperty).(bool)
if !ok || !isUserAdmin {
key = hideKey(key)
}
writeSuccess(w, key)
}
// UpdateSetupKeyHandler is a PUT request to update server.SetupKey
func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -176,19 +168,14 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request
// GetAllSetupKeysHandler is a GET request that returns a list of SetupKey
func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) {
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
isUserAdmin, ok := r.Context().Value(middleware.IsUserAdminProperty).(bool)
if !ok {
isUserAdmin = false
}
setupKeys, err := h.accountManager.ListSetupKeys(account.Id)
setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -196,23 +183,12 @@ func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Reques
}
apiSetupKeys := make([]*api.SetupKey, 0)
for _, key := range setupKeys {
k := key.Copy()
if !isUserAdmin {
k = hideKey(key)
}
apiSetupKeys = append(apiSetupKeys, toResponseBody(k))
apiSetupKeys = append(apiSetupKeys, toResponseBody(key))
}
writeJSONObject(w, apiSetupKeys)
}
func hideKey(key *server.SetupKey) *server.SetupKey {
k := key.Copy()
prefix := k.Key[0:5]
k.Key = prefix + strings.Repeat("*", utf8.RuneCountInString(key.Key)-len(prefix))
return k
}
func writeSuccess(w http.ResponseWriter, key *server.SetupKey) {
w.WriteHeader(200)
w.Header().Set("Content-Type", "application/json")

View File

@@ -2,7 +2,6 @@ package http
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/gorilla/mux"
@@ -29,13 +28,17 @@ const (
notFoundSetupKeyID = "notFoundSetupKeyID"
)
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey) *SetupKeys {
func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey,
user *server.User) *SetupKeys {
return &SetupKeys{
accountManager: &mock_server.MockAccountManager{
GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) {
return &server.Account{
Id: testAccountID,
Domain: "hotmail.com",
Users: map[string]*server.User{
user.Id: user,
},
SetupKeys: map[string]*server.SetupKey{
defaultKey.Key: defaultKey,
},
@@ -50,7 +53,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
}
return nil, fmt.Errorf("failed creating setup key")
},
GetSetupKeyFunc: func(accountID string, keyID string) (*server.SetupKey, error) {
GetSetupKeyFunc: func(accountID, userID, keyID string) (*server.SetupKey, error) {
switch keyID {
case defaultKey.Id:
return defaultKey, nil
@@ -68,7 +71,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id)
},
ListSetupKeysFunc: func(accountID string) ([]*server.SetupKey, error) {
ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) {
return []*server.SetupKey{defaultKey}, nil
},
},
@@ -76,7 +79,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
UserId: user.Id,
Domain: "hotmail.com",
AccountId: testAccountID,
}
@@ -89,6 +92,8 @@ func TestSetupKeysHandlers(t *testing.T) {
defaultSetupKey := server.GenerateDefaultSetupKey()
defaultSetupKey.Id = existingSetupKeyID
adminUser := server.NewAdminUser("test_user")
newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"})
updatedDefaultSetupKey := defaultSetupKey.Copy()
updatedDefaultSetupKey.AutoGroups = []string{"group-1"}
@@ -154,13 +159,12 @@ func TestSetupKeysHandlers(t *testing.T) {
},
}
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey)
handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody)
req = req.Clone(context.WithValue(context.TODO(), "isAdminUser", true)) //nolint
router := mux.NewRouter()
router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS")

View File

@@ -34,7 +34,7 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", http.StatusBadRequest)
}
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
@@ -87,7 +87,7 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request)
http.Error(w, "", http.StatusNotFound)
}
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
}
@@ -132,12 +132,13 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) {
http.Error(w, "", http.StatusBadRequest)
}
account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
data, err := h.accountManager.GetUsersFromAccount(account.Id)
data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id)
if err != nil {
log.Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)

View File

@@ -27,7 +27,7 @@ func initUsers(user ...*server.User) *UserHandler {
Users: users,
}, nil
},
GetUsersFromAccountFunc: func(accountID string) ([]*server.UserInfo, error) {
GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) {
users := make([]*server.UserInfo, 0)
for _, v := range user {
users = append(users, &server.UserInfo{
@@ -44,7 +44,7 @@ func initUsers(user ...*server.User) *UserHandler {
jwtExtractor: jwtclaims.ClaimsExtractor{
ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims {
return jwtclaims.AuthorizationClaims{
UserId: "test_user",
UserId: "1",
Domain: "hotmail.com",
AccountId: "test_id",
}

View File

@@ -56,16 +56,22 @@ func (d *Duration) UnmarshalJSON(b []byte) error {
func getJWTAccount(accountManager server.AccountManager,
jwtExtractor jwtclaims.ClaimsExtractor,
authAudience string, r *http.Request) (*server.Account, error) {
authAudience string, r *http.Request) (*server.Account, *server.User, error) {
jwtClaims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience)
claims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience)
account, err := accountManager.GetAccountFromToken(jwtClaims)
account, err := accountManager.GetAccountFromToken(claims)
if err != nil {
return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err)
return nil, nil, fmt.Errorf("failed getting account of a user %s: %v", claims.UserId, err)
}
return account, nil
user := account.Users[claims.UserId]
if user == nil {
// this is not really possible because we got an account by user ID
return nil, nil, fmt.Errorf("user %s not found", claims.UserId)
}
return account, user, nil
}
func toHTTPError(err error, w http.ResponseWriter) {