use api wrapper for permissions management

This commit is contained in:
pascal
2026-02-24 00:39:54 +01:00
parent 9d123ec059
commit 32730af33f
43 changed files with 867 additions and 1320 deletions

View File

@@ -6,7 +6,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -15,21 +18,15 @@ type handler struct {
manager accesslogs.Manager
}
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager) {
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager, permissionsManager permissions.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/events/proxy", h.getAccessLogs).Methods("GET", "OPTIONS")
router.HandleFunc("/events/proxy", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAccessLogs)).Methods("GET", "OPTIONS")
}
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var filter accesslogs.AccessLogFilter
filter.ParseFromRequest(r)

View File

@@ -7,7 +7,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -17,15 +20,15 @@ type handler struct {
manager Manager
}
func RegisterEndpoints(router *mux.Router, manager Manager) {
func RegisterEndpoints(router *mux.Router, manager Manager, permissionsManager permissions.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS")
router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS")
router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS")
router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS")
router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllDomains)).Methods("GET", "OPTIONS")
router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Create, h.createCustomDomain)).Methods("POST", "OPTIONS")
router.HandleFunc("/domains/{domainId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteCustomDomain)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/domains/{domainId}/validate", permissionsManager.WithPermission(modules.Services, operations.Read, h.triggerCustomDomainValidation)).Methods("GET", "OPTIONS")
}
func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
@@ -53,13 +56,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
return resp
}
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -74,13 +71,7 @@ func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, ret)
}
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.PostApiReverseProxiesDomainsJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -96,13 +87,7 @@ func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, domainToApi(domain))
}
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
domainID := mux.Vars(r)["domainId"]
if domainID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
@@ -117,13 +102,7 @@ func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
domainID := mux.Vars(r)["domainId"]
if domainID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)

View File

@@ -10,7 +10,10 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -21,30 +24,24 @@ type handler struct {
}
// RegisterEndpoints registers all service HTTP endpoints.
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, permissionsManager permissions.Manager, router *mux.Router) {
h := &handler{
manager: manager,
}
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
domainmanager.RegisterEndpoints(domainRouter, domainManager)
domainmanager.RegisterEndpoints(domainRouter, domainManager, permissionsManager)
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
accesslogsmanager.RegisterEndpoints(router, accessLogsManager, permissionsManager)
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllServices)).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Create, h.createService)).Methods("POST", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Read, h.getService)).Methods("GET", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Update, h.updateService)).Methods("PUT", "OPTIONS")
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteService)).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -59,13 +56,7 @@ func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, apiServices)
}
func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.ServiceRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -75,7 +66,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
service := new(reverseproxy.Service)
service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.Validate(); err != nil {
if err := service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -89,13 +80,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse())
}
func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
@@ -111,13 +96,7 @@ func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, service.ToAPIResponse())
}
func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
@@ -134,7 +113,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
service.ID = serviceID
service.FromAPIRequest(&req, userAuth.AccountId)
if err = service.Validate(); err != nil {
if err := service.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -148,13 +127,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse())
}
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
serviceID := mux.Vars(r)["serviceId"]
if serviceID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)

View File

@@ -7,7 +7,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/zones"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -17,25 +20,19 @@ type handler struct {
manager zones.Manager
}
func RegisterEndpoints(router *mux.Router, manager zones.Manager) {
func RegisterEndpoints(router *mux.Router, manager zones.Manager, permissionsManager permissions.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS")
router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllZones)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createZone)).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getZone)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateZone)).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteZone)).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -50,13 +47,7 @@ func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, apiZones)
}
func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) createZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.PostApiDnsZonesJSONRequestBody
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -66,7 +57,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
zone := new(zones.Zone)
zone.FromAPIRequest(&req)
if err = zone.Validate(); err != nil {
if err := zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -80,13 +71,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse())
}
func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -102,13 +87,7 @@ func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse())
}
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -116,7 +95,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
}
var req api.PutApiDnsZonesZoneIdJSONRequestBody
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
@@ -125,7 +104,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
zone.FromAPIRequest(&req)
zone.ID = zoneID
if err = zone.Validate(); err != nil {
if err := zone.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -139,20 +118,14 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse())
}
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
return
}
if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
if err := h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -7,7 +7,10 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -17,25 +20,19 @@ type handler struct {
manager records.Manager
}
func RegisterEndpoints(router *mux.Router, manager records.Manager) {
func RegisterEndpoints(router *mux.Router, manager records.Manager, permissionsManager permissions.Manager) {
h := &handler{
manager: manager,
}
router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllRecords)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createRecord)).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getRecord)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateRecord)).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteRecord)).Methods("DELETE", "OPTIONS")
}
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -56,13 +53,7 @@ func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, apiRecords)
}
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -78,7 +69,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
record := new(records.Record)
record.FromAPIRequest(&req)
if err = record.Validate(); err != nil {
if err := record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -92,13 +83,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse())
}
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -120,13 +105,7 @@ func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, record.ToAPIResponse())
}
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -140,7 +119,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
}
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
@@ -149,7 +128,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
record.FromAPIRequest(&req)
record.ID = recordID
if err = record.Validate(); err != nil {
if err := record.Validate(); err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
return
}
@@ -163,13 +142,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse())
}
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
zoneID := mux.Vars(r)["zoneId"]
if zoneID == "" {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
@@ -182,7 +155,7 @@ func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
return
}
if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
if err := h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -154,27 +154,27 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
return nil, fmt.Errorf("failed to create instance manager: %w", err)
}
accounts.AddEndpoints(accountManager, settingsManager, router)
accounts.AddEndpoints(accountManager, settingsManager, router, permissionsManager)
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
users.AddEndpoints(accountManager, router)
users.AddInvitesEndpoints(accountManager, router)
users.AddEndpoints(accountManager, router, permissionsManager)
users.AddInvitesEndpoints(accountManager, router, permissionsManager)
users.AddPublicInvitesEndpoints(accountManager, router)
setup_keys.AddEndpoints(accountManager, router)
policies.AddEndpoints(accountManager, LocationManager, router)
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router)
setup_keys.AddEndpoints(accountManager, router, permissionsManager)
policies.AddEndpoints(accountManager, LocationManager, router, permissionsManager)
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router, permissionsManager)
policies.AddLocationsEndpoints(accountManager, LocationManager, permissionsManager, router)
groups.AddEndpoints(accountManager, router)
routes.AddEndpoints(accountManager, router)
dns.AddEndpoints(accountManager, router)
events.AddEndpoints(accountManager, router)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
zonesManager.RegisterEndpoints(router, zManager)
recordsManager.RegisterEndpoints(router, rManager)
idp.AddEndpoints(accountManager, router)
groups.AddEndpoints(accountManager, router, permissionsManager)
routes.AddEndpoints(accountManager, router, permissionsManager)
dns.AddEndpoints(accountManager, router, permissionsManager)
events.AddEndpoints(accountManager, router, permissionsManager)
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, permissionsManager, router)
zonesManager.RegisterEndpoints(router, zManager, permissionsManager)
recordsManager.RegisterEndpoints(router, rManager, permissionsManager)
idp.AddEndpoints(accountManager, router, permissionsManager)
instance.AddEndpoints(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router)
instance.AddVersionEndpoint(instanceManager, router, permissionsManager)
if reverseProxyManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
}
// Register OAuth callback handler for proxy authentication

View File

