account for peers permission on the groups endpoint

This commit is contained in:
pascal
2026-04-20 14:00:57 +02:00
parent b65a8bcb9c
commit bed8d89d9f
2 changed files with 34 additions and 14 deletions

View File

@@ -21,11 +21,12 @@ import (
// handler is a handler that returns groups of the account
type handler struct {
accountManager account.Manager
accountManager account.Manager
permissionsManager permissions.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
groupsHandler := newHandler(accountManager)
groupsHandler := newHandler(accountManager, permissionsManager)
router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getAllGroups)).Methods("GET", "OPTIONS")
router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Create, groupsHandler.createGroup)).Methods("POST", "OPTIONS")
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Update, groupsHandler.updateGroup)).Methods("PUT", "OPTIONS")
@@ -34,12 +35,18 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router, permission
}
// newHandler creates a new groups handler
func newHandler(accountManager account.Manager) *handler {
func newHandler(accountManager account.Manager, permissionsManager permissions.Manager) *handler {
return &handler{
accountManager: accountManager,
accountManager: accountManager,
permissionsManager: permissionsManager,
}
}
func (h *handler) canReadPeers(r *http.Request, userAuth *auth.UserAuth) bool {
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Peers, operations.Read)
return err == nil && allowed
}
// getAllGroups list for the account
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
// Check if filtering by name
@@ -52,7 +59,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", true)
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -71,7 +78,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", true)
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -158,7 +165,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", true)
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -210,7 +217,7 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request, userAuth *
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", true)
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -256,7 +263,7 @@ func (h *handler) getGroup(w http.ResponseWriter, r *http.Request, userAuth *aut
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", true)
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -13,11 +13,14 @@ import (
"strings"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/management/internals/modules/permissions"
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
"github.com/netbirdio/netbird/management/server"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/mock_server"
@@ -34,8 +37,18 @@ var TestPeers = map[string]*nbpeer.Peer{
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
}
func initGroupTestData(initGroups ...*types.Group) *handler {
func initGroupTestData(t *testing.T, initGroups ...*types.Group) *handler {
t.Helper()
ctrl := gomock.NewController(t)
permissionsManagerMock := permissions.NewMockManager(ctrl)
permissionsManagerMock.EXPECT().
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Eq(modules.Peers), gomock.Eq(operations.Read)).
Return(true, nil).
AnyTimes()
return &handler{
permissionsManager: permissionsManagerMock,
accountManager: &mock_server.MockAccountManager{
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group, create bool) error {
if !strings.HasPrefix(group.ID, "id-") {
@@ -129,7 +142,7 @@ func TestGetGroup(t *testing.T) {
Name: "Group",
}
p := initGroupTestData(group)
p := initGroupTestData(t, group)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -255,7 +268,7 @@ func TestWriteGroup(t *testing.T) {
},
}
p := initGroupTestData()
p := initGroupTestData(t)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -333,7 +346,7 @@ func TestGetAllGroups(t *testing.T) {
},
}
p := initGroupTestData()
p := initGroupTestData(t)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {
@@ -415,7 +428,7 @@ func TestDeleteGroup(t *testing.T) {
},
}
p := initGroupTestData()
p := initGroupTestData(t)
for _, tc := range tt {
t.Run(tc.name, func(t *testing.T) {