diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index dc7130209..e935d61af 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -21,11 +21,12 @@ import ( // handler is a handler that returns groups of the account type handler struct { - accountManager account.Manager + accountManager account.Manager + permissionsManager permissions.Manager } func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { - groupsHandler := newHandler(accountManager) + groupsHandler := newHandler(accountManager, permissionsManager) router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getAllGroups)).Methods("GET", "OPTIONS") router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Create, groupsHandler.createGroup)).Methods("POST", "OPTIONS") router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Update, groupsHandler.updateGroup)).Methods("PUT", "OPTIONS") @@ -34,12 +35,18 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router, permission } // newHandler creates a new groups handler -func newHandler(accountManager account.Manager) *handler { +func newHandler(accountManager account.Manager, permissionsManager permissions.Manager) *handler { return &handler{ - accountManager: accountManager, + accountManager: accountManager, + permissionsManager: permissionsManager, } } +func (h *handler) canReadPeers(r *http.Request, userAuth *auth.UserAuth) bool { + allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Peers, operations.Read) + return err == nil && allowed +} + // getAllGroups list for the account func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { // Check if filtering by name @@ -52,7 +59,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", true) + accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth)) if err != nil { util.WriteError(r.Context(), err, w) return @@ -71,7 +78,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", true) + accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth)) if err != nil { util.WriteError(r.Context(), err, w) return @@ -158,7 +165,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth * return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", true) + accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth)) if err != nil { util.WriteError(r.Context(), err, w) return @@ -210,7 +217,7 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request, userAuth * return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", true) + accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth)) if err != nil { util.WriteError(r.Context(), err, w) return @@ -256,7 +263,7 @@ func (h *handler) getGroup(w http.ResponseWriter, r *http.Request, userAuth *aut return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", true) + accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth)) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index f29f18521..9ed8ea5cd 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -13,11 +13,14 @@ import ( "strings" "testing" + "github.com/golang/mock/gomock" "github.com/gorilla/mux" "github.com/stretchr/testify/assert" "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/internals/modules/permissions" + "github.com/netbirdio/netbird/management/internals/modules/permissions/modules" + "github.com/netbirdio/netbird/management/internals/modules/permissions/operations" "github.com/netbirdio/netbird/management/server" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/mock_server" @@ -34,8 +37,18 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } -func initGroupTestData(initGroups ...*types.Group) *handler { +func initGroupTestData(t *testing.T, initGroups ...*types.Group) *handler { + t.Helper() + + ctrl := gomock.NewController(t) + permissionsManagerMock := permissions.NewMockManager(ctrl) + permissionsManagerMock.EXPECT(). + ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Eq(modules.Peers), gomock.Eq(operations.Read)). + Return(true, nil). + AnyTimes() + return &handler{ + permissionsManager: permissionsManagerMock, accountManager: &mock_server.MockAccountManager{ SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group, create bool) error { if !strings.HasPrefix(group.ID, "id-") { @@ -129,7 +142,7 @@ func TestGetGroup(t *testing.T) { Name: "Group", } - p := initGroupTestData(group) + p := initGroupTestData(t, group) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -255,7 +268,7 @@ func TestWriteGroup(t *testing.T) { }, } - p := initGroupTestData() + p := initGroupTestData(t) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -333,7 +346,7 @@ func TestGetAllGroups(t *testing.T) { }, } - p := initGroupTestData() + p := initGroupTestData(t) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) { @@ -415,7 +428,7 @@ func TestDeleteGroup(t *testing.T) { }, } - p := initGroupTestData() + p := initGroupTestData(t) for _, tc := range tt { t.Run(tc.name, func(t *testing.T) {