@@ -13,9 +13,12 @@ import (
goversion "github.com/hashicorp/go-version"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/settings"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -40,11 +43,11 @@ type handler struct {
settingsManager settings.Manager
}
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) {
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router, permissionsManager permissions.Manager) {
accountsHandler := newHandler(accountManager, settingsManager)
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Update, accountsHandler.updateAccount)).Methods("PUT", "OPTIONS")
router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Delete, accountsHandler.deleteAccount)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/accounts", permissionsManager.WithPermission(modules.Accounts, operations.Read, accountsHandler.getAllAccounts)).Methods("GET", "OPTIONS")
}
// newHandler creates a new handler HTTP handler
@@ -136,34 +139,26 @@ func calculateRequiredAddresses(peerCount int) int64 {
}
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
meta, err := h.accountManager.GetAccountMeta(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID)
settings, err := h.settingsManager.GetSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
settings, err := h.settingsManager.GetSettings(r.Context(), accountID, userID)
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp := toAccountResponse(accountID, settings, meta, onboarding)
resp := toAccountResponse(userAuth.AccountId, settings, meta, onboarding)
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
}
@@ -223,15 +218,7 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
}
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
_, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
accountID := vars["accountId"]
if len(accountID) == 0 {
@@ -240,7 +227,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
}
var req api.PutApiAccountsAccountIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -257,7 +244,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w)
return
}
if err := h.validateNetworkRange(r.Context(), accountID, userID, prefix); err != nil {
if err := h.validateNetworkRange(r.Context(), accountID, userAuth.UserId, prefix); err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -272,19 +259,19 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
}
}
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding)
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userAuth.UserId, onboarding)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userAuth.UserId, settings)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID)
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -296,13 +283,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
}
// deleteAccount is a HTTP DELETE handler to delete an account
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
targetAccountID := vars["accountId"]
if len(targetAccountID) == 0 {
@@ -310,7 +291,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId)
err := h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -72,6 +72,18 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
}
}
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
func TestAccounts_AccountsHandler(t *testing.T) {
accountID := "test_account"
adminUser := types.NewAdminUser("test_user")
@@ -284,8 +296,8 @@ func TestAccounts_AccountsHandler(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET")
router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT")
router.HandleFunc("/api/accounts", wrapHandler(handler.getAllAccounts)).Methods("GET")
router.HandleFunc("/api/accounts/{accountId}", wrapHandler(handler.updateAccount)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -5,11 +5,13 @@ import (
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -19,15 +21,15 @@ type dnsSettingsHandler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
addDNSSettingEndpoint(accountManager, router)
addDNSNameserversEndpoint(accountManager, router)
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
addDNSSettingEndpoint(accountManager, router, permissionsManager)
addDNSNameserversEndpoint(accountManager, router, permissionsManager)
}
func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router) {
func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
dnsSettingsHandler := newDNSSettingsHandler(accountManager)
router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Read, dnsSettingsHandler.getDNSSettings)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Update, dnsSettingsHandler.updateDNSSettings)).Methods("PUT", "OPTIONS")
}
// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler
@@ -36,17 +38,8 @@ func newDNSSettingsHandler(accountManager account.Manager) *dnsSettingsHandler {
}
// getDNSSettings returns the DNS settings for the account
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -60,17 +53,9 @@ func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Reque
}
// updateDNSSettings handles update to DNS settings of an account
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.PutApiDnsSettingsJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -80,7 +65,7 @@ func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Re
DisabledManagementGroups: req.DisabledManagementGroups,
}
err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings)
err = h.accountManager.SaveDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId, updateDNSSettings)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -23,6 +23,18 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
)
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
const (
testDNSSettingsAccountID = "test_id"
testDNSSettingsExistingGroup = "test_group"
@@ -115,8 +127,8 @@ func TestDNSSettingsHandlers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET")
router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT")
router.HandleFunc("/api/dns/settings", wrapHandler(p.getDNSSettings)).Methods("GET")
router.HandleFunc("/api/dns/settings", wrapHandler(p.updateDNSSettings)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -6,11 +6,13 @@ import (
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -21,13 +23,13 @@ type nameserversHandler struct {
accountManager account.Manager
}
func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router) {
func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
nameserversHandler := newNameserversHandler(accountManager)
router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS")
router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getAllNameservers)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Create, nameserversHandler.createNameserverGroup)).Methods("POST", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Update, nameserversHandler.updateNameserverGroup)).Methods("PUT", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getNameserverGroup)).Methods("GET", "OPTIONS")
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Delete, nameserversHandler.deleteNameserverGroup)).Methods("DELETE", "OPTIONS")
}
// newNameserversHandler returns a new instance of nameserversHandler handler
@@ -36,17 +38,8 @@ func newNameserversHandler(accountManager account.Manager) *nameserversHandler {
}
// getAllNameservers returns the list of nameserver groups for the account
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -61,17 +54,9 @@ func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Re
}
// createNameserverGroup handles nameserver group creation request
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.PostApiDnsNameserversJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -83,7 +68,7 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
return
}
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), userAuth.AccountId, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userAuth.UserId, req.SearchDomainsEnabled)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -95,15 +80,7 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
}
// updateNameserverGroup handles update to a nameserver group identified by a given ID
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
@@ -111,7 +88,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
}
var req api.PutApiDnsNameserversNsgroupIdJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -135,7 +112,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
SearchDomainsEnabled: req.SearchDomainsEnabled,
}
err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup)
err = h.accountManager.SaveNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, updatedNSGroup)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -147,22 +124,14 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
}
// deleteNameserverGroup handles nameserver group deletion request
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID)
err := h.accountManager.DeleteNameServerGroup(r.Context(), userAuth.AccountId, nsGroupID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -172,22 +141,14 @@ func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *htt
}
// getNameserverGroup handles a nameserver group Get request identified by ID
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
nsGroupID := mux.Vars(r)["nsgroupId"]
if len(nsGroupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
return
}
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID)
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, nsGroupID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -24,6 +24,18 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
)
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
const (
existingNSGroupID = "existingNSGroupID"
notFoundNSGroupID = "notFoundNSGroupID"
@@ -201,10 +213,10 @@ func TestNameserversHandlers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET")
router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", wrapHandler(p.getNameserverGroup)).Methods("GET")
router.HandleFunc("/api/dns/nameservers", wrapHandler(p.createNameserverGroup)).Methods("POST")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", wrapHandler(p.deleteNameserverGroup)).Methods("DELETE")
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", wrapHandler(p.updateNameserverGroup)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -5,11 +5,13 @@ import (
"net/http"
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -19,10 +21,10 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
eventsHandler := newHandler(accountManager)
router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
router.HandleFunc("/events/audit", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
router.HandleFunc("/events", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS")
router.HandleFunc("/events/audit", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS")
}
// newHandler creates a new events handler
@@ -31,17 +33,8 @@ func newHandler(accountManager account.Manager) *handler {
}
// getAllEvents list of the given account
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
accountEvents, err := h.accountManager.GetEvents(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -163,6 +163,18 @@ func generateEvents(accountID, userID string) []*activity.Event {
return events
}
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
func TestEvents_GetEvents(t *testing.T) {
tt := []struct {
name string
@@ -196,7 +208,7 @@ func TestEvents_GetEvents(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET")
router.HandleFunc("/api/events/", wrapHandler(handler.getAllEvents)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -8,10 +8,12 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -22,13 +24,13 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
groupsHandler := newHandler(accountManager)
router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS")
router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.getGroup).Methods("GET", "OPTIONS")
router.HandleFunc("/groups/{groupId}", groupsHandler.deleteGroup).Methods("DELETE", "OPTIONS")
router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getAllGroups)).Methods("GET", "OPTIONS")
router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Create, groupsHandler.createGroup)).Methods("POST", "OPTIONS")
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Update, groupsHandler.updateGroup)).Methods("PUT", "OPTIONS")
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getGroup)).Methods("GET", "OPTIONS")
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Delete, groupsHandler.deleteGroup)).Methods("DELETE", "OPTIONS")
}
// newHandler creates a new groups handler
@@ -39,26 +41,18 @@ func newHandler(accountManager account.Manager) *handler {
}
// getAllGroups list for the account
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Error(err)
http.Redirect(w, r, "/", http.StatusInternalServerError)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
// Check if filtering by name
groupName := r.URL.Query().Get("name")
if groupName != "" {
// Get single group by name
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID)
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, userAuth.AccountId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "")
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -71,13 +65,13 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
}
// Get all groups
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
groups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "")
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -92,15 +86,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
}
// updateGroup handles update to a group identified by a given ID
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
groupID, ok := vars["groupId"]
if !ok {
@@ -112,13 +98,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
return
}
existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
existingGroup, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", userAuth.AccountId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -166,13 +152,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
IntegrationReference: existingGroup.IntegrationReference,
}
if err := h.accountManager.UpdateGroup(r.Context(), accountID, userID, &group); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
if err := h.accountManager.UpdateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group); err != nil {
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, userAuth.AccountId, err)
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "")
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -182,17 +168,9 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
}
// createGroup handles group creation request
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.PostApiGroupsJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -226,13 +204,13 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
Issued: types.GroupIssuedAPI,
}
err = h.accountManager.CreateGroup(r.Context(), accountID, userID, &group)
err = h.accountManager.CreateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "")
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -242,22 +220,14 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
}
// deleteGroup handles group deletion request
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID)
err := h.accountManager.DeleteGroup(r.Context(), userAuth.AccountId, userAuth.UserId, groupID)
if err != nil {
wrappedErr, ok := err.(interface{ Unwrap() []error })
if ok && len(wrappedErr.Unwrap()) > 0 {
@@ -273,34 +243,26 @@ func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
}
// getGroup returns a group
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
groupID := mux.Vars(r)["groupId"]
if len(groupID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
return
}
group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
group, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "")
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group))
}
func toGroupResponse(peers []*nbpeer.Peer, group *types.Group) *api.Group {

View File

@@ -33,6 +33,18 @@ var TestPeers = map[string]*nbpeer.Peer{
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
}
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
func initGroupTestData(initGroups ...*types.Group) *handler {
return &handler{
accountManager: &mock_server.MockAccountManager{
@@ -141,7 +153,7 @@ func TestGetGroup(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET")
router.HandleFunc("/api/groups/{groupId}", wrapHandler(p.getGroup)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -267,8 +279,8 @@ func TestWriteGroup(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/groups", p.createGroup).Methods("POST")
router.HandleFunc("/api/groups/{groupId}", p.updateGroup).Methods("PUT")
router.HandleFunc("/api/groups", wrapHandler(p.createGroup)).Methods("POST")
router.HandleFunc("/api/groups/{groupId}", wrapHandler(p.updateGroup)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -345,7 +357,7 @@ func TestGetAllGroups(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/groups", p.getAllGroups).Methods("GET")
router.HandleFunc("/api/groups", wrapHandler(p.getAllGroups)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -426,7 +438,7 @@ func TestDeleteGroup(t *testing.T) {
AccountId: "test_id",
})
router := mux.NewRouter()
router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE")
router.HandleFunc("/api/groups/{groupId}", wrapHandler(p.deleteGroup)).Methods("DELETE")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -7,8 +7,11 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -20,13 +23,13 @@ type handler struct {
}
// AddEndpoints registers identity provider endpoints
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
h := newHandler(accountManager)
router.HandleFunc("/identity-providers", h.getAllIdentityProviders).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers", h.createIdentityProvider).Methods("POST", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE", "OPTIONS")
router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getAllIdentityProviders)).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Create, h.createIdentityProvider)).Methods("POST", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getIdentityProvider)).Methods("GET", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Update, h.updateIdentityProvider)).Methods("PUT", "OPTIONS")
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Delete, h.deleteIdentityProvider)).Methods("DELETE", "OPTIONS")
}
func newHandler(accountManager account.Manager) *handler {
@@ -36,16 +39,8 @@ func newHandler(accountManager account.Manager) *handler {
}
// getAllIdentityProviders returns all identity providers for the account
func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
providers, err := h.accountManager.GetIdentityProviders(r.Context(), accountID, userID)
func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
providers, err := h.accountManager.GetIdentityProviders(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -60,15 +55,7 @@ func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request
}
// getIdentityProvider returns a specific identity provider
func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
idpID := vars["idpId"]
if idpID == "" {
@@ -76,7 +63,7 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
return
}
provider, err := h.accountManager.GetIdentityProvider(r.Context(), accountID, idpID, userID)
provider, err := h.accountManager.GetIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -86,15 +73,7 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
}
// createIdentityProvider creates a new identity provider
func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.IdentityProviderRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -103,7 +82,7 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request)
idp := fromAPIRequest(&req)
created, err := h.accountManager.CreateIdentityProvider(r.Context(), accountID, userID, idp)
created, err := h.accountManager.CreateIdentityProvider(r.Context(), userAuth.AccountId, userAuth.UserId, idp)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -113,15 +92,7 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request)
}
// updateIdentityProvider updates an existing identity provider
func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
idpID := vars["idpId"]
if idpID == "" {
@@ -137,7 +108,7 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request)
idp := fromAPIRequest(&req)
updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), accountID, idpID, userID, idp)
updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId, idp)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -147,15 +118,7 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request)
}
// deleteIdentityProvider deletes an identity provider
func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
idpID := vars["idpId"]
if idpID == "" {
@@ -163,7 +126,7 @@ func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request)
return
}
if err := h.accountManager.DeleteIdentityProvider(r.Context(), accountID, idpID, userID); err != nil {
if err := h.accountManager.DeleteIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -87,6 +87,18 @@ func initIDPTestData(existingIDP *types.IdentityProvider) *handler {
}
}
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
func TestGetAllIdentityProviders(t *testing.T) {
existingIDP := &types.IdentityProvider{
ID: existingIDPID,
@@ -120,7 +132,7 @@ func TestGetAllIdentityProviders(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers", h.getAllIdentityProviders).Methods("GET")
router.HandleFunc("/api/identity-providers", wrapHandler(h.getAllIdentityProviders)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -180,7 +192,7 @@ func TestGetIdentityProvider(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET")
router.HandleFunc("/api/identity-providers/{idpId}", wrapHandler(h.getIdentityProvider)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -242,7 +254,7 @@ func TestCreateIdentityProvider(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers", h.createIdentityProvider).Methods("POST")
router.HandleFunc("/api/identity-providers", wrapHandler(h.createIdentityProvider)).Methods("POST")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -328,7 +340,7 @@ func TestUpdateIdentityProvider(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT")
router.HandleFunc("/api/identity-providers/{idpId}", wrapHandler(h.updateIdentityProvider)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -388,7 +400,7 @@ func TestDeleteIdentityProvider(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE")
router.HandleFunc("/api/identity-providers/{idpId}", wrapHandler(h.deleteIdentityProvider)).Methods("DELETE")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -8,6 +8,10 @@ import (
log "github.com/sirupsen/logrus"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -29,12 +33,12 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
}
// AddVersionEndpoint registers the authenticated version endpoint.
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) {
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router, permissionsManager permissions.Manager) {
h := &handler{
instanceManager: instanceManager,
}
router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS")
router.HandleFunc("/instance/version", permissionsManager.WithPermission(modules.Settings, operations.Read, h.getVersionInfo)).Methods("GET", "OPTIONS")
}
// getInstanceStatus returns the instance status including whether setup is required.
@@ -77,7 +81,7 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
// getVersionInfo returns version information for NetBird components.
// This endpoint requires authentication.
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) {
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
versionInfo, err := h.instanceManager.GetVersionInfo(r.Context())
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get version info: %v", err)

View File

@@ -296,7 +296,7 @@ func TestSetup_ManagerError(t *testing.T) {
func TestGetVersionInfo_Success(t *testing.T) {
manager := &mockInstanceManager{}
router := mux.NewRouter()
AddVersionEndpoint(manager, router)
AddVersionEndpoint(manager, router, nil)
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
rec := httptest.NewRecorder()
@@ -324,7 +324,7 @@ func TestGetVersionInfo_Error(t *testing.T) {
},
}
router := mux.NewRouter()
AddVersionEndpoint(manager, router)
AddVersionEndpoint(manager, router, nil)
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
rec := httptest.NewRecorder()

View File

@@ -10,14 +10,17 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
nbtypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -33,16 +36,16 @@ type handler struct {
groupsManager groups.Manager
}
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, router *mux.Router) {
addRouterEndpoints(routerManager, router)
addResourceEndpoints(resourceManager, groupsManager, router)
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, permissionsManager permissions.Manager, router *mux.Router) {
addRouterEndpoints(routerManager, permissionsManager, router)
addResourceEndpoints(resourceManager, groupsManager, permissionsManager, router)
networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager)
router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS")
router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.updateNetwork).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS")
router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getAllNetworks)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Create, networksHandler.createNetwork)).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getNetwork)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Update, networksHandler.updateNetwork)).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, networksHandler.deleteNetwork)).Methods("DELETE", "OPTIONS")
}
func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager) *handler {
@@ -55,40 +58,32 @@ func newHandler(networksManager networks.Manager, resourceManager resources.Mana
}
}
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
networks, err := h.networksManager.GetAllNetworks(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID)
resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), accountID, userID)
groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), accountID, userID)
routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
account, err := h.accountManager.GetAccount(r.Context(), accountID)
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -97,16 +92,9 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups, account))
}
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.NetworkRequest
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -115,14 +103,14 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
network := &types.Network{}
network.FromAPIRequest(&req)
network.AccountID = accountID
network, err = h.networksManager.CreateNetwork(r.Context(), userID, network)
network.AccountID = userAuth.AccountId
network, err = h.networksManager.CreateNetwork(r.Context(), userAuth.UserId, network)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
account, err := h.accountManager.GetAccount(r.Context(), accountID)
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -133,14 +121,7 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0, policyIDs))
}
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
networkID := vars["networkId"]
if len(networkID) == 0 {
@@ -148,19 +129,19 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
return
}
network, err := h.networksManager.GetNetwork(r.Context(), accountID, userID, networkID)
network, err := h.networksManager.GetNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID)
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
account, err := h.accountManager.GetAccount(r.Context(), accountID)
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -171,14 +152,7 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
}
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
networkID := vars["networkId"]
if len(networkID) == 0 {
@@ -187,7 +161,7 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
}
var req api.NetworkRequest
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -197,20 +171,20 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
network.FromAPIRequest(&req)
network.ID = networkID
network.AccountID = accountID
network, err = h.networksManager.UpdateNetwork(r.Context(), userID, network)
network.AccountID = userAuth.AccountId
network, err = h.networksManager.UpdateNetwork(r.Context(), userAuth.UserId, network)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID)
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
account, err := h.accountManager.GetAccount(r.Context(), accountID)
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -221,14 +195,7 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
}
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
networkID := vars["networkId"]
if len(networkID) == 0 {
@@ -236,7 +203,7 @@ func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
return
}
err = h.networksManager.DeleteNetwork(r.Context(), accountID, userID, networkID)
err := h.networksManager.DeleteNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -6,10 +6,13 @@ import (
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -19,14 +22,14 @@ type resourceHandler struct {
groupsManager groups.Manager
}
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, router *mux.Router) {
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, permissionsManager permissions.Manager, router *mux.Router) {
resourceHandler := newResourceHandler(resourcesManager, groupsManager)
router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.getResource).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.updateResource).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS")
router.HandleFunc("/networks/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInAccount)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInNetwork)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Create, resourceHandler.createResource)).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getResource)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Update, resourceHandler.updateResource)).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, resourceHandler.deleteResource)).Methods("DELETE", "OPTIONS")
}
func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager) *resourceHandler {
@@ -36,22 +39,15 @@ func newResourceHandler(resourceManager resources.Manager, groupsManager groups.
}
}
func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
networkID := mux.Vars(r)["networkId"]
resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID)
resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -66,22 +62,14 @@ func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *htt
util.WriteJSONObject(r.Context(), w, resourcesResponse)
}
func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -97,17 +85,9 @@ func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *htt
util.WriteJSONObject(r.Context(), w, resourcesResponse)
}
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.NetworkResourceRequest
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -117,14 +97,14 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request)
resource.FromAPIRequest(&req)
resource.NetworkID = mux.Vars(r)["networkId"]
resource.AccountID = accountID
resource, err = h.resourceManager.CreateResource(r.Context(), userID, resource)
resource.AccountID = userAuth.AccountId
resource, err = h.resourceManager.CreateResource(r.Context(), userAuth.UserId, resource)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -135,23 +115,16 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request)
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
}
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
networkID := mux.Vars(r)["networkId"]
resourceID := mux.Vars(r)["resourceId"]
resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID)
resource, err := h.resourceManager.GetResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -162,16 +135,9 @@ func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
}
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.NetworkResourceRequest
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -182,14 +148,14 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request)
resource.ID = mux.Vars(r)["resourceId"]
resource.NetworkID = mux.Vars(r)["networkId"]
resource.AccountID = accountID
resource, err = h.resourceManager.UpdateResource(r.Context(), userID, resource)
resource.AccountID = userAuth.AccountId
resource, err = h.resourceManager.UpdateResource(r.Context(), userAuth.UserId, resource)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -200,17 +166,10 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request)
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
}
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
networkID := mux.Vars(r)["networkId"]
resourceID := mux.Vars(r)["resourceId"]
err = h.resourceManager.DeleteResource(r.Context(), accountID, userID, networkID, resourceID)
err := h.resourceManager.DeleteResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -6,9 +6,12 @@ import (
"github.com/gorilla/mux"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
)
@@ -17,14 +20,14 @@ type routersHandler struct {
routersManager routers.Manager
}
func addRouterEndpoints(routersManager routers.Manager, router *mux.Router) {
func addRouterEndpoints(routersManager routers.Manager, permissionsManager permissions.Manager, router *mux.Router) {
routersHandler := newRoutersHandler(routersManager)
router.HandleFunc("/networks/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", routersHandler.getNetworkRouters).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS")
router.HandleFunc("/networks/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getAllRouters)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getNetworkRouters)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Create, routersHandler.createRouter)).Methods("POST", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getRouter)).Methods("GET", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Update, routersHandler.updateRouter)).Methods("PUT", "OPTIONS")
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, routersHandler.deleteRouter)).Methods("DELETE", "OPTIONS")
}
func newRoutersHandler(routersManager routers.Manager) *routersHandler {
@@ -33,16 +36,8 @@ func newRoutersHandler(routersManager routers.Manager) *routersHandler {
}
}
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), accountID, userID)
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -58,17 +53,9 @@ func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, routersResponse)
}
func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
networkID := mux.Vars(r)["networkId"]
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID)
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -82,18 +69,10 @@ func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Reques
util.WriteJSONObject(r.Context(), w, routersResponse)
}
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
networkID := mux.Vars(r)["networkId"]
var req api.NetworkRouterRequest
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -103,9 +82,9 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
router.FromAPIRequest(&req)
router.NetworkID = networkID
router.AccountID = accountID
router.AccountID = userAuth.AccountId
router.Enabled = true
router, err = h.routersManager.CreateRouter(r.Context(), userID, router)
router, err = h.routersManager.CreateRouter(r.Context(), userAuth.UserId, router)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -114,18 +93,10 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
}
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
routerID := mux.Vars(r)["routerId"]
networkID := mux.Vars(r)["networkId"]
router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID)
router, err := h.routersManager.GetRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -134,17 +105,9 @@ func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
}
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.NetworkRouterRequest
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -155,9 +118,9 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
router.NetworkID = mux.Vars(r)["networkId"]
router.ID = mux.Vars(r)["routerId"]
router.AccountID = accountID
router.AccountID = userAuth.AccountId
router, err = h.routersManager.UpdateRouter(r.Context(), userID, router)
router, err = h.routersManager.UpdateRouter(r.Context(), userAuth.UserId, router)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -166,17 +129,10 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
}
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
routerID := mux.Vars(r)["routerId"]
networkID := mux.Vars(r)["networkId"]
err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID)
err := h.routersManager.DeleteRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -1,7 +1,6 @@
package peers
import (
"context"
"encoding/json"
"fmt"
"net/http"
@@ -14,13 +13,13 @@ import (
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/groups"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -35,14 +34,15 @@ type Handler struct {
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) {
peersHandler := NewHandler(accountManager, networkMapController, permissionsManager)
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
Methods("GET", "PUT", "DELETE", "OPTIONS")
router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.ListJobs).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.CreateJob).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", peersHandler.GetJob).Methods("GET", "OPTIONS")
router.HandleFunc("/peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAllPeers)).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetPeer)).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Update, peersHandler.UpdatePeer)).Methods("PUT", "OPTIONS")
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Delete, peersHandler.DeletePeer)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/peers/{peerId}/accessible-peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAccessiblePeers)).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/temporary-access", permissionsManager.WithPermission(modules.Peers, operations.Create, peersHandler.CreateTemporaryAccess)).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.ListJobs)).Methods("GET", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Create, peersHandler.CreateJob)).Methods("POST", "OPTIONS")
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.GetJob)).Methods("GET", "OPTIONS")
}
// NewHandler creates a new peers Handler
@@ -54,14 +54,7 @@ func NewHandler(accountManager account.Manager, networkMapController network_map
}
}
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
@@ -73,37 +66,30 @@ func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) {
job, err := types.NewJob(userAuth.UserId, userAuth.AccountId, peerID, req)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
if err := h.accountManager.CreatePeerJob(ctx, userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
util.WriteError(ctx, err, w)
if err := h.accountManager.CreatePeerJob(r.Context(), userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
util.WriteError(r.Context(), err, w)
return
}
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(ctx, w, resp)
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
jobs, err := h.accountManager.GetAllPeerJobs(ctx, userAuth.AccountId, userAuth.UserId, peerID)
jobs, err := h.accountManager.GetAllPeerJobs(r.Context(), userAuth.AccountId, userAuth.UserId, peerID)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
@@ -111,79 +97,88 @@ func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) {
for _, job := range jobs {
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
respBody = append(respBody, resp)
}
util.WriteJSONObject(ctx, w, respBody)
util.WriteJSONObject(r.Context(), w, respBody)
}
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(ctx, err, w)
return
}
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
jobID := vars["jobId"]
job, err := h.accountManager.GetPeerJobByID(ctx, userAuth.AccountId, userAuth.UserId, peerID, jobID)
job, err := h.accountManager.GetPeerJobByID(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, jobID)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
resp, err := toSingleJobResponse(job)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(ctx, w, resp)
util.WriteJSONObject(r.Context(), w, resp)
}
func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
// GetPeer handles GET request for a single peer
func (h *Handler) GetPeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
peer, err := h.accountManager.GetPeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
if peer.ProxyMeta.Embedded {
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
return
}
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
dnsDomain := h.networkMapController.GetDNSDomain(settings)
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
grps, _ := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peerID)
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
if err != nil {
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return
}
_, valid := validPeers[peer.ID]
reason := invalidPeers[peer.ID]
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
}
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
// UpdatePeer handles PUT request to update a peer
func (h *Handler) UpdatePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
return
}
req := &api.PeerRequest{}
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
@@ -192,11 +187,10 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
}
update := &nbpeer.Peer{
ID: peerID,
SSHEnabled: req.SshEnabled,
Name: req.Name,
LoginExpirationEnabled: req.LoginExpirationEnabled,
ID: peerID,
SSHEnabled: req.SshEnabled,
Name: req.Name,
LoginExpirationEnabled: req.LoginExpirationEnabled,
InactivityExpirationEnabled: req.InactivityExpirationEnabled,
}
@@ -210,41 +204,41 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
if req.Ip != nil {
addr, err := netip.ParseAddr(*req.Ip)
if err != nil {
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
return
}
if err = h.accountManager.UpdatePeerIP(ctx, accountID, userID, peerID, addr); err != nil {
util.WriteError(ctx, err, w)
if err = h.accountManager.UpdatePeerIP(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, addr); err != nil {
util.WriteError(r.Context(), err, w)
return
}
}
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
peer, err := h.accountManager.UpdatePeer(r.Context(), userAuth.AccountId, userAuth.UserId, update)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
dnsDomain := h.networkMapController.GetDNSDomain(settings)
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peer.ID)
if err != nil {
util.WriteError(ctx, err, w)
util.WriteError(r.Context(), err, w)
return
}
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
if err != nil {
log.WithContext(ctx).Errorf("failed to get validated peers: %v", err)
util.WriteError(ctx, fmt.Errorf("internal error"), w)
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
return
}
@@ -254,25 +248,8 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
}
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete peer: %v", err)
util.WriteError(ctx, err, w)
return
}
util.WriteJSONObject(ctx, w, util.EmptyObject{})
}
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
// DeletePeer handles DELETE request to delete a peer
func (h *Handler) DeletePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
@@ -280,48 +257,34 @@ func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
return
}
switch r.Method {
case http.MethodDelete:
h.deletePeer(r.Context(), accountID, userID, peerID, w)
err := h.accountManager.DeletePeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to delete peer: %v", err)
util.WriteError(r.Context(), err, w)
return
case http.MethodGet:
h.getPeer(r.Context(), accountID, peerID, userID, w)
return
case http.MethodPut:
h.updatePeer(r.Context(), accountID, userID, peerID, w, r)
return
default:
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
}
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
// GetAllPeers returns a list of all peers associated with a provided account
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
nameFilter := r.URL.Query().Get("name")
ipFilter := r.URL.Query().Get("ip")
accountID, userID := userAuth.AccountId, userAuth.UserId
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, nameFilter, ipFilter)
peers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, nameFilter, ipFilter)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, activity.SystemInitiator)
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
dnsDomain := h.networkMapController.GetDNSDomain(settings)
grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
grps, _ := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers))
respBody := make([]*api.PeerBatch, 0, len(peers))
@@ -332,7 +295,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0))
}
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
@@ -356,15 +319,7 @@ func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersM
}
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
@@ -372,19 +327,19 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
return
}
user, err := h.accountManager.GetUserByID(r.Context(), userID)
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), accountID, userID, modules.Peers, operations.Read)
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Peers, operations.Read)
if err != nil {
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
return
}
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
account, err := h.accountManager.GetAccountByID(r.Context(), userAuth.AccountId, activity.SystemInitiator)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -408,7 +363,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
}
}
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
@@ -422,13 +377,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
}
func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
peerID := vars["peerId"]
if len(peerID) == 0 {
@@ -437,7 +386,7 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
}
var req api.PeerTemporaryAccessRequest
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return

