diff --git a/management/internals/modules/reverseproxy/accesslogs/manager/api.go b/management/internals/modules/reverseproxy/accesslogs/manager/api.go index 1e1414ca5..4859bc147 100644 --- a/management/internals/modules/reverseproxy/accesslogs/manager/api.go +++ b/management/internals/modules/reverseproxy/accesslogs/manager/api.go @@ -6,7 +6,10 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" ) @@ -15,21 +18,15 @@ type handler struct { manager accesslogs.Manager } -func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager) { +func RegisterEndpoints(router *mux.Router, manager accesslogs.Manager, permissionsManager permissions.Manager) { h := &handler{ manager: manager, } - router.HandleFunc("/events/proxy", h.getAccessLogs).Methods("GET", "OPTIONS") + router.HandleFunc("/events/proxy", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAccessLogs)).Methods("GET", "OPTIONS") } -func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getAccessLogs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var filter accesslogs.AccessLogFilter filter.ParseFromRequest(r) diff --git a/management/internals/modules/reverseproxy/domain/manager/api.go b/management/internals/modules/reverseproxy/domain/manager/api.go index 2fbcdd5b8..63bc79a64 100644 --- a/management/internals/modules/reverseproxy/domain/manager/api.go +++ b/management/internals/modules/reverseproxy/domain/manager/api.go @@ -7,7 +7,10 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -17,15 +20,15 @@ type handler struct { manager Manager } -func RegisterEndpoints(router *mux.Router, manager Manager) { +func RegisterEndpoints(router *mux.Router, manager Manager, permissionsManager permissions.Manager) { h := &handler{ manager: manager, } - router.HandleFunc("/domains", h.getAllDomains).Methods("GET", "OPTIONS") - router.HandleFunc("/domains", h.createCustomDomain).Methods("POST", "OPTIONS") - router.HandleFunc("/domains/{domainId}", h.deleteCustomDomain).Methods("DELETE", "OPTIONS") - router.HandleFunc("/domains/{domainId}/validate", h.triggerCustomDomainValidation).Methods("GET", "OPTIONS") + router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllDomains)).Methods("GET", "OPTIONS") + router.HandleFunc("/domains", permissionsManager.WithPermission(modules.Services, operations.Create, h.createCustomDomain)).Methods("POST", "OPTIONS") + router.HandleFunc("/domains/{domainId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteCustomDomain)).Methods("DELETE", "OPTIONS") + router.HandleFunc("/domains/{domainId}/validate", permissionsManager.WithPermission(modules.Services, operations.Read, h.triggerCustomDomainValidation)).Methods("GET", "OPTIONS") } func domainTypeToApi(t domain.Type) api.ReverseProxyDomainType { @@ -53,13 +56,7 @@ func domainToApi(d *domain.Domain) api.ReverseProxyDomain { return resp } -func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { domains, err := h.manager.GetDomains(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) @@ -74,13 +71,7 @@ func (h *handler) getAllDomains(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, ret) } -func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.PostApiReverseProxiesDomainsJSONRequestBody if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -96,13 +87,7 @@ func (h *handler) createCustomDomain(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, domainToApi(domain)) } -func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { domainID := mux.Vars(r)["domainId"] if domainID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w) @@ -117,13 +102,7 @@ func (h *handler) deleteCustomDomain(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) } -func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) triggerCustomDomainValidation(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { domainID := mux.Vars(r)["domainId"] if domainID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "domain ID is required"), w) diff --git a/management/internals/modules/reverseproxy/manager/api.go b/management/internals/modules/reverseproxy/manager/api.go index 9117ecd38..5cfab53a0 100644 --- a/management/internals/modules/reverseproxy/manager/api.go +++ b/management/internals/modules/reverseproxy/manager/api.go @@ -10,7 +10,10 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs" accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" domainmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -21,30 +24,24 @@ type handler struct { } // RegisterEndpoints registers all service HTTP endpoints. -func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, router *mux.Router) { +func RegisterEndpoints(manager reverseproxy.Manager, domainManager domainmanager.Manager, accessLogsManager accesslogs.Manager, permissionsManager permissions.Manager, router *mux.Router) { h := &handler{ manager: manager, } domainRouter := router.PathPrefix("/reverse-proxies").Subrouter() - domainmanager.RegisterEndpoints(domainRouter, domainManager) + domainmanager.RegisterEndpoints(domainRouter, domainManager, permissionsManager) - accesslogsmanager.RegisterEndpoints(router, accessLogsManager) + accesslogsmanager.RegisterEndpoints(router, accessLogsManager, permissionsManager) - router.HandleFunc("/reverse-proxies/services", h.getAllServices).Methods("GET", "OPTIONS") - router.HandleFunc("/reverse-proxies/services", h.createService).Methods("POST", "OPTIONS") - router.HandleFunc("/reverse-proxies/services/{serviceId}", h.getService).Methods("GET", "OPTIONS") - router.HandleFunc("/reverse-proxies/services/{serviceId}", h.updateService).Methods("PUT", "OPTIONS") - router.HandleFunc("/reverse-proxies/services/{serviceId}", h.deleteService).Methods("DELETE", "OPTIONS") + router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Read, h.getAllServices)).Methods("GET", "OPTIONS") + router.HandleFunc("/reverse-proxies/services", permissionsManager.WithPermission(modules.Services, operations.Create, h.createService)).Methods("POST", "OPTIONS") + router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Read, h.getService)).Methods("GET", "OPTIONS") + router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Update, h.updateService)).Methods("PUT", "OPTIONS") + router.HandleFunc("/reverse-proxies/services/{serviceId}", permissionsManager.WithPermission(modules.Services, operations.Delete, h.deleteService)).Methods("DELETE", "OPTIONS") } -func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { allServices, err := h.manager.GetAllServices(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) @@ -59,13 +56,7 @@ func (h *handler) getAllServices(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, apiServices) } -func (h *handler) createService(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) createService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.ServiceRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -75,7 +66,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) { service := new(reverseproxy.Service) service.FromAPIRequest(&req, userAuth.AccountId) - if err = service.Validate(); err != nil { + if err := service.Validate(); err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) return } @@ -89,13 +80,7 @@ func (h *handler) createService(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, createdService.ToAPIResponse()) } -func (h *handler) getService(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { serviceID := mux.Vars(r)["serviceId"] if serviceID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w) @@ -111,13 +96,7 @@ func (h *handler) getService(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, service.ToAPIResponse()) } -func (h *handler) updateService(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) updateService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { serviceID := mux.Vars(r)["serviceId"] if serviceID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w) @@ -134,7 +113,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) { service.ID = serviceID service.FromAPIRequest(&req, userAuth.AccountId) - if err = service.Validate(); err != nil { + if err := service.Validate(); err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) return } @@ -148,13 +127,7 @@ func (h *handler) updateService(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, updatedService.ToAPIResponse()) } -func (h *handler) deleteService(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) deleteService(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { serviceID := mux.Vars(r)["serviceId"] if serviceID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "service ID is required"), w) diff --git a/management/internals/modules/zones/manager/api.go b/management/internals/modules/zones/manager/api.go index 919d77d61..d5a894eb5 100644 --- a/management/internals/modules/zones/manager/api.go +++ b/management/internals/modules/zones/manager/api.go @@ -7,7 +7,10 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/internals/modules/zones" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -17,25 +20,19 @@ type handler struct { manager zones.Manager } -func RegisterEndpoints(router *mux.Router, manager zones.Manager) { +func RegisterEndpoints(router *mux.Router, manager zones.Manager, permissionsManager permissions.Manager) { h := &handler{ manager: manager, } - router.HandleFunc("/dns/zones", h.getAllZones).Methods("GET", "OPTIONS") - router.HandleFunc("/dns/zones", h.createZone).Methods("POST", "OPTIONS") - router.HandleFunc("/dns/zones/{zoneId}", h.getZone).Methods("GET", "OPTIONS") - router.HandleFunc("/dns/zones/{zoneId}", h.updateZone).Methods("PUT", "OPTIONS") - router.HandleFunc("/dns/zones/{zoneId}", h.deleteZone).Methods("DELETE", "OPTIONS") + router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllZones)).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/zones", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createZone)).Methods("POST", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getZone)).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateZone)).Methods("PUT", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteZone)).Methods("DELETE", "OPTIONS") } -func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { allZones, err := h.manager.GetAllZones(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) @@ -50,13 +47,7 @@ func (h *handler) getAllZones(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, apiZones) } -func (h *handler) createZone(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) createZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.PostApiDnsZonesJSONRequestBody if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -66,7 +57,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request) { zone := new(zones.Zone) zone.FromAPIRequest(&req) - if err = zone.Validate(); err != nil { + if err := zone.Validate(); err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) return } @@ -80,13 +71,7 @@ func (h *handler) createZone(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, createdZone.ToAPIResponse()) } -func (h *handler) getZone(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { zoneID := mux.Vars(r)["zoneId"] if zoneID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) @@ -102,13 +87,7 @@ func (h *handler) getZone(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, zone.ToAPIResponse()) } -func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) updateZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { zoneID := mux.Vars(r)["zoneId"] if zoneID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) @@ -116,7 +95,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) { } var req api.PutApiDnsZonesZoneIdJSONRequestBody - if err = json.NewDecoder(r.Body).Decode(&req); err != nil { + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } @@ -125,7 +104,7 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) { zone.FromAPIRequest(&req) zone.ID = zoneID - if err = zone.Validate(); err != nil { + if err := zone.Validate(); err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) return } @@ -139,20 +118,14 @@ func (h *handler) updateZone(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, updatedZone.ToAPIResponse()) } -func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) deleteZone(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { zoneID := mux.Vars(r)["zoneId"] if zoneID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) return } - if err = h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil { + if err := h.manager.DeleteZone(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID); err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/internals/modules/zones/records/manager/api.go b/management/internals/modules/zones/records/manager/api.go index f8ecfef7d..ebf2b3ebd 100644 --- a/management/internals/modules/zones/records/manager/api.go +++ b/management/internals/modules/zones/records/manager/api.go @@ -7,7 +7,10 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/internals/modules/zones/records" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -17,25 +20,19 @@ type handler struct { manager records.Manager } -func RegisterEndpoints(router *mux.Router, manager records.Manager) { +func RegisterEndpoints(router *mux.Router, manager records.Manager, permissionsManager permissions.Manager) { h := &handler{ manager: manager, } - router.HandleFunc("/dns/zones/{zoneId}/records", h.getAllRecords).Methods("GET", "OPTIONS") - router.HandleFunc("/dns/zones/{zoneId}/records", h.createRecord).Methods("POST", "OPTIONS") - router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.getRecord).Methods("GET", "OPTIONS") - router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.updateRecord).Methods("PUT", "OPTIONS") - router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", h.deleteRecord).Methods("DELETE", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getAllRecords)).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}/records", permissionsManager.WithPermission(modules.Dns, operations.Create, h.createRecord)).Methods("POST", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Read, h.getRecord)).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Update, h.updateRecord)).Methods("PUT", "OPTIONS") + router.HandleFunc("/dns/zones/{zoneId}/records/{recordId}", permissionsManager.WithPermission(modules.Dns, operations.Delete, h.deleteRecord)).Methods("DELETE", "OPTIONS") } -func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { zoneID := mux.Vars(r)["zoneId"] if zoneID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) @@ -56,13 +53,7 @@ func (h *handler) getAllRecords(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, apiRecords) } -func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) createRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { zoneID := mux.Vars(r)["zoneId"] if zoneID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) @@ -78,7 +69,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) { record := new(records.Record) record.FromAPIRequest(&req) - if err = record.Validate(); err != nil { + if err := record.Validate(); err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) return } @@ -92,13 +83,7 @@ func (h *handler) createRecord(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, createdRecord.ToAPIResponse()) } -func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) getRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { zoneID := mux.Vars(r)["zoneId"] if zoneID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) @@ -120,13 +105,7 @@ func (h *handler) getRecord(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, record.ToAPIResponse()) } -func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { zoneID := mux.Vars(r)["zoneId"] if zoneID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) @@ -140,7 +119,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) { } var req api.PutApiDnsZonesZoneIdRecordsRecordIdJSONRequestBody - if err = json.NewDecoder(r.Body).Decode(&req); err != nil { + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } @@ -149,7 +128,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) { record.FromAPIRequest(&req) record.ID = recordID - if err = record.Validate(); err != nil { + if err := record.Validate(); err != nil { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "%s", err.Error()), w) return } @@ -163,13 +142,7 @@ func (h *handler) updateRecord(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, updatedRecord.ToAPIResponse()) } -func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { zoneID := mux.Vars(r)["zoneId"] if zoneID == "" { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "zone ID is required"), w) @@ -182,7 +155,7 @@ func (h *handler) deleteRecord(w http.ResponseWriter, r *http.Request) { return } - if err = h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil { + if err := h.manager.DeleteRecord(r.Context(), userAuth.AccountId, userAuth.UserId, zoneID, recordID); err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 9d2384cae..f5477574d 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -154,27 +154,27 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks return nil, fmt.Errorf("failed to create instance manager: %w", err) } - accounts.AddEndpoints(accountManager, settingsManager, router) + accounts.AddEndpoints(accountManager, settingsManager, router, permissionsManager) peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager) - users.AddEndpoints(accountManager, router) - users.AddInvitesEndpoints(accountManager, router) + users.AddEndpoints(accountManager, router, permissionsManager) + users.AddInvitesEndpoints(accountManager, router, permissionsManager) users.AddPublicInvitesEndpoints(accountManager, router) - setup_keys.AddEndpoints(accountManager, router) - policies.AddEndpoints(accountManager, LocationManager, router) - policies.AddPostureCheckEndpoints(accountManager, LocationManager, router) + setup_keys.AddEndpoints(accountManager, router, permissionsManager) + policies.AddEndpoints(accountManager, LocationManager, router, permissionsManager) + policies.AddPostureCheckEndpoints(accountManager, LocationManager, router, permissionsManager) policies.AddLocationsEndpoints(accountManager, LocationManager, permissionsManager, router) - groups.AddEndpoints(accountManager, router) - routes.AddEndpoints(accountManager, router) - dns.AddEndpoints(accountManager, router) - events.AddEndpoints(accountManager, router) - networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, router) - zonesManager.RegisterEndpoints(router, zManager) - recordsManager.RegisterEndpoints(router, rManager) - idp.AddEndpoints(accountManager, router) + groups.AddEndpoints(accountManager, router, permissionsManager) + routes.AddEndpoints(accountManager, router, permissionsManager) + dns.AddEndpoints(accountManager, router, permissionsManager) + events.AddEndpoints(accountManager, router, permissionsManager) + networks.AddEndpoints(networksManager, resourceManager, routerManager, groupsManager, accountManager, permissionsManager, router) + zonesManager.RegisterEndpoints(router, zManager, permissionsManager) + recordsManager.RegisterEndpoints(router, rManager, permissionsManager) + idp.AddEndpoints(accountManager, router, permissionsManager) instance.AddEndpoints(instanceManager, router) - instance.AddVersionEndpoint(instanceManager, router) + instance.AddVersionEndpoint(instanceManager, router, permissionsManager) if reverseProxyManager != nil && reverseProxyDomainManager != nil { - reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, router) + reverseproxymanager.RegisterEndpoints(reverseProxyManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router) } // Register OAuth callback handler for proxy authentication diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index 122c061ce..02aaf5246 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -13,9 +13,12 @@ import ( goversion "github.com/hashicorp/go-version" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/settings" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -40,11 +43,11 @@ type handler struct { settingsManager settings.Manager } -func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) { +func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router, permissionsManager permissions.Manager) { accountsHandler := newHandler(accountManager, settingsManager) - router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS") - router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS") - router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS") + router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Update, accountsHandler.updateAccount)).Methods("PUT", "OPTIONS") + router.HandleFunc("/accounts/{accountId}", permissionsManager.WithPermission(modules.Accounts, operations.Delete, accountsHandler.deleteAccount)).Methods("DELETE", "OPTIONS") + router.HandleFunc("/accounts", permissionsManager.WithPermission(modules.Accounts, operations.Read, accountsHandler.getAllAccounts)).Methods("GET", "OPTIONS") } // newHandler creates a new handler HTTP handler @@ -136,34 +139,26 @@ func calculateRequiredAddresses(peerCount int) int64 { } // getAllAccounts is HTTP GET handler that returns a list of accounts. Effectively returns just a single account. -func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) +func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + meta, err := h.accountManager.GetAccountMeta(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - accountID, userID := userAuth.AccountId, userAuth.UserId - - meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID) + settings, err := h.settingsManager.GetSettings(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - settings, err := h.settingsManager.GetSettings(r.Context(), accountID, userID) + onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - resp := toAccountResponse(accountID, settings, meta, onboarding) + resp := toAccountResponse(userAuth.AccountId, settings, meta, onboarding) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } @@ -223,15 +218,7 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS } // updateAccount is HTTP PUT handler that updates the provided account. Updates only account settings (server.Settings) -func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - _, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) accountID := vars["accountId"] if len(accountID) == 0 { @@ -240,7 +227,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { } var req api.PutApiAccountsAccountIdJSONRequestBody - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -257,7 +244,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid CIDR format: %v", err), w) return } - if err := h.validateNetworkRange(r.Context(), accountID, userID, prefix); err != nil { + if err := h.validateNetworkRange(r.Context(), accountID, userAuth.UserId, prefix); err != nil { util.WriteError(r.Context(), err, w) return } @@ -272,19 +259,19 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { } } - updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding) + updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userAuth.UserId, onboarding) if err != nil { util.WriteError(r.Context(), err, w) return } - updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) + updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userAuth.UserId, settings) if err != nil { util.WriteError(r.Context(), err, w) return } - meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userID) + meta, err := h.accountManager.GetAccountMeta(r.Context(), accountID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -296,13 +283,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { } // deleteAccount is a HTTP DELETE handler to delete an account -func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) targetAccountID := vars["accountId"] if len(targetAccountID) == 0 { @@ -310,7 +291,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId) + err := h.accountManager.DeleteAccount(r.Context(), targetAccountID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index 6cbd5908d..d4ce3ab9d 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -72,6 +72,18 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler { } } +// wrapHandler wraps a handler function that requires userAuth parameter +func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + h(w, r, userAuth) + } +} + func TestAccounts_AccountsHandler(t *testing.T) { accountID := "test_account" adminUser := types.NewAdminUser("test_user") @@ -284,8 +296,8 @@ func TestAccounts_AccountsHandler(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/accounts", handler.getAllAccounts).Methods("GET") - router.HandleFunc("/api/accounts/{accountId}", handler.updateAccount).Methods("PUT") + router.HandleFunc("/api/accounts", wrapHandler(handler.getAllAccounts)).Methods("GET") + router.HandleFunc("/api/accounts/{accountId}", wrapHandler(handler.updateAccount)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/dns/dns_settings_handler.go b/management/server/http/handlers/dns/dns_settings_handler.go index 67638aea5..c7f99a784 100644 --- a/management/server/http/handlers/dns/dns_settings_handler.go +++ b/management/server/http/handlers/dns/dns_settings_handler.go @@ -5,11 +5,13 @@ import ( "net/http" "github.com/gorilla/mux" - log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" ) @@ -19,15 +21,15 @@ type dnsSettingsHandler struct { accountManager account.Manager } -func AddEndpoints(accountManager account.Manager, router *mux.Router) { - addDNSSettingEndpoint(accountManager, router) - addDNSNameserversEndpoint(accountManager, router) +func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { + addDNSSettingEndpoint(accountManager, router, permissionsManager) + addDNSNameserversEndpoint(accountManager, router, permissionsManager) } -func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router) { +func addDNSSettingEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { dnsSettingsHandler := newDNSSettingsHandler(accountManager) - router.HandleFunc("/dns/settings", dnsSettingsHandler.getDNSSettings).Methods("GET", "OPTIONS") - router.HandleFunc("/dns/settings", dnsSettingsHandler.updateDNSSettings).Methods("PUT", "OPTIONS") + router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Read, dnsSettingsHandler.getDNSSettings)).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/settings", permissionsManager.WithPermission(modules.Dns, operations.Update, dnsSettingsHandler.updateDNSSettings)).Methods("PUT", "OPTIONS") } // newDNSSettingsHandler returns a new instance of dnsSettingsHandler handler @@ -36,17 +38,8 @@ func newDNSSettingsHandler(accountManager account.Manager) *dnsSettingsHandler { } // getDNSSettings returns the DNS settings for the account -func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - log.WithContext(r.Context()).Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), accountID, userID) +func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + dnsSettings, err := h.accountManager.GetDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -60,17 +53,9 @@ func (h *dnsSettingsHandler) getDNSSettings(w http.ResponseWriter, r *http.Reque } // updateDNSSettings handles update to DNS settings of an account -func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.PutApiDnsSettingsJSONRequestBody - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -80,7 +65,7 @@ func (h *dnsSettingsHandler) updateDNSSettings(w http.ResponseWriter, r *http.Re DisabledManagementGroups: req.DisabledManagementGroups, } - err = h.accountManager.SaveDNSSettings(r.Context(), accountID, userID, updateDNSSettings) + err = h.accountManager.SaveDNSSettings(r.Context(), userAuth.AccountId, userAuth.UserId, updateDNSSettings) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/dns/dns_settings_handler_test.go b/management/server/http/handlers/dns/dns_settings_handler_test.go index a027c067e..fbe7ca578 100644 --- a/management/server/http/handlers/dns/dns_settings_handler_test.go +++ b/management/server/http/handlers/dns/dns_settings_handler_test.go @@ -23,6 +23,18 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" ) +// wrapHandler wraps a handler function that requires userAuth parameter +func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + h(w, r, userAuth) + } +} + const ( testDNSSettingsAccountID = "test_id" testDNSSettingsExistingGroup = "test_group" @@ -115,8 +127,8 @@ func TestDNSSettingsHandlers(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/dns/settings", p.getDNSSettings).Methods("GET") - router.HandleFunc("/api/dns/settings", p.updateDNSSettings).Methods("PUT") + router.HandleFunc("/api/dns/settings", wrapHandler(p.getDNSSettings)).Methods("GET") + router.HandleFunc("/api/dns/settings", wrapHandler(p.updateDNSSettings)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/dns/nameservers_handler.go b/management/server/http/handlers/dns/nameservers_handler.go index bce1c4b78..9421b8f45 100644 --- a/management/server/http/handlers/dns/nameservers_handler.go +++ b/management/server/http/handlers/dns/nameservers_handler.go @@ -6,11 +6,13 @@ import ( "net/http" "github.com/gorilla/mux" - log "github.com/sirupsen/logrus" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -21,13 +23,13 @@ type nameserversHandler struct { accountManager account.Manager } -func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router) { +func addDNSNameserversEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { nameserversHandler := newNameserversHandler(accountManager) - router.HandleFunc("/dns/nameservers", nameserversHandler.getAllNameservers).Methods("GET", "OPTIONS") - router.HandleFunc("/dns/nameservers", nameserversHandler.createNameserverGroup).Methods("POST", "OPTIONS") - router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.updateNameserverGroup).Methods("PUT", "OPTIONS") - router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.getNameserverGroup).Methods("GET", "OPTIONS") - router.HandleFunc("/dns/nameservers/{nsgroupId}", nameserversHandler.deleteNameserverGroup).Methods("DELETE", "OPTIONS") + router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getAllNameservers)).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/nameservers", permissionsManager.WithPermission(modules.Nameservers, operations.Create, nameserversHandler.createNameserverGroup)).Methods("POST", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Update, nameserversHandler.updateNameserverGroup)).Methods("PUT", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Read, nameserversHandler.getNameserverGroup)).Methods("GET", "OPTIONS") + router.HandleFunc("/dns/nameservers/{nsgroupId}", permissionsManager.WithPermission(modules.Nameservers, operations.Delete, nameserversHandler.deleteNameserverGroup)).Methods("DELETE", "OPTIONS") } // newNameserversHandler returns a new instance of nameserversHandler handler @@ -36,17 +38,8 @@ func newNameserversHandler(accountManager account.Manager) *nameserversHandler { } // getAllNameservers returns the list of nameserver groups for the account -func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - log.WithContext(r.Context()).Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), accountID, userID) +func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + nsGroups, err := h.accountManager.ListNameServerGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -61,17 +54,9 @@ func (h *nameserversHandler) getAllNameservers(w http.ResponseWriter, r *http.Re } // createNameserverGroup handles nameserver group creation request -func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.PostApiDnsNameserversJSONRequestBody - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -83,7 +68,7 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt return } - nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), accountID, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userID, req.SearchDomainsEnabled) + nsGroup, err := h.accountManager.CreateNameServerGroup(r.Context(), userAuth.AccountId, req.Name, req.Description, nsList, req.Groups, req.Primary, req.Domains, req.Enabled, userAuth.UserId, req.SearchDomainsEnabled) if err != nil { util.WriteError(r.Context(), err, w) return @@ -95,15 +80,7 @@ func (h *nameserversHandler) createNameserverGroup(w http.ResponseWriter, r *htt } // updateNameserverGroup handles update to a nameserver group identified by a given ID -func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) @@ -111,7 +88,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt } var req api.PutApiDnsNameserversNsgroupIdJSONRequestBody - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -135,7 +112,7 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt SearchDomainsEnabled: req.SearchDomainsEnabled, } - err = h.accountManager.SaveNameServerGroup(r.Context(), accountID, userID, updatedNSGroup) + err = h.accountManager.SaveNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, updatedNSGroup) if err != nil { util.WriteError(r.Context(), err, w) return @@ -147,22 +124,14 @@ func (h *nameserversHandler) updateNameserverGroup(w http.ResponseWriter, r *htt } // deleteNameserverGroup handles nameserver group deletion request -func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) return } - err = h.accountManager.DeleteNameServerGroup(r.Context(), accountID, nsGroupID, userID) + err := h.accountManager.DeleteNameServerGroup(r.Context(), userAuth.AccountId, nsGroupID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -172,22 +141,14 @@ func (h *nameserversHandler) deleteNameserverGroup(w http.ResponseWriter, r *htt } // getNameserverGroup handles a nameserver group Get request identified by ID -func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *nameserversHandler) getNameserverGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { nsGroupID := mux.Vars(r)["nsgroupId"] if len(nsGroupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid nameserver group ID"), w) return } - nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), accountID, userID, nsGroupID) + nsGroup, err := h.accountManager.GetNameServerGroup(r.Context(), userAuth.AccountId, userAuth.UserId, nsGroupID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/dns/nameservers_handler_test.go b/management/server/http/handlers/dns/nameservers_handler_test.go index 4716782f3..07285175c 100644 --- a/management/server/http/handlers/dns/nameservers_handler_test.go +++ b/management/server/http/handlers/dns/nameservers_handler_test.go @@ -24,6 +24,18 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" ) +// wrapHandler wraps a handler function that requires userAuth parameter +func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + h(w, r, userAuth) + } +} + const ( existingNSGroupID = "existingNSGroupID" notFoundNSGroupID = "notFoundNSGroupID" @@ -201,10 +213,10 @@ func TestNameserversHandlers(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.getNameserverGroup).Methods("GET") - router.HandleFunc("/api/dns/nameservers", p.createNameserverGroup).Methods("POST") - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.deleteNameserverGroup).Methods("DELETE") - router.HandleFunc("/api/dns/nameservers/{nsgroupId}", p.updateNameserverGroup).Methods("PUT") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", wrapHandler(p.getNameserverGroup)).Methods("GET") + router.HandleFunc("/api/dns/nameservers", wrapHandler(p.createNameserverGroup)).Methods("POST") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", wrapHandler(p.deleteNameserverGroup)).Methods("DELETE") + router.HandleFunc("/api/dns/nameservers/{nsgroupId}", wrapHandler(p.updateNameserverGroup)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/events/events_handler.go b/management/server/http/handlers/events/events_handler.go index ae1e64e5c..add30848b 100644 --- a/management/server/http/handlers/events/events_handler.go +++ b/management/server/http/handlers/events/events_handler.go @@ -5,11 +5,13 @@ import ( "net/http" "github.com/gorilla/mux" - log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" ) @@ -19,10 +21,10 @@ type handler struct { accountManager account.Manager } -func AddEndpoints(accountManager account.Manager, router *mux.Router) { +func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { eventsHandler := newHandler(accountManager) - router.HandleFunc("/events", eventsHandler.getAllEvents).Methods("GET", "OPTIONS") - router.HandleFunc("/events/audit", eventsHandler.getAllEvents).Methods("GET", "OPTIONS") + router.HandleFunc("/events", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS") + router.HandleFunc("/events/audit", permissionsManager.WithPermission(modules.Events, operations.Read, eventsHandler.getAllEvents)).Methods("GET", "OPTIONS") } // newHandler creates a new events handler @@ -31,17 +33,8 @@ func newHandler(accountManager account.Manager) *handler { } // getAllEvents list of the given account -func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - log.WithContext(r.Context()).Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - accountEvents, err := h.accountManager.GetEvents(r.Context(), accountID, userID) +func (h *handler) getAllEvents(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + accountEvents, err := h.accountManager.GetEvents(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/events/events_handler_test.go b/management/server/http/handlers/events/events_handler_test.go index 923a24e31..c3b932248 100644 --- a/management/server/http/handlers/events/events_handler_test.go +++ b/management/server/http/handlers/events/events_handler_test.go @@ -163,6 +163,18 @@ func generateEvents(accountID, userID string) []*activity.Event { return events } +// wrapHandler wraps a handler function that requires userAuth parameter +func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + h(w, r, userAuth) + } +} + func TestEvents_GetEvents(t *testing.T) { tt := []struct { name string @@ -196,7 +208,7 @@ func TestEvents_GetEvents(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/events/", handler.getAllEvents).Methods("GET") + router.HandleFunc("/api/events/", wrapHandler(handler.getAllEvents)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/groups/groups_handler.go b/management/server/http/handlers/groups/groups_handler.go index 56ccc9d0b..41e1af825 100644 --- a/management/server/http/handlers/groups/groups_handler.go +++ b/management/server/http/handlers/groups/groups_handler.go @@ -8,10 +8,12 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" nbpeer "github.com/netbirdio/netbird/management/server/peer" - + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -22,13 +24,13 @@ type handler struct { accountManager account.Manager } -func AddEndpoints(accountManager account.Manager, router *mux.Router) { +func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { groupsHandler := newHandler(accountManager) - router.HandleFunc("/groups", groupsHandler.getAllGroups).Methods("GET", "OPTIONS") - router.HandleFunc("/groups", groupsHandler.createGroup).Methods("POST", "OPTIONS") - router.HandleFunc("/groups/{groupId}", groupsHandler.updateGroup).Methods("PUT", "OPTIONS") - router.HandleFunc("/groups/{groupId}", groupsHandler.getGroup).Methods("GET", "OPTIONS") - router.HandleFunc("/groups/{groupId}", groupsHandler.deleteGroup).Methods("DELETE", "OPTIONS") + router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getAllGroups)).Methods("GET", "OPTIONS") + router.HandleFunc("/groups", permissionsManager.WithPermission(modules.Groups, operations.Create, groupsHandler.createGroup)).Methods("POST", "OPTIONS") + router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Update, groupsHandler.updateGroup)).Methods("PUT", "OPTIONS") + router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Read, groupsHandler.getGroup)).Methods("GET", "OPTIONS") + router.HandleFunc("/groups/{groupId}", permissionsManager.WithPermission(modules.Groups, operations.Delete, groupsHandler.deleteGroup)).Methods("DELETE", "OPTIONS") } // newHandler creates a new groups handler @@ -39,26 +41,18 @@ func newHandler(accountManager account.Manager) *handler { } // getAllGroups list for the account -func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - log.WithContext(r.Context()).Error(err) - http.Redirect(w, r, "/", http.StatusInternalServerError) - return - } - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { // Check if filtering by name groupName := r.URL.Query().Get("name") if groupName != "" { // Get single group by name - group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID) + group, err := h.accountManager.GetGroupByName(r.Context(), groupName, userAuth.AccountId) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") + accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "") if err != nil { util.WriteError(r.Context(), err, w) return @@ -71,13 +65,13 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { } // Get all groups - groups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + groups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") + accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "") if err != nil { util.WriteError(r.Context(), err, w) return @@ -92,15 +86,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) { } // updateGroup handles update to a group identified by a given ID -func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) groupID, ok := vars["groupId"] if !ok { @@ -112,13 +98,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { return } - existingGroup, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID) + existingGroup, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID) + allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", userAuth.AccountId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -166,13 +152,13 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { IntegrationReference: existingGroup.IntegrationReference, } - if err := h.accountManager.UpdateGroup(r.Context(), accountID, userID, &group); err != nil { - log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, accountID, err) + if err := h.accountManager.UpdateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group); err != nil { + log.WithContext(r.Context()).Errorf("failed updating group %s under account %s %v", groupID, userAuth.AccountId, err) util.WriteError(r.Context(), err, w) return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") + accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "") if err != nil { util.WriteError(r.Context(), err, w) return @@ -182,17 +168,9 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) { } // createGroup handles group creation request -func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) createGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.PostApiGroupsJSONRequestBody - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -226,13 +204,13 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { Issued: types.GroupIssuedAPI, } - err = h.accountManager.CreateGroup(r.Context(), accountID, userID, &group) + err = h.accountManager.CreateGroup(r.Context(), userAuth.AccountId, userAuth.UserId, &group) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") + accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "") if err != nil { util.WriteError(r.Context(), err, w) return @@ -242,22 +220,14 @@ func (h *handler) createGroup(w http.ResponseWriter, r *http.Request) { } // deleteGroup handles group deletion request -func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { groupID := mux.Vars(r)["groupId"] if len(groupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) return } - err = h.accountManager.DeleteGroup(r.Context(), accountID, userID, groupID) + err := h.accountManager.DeleteGroup(r.Context(), userAuth.AccountId, userAuth.UserId, groupID) if err != nil { wrappedErr, ok := err.(interface{ Unwrap() []error }) if ok && len(wrappedErr.Unwrap()) > 0 { @@ -273,34 +243,26 @@ func (h *handler) deleteGroup(w http.ResponseWriter, r *http.Request) { } // getGroup returns a group -func (h *handler) getGroup(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) getGroup(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { groupID := mux.Vars(r)["groupId"] if len(groupID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid group ID"), w) return } - group, err := h.accountManager.GetGroup(r.Context(), accountID, groupID, userID) + group, err := h.accountManager.GetGroup(r.Context(), userAuth.AccountId, groupID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - accountPeers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, "", "") + accountPeers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, "", "") if err != nil { util.WriteError(r.Context(), err, w) return } util.WriteJSONObject(r.Context(), w, toGroupResponse(accountPeers, group)) - } func toGroupResponse(peers []*nbpeer.Peer, group *types.Group) *api.Group { diff --git a/management/server/http/handlers/groups/groups_handler_test.go b/management/server/http/handlers/groups/groups_handler_test.go index 458a15c11..a40fc9f15 100644 --- a/management/server/http/handlers/groups/groups_handler_test.go +++ b/management/server/http/handlers/groups/groups_handler_test.go @@ -33,6 +33,18 @@ var TestPeers = map[string]*nbpeer.Peer{ "B": {Key: "B", ID: "peer-B-ID", IP: net.ParseIP("200.200.200.200")}, } +// wrapHandler wraps a handler function that requires userAuth parameter +func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + h(w, r, userAuth) + } +} + func initGroupTestData(initGroups ...*types.Group) *handler { return &handler{ accountManager: &mock_server.MockAccountManager{ @@ -141,7 +153,7 @@ func TestGetGroup(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/groups/{groupId}", p.getGroup).Methods("GET") + router.HandleFunc("/api/groups/{groupId}", wrapHandler(p.getGroup)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -267,8 +279,8 @@ func TestWriteGroup(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/groups", p.createGroup).Methods("POST") - router.HandleFunc("/api/groups/{groupId}", p.updateGroup).Methods("PUT") + router.HandleFunc("/api/groups", wrapHandler(p.createGroup)).Methods("POST") + router.HandleFunc("/api/groups/{groupId}", wrapHandler(p.updateGroup)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -345,7 +357,7 @@ func TestGetAllGroups(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/groups", p.getAllGroups).Methods("GET") + router.HandleFunc("/api/groups", wrapHandler(p.getAllGroups)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -426,7 +438,7 @@ func TestDeleteGroup(t *testing.T) { AccountId: "test_id", }) router := mux.NewRouter() - router.HandleFunc("/api/groups/{groupId}", p.deleteGroup).Methods("DELETE") + router.HandleFunc("/api/groups/{groupId}", wrapHandler(p.deleteGroup)).Methods("DELETE") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/idp/idp_handler.go b/management/server/http/handlers/idp/idp_handler.go index 077507b89..5a192e7df 100644 --- a/management/server/http/handlers/idp/idp_handler.go +++ b/management/server/http/handlers/idp/idp_handler.go @@ -7,8 +7,11 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -20,13 +23,13 @@ type handler struct { } // AddEndpoints registers identity provider endpoints -func AddEndpoints(accountManager account.Manager, router *mux.Router) { +func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { h := newHandler(accountManager) - router.HandleFunc("/identity-providers", h.getAllIdentityProviders).Methods("GET", "OPTIONS") - router.HandleFunc("/identity-providers", h.createIdentityProvider).Methods("POST", "OPTIONS") - router.HandleFunc("/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET", "OPTIONS") - router.HandleFunc("/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT", "OPTIONS") - router.HandleFunc("/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE", "OPTIONS") + router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getAllIdentityProviders)).Methods("GET", "OPTIONS") + router.HandleFunc("/identity-providers", permissionsManager.WithPermission(modules.IdentityProviders, operations.Create, h.createIdentityProvider)).Methods("POST", "OPTIONS") + router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Read, h.getIdentityProvider)).Methods("GET", "OPTIONS") + router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Update, h.updateIdentityProvider)).Methods("PUT", "OPTIONS") + router.HandleFunc("/identity-providers/{idpId}", permissionsManager.WithPermission(modules.IdentityProviders, operations.Delete, h.deleteIdentityProvider)).Methods("DELETE", "OPTIONS") } func newHandler(accountManager account.Manager) *handler { @@ -36,16 +39,8 @@ func newHandler(accountManager account.Manager) *handler { } // getAllIdentityProviders returns all identity providers for the account -func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - providers, err := h.accountManager.GetIdentityProviders(r.Context(), accountID, userID) +func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + providers, err := h.accountManager.GetIdentityProviders(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -60,15 +55,7 @@ func (h *handler) getAllIdentityProviders(w http.ResponseWriter, r *http.Request } // getIdentityProvider returns a specific identity provider -func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) idpID := vars["idpId"] if idpID == "" { @@ -76,7 +63,7 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) { return } - provider, err := h.accountManager.GetIdentityProvider(r.Context(), accountID, idpID, userID) + provider, err := h.accountManager.GetIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -86,15 +73,7 @@ func (h *handler) getIdentityProvider(w http.ResponseWriter, r *http.Request) { } // createIdentityProvider creates a new identity provider -func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.IdentityProviderRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -103,7 +82,7 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) idp := fromAPIRequest(&req) - created, err := h.accountManager.CreateIdentityProvider(r.Context(), accountID, userID, idp) + created, err := h.accountManager.CreateIdentityProvider(r.Context(), userAuth.AccountId, userAuth.UserId, idp) if err != nil { util.WriteError(r.Context(), err, w) return @@ -113,15 +92,7 @@ func (h *handler) createIdentityProvider(w http.ResponseWriter, r *http.Request) } // updateIdentityProvider updates an existing identity provider -func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) idpID := vars["idpId"] if idpID == "" { @@ -137,7 +108,7 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) idp := fromAPIRequest(&req) - updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), accountID, idpID, userID, idp) + updated, err := h.accountManager.UpdateIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId, idp) if err != nil { util.WriteError(r.Context(), err, w) return @@ -147,15 +118,7 @@ func (h *handler) updateIdentityProvider(w http.ResponseWriter, r *http.Request) } // deleteIdentityProvider deletes an identity provider -func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) idpID := vars["idpId"] if idpID == "" { @@ -163,7 +126,7 @@ func (h *handler) deleteIdentityProvider(w http.ResponseWriter, r *http.Request) return } - if err := h.accountManager.DeleteIdentityProvider(r.Context(), accountID, idpID, userID); err != nil { + if err := h.accountManager.DeleteIdentityProvider(r.Context(), userAuth.AccountId, idpID, userAuth.UserId); err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/http/handlers/idp/idp_handler_test.go b/management/server/http/handlers/idp/idp_handler_test.go index 74b204048..39eaafae4 100644 --- a/management/server/http/handlers/idp/idp_handler_test.go +++ b/management/server/http/handlers/idp/idp_handler_test.go @@ -87,6 +87,18 @@ func initIDPTestData(existingIDP *types.IdentityProvider) *handler { } } +// wrapHandler wraps a handler function that requires userAuth parameter +func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + h(w, r, userAuth) + } +} + func TestGetAllIdentityProviders(t *testing.T) { existingIDP := &types.IdentityProvider{ ID: existingIDPID, @@ -120,7 +132,7 @@ func TestGetAllIdentityProviders(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/identity-providers", h.getAllIdentityProviders).Methods("GET") + router.HandleFunc("/api/identity-providers", wrapHandler(h.getAllIdentityProviders)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -180,7 +192,7 @@ func TestGetIdentityProvider(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/identity-providers/{idpId}", h.getIdentityProvider).Methods("GET") + router.HandleFunc("/api/identity-providers/{idpId}", wrapHandler(h.getIdentityProvider)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -242,7 +254,7 @@ func TestCreateIdentityProvider(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/identity-providers", h.createIdentityProvider).Methods("POST") + router.HandleFunc("/api/identity-providers", wrapHandler(h.createIdentityProvider)).Methods("POST") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -328,7 +340,7 @@ func TestUpdateIdentityProvider(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/identity-providers/{idpId}", h.updateIdentityProvider).Methods("PUT") + router.HandleFunc("/api/identity-providers/{idpId}", wrapHandler(h.updateIdentityProvider)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -388,7 +400,7 @@ func TestDeleteIdentityProvider(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/identity-providers/{idpId}", h.deleteIdentityProvider).Methods("DELETE") + router.HandleFunc("/api/identity-providers/{idpId}", wrapHandler(h.deleteIdentityProvider)).Methods("DELETE") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/instance/instance_handler.go b/management/server/http/handlers/instance/instance_handler.go index cd9fae6b8..77e1efb57 100644 --- a/management/server/http/handlers/instance/instance_handler.go +++ b/management/server/http/handlers/instance/instance_handler.go @@ -8,6 +8,10 @@ import ( log "github.com/sirupsen/logrus" nbinstance "github.com/netbirdio/netbird/management/server/instance" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" ) @@ -29,12 +33,12 @@ func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) { } // AddVersionEndpoint registers the authenticated version endpoint. -func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router) { +func AddVersionEndpoint(instanceManager nbinstance.Manager, router *mux.Router, permissionsManager permissions.Manager) { h := &handler{ instanceManager: instanceManager, } - router.HandleFunc("/instance/version", h.getVersionInfo).Methods("GET", "OPTIONS") + router.HandleFunc("/instance/version", permissionsManager.WithPermission(modules.Settings, operations.Read, h.getVersionInfo)).Methods("GET", "OPTIONS") } // getInstanceStatus returns the instance status including whether setup is required. @@ -77,7 +81,7 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) { // getVersionInfo returns version information for NetBird components. // This endpoint requires authentication. -func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request) { +func (h *handler) getVersionInfo(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { versionInfo, err := h.instanceManager.GetVersionInfo(r.Context()) if err != nil { log.WithContext(r.Context()).Errorf("failed to get version info: %v", err) diff --git a/management/server/http/handlers/instance/instance_handler_test.go b/management/server/http/handlers/instance/instance_handler_test.go index 470079c85..ce0ce0ce7 100644 --- a/management/server/http/handlers/instance/instance_handler_test.go +++ b/management/server/http/handlers/instance/instance_handler_test.go @@ -296,7 +296,7 @@ func TestSetup_ManagerError(t *testing.T) { func TestGetVersionInfo_Success(t *testing.T) { manager := &mockInstanceManager{} router := mux.NewRouter() - AddVersionEndpoint(manager, router) + AddVersionEndpoint(manager, router, nil) req := httptest.NewRequest(http.MethodGet, "/instance/version", nil) rec := httptest.NewRecorder() @@ -324,7 +324,7 @@ func TestGetVersionInfo_Error(t *testing.T) { }, } router := mux.NewRouter() - AddVersionEndpoint(manager, router) + AddVersionEndpoint(manager, router, nil) req := httptest.NewRequest(http.MethodGet, "/instance/version", nil) rec := httptest.NewRecorder() diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go index f99eca794..f8fec0948 100644 --- a/management/server/http/handlers/networks/handler.go +++ b/management/server/http/handlers/networks/handler.go @@ -10,14 +10,17 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/networks" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/routers" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/networks/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" nbtypes "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -33,16 +36,16 @@ type handler struct { groupsManager groups.Manager } -func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, router *mux.Router) { - addRouterEndpoints(routerManager, router) - addResourceEndpoints(resourceManager, groupsManager, router) +func AddEndpoints(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager, permissionsManager permissions.Manager, router *mux.Router) { + addRouterEndpoints(routerManager, permissionsManager, router) + addResourceEndpoints(resourceManager, groupsManager, permissionsManager, router) networksHandler := newHandler(networksManager, resourceManager, routerManager, groupsManager, accountManager) - router.HandleFunc("/networks", networksHandler.getAllNetworks).Methods("GET", "OPTIONS") - router.HandleFunc("/networks", networksHandler.createNetwork).Methods("POST", "OPTIONS") - router.HandleFunc("/networks/{networkId}", networksHandler.getNetwork).Methods("GET", "OPTIONS") - router.HandleFunc("/networks/{networkId}", networksHandler.updateNetwork).Methods("PUT", "OPTIONS") - router.HandleFunc("/networks/{networkId}", networksHandler.deleteNetwork).Methods("DELETE", "OPTIONS") + router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getAllNetworks)).Methods("GET", "OPTIONS") + router.HandleFunc("/networks", permissionsManager.WithPermission(modules.Networks, operations.Create, networksHandler.createNetwork)).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Read, networksHandler.getNetwork)).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Update, networksHandler.updateNetwork)).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, networksHandler.deleteNetwork)).Methods("DELETE", "OPTIONS") } func newHandler(networksManager networks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager groups.Manager, accountManager account.Manager) *handler { @@ -55,40 +58,32 @@ func newHandler(networksManager networks.Manager, resourceManager resources.Mana } } -func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) +func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + networks, err := h.networksManager.GetAllNetworks(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - accountID, userID := userAuth.AccountId, userAuth.UserId - - networks, err := h.networksManager.GetAllNetworks(r.Context(), accountID, userID) + resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - resourceIDs, err := h.resourceManager.GetAllResourceIDsInAccount(r.Context(), accountID, userID) + groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), accountID, userID) + routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - routers, err := h.routerManager.GetAllRoutersInAccount(r.Context(), accountID, userID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - account, err := h.accountManager.GetAccount(r.Context(), accountID) + account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -97,16 +92,9 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, h.generateNetworkResponse(networks, routers, resourceIDs, groups, account)) } -func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.NetworkRequest - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -115,14 +103,14 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { network := &types.Network{} network.FromAPIRequest(&req) - network.AccountID = accountID - network, err = h.networksManager.CreateNetwork(r.Context(), userID, network) + network.AccountID = userAuth.AccountId + network, err = h.networksManager.CreateNetwork(r.Context(), userAuth.UserId, network) if err != nil { util.WriteError(r.Context(), err, w) return } - account, err := h.accountManager.GetAccount(r.Context(), accountID) + account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -133,14 +121,7 @@ func (h *handler) createNetwork(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, network.ToAPIResponse([]string{}, []string{}, 0, policyIDs)) } -func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) networkID := vars["networkId"] if len(networkID) == 0 { @@ -148,19 +129,19 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { return } - network, err := h.networksManager.GetNetwork(r.Context(), accountID, userID, networkID) + network, err := h.networksManager.GetNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID) if err != nil { util.WriteError(r.Context(), err, w) return } - routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID) + routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID) if err != nil { util.WriteError(r.Context(), err, w) return } - account, err := h.accountManager.GetAccount(r.Context(), accountID) + account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -171,14 +152,7 @@ func (h *handler) getNetwork(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs)) } -func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) networkID := vars["networkId"] if len(networkID) == 0 { @@ -187,7 +161,7 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { } var req api.NetworkRequest - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -197,20 +171,20 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { network.FromAPIRequest(&req) network.ID = networkID - network.AccountID = accountID - network, err = h.networksManager.UpdateNetwork(r.Context(), userID, network) + network.AccountID = userAuth.AccountId + network, err = h.networksManager.UpdateNetwork(r.Context(), userAuth.UserId, network) if err != nil { util.WriteError(r.Context(), err, w) return } - routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), accountID, userID, networkID) + routerIDs, resourceIDs, peerCount, err := h.collectIDsInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID) if err != nil { util.WriteError(r.Context(), err, w) return } - account, err := h.accountManager.GetAccount(r.Context(), accountID) + account, err := h.accountManager.GetAccount(r.Context(), userAuth.AccountId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -221,14 +195,7 @@ func (h *handler) updateNetwork(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, network.ToAPIResponse(routerIDs, resourceIDs, peerCount, policyIDs)) } -func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) networkID := vars["networkId"] if len(networkID) == 0 { @@ -236,7 +203,7 @@ func (h *handler) deleteNetwork(w http.ResponseWriter, r *http.Request) { return } - err = h.networksManager.DeleteNetwork(r.Context(), accountID, userID, networkID) + err := h.networksManager.DeleteNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/networks/resources_handler.go b/management/server/http/handlers/networks/resources_handler.go index c31729a39..4df968d6b 100644 --- a/management/server/http/handlers/networks/resources_handler.go +++ b/management/server/http/handlers/networks/resources_handler.go @@ -6,10 +6,13 @@ import ( "github.com/gorilla/mux" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/networks/resources" "github.com/netbirdio/netbird/management/server/networks/resources/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" ) @@ -19,14 +22,14 @@ type resourceHandler struct { groupsManager groups.Manager } -func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, router *mux.Router) { +func addResourceEndpoints(resourcesManager resources.Manager, groupsManager groups.Manager, permissionsManager permissions.Manager, router *mux.Router) { resourceHandler := newResourceHandler(resourcesManager, groupsManager) - router.HandleFunc("/networks/resources", resourceHandler.getAllResourcesInAccount).Methods("GET", "OPTIONS") - router.HandleFunc("/networks/{networkId}/resources", resourceHandler.getAllResourcesInNetwork).Methods("GET", "OPTIONS") - router.HandleFunc("/networks/{networkId}/resources", resourceHandler.createResource).Methods("POST", "OPTIONS") - router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.getResource).Methods("GET", "OPTIONS") - router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.updateResource).Methods("PUT", "OPTIONS") - router.HandleFunc("/networks/{networkId}/resources/{resourceId}", resourceHandler.deleteResource).Methods("DELETE", "OPTIONS") + router.HandleFunc("/networks/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInAccount)).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getAllResourcesInNetwork)).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources", permissionsManager.WithPermission(modules.Networks, operations.Create, resourceHandler.createResource)).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Read, resourceHandler.getResource)).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Update, resourceHandler.updateResource)).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}/resources/{resourceId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, resourceHandler.deleteResource)).Methods("DELETE", "OPTIONS") } func newResourceHandler(resourceManager resources.Manager, groupsManager groups.Manager) *resourceHandler { @@ -36,22 +39,15 @@ func newResourceHandler(resourceManager resources.Manager, groupsManager groups. } } -func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { networkID := mux.Vars(r)["networkId"] - resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), accountID, userID, networkID) + resources, err := h.resourceManager.GetAllResourcesInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID) if err != nil { util.WriteError(r.Context(), err, w) return } - grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -66,22 +62,14 @@ func (h *resourceHandler) getAllResourcesInNetwork(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, resourcesResponse) } -func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) +func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - accountID, userID := userAuth.AccountId, userAuth.UserId - - resources, err := h.resourceManager.GetAllResourcesInAccount(r.Context(), accountID, userID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -97,17 +85,9 @@ func (h *resourceHandler) getAllResourcesInAccount(w http.ResponseWriter, r *htt util.WriteJSONObject(r.Context(), w, resourcesResponse) } -func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.NetworkResourceRequest - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -117,14 +97,14 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) resource.FromAPIRequest(&req) resource.NetworkID = mux.Vars(r)["networkId"] - resource.AccountID = accountID - resource, err = h.resourceManager.CreateResource(r.Context(), userID, resource) + resource.AccountID = userAuth.AccountId + resource, err = h.resourceManager.CreateResource(r.Context(), userAuth.UserId, resource) if err != nil { util.WriteError(r.Context(), err, w) return } - grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -135,23 +115,16 @@ func (h *resourceHandler) createResource(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID])) } -func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { networkID := mux.Vars(r)["networkId"] resourceID := mux.Vars(r)["resourceId"] - resource, err := h.resourceManager.GetResource(r.Context(), accountID, userID, networkID, resourceID) + resource, err := h.resourceManager.GetResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID) if err != nil { util.WriteError(r.Context(), err, w) return } - grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -162,16 +135,9 @@ func (h *resourceHandler) getResource(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID])) } -func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.NetworkResourceRequest - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -182,14 +148,14 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) resource.ID = mux.Vars(r)["resourceId"] resource.NetworkID = mux.Vars(r)["networkId"] - resource.AccountID = accountID - resource, err = h.resourceManager.UpdateResource(r.Context(), userID, resource) + resource.AccountID = userAuth.AccountId + resource, err = h.resourceManager.UpdateResource(r.Context(), userAuth.UserId, resource) if err != nil { util.WriteError(r.Context(), err, w) return } - grps, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + grps, err := h.groupsManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -200,17 +166,10 @@ func (h *resourceHandler) updateResource(w http.ResponseWriter, r *http.Request) util.WriteJSONObject(r.Context(), w, resource.ToAPIResponse(grpsInfoMap[resource.ID])) } -func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *resourceHandler) deleteResource(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { networkID := mux.Vars(r)["networkId"] resourceID := mux.Vars(r)["resourceId"] - err = h.resourceManager.DeleteResource(r.Context(), accountID, userID, networkID, resourceID) + err := h.resourceManager.DeleteResource(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, resourceID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go index c311a29fe..0e33d2757 100644 --- a/management/server/http/handlers/networks/routers_handler.go +++ b/management/server/http/handlers/networks/routers_handler.go @@ -6,9 +6,12 @@ import ( "github.com/gorilla/mux" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/networks/routers" "github.com/netbirdio/netbird/management/server/networks/routers/types" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" ) @@ -17,14 +20,14 @@ type routersHandler struct { routersManager routers.Manager } -func addRouterEndpoints(routersManager routers.Manager, router *mux.Router) { +func addRouterEndpoints(routersManager routers.Manager, permissionsManager permissions.Manager, router *mux.Router) { routersHandler := newRoutersHandler(routersManager) - router.HandleFunc("/networks/routers", routersHandler.getAllRouters).Methods("GET", "OPTIONS") - router.HandleFunc("/networks/{networkId}/routers", routersHandler.getNetworkRouters).Methods("GET", "OPTIONS") - router.HandleFunc("/networks/{networkId}/routers", routersHandler.createRouter).Methods("POST", "OPTIONS") - router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.getRouter).Methods("GET", "OPTIONS") - router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.updateRouter).Methods("PUT", "OPTIONS") - router.HandleFunc("/networks/{networkId}/routers/{routerId}", routersHandler.deleteRouter).Methods("DELETE", "OPTIONS") + router.HandleFunc("/networks/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getAllRouters)).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getNetworkRouters)).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers", permissionsManager.WithPermission(modules.Networks, operations.Create, routersHandler.createRouter)).Methods("POST", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Read, routersHandler.getRouter)).Methods("GET", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Update, routersHandler.updateRouter)).Methods("PUT", "OPTIONS") + router.HandleFunc("/networks/{networkId}/routers/{routerId}", permissionsManager.WithPermission(modules.Networks, operations.Delete, routersHandler.deleteRouter)).Methods("DELETE", "OPTIONS") } func newRoutersHandler(routersManager routers.Manager) *routersHandler { @@ -33,16 +36,8 @@ func newRoutersHandler(routersManager routers.Manager) *routersHandler { } } -func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), accountID, userID) +func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + routersMap, err := h.routersManager.GetAllRoutersInAccount(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -58,17 +53,9 @@ func (h *routersHandler) getAllRouters(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, routersResponse) } -func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { networkID := mux.Vars(r)["networkId"] - routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), accountID, userID, networkID) + routers, err := h.routersManager.GetAllRoutersInNetwork(r.Context(), userAuth.AccountId, userAuth.UserId, networkID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -82,18 +69,10 @@ func (h *routersHandler) getNetworkRouters(w http.ResponseWriter, r *http.Reques util.WriteJSONObject(r.Context(), w, routersResponse) } -func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { networkID := mux.Vars(r)["networkId"] var req api.NetworkRouterRequest - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -103,9 +82,9 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { router.FromAPIRequest(&req) router.NetworkID = networkID - router.AccountID = accountID + router.AccountID = userAuth.AccountId router.Enabled = true - router, err = h.routersManager.CreateRouter(r.Context(), userID, router) + router, err = h.routersManager.CreateRouter(r.Context(), userAuth.UserId, router) if err != nil { util.WriteError(r.Context(), err, w) return @@ -114,18 +93,10 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) } -func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { routerID := mux.Vars(r)["routerId"] networkID := mux.Vars(r)["networkId"] - router, err := h.routersManager.GetRouter(r.Context(), accountID, userID, networkID, routerID) + router, err := h.routersManager.GetRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -134,17 +105,9 @@ func (h *routersHandler) getRouter(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) } -func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.NetworkRouterRequest - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -155,9 +118,9 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { router.NetworkID = mux.Vars(r)["networkId"] router.ID = mux.Vars(r)["routerId"] - router.AccountID = accountID + router.AccountID = userAuth.AccountId - router, err = h.routersManager.UpdateRouter(r.Context(), userID, router) + router, err = h.routersManager.UpdateRouter(r.Context(), userAuth.UserId, router) if err != nil { util.WriteError(r.Context(), err, w) return @@ -166,17 +129,10 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, router.ToAPIResponse()) } -func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *routersHandler) deleteRouter(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { routerID := mux.Vars(r)["routerId"] networkID := mux.Vars(r)["networkId"] - err = h.routersManager.DeleteRouter(r.Context(), accountID, userID, networkID, routerID) + err := h.routersManager.DeleteRouter(r.Context(), userAuth.AccountId, userAuth.UserId, networkID, routerID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 6b9a69f04..566fcc064 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -1,7 +1,6 @@ package peers import ( - "context" "encoding/json" "fmt" "net/http" @@ -14,13 +13,13 @@ import ( "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -35,14 +34,15 @@ type Handler struct { func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) { peersHandler := NewHandler(accountManager, networkMapController, permissionsManager) - router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") - router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). - Methods("GET", "PUT", "DELETE", "OPTIONS") - router.HandleFunc("/peers/{peerId}/accessible-peers", peersHandler.GetAccessiblePeers).Methods("GET", "OPTIONS") - router.HandleFunc("/peers/{peerId}/temporary-access", peersHandler.CreateTemporaryAccess).Methods("POST", "OPTIONS") - router.HandleFunc("/peers/{peerId}/jobs", peersHandler.ListJobs).Methods("GET", "OPTIONS") - router.HandleFunc("/peers/{peerId}/jobs", peersHandler.CreateJob).Methods("POST", "OPTIONS") - router.HandleFunc("/peers/{peerId}/jobs/{jobId}", peersHandler.GetJob).Methods("GET", "OPTIONS") + router.HandleFunc("/peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAllPeers)).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetPeer)).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Update, peersHandler.UpdatePeer)).Methods("PUT", "OPTIONS") + router.HandleFunc("/peers/{peerId}", permissionsManager.WithPermission(modules.Peers, operations.Delete, peersHandler.DeletePeer)).Methods("DELETE", "OPTIONS") + router.HandleFunc("/peers/{peerId}/accessible-peers", permissionsManager.WithPermission(modules.Peers, operations.Read, peersHandler.GetAccessiblePeers)).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}/temporary-access", permissionsManager.WithPermission(modules.Peers, operations.Create, peersHandler.CreateTemporaryAccess)).Methods("POST", "OPTIONS") + router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.ListJobs)).Methods("GET", "OPTIONS") + router.HandleFunc("/peers/{peerId}/jobs", permissionsManager.WithPermission(modules.RemoteJobs, operations.Create, peersHandler.CreateJob)).Methods("POST", "OPTIONS") + router.HandleFunc("/peers/{peerId}/jobs/{jobId}", permissionsManager.WithPermission(modules.RemoteJobs, operations.Read, peersHandler.GetJob)).Methods("GET", "OPTIONS") } // NewHandler creates a new peers Handler @@ -54,14 +54,7 @@ func NewHandler(accountManager account.Manager, networkMapController network_map } } -func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - userAuth, err := nbcontext.GetUserAuthFromContext(ctx) - if err != nil { - util.WriteError(ctx, err, w) - return - } - +func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) peerID := vars["peerId"] @@ -73,37 +66,30 @@ func (h *Handler) CreateJob(w http.ResponseWriter, r *http.Request) { job, err := types.NewJob(userAuth.UserId, userAuth.AccountId, peerID, req) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } - if err := h.accountManager.CreatePeerJob(ctx, userAuth.AccountId, peerID, userAuth.UserId, job); err != nil { - util.WriteError(ctx, err, w) + if err := h.accountManager.CreatePeerJob(r.Context(), userAuth.AccountId, peerID, userAuth.UserId, job); err != nil { + util.WriteError(r.Context(), err, w) return } resp, err := toSingleJobResponse(job) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(ctx, w, resp) + util.WriteJSONObject(r.Context(), w, resp) } -func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - userAuth, err := nbcontext.GetUserAuthFromContext(ctx) - if err != nil { - util.WriteError(ctx, err, w) - return - } - +func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) peerID := vars["peerId"] - jobs, err := h.accountManager.GetAllPeerJobs(ctx, userAuth.AccountId, userAuth.UserId, peerID) + jobs, err := h.accountManager.GetAllPeerJobs(r.Context(), userAuth.AccountId, userAuth.UserId, peerID) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } @@ -111,79 +97,88 @@ func (h *Handler) ListJobs(w http.ResponseWriter, r *http.Request) { for _, job := range jobs { resp, err := toSingleJobResponse(job) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } respBody = append(respBody, resp) } - util.WriteJSONObject(ctx, w, respBody) + util.WriteJSONObject(r.Context(), w, respBody) } -func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - userAuth, err := nbcontext.GetUserAuthFromContext(ctx) - if err != nil { - util.WriteError(ctx, err, w) - return - } - +func (h *Handler) GetJob(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) peerID := vars["peerId"] jobID := vars["jobId"] - job, err := h.accountManager.GetPeerJobByID(ctx, userAuth.AccountId, userAuth.UserId, peerID, jobID) + job, err := h.accountManager.GetPeerJobByID(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, jobID) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } resp, err := toSingleJobResponse(job) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(ctx, w, resp) + util.WriteJSONObject(r.Context(), w, resp) } -func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, w http.ResponseWriter) { - peer, err := h.accountManager.GetPeer(ctx, accountID, peerID, userID) +// GetPeer handles GET request for a single peer +func (h *Handler) GetPeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + vars := mux.Vars(r) + peerID := vars["peerId"] + if len(peerID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w) + return + } + + peer, err := h.accountManager.GetPeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } if peer.ProxyMeta.Embedded { - util.WriteError(ctx, status.Errorf(status.InvalidArgument, "not allowed to read peer"), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "not allowed to read peer"), w) return } - settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator) + settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } dnsDomain := h.networkMapController.GetDNSDomain(settings) - grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) + grps, _ := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peerID) grpsInfoMap := groups.ToGroupsInfoMap(grps, 0) - validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId) if err != nil { - log.WithContext(ctx).Errorf("failed to list approved peers: %v", err) - util.WriteError(ctx, fmt.Errorf("internal error"), w) + log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) + util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } _, valid := validPeers[peer.ID] reason := invalidPeers[peer.ID] - util.WriteJSONObject(ctx, w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) + util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } -func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID string, w http.ResponseWriter, r *http.Request) { +// UpdatePeer handles PUT request to update a peer +func (h *Handler) UpdatePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + vars := mux.Vars(r) + peerID := vars["peerId"] + if len(peerID) == 0 { + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid peer ID"), w) + return + } + req := &api.PeerRequest{} err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -192,11 +187,10 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri } update := &nbpeer.Peer{ - ID: peerID, - SSHEnabled: req.SshEnabled, - Name: req.Name, - LoginExpirationEnabled: req.LoginExpirationEnabled, - + ID: peerID, + SSHEnabled: req.SshEnabled, + Name: req.Name, + LoginExpirationEnabled: req.LoginExpirationEnabled, InactivityExpirationEnabled: req.InactivityExpirationEnabled, } @@ -210,41 +204,41 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri if req.Ip != nil { addr, err := netip.ParseAddr(*req.Ip) if err != nil { - util.WriteError(ctx, status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w) + util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid IP address %s: %v", *req.Ip, err), w) return } - if err = h.accountManager.UpdatePeerIP(ctx, accountID, userID, peerID, addr); err != nil { - util.WriteError(ctx, err, w) + if err = h.accountManager.UpdatePeerIP(r.Context(), userAuth.AccountId, userAuth.UserId, peerID, addr); err != nil { + util.WriteError(r.Context(), err, w) return } } - peer, err := h.accountManager.UpdatePeer(ctx, accountID, userID, update) + peer, err := h.accountManager.UpdatePeer(r.Context(), userAuth.AccountId, userAuth.UserId, update) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } - settings, err := h.accountManager.GetAccountSettings(ctx, accountID, activity.SystemInitiator) + settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } dnsDomain := h.networkMapController.GetDNSDomain(settings) - peerGroups, err := h.accountManager.GetPeerGroups(ctx, accountID, peer.ID) + peerGroups, err := h.accountManager.GetPeerGroups(r.Context(), userAuth.AccountId, peer.ID) if err != nil { - util.WriteError(ctx, err, w) + util.WriteError(r.Context(), err, w) return } grpsInfoMap := groups.ToGroupsInfoMap(peerGroups, 0) - validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) + validPeers, invalidPeers, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId) if err != nil { - log.WithContext(ctx).Errorf("failed to get validated peers: %v", err) - util.WriteError(ctx, fmt.Errorf("internal error"), w) + log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err) + util.WriteError(r.Context(), fmt.Errorf("internal error"), w) return } @@ -254,25 +248,8 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri util.WriteJSONObject(r.Context(), w, toSinglePeerResponse(peer, grpsInfoMap[peerID], dnsDomain, valid, reason)) } -func (h *Handler) deletePeer(ctx context.Context, accountID, userID string, peerID string, w http.ResponseWriter) { - err := h.accountManager.DeletePeer(ctx, accountID, peerID, userID) - if err != nil { - log.WithContext(ctx).Errorf("failed to delete peer: %v", err) - util.WriteError(ctx, err, w) - return - } - util.WriteJSONObject(ctx, w, util.EmptyObject{}) -} - -// HandlePeer handles all peer requests for GET, PUT and DELETE operations -func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +// DeletePeer handles DELETE request to delete a peer +func (h *Handler) DeletePeer(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) peerID := vars["peerId"] if len(peerID) == 0 { @@ -280,48 +257,34 @@ func (h *Handler) HandlePeer(w http.ResponseWriter, r *http.Request) { return } - switch r.Method { - case http.MethodDelete: - h.deletePeer(r.Context(), accountID, userID, peerID, w) + err := h.accountManager.DeletePeer(r.Context(), userAuth.AccountId, peerID, userAuth.UserId) + if err != nil { + log.WithContext(r.Context()).Errorf("failed to delete peer: %v", err) + util.WriteError(r.Context(), err, w) return - case http.MethodGet: - h.getPeer(r.Context(), accountID, peerID, userID, w) - return - case http.MethodPut: - h.updatePeer(r.Context(), accountID, userID, peerID, w, r) - return - default: - util.WriteError(r.Context(), status.Errorf(status.NotFound, "unknown METHOD"), w) } + util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } // GetAllPeers returns a list of all peers associated with a provided account -func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { nameFilter := r.URL.Query().Get("name") ipFilter := r.URL.Query().Get("ip") - accountID, userID := userAuth.AccountId, userAuth.UserId - - peers, err := h.accountManager.GetPeers(r.Context(), accountID, userID, nameFilter, ipFilter) + peers, err := h.accountManager.GetPeers(r.Context(), userAuth.AccountId, userAuth.UserId, nameFilter, ipFilter) if err != nil { util.WriteError(r.Context(), err, w) return } - settings, err := h.accountManager.GetAccountSettings(r.Context(), accountID, activity.SystemInitiator) + settings, err := h.accountManager.GetAccountSettings(r.Context(), userAuth.AccountId, activity.SystemInitiator) if err != nil { util.WriteError(r.Context(), err, w) return } dnsDomain := h.networkMapController.GetDNSDomain(settings) - grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + grps, _ := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) grpsInfoMap := groups.ToGroupsInfoMap(grps, len(peers)) respBody := make([]*api.PeerBatch, 0, len(peers)) @@ -332,7 +295,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { respBody = append(respBody, toPeerListItemResponse(peer, grpsInfoMap[peer.ID], dnsDomain, 0)) } - validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeersMap, invalidPeersMap, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId) if err != nil { log.WithContext(r.Context()).Errorf("failed to get validated peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -356,15 +319,7 @@ func (h *Handler) setApprovalRequiredFlag(respBody []*api.PeerBatch, validPeersM } // GetAccessiblePeers returns a list of all peers that the specified peer can connect to within the network. -func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) peerID := vars["peerId"] if len(peerID) == 0 { @@ -372,19 +327,19 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { return } - user, err := h.accountManager.GetUserByID(r.Context(), userID) + user, err := h.accountManager.GetUserByID(r.Context(), userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), accountID, userID, modules.Peers, operations.Read) + allowed, err := h.permissionsManager.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, modules.Peers, operations.Read) if err != nil { util.WriteError(r.Context(), status.NewPermissionValidationError(err), w) return } - account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator) + account, err := h.accountManager.GetAccountByID(r.Context(), userAuth.AccountId, activity.SystemInitiator) if err != nil { util.WriteError(r.Context(), err, w) return @@ -408,7 +363,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { } } - validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), accountID) + validPeers, _, err := h.accountManager.GetValidatedPeers(r.Context(), userAuth.AccountId) if err != nil { log.WithContext(r.Context()).Errorf("failed to list approved peers: %v", err) util.WriteError(r.Context(), fmt.Errorf("internal error"), w) @@ -422,13 +377,7 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, toAccessiblePeers(netMap, dnsDomain)) } -func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) peerID := vars["peerId"] if len(peerID) == 0 { @@ -437,7 +386,7 @@ func (h *Handler) CreateTemporaryAccess(w http.ResponseWriter, r *http.Request) } var req api.PeerTemporaryAccessRequest - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 6b3616597..ae4cbe0af 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -34,6 +34,18 @@ import ( "github.com/netbirdio/netbird/management/server/mock_server" ) +// wrapHandler wraps a handler function that requires userAuth parameter +func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + h(w, r, userAuth) + } +} + const ( testPeerID = "test_peer" noUpdateChannelTestPeerID = "no-update-channel" @@ -307,9 +319,9 @@ func TestGetPeers(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/peers/", p.GetAllPeers).Methods("GET") - router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("GET") - router.HandleFunc("/api/peers/{peerId}", p.HandlePeer).Methods("PUT") + router.HandleFunc("/api/peers/", wrapHandler(p.GetAllPeers)).Methods("GET") + router.HandleFunc("/api/peers/{peerId}", wrapHandler(p.GetPeer)).Methods("GET") + router.HandleFunc("/api/peers/{peerId}", wrapHandler(p.UpdatePeer)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -498,7 +510,7 @@ func TestGetAccessiblePeers(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/peers/{peerId}/accessible-peers", p.GetAccessiblePeers).Methods("GET") + router.HandleFunc("/api/peers/{peerId}/accessible-peers", wrapHandler(p.GetAccessiblePeers)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -582,7 +594,7 @@ func TestPeersHandlerUpdatePeerIP(t *testing.T) { rr := httptest.NewRecorder() router := mux.NewRouter() - router.HandleFunc("/peers/{peerId}", p.HandlePeer).Methods("PUT") + router.HandleFunc("/peers/{peerId}", wrapHandler(p.UpdatePeer)).Methods("PUT") router.ServeHTTP(rr, req) diff --git a/management/server/http/handlers/policies/geolocation_handler_test.go b/management/server/http/handlers/policies/geolocation_handler_test.go index 094a36e38..8084cac75 100644 --- a/management/server/http/handlers/policies/geolocation_handler_test.go +++ b/management/server/http/handlers/policies/geolocation_handler_test.go @@ -65,6 +65,18 @@ func initGeolocationTestData(t *testing.T) *geolocationsHandler { } } +// wrapHandler wraps a handler function that requires userAuth parameter +func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + h(w, r, userAuth) + } +} + func TestGetCitiesByCountry(t *testing.T) { tt := []struct { name string @@ -121,7 +133,7 @@ func TestGetCitiesByCountry(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/locations/countries/{country}/cities", geolocationHandler.getCitiesByCountry).Methods("GET") + router.HandleFunc("/api/locations/countries/{country}/cities", wrapHandler(geolocationHandler.getCitiesByCountry)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -214,7 +226,7 @@ func TestGetAllCountries(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/locations/countries", geolocationHandler.getAllCountries).Methods("GET") + router.HandleFunc("/api/locations/countries", wrapHandler(geolocationHandler.getAllCountries)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/policies/geolocations_handler.go b/management/server/http/handlers/policies/geolocations_handler.go index a2d656a47..793e943a8 100644 --- a/management/server/http/handlers/policies/geolocations_handler.go +++ b/management/server/http/handlers/policies/geolocations_handler.go @@ -7,11 +7,11 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -30,8 +30,8 @@ type geolocationsHandler struct { func AddLocationsEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, permissionsManager permissions.Manager, router *mux.Router) { locationHandler := newGeolocationsHandlerHandler(accountManager, locationManager, permissionsManager) - router.HandleFunc("/locations/countries", locationHandler.getAllCountries).Methods("GET", "OPTIONS") - router.HandleFunc("/locations/countries/{country}/cities", locationHandler.getCitiesByCountry).Methods("GET", "OPTIONS") + router.HandleFunc("/locations/countries", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getAllCountries)).Methods("GET", "OPTIONS") + router.HandleFunc("/locations/countries/{country}/cities", permissionsManager.WithPermission(modules.Policies, operations.Read, locationHandler.getCitiesByCountry)).Methods("GET", "OPTIONS") } // newGeolocationsHandlerHandler creates a new Geolocations handler @@ -44,12 +44,7 @@ func newGeolocationsHandlerHandler(accountManager account.Manager, geolocationMa } // getAllCountries retrieves a list of all countries -func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request) { - if err := l.authenticateUser(r); err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if l.geolocationManager == nil { // TODO: update error message to include geo db self hosted doc link when ready util.WriteError(r.Context(), status.Errorf(status.PreconditionFailed, "Geo location database is not initialized"), w) @@ -70,12 +65,7 @@ func (l *geolocationsHandler) getAllCountries(w http.ResponseWriter, r *http.Req } // getCitiesByCountry retrieves a list of cities based on the given country code -func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request) { - if err := l.authenticateUser(r); err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) countryCode := vars["country"] if !countryCodeRegex.MatchString(countryCode) { @@ -102,27 +92,6 @@ func (l *geolocationsHandler) getCitiesByCountry(w http.ResponseWriter, r *http. util.WriteJSONObject(r.Context(), w, cities) } -func (l *geolocationsHandler) authenticateUser(r *http.Request) error { - ctx := r.Context() - - userAuth, err := nbcontext.GetUserAuthFromContext(ctx) - if err != nil { - return err - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - allowed, err := l.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Policies, operations.Read) - if err != nil { - return status.NewPermissionValidationError(err) - } - - if !allowed { - return status.NewPermissionDeniedError() - } - return nil -} - func toCountryResponse(country geolocation.Country) api.Country { return api.Country{ CountryName: country.CountryName, diff --git a/management/server/http/handlers/policies/policies_handler.go b/management/server/http/handlers/policies/policies_handler.go index e4d1d73df..b382168d9 100644 --- a/management/server/http/handlers/policies/policies_handler.go +++ b/management/server/http/handlers/policies/policies_handler.go @@ -8,9 +8,12 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -21,13 +24,13 @@ type handler struct { accountManager account.Manager } -func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) { +func AddEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) { policiesHandler := newHandler(accountManager) - router.HandleFunc("/policies", policiesHandler.getAllPolicies).Methods("GET", "OPTIONS") - router.HandleFunc("/policies", policiesHandler.createPolicy).Methods("POST", "OPTIONS") - router.HandleFunc("/policies/{policyId}", policiesHandler.updatePolicy).Methods("PUT", "OPTIONS") - router.HandleFunc("/policies/{policyId}", policiesHandler.getPolicy).Methods("GET", "OPTIONS") - router.HandleFunc("/policies/{policyId}", policiesHandler.deletePolicy).Methods("DELETE", "OPTIONS") + router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getAllPolicies)).Methods("GET", "OPTIONS") + router.HandleFunc("/policies", permissionsManager.WithPermission(modules.Policies, operations.Create, policiesHandler.createPolicy)).Methods("POST", "OPTIONS") + router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Update, policiesHandler.updatePolicy)).Methods("PUT", "OPTIONS") + router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Read, policiesHandler.getPolicy)).Methods("GET", "OPTIONS") + router.HandleFunc("/policies/{policyId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, policiesHandler.deletePolicy)).Methods("DELETE", "OPTIONS") } // newHandler creates a new policies handler @@ -38,22 +41,14 @@ func newHandler(accountManager account.Manager) *handler { } // getAllPolicies list for the account -func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) +func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + listPolicies, err := h.accountManager.ListPolicies(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - accountID, userID := userAuth.AccountId, userAuth.UserId - - listPolicies, err := h.accountManager.ListPolicies(r.Context(), accountID, userID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -73,15 +68,7 @@ func (h *handler) getAllPolicies(w http.ResponseWriter, r *http.Request) { } // updatePolicy handles update to a policy identified by a given ID -func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { @@ -89,26 +76,18 @@ func (h *handler) updatePolicy(w http.ResponseWriter, r *http.Request) { return } - _, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID) + _, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - h.savePolicy(w, r, accountID, userID, policyID, false) + h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, policyID, false) } // createPolicy handles policy creation request -func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - h.savePolicy(w, r, accountID, userID, "", true) +func (h *handler) createPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + h.savePolicy(w, r, userAuth.AccountId, userAuth.UserId, "", true) } // savePolicy handles policy creation and update @@ -303,14 +282,7 @@ func (h *handler) savePolicy(w http.ResponseWriter, r *http.Request, accountID s } // deletePolicy handles policy deletion request -func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { @@ -318,7 +290,7 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { return } - if err = h.accountManager.DeletePolicy(r.Context(), accountID, policyID, userID); err != nil { + if err := h.accountManager.DeletePolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId); err != nil { util.WriteError(r.Context(), err, w) return } @@ -327,15 +299,7 @@ func (h *handler) deletePolicy(w http.ResponseWriter, r *http.Request) { } // getPolicy handles a group Get request identified by ID -func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) policyID := vars["policyId"] if len(policyID) == 0 { @@ -343,13 +307,13 @@ func (h *handler) getPolicy(w http.ResponseWriter, r *http.Request) { return } - policy, err := h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID) + policy, err := h.accountManager.GetPolicy(r.Context(), userAuth.AccountId, policyID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - allGroups, err := h.accountManager.GetAllGroups(r.Context(), accountID, userID) + allGroups, err := h.accountManager.GetAllGroups(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/policies/policies_handler_test.go b/management/server/http/handlers/policies/policies_handler_test.go index ca5a0a6ab..258b28670 100644 --- a/management/server/http/handlers/policies/policies_handler_test.go +++ b/management/server/http/handlers/policies/policies_handler_test.go @@ -21,6 +21,18 @@ import ( "github.com/netbirdio/netbird/shared/management/status" ) +// wrapHandler wraps a handler function that requires userAuth parameter +func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + h(w, r, userAuth) + } +} + func initPoliciesTestData(policies ...*types.Policy) *handler { testPolicies := make(map[string]*types.Policy, len(policies)) for _, policy := range policies { @@ -111,7 +123,7 @@ func TestPoliciesGetPolicy(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/policies/{policyId}", p.getPolicy).Methods("GET") + router.HandleFunc("/api/policies/{policyId}", wrapHandler(p.getPolicy)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -275,8 +287,8 @@ func TestPoliciesWritePolicy(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/policies", p.createPolicy).Methods("POST") - router.HandleFunc("/api/policies/{policyId}", p.updatePolicy).Methods("PUT") + router.HandleFunc("/api/policies", wrapHandler(p.createPolicy)).Methods("POST") + router.HandleFunc("/api/policies/{policyId}", wrapHandler(p.updatePolicy)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/policies/posture_checks_handler.go b/management/server/http/handlers/policies/posture_checks_handler.go index 744cde10b..01d919151 100644 --- a/management/server/http/handlers/policies/posture_checks_handler.go +++ b/management/server/http/handlers/policies/posture_checks_handler.go @@ -7,9 +7,12 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/geolocation" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/posture" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -21,13 +24,13 @@ type postureChecksHandler struct { geolocationManager geolocation.Geolocation } -func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router) { +func AddPostureCheckEndpoints(accountManager account.Manager, locationManager geolocation.Geolocation, router *mux.Router, permissionsManager permissions.Manager) { postureCheckHandler := newPostureChecksHandler(accountManager, locationManager) - router.HandleFunc("/posture-checks", postureCheckHandler.getAllPostureChecks).Methods("GET", "OPTIONS") - router.HandleFunc("/posture-checks", postureCheckHandler.createPostureCheck).Methods("POST", "OPTIONS") - router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.updatePostureCheck).Methods("PUT", "OPTIONS") - router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.getPostureCheck).Methods("GET", "OPTIONS") - router.HandleFunc("/posture-checks/{postureCheckId}", postureCheckHandler.deletePostureCheck).Methods("DELETE", "OPTIONS") + router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getAllPostureChecks)).Methods("GET", "OPTIONS") + router.HandleFunc("/posture-checks", permissionsManager.WithPermission(modules.Policies, operations.Create, postureCheckHandler.createPostureCheck)).Methods("POST", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Update, postureCheckHandler.updatePostureCheck)).Methods("PUT", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Read, postureCheckHandler.getPostureCheck)).Methods("GET", "OPTIONS") + router.HandleFunc("/posture-checks/{postureCheckId}", permissionsManager.WithPermission(modules.Policies, operations.Delete, postureCheckHandler.deletePostureCheck)).Methods("DELETE", "OPTIONS") } // newPostureChecksHandler creates a new PostureChecks handler @@ -39,15 +42,8 @@ func newPostureChecksHandler(accountManager account.Manager, geolocationManager } // getAllPostureChecks list for the account -func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), accountID, userID) +func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + listPostureChecks, err := p.accountManager.ListPostureChecks(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -62,15 +58,7 @@ func (p *postureChecksHandler) getAllPostureChecks(w http.ResponseWriter, r *htt } // updatePostureCheck handles update to a posture check identified by a given ID -func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { @@ -78,37 +66,22 @@ func (p *postureChecksHandler) updatePostureCheck(w http.ResponseWriter, r *http return } - _, err = p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID) + _, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return } - p.savePostureChecks(w, r, accountID, userID, postureChecksID, false) + p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, postureChecksID, false) } // createPostureCheck handles posture check creation request -func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - p.savePostureChecks(w, r, accountID, userID, "", true) +func (p *postureChecksHandler) createPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + p.savePostureChecks(w, r, userAuth.AccountId, userAuth.UserId, "", true) } // getPostureCheck handles a posture check Get request identified by ID -func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { @@ -116,7 +89,7 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re return } - postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), accountID, postureChecksID, userID) + postureChecks, err := p.accountManager.GetPostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -126,14 +99,7 @@ func (p *postureChecksHandler) getPostureCheck(w http.ResponseWriter, r *http.Re } // deletePostureCheck handles posture check deletion request -func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) postureChecksID := vars["postureCheckId"] if len(postureChecksID) == 0 { @@ -141,7 +107,7 @@ func (p *postureChecksHandler) deletePostureCheck(w http.ResponseWriter, r *http return } - if err = p.accountManager.DeletePostureChecks(r.Context(), accountID, postureChecksID, userID); err != nil { + if err := p.accountManager.DeletePostureChecks(r.Context(), userAuth.AccountId, postureChecksID, userAuth.UserId); err != nil { util.WriteError(r.Context(), err, w) return } diff --git a/management/server/http/handlers/policies/posture_checks_handler_test.go b/management/server/http/handlers/policies/posture_checks_handler_test.go index a5999f6c7..104fbf0e5 100644 --- a/management/server/http/handlers/policies/posture_checks_handler_test.go +++ b/management/server/http/handlers/policies/posture_checks_handler_test.go @@ -26,6 +26,18 @@ import ( var berlin = "Berlin" var losAngeles = "Los Angeles" +// wrapHandler wraps a handler function that requires userAuth parameter +func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + h(w, r, userAuth) + } +} + func initPostureChecksTestData(postureChecks ...*posture.Checks) *postureChecksHandler { testPostureChecks := make(map[string]*posture.Checks, len(postureChecks)) for _, postureCheck := range postureChecks { @@ -183,7 +195,7 @@ func TestGetPostureCheck(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/posture-checks/{postureCheckId}", p.getPostureCheck).Methods("GET") + router.HandleFunc("/api/posture-checks/{postureCheckId}", wrapHandler(p.getPostureCheck)).Methods("GET") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -841,8 +853,8 @@ func TestPostureCheckUpdate(t *testing.T) { } router := mux.NewRouter() - router.HandleFunc("/api/posture-checks", defaultHandler.createPostureCheck).Methods("POST") - router.HandleFunc("/api/posture-checks/{postureCheckId}", defaultHandler.updatePostureCheck).Methods("PUT") + router.HandleFunc("/api/posture-checks", wrapHandler(defaultHandler.createPostureCheck)).Methods("POST") + router.HandleFunc("/api/posture-checks/{postureCheckId}", wrapHandler(defaultHandler.updatePostureCheck)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/routes/routes_handler.go b/management/server/http/handlers/routes/routes_handler.go index 7bb6f2372..19eb1f305 100644 --- a/management/server/http/handlers/routes/routes_handler.go +++ b/management/server/http/handlers/routes/routes_handler.go @@ -9,8 +9,11 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" @@ -26,13 +29,13 @@ type handler struct { accountManager account.Manager } -func AddEndpoints(accountManager account.Manager, router *mux.Router) { +func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { routesHandler := newHandler(accountManager) - router.HandleFunc("/routes", routesHandler.getAllRoutes).Methods("GET", "OPTIONS") - router.HandleFunc("/routes", routesHandler.createRoute).Methods("POST", "OPTIONS") - router.HandleFunc("/routes/{routeId}", routesHandler.updateRoute).Methods("PUT", "OPTIONS") - router.HandleFunc("/routes/{routeId}", routesHandler.getRoute).Methods("GET", "OPTIONS") - router.HandleFunc("/routes/{routeId}", routesHandler.deleteRoute).Methods("DELETE", "OPTIONS") + router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getAllRoutes)).Methods("GET", "OPTIONS") + router.HandleFunc("/routes", permissionsManager.WithPermission(modules.Routes, operations.Create, routesHandler.createRoute)).Methods("POST", "OPTIONS") + router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Update, routesHandler.updateRoute)).Methods("PUT", "OPTIONS") + router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Read, routesHandler.getRoute)).Methods("GET", "OPTIONS") + router.HandleFunc("/routes/{routeId}", permissionsManager.WithPermission(modules.Routes, operations.Delete, routesHandler.deleteRoute)).Methods("DELETE", "OPTIONS") } // newHandler returns a new instance of routes handler @@ -43,16 +46,8 @@ func newHandler(accountManager account.Manager) *handler { } // getAllRoutes returns the list of routes for the account -func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - - routes, err := h.accountManager.ListRoutes(r.Context(), accountID, userID) +func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + routes, err := h.accountManager.ListRoutes(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -71,17 +66,9 @@ func (h *handler) getAllRoutes(w http.ResponseWriter, r *http.Request) { } // createRoute handles route creation request -func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) createRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.PostApiRoutesJSONRequestBody - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -134,8 +121,8 @@ func (h *handler) createRoute(w http.ResponseWriter, r *http.Request) { skipAutoApply = false } - newRoute, err := h.accountManager.CreateRoute(r.Context(), accountID, newPrefix, networkType, domains, peerId, peerGroupIds, - req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userID, req.KeepRoute, skipAutoApply) + newRoute, err := h.accountManager.CreateRoute(r.Context(), userAuth.AccountId, newPrefix, networkType, domains, peerId, peerGroupIds, + req.Description, route.NetID(req.NetworkId), req.Masquerade, req.Metric, req.Groups, accessControlGroupIds, req.Enabled, userAuth.UserId, req.KeepRoute, skipAutoApply) if err != nil { util.WriteError(r.Context(), err, w) @@ -185,14 +172,7 @@ func (h *handler) validateRouteCommon(network *string, domains *[]string, peer * } // updateRoute handles update to a route identified by a given ID -func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) routeID := vars["routeId"] if len(routeID) == 0 { @@ -200,7 +180,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { return } - _, err = h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID) + _, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -271,7 +251,7 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { newRoute.AccessControlGroups = *req.AccessControlGroups } - err = h.accountManager.SaveRoute(r.Context(), accountID, userID, newRoute) + err = h.accountManager.SaveRoute(r.Context(), userAuth.AccountId, userAuth.UserId, newRoute) if err != nil { util.WriteError(r.Context(), err, w) return @@ -287,21 +267,14 @@ func (h *handler) updateRoute(w http.ResponseWriter, r *http.Request) { } // deleteRoute handles route deletion request -func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { routeID := mux.Vars(r)["routeId"] if len(routeID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w) return } - err = h.accountManager.DeleteRoute(r.Context(), accountID, route.ID(routeID), userID) + err := h.accountManager.DeleteRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -311,22 +284,14 @@ func (h *handler) deleteRoute(w http.ResponseWriter, r *http.Request) { } // getRoute handles a route Get request identified by ID -func (h *handler) getRoute(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) getRoute(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { routeID := mux.Vars(r)["routeId"] if len(routeID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid route ID"), w) return } - foundRoute, err := h.accountManager.GetRoute(r.Context(), accountID, route.ID(routeID), userID) + foundRoute, err := h.accountManager.GetRoute(r.Context(), userAuth.AccountId, route.ID(routeID), userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/routes/routes_handler_test.go b/management/server/http/handlers/routes/routes_handler_test.go index a44d81e3e..a4a0c1904 100644 --- a/management/server/http/handlers/routes/routes_handler_test.go +++ b/management/server/http/handlers/routes/routes_handler_test.go @@ -25,6 +25,18 @@ import ( "github.com/netbirdio/netbird/shared/management/status" ) +// wrapHandler wraps a handler function that requires userAuth parameter +func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + h(w, r, userAuth) + } +} + const ( existingRouteID = "existingRouteID" existingRouteID2 = "existingRouteID2" // for peer_groups test @@ -501,10 +513,10 @@ func TestRoutesHandlers(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/routes/{routeId}", p.getRoute).Methods("GET") - router.HandleFunc("/api/routes/{routeId}", p.deleteRoute).Methods("DELETE") - router.HandleFunc("/api/routes", p.createRoute).Methods("POST") - router.HandleFunc("/api/routes/{routeId}", p.updateRoute).Methods("PUT") + router.HandleFunc("/api/routes/{routeId}", wrapHandler(p.getRoute)).Methods("GET") + router.HandleFunc("/api/routes/{routeId}", wrapHandler(p.deleteRoute)).Methods("DELETE") + router.HandleFunc("/api/routes", wrapHandler(p.createRoute)).Methods("POST") + router.HandleFunc("/api/routes/{routeId}", wrapHandler(p.updateRoute)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler.go b/management/server/http/handlers/setup_keys/setupkeys_handler.go index d267b6eea..12506464d 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler.go @@ -9,8 +9,11 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -21,13 +24,13 @@ type handler struct { accountManager account.Manager } -func AddEndpoints(accountManager account.Manager, router *mux.Router) { +func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { keysHandler := newHandler(accountManager) - router.HandleFunc("/setup-keys", keysHandler.getAllSetupKeys).Methods("GET", "OPTIONS") - router.HandleFunc("/setup-keys", keysHandler.createSetupKey).Methods("POST", "OPTIONS") - router.HandleFunc("/setup-keys/{keyId}", keysHandler.getSetupKey).Methods("GET", "OPTIONS") - router.HandleFunc("/setup-keys/{keyId}", keysHandler.updateSetupKey).Methods("PUT", "OPTIONS") - router.HandleFunc("/setup-keys/{keyId}", keysHandler.deleteSetupKey).Methods("DELETE", "OPTIONS") + router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getAllSetupKeys)).Methods("GET", "OPTIONS") + router.HandleFunc("/setup-keys", permissionsManager.WithPermission(modules.SetupKeys, operations.Create, keysHandler.createSetupKey)).Methods("POST", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Read, keysHandler.getSetupKey)).Methods("GET", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Update, keysHandler.updateSetupKey)).Methods("PUT", "OPTIONS") + router.HandleFunc("/setup-keys/{keyId}", permissionsManager.WithPermission(modules.SetupKeys, operations.Delete, keysHandler.deleteSetupKey)).Methods("DELETE", "OPTIONS") } // newHandler creates a new setup key handler @@ -38,16 +41,9 @@ func newHandler(accountManager account.Manager) *handler { } // createSetupKey is a POST requests that creates a new SetupKey -func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { req := &api.PostApiSetupKeysJSONRequestBody{} - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -85,8 +81,8 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { allowExtraDNSLabels = *req.AllowExtraDnsLabels } - setupKey, err := h.accountManager.CreateSetupKey(r.Context(), accountID, req.Name, types.SetupKeyType(req.Type), expiresIn, - req.AutoGroups, req.UsageLimit, userID, ephemeral, allowExtraDNSLabels) + setupKey, err := h.accountManager.CreateSetupKey(r.Context(), userAuth.AccountId, req.Name, types.SetupKeyType(req.Type), expiresIn, + req.AutoGroups, req.UsageLimit, userAuth.UserId, ephemeral, allowExtraDNSLabels) if err != nil { util.WriteError(r.Context(), err, w) return @@ -100,14 +96,7 @@ func (h *handler) createSetupKey(w http.ResponseWriter, r *http.Request) { } // getSetupKey is a GET request to get a SetupKey by ID -func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - accountID, userID := userAuth.AccountId, userAuth.UserId - +func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { @@ -115,7 +104,7 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) { return } - key, err := h.accountManager.GetSetupKey(r.Context(), accountID, userID, keyID) + key, err := h.accountManager.GetSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -125,14 +114,7 @@ func (h *handler) getSetupKey(w http.ResponseWriter, r *http.Request) { } // updateSetupKey is a PUT request to update server.SetupKey -func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { @@ -141,7 +123,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { } req := &api.PutApiSetupKeysKeyIdJSONRequestBody{} - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -157,7 +139,7 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { newKey.Revoked = req.Revoked newKey.Id = keyID - newKey, err = h.accountManager.SaveSetupKey(r.Context(), accountID, newKey, userID) + newKey, err = h.accountManager.SaveSetupKey(r.Context(), userAuth.AccountId, newKey, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -166,15 +148,8 @@ func (h *handler) updateSetupKey(w http.ResponseWriter, r *http.Request) { } // getAllSetupKeys is a GET request that returns a list of SetupKey -func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), accountID, userID) +func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { + setupKeys, err := h.accountManager.ListSetupKeys(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -188,14 +163,7 @@ func (h *handler) getAllSetupKeys(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, apiSetupKeys) } -func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) keyID := vars["keyId"] if len(keyID) == 0 { @@ -203,7 +171,7 @@ func (h *handler) deleteSetupKey(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteSetupKey(r.Context(), accountID, userID, keyID) + err := h.accountManager.DeleteSetupKey(r.Context(), userAuth.AccountId, userAuth.UserId, keyID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go index b137b6dd1..0da445632 100644 --- a/management/server/http/handlers/setup_keys/setupkeys_handler_test.go +++ b/management/server/http/handlers/setup_keys/setupkeys_handler_test.go @@ -22,6 +22,18 @@ import ( "github.com/netbirdio/netbird/shared/management/status" ) +// wrapHandler wraps a handler function that requires userAuth parameter +func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + h(w, r, userAuth) + } +} + const ( existingSetupKeyID = "existingSetupKeyID" newSetupKeyName = "New Setup Key" @@ -171,11 +183,11 @@ func TestSetupKeysHandlers(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/setup-keys", handler.getAllSetupKeys).Methods("GET", "OPTIONS") - router.HandleFunc("/api/setup-keys", handler.createSetupKey).Methods("POST", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.getSetupKey).Methods("GET", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.updateSetupKey).Methods("PUT", "OPTIONS") - router.HandleFunc("/api/setup-keys/{keyId}", handler.deleteSetupKey).Methods("DELETE", "OPTIONS") + router.HandleFunc("/api/setup-keys", wrapHandler(handler.getAllSetupKeys)).Methods("GET", "OPTIONS") + router.HandleFunc("/api/setup-keys", wrapHandler(handler.createSetupKey)).Methods("POST", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", wrapHandler(handler.getSetupKey)).Methods("GET", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", wrapHandler(handler.updateSetupKey)).Methods("PUT", "OPTIONS") + router.HandleFunc("/api/setup-keys/{keyId}", wrapHandler(handler.deleteSetupKey)).Methods("DELETE", "OPTIONS") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/users/invites_handler.go b/management/server/http/handlers/users/invites_handler.go index 0f0f57c29..41f2b9211 100644 --- a/management/server/http/handlers/users/invites_handler.go +++ b/management/server/http/handlers/users/invites_handler.go @@ -10,9 +10,12 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -55,14 +58,14 @@ type invitesHandler struct { } // AddInvitesEndpoints registers invite-related endpoints -func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router) { +func AddInvitesEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { h := &invitesHandler{accountManager: accountManager} // Authenticated endpoints (require admin) - router.HandleFunc("/users/invites", h.listInvites).Methods("GET", "OPTIONS") - router.HandleFunc("/users/invites", h.createInvite).Methods("POST", "OPTIONS") - router.HandleFunc("/users/invites/{inviteId}", h.deleteInvite).Methods("DELETE", "OPTIONS") - router.HandleFunc("/users/invites/{inviteId}/regenerate", h.regenerateInvite).Methods("POST", "OPTIONS") + router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Read, h.listInvites)).Methods("GET", "OPTIONS") + router.HandleFunc("/users/invites", permissionsManager.WithPermission(modules.Users, operations.Create, h.createInvite)).Methods("POST", "OPTIONS") + router.HandleFunc("/users/invites/{inviteId}", permissionsManager.WithPermission(modules.Users, operations.Delete, h.deleteInvite)).Methods("DELETE", "OPTIONS") + router.HandleFunc("/users/invites/{inviteId}/regenerate", permissionsManager.WithPermission(modules.Users, operations.Update, h.regenerateInvite)).Methods("POST", "OPTIONS") } // AddPublicInvitesEndpoints registers public (unauthenticated) invite endpoints with rate limiting @@ -79,14 +82,7 @@ func AddPublicInvitesEndpoints(accountManager account.Manager, router *mux.Route } // listInvites handles GET /api/users/invites -func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) { - - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { invites, err := h.accountManager.ListUserInvites(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) @@ -102,14 +98,7 @@ func (h *invitesHandler) listInvites(w http.ResponseWriter, r *http.Request) { } // createInvite handles POST /api/users/invites -func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request) { - - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *invitesHandler) createInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { var req api.UserInviteCreateRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) @@ -191,18 +180,12 @@ func (h *invitesHandler) acceptInvite(w http.ResponseWriter, r *http.Request) { } // regenerateInvite handles POST /api/users/invites/{inviteId}/regenerate -func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request) { +func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - vars := mux.Vars(r) inviteID := vars["inviteId"] if inviteID == "" { @@ -238,14 +221,7 @@ func (h *invitesHandler) regenerateInvite(w http.ResponseWriter, r *http.Request } // deleteInvite handles DELETE /api/users/invites/{inviteId} -func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) { - - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - +func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) inviteID := vars["inviteId"] if inviteID == "" { @@ -253,7 +229,7 @@ func (h *invitesHandler) deleteInvite(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID) + err := h.accountManager.DeleteUserInvite(r.Context(), userAuth.AccountId, userAuth.UserId, inviteID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/users/invites_handler_test.go b/management/server/http/handlers/users/invites_handler_test.go index 529ea24d6..5017c49b5 100644 --- a/management/server/http/handlers/users/invites_handler_test.go +++ b/management/server/http/handlers/users/invites_handler_test.go @@ -110,7 +110,11 @@ func TestListInvites(t *testing.T) { }) rr := httptest.NewRecorder() - handler.listInvites(rr, req) + userAuth := &auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + } + handler.listInvites(rr, req, userAuth) assert.Equal(t, tc.expectedStatus, rr.Code) @@ -235,7 +239,11 @@ func TestCreateInvite(t *testing.T) { }) rr := httptest.NewRecorder() - handler.createInvite(rr, req) + userAuth := &auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + } + handler.createInvite(rr, req, userAuth) assert.Equal(t, tc.expectedStatus, rr.Code) @@ -573,7 +581,11 @@ func TestRegenerateInvite(t *testing.T) { } rr := httptest.NewRecorder() - handler.regenerateInvite(rr, req) + userAuth := &auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + } + handler.regenerateInvite(rr, req, userAuth) assert.Equal(t, tc.expectedStatus, rr.Code) @@ -651,7 +663,11 @@ func TestDeleteInvite(t *testing.T) { } rr := httptest.NewRecorder() - handler.deleteInvite(rr, req) + userAuth := &auth.UserAuth{ + UserId: testUserID, + AccountId: testAccountID, + } + handler.deleteInvite(rr, req, userAuth) assert.Equal(t, tc.expectedStatus, rr.Code) }) diff --git a/management/server/http/handlers/users/pat_handler.go b/management/server/http/handlers/users/pat_handler.go index 867db3ca9..79ee1fe9d 100644 --- a/management/server/http/handlers/users/pat_handler.go +++ b/management/server/http/handlers/users/pat_handler.go @@ -7,8 +7,11 @@ import ( "github.com/gorilla/mux" "github.com/netbirdio/netbird/management/server/account" - nbcontext "github.com/netbirdio/netbird/management/server/context" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" @@ -19,12 +22,12 @@ type patHandler struct { accountManager account.Manager } -func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router) { +func addUsersTokensEndpoint(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { tokenHandler := newPATsHandler(accountManager) - router.HandleFunc("/users/{userId}/tokens", tokenHandler.getAllTokens).Methods("GET", "OPTIONS") - router.HandleFunc("/users/{userId}/tokens", tokenHandler.createToken).Methods("POST", "OPTIONS") - router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.getToken).Methods("GET", "OPTIONS") - router.HandleFunc("/users/{userId}/tokens/{tokenId}", tokenHandler.deleteToken).Methods("DELETE", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getAllTokens)).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens", permissionsManager.WithPermission(modules.Pats, operations.Create, tokenHandler.createToken)).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Read, tokenHandler.getToken)).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}/tokens/{tokenId}", permissionsManager.WithPermission(modules.Pats, operations.Delete, tokenHandler.deleteToken)).Methods("DELETE", "OPTIONS") } // newPATsHandler creates a new patHandler HTTP handler @@ -35,22 +38,15 @@ func newPATsHandler(accountManager account.Manager) *patHandler { } // getAllTokens is HTTP GET handler that returns a list of all personal access tokens for the given user -func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) targetUserID := vars["userId"] - if len(userID) == 0 { + if len(targetUserID) == 0 { util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid user ID"), w) return } - pats, err := h.accountManager.GetAllPATs(r.Context(), accountID, userID, targetUserID) + pats, err := h.accountManager.GetAllPATs(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -65,14 +61,7 @@ func (h *patHandler) getAllTokens(w http.ResponseWriter, r *http.Request) { } // getToken is HTTP GET handler that returns a personal access token for the given user -func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -86,7 +75,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) { return } - pat, err := h.accountManager.GetPAT(r.Context(), accountID, userID, targetUserID, tokenID) + pat, err := h.accountManager.GetPAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -96,14 +85,7 @@ func (h *patHandler) getToken(w http.ResponseWriter, r *http.Request) { } // createToken is HTTP POST handler that creates a personal access token for the given user -func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -112,13 +94,13 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) { } var req api.PostApiUsersUserIdTokensJSONRequestBody - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } - pat, err := h.accountManager.CreatePAT(r.Context(), accountID, userID, targetUserID, req.Name, req.ExpiresIn) + pat, err := h.accountManager.CreatePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.Name, req.ExpiresIn) if err != nil { util.WriteError(r.Context(), err, w) return @@ -128,14 +110,7 @@ func (h *patHandler) createToken(w http.ResponseWriter, r *http.Request) { } // deleteToken is HTTP DELETE handler that deletes a personal access token for the given user -func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) { - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId +func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -149,7 +124,7 @@ func (h *patHandler) deleteToken(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeletePAT(r.Context(), accountID, userID, targetUserID, tokenID) + err := h.accountManager.DeletePAT(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, tokenID) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/users/pat_handler_test.go b/management/server/http/handlers/users/pat_handler_test.go index 7cda14468..c8f7e9741 100644 --- a/management/server/http/handlers/users/pat_handler_test.go +++ b/management/server/http/handlers/users/pat_handler_test.go @@ -24,6 +24,18 @@ import ( "github.com/netbirdio/netbird/shared/management/status" ) +// wrapHandler wraps a handler function that requires userAuth parameter +func wrapHandler(h func(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth)) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + h(w, r, userAuth) + } +} + const ( existingAccountID = "existingAccountID" notFoundAccountID = "notFoundAccountID" @@ -181,10 +193,10 @@ func TestTokenHandlers(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/users/{userId}/tokens", p.getAllTokens).Methods("GET") - router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.getToken).Methods("GET") - router.HandleFunc("/api/users/{userId}/tokens", p.createToken).Methods("POST") - router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", p.deleteToken).Methods("DELETE") + router.HandleFunc("/api/users/{userId}/tokens", wrapHandler(p.getAllTokens)).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", wrapHandler(p.getToken)).Methods("GET") + router.HandleFunc("/api/users/{userId}/tokens", wrapHandler(p.createToken)).Methods("POST") + router.HandleFunc("/api/users/{userId}/tokens/{tokenId}", wrapHandler(p.deleteToken)).Methods("DELETE") router.ServeHTTP(recorder, req) res := recorder.Result() diff --git a/management/server/http/handlers/users/users_handler.go b/management/server/http/handlers/users/users_handler.go index 40ad585d2..11480f104 100644 --- a/management/server/http/handlers/users/users_handler.go +++ b/management/server/http/handlers/users/users_handler.go @@ -9,13 +9,15 @@ import ( log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" + "github.com/netbirdio/netbird/management/server/permissions" + "github.com/netbirdio/netbird/management/server/permissions/modules" + "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" + "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" - - nbcontext "github.com/netbirdio/netbird/management/server/context" ) // handler is a handler that returns users of the account @@ -23,18 +25,18 @@ type handler struct { accountManager account.Manager } -func AddEndpoints(accountManager account.Manager, router *mux.Router) { +func AddEndpoints(accountManager account.Manager, router *mux.Router, permissionsManager permissions.Manager) { userHandler := newHandler(accountManager) - router.HandleFunc("/users", userHandler.getAllUsers).Methods("GET", "OPTIONS") - router.HandleFunc("/users/current", userHandler.getCurrentUser).Methods("GET", "OPTIONS") - router.HandleFunc("/users/{userId}", userHandler.updateUser).Methods("PUT", "OPTIONS") - router.HandleFunc("/users/{userId}", userHandler.deleteUser).Methods("DELETE", "OPTIONS") - router.HandleFunc("/users", userHandler.createUser).Methods("POST", "OPTIONS") - router.HandleFunc("/users/{userId}/invite", userHandler.inviteUser).Methods("POST", "OPTIONS") - router.HandleFunc("/users/{userId}/approve", userHandler.approveUser).Methods("POST", "OPTIONS") - router.HandleFunc("/users/{userId}/reject", userHandler.rejectUser).Methods("DELETE", "OPTIONS") - router.HandleFunc("/users/{userId}/password", userHandler.changePassword).Methods("PUT", "OPTIONS") - addUsersTokensEndpoint(accountManager, router) + router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getAllUsers)).Methods("GET", "OPTIONS") + router.HandleFunc("/users/current", permissionsManager.WithPermission(modules.Users, operations.Read, userHandler.getCurrentUser)).Methods("GET", "OPTIONS") + router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.updateUser)).Methods("PUT", "OPTIONS") + router.HandleFunc("/users/{userId}", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.deleteUser)).Methods("DELETE", "OPTIONS") + router.HandleFunc("/users", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.createUser)).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/invite", permissionsManager.WithPermission(modules.Users, operations.Create, userHandler.inviteUser)).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/approve", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.approveUser)).Methods("POST", "OPTIONS") + router.HandleFunc("/users/{userId}/reject", permissionsManager.WithPermission(modules.Users, operations.Delete, userHandler.rejectUser)).Methods("DELETE", "OPTIONS") + router.HandleFunc("/users/{userId}/password", permissionsManager.WithPermission(modules.Users, operations.Update, userHandler.changePassword)).Methods("PUT", "OPTIONS") + addUsersTokensEndpoint(accountManager, router, permissionsManager) } // newHandler creates a new UsersHandler HTTP handler @@ -45,19 +47,12 @@ func newHandler(accountManager account.Manager) *handler { } // updateUser is a PUT requests to update User data -func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) updateUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodPut { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -89,7 +84,7 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { return } - newUser, err := h.accountManager.SaveUser(r.Context(), accountID, userID, &types.User{ + newUser, err := h.accountManager.SaveUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.User{ Id: targetUserID, Role: userRole, AutoGroups: req.AutoGroups, @@ -102,23 +97,16 @@ func (h *handler) updateUser(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID)) + util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId)) } // deleteUser is a DELETE request to delete a user -func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodDelete { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -126,7 +114,7 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.DeleteUser(r.Context(), accountID, userID, targetUserID) + err := h.accountManager.DeleteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -136,21 +124,14 @@ func (h *handler) deleteUser(w http.ResponseWriter, r *http.Request) { } // createUser creates a User in the system with a status "invited" (effectively this is a user invite). -func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) createUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - accountID, userID := userAuth.AccountId, userAuth.UserId - req := &api.PostApiUsersJSONRequestBody{} - err = json.NewDecoder(r.Body).Decode(&req) + err := json.NewDecoder(r.Body).Decode(&req) if err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return @@ -171,7 +152,7 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { name = *req.Name } - newUser, err := h.accountManager.CreateUser(r.Context(), accountID, userID, &types.UserInfo{ + newUser, err := h.accountManager.CreateUser(r.Context(), userAuth.AccountId, userAuth.UserId, &types.UserInfo{ Email: email, Name: name, Role: req.Role, @@ -183,25 +164,18 @@ func (h *handler) createUser(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userID)) + util.WriteJSONObject(r.Context(), w, toUserResponse(newUser, userAuth.UserId)) } // getAllUsers returns a list of users of the account this user belongs to. // It also gathers additional user data (like email and name) from the IDP manager. -func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { +func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodGet { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - accountID, userID := userAuth.AccountId, userAuth.UserId - data, err := h.accountManager.GetUsersFromAccount(r.Context(), accountID, userID) + data, err := h.accountManager.GetUsersFromAccount(r.Context(), userAuth.AccountId, userAuth.UserId) if err != nil { util.WriteError(r.Context(), err, w) return @@ -215,7 +189,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { continue } if serviceUser == "" { - users = append(users, toUserResponse(d, userID)) + users = append(users, toUserResponse(d, userAuth.UserId)) continue } @@ -226,7 +200,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { return } if includeServiceUser == d.IsServiceUser { - users = append(users, toUserResponse(d, userID)) + users = append(users, toUserResponse(d, userAuth.UserId)) } } @@ -235,19 +209,12 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) { // inviteUser resend invitations to users who haven't activated their accounts, // prior to the expiration period. -func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - accountID, userID := userAuth.AccountId, userAuth.UserId - vars := mux.Vars(r) targetUserID := vars["userId"] if len(targetUserID) == 0 { @@ -255,7 +222,7 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { return } - err = h.accountManager.InviteUser(r.Context(), accountID, userID, targetUserID) + err := h.accountManager.InviteUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -264,19 +231,13 @@ func (h *handler) inviteUser(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) getCurrentUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodGet { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return } - ctx := r.Context() - userAuth, err := nbcontext.GetUserAuthFromContext(ctx) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - user, err := h.accountManager.GetCurrentUserInfo(ctx, userAuth) + user, err := h.accountManager.GetCurrentUserInfo(r.Context(), *userAuth) if err != nil { util.WriteError(r.Context(), err, w) return @@ -356,7 +317,7 @@ func toUserResponse(user *types.UserInfo, currenUserID string) *api.User { } // approveUser is a POST request to approve a user that is pending approval -func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) approveUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodPost { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -369,11 +330,6 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) { return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } user, err := h.accountManager.ApproveUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) @@ -385,7 +341,7 @@ func (h *handler) approveUser(w http.ResponseWriter, r *http.Request) { } // rejectUser is a DELETE request to reject a user that is pending approval -func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) { +func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodDelete { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -398,12 +354,7 @@ func (h *handler) rejectUser(w http.ResponseWriter, r *http.Request) { return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - err = h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) + err := h.accountManager.RejectUser(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -421,7 +372,7 @@ type passwordChangeRequest struct { // changePassword is a PUT request to change user's password. // Only available when embedded IDP is enabled. // Users can only change their own password. -func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) { +func (h *handler) changePassword(w http.ResponseWriter, r *http.Request, userAuth *auth.UserAuth) { if r.Method != http.MethodPut { util.WriteErrorResponse("wrong HTTP method", http.StatusMethodNotAllowed, w) return @@ -434,19 +385,13 @@ func (h *handler) changePassword(w http.ResponseWriter, r *http.Request) { return } - userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - var req passwordChangeRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { util.WriteErrorResponse("couldn't parse JSON request", http.StatusBadRequest, w) return } - err = h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword) + err := h.accountManager.UpdateUserPassword(r.Context(), userAuth.AccountId, userAuth.UserId, targetUserID, req.OldPassword, req.NewPassword) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/users/users_handler_test.go b/management/server/http/handlers/users/users_handler_test.go index aa77dd843..ef2f96685 100644 --- a/management/server/http/handlers/users/users_handler_test.go +++ b/management/server/http/handlers/users/users_handler_test.go @@ -232,7 +232,11 @@ func TestGetUsers(t *testing.T) { AccountId: existingAccountID, }) - userHandler.getAllUsers(recorder, req) + userAuth := &auth.UserAuth{ + UserId: existingUserID, + AccountId: existingAccountID, + } + userHandler.getAllUsers(recorder, req, userAuth) res := recorder.Result() defer res.Body.Close() @@ -343,7 +347,7 @@ func TestUpdateUser(t *testing.T) { }) router := mux.NewRouter() - router.HandleFunc("/api/users/{userId}", userHandler.updateUser).Methods("PUT") + router.HandleFunc("/api/users/{userId}", wrapHandler(userHandler.updateUser)).Methods("PUT") router.ServeHTTP(recorder, req) res := recorder.Result() @@ -439,7 +443,11 @@ func TestCreateUser(t *testing.T) { AccountId: existingAccountID, }) - userHandler.createUser(rr, req) + userAuth := &auth.UserAuth{ + UserId: existingUserID, + AccountId: existingAccountID, + } + userHandler.createUser(rr, req, userAuth) res := rr.Result() defer res.Body.Close() @@ -490,7 +498,11 @@ func TestInviteUser(t *testing.T) { rr := httptest.NewRecorder() - userHandler.inviteUser(rr, req) + userAuth := &auth.UserAuth{ + UserId: existingUserID, + AccountId: existingAccountID, + } + userHandler.inviteUser(rr, req, userAuth) res := rr.Result() defer res.Body.Close() @@ -549,7 +561,11 @@ func TestDeleteUser(t *testing.T) { rr := httptest.NewRecorder() - userHandler.deleteUser(rr, req) + userAuth := &auth.UserAuth{ + UserId: existingUserID, + AccountId: existingAccountID, + } + userHandler.deleteUser(rr, req, userAuth) res := rr.Result() defer res.Body.Close() @@ -682,7 +698,11 @@ func TestCurrentUser(t *testing.T) { rr := httptest.NewRecorder() - userHandler.getCurrentUser(rr, req) + userAuth := &auth.UserAuth{ + UserId: existingUserID, + AccountId: existingAccountID, + } + userHandler.getCurrentUser(rr, req, userAuth) res := rr.Result() defer res.Body.Close() @@ -779,7 +799,7 @@ func TestApproveUserEndpoint(t *testing.T) { handler := newHandler(am) router := mux.NewRouter() - router.HandleFunc("/users/{userId}/approve", handler.approveUser).Methods("POST") + router.HandleFunc("/users/{userId}/approve", wrapHandler(handler.approveUser)).Methods("POST") req, err := http.NewRequest("POST", "/users/pending-user/approve", nil) require.NoError(t, err) diff --git a/management/server/permissions/manager.go b/management/server/permissions/manager.go index e6bdd2025..914b88324 100644 --- a/management/server/permissions/manager.go +++ b/management/server/permissions/manager.go @@ -4,20 +4,25 @@ package permissions import ( "context" + "net/http" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" + nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/roles" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" + "github.com/netbirdio/netbird/shared/auth" + "github.com/netbirdio/netbird/shared/management/http/util" "github.com/netbirdio/netbird/shared/management/status" ) type Manager interface { + WithPermission(module modules.Module, operation operations.Operation, handlerFunc func(w http.ResponseWriter, r *http.Request, auth *auth.UserAuth)) http.HandlerFunc ValidateUserPermissions(ctx context.Context, accountID, userID string, module modules.Module, operation operations.Operation) (bool, error) ValidateRoleModuleAccess(ctx context.Context, accountID string, role roles.RolePermissions, module modules.Module, operation operations.Operation) bool ValidateAccountAccess(ctx context.Context, accountID string, user *types.User, allowOwnerAndAdmin bool) error @@ -36,6 +41,37 @@ func NewManager(store store.Store) Manager { } } +// WithPermission wraps an HTTP handler with permission checking logic. +func (m *managerImpl) WithPermission( + module modules.Module, + operation operations.Operation, + handlerFunc func(w http.ResponseWriter, r *http.Request, auth *auth.UserAuth), +) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + userAuth, err := nbcontext.GetUserAuthFromContext(r.Context()) + if err != nil { + log.WithContext(r.Context()).Errorf("failed to get user auth from context: %v", err) + util.WriteError(r.Context(), err, w) + return + } + + allowed, err := m.ValidateUserPermissions(r.Context(), userAuth.AccountId, userAuth.UserId, module, operation) + if err != nil { + log.WithContext(r.Context()).Errorf("failed to validate permissions for user %s on account %s: %v", userAuth.UserId, userAuth.AccountId, err) + util.WriteError(r.Context(), status.NewPermissionValidationError(err), w) + return + } + + if !allowed { + log.WithContext(r.Context()).Tracef("user %s on account %s is not allowed to %s in %s", userAuth.UserId, userAuth.AccountId, operation, module) + util.WriteError(r.Context(), status.NewPermissionDeniedError(), w) + return + } + + handlerFunc(w, r, &userAuth) + } +} + func (m *managerImpl) ValidateUserPermissions( ctx context.Context, accountID string, diff --git a/management/server/permissions/manager_mock.go b/management/server/permissions/manager_mock.go index ec9f263f9..82e728fd4 100644 --- a/management/server/permissions/manager_mock.go +++ b/management/server/permissions/manager_mock.go @@ -6,6 +6,7 @@ package permissions import ( context "context" + http "net/http" reflect "reflect" gomock "github.com/golang/mock/gomock" @@ -14,6 +15,7 @@ import ( operations "github.com/netbirdio/netbird/management/server/permissions/operations" roles "github.com/netbirdio/netbird/management/server/permissions/roles" types "github.com/netbirdio/netbird/management/server/types" + auth "github.com/netbirdio/netbird/shared/auth" ) // MockManager is a mock of Manager interface. @@ -108,3 +110,17 @@ func (mr *MockManagerMockRecorder) ValidateUserPermissions(ctx, accountID, userI mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ValidateUserPermissions", reflect.TypeOf((*MockManager)(nil).ValidateUserPermissions), ctx, accountID, userID, module, operation) } + +// WithPermission mocks base method. +func (m *MockManager) WithPermission(module modules.Module, operation operations.Operation, handlerFunc func(http.ResponseWriter, *http.Request, *auth.UserAuth)) http.HandlerFunc { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithPermission", module, operation, handlerFunc) + ret0, _ := ret[0].(http.HandlerFunc) + return ret0 +} + +// WithPermission indicates an expected call of WithPermission. +func (mr *MockManagerMockRecorder) WithPermission(module, operation, handlerFunc interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithPermission", reflect.TypeOf((*MockManager)(nil).WithPermission), module, operation, handlerFunc) +}