diff --git a/management/server/account.go b/management/server/account.go index b0555726d..31206124f 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -47,15 +47,16 @@ type AccountManager interface { ) (*SetupKey, error) SaveSetupKey(accountID string, key *SetupKey) (*SetupKey, error) CreateUser(accountID string, key *UserInfo) (*UserInfo, error) - ListSetupKeys(accountID string) ([]*SetupKey, error) + ListSetupKeys(accountID, userID string) ([]*SetupKey, error) SaveUser(accountID string, key *User) (*UserInfo, error) - GetSetupKey(accountID, keyID string) (*SetupKey, error) + GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) GetAccountById(accountId string) (*Account, error) GetAccountByUserOrAccountId(userId, accountId, domain string) (*Account, error) GetAccountFromToken(claims jwtclaims.AuthorizationClaims) (*Account, error) IsUserAdmin(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExists(accountId string) (*bool, error) GetPeer(peerKey string) (*Peer, error) + GetPeers(accountID, userID string) ([]*Peer, error) MarkPeerConnected(peerKey string, connected bool) error RenamePeer(accountId string, peerKey string, newName string) (*Peer, error) DeletePeer(accountId string, peerKey string) (*Peer, error) @@ -66,7 +67,7 @@ type AccountManager interface { AddPeer(setupKey string, userId string, peer *Peer) (*Peer, error) UpdatePeerMeta(peerKey string, meta PeerSystemMeta) error UpdatePeerSSHKey(peerKey string, sshKey string) error - GetUsersFromAccount(accountId string) ([]*UserInfo, error) + GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) GetGroup(accountId, groupID string) (*Group, error) SaveGroup(accountId string, group *Group) error UpdateGroup(accountID string, groupID string, operations []GroupUpdateOperation) (*Group, error) @@ -75,17 +76,17 @@ type AccountManager interface { GroupAddPeer(accountId, groupID, peerKey string) error GroupDeletePeer(accountId, groupID, peerKey string) error GroupListPeers(accountId, groupID string) ([]*Peer, error) - GetRule(accountId, ruleID string) (*Rule, error) + GetRule(accountID, ruleID, userID string) (*Rule, error) SaveRule(accountID string, rule *Rule) error UpdateRule(accountID string, ruleID string, operations []RuleUpdateOperation) (*Rule, error) DeleteRule(accountId, ruleID string) error - ListRules(accountId string) ([]*Rule, error) - GetRoute(accountID, routeID string) (*route.Route, error) + ListRules(accountID, userID string) ([]*Rule, error) + GetRoute(accountID, routeID, userID string) (*route.Route, error) CreateRoute(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) SaveRoute(accountID string, route *route.Route) error UpdateRoute(accountID string, routeID string, operations []RouteUpdateOperation) (*route.Route, error) DeleteRoute(accountID, routeID string) error - ListRoutes(accountID string) ([]*route.Route, error) + ListRoutes(accountID, userID string) ([]*route.Route, error) GetNameServerGroup(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroup(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) SaveNameServerGroup(accountID string, nsGroupToSave *nbdns.NameServerGroup) error @@ -142,6 +143,16 @@ type UserInfo struct { Status string `json:"-"` } +// FindUser looks for a given user in the Account or returns error if user wasn't found. +func (a *Account) FindUser(userID string) (*User, error) { + user := a.Users[userID] + if user == nil { + return nil, Errorf(UserNotFound, "user %s not found", userID) + } + + return user, nil +} + func (a *Account) Copy() *Account { peers := map[string]*Peer{} for id, peer := range a.Peers { diff --git a/management/server/account_test.go b/management/server/account_test.go index 503a36805..5c479436c 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -874,7 +874,7 @@ func TestGetUsersFromAccount(t *testing.T) { account.Users[user.Id] = user } - userInfos, err := manager.GetUsersFromAccount(accountId) + userInfos, err := manager.GetUsersFromAccount(accountId, "1") if err != nil { t.Fatal(err) } diff --git a/management/server/error.go b/management/server/error.go index 1f7bce7f8..f84cdb9d7 100644 --- a/management/server/error.go +++ b/management/server/error.go @@ -11,6 +11,12 @@ const ( AccountNotFound ErrorType = iota // PreconditionFailed indicates that some pre-condition for the operation hasn't been fulfilled PreconditionFailed ErrorType = iota + + // UserNotFound indicates that user wasn't found in the system (or under a given Account) + UserNotFound ErrorType = iota + + // PermissionDenied indicates that user has no permissions to view data + PermissionDenied ErrorType = iota ) // ErrorType is a type of the Error diff --git a/management/server/http/groups.go b/management/server/http/groups.go index 185e0ec0c..fbf660ad8 100644 --- a/management/server/http/groups.go +++ b/management/server/http/groups.go @@ -33,7 +33,7 @@ func NewGroups(accountManager server.AccountManager, authAudience string) *Group // GetAllGroupsHandler list for the account func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -50,7 +50,7 @@ func (h *Groups) GetAllGroupsHandler(w http.ResponseWriter, r *http.Request) { // UpdateGroupHandler handles update to a group identified by a given ID func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { http.Redirect(w, r, "/", http.StatusInternalServerError) return @@ -111,7 +111,7 @@ func (h *Groups) UpdateGroupHandler(w http.ResponseWriter, r *http.Request) { // PatchGroupHandler handles patch updates to a group identified by a given ID func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { http.Redirect(w, r, "/", http.StatusInternalServerError) return @@ -226,7 +226,7 @@ func (h *Groups) PatchGroupHandler(w http.ResponseWriter, r *http.Request) { // CreateGroupHandler handles group creation request func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { http.Redirect(w, r, "/", http.StatusInternalServerError) return @@ -260,7 +260,7 @@ func (h *Groups) CreateGroupHandler(w http.ResponseWriter, r *http.Request) { // DeleteGroupHandler handles group deletion request func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { http.Redirect(w, r, "/", http.StatusInternalServerError) return @@ -295,7 +295,7 @@ func (h *Groups) DeleteGroupHandler(w http.ResponseWriter, r *http.Request) { // GetGroupHandler returns a group func (h *Groups) GetGroupHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { http.Redirect(w, r, "/", http.StatusInternalServerError) return diff --git a/management/server/http/groups_test.go b/management/server/http/groups_test.go index aab18df49..1bfb62292 100644 --- a/management/server/http/groups_test.go +++ b/management/server/http/groups_test.go @@ -25,7 +25,7 @@ var TestPeers = map[string]*server.Peer{ "B": &server.Peer{Key: "B", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(groups ...*server.Group) *Groups { +func initGroupTestData(user *server.User, groups ...*server.Group) *Groups { return &Groups{ accountManager: &mock_server.MockAccountManager{ SaveGroupFunc: func(accountID string, group *server.Group) error { @@ -72,6 +72,9 @@ func initGroupTestData(groups ...*server.Group) *Groups { Id: claims.AccountId, Domain: "hotmail.com", Peers: TestPeers, + Users: map[string]*server.User{ + user.Id: user, + }, Groups: map[string]*server.Group{ "id-existed": {ID: "id-existed", Peers: []string{"A", "B"}}, "id-all": {ID: "id-all", Name: "All"}}, @@ -120,7 +123,8 @@ func TestGetGroup(t *testing.T) { Name: "Group", } - p := initGroupTestData(group) + adminUser := server.NewAdminUser("test_user") + p := initGroupTestData(adminUser, group) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -270,7 +274,8 @@ func TestWriteGroup(t *testing.T) { }, } - p := initGroupTestData() + adminUser := server.NewAdminUser("test_user") + p := initGroupTestData(adminUser) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { diff --git a/management/server/http/middleware/access_control.go b/management/server/http/middleware/access_control.go index 3592363a7..be4276206 100644 --- a/management/server/http/middleware/access_control.go +++ b/management/server/http/middleware/access_control.go @@ -1,17 +1,12 @@ package middleware import ( - "context" "fmt" "net/http" "github.com/netbirdio/netbird/management/server/jwtclaims" ) -const ( - IsUserAdminProperty = "isAdminUser" -) - type IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) // AccessControl middleware to restrict to make POST/PUT/DELETE requests by admin only @@ -50,10 +45,6 @@ func (a *AccessControl) Handler(h http.Handler) http.Handler { } } - newRequest := r.Clone(context.WithValue(r.Context(), IsUserAdminProperty, ok)) //nolint - // Update the current request with the new context information. - *r = *newRequest - h.ServeHTTP(w, r) }) } diff --git a/management/server/http/nameservers.go b/management/server/http/nameservers.go index fed939374..55ac13ddf 100644 --- a/management/server/http/nameservers.go +++ b/management/server/http/nameservers.go @@ -30,7 +30,7 @@ func NewNameservers(accountManager server.AccountManager, authAudience string) * // GetAllNameserversHandler returns the list of nameserver groups for the account func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -53,7 +53,7 @@ func (h *Nameservers) GetAllNameserversHandler(w http.ResponseWriter, r *http.Re // CreateNameserverGroupHandler handles nameserver group creation request func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -84,7 +84,7 @@ func (h *Nameservers) CreateNameserverGroupHandler(w http.ResponseWriter, r *htt // UpdateNameserverGroupHandler handles update to a nameserver group identified by a given ID func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -131,7 +131,7 @@ func (h *Nameservers) UpdateNameserverGroupHandler(w http.ResponseWriter, r *htt // PatchNameserverGroupHandler handles patch updates to a nameserver group identified by a given ID func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -202,7 +202,7 @@ func (h *Nameservers) PatchNameserverGroupHandler(w http.ResponseWriter, r *http // DeleteNameserverGroupHandler handles nameserver group deletion request func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -226,7 +226,7 @@ func (h *Nameservers) DeleteNameserverGroupHandler(w http.ResponseWriter, r *htt // GetNameserverGroupHandler handles a nameserver group Get request identified by ID func (h *Nameservers) GetNameserverGroupHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) diff --git a/management/server/http/nameservers_test.go b/management/server/http/nameservers_test.go index 2433078f8..5e78e0624 100644 --- a/management/server/http/nameservers_test.go +++ b/management/server/http/nameservers_test.go @@ -29,6 +29,9 @@ const ( var testingNSAccount = &server.Account{ Id: testNSGroupAccountID, Domain: "hotmail.com", + Users: map[string]*server.User{ + "test_user": server.NewAdminUser("test_user"), + }, } var baseExistingNSGroup = &nbdns.NameServerGroup{ diff --git a/management/server/http/peers.go b/management/server/http/peers.go index f9c04a831..2bd7e7243 100644 --- a/management/server/http/peers.go +++ b/management/server/http/peers.go @@ -56,7 +56,7 @@ func (h *Peers) deletePeer(accountId string, peer *server.Peer, w http.ResponseW } func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -95,15 +95,20 @@ func (h *Peers) HandlePeer(w http.ResponseWriter, r *http.Request) { func (h *Peers) GetPeers(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } + peers, err := h.accountManager.GetPeers(account.Id, user.Id) + if err != nil { + return + } + respBody := []*api.Peer{} - for _, peer := range account.Peers { + for _, peer := range peers { respBody = append(respBody, toPeerResponse(peer, account)) } writeJSONObject(w, respBody) diff --git a/management/server/http/peers_test.go b/management/server/http/peers_test.go index 45339f338..648ac48b4 100644 --- a/management/server/http/peers_test.go +++ b/management/server/http/peers_test.go @@ -16,15 +16,21 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" ) -func initTestMetaData(peer ...*server.Peer) *Peers { +func initTestMetaData(peers ...*server.Peer) *Peers { return &Peers{ accountManager: &mock_server.MockAccountManager{ + GetPeersFunc: func(accountID, userID string) ([]*server.Peer, error) { + return peers, nil + }, GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { return &server.Account{ Id: claims.AccountId, Domain: "hotmail.com", Peers: map[string]*server.Peer{ - "test_peer": peer[0], + "test_peer": peers[0], + }, + Users: map[string]*server.User{ + "test_user": server.NewAdminUser("test_user"), }, }, nil }, diff --git a/management/server/http/routes.go b/management/server/http/routes.go index 1d8fa1503..8158d452d 100644 --- a/management/server/http/routes.go +++ b/management/server/http/routes.go @@ -33,16 +33,24 @@ func NewRoutes(accountManager server.AccountManager, authAudience string) *Route // GetAllRoutesHandler returns the list of routes for the account func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - routes, err := h.accountManager.ListRoutes(account.Id) + routes, err := h.accountManager.ListRoutes(account.Id, user.Id) if err != nil { log.Error(err) + if e, ok := server.FromError(err); ok { + switch e.Type() { + case server.PermissionDenied: + http.Error(w, e.Error(), http.StatusForbidden) + return + default: + } + } http.Redirect(w, r, "/", http.StatusInternalServerError) return } @@ -56,7 +64,7 @@ func (h *Routes) GetAllRoutesHandler(w http.ResponseWriter, r *http.Request) { // CreateRouteHandler handles route creation request func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { http.Redirect(w, r, "/", http.StatusInternalServerError) return @@ -104,7 +112,7 @@ func (h *Routes) CreateRouteHandler(w http.ResponseWriter, r *http.Request) { // UpdateRouteHandler handles update to a route identified by a given ID func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { http.Redirect(w, r, "/", http.StatusInternalServerError) return @@ -117,7 +125,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) { return } - _, err = h.accountManager.GetRoute(account.Id, routeID) + _, err = h.accountManager.GetRoute(account.Id, routeID, "") if err != nil { http.Error(w, fmt.Sprintf("couldn't find route for ID %s", routeID), http.StatusNotFound) return @@ -177,7 +185,7 @@ func (h *Routes) UpdateRouteHandler(w http.ResponseWriter, r *http.Request) { // PatchRouteHandler handles patch updates to a route identified by a given ID func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { http.Redirect(w, r, "/", http.StatusInternalServerError) return @@ -190,7 +198,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { return } - _, err = h.accountManager.GetRoute(account.Id, routeID) + _, err = h.accountManager.GetRoute(account.Id, routeID, "") if err != nil { log.Error(err) http.Error(w, fmt.Sprintf("couldn't find route ID %s", routeID), http.StatusNotFound) @@ -334,7 +342,7 @@ func (h *Routes) PatchRouteHandler(w http.ResponseWriter, r *http.Request) { // DeleteRouteHandler handles route deletion request func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { http.Redirect(w, r, "/", http.StatusInternalServerError) return @@ -363,7 +371,7 @@ func (h *Routes) DeleteRouteHandler(w http.ResponseWriter, r *http.Request) { // GetRouteHandler handles a route Get request identified by ID func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { http.Redirect(w, r, "/", http.StatusInternalServerError) return @@ -375,7 +383,7 @@ func (h *Routes) GetRouteHandler(w http.ResponseWriter, r *http.Request) { return } - foundRoute, err := h.accountManager.GetRoute(account.Id, routeID) + foundRoute, err := h.accountManager.GetRoute(account.Id, routeID, user.Id) if err != nil { http.Error(w, "route not found", http.StatusNotFound) return diff --git a/management/server/http/routes_test.go b/management/server/http/routes_test.go index fe4f1bddf..aaaf2f6b8 100644 --- a/management/server/http/routes_test.go +++ b/management/server/http/routes_test.go @@ -51,12 +51,15 @@ var testingAccount = &server.Account{ IP: netip.MustParseAddr(existingPeerID).AsSlice(), }, }, + Users: map[string]*server.User{ + "test_user": server.NewAdminUser("test_user"), + }, } func initRoutesTestData() *Routes { return &Routes{ accountManager: &mock_server.MockAccountManager{ - GetRouteFunc: func(_, routeID string) (*route.Route, error) { + GetRouteFunc: func(_, routeID, _ string) (*route.Route, error) { if routeID == existingRouteID { return baseExistingRoute, nil } diff --git a/management/server/http/rules.go b/management/server/http/rules.go index b1f4db2b8..1842f5557 100644 --- a/management/server/http/rules.go +++ b/management/server/http/rules.go @@ -31,15 +31,29 @@ func NewRules(accountManager server.AccountManager, authAudience string) *Rules // GetAllRulesHandler list for the account func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } + accountRules, err := h.accountManager.ListRules(account.Id, user.Id) + if err != nil { + log.Error(err) + if e, ok := server.FromError(err); ok { + switch e.Type() { + case server.PermissionDenied: + http.Error(w, e.Error(), http.StatusForbidden) + return + default: + } + } + http.Redirect(w, r, "/", http.StatusInternalServerError) + return + } rules := []*api.Rule{} - for _, r := range account.Rules { + for _, r := range accountRules { rules = append(rules, toRuleResponse(account, r)) } @@ -48,7 +62,7 @@ func (h *Rules) GetAllRulesHandler(w http.ResponseWriter, r *http.Request) { // UpdateRuleHandler handles update to a rule identified by a given ID func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { http.Redirect(w, r, "/", http.StatusInternalServerError) return @@ -118,7 +132,7 @@ func (h *Rules) UpdateRuleHandler(w http.ResponseWriter, r *http.Request) { // PatchRuleHandler handles patch updates to a rule identified by a given ID func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { http.Redirect(w, r, "/", http.StatusInternalServerError) return @@ -275,7 +289,7 @@ func (h *Rules) PatchRuleHandler(w http.ResponseWriter, r *http.Request) { // CreateRuleHandler handles rule creation request func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { http.Redirect(w, r, "/", http.StatusInternalServerError) return @@ -332,7 +346,7 @@ func (h *Rules) CreateRuleHandler(w http.ResponseWriter, r *http.Request) { // DeleteRuleHandler handles rule deletion request func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { http.Redirect(w, r, "/", http.StatusInternalServerError) return @@ -356,7 +370,7 @@ func (h *Rules) DeleteRuleHandler(w http.ResponseWriter, r *http.Request) { // GetRuleHandler handles a group Get request identified by ID func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { http.Redirect(w, r, "/", http.StatusInternalServerError) return @@ -370,7 +384,7 @@ func (h *Rules) GetRuleHandler(w http.ResponseWriter, r *http.Request) { return } - rule, err := h.accountManager.GetRule(account.Id, ruleID) + rule, err := h.accountManager.GetRule(account.Id, ruleID, user.Id) if err != nil { http.Error(w, "rule not found", http.StatusNotFound) return diff --git a/management/server/http/rules_test.go b/management/server/http/rules_test.go index e5c62324e..f6f20f0d5 100644 --- a/management/server/http/rules_test.go +++ b/management/server/http/rules_test.go @@ -28,7 +28,7 @@ func initRulesTestData(rules ...*server.Rule) *Rules { } return nil }, - GetRuleFunc: func(_, ruleID string) (*server.Rule, error) { + GetRuleFunc: func(_, ruleID, _ string) (*server.Rule, error) { if ruleID != "idoftherule" { return nil, fmt.Errorf("not found") } @@ -75,6 +75,9 @@ func initRulesTestData(rules ...*server.Rule) *Rules { "F": {ID: "F"}, "G": {ID: "G"}, }, + Users: map[string]*server.User{ + "test_user": server.NewAdminUser("test_user"), + }, }, nil }, }, diff --git a/management/server/http/setupkeys.go b/management/server/http/setupkeys.go index d7b1ce8b7..a07dacd60 100644 --- a/management/server/http/setupkeys.go +++ b/management/server/http/setupkeys.go @@ -6,15 +6,12 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/http/api" - "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/jwtclaims" log "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "net/http" - "strings" "time" - "unicode/utf8" ) // SetupKeys is a handler that returns a list of setup keys of the account @@ -34,7 +31,7 @@ func NewSetupKeysHandler(accountManager server.AccountManager, authAudience stri // CreateSetupKeyHandler is a POST requests that creates a new SetupKey func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -84,7 +81,7 @@ func (h *SetupKeys) CreateSetupKeyHandler(w http.ResponseWriter, r *http.Request // GetSetupKeyHandler is a GET request to get a SetupKey by ID func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -98,7 +95,7 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) { return } - key, err := h.accountManager.GetSetupKey(account.Id, keyID) + key, err := h.accountManager.GetSetupKey(account.Id, user.Id, keyID) if err != nil { errStatus, ok := status.FromError(err) if ok && errStatus.Code() == codes.NotFound { @@ -110,17 +107,12 @@ func (h *SetupKeys) GetSetupKeyHandler(w http.ResponseWriter, r *http.Request) { return } - isUserAdmin, ok := r.Context().Value(middleware.IsUserAdminProperty).(bool) - if !ok || !isUserAdmin { - key = hideKey(key) - } - writeSuccess(w, key) } // UpdateSetupKeyHandler is a PUT request to update server.SetupKey func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -176,19 +168,14 @@ func (h *SetupKeys) UpdateSetupKeyHandler(w http.ResponseWriter, r *http.Request // GetAllSetupKeysHandler is a GET request that returns a list of SetupKey func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Request) { - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) return } - isUserAdmin, ok := r.Context().Value(middleware.IsUserAdminProperty).(bool) - if !ok { - isUserAdmin = false - } - - setupKeys, err := h.accountManager.ListSetupKeys(account.Id) + setupKeys, err := h.accountManager.ListSetupKeys(account.Id, user.Id) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -196,23 +183,12 @@ func (h *SetupKeys) GetAllSetupKeysHandler(w http.ResponseWriter, r *http.Reques } apiSetupKeys := make([]*api.SetupKey, 0) for _, key := range setupKeys { - k := key.Copy() - if !isUserAdmin { - k = hideKey(key) - } - apiSetupKeys = append(apiSetupKeys, toResponseBody(k)) + apiSetupKeys = append(apiSetupKeys, toResponseBody(key)) } writeJSONObject(w, apiSetupKeys) } -func hideKey(key *server.SetupKey) *server.SetupKey { - k := key.Copy() - prefix := k.Key[0:5] - k.Key = prefix + strings.Repeat("*", utf8.RuneCountInString(key.Key)-len(prefix)) - return k -} - func writeSuccess(w http.ResponseWriter, key *server.SetupKey) { w.WriteHeader(200) w.Header().Set("Content-Type", "application/json") diff --git a/management/server/http/setupkeys_test.go b/management/server/http/setupkeys_test.go index 6819ae03d..87d7c53b4 100644 --- a/management/server/http/setupkeys_test.go +++ b/management/server/http/setupkeys_test.go @@ -2,7 +2,6 @@ package http import ( "bytes" - "context" "encoding/json" "fmt" "github.com/gorilla/mux" @@ -29,13 +28,17 @@ const ( notFoundSetupKeyID = "notFoundSetupKeyID" ) -func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey) *SetupKeys { +func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.SetupKey, updatedSetupKey *server.SetupKey, + user *server.User) *SetupKeys { return &SetupKeys{ accountManager: &mock_server.MockAccountManager{ GetAccountFromTokenFunc: func(claims jwtclaims.AuthorizationClaims) (*server.Account, error) { return &server.Account{ Id: testAccountID, Domain: "hotmail.com", + Users: map[string]*server.User{ + user.Id: user, + }, SetupKeys: map[string]*server.SetupKey{ defaultKey.Key: defaultKey, }, @@ -50,7 +53,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup } return nil, fmt.Errorf("failed creating setup key") }, - GetSetupKeyFunc: func(accountID string, keyID string) (*server.SetupKey, error) { + GetSetupKeyFunc: func(accountID, userID, keyID string) (*server.SetupKey, error) { switch keyID { case defaultKey.Id: return defaultKey, nil @@ -68,7 +71,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup return nil, status.Errorf(codes.NotFound, "key %s not found", key.Id) }, - ListSetupKeysFunc: func(accountID string) ([]*server.SetupKey, error) { + ListSetupKeysFunc: func(accountID, userID string) ([]*server.SetupKey, error) { return []*server.SetupKey{defaultKey}, nil }, }, @@ -76,7 +79,7 @@ func initSetupKeysTestMetaData(defaultKey *server.SetupKey, newKey *server.Setup jwtExtractor: jwtclaims.ClaimsExtractor{ ExtractClaimsFromRequestContext: func(r *http.Request, authAudience string) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ - UserId: "test_user", + UserId: user.Id, Domain: "hotmail.com", AccountId: testAccountID, } @@ -89,6 +92,8 @@ func TestSetupKeysHandlers(t *testing.T) { defaultSetupKey := server.GenerateDefaultSetupKey() defaultSetupKey.Id = existingSetupKeyID + adminUser := server.NewAdminUser("test_user") + newSetupKey := server.GenerateSetupKey(newSetupKeyName, server.SetupKeyReusable, 0, []string{"group-1"}) updatedDefaultSetupKey := defaultSetupKey.Copy() updatedDefaultSetupKey.AutoGroups = []string{"group-1"} @@ -154,13 +159,12 @@ func TestSetupKeysHandlers(t *testing.T) { }, } - handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey) + handler := initSetupKeysTestMetaData(defaultSetupKey, newSetupKey, updatedDefaultSetupKey, adminUser) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() req := httptest.NewRequest(tc.requestType, tc.requestPath, tc.requestBody) - req = req.Clone(context.WithValue(context.TODO(), "isAdminUser", true)) //nolint router := mux.NewRouter() router.HandleFunc("/api/setup-keys", handler.GetAllSetupKeysHandler).Methods("GET", "OPTIONS") diff --git a/management/server/http/users.go b/management/server/http/users.go index 3242c96e7..f7b227c40 100644 --- a/management/server/http/users.go +++ b/management/server/http/users.go @@ -34,7 +34,7 @@ func (h *UserHandler) UpdateUser(w http.ResponseWriter, r *http.Request) { http.Error(w, "", http.StatusBadRequest) } - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) @@ -87,7 +87,7 @@ func (h *UserHandler) CreateUserHandler(w http.ResponseWriter, r *http.Request) http.Error(w, "", http.StatusNotFound) } - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, _, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { log.Error(err) } @@ -132,12 +132,13 @@ func (h *UserHandler) GetUsers(w http.ResponseWriter, r *http.Request) { http.Error(w, "", http.StatusBadRequest) } - account, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) + account, user, err := getJWTAccount(h.accountManager, h.jwtExtractor, h.authAudience, r) if err != nil { - log.Error(err) + http.Redirect(w, r, "/", http.StatusInternalServerError) + return } - data, err := h.accountManager.GetUsersFromAccount(account.Id) + data, err := h.accountManager.GetUsersFromAccount(account.Id, user.Id) if err != nil { log.Error(err) http.Redirect(w, r, "/", http.StatusInternalServerError) diff --git a/management/server/http/users_test.go b/management/server/http/users_test.go index 597274c70..e402712c3 100644 --- a/management/server/http/users_test.go +++ b/management/server/http/users_test.go @@ -27,7 +27,7 @@ func initUsers(user ...*server.User) *UserHandler { Users: users, }, nil }, - GetUsersFromAccountFunc: func(accountID string) ([]*server.UserInfo, error) { + GetUsersFromAccountFunc: func(accountID, userID string) ([]*server.UserInfo, error) { users := make([]*server.UserInfo, 0) for _, v := range user { users = append(users, &server.UserInfo{ @@ -44,7 +44,7 @@ func initUsers(user ...*server.User) *UserHandler { jwtExtractor: jwtclaims.ClaimsExtractor{ ExtractClaimsFromRequestContext: func(r *http.Request, authAudiance string) jwtclaims.AuthorizationClaims { return jwtclaims.AuthorizationClaims{ - UserId: "test_user", + UserId: "1", Domain: "hotmail.com", AccountId: "test_id", } diff --git a/management/server/http/util.go b/management/server/http/util.go index 084ff3d9b..dbf28fdb8 100644 --- a/management/server/http/util.go +++ b/management/server/http/util.go @@ -56,16 +56,22 @@ func (d *Duration) UnmarshalJSON(b []byte) error { func getJWTAccount(accountManager server.AccountManager, jwtExtractor jwtclaims.ClaimsExtractor, - authAudience string, r *http.Request) (*server.Account, error) { + authAudience string, r *http.Request) (*server.Account, *server.User, error) { - jwtClaims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience) + claims := jwtExtractor.ExtractClaimsFromRequestContext(r, authAudience) - account, err := accountManager.GetAccountFromToken(jwtClaims) + account, err := accountManager.GetAccountFromToken(claims) if err != nil { - return nil, fmt.Errorf("failed getting account of a user %s: %v", jwtClaims.UserId, err) + return nil, nil, fmt.Errorf("failed getting account of a user %s: %v", claims.UserId, err) } - return account, nil + user := account.Users[claims.UserId] + if user == nil { + // this is not really possible because we got an account by user ID + return nil, nil, fmt.Errorf("user %s not found", claims.UserId) + } + + return account, user, nil } func toHTTPError(err error, w http.ResponseWriter) { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index ae7018183..9505dc9fd 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -14,12 +14,13 @@ type MockAccountManager struct { GetOrCreateAccountByUserFunc func(userId, domain string) (*server.Account, error) GetAccountByUserFunc func(userId string) (*server.Account, error) CreateSetupKeyFunc func(accountId string, keyName string, keyType server.SetupKeyType, expiresIn time.Duration, autoGroups []string) (*server.SetupKey, error) - GetSetupKeyFunc func(accountID string, keyID string) (*server.SetupKey, error) + GetSetupKeyFunc func(accountID, userID, keyID string) (*server.SetupKey, error) GetAccountByIdFunc func(accountId string) (*server.Account, error) GetAccountByUserOrAccountIdFunc func(userId, accountId, domain string) (*server.Account, error) IsUserAdminFunc func(claims jwtclaims.AuthorizationClaims) (bool, error) AccountExistsFunc func(accountId string) (*bool, error) GetPeerFunc func(peerKey string) (*server.Peer, error) + GetPeersFunc func(accountID, userID string) ([]*server.Peer, error) MarkPeerConnectedFunc func(peerKey string, connected bool) error RenamePeerFunc func(accountId string, peerKey string, newName string) (*server.Peer, error) DeletePeerFunc func(accountId string, peerKey string) (*server.Peer, error) @@ -35,23 +36,23 @@ type MockAccountManager struct { GroupAddPeerFunc func(accountID, groupID, peerKey string) error GroupDeletePeerFunc func(accountID, groupID, peerKey string) error GroupListPeersFunc func(accountID, groupID string) ([]*server.Peer, error) - GetRuleFunc func(accountID, ruleID string) (*server.Rule, error) + GetRuleFunc func(accountID, ruleID, userID string) (*server.Rule, error) SaveRuleFunc func(accountID string, rule *server.Rule) error UpdateRuleFunc func(accountID string, ruleID string, operations []server.RuleUpdateOperation) (*server.Rule, error) DeleteRuleFunc func(accountID, ruleID string) error - ListRulesFunc func(accountID string) ([]*server.Rule, error) - GetUsersFromAccountFunc func(accountID string) ([]*server.UserInfo, error) + ListRulesFunc func(accountID, userID string) ([]*server.Rule, error) + GetUsersFromAccountFunc func(accountID, userID string) ([]*server.UserInfo, error) UpdatePeerMetaFunc func(peerKey string, meta server.PeerSystemMeta) error UpdatePeerSSHKeyFunc func(peerKey string, sshKey string) error UpdatePeerFunc func(accountID string, peer *server.Peer) (*server.Peer, error) CreateRouteFunc func(accountID string, prefix, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) - GetRouteFunc func(accountID, routeID string) (*route.Route, error) + GetRouteFunc func(accountID, routeID, userID string) (*route.Route, error) SaveRouteFunc func(accountID string, route *route.Route) error UpdateRouteFunc func(accountID string, routeID string, operations []server.RouteUpdateOperation) (*route.Route, error) DeleteRouteFunc func(accountID, routeID string) error - ListRoutesFunc func(accountID string) ([]*route.Route, error) + ListRoutesFunc func(accountID, userID string) ([]*route.Route, error) SaveSetupKeyFunc func(accountID string, key *server.SetupKey) (*server.SetupKey, error) - ListSetupKeysFunc func(accountID string) ([]*server.SetupKey, error) + ListSetupKeysFunc func(accountID, userID string) ([]*server.SetupKey, error) SaveUserFunc func(accountID string, user *server.User) (*server.UserInfo, error) GetNameServerGroupFunc func(accountID, nsGroupID string) (*nbdns.NameServerGroup, error) CreateNameServerGroupFunc func(accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool) (*nbdns.NameServerGroup, error) @@ -64,9 +65,9 @@ type MockAccountManager struct { } // GetUsersFromAccount mock implementation of GetUsersFromAccount from server.AccountManager interface -func (am *MockAccountManager) GetUsersFromAccount(accountID string) ([]*server.UserInfo, error) { +func (am *MockAccountManager) GetUsersFromAccount(accountID string, userID string) ([]*server.UserInfo, error) { if am.GetUsersFromAccountFunc != nil { - return am.GetUsersFromAccountFunc(accountID) + return am.GetUsersFromAccountFunc(accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetUsersFromAccount is not implemented") } @@ -272,9 +273,9 @@ func (am *MockAccountManager) GroupListPeers(accountID, groupID string) ([]*serv } // GetRule mock implementation of GetRule from server.AccountManager interface -func (am *MockAccountManager) GetRule(accountID, ruleID string) (*server.Rule, error) { +func (am *MockAccountManager) GetRule(accountID, ruleID, userID string) (*server.Rule, error) { if am.GetRuleFunc != nil { - return am.GetRuleFunc(accountID, ruleID) + return am.GetRuleFunc(accountID, ruleID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetRule is not implemented") } @@ -304,9 +305,9 @@ func (am *MockAccountManager) DeleteRule(accountID, ruleID string) error { } // ListRules mock implementation of ListRules from server.AccountManager interface -func (am *MockAccountManager) ListRules(accountID string) ([]*server.Rule, error) { +func (am *MockAccountManager) ListRules(accountID, userID string) ([]*server.Rule, error) { if am.ListRulesFunc != nil { - return am.ListRulesFunc(accountID) + return am.ListRulesFunc(accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method ListRules is not implemented") } @@ -345,16 +346,16 @@ func (am *MockAccountManager) UpdatePeer(accountID string, peer *server.Peer) (* // CreateRoute mock implementation of CreateRoute from server.AccountManager interface func (am *MockAccountManager) CreateRoute(accountID string, network, peer, description, netID string, masquerade bool, metric int, enabled bool) (*route.Route, error) { - if am.GetRouteFunc != nil { + if am.CreateRouteFunc != nil { return am.CreateRouteFunc(accountID, network, peer, description, netID, masquerade, metric, enabled) } return nil, status.Errorf(codes.Unimplemented, "method CreateRoute is not implemented") } // GetRoute mock implementation of GetRoute from server.AccountManager interface -func (am *MockAccountManager) GetRoute(accountID, routeID string) (*route.Route, error) { +func (am *MockAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) { if am.GetRouteFunc != nil { - return am.GetRouteFunc(accountID, routeID) + return am.GetRouteFunc(accountID, routeID, userID) } return nil, status.Errorf(codes.Unimplemented, "method GetRoute is not implemented") } @@ -384,9 +385,9 @@ func (am *MockAccountManager) DeleteRoute(accountID, routeID string) error { } // ListRoutes mock implementation of ListRoutes from server.AccountManager interface -func (am *MockAccountManager) ListRoutes(accountID string) ([]*route.Route, error) { +func (am *MockAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) { if am.ListRoutesFunc != nil { - return am.ListRoutesFunc(accountID) + return am.ListRoutesFunc(accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method ListRoutes is not implemented") } @@ -401,18 +402,18 @@ func (am *MockAccountManager) SaveSetupKey(accountID string, key *server.SetupKe } // GetSetupKey mocks GetSetupKey of the AccountManager interface -func (am *MockAccountManager) GetSetupKey(accountID, keyID string) (*server.SetupKey, error) { +func (am *MockAccountManager) GetSetupKey(accountID, userID, keyID string) (*server.SetupKey, error) { if am.GetSetupKeyFunc != nil { - return am.GetSetupKeyFunc(accountID, keyID) + return am.GetSetupKeyFunc(accountID, userID, keyID) } return nil, status.Errorf(codes.Unimplemented, "method GetSetupKey is not implemented") } // ListSetupKeys mocks ListSetupKeys of the AccountManager interface -func (am *MockAccountManager) ListSetupKeys(accountID string) ([]*server.SetupKey, error) { +func (am *MockAccountManager) ListSetupKeys(accountID, userID string) ([]*server.SetupKey, error) { if am.ListSetupKeysFunc != nil { - return am.ListSetupKeysFunc(accountID) + return am.ListSetupKeysFunc(accountID, userID) } return nil, status.Errorf(codes.Unimplemented, "method ListSetupKeys is not implemented") @@ -489,3 +490,11 @@ func (am *MockAccountManager) GetAccountFromToken(claims jwtclaims.Authorization } return nil, status.Errorf(codes.Unimplemented, "method GetAccountFromToken is not implemented") } + +// GetPeers mocks GetPeers of the AccountManager interface +func (am *MockAccountManager) GetPeers(accountID, userID string) ([]*server.Peer, error) { + if am.GetAccountFromTokenFunc != nil { + return am.GetPeersFunc(accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetPeersFunc is not implemented") +} diff --git a/management/server/peer.go b/management/server/peer.go index 615137f79..212b1c449 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -81,6 +81,33 @@ func (am *DefaultAccountManager) GetPeer(peerKey string) (*Peer, error) { return peer, nil } +// GetPeers returns a list of peers under the given account filtering out peers that do not belong to a user if +// the current user is not an admin. +func (am *DefaultAccountManager) GetPeers(accountID, userID string) ([]*Peer, error) { + am.mux.Lock() + defer am.mux.Unlock() + account, err := am.Store.GetAccount(accountID) + if err != nil { + return nil, err + } + + user, err := account.FindUser(userID) + if err != nil { + return nil, err + } + + peers := make([]*Peer, 0, len(account.Peers)) + for _, peer := range account.Peers { + if !user.IsAdmin() && user.Id != peer.UserID { + // only display peers that belong to the current user if the current user is not an admin + continue + } + peers = append(peers, peer.Copy()) + } + + return peers, nil +} + // MarkPeerConnected marks peer as connected (true) or disconnected (false) func (am *DefaultAccountManager) MarkPeerConnected(peerKey string, connected bool) error { am.mux.Lock() diff --git a/management/server/peer_test.go b/management/server/peer_test.go index ea392b076..54f4629ba 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -134,7 +134,7 @@ func TestAccountManager_GetNetworkMapWithRule(t *testing.T) { return } - rules, err := manager.ListRules(account.Id) + rules, err := manager.ListRules(account.Id, userId) if err != nil { t.Errorf("expecting to get a list of rules, got failure %v", err) return diff --git a/management/server/route.go b/management/server/route.go index 40d331283..f523c8725 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -60,7 +60,7 @@ type RouteUpdateOperation struct { } // GetRoute gets a route object from account and route IDs -func (am *DefaultAccountManager) GetRoute(accountID, routeID string) (*route.Route, error) { +func (am *DefaultAccountManager) GetRoute(accountID, routeID, userID string) (*route.Route, error) { am.mux.Lock() defer am.mux.Unlock() @@ -69,6 +69,15 @@ func (am *DefaultAccountManager) GetRoute(accountID, routeID string) (*route.Rou return nil, status.Errorf(codes.NotFound, "account not found") } + user, err := account.FindUser(userID) + if err != nil { + return nil, err + } + + if !user.IsAdmin() { + return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes") + } + wantedRoute, found := account.Routes[routeID] if found { return wantedRoute, nil @@ -325,7 +334,7 @@ func (am *DefaultAccountManager) DeleteRoute(accountID, routeID string) error { } // ListRoutes returns a list of routes from account -func (am *DefaultAccountManager) ListRoutes(accountID string) ([]*route.Route, error) { +func (am *DefaultAccountManager) ListRoutes(accountID, userID string) ([]*route.Route, error) { am.mux.Lock() defer am.mux.Unlock() @@ -334,6 +343,15 @@ func (am *DefaultAccountManager) ListRoutes(accountID string) ([]*route.Route, e return nil, status.Errorf(codes.NotFound, "account not found") } + user, err := account.FindUser(userID) + if err != nil { + return nil, err + } + + if !user.IsAdmin() { + return nil, Errorf(PermissionDenied, "Only administrators can view Network Routes") + } + routes := make([]*route.Route, 0, len(account.Routes)) for _, item := range account.Routes { routes = append(routes, item) diff --git a/management/server/route_test.go b/management/server/route_test.go index 6e96c9db2..05c055e84 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -740,7 +740,7 @@ func TestGetNetworkMap_RouteSync(t *testing.T) { err = am.SaveGroup(account.Id, newGroup) require.NoError(t, err) - rules, err := am.ListRules(account.Id) + rules, err := am.ListRules(account.Id, "testingUser") require.NoError(t, err) defaultRule := rules[0] diff --git a/management/server/rule.go b/management/server/rule.go index 82689f347..8133d4f4a 100644 --- a/management/server/rule.go +++ b/management/server/rule.go @@ -89,7 +89,7 @@ func (r *Rule) Copy() *Rule { } // GetRule of ACL from the store -func (am *DefaultAccountManager) GetRule(accountID, ruleID string) (*Rule, error) { +func (am *DefaultAccountManager) GetRule(accountID, ruleID, userID string) (*Rule, error) { am.mux.Lock() defer am.mux.Unlock() @@ -98,6 +98,15 @@ func (am *DefaultAccountManager) GetRule(accountID, ruleID string) (*Rule, error return nil, status.Errorf(codes.NotFound, "account not found") } + user, err := account.FindUser(userID) + if err != nil { + return nil, err + } + + if !user.IsAdmin() { + return nil, Errorf(PermissionDenied, "only admins are allowed to view rules") + } + rule, ok := account.Rules[ruleID] if ok { return rule, nil @@ -222,7 +231,7 @@ func (am *DefaultAccountManager) DeleteRule(accountID, ruleID string) error { } // ListRules of ACL from the store -func (am *DefaultAccountManager) ListRules(accountID string) ([]*Rule, error) { +func (am *DefaultAccountManager) ListRules(accountID, userID string) ([]*Rule, error) { am.mux.Lock() defer am.mux.Unlock() @@ -231,6 +240,15 @@ func (am *DefaultAccountManager) ListRules(accountID string) ([]*Rule, error) { return nil, status.Errorf(codes.NotFound, "account not found") } + user, err := account.FindUser(userID) + if err != nil { + return nil, err + } + + if !user.IsAdmin() { + return nil, Errorf(PermissionDenied, "Only Administrators can view Access Rules") + } + rules := make([]*Rule, 0, len(account.Rules)) for _, item := range account.Rules { rules = append(rules, item) diff --git a/management/server/setupkey.go b/management/server/setupkey.go index 2cb428b16..1a30d9ac9 100644 --- a/management/server/setupkey.go +++ b/management/server/setupkey.go @@ -9,6 +9,7 @@ import ( "strconv" "strings" "time" + "unicode/utf8" ) const ( @@ -100,6 +101,15 @@ func (key *SetupKey) Copy() *SetupKey { } } +// HiddenCopy returns a copy of the key with a Key value hidden with "*" and a 5 character prefix. +// E.g., "831F6*******************************" +func (key *SetupKey) HiddenCopy() *SetupKey { + k := key.Copy() + prefix := k.Key[0:5] + k.Key = prefix + strings.Repeat("*", utf8.RuneCountInString(key.Key)-len(prefix)) + return k +} + // IncrementUsage makes a copy of a key, increments the UsedTimes by 1 and sets LastUsed to now func (key *SetupKey) IncrementUsage() *SetupKey { c := key.Copy() @@ -238,7 +248,7 @@ func (am *DefaultAccountManager) SaveSetupKey(accountID string, keyToSave *Setup } // ListSetupKeys returns a list of all setup keys of the account -func (am *DefaultAccountManager) ListSetupKeys(accountID string) ([]*SetupKey, error) { +func (am *DefaultAccountManager) ListSetupKeys(accountID, userID string) ([]*SetupKey, error) { am.mux.Lock() defer am.mux.Unlock() account, err := am.Store.GetAccount(accountID) @@ -246,16 +256,27 @@ func (am *DefaultAccountManager) ListSetupKeys(accountID string) ([]*SetupKey, e return nil, status.Errorf(codes.NotFound, "account not found") } + user, err := account.FindUser(userID) + if err != nil { + return nil, err + } + keys := make([]*SetupKey, 0, len(account.SetupKeys)) for _, key := range account.SetupKeys { - keys = append(keys, key.Copy()) + var k *SetupKey + if !user.IsAdmin() { + k = key.HiddenCopy() + } else { + k = key.Copy() + } + keys = append(keys, k) } return keys, nil } // GetSetupKey looks up a SetupKey by KeyID, returns NotFound error if not found. -func (am *DefaultAccountManager) GetSetupKey(accountID, keyID string) (*SetupKey, error) { +func (am *DefaultAccountManager) GetSetupKey(accountID, userID, keyID string) (*SetupKey, error) { am.mux.Lock() defer am.mux.Unlock() @@ -264,6 +285,11 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, keyID string) (*SetupKey return nil, status.Errorf(codes.NotFound, "account not found") } + user, err := account.FindUser(userID) + if err != nil { + return nil, err + } + var foundKey *SetupKey for _, key := range account.SetupKeys { if key.Id == keyID { @@ -280,5 +306,9 @@ func (am *DefaultAccountManager) GetSetupKey(accountID, keyID string) (*SetupKey foundKey.UpdatedAt = foundKey.CreatedAt } + if !user.IsAdmin() { + foundKey = foundKey.HiddenCopy() + } + return foundKey, nil } diff --git a/management/server/user.go b/management/server/user.go index 300fe897a..0929d3e76 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -46,6 +46,11 @@ type User struct { AutoGroups []string } +// IsAdmin returns true if user is an admin, false otherwise +func (u *User) IsAdmin() bool { + return u.Role == UserRoleAdmin +} + // toUserInfo converts a User object to a UserInfo object. func (u *User) toUserInfo(userData *idp.UserData) (*UserInfo, error) { autoGroups := u.AutoGroups @@ -288,13 +293,19 @@ func (am *DefaultAccountManager) IsUserAdmin(claims jwtclaims.AuthorizationClaim return user.Role == UserRoleAdmin, nil } -// GetUsersFromAccount performs a batched request for users from IDP by account ID -func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserInfo, error) { +// GetUsersFromAccount performs a batched request for users from IDP by account ID apply filter on what data to return +// based on provided user role. +func (am *DefaultAccountManager) GetUsersFromAccount(accountID, userID string) ([]*UserInfo, error) { account, err := am.GetAccountById(accountID) if err != nil { return nil, err } + user, err := account.FindUser(userID) + if err != nil { + return nil, err + } + queriedUsers := make([]*idp.UserData, 0) if !isNil(am.idpManager) { users := make(map[string]struct{}, len(account.Users)) @@ -311,8 +322,12 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserI // in case of self-hosted, or IDP doesn't return anything, we will return the locally stored userInfo if len(queriedUsers) == 0 { - for _, user := range account.Users { - info, err := user.toUserInfo(nil) + for _, accountUser := range account.Users { + if !user.IsAdmin() && user.Id != accountUser.Id { + // if user is not an admin then show only current user and do not show other users + continue + } + info, err := accountUser.toUserInfo(nil) if err != nil { return nil, err } @@ -322,6 +337,10 @@ func (am *DefaultAccountManager) GetUsersFromAccount(accountID string) ([]*UserI } for _, queriedUser := range queriedUsers { + if !user.IsAdmin() && user.Id != queriedUser.ID { + // if user is not an admin then show only current user and do not show other users + continue + } if localUser, contains := account.Users[queriedUser.ID]; contains { info, err := localUser.toUserInfo(queriedUser)