View File

@@ -34,6 +34,18 @@ import (
"github.com/netbirdio/netbird/management/server/mock_server"
)
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
const (
testPeerID = "test_peer"
noUpdateChannelTestPeerID = "no-update-channel"
@@ -307,9 +319,9 @@ func TestGetPeers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("PUT")
router.HandleFunc("/api/peers/", wrapHandler(p.GetAllPeers)).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", wrapHandler(p.GetPeer)).Methods("GET")
router.HandleFunc("/api/peers/{peerId}", wrapHandler(p.UpdatePeer)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -498,7 +510,7 @@ func TestGetAccessiblePeers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
router.HandleFunc("/api/peers/{peerId}/accessible-peers", wrapHandler(p.GetAccessiblePeers)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -582,7 +594,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
rr := httptest.NewRecorder()
router := mux.NewRouter()
router.HandleFunc("/peers/{peerId}", p.HandlePeer).Methods("PUT")
router.HandleFunc("/peers/{peerId}", wrapHandler(p.UpdatePeer)).Methods("PUT")
router.ServeHTTP(rr, req)

View File

@@ -65,6 +65,18 @@ func initGeolocationTestData(t *testing.T) *geolocationsHandler {
}
}
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
func TestGetCitiesByCountry(t *testing.T) {
tt := []struct {
name string
@@ -121,7 +133,7 @@ func TestGetCitiesByCountry(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET")
router.HandleFunc("/api/locations/countries/{country}/cities", wrapHandler(geolocationHandler.getCitiesByCountry)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -214,7 +226,7 @@ func TestGetAllCountries(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET")
router.HandleFunc("/api/locations/countries", wrapHandler(geolocationHandler.getAllCountries)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -7,11 +7,11 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -30,8 +30,8 @@ type geolocationsHandler struct {
func AddLocationsEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, permissionsManager permissions.Manager, router *mux.Router) {
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, permissionsManager)
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getAllCountries)).Methods("GET", "OPTIONS")
router.HandleFunc("/locations/countries/{country}/cities", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getCitiesByCountry)).Methods("GET", "OPTIONS")
}
// newGeolocationsHandlerHandler creates a new Geolocations handler
@@ -44,12 +44,7 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa
}
// getAllCountries retrieves a list of all countries
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if l.geolocationManager == nil {
// TODO: update error message to include geo db self hosted doc link when ready
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
@@ -70,12 +65,7 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req
}
// getCitiesByCountry retrieves a list of cities based on the given country code
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) {
if err := l.authenticateUser(r); err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
countryCode := vars["country"]
if !countryCodeRegex.MatchString(countryCode) {
@@ -102,27 +92,6 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
util.WriteJSONObject(r.Context(), w, cities)
}
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
return err
}
accountID, userID := userAuth.AccountId, userAuth.UserId
allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
if err != nil {
return status.NewPermissionValidationError(err)
}
if !allowed {
return status.NewPermissionDeniedError()
}
return nil
}
func toCountryResponse(country geolocation.Country) api.Country {
return api.Country{
CountryName: country.CountryName,

View File

@@ -8,9 +8,12 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -21,13 +24,13 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) {
func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) {
policiesHandler := newHandler(accountManager)
router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS")
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS")
router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS")
router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getAllPolicies)).Methods("GET", "OPTIONS")
router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Create, policiesHandler.createPolicy)).Methods("POST", "OPTIONS")
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Update, policiesHandler.updatePolicy)).Methods("PUT", "OPTIONS")
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getPolicy)).Methods("GET", "OPTIONS")
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, policiesHandler.deletePolicy)).Methods("DELETE", "OPTIONS")
}
// newHandler creates a new policies handler
@@ -38,22 +41,14 @@ func newHandler(accountManager account.Manager) *handler {
}
// getAllPolicies list for the account
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
listPolicies, err := h.accountManager.ListPolicies(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -73,15 +68,7 @@ func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
}
// updatePolicy handles update to a policy identified by a given ID
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
@@ -89,26 +76,18 @@ func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
return
}
_, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
_, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
h.savePolicy(w, r, accountID, userID, policyID, false)
h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, policyID, false)
}
// createPolicy handles policy creation request
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
h.savePolicy(w, r, accountID, userID, "", true)
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, "", true)
}
// savePolicy handles policy creation and update
@@ -303,14 +282,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
}
// deletePolicy handles policy deletion request
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
@@ -318,7 +290,7 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
return
}
if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil {
if err := h.accountManager.DeletePolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId); err != nil {
util.WriteError(r.Context(), err, w)
return
}
@@ -327,15 +299,7 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
}
// getPolicy handles a group Get request identified by ID
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
policyID := vars["policyId"]
if len(policyID) == 0 {
@@ -343,13 +307,13 @@ func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
return
}
policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
policy, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -21,6 +21,18 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
func initPoliciesTestData(policies ...*types.Policy) *handler {
testPolicies := make(map[string]*types.Policy, len(policies))
for _, policy := range policies {
@@ -111,7 +123,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET")
router.HandleFunc("/api/policies/{policyId}", wrapHandler(p.getPolicy)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -275,8 +287,8 @@ func TestPoliciesWritePolicy(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/policies", p.createPolicy).Methods("POST")
router.HandleFunc("/api/policies/{policyId}", p.updatePolicy).Methods("PUT")
router.HandleFunc("/api/policies", wrapHandler(p.createPolicy)).Methods("POST")
router.HandleFunc("/api/policies/{policyId}", wrapHandler(p.updatePolicy)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -7,9 +7,12 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -21,13 +24,13 @@ type postureChecksHandler struct {
geolocationManager geolocation.Geolocation
}
func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) {
func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) {
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager)
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS")
router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getAllPostureChecks)).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Create, postureCheckHandler.createPostureCheck)).Methods("POST", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Update, postureCheckHandler.updatePostureCheck)).Methods("PUT", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getPostureCheck)).Methods("GET", "OPTIONS")
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, postureCheckHandler.deletePostureCheck)).Methods("DELETE", "OPTIONS")
}
// newPostureChecksHandler creates a new PostureChecks handler
@@ -39,15 +42,8 @@ func newPostureChecksHandler(accountManager account.Manager, geolocationManager
}
// getAllPostureChecks list for the account
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -62,15 +58,7 @@ func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *htt
}
// updatePostureCheck handles update to a posture check identified by a given ID
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
@@ -78,37 +66,22 @@ func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http
return
}
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
_, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
p.savePostureChecks(w, r, accountID, userID, postureChecksID, false)
p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, postureChecksID, false)
}
// createPostureCheck handles posture check creation request
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
p.savePostureChecks(w, r, accountID, userID, "", true)
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, "", true)
}
// getPostureCheck handles a posture check Get request identified by ID
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
@@ -116,7 +89,7 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
return
}
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -126,14 +99,7 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
}
// deletePostureCheck handles posture check deletion request
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
postureChecksID := vars["postureCheckId"]
if len(postureChecksID) == 0 {
@@ -141,7 +107,7 @@ func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http
return
}
if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil {
if err := p.accountManager.DeletePostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId); err != nil {
util.WriteError(r.Context(), err, w)
return
}

