mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-25 19:56:46 +00:00
account for peers permission on the groups endpoint
This commit is contained in:
@@ -21,11 +21,12 @@ import (
|
||||
|
||||
// handler is a handler that returns groups of the account
|
||||
type handler struct {
|
||||
accountManager account.Manager
|
||||
accountManager account.Manager
|
||||
permissionsManager permissions.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
groupsHandler := newHandler(accountManager)
|
||||
groupsHandler := newHandler(accountManager, permissionsManager)
|
||||
router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getAllGroups)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Create, groupsHandler.createGroup)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Update, groupsHandler.updateGroup)).Methods("PUT", "OPTIONS")
|
||||
@@ -34,12 +35,18 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router, permission
|
||||
}
|
||||
|
||||
// newHandler creates a new groups handler
|
||||
func newHandler(accountManager account.Manager) *handler {
|
||||
func newHandler(accountManager account.Manager, permissionsManager permissions.Manager) *handler {
|
||||
return &handler{
|
||||
accountManager: accountManager,
|
||||
accountManager: accountManager,
|
||||
permissionsManager: permissionsManager,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handler) canReadPeers(r *http.Request, userAuth *auth.UserAuth) bool {
|
||||
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Peers, operations.Read)
|
||||
return err == nil && allowed
|
||||
}
|
||||
|
||||
// getAllGroups list for the account
|
||||
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
// Check if filtering by name
|
||||
@@ -52,7 +59,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", true)
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -71,7 +78,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", true)
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -158,7 +165,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", true)
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -210,7 +217,7 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request, userAuth *
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", true)
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -256,7 +263,7 @@ func (h *handler) getGroup(w http.ResponseWriter, r *http.Request, userAuth *aut
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", true)
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "", h.canReadPeers(r, userAuth))
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -13,11 +13,14 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
@@ -34,8 +37,18 @@ var TestPeers = map[string]*nbpeer.Peer{
|
||||
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
|
||||
}
|
||||
|
||||
func initGroupTestData(initGroups ...*types.Group) *handler {
|
||||
func initGroupTestData(t *testing.T, initGroups ...*types.Group) *handler {
|
||||
t.Helper()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||
permissionsManagerMock.EXPECT().
|
||||
ValidateUserPermissions(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Eq(modules.Peers), gomock.Eq(operations.Read)).
|
||||
Return(true, nil).
|
||||
AnyTimes()
|
||||
|
||||
return &handler{
|
||||
permissionsManager: permissionsManagerMock,
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
SaveGroupFunc: func(_ context.Context, accountID, userID string, group *types.Group, create bool) error {
|
||||
if !strings.HasPrefix(group.ID, "id-") {
|
||||
@@ -129,7 +142,7 @@ func TestGetGroup(t *testing.T) {
|
||||
Name: "Group",
|
||||
}
|
||||
|
||||
p := initGroupTestData(group)
|
||||
p := initGroupTestData(t, group)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -255,7 +268,7 @@ func TestWriteGroup(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
p := initGroupTestData()
|
||||
p := initGroupTestData(t)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -333,7 +346,7 @@ func TestGetAllGroups(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
p := initGroupTestData()
|
||||
p := initGroupTestData(t)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
@@ -415,7 +428,7 @@ func TestDeleteGroup(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
p := initGroupTestData()
|
||||
p := initGroupTestData(t)
|
||||
|
||||
for _, tc := range tt {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user