mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-31 21:19:55 +00:00
use api wrapper for permissions management
This commit is contained in:
@@ -6,7 +6,10 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -15,21 +18,15 @@ type handler struct {
|
||||
manager accesslogs.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager) {
|
||||
func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager, permissionsManager permissions.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/events/proxy", h.getAccessLogs).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/events/proxy", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAccessLogs)).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var filter accesslogs.AccessLogFilter
|
||||
filter.ParseFromRequest(r)
|
||||
|
||||
|
||||
@@ -7,7 +7,10 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -17,15 +20,15 @@ type handler struct {
|
||||
manager Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager Manager) {
|
||||
func RegisterEndpoints(router *mux.Router, manager Manager, permissionsManager permissions.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllDomains)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Create, h.createCustomDomain)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/domains/{domainId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteCustomDomain)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/domains/{domainId}/validate", permissionsManager.WithPermission(modules.Services, operations.Read, h.triggerCustomDomainValidation)).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType {
|
||||
@@ -53,13 +56,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain {
|
||||
return resp
|
||||
}
|
||||
|
||||
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -74,13 +71,7 @@ func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, ret)
|
||||
}
|
||||
|
||||
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.PostApiReverseProxiesDomainsJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@@ -96,13 +87,7 @@ func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, domainToApi(domain))
|
||||
}
|
||||
|
||||
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
domainID := mux.Vars(r)["domainId"]
|
||||
if domainID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
|
||||
@@ -117,13 +102,7 @@ func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
domainID := mux.Vars(r)["domainId"]
|
||||
if domainID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w)
|
||||
|
||||
@@ -10,7 +10,10 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs"
|
||||
accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager"
|
||||
domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -21,30 +24,24 @@ type handler struct {
|
||||
}
|
||||
|
||||
// RegisterEndpoints registers all service HTTP endpoints.
|
||||
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) {
|
||||
func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
domainRouter := router.PathPrefix("/reverse-proxies").Subrouter()
|
||||
domainmanager.RegisterEndpoints(domainRouter, domainManager)
|
||||
domainmanager.RegisterEndpoints(domainRouter, domainManager, permissionsManager)
|
||||
|
||||
accesslogsmanager.RegisterEndpoints(router, accessLogsManager)
|
||||
accesslogsmanager.RegisterEndpoints(router, accessLogsManager, permissionsManager)
|
||||
|
||||
router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllServices)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Create, h.createService)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Read, h.getService)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Update, h.updateService)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteService)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -59,13 +56,7 @@ func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, apiServices)
|
||||
}
|
||||
|
||||
func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.ServiceRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@@ -75,7 +66,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
||||
service := new(reverseproxy.Service)
|
||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||
|
||||
if err = service.Validate(); err != nil {
|
||||
if err := service.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -89,13 +80,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
serviceID := mux.Vars(r)["serviceId"]
|
||||
if serviceID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||
@@ -111,13 +96,7 @@ func (h *handler) getService(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, service.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
serviceID := mux.Vars(r)["serviceId"]
|
||||
if serviceID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||
@@ -134,7 +113,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
||||
service.ID = serviceID
|
||||
service.FromAPIRequest(&req, userAuth.AccountId)
|
||||
|
||||
if err = service.Validate(); err != nil {
|
||||
if err := service.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -148,13 +127,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) deleteService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
serviceID := mux.Vars(r)["serviceId"]
|
||||
if serviceID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w)
|
||||
|
||||
@@ -7,7 +7,10 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -17,25 +20,19 @@ type handler struct {
|
||||
manager zones.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager zones.Manager) {
|
||||
func RegisterEndpoints(router *mux.Router, manager zones.Manager, permissionsManager permissions.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllZones)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createZone)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getZone)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateZone)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteZone)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -50,13 +47,7 @@ func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, apiZones)
|
||||
}
|
||||
|
||||
func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) createZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.PostApiDnsZonesJSONRequestBody
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@@ -66,7 +57,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
|
||||
zone := new(zones.Zone)
|
||||
zone.FromAPIRequest(&req)
|
||||
|
||||
if err = zone.Validate(); err != nil {
|
||||
if err := zone.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -80,13 +71,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -102,13 +87,7 @@ func (h *handler) getZone(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -116,7 +95,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var req api.PutApiDnsZonesZoneIdJSONRequestBody
|
||||
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
@@ -125,7 +104,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
|
||||
zone.FromAPIRequest(&req)
|
||||
zone.ID = zoneID
|
||||
|
||||
if err = zone.Validate(); err != nil {
|
||||
if err := zone.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -139,20 +118,14 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
|
||||
if err := h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -7,7 +7,10 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -17,25 +20,19 @@ type handler struct {
|
||||
manager records.Manager
|
||||
}
|
||||
|
||||
func RegisterEndpoints(router *mux.Router, manager records.Manager) {
|
||||
func RegisterEndpoints(router *mux.Router, manager records.Manager, permissionsManager permissions.Manager) {
|
||||
h := &handler{
|
||||
manager: manager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllRecords)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createRecord)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getRecord)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateRecord)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteRecord)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -56,13 +53,7 @@ func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, apiRecords)
|
||||
}
|
||||
|
||||
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) createRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -78,7 +69,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
|
||||
record := new(records.Record)
|
||||
record.FromAPIRequest(&req)
|
||||
|
||||
if err = record.Validate(); err != nil {
|
||||
if err := record.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -92,13 +83,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) getRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -120,13 +105,7 @@ func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, record.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -140,7 +119,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody
|
||||
if err = json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
@@ -149,7 +128,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
|
||||
record.FromAPIRequest(&req)
|
||||
record.ID = recordID
|
||||
|
||||
if err = record.Validate(); err != nil {
|
||||
if err := record.Validate(); err != nil {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w)
|
||||
return
|
||||
}
|
||||
@@ -163,13 +142,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
zoneID := mux.Vars(r)["zoneId"]
|
||||
if zoneID == "" {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w)
|
||||
@@ -182,7 +155,7 @@ func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
|
||||
if err := h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -154,27 +154,27 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
|
||||
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
||||
}
|
||||
|
||||
accounts.AddEndpoints(accountManager, settingsManager, router)
|
||||
accounts.AddEndpoints(accountManager, settingsManager, router, permissionsManager)
|
||||
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
|
||||
users.AddEndpoints(accountManager, router)
|
||||
users.AddInvitesEndpoints(accountManager, router)
|
||||
users.AddEndpoints(accountManager, router, permissionsManager)
|
||||
users.AddInvitesEndpoints(accountManager, router, permissionsManager)
|
||||
users.AddPublicInvitesEndpoints(accountManager, router)
|
||||
setup_keys.AddEndpoints(accountManager, router)
|
||||
policies.AddEndpoints(accountManager, LocationManager, router)
|
||||
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router)
|
||||
setup_keys.AddEndpoints(accountManager, router, permissionsManager)
|
||||
policies.AddEndpoints(accountManager, LocationManager, router, permissionsManager)
|
||||
policies.AddPostureCheckEndpoints(accountManager, LocationManager, router, permissionsManager)
|
||||
policies.AddLocationsEndpoints(accountManager, LocationManager, permissionsManager, router)
|
||||
groups.AddEndpoints(accountManager, router)
|
||||
routes.AddEndpoints(accountManager, router)
|
||||
dns.AddEndpoints(accountManager, router)
|
||||
events.AddEndpoints(accountManager, router)
|
||||
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router)
|
||||
zonesManager.RegisterEndpoints(router, zManager)
|
||||
recordsManager.RegisterEndpoints(router, rManager)
|
||||
idp.AddEndpoints(accountManager, router)
|
||||
groups.AddEndpoints(accountManager, router, permissionsManager)
|
||||
routes.AddEndpoints(accountManager, router, permissionsManager)
|
||||
dns.AddEndpoints(accountManager, router, permissionsManager)
|
||||
events.AddEndpoints(accountManager, router, permissionsManager)
|
||||
networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, permissionsManager, router)
|
||||
zonesManager.RegisterEndpoints(router, zManager, permissionsManager)
|
||||
recordsManager.RegisterEndpoints(router, rManager, permissionsManager)
|
||||
idp.AddEndpoints(accountManager, router, permissionsManager)
|
||||
instance.AddEndpoints(instanceManager, router)
|
||||
instance.AddVersionEndpoint(instanceManager, router)
|
||||
instance.AddVersionEndpoint(instanceManager, router, permissionsManager)
|
||||
if reverseProxyManager != nil && reverseProxyDomainManager != nil {
|
||||
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router)
|
||||
reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)
|
||||
}
|
||||
|
||||
// Register OAuth callback handler for proxy authentication
|
||||
|
||||
@@ -13,9 +13,12 @@ import (
|
||||
goversion "github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/settings"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -40,11 +43,11 @@ type handler struct {
|
||||
settingsManager settings.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) {
|
||||
func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
accountsHandler := newHandler(accountManager, settingsManager)
|
||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Update, accountsHandler.updateAccount)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Delete, accountsHandler.deleteAccount)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/accounts", permissionsManager.WithPermission(modules.Accounts, operations.Read, accountsHandler.getAllAccounts)).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new handler HTTP handler
|
||||
@@ -136,34 +139,26 @@ func calculateRequiredAddresses(peerCount int) int64 {
|
||||
}
|
||||
|
||||
// getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account.
|
||||
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
meta, err := h.accountManager.GetAccountMeta(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID)
|
||||
settings, err := h.settingsManager.GetSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.settingsManager.GetSettings(r.Context(), accountID, userID)
|
||||
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp := toAccountResponse(accountID, settings, meta, onboarding)
|
||||
resp := toAccountResponse(userAuth.AccountId, settings, meta, onboarding)
|
||||
util.WriteJSONObject(r.Context(), w, []*api.Account{resp})
|
||||
}
|
||||
|
||||
@@ -223,15 +218,7 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
|
||||
}
|
||||
|
||||
// updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings)
|
||||
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
_, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
accountID := vars["accountId"]
|
||||
if len(accountID) == 0 {
|
||||
@@ -240,7 +227,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var req api.PutApiAccountsAccountIdJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -257,7 +244,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w)
|
||||
return
|
||||
}
|
||||
if err := h.validateNetworkRange(r.Context(), accountID, userID, prefix); err != nil {
|
||||
if err := h.validateNetworkRange(r.Context(), accountID, userAuth.UserId, prefix); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
@@ -272,19 +259,19 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding)
|
||||
updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userAuth.UserId, onboarding)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
|
||||
updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userAuth.UserId, settings)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID)
|
||||
meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -296,13 +283,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// deleteAccount is a HTTP DELETE handler to delete an account
|
||||
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
targetAccountID := vars["accountId"]
|
||||
if len(targetAccountID) == 0 {
|
||||
@@ -310,7 +291,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId)
|
||||
err := h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -72,6 +72,18 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
|
||||
}
|
||||
}
|
||||
|
||||
// wrapHandler wraps a handler function that requires userAuth parameter
|
||||
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h(w, r, userAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
accountID := "test_account"
|
||||
adminUser := types.NewAdminUser("test_user")
|
||||
@@ -284,8 +296,8 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET")
|
||||
router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT")
|
||||
router.HandleFunc("/api/accounts", wrapHandler(handler.getAllAccounts)).Methods("GET")
|
||||
router.HandleFunc("/api/accounts/{accountId}", wrapHandler(handler.updateAccount)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -5,11 +5,13 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -19,15 +21,15 @@ type dnsSettingsHandler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
addDNSSettingEndpoint(accountManager, router)
|
||||
addDNSNameserversEndpoint(accountManager, router)
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
addDNSSettingEndpoint(accountManager, router, permissionsManager)
|
||||
addDNSNameserversEndpoint(accountManager, router, permissionsManager)
|
||||
}
|
||||
|
||||
func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router) {
|
||||
func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
dnsSettingsHandler := newDNSSettingsHandler(accountManager)
|
||||
router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Read, dnsSettingsHandler.getDNSSettings)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Update, dnsSettingsHandler.updateDNSSettings)).Methods("PUT", "OPTIONS")
|
||||
}
|
||||
|
||||
// newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler
|
||||
@@ -36,17 +38,8 @@ func newDNSSettingsHandler(accountManager account.Manager) *dnsSettingsHandler {
|
||||
}
|
||||
|
||||
// getDNSSettings returns the DNS settings for the account
|
||||
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID)
|
||||
func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -60,17 +53,9 @@ func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Reque
|
||||
}
|
||||
|
||||
// updateDNSSettings handles update to DNS settings of an account
|
||||
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.PutApiDnsSettingsJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -80,7 +65,7 @@ func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Re
|
||||
DisabledManagementGroups: req.DisabledManagementGroups,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings)
|
||||
err = h.accountManager.SaveDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId, updateDNSSettings)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -23,6 +23,18 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
|
||||
// wrapHandler wraps a handler function that requires userAuth parameter
|
||||
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h(w, r, userAuth)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
testDNSSettingsAccountID = "test_id"
|
||||
testDNSSettingsExistingGroup = "test_group"
|
||||
@@ -115,8 +127,8 @@ func TestDNSSettingsHandlers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET")
|
||||
router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT")
|
||||
router.HandleFunc("/api/dns/settings", wrapHandler(p.getDNSSettings)).Methods("GET")
|
||||
router.HandleFunc("/api/dns/settings", wrapHandler(p.updateDNSSettings)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -6,11 +6,13 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -21,13 +23,13 @@ type nameserversHandler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router) {
|
||||
func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
nameserversHandler := newNameserversHandler(accountManager)
|
||||
router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getAllNameservers)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Create, nameserversHandler.createNameserverGroup)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Update, nameserversHandler.updateNameserverGroup)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getNameserverGroup)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Delete, nameserversHandler.deleteNameserverGroup)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newNameserversHandler returns a new instance of nameserversHandler handler
|
||||
@@ -36,17 +38,8 @@ func newNameserversHandler(accountManager account.Manager) *nameserversHandler {
|
||||
}
|
||||
|
||||
// getAllNameservers returns the list of nameserver groups for the account
|
||||
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID)
|
||||
func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -61,17 +54,9 @@ func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Re
|
||||
}
|
||||
|
||||
// createNameserverGroup handles nameserver group creation request
|
||||
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.PostApiDnsNameserversJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -83,7 +68,7 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
|
||||
return
|
||||
}
|
||||
|
||||
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled)
|
||||
nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), userAuth.AccountId, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userAuth.UserId, req.SearchDomainsEnabled)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -95,15 +80,7 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
// updateNameserverGroup handles update to a nameserver group identified by a given ID
|
||||
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
@@ -111,7 +88,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
var req api.PutApiDnsNameserversNsgroupIdJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -135,7 +112,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
SearchDomainsEnabled: req.SearchDomainsEnabled,
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup)
|
||||
err = h.accountManager.SaveNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, updatedNSGroup)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -147,22 +124,14 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
// deleteNameserverGroup handles nameserver group deletion request
|
||||
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID)
|
||||
err := h.accountManager.DeleteNameServerGroup(r.Context(), userAuth.AccountId, nsGroupID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -172,22 +141,14 @@ func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
// getNameserverGroup handles a nameserver group Get request identified by ID
|
||||
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
nsGroupID := mux.Vars(r)["nsgroupId"]
|
||||
if len(nsGroupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID)
|
||||
nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, nsGroupID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -24,6 +24,18 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
|
||||
// wrapHandler wraps a handler function that requires userAuth parameter
|
||||
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h(w, r, userAuth)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
existingNSGroupID = "existingNSGroupID"
|
||||
notFoundNSGroupID = "notFoundNSGroupID"
|
||||
@@ -201,10 +213,10 @@ func TestNameserversHandlers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET")
|
||||
router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST")
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE")
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT")
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", wrapHandler(p.getNameserverGroup)).Methods("GET")
|
||||
router.HandleFunc("/api/dns/nameservers", wrapHandler(p.createNameserverGroup)).Methods("POST")
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", wrapHandler(p.deleteNameserverGroup)).Methods("DELETE")
|
||||
router.HandleFunc("/api/dns/nameservers/{nsgroupId}", wrapHandler(p.updateNameserverGroup)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -5,11 +5,13 @@ import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -19,10 +21,10 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
eventsHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/events/audit", eventsHandler.getAllEvents).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/events", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/events/audit", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new events handler
|
||||
@@ -31,17 +33,8 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// getAllEvents list of the given account
|
||||
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID)
|
||||
func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
accountEvents, err := h.accountManager.GetEvents(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -163,6 +163,18 @@ func generateEvents(accountID, userID string) []*activity.Event {
|
||||
return events
|
||||
}
|
||||
|
||||
// wrapHandler wraps a handler function that requires userAuth parameter
|
||||
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h(w, r, userAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEvents_GetEvents(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
@@ -196,7 +208,7 @@ func TestEvents_GetEvents(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET")
|
||||
router.HandleFunc("/api/events/", wrapHandler(handler.getAllEvents)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -8,10 +8,12 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -22,13 +24,13 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
groupsHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", groupsHandler.getGroup).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", groupsHandler.deleteGroup).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getAllGroups)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Create, groupsHandler.createGroup)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Update, groupsHandler.updateGroup)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getGroup)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Delete, groupsHandler.deleteGroup)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new groups handler
|
||||
@@ -39,26 +41,18 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// getAllGroups list for the account
|
||||
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Error(err)
|
||||
http.Redirect(w, r, "/", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
// Check if filtering by name
|
||||
groupName := r.URL.Query().Get("name")
|
||||
if groupName != "" {
|
||||
// Get single group by name
|
||||
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID)
|
||||
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "")
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -71,13 +65,13 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// Get all groups
|
||||
groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
groups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "")
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -92,15 +86,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// updateGroup handles update to a group identified by a given ID
|
||||
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
groupID, ok := vars["groupId"]
|
||||
if !ok {
|
||||
@@ -112,13 +98,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
|
||||
existingGroup, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
|
||||
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -166,13 +152,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
IntegrationReference: existingGroup.IntegrationReference,
|
||||
}
|
||||
|
||||
if err := h.accountManager.UpdateGroup(r.Context(), accountID, userID, &group); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err)
|
||||
if err := h.accountManager.UpdateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group); err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, userAuth.AccountId, err)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "")
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -182,17 +168,9 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// createGroup handles group creation request
|
||||
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) createGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.PostApiGroupsJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -226,13 +204,13 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
|
||||
Issued: types.GroupIssuedAPI,
|
||||
}
|
||||
|
||||
err = h.accountManager.CreateGroup(r.Context(), accountID, userID, &group)
|
||||
err = h.accountManager.CreateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "")
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -242,22 +220,14 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// deleteGroup handles group deletion request
|
||||
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
groupID := mux.Vars(r)["groupId"]
|
||||
if len(groupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID)
|
||||
err := h.accountManager.DeleteGroup(r.Context(), userAuth.AccountId, userAuth.UserId, groupID)
|
||||
if err != nil {
|
||||
wrappedErr, ok := err.(interface{ Unwrap() []error })
|
||||
if ok && len(wrappedErr.Unwrap()) > 0 {
|
||||
@@ -273,34 +243,26 @@ func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// getGroup returns a group
|
||||
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) getGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
groupID := mux.Vars(r)["groupId"]
|
||||
if len(groupID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID)
|
||||
group, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "")
|
||||
accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "")
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group))
|
||||
|
||||
}
|
||||
|
||||
func toGroupResponse(peers []*nbpeer.Peer, group *types.Group) *api.Group {
|
||||
|
||||
@@ -33,6 +33,18 @@ var TestPeers = map[string]*nbpeer.Peer{
|
||||
"B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")},
|
||||
}
|
||||
|
||||
// wrapHandler wraps a handler function that requires userAuth parameter
|
||||
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h(w, r, userAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func initGroupTestData(initGroups ...*types.Group) *handler {
|
||||
return &handler{
|
||||
accountManager: &mock_server.MockAccountManager{
|
||||
@@ -141,7 +153,7 @@ func TestGetGroup(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET")
|
||||
router.HandleFunc("/api/groups/{groupId}", wrapHandler(p.getGroup)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -267,8 +279,8 @@ func TestWriteGroup(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/groups", p.createGroup).Methods("POST")
|
||||
router.HandleFunc("/api/groups/{groupId}", p.updateGroup).Methods("PUT")
|
||||
router.HandleFunc("/api/groups", wrapHandler(p.createGroup)).Methods("POST")
|
||||
router.HandleFunc("/api/groups/{groupId}", wrapHandler(p.updateGroup)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -345,7 +357,7 @@ func TestGetAllGroups(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/groups", p.getAllGroups).Methods("GET")
|
||||
router.HandleFunc("/api/groups", wrapHandler(p.getAllGroups)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -426,7 +438,7 @@ func TestDeleteGroup(t *testing.T) {
|
||||
AccountId: "test_id",
|
||||
})
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE")
|
||||
router.HandleFunc("/api/groups/{groupId}", wrapHandler(p.deleteGroup)).Methods("DELETE")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -7,8 +7,11 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -20,13 +23,13 @@ type handler struct {
|
||||
}
|
||||
|
||||
// AddEndpoints registers identity provider endpoints
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
h := newHandler(accountManager)
|
||||
router.HandleFunc("/identity-providers", h.getAllIdentityProviders).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers", h.createIdentityProvider).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getAllIdentityProviders)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Create, h.createIdentityProvider)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getIdentityProvider)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Update, h.updateIdentityProvider)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Delete, h.deleteIdentityProvider)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newHandler(accountManager account.Manager) *handler {
|
||||
@@ -36,16 +39,8 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// getAllIdentityProviders returns all identity providers for the account
|
||||
func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
providers, err := h.accountManager.GetIdentityProviders(r.Context(), accountID, userID)
|
||||
func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
providers, err := h.accountManager.GetIdentityProviders(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -60,15 +55,7 @@ func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
// getIdentityProvider returns a specific identity provider
|
||||
func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
idpID := vars["idpId"]
|
||||
if idpID == "" {
|
||||
@@ -76,7 +63,7 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
provider, err := h.accountManager.GetIdentityProvider(r.Context(), accountID, idpID, userID)
|
||||
provider, err := h.accountManager.GetIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -86,15 +73,7 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// createIdentityProvider creates a new identity provider
|
||||
func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.IdentityProviderRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@@ -103,7 +82,7 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
idp := fromAPIRequest(&req)
|
||||
|
||||
created, err := h.accountManager.CreateIdentityProvider(r.Context(), accountID, userID, idp)
|
||||
created, err := h.accountManager.CreateIdentityProvider(r.Context(), userAuth.AccountId, userAuth.UserId, idp)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -113,15 +92,7 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// updateIdentityProvider updates an existing identity provider
|
||||
func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
idpID := vars["idpId"]
|
||||
if idpID == "" {
|
||||
@@ -137,7 +108,7 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
idp := fromAPIRequest(&req)
|
||||
|
||||
updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), accountID, idpID, userID, idp)
|
||||
updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId, idp)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -147,15 +118,7 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
// deleteIdentityProvider deletes an identity provider
|
||||
func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
idpID := vars["idpId"]
|
||||
if idpID == "" {
|
||||
@@ -163,7 +126,7 @@ func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request)
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.accountManager.DeleteIdentityProvider(r.Context(), accountID, idpID, userID); err != nil {
|
||||
if err := h.accountManager.DeleteIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -87,6 +87,18 @@ func initIDPTestData(existingIDP *types.IdentityProvider) *handler {
|
||||
}
|
||||
}
|
||||
|
||||
// wrapHandler wraps a handler function that requires userAuth parameter
|
||||
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h(w, r, userAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllIdentityProviders(t *testing.T) {
|
||||
existingIDP := &types.IdentityProvider{
|
||||
ID: existingIDPID,
|
||||
@@ -120,7 +132,7 @@ func TestGetAllIdentityProviders(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers", h.getAllIdentityProviders).Methods("GET")
|
||||
router.HandleFunc("/api/identity-providers", wrapHandler(h.getAllIdentityProviders)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -180,7 +192,7 @@ func TestGetIdentityProvider(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET")
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", wrapHandler(h.getIdentityProvider)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -242,7 +254,7 @@ func TestCreateIdentityProvider(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers", h.createIdentityProvider).Methods("POST")
|
||||
router.HandleFunc("/api/identity-providers", wrapHandler(h.createIdentityProvider)).Methods("POST")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -328,7 +340,7 @@ func TestUpdateIdentityProvider(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT")
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", wrapHandler(h.updateIdentityProvider)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -388,7 +400,7 @@ func TestDeleteIdentityProvider(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE")
|
||||
router.HandleFunc("/api/identity-providers/{idpId}", wrapHandler(h.deleteIdentityProvider)).Methods("DELETE")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -8,6 +8,10 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -29,12 +33,12 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
|
||||
}
|
||||
|
||||
// AddVersionEndpoint registers the authenticated version endpoint.
|
||||
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) {
|
||||
func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
h := &handler{
|
||||
instanceManager: instanceManager,
|
||||
}
|
||||
|
||||
router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/instance/version", permissionsManager.WithPermission(modules.Settings, operations.Read, h.getVersionInfo)).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// getInstanceStatus returns the instance status including whether setup is required.
|
||||
@@ -77,7 +81,7 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// getVersionInfo returns version information for NetBird components.
|
||||
// This endpoint requires authentication.
|
||||
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
versionInfo, err := h.instanceManager.GetVersionInfo(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get version info: %v", err)
|
||||
|
||||
@@ -296,7 +296,7 @@ func TestSetup_ManagerError(t *testing.T) {
|
||||
func TestGetVersionInfo_Success(t *testing.T) {
|
||||
manager := &mockInstanceManager{}
|
||||
router := mux.NewRouter()
|
||||
AddVersionEndpoint(manager, router)
|
||||
AddVersionEndpoint(manager, router, nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
@@ -324,7 +324,7 @@ func TestGetVersionInfo_Error(t *testing.T) {
|
||||
},
|
||||
}
|
||||
router := mux.NewRouter()
|
||||
AddVersionEndpoint(manager, router)
|
||||
AddVersionEndpoint(manager, router, nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/instance/version", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
@@ -10,14 +10,17 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/networks"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/networks/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -33,16 +36,16 @@ type handler struct {
|
||||
groupsManager groups.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, router *mux.Router) {
|
||||
addRouterEndpoints(routerManager, router)
|
||||
addResourceEndpoints(resourceManager, groupsManager, router)
|
||||
func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
addRouterEndpoints(routerManager, permissionsManager, router)
|
||||
addResourceEndpoints(resourceManager, groupsManager, permissionsManager, router)
|
||||
|
||||
networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager)
|
||||
router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", networksHandler.updateNetwork).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getAllNetworks)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Create, networksHandler.createNetwork)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getNetwork)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Update, networksHandler.updateNetwork)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, networksHandler.deleteNetwork)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager) *handler {
|
||||
@@ -55,40 +58,32 @@ func newHandler(networksManager networks.Manager, resourceManager resources.Mana
|
||||
}
|
||||
}
|
||||
|
||||
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
networks, err := h.networksManager.GetAllNetworks(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID)
|
||||
resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), accountID, userID)
|
||||
groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), accountID, userID)
|
||||
routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccount(r.Context(), accountID)
|
||||
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -97,16 +92,9 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups, account))
|
||||
}
|
||||
|
||||
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.NetworkRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -115,14 +103,14 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
network := &types.Network{}
|
||||
network.FromAPIRequest(&req)
|
||||
|
||||
network.AccountID = accountID
|
||||
network, err = h.networksManager.CreateNetwork(r.Context(), userID, network)
|
||||
network.AccountID = userAuth.AccountId
|
||||
network, err = h.networksManager.CreateNetwork(r.Context(), userAuth.UserId, network)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccount(r.Context(), accountID)
|
||||
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -133,14 +121,7 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0, policyIDs))
|
||||
}
|
||||
|
||||
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
networkID := vars["networkId"]
|
||||
if len(networkID) == 0 {
|
||||
@@ -148,19 +129,19 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
network, err := h.networksManager.GetNetwork(r.Context(), accountID, userID, networkID)
|
||||
network, err := h.networksManager.GetNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID)
|
||||
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccount(r.Context(), accountID)
|
||||
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -171,14 +152,7 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
|
||||
}
|
||||
|
||||
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
networkID := vars["networkId"]
|
||||
if len(networkID) == 0 {
|
||||
@@ -187,7 +161,7 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var req api.NetworkRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -197,20 +171,20 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
network.FromAPIRequest(&req)
|
||||
|
||||
network.ID = networkID
|
||||
network.AccountID = accountID
|
||||
network, err = h.networksManager.UpdateNetwork(r.Context(), userID, network)
|
||||
network.AccountID = userAuth.AccountId
|
||||
network, err = h.networksManager.UpdateNetwork(r.Context(), userAuth.UserId, network)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID)
|
||||
routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccount(r.Context(), accountID)
|
||||
account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -221,14 +195,7 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs))
|
||||
}
|
||||
|
||||
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
networkID := vars["networkId"]
|
||||
if len(networkID) == 0 {
|
||||
@@ -236,7 +203,7 @@ func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.networksManager.DeleteNetwork(r.Context(), accountID, userID, networkID)
|
||||
err := h.networksManager.DeleteNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -6,10 +6,13 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||
"github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -19,14 +22,14 @@ type resourceHandler struct {
|
||||
groupsManager groups.Manager
|
||||
}
|
||||
|
||||
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, router *mux.Router) {
|
||||
func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
resourceHandler := newResourceHandler(resourcesManager, groupsManager)
|
||||
router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.getResource).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.updateResource).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/networks/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInAccount)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInNetwork)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Create, resourceHandler.createResource)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getResource)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Update, resourceHandler.updateResource)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, resourceHandler.deleteResource)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager) *resourceHandler {
|
||||
@@ -36,22 +39,15 @@ func newResourceHandler(resourceManager resources.Manager, groupsManager groups.
|
||||
}
|
||||
}
|
||||
|
||||
func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID)
|
||||
resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -66,22 +62,14 @@ func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *htt
|
||||
|
||||
util.WriteJSONObject(r.Context(), w, resourcesResponse)
|
||||
}
|
||||
func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -97,17 +85,9 @@ func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *htt
|
||||
util.WriteJSONObject(r.Context(), w, resourcesResponse)
|
||||
}
|
||||
|
||||
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.NetworkResourceRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -117,14 +97,14 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request)
|
||||
resource.FromAPIRequest(&req)
|
||||
|
||||
resource.NetworkID = mux.Vars(r)["networkId"]
|
||||
resource.AccountID = accountID
|
||||
resource, err = h.resourceManager.CreateResource(r.Context(), userID, resource)
|
||||
resource.AccountID = userAuth.AccountId
|
||||
resource, err = h.resourceManager.CreateResource(r.Context(), userAuth.UserId, resource)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -135,23 +115,16 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request)
|
||||
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
|
||||
}
|
||||
|
||||
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
resourceID := mux.Vars(r)["resourceId"]
|
||||
resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID)
|
||||
resource, err := h.resourceManager.GetResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -162,16 +135,9 @@ func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
|
||||
}
|
||||
|
||||
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.NetworkResourceRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -182,14 +148,14 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request)
|
||||
|
||||
resource.ID = mux.Vars(r)["resourceId"]
|
||||
resource.NetworkID = mux.Vars(r)["networkId"]
|
||||
resource.AccountID = accountID
|
||||
resource, err = h.resourceManager.UpdateResource(r.Context(), userID, resource)
|
||||
resource.AccountID = userAuth.AccountId
|
||||
resource, err = h.resourceManager.UpdateResource(r.Context(), userAuth.UserId, resource)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -200,17 +166,10 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request)
|
||||
util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID]))
|
||||
}
|
||||
|
||||
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
resourceID := mux.Vars(r)["resourceId"]
|
||||
err = h.resourceManager.DeleteResource(r.Context(), accountID, userID, networkID, resourceID)
|
||||
err := h.resourceManager.DeleteResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -6,9 +6,12 @@ import (
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||
"github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
)
|
||||
@@ -17,14 +20,14 @@ type routersHandler struct {
|
||||
routersManager routers.Manager
|
||||
}
|
||||
|
||||
func addRouterEndpoints(routersManager routers.Manager, router *mux.Router) {
|
||||
func addRouterEndpoints(routersManager routers.Manager, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
routersHandler := newRoutersHandler(routersManager)
|
||||
router.HandleFunc("/networks/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers", routersHandler.getNetworkRouters).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/networks/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getAllRouters)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getNetworkRouters)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Create, routersHandler.createRouter)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getRouter)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Update, routersHandler.updateRouter)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, routersHandler.deleteRouter)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
func newRoutersHandler(routersManager routers.Manager) *routersHandler {
|
||||
@@ -33,16 +36,8 @@ func newRoutersHandler(routersManager routers.Manager) *routersHandler {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), accountID, userID)
|
||||
func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -58,17 +53,9 @@ func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, routersResponse)
|
||||
}
|
||||
|
||||
func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID)
|
||||
routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -82,18 +69,10 @@ func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Reques
|
||||
util.WriteJSONObject(r.Context(), w, routersResponse)
|
||||
}
|
||||
|
||||
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
var req api.NetworkRouterRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -103,9 +82,9 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
|
||||
router.FromAPIRequest(&req)
|
||||
|
||||
router.NetworkID = networkID
|
||||
router.AccountID = accountID
|
||||
router.AccountID = userAuth.AccountId
|
||||
router.Enabled = true
|
||||
router, err = h.routersManager.CreateRouter(r.Context(), userID, router)
|
||||
router, err = h.routersManager.CreateRouter(r.Context(), userAuth.UserId, router)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -114,18 +93,10 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
routerID := mux.Vars(r)["routerId"]
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID)
|
||||
router, err := h.routersManager.GetRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -134,17 +105,9 @@ func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.NetworkRouterRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -155,9 +118,9 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
router.NetworkID = mux.Vars(r)["networkId"]
|
||||
router.ID = mux.Vars(r)["routerId"]
|
||||
router.AccountID = accountID
|
||||
router.AccountID = userAuth.AccountId
|
||||
|
||||
router, err = h.routersManager.UpdateRouter(r.Context(), userID, router)
|
||||
router, err = h.routersManager.UpdateRouter(r.Context(), userAuth.UserId, router)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -166,17 +129,10 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, router.ToAPIResponse())
|
||||
}
|
||||
|
||||
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
routerID := mux.Vars(r)["routerId"]
|
||||
networkID := mux.Vars(r)["networkId"]
|
||||
err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID)
|
||||
err := h.routersManager.DeleteRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package peers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
@@ -14,13 +13,13 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/groups"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -35,14 +34,15 @@ type Handler struct {
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) {
|
||||
peersHandler := NewHandler(accountManager, networkMapController, permissionsManager)
|
||||
router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer).
|
||||
Methods("GET", "PUT", "DELETE", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.ListJobs).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs", peersHandler.CreateJob).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", peersHandler.GetJob).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAllPeers)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetPeer)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Update, peersHandler.UpdatePeer)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Delete, peersHandler.DeletePeer)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/accessible-peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAccessiblePeers)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/temporary-access", permissionsManager.WithPermission(modules.Peers, operations.Create, peersHandler.CreateTemporaryAccess)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.ListJobs)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Create, peersHandler.CreateJob)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/peers/{peerId}/jobs/{jobId}", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.GetJob)).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// NewHandler creates a new peers Handler
|
||||
@@ -54,14 +54,7 @@ func NewHandler(accountManager account.Manager, networkMapController network_map
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
|
||||
@@ -73,37 +66,30 @@ func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
job, err := types.NewJob(userAuth.UserId, userAuth.AccountId, peerID, req)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
if err := h.accountManager.CreatePeerJob(ctx, userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
if err := h.accountManager.CreatePeerJob(r.Context(), userAuth.AccountId, peerID, userAuth.UserId, job); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := toSingleJobResponse(job)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(ctx, w, resp)
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
|
||||
jobs, err := h.accountManager.GetAllPeerJobs(ctx, userAuth.AccountId, userAuth.UserId, peerID)
|
||||
jobs, err := h.accountManager.GetAllPeerJobs(r.Context(), userAuth.AccountId, userAuth.UserId, peerID)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -111,79 +97,88 @@ func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) {
|
||||
for _, job := range jobs {
|
||||
resp, err := toSingleJobResponse(job)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
respBody = append(respBody, resp)
|
||||
}
|
||||
|
||||
util.WriteJSONObject(ctx, w, respBody)
|
||||
util.WriteJSONObject(r.Context(), w, respBody)
|
||||
}
|
||||
|
||||
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
jobID := vars["jobId"]
|
||||
|
||||
job, err := h.accountManager.GetPeerJobByID(ctx, userAuth.AccountId, userAuth.UserId, peerID, jobID)
|
||||
job, err := h.accountManager.GetPeerJobByID(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, jobID)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := toSingleJobResponse(job)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
util.WriteJSONObject(ctx, w, resp)
|
||||
util.WriteJSONObject(r.Context(), w, resp)
|
||||
}
|
||||
|
||||
func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) {
|
||||
peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID)
|
||||
// GetPeer handles GET request for a single peer
|
||||
func (h *Handler) GetPeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
peer, err := h.accountManager.GetPeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
if peer.ProxyMeta.Embedded {
|
||||
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "not allowed to read peer"), w)
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
|
||||
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID)
|
||||
grps, _ := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peerID)
|
||||
grpsInfoMap := groups.ToGroupsInfoMap(grps, 0)
|
||||
|
||||
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
|
||||
_, valid := validPeers[peer.ID]
|
||||
reason := invalidPeers[peer.ID]
|
||||
|
||||
util.WriteJSONObject(ctx, w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||
}
|
||||
|
||||
func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) {
|
||||
// UpdatePeer handles PUT request to update a peer
|
||||
func (h *Handler) UpdatePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
req := &api.PeerRequest{}
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
@@ -192,11 +187,10 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
|
||||
}
|
||||
|
||||
update := &nbpeer.Peer{
|
||||
ID: peerID,
|
||||
SSHEnabled: req.SshEnabled,
|
||||
Name: req.Name,
|
||||
LoginExpirationEnabled: req.LoginExpirationEnabled,
|
||||
|
||||
ID: peerID,
|
||||
SSHEnabled: req.SshEnabled,
|
||||
Name: req.Name,
|
||||
LoginExpirationEnabled: req.LoginExpirationEnabled,
|
||||
InactivityExpirationEnabled: req.InactivityExpirationEnabled,
|
||||
}
|
||||
|
||||
@@ -210,41 +204,41 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
|
||||
if req.Ip != nil {
|
||||
addr, err := netip.ParseAddr(*req.Ip)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w)
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.accountManager.UpdatePeerIP(ctx, accountID, userID, peerID, addr); err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
if err = h.accountManager.UpdatePeerIP(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, addr); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update)
|
||||
peer, err := h.accountManager.UpdatePeer(r.Context(), userAuth.AccountId, userAuth.UserId, update)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator)
|
||||
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID)
|
||||
peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peer.ID)
|
||||
if err != nil {
|
||||
util.WriteError(ctx, err, w)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0)
|
||||
|
||||
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID)
|
||||
validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get validated peers: %v", err)
|
||||
util.WriteError(ctx, fmt.Errorf("internal error"), w)
|
||||
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -254,25 +248,8 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri
|
||||
util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason))
|
||||
}
|
||||
|
||||
func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) {
|
||||
err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to delete peer: %v", err)
|
||||
util.WriteError(ctx, err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(ctx, w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
// HandlePeer handles all peer requests for GET, PUT and DELETE operations
|
||||
func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
// DeletePeer handles DELETE request to delete a peer
|
||||
func (h *Handler) DeletePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
@@ -280,48 +257,34 @@ func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
switch r.Method {
|
||||
case http.MethodDelete:
|
||||
h.deletePeer(r.Context(), accountID, userID, peerID, w)
|
||||
err := h.accountManager.DeletePeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to delete peer: %v", err)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
case http.MethodGet:
|
||||
h.getPeer(r.Context(), accountID, peerID, userID, w)
|
||||
return
|
||||
case http.MethodPut:
|
||||
h.updatePeer(r.Context(), accountID, userID, peerID, w, r)
|
||||
return
|
||||
default:
|
||||
util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w)
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
// GetAllPeers returns a list of all peers associated with a provided account
|
||||
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
nameFilter := r.URL.Query().Get("name")
|
||||
ipFilter := r.URL.Query().Get("ip")
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, nameFilter, ipFilter)
|
||||
peers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, nameFilter, ipFilter)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, activity.SystemInitiator)
|
||||
settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
dnsDomain := h.networkMapController.GetDNSDomain(settings)
|
||||
|
||||
grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
grps, _ := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
|
||||
grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers))
|
||||
respBody := make([]*api.PeerBatch, 0, len(peers))
|
||||
@@ -332,7 +295,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) {
|
||||
respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0))
|
||||
}
|
||||
|
||||
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
@@ -356,15 +319,7 @@ func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersM
|
||||
}
|
||||
|
||||
// GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network.
|
||||
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
@@ -372,19 +327,19 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userID)
|
||||
user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), accountID, userID, modules.Peers, operations.Read)
|
||||
allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Peers, operations.Read)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
|
||||
return
|
||||
}
|
||||
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator)
|
||||
account, err := h.accountManager.GetAccountByID(r.Context(), userAuth.AccountId, activity.SystemInitiator)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -408,7 +363,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID)
|
||||
validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err)
|
||||
util.WriteError(r.Context(), fmt.Errorf("internal error"), w)
|
||||
@@ -422,13 +377,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain))
|
||||
}
|
||||
|
||||
func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
peerID := vars["peerId"]
|
||||
if len(peerID) == 0 {
|
||||
@@ -437,7 +386,7 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
||||
var req api.PeerTemporaryAccessRequest
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
|
||||
@@ -34,6 +34,18 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/mock_server"
|
||||
)
|
||||
|
||||
// wrapHandler wraps a handler function that requires userAuth parameter
|
||||
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h(w, r, userAuth)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
testPeerID = "test_peer"
|
||||
noUpdateChannelTestPeerID = "no-update-channel"
|
||||
@@ -307,9 +319,9 @@ func TestGetPeers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET")
|
||||
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("GET")
|
||||
router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("PUT")
|
||||
router.HandleFunc("/api/peers/", wrapHandler(p.GetAllPeers)).Methods("GET")
|
||||
router.HandleFunc("/api/peers/{peerId}", wrapHandler(p.GetPeer)).Methods("GET")
|
||||
router.HandleFunc("/api/peers/{peerId}", wrapHandler(p.UpdatePeer)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -498,7 +510,7 @@ func TestGetAccessiblePeers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET")
|
||||
router.HandleFunc("/api/peers/{peerId}/accessible-peers", wrapHandler(p.GetAccessiblePeers)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -582,7 +594,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) {
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/peers/{peerId}", p.HandlePeer).Methods("PUT")
|
||||
router.HandleFunc("/peers/{peerId}", wrapHandler(p.UpdatePeer)).Methods("PUT")
|
||||
|
||||
router.ServeHTTP(rr, req)
|
||||
|
||||
|
||||
@@ -65,6 +65,18 @@ func initGeolocationTestData(t *testing.T) *geolocationsHandler {
|
||||
}
|
||||
}
|
||||
|
||||
// wrapHandler wraps a handler function that requires userAuth parameter
|
||||
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h(w, r, userAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCitiesByCountry(t *testing.T) {
|
||||
tt := []struct {
|
||||
name string
|
||||
@@ -121,7 +133,7 @@ func TestGetCitiesByCountry(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET")
|
||||
router.HandleFunc("/api/locations/countries/{country}/cities", wrapHandler(geolocationHandler.getCitiesByCountry)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -214,7 +226,7 @@ func TestGetAllCountries(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET")
|
||||
router.HandleFunc("/api/locations/countries", wrapHandler(geolocationHandler.getAllCountries)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -7,11 +7,11 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -30,8 +30,8 @@ type geolocationsHandler struct {
|
||||
|
||||
func AddLocationsEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, permissionsManager permissions.Manager, router *mux.Router) {
|
||||
locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, permissionsManager)
|
||||
router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/locations/countries", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getAllCountries)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/locations/countries/{country}/cities", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getCitiesByCountry)).Methods("GET", "OPTIONS")
|
||||
}
|
||||
|
||||
// newGeolocationsHandlerHandler creates a new Geolocations handler
|
||||
@@ -44,12 +44,7 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa
|
||||
}
|
||||
|
||||
// getAllCountries retrieves a list of all countries
|
||||
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) {
|
||||
if err := l.authenticateUser(r); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if l.geolocationManager == nil {
|
||||
// TODO: update error message to include geo db self hosted doc link when ready
|
||||
util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w)
|
||||
@@ -70,12 +65,7 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req
|
||||
}
|
||||
|
||||
// getCitiesByCountry retrieves a list of cities based on the given country code
|
||||
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) {
|
||||
if err := l.authenticateUser(r); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
countryCode := vars["country"]
|
||||
if !countryCodeRegex.MatchString(countryCode) {
|
||||
@@ -102,27 +92,6 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.
|
||||
util.WriteJSONObject(r.Context(), w, cities)
|
||||
}
|
||||
|
||||
func (l *geolocationsHandler) authenticateUser(r *http.Request) error {
|
||||
ctx := r.Context()
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read)
|
||||
if err != nil {
|
||||
return status.NewPermissionValidationError(err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return status.NewPermissionDeniedError()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func toCountryResponse(country geolocation.Country) api.Country {
|
||||
return api.Country{
|
||||
CountryName: country.CountryName,
|
||||
|
||||
@@ -8,9 +8,12 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -21,13 +24,13 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) {
|
||||
func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
policiesHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getAllPolicies)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Create, policiesHandler.createPolicy)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Update, policiesHandler.updatePolicy)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getPolicy)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, policiesHandler.deletePolicy)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new policies handler
|
||||
@@ -38,22 +41,14 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// getAllPolicies list for the account
|
||||
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
listPolicies, err := h.accountManager.ListPolicies(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -73,15 +68,7 @@ func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// updatePolicy handles update to a policy identified by a given ID
|
||||
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
@@ -89,26 +76,18 @@ func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
|
||||
_, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
h.savePolicy(w, r, accountID, userID, policyID, false)
|
||||
h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, policyID, false)
|
||||
}
|
||||
|
||||
// createPolicy handles policy creation request
|
||||
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
h.savePolicy(w, r, accountID, userID, "", true)
|
||||
func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, "", true)
|
||||
}
|
||||
|
||||
// savePolicy handles policy creation and update
|
||||
@@ -303,14 +282,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s
|
||||
}
|
||||
|
||||
// deletePolicy handles policy deletion request
|
||||
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
@@ -318,7 +290,7 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil {
|
||||
if err := h.accountManager.DeletePolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
@@ -327,15 +299,7 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// getPolicy handles a group Get request identified by ID
|
||||
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
policyID := vars["policyId"]
|
||||
if len(policyID) == 0 {
|
||||
@@ -343,13 +307,13 @@ func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID)
|
||||
policy, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID)
|
||||
allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -21,6 +21,18 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// wrapHandler wraps a handler function that requires userAuth parameter
|
||||
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h(w, r, userAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func initPoliciesTestData(policies ...*types.Policy) *handler {
|
||||
testPolicies := make(map[string]*types.Policy, len(policies))
|
||||
for _, policy := range policies {
|
||||
@@ -111,7 +123,7 @@ func TestPoliciesGetPolicy(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET")
|
||||
router.HandleFunc("/api/policies/{policyId}", wrapHandler(p.getPolicy)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -275,8 +287,8 @@ func TestPoliciesWritePolicy(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/policies", p.createPolicy).Methods("POST")
|
||||
router.HandleFunc("/api/policies/{policyId}", p.updatePolicy).Methods("PUT")
|
||||
router.HandleFunc("/api/policies", wrapHandler(p.createPolicy)).Methods("POST")
|
||||
router.HandleFunc("/api/policies/{policyId}", wrapHandler(p.updatePolicy)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -7,9 +7,12 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -21,13 +24,13 @@ type postureChecksHandler struct {
|
||||
geolocationManager geolocation.Geolocation
|
||||
}
|
||||
|
||||
func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) {
|
||||
func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
postureCheckHandler := newPostureChecksHandler(accountManager, locationManager)
|
||||
router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getAllPostureChecks)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Create, postureCheckHandler.createPostureCheck)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Update, postureCheckHandler.updatePostureCheck)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getPostureCheck)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, postureCheckHandler.deletePostureCheck)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newPostureChecksHandler creates a new PostureChecks handler
|
||||
@@ -39,15 +42,8 @@ func newPostureChecksHandler(accountManager account.Manager, geolocationManager
|
||||
}
|
||||
|
||||
// getAllPostureChecks list for the account
|
||||
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID)
|
||||
func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -62,15 +58,7 @@ func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *htt
|
||||
}
|
||||
|
||||
// updatePostureCheck handles update to a posture check identified by a given ID
|
||||
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
@@ -78,37 +66,22 @@ func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http
|
||||
return
|
||||
}
|
||||
|
||||
_, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
|
||||
_, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
p.savePostureChecks(w, r, accountID, userID, postureChecksID, false)
|
||||
p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, postureChecksID, false)
|
||||
}
|
||||
|
||||
// createPostureCheck handles posture check creation request
|
||||
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
p.savePostureChecks(w, r, accountID, userID, "", true)
|
||||
func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, "", true)
|
||||
}
|
||||
|
||||
// getPostureCheck handles a posture check Get request identified by ID
|
||||
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
@@ -116,7 +89,7 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
|
||||
return
|
||||
}
|
||||
|
||||
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID)
|
||||
postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -126,14 +99,7 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re
|
||||
}
|
||||
|
||||
// deletePostureCheck handles posture check deletion request
|
||||
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
postureChecksID := vars["postureCheckId"]
|
||||
if len(postureChecksID) == 0 {
|
||||
@@ -141,7 +107,7 @@ func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http
|
||||
return
|
||||
}
|
||||
|
||||
if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil {
|
||||
if err := p.accountManager.DeletePostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId); err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -26,6 +26,18 @@ import (
|
||||
var berlin = "Berlin"
|
||||
var losAngeles = "Los Angeles"
|
||||
|
||||
// wrapHandler wraps a handler function that requires userAuth parameter
|
||||
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h(w, r, userAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksHandler {
|
||||
testPostureChecks := make(map[string]*posture.Checks, len(postureChecks))
|
||||
for _, postureCheck := range postureChecks {
|
||||
@@ -183,7 +195,7 @@ func TestGetPostureCheck(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET")
|
||||
router.HandleFunc("/api/posture-checks/{postureCheckId}", wrapHandler(p.getPostureCheck)).Methods("GET")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -841,8 +853,8 @@ func TestPostureCheckUpdate(t *testing.T) {
|
||||
}
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/posture-checks", defaultHandler.createPostureCheck).Methods("POST")
|
||||
router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.updatePostureCheck).Methods("PUT")
|
||||
router.HandleFunc("/api/posture-checks", wrapHandler(defaultHandler.createPostureCheck)).Methods("POST")
|
||||
router.HandleFunc("/api/posture-checks/{postureCheckId}", wrapHandler(defaultHandler.updatePostureCheck)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -9,8 +9,11 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
@@ -26,13 +29,13 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
routesHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", routesHandler.getRoute).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", routesHandler.deleteRoute).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getAllRoutes)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Create, routesHandler.createRoute)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Update, routesHandler.updateRoute)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getRoute)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Delete, routesHandler.deleteRoute)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler returns a new instance of routes handler
|
||||
@@ -43,16 +46,8 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// getAllRoutes returns the list of routes for the account
|
||||
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID)
|
||||
func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
routes, err := h.accountManager.ListRoutes(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -71,17 +66,9 @@ func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// createRoute handles route creation request
|
||||
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) createRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.PostApiRoutesJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -134,8 +121,8 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) {
|
||||
skipAutoApply = false
|
||||
}
|
||||
|
||||
newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds,
|
||||
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute, skipAutoApply)
|
||||
newRoute, err := h.accountManager.CreateRoute(r.Context(), userAuth.AccountId, newPrefix, networkType, domains, peerId, peerGroupIds,
|
||||
req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userAuth.UserId, req.KeepRoute, skipAutoApply)
|
||||
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -185,14 +172,7 @@ func (h *handler) validateRouteCommon(network *string, domains *[]string, peer *
|
||||
}
|
||||
|
||||
// updateRoute handles update to a route identified by a given ID
|
||||
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
routeID := vars["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
@@ -200,7 +180,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||
_, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -271,7 +251,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
newRoute.AccessControlGroups = *req.AccessControlGroups
|
||||
}
|
||||
|
||||
err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute)
|
||||
err = h.accountManager.SaveRoute(r.Context(), userAuth.AccountId, userAuth.UserId, newRoute)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -287,21 +267,14 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// deleteRoute handles route deletion request
|
||||
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
routeID := mux.Vars(r)["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||
err := h.accountManager.DeleteRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -311,22 +284,14 @@ func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// getRoute handles a route Get request identified by ID
|
||||
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) getRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
routeID := mux.Vars(r)["routeId"]
|
||||
if len(routeID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID)
|
||||
foundRoute, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -25,6 +25,18 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// wrapHandler wraps a handler function that requires userAuth parameter
|
||||
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h(w, r, userAuth)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
existingRouteID = "existingRouteID"
|
||||
existingRouteID2 = "existingRouteID2" // for peer_groups test
|
||||
@@ -501,10 +513,10 @@ func TestRoutesHandlers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET")
|
||||
router.HandleFunc("/api/routes/{routeId}", p.deleteRoute).Methods("DELETE")
|
||||
router.HandleFunc("/api/routes", p.createRoute).Methods("POST")
|
||||
router.HandleFunc("/api/routes/{routeId}", p.updateRoute).Methods("PUT")
|
||||
router.HandleFunc("/api/routes/{routeId}", wrapHandler(p.getRoute)).Methods("GET")
|
||||
router.HandleFunc("/api/routes/{routeId}", wrapHandler(p.deleteRoute)).Methods("DELETE")
|
||||
router.HandleFunc("/api/routes", wrapHandler(p.createRoute)).Methods("POST")
|
||||
router.HandleFunc("/api/routes/{routeId}", wrapHandler(p.updateRoute)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -9,8 +9,11 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -21,13 +24,13 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
keysHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", keysHandler.updateSetupKey).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", keysHandler.deleteSetupKey).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getAllSetupKeys)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Create, keysHandler.createSetupKey)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getSetupKey)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Update, keysHandler.updateSetupKey)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Delete, keysHandler.deleteSetupKey)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newHandler creates a new setup key handler
|
||||
@@ -38,16 +41,9 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// createSetupKey is a POST requests that creates a new SetupKey
|
||||
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
req := &api.PostApiSetupKeysJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -85,8 +81,8 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
allowExtraDNSLabels = *req.AllowExtraDnsLabels
|
||||
}
|
||||
|
||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, types.SetupKeyType(req.Type), expiresIn,
|
||||
req.AutoGroups, req.UsageLimit, userID, ephemeral, allowExtraDNSLabels)
|
||||
setupKey, err := h.accountManager.CreateSetupKey(r.Context(), userAuth.AccountId, req.Name, types.SetupKeyType(req.Type), expiresIn,
|
||||
req.AutoGroups, req.UsageLimit, userAuth.UserId, ephemeral, allowExtraDNSLabels)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -100,14 +96,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// getSetupKey is a GET request to get a SetupKey by ID
|
||||
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
@@ -115,7 +104,7 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID)
|
||||
key, err := h.accountManager.GetSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -125,14 +114,7 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// updateSetupKey is a PUT request to update server.SetupKey
|
||||
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
@@ -141,7 +123,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
req := &api.PutApiSetupKeysKeyIdJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -157,7 +139,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
newKey.Revoked = req.Revoked
|
||||
newKey.Id = keyID
|
||||
|
||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID)
|
||||
newKey, err = h.accountManager.SaveSetupKey(r.Context(), userAuth.AccountId, newKey, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -166,15 +148,8 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// getAllSetupKeys is a GET request that returns a list of SetupKey
|
||||
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID)
|
||||
func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -188,14 +163,7 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, apiSetupKeys)
|
||||
}
|
||||
|
||||
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
keyID := vars["keyId"]
|
||||
if len(keyID) == 0 {
|
||||
@@ -203,7 +171,7 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteSetupKey(r.Context(), accountID, userID, keyID)
|
||||
err := h.accountManager.DeleteSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -22,6 +22,18 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// wrapHandler wraps a handler function that requires userAuth parameter
|
||||
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h(w, r, userAuth)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
existingSetupKeyID = "existingSetupKeyID"
|
||||
newSetupKeyName = "New Setup Key"
|
||||
@@ -171,11 +183,11 @@ func TestSetupKeysHandlers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys", handler.createSetupKey).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.getSetupKey).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.updateSetupKey).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", handler.deleteSetupKey).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys", wrapHandler(handler.getAllSetupKeys)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys", wrapHandler(handler.createSetupKey)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", wrapHandler(handler.getSetupKey)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", wrapHandler(handler.updateSetupKey)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/api/setup-keys/{keyId}", wrapHandler(handler.deleteSetupKey)).Methods("DELETE", "OPTIONS")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -10,9 +10,12 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -55,14 +58,14 @@ type invitesHandler struct {
|
||||
}
|
||||
|
||||
// AddInvitesEndpoints registers invite-related endpoints
|
||||
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
h := &invitesHandler{accountManager: accountManager}
|
||||
|
||||
// Authenticated endpoints (require admin)
|
||||
router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Read, h.listInvites)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Create, h.createInvite)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}", permissionsManager.WithPermission(modules.Users, operations.Delete, h.deleteInvite)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/invites/{inviteId}/regenerate", permissionsManager.WithPermission(modules.Users, operations.Update, h.regenerateInvite)).Methods("POST", "OPTIONS")
|
||||
}
|
||||
|
||||
// AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting
|
||||
@@ -79,14 +82,7 @@ func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Route
|
||||
}
|
||||
|
||||
// listInvites handles GET /api/users/invites
|
||||
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -102,14 +98,7 @@ func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// createInvite handles POST /api/users/invites
|
||||
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
var req api.UserInviteCreateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
@@ -191,18 +180,12 @@ func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate
|
||||
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
inviteID := vars["inviteId"]
|
||||
if inviteID == "" {
|
||||
@@ -238,14 +221,7 @@ func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request
|
||||
}
|
||||
|
||||
// deleteInvite handles DELETE /api/users/invites/{inviteId}
|
||||
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
inviteID := vars["inviteId"]
|
||||
if inviteID == "" {
|
||||
@@ -253,7 +229,7 @@ func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
|
||||
err := h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -110,7 +110,11 @@ func TestListInvites(t *testing.T) {
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.listInvites(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
handler.listInvites(rr, req, userAuth)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
@@ -235,7 +239,11 @@ func TestCreateInvite(t *testing.T) {
|
||||
})
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.createInvite(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
handler.createInvite(rr, req, userAuth)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
@@ -573,7 +581,11 @@ func TestRegenerateInvite(t *testing.T) {
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.regenerateInvite(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
handler.regenerateInvite(rr, req, userAuth)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
|
||||
@@ -651,7 +663,11 @@ func TestDeleteInvite(t *testing.T) {
|
||||
}
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
handler.deleteInvite(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: testUserID,
|
||||
AccountId: testAccountID,
|
||||
}
|
||||
handler.deleteInvite(rr, req, userAuth)
|
||||
|
||||
assert.Equal(t, tc.expectedStatus, rr.Code)
|
||||
})
|
||||
|
||||
@@ -7,8 +7,11 @@ import (
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
@@ -19,12 +22,12 @@ type patHandler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router) {
|
||||
func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
tokenHandler := newPATsHandler(accountManager)
|
||||
router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getAllTokens)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Create, tokenHandler.createToken)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getToken)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Delete, tokenHandler.deleteToken)).Methods("DELETE", "OPTIONS")
|
||||
}
|
||||
|
||||
// newPATsHandler creates a new patHandler HTTP handler
|
||||
@@ -35,22 +38,15 @@ func newPATsHandler(accountManager account.Manager) *patHandler {
|
||||
}
|
||||
|
||||
// getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user
|
||||
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(userID) == 0 {
|
||||
if len(targetUserID) == 0 {
|
||||
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w)
|
||||
return
|
||||
}
|
||||
|
||||
pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID)
|
||||
pats, err := h.accountManager.GetAllPATs(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -65,14 +61,7 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// getToken is HTTP GET handler that returns a personal access token for the given user
|
||||
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -86,7 +75,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||
pat, err := h.accountManager.GetPAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -96,14 +85,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// createToken is HTTP POST handler that creates a personal access token for the given user
|
||||
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -112,13 +94,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
var req api.PostApiUsersUserIdTokensJSONRequestBody
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn)
|
||||
pat, err := h.accountManager.CreatePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.Name, req.ExpiresIn)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -128,14 +110,7 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// deleteToken is HTTP DELETE handler that deletes a personal access token for the given user
|
||||
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -149,7 +124,7 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID)
|
||||
err := h.accountManager.DeletePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -24,6 +24,18 @@ import (
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
// wrapHandler wraps a handler function that requires userAuth parameter
|
||||
func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
h(w, r, userAuth)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
existingAccountID = "existingAccountID"
|
||||
notFoundAccountID = "notFoundAccountID"
|
||||
@@ -181,10 +193,10 @@ func TestTokenHandlers(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE")
|
||||
router.HandleFunc("/api/users/{userId}/tokens", wrapHandler(p.getAllTokens)).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", wrapHandler(p.getToken)).Methods("GET")
|
||||
router.HandleFunc("/api/users/{userId}/tokens", wrapHandler(p.createToken)).Methods("POST")
|
||||
router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", wrapHandler(p.deleteToken)).Methods("DELETE")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
|
||||
@@ -9,13 +9,15 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/permissions"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
)
|
||||
|
||||
// handler is a handler that returns users of the account
|
||||
@@ -23,18 +25,18 @@ type handler struct {
|
||||
accountManager account.Manager
|
||||
}
|
||||
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router) {
|
||||
func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) {
|
||||
userHandler := newHandler(accountManager)
|
||||
router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/current", userHandler.getCurrentUser).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/password", userHandler.changePassword).Methods("PUT", "OPTIONS")
|
||||
addUsersTokensEndpoint(accountManager, router)
|
||||
router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getAllUsers)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/current", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getCurrentUser)).Methods("GET", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.updateUser)).Methods("PUT", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.deleteUser)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.createUser)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/invite", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.inviteUser)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/approve", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.approveUser)).Methods("POST", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/reject", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.rejectUser)).Methods("DELETE", "OPTIONS")
|
||||
router.HandleFunc("/users/{userId}/password", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.changePassword)).Methods("PUT", "OPTIONS")
|
||||
addUsersTokensEndpoint(accountManager, router, permissionsManager)
|
||||
}
|
||||
|
||||
// newHandler creates a new UsersHandler HTTP handler
|
||||
@@ -45,19 +47,12 @@ func newHandler(accountManager account.Manager) *handler {
|
||||
}
|
||||
|
||||
// updateUser is a PUT requests to update User data
|
||||
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPut {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -89,7 +84,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{
|
||||
newUser, err := h.accountManager.SaveUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.User{
|
||||
Id: targetUserID,
|
||||
Role: userRole,
|
||||
AutoGroups: req.AutoGroups,
|
||||
@@ -102,23 +97,16 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
|
||||
}
|
||||
|
||||
// deleteUser is a DELETE request to delete a user
|
||||
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodDelete {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -126,7 +114,7 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID)
|
||||
err := h.accountManager.DeleteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -136,21 +124,14 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// createUser creates a User in the system with a status "invited" (effectively this is a user invite).
|
||||
func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) createUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
req := &api.PostApiUsersJSONRequestBody{}
|
||||
err = json.NewDecoder(r.Body).Decode(&req)
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
@@ -171,7 +152,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
name = *req.Name
|
||||
}
|
||||
|
||||
newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{
|
||||
newUser, err := h.accountManager.CreateUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.UserInfo{
|
||||
Email: email,
|
||||
Name: name,
|
||||
Role: req.Role,
|
||||
@@ -183,25 +164,18 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID))
|
||||
util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId))
|
||||
}
|
||||
|
||||
// getAllUsers returns a list of users of the account this user belongs to.
|
||||
// It also gathers additional user data (like email and name) from the IDP manager.
|
||||
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodGet {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID)
|
||||
data, err := h.accountManager.GetUsersFromAccount(r.Context(), userAuth.AccountId, userAuth.UserId)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -215,7 +189,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
continue
|
||||
}
|
||||
if serviceUser == "" {
|
||||
users = append(users, toUserResponse(d, userID))
|
||||
users = append(users, toUserResponse(d, userAuth.UserId))
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -226,7 +200,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
if includeServiceUser == d.IsServiceUser {
|
||||
users = append(users, toUserResponse(d, userID))
|
||||
users = append(users, toUserResponse(d, userAuth.UserId))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -235,19 +209,12 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// inviteUser resend invitations to users who haven't activated their accounts,
|
||||
// prior to the expiration period.
|
||||
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
accountID, userID := userAuth.AccountId, userAuth.UserId
|
||||
|
||||
vars := mux.Vars(r)
|
||||
targetUserID := vars["userId"]
|
||||
if len(targetUserID) == 0 {
|
||||
@@ -255,7 +222,7 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID)
|
||||
err := h.accountManager.InviteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -264,19 +231,13 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) {
|
||||
util.WriteJSONObject(r.Context(), w, util.EmptyObject{})
|
||||
}
|
||||
|
||||
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodGet {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
}
|
||||
ctx := r.Context()
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(ctx)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
user, err := h.accountManager.GetCurrentUserInfo(ctx, userAuth)
|
||||
user, err := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -356,7 +317,7 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User {
|
||||
}
|
||||
|
||||
// approveUser is a POST request to approve a user that is pending approval
|
||||
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) approveUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPost {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
@@ -369,11 +330,6 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
@@ -385,7 +341,7 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// rejectUser is a DELETE request to reject a user that is pending approval
|
||||
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodDelete {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
@@ -398,12 +354,7 @@ func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
err := h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -421,7 +372,7 @@ type passwordChangeRequest struct {
|
||||
// changePassword is a PUT request to change user's password.
|
||||
// Only available when embedded IDP is enabled.
|
||||
// Users can only change their own password.
|
||||
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
|
||||
func (h *handler) changePassword(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) {
|
||||
if r.Method != http.MethodPut {
|
||||
util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w)
|
||||
return
|
||||
@@ -434,19 +385,13 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
var req passwordChangeRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w)
|
||||
return
|
||||
}
|
||||
|
||||
err = h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
|
||||
err := h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -232,7 +232,11 @@ func TestGetUsers(t *testing.T) {
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
userHandler.getAllUsers(recorder, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.getAllUsers(recorder, req, userAuth)
|
||||
|
||||
res := recorder.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -343,7 +347,7 @@ func TestUpdateUser(t *testing.T) {
|
||||
})
|
||||
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT")
|
||||
router.HandleFunc("/api/users/{userId}", wrapHandler(userHandler.updateUser)).Methods("PUT")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
res := recorder.Result()
|
||||
@@ -439,7 +443,11 @@ func TestCreateUser(t *testing.T) {
|
||||
AccountId: existingAccountID,
|
||||
})
|
||||
|
||||
userHandler.createUser(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.createUser(rr, req, userAuth)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -490,7 +498,11 @@ func TestInviteUser(t *testing.T) {
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.inviteUser(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.inviteUser(rr, req, userAuth)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -549,7 +561,11 @@ func TestDeleteUser(t *testing.T) {
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.deleteUser(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.deleteUser(rr, req, userAuth)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -682,7 +698,11 @@ func TestCurrentUser(t *testing.T) {
|
||||
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
userHandler.getCurrentUser(rr, req)
|
||||
userAuth := &auth.UserAuth{
|
||||
UserId: existingUserID,
|
||||
AccountId: existingAccountID,
|
||||
}
|
||||
userHandler.getCurrentUser(rr, req, userAuth)
|
||||
|
||||
res := rr.Result()
|
||||
defer res.Body.Close()
|
||||
@@ -779,7 +799,7 @@ func TestApproveUserEndpoint(t *testing.T) {
|
||||
|
||||
handler := newHandler(am)
|
||||
router := mux.NewRouter()
|
||||
router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST")
|
||||
router.HandleFunc("/users/{userId}/approve", wrapHandler(handler.approveUser)).Methods("POST")
|
||||
|
||||
req, err := http.NewRequest("POST", "/users/pending-user/approve", nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -4,20 +4,25 @@ package permissions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/http/util"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
type Manager interface {
|
||||
WithPermission(module modules.Module, operation operations.Operation, handlerFunc func(w http.ResponseWriter, r *http.Request, auth *auth.UserAuth)) http.HandlerFunc
|
||||
ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error)
|
||||
ValidateRoleModuleAccess(ctx context.Context, accountID string, role roles.RolePermissions, module modules.Module, operation operations.Operation) bool
|
||||
ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error
|
||||
@@ -36,6 +41,37 @@ func NewManager(store store.Store) Manager {
|
||||
}
|
||||
}
|
||||
|
||||
// WithPermission wraps an HTTP handler with permission checking logic.
|
||||
func (m *managerImpl) WithPermission(
|
||||
module modules.Module,
|
||||
operation operations.Operation,
|
||||
handlerFunc func(w http.ResponseWriter, r *http.Request, auth *auth.UserAuth),
|
||||
) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userAuth, err := nbcontext.GetUserAuthFromContext(r.Context())
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to get user auth from context: %v", err)
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
}
|
||||
|
||||
allowed, err := m.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, module, operation)
|
||||
if err != nil {
|
||||
log.WithContext(r.Context()).Errorf("failed to validate permissions for user %s on account %s: %v", userAuth.UserId, userAuth.AccountId, err)
|
||||
util.WriteError(r.Context(), status.NewPermissionValidationError(err), w)
|
||||
return
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
log.WithContext(r.Context()).Tracef("user %s on account %s is not allowed to %s in %s", userAuth.UserId, userAuth.AccountId, operation, module)
|
||||
util.WriteError(r.Context(), status.NewPermissionDeniedError(), w)
|
||||
return
|
||||
}
|
||||
|
||||
handlerFunc(w, r, &userAuth)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *managerImpl) ValidateUserPermissions(
|
||||
ctx context.Context,
|
||||
accountID string,
|
||||
|
||||
@@ -6,6 +6,7 @@ package permissions
|
||||
|
||||
import (
|
||||
context "context"
|
||||
http "net/http"
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
@@ -14,6 +15,7 @@ import (
|
||||
operations "github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
roles "github.com/netbirdio/netbird/management/server/permissions/roles"
|
||||
types "github.com/netbirdio/netbird/management/server/types"
|
||||
auth "github.com/netbirdio/netbird/shared/auth"
|
||||
)
|
||||
|
||||
// MockManager is a mock of Manager interface.
|
||||
@@ -108,3 +110,17 @@ func (mr *MockManagerMockRecorder) ValidateUserPermissions(ctx, accountID, userI
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateUserPermissions", reflect.TypeOf((*MockManager)(nil).ValidateUserPermissions), ctx, accountID, userID, module, operation)
|
||||
}
|
||||
|
||||
// WithPermission mocks base method.
|
||||
func (m *MockManager) WithPermission(module modules.Module, operation operations.Operation, handlerFunc func(http.ResponseWriter, *http.Request, *auth.UserAuth)) http.HandlerFunc {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "WithPermission", module, operation, handlerFunc)
|
||||
ret0, _ := ret[0].(http.HandlerFunc)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// WithPermission indicates an expected call of WithPermission.
|
||||
func (mr *MockManagerMockRecorder) WithPermission(module, operation, handlerFunc interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithPermission", reflect.TypeOf((*MockManager)(nil).WithPermission), module, operation, handlerFunc)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user