View File

@@ -26,6 +26,18 @@ import (
var berlin = "Berlin"
var losAngeles = "Los Angeles"
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksHandler {
testPostureChecks := make(map[string]*posture.Checks, len(postureChecks))
for _, postureCheck := range postureChecks {
@@ -183,7 +195,7 @@ func TestGetPostureCheck(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET")
router.HandleFunc("/api/posture-checks/{postureCheckId}", wrapHandler(p.getPostureCheck)).Methods("GET")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -841,8 +853,8 @@ func TestPostureCheckUpdate(t *testing.T) {
}
router := mux.NewRouter()
router.HandleFunc("/api/posture-checks", defaultHandler.createPostureCheck).Methods("POST")
router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.updatePostureCheck).Methods("PUT")
router.HandleFunc("/api/posture-checks", wrapHandler(defaultHandler.createPostureCheck)).Methods("POST")
router.HandleFunc("/api/posture-checks/{postureCheckId}", wrapHandler(defaultHandler.updatePostureCheck)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -9,8 +9,11 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
@@ -26,13 +29,13 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
routesHandler := newHandler(accountManager)
router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS")
router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.getRoute).Methods("GET", "OPTIONS")
router.HandleFunc("/routes/{routeId}", routesHandler.deleteRoute).Methods("DELETE", "OPTIONS")
router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getAllRoutes)).Methods("GET", "OPTIONS")
router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Create, routesHandler.createRoute)).Methods("POST", "OPTIONS")
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Update, routesHandler.updateRoute)).Methods("PUT", "OPTIONS")
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getRoute)).Methods("GET", "OPTIONS")
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Delete, routesHandler.deleteRoute)).Methods("DELETE", "OPTIONS")
}
// newHandler returns a new instance of routes handler
@@ -43,16 +46,8 @@ func newHandler(accountManager account.Manager) *handler {
}
// getAllRoutes returns the list of routes for the account
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
routes, err := h.accountManager.ListRoutes(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -71,17 +66,9 @@ func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
}
// createRoute handles route creation request
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.PostApiRoutesJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -134,8 +121,8 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
skipAutoApply = false
}
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute, skipAutoApply)
newRoute, err := h.accountManager.CreateRoute(r.Context(), userAuth.AccountId, newPrefix, networkType, domains, peerId, peerGroupIds,
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userAuth.UserId, req.KeepRoute, skipAutoApply)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -185,14 +172,7 @@ func (h *handler) validateRouteCommon(network *string, domains *[]string, peer *
}
// updateRoute handles update to a route identified by a given ID
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
routeID := vars["routeId"]
if len(routeID) == 0 {
@@ -200,7 +180,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
return
}
_, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
_, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -271,7 +251,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
newRoute.AccessControlGroups = *req.AccessControlGroups
}
err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
err = h.accountManager.SaveRoute(r.Context(), userAuth.AccountId, userAuth.UserId, newRoute)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -287,21 +267,14 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
}
// deleteRoute handles route deletion request
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID)
err := h.accountManager.DeleteRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -311,22 +284,14 @@ func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
}
// getRoute handles a route Get request identified by ID
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
routeID := mux.Vars(r)["routeId"]
if len(routeID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
return
}
foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
foundRoute, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -25,6 +25,18 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
const (
existingRouteID = "existingRouteID"
existingRouteID2 = "existingRouteID2" // for peer_groups test
@@ -501,10 +513,10 @@ func TestRoutesHandlers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET")
router.HandleFunc("/api/routes/{routeId}", p.deleteRoute).Methods("DELETE")
router.HandleFunc("/api/routes", p.createRoute).Methods("POST")
router.HandleFunc("/api/routes/{routeId}", p.updateRoute).Methods("PUT")
router.HandleFunc("/api/routes/{routeId}", wrapHandler(p.getRoute)).Methods("GET")
router.HandleFunc("/api/routes/{routeId}", wrapHandler(p.deleteRoute)).Methods("DELETE")
router.HandleFunc("/api/routes", wrapHandler(p.createRoute)).Methods("POST")
router.HandleFunc("/api/routes/{routeId}", wrapHandler(p.updateRoute)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -9,8 +9,11 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -21,13 +24,13 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
keysHandler := newHandler(accountManager)
router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.updateSetupKey).Methods("PUT", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", keysHandler.deleteSetupKey).Methods("DELETE", "OPTIONS")
router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getAllSetupKeys)).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Create, keysHandler.createSetupKey)).Methods("POST", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getSetupKey)).Methods("GET", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Update, keysHandler.updateSetupKey)).Methods("PUT", "OPTIONS")
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Delete, keysHandler.deleteSetupKey)).Methods("DELETE", "OPTIONS")
}
// newHandler creates a new setup key handler
@@ -38,16 +41,9 @@ func newHandler(accountManager account.Manager) *handler {
}
// createSetupKey is a POST requests that creates a new SetupKey
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
req := &api.PostApiSetupKeysJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -85,8 +81,8 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
allowExtraDNSLabels = *req.AllowExtraDnsLabels
}
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, types.SetupKeyType(req.Type), expiresIn,
req.AutoGroups, req.UsageLimit, userID, ephemeral, allowExtraDNSLabels)
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), userAuth.AccountId, req.Name, types.SetupKeyType(req.Type), expiresIn,
req.AutoGroups, req.UsageLimit, userAuth.UserId, ephemeral, allowExtraDNSLabels)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -100,14 +96,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
}
// getSetupKey is a GET request to get a SetupKey by ID
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
@@ -115,7 +104,7 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
return
}
key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID)
key, err := h.accountManager.GetSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -125,14 +114,7 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
}
// updateSetupKey is a PUT request to update server.SetupKey
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
@@ -141,7 +123,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
}
req := &api.PutApiSetupKeysKeyIdJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -157,7 +139,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
newKey.Revoked = req.Revoked
newKey.Id = keyID
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
newKey, err = h.accountManager.SaveSetupKey(r.Context(), userAuth.AccountId, newKey, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -166,15 +148,8 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
}
// getAllSetupKeys is a GET request that returns a list of SetupKey
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -188,14 +163,7 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
}
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
keyID := vars["keyId"]
if len(keyID) == 0 {
@@ -203,7 +171,7 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeleteSetupKey(r.Context(), accountID, userID, keyID)
err := h.accountManager.DeleteSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -22,6 +22,18 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
const (
existingSetupKeyID = "existingSetupKeyID"
newSetupKeyName = "New Setup Key"
@@ -171,11 +183,11 @@ func TestSetupKeysHandlers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys", handler.createSetupKey).Methods("POST", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.getSetupKey).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.updateSetupKey).Methods("PUT", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", handler.deleteSetupKey).Methods("DELETE", "OPTIONS")
router.HandleFunc("/api/setup-keys", wrapHandler(handler.getAllSetupKeys)).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys", wrapHandler(handler.createSetupKey)).Methods("POST", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", wrapHandler(handler.getSetupKey)).Methods("GET", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", wrapHandler(handler.updateSetupKey)).Methods("PUT", "OPTIONS")
router.HandleFunc("/api/setup-keys/{keyId}", wrapHandler(handler.deleteSetupKey)).Methods("DELETE", "OPTIONS")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -10,9 +10,12 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/http/middleware"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -55,14 +58,14 @@ type invitesHandler struct {
}
// AddInvitesEndpoints registers invite-related endpoints
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
h := &invitesHandler{accountManager: accountManager}
// Authenticated endpoints (require admin)
router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS")
router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Read, h.listInvites)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Create, h.createInvite)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}", permissionsManager.WithPermission(modules.Users, operations.Delete, h.deleteInvite)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/invites/{inviteId}/regenerate", permissionsManager.WithPermission(modules.Users, operations.Update, h.regenerateInvite)).Methods("POST", "OPTIONS")
}
// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting
@@ -79,14 +82,7 @@ func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Route
}
// listInvites handles GET /api/users/invites
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -102,14 +98,7 @@ func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
}
// createInvite handles POST /api/users/invites
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
var req api.UserInviteCreateRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
@@ -191,18 +180,12 @@ func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) {
}
// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) {
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
vars := mux.Vars(r)
inviteID := vars["inviteId"]
if inviteID == "" {
@@ -238,14 +221,7 @@ func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request
}
// deleteInvite handles DELETE /api/users/invites/{inviteId}
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
inviteID := vars["inviteId"]
if inviteID == "" {
@@ -253,7 +229,7 @@ func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
err := h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -110,7 +110,11 @@ func TestListInvites(t *testing.T) {
})
rr := httptest.NewRecorder()
handler.listInvites(rr, req)
userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.listInvites(rr, req, userAuth)
assert.Equal(t, tc.expectedStatus, rr.Code)
@@ -235,7 +239,11 @@ func TestCreateInvite(t *testing.T) {
})
rr := httptest.NewRecorder()
handler.createInvite(rr, req)
userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.createInvite(rr, req, userAuth)
assert.Equal(t, tc.expectedStatus, rr.Code)
@@ -573,7 +581,11 @@ func TestRegenerateInvite(t *testing.T) {
}
rr := httptest.NewRecorder()
handler.regenerateInvite(rr, req)
userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.regenerateInvite(rr, req, userAuth)
assert.Equal(t, tc.expectedStatus, rr.Code)
@@ -651,7 +663,11 @@ func TestDeleteInvite(t *testing.T) {
}
rr := httptest.NewRecorder()
handler.deleteInvite(rr, req)
userAuth := &auth.UserAuth{
UserId: testUserID,
AccountId: testAccountID,
}
handler.deleteInvite(rr, req, userAuth)
assert.Equal(t, tc.expectedStatus, rr.Code)
})

View File

@@ -7,8 +7,11 @@ import (
"github.com/gorilla/mux"
"github.com/netbirdio/netbird/management/server/account"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
@@ -19,12 +22,12 @@ type patHandler struct {
accountManager account.Manager
}
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router) {
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
tokenHandler := newPATsHandler(accountManager)
router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getAllTokens)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Create, tokenHandler.createToken)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getToken)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Delete, tokenHandler.deleteToken)).Methods("DELETE", "OPTIONS")
}
// newPATsHandler creates a new patHandler HTTP handler
@@ -35,22 +38,15 @@ func newPATsHandler(accountManager account.Manager) *patHandler {
}
// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(userID) == 0 {
if len(targetUserID) == 0 {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
return
}
pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
pats, err := h.accountManager.GetAllPATs(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -65,14 +61,7 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
}
// getToken is HTTP GET handler that returns a personal access token for the given user
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -86,7 +75,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
return
}
pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
pat, err := h.accountManager.GetPAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -96,14 +85,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
}
// createToken is HTTP POST handler that creates a personal access token for the given user
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -112,13 +94,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
}
var req api.PostApiUsersUserIdTokensJSONRequestBody
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
pat, err := h.accountManager.CreatePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.Name, req.ExpiresIn)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -128,14 +110,7 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
}
// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -149,7 +124,7 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
err := h.accountManager.DeletePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -24,6 +24,18 @@ import (
"github.com/netbirdio/netbird/shared/management/status"
)
// wrapHandler wraps a handler function that requires userAuth parameter
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
h(w, r, userAuth)
}
}
const (
existingAccountID = "existingAccountID"
notFoundAccountID = "notFoundAccountID"
@@ -181,10 +193,10 @@ func TestTokenHandlers(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE")
router.HandleFunc("/api/users/{userId}/tokens", wrapHandler(p.getAllTokens)).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", wrapHandler(p.getToken)).Methods("GET")
router.HandleFunc("/api/users/{userId}/tokens", wrapHandler(p.createToken)).Methods("POST")
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", wrapHandler(p.deleteToken)).Methods("DELETE")
router.ServeHTTP(recorder, req)
res := recorder.Result()

View File

@@ -9,13 +9,15 @@ import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/users"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
nbcontext "github.com/netbirdio/netbird/management/server/context"
)
// handler is a handler that returns users of the account
@@ -23,18 +25,18 @@ type handler struct {
accountManager account.Manager
}
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
userHandler := newHandler(accountManager)
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
router.HandleFunc("/users/current", userHandler.getCurrentUser).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/{userId}/password", userHandler.changePassword).Methods("PUT", "OPTIONS")
addUsersTokensEndpoint(accountManager, router)
router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getAllUsers)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/current", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getCurrentUser)).Methods("GET", "OPTIONS")
router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.updateUser)).Methods("PUT", "OPTIONS")
router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.deleteUser)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.createUser)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/invite", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.inviteUser)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/approve", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.approveUser)).Methods("POST", "OPTIONS")
router.HandleFunc("/users/{userId}/reject", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.rejectUser)).Methods("DELETE", "OPTIONS")
router.HandleFunc("/users/{userId}/password", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.changePassword)).Methods("PUT", "OPTIONS")
addUsersTokensEndpoint(accountManager, router, permissionsManager)
}
// newHandler creates a new UsersHandler HTTP handler
@@ -45,19 +47,12 @@ func newHandler(accountManager account.Manager) *handler {
}
// updateUser is a PUT requests to update User data
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPut {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -89,7 +84,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
return
}
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{
newUser, err := h.accountManager.SaveUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.User{
Id: targetUserID,
Role: userRole,
AutoGroups: req.AutoGroups,
@@ -102,23 +97,16 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
}
// deleteUser is a DELETE request to delete a user
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -126,7 +114,7 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
err := h.accountManager.DeleteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -136,21 +124,14 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
}
// createUser creates a User in the system with a status "invited" (effectively this is a user invite).
func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) createUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
req := &api.PostApiUsersJSONRequestBody{}
err = json.NewDecoder(r.Body).Decode(&req)
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
@@ -171,7 +152,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
name = *req.Name
}
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{
newUser, err := h.accountManager.CreateUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.UserInfo{
Email: email,
Name: name,
Role: req.Role,
@@ -183,25 +164,18 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
util.WriteError(r.Context(), err, w)
return
}
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
}
// getAllUsers returns a list of users of the account this user belongs to.
// It also gathers additional user data (like email and name) from the IDP manager.
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodGet {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
data, err := h.accountManager.GetUsersFromAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -215,7 +189,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
continue
}
if serviceUser == "" {
users = append(users, toUserResponse(d, userID))
users = append(users, toUserResponse(d, userAuth.UserId))
continue
}
@@ -226,7 +200,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
return
}
if includeServiceUser == d.IsServiceUser {
users = append(users, toUserResponse(d, userID))
users = append(users, toUserResponse(d, userAuth.UserId))
}
}
@@ -235,19 +209,12 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
// inviteUser resend invitations to users who haven't activated their accounts,
// prior to the expiration period.
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
accountID, userID := userAuth.AccountId, userAuth.UserId
vars := mux.Vars(r)
targetUserID := vars["userId"]
if len(targetUserID) == 0 {
@@ -255,7 +222,7 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
return
}
err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
err := h.accountManager.InviteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -264,19 +231,13 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
}
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodGet {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
}
ctx := r.Context()
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.GetCurrentUserInfo(ctx, userAuth)
user, err := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -356,7 +317,7 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
}
// approveUser is a POST request to approve a user that is pending approval
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPost {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
@@ -369,11 +330,6 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
@@ -385,7 +341,7 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
}
// rejectUser is a DELETE request to reject a user that is pending approval
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodDelete {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
@@ -398,12 +354,7 @@ func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
err := h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -421,7 +372,7 @@ type passwordChangeRequest struct {
// changePassword is a PUT request to change user's password.
// Only available when embedded IDP is enabled.
// Users can only change their own password.
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
if r.Method != http.MethodPut {
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
return
@@ -434,19 +385,13 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
return
}
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
util.WriteError(r.Context(), err, w)
return
}
var req passwordChangeRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
return
}
err = h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
err := h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -232,7 +232,11 @@ func TestGetUsers(t *testing.T) {
AccountId: existingAccountID,
})
userHandler.getAllUsers(recorder, req)
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.getAllUsers(recorder, req, userAuth)
res := recorder.Result()
defer res.Body.Close()
@@ -343,7 +347,7 @@ func TestUpdateUser(t *testing.T) {
})
router := mux.NewRouter()
router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT")
router.HandleFunc("/api/users/{userId}", wrapHandler(userHandler.updateUser)).Methods("PUT")
router.ServeHTTP(recorder, req)
res := recorder.Result()
@@ -439,7 +443,11 @@ func TestCreateUser(t *testing.T) {
AccountId: existingAccountID,
})
userHandler.createUser(rr, req)
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.createUser(rr, req, userAuth)
res := rr.Result()
defer res.Body.Close()
@@ -490,7 +498,11 @@ func TestInviteUser(t *testing.T) {
rr := httptest.NewRecorder()
userHandler.inviteUser(rr, req)
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.inviteUser(rr, req, userAuth)
res := rr.Result()
defer res.Body.Close()
@@ -549,7 +561,11 @@ func TestDeleteUser(t *testing.T) {
rr := httptest.NewRecorder()
userHandler.deleteUser(rr, req)
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.deleteUser(rr, req, userAuth)
res := rr.Result()
defer res.Body.Close()
@@ -682,7 +698,11 @@ func TestCurrentUser(t *testing.T) {
rr := httptest.NewRecorder()
userHandler.getCurrentUser(rr, req)
userAuth := &auth.UserAuth{
UserId: existingUserID,
AccountId: existingAccountID,
}
userHandler.getCurrentUser(rr, req, userAuth)
res := rr.Result()
defer res.Body.Close()
@@ -779,7 +799,7 @@ func TestApproveUserEndpoint(t *testing.T) {
handler := newHandler(am)
router := mux.NewRouter()
router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST")
router.HandleFunc("/users/{userId}/approve", wrapHandler(handler.approveUser)).Methods("POST")
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
require.NoError(t, err)

View File

@@ -4,20 +4,25 @@ package permissions
import (
"context"
"net/http"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/permissions/roles"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
type Manager interface {
WithPermission(module modules.Module, operation operations.Operation, handlerFunc func(w http.ResponseWriter, r *http.Request, auth *auth.UserAuth)) http.HandlerFunc
ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error)
ValidateRoleModuleAccess(ctx context.Context, accountID string, role roles.RolePermissions, module modules.Module, operation operations.Operation) bool
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
@@ -36,6 +41,37 @@ func NewManager(store store.Store) Manager {
}
}
// WithPermission wraps an HTTP handler with permission checking logic.
func (m *managerImpl) WithPermission(
module modules.Module,
operation operations.Operation,
handlerFunc func(w http.ResponseWriter, r *http.Request, auth *auth.UserAuth),
) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
if err != nil {
log.WithContext(r.Context()).Errorf("failed to get user auth from context: %v", err)
util.WriteError(r.Context(), err, w)
return
}
allowed, err := m.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, module, operation)
if err != nil {
log.WithContext(r.Context()).Errorf("failed to validate permissions for user %s on account %s: %v", userAuth.UserId, userAuth.AccountId, err)
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
return
}
if !allowed {
log.WithContext(r.Context()).Tracef("user %s on account %s is not allowed to %s in %s", userAuth.UserId, userAuth.AccountId, operation, module)
util.WriteError(r.Context(), status.NewPermissionDeniedError(), w)
return
}
handlerFunc(w, r, &userAuth)
}
}
func (m *managerImpl) ValidateUserPermissions(
ctx context.Context,
accountID string,

View File

@@ -6,6 +6,7 @@ package permissions
import (
context "context"
http "net/http"
reflect "reflect"
gomock "github.com/golang/mock/gomock"
@@ -14,6 +15,7 @@ import (
operations "github.com/netbirdio/netbird/management/server/permissions/operations"
roles "github.com/netbirdio/netbird/management/server/permissions/roles"
types "github.com/netbirdio/netbird/management/server/types"
auth "github.com/netbirdio/netbird/shared/auth"
)
// MockManager is a mock of Manager interface.
@@ -108,3 +110,17 @@ func (mr *MockManagerMockRecorder) ValidateUserPermissions(ctx, accountID, userI
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateUserPermissions", reflect.TypeOf((*MockManager)(nil).ValidateUserPermissions), ctx, accountID, userID, module, operation)
}
// WithPermission mocks base method.
func (m *MockManager) WithPermission(module modules.Module, operation operations.Operation, handlerFunc func(http.ResponseWriter, *http.Request, *auth.UserAuth)) http.HandlerFunc {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "WithPermission", module, operation, handlerFunc)
ret0, _ := ret[0].(http.HandlerFunc)
return ret0
}
// WithPermission indicates an expected call of WithPermission.
func (mr *MockManagerMockRecorder) WithPermission(module, operation, handlerFunc interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithPermission", reflect.TypeOf((*MockManager)(nil).WithPermission), module, operation, handlerFunc)
}