mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-29 20:19:56 +00:00
Compare commits
1 Commits
ui-refacto
...
refactor/m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
78ed62535e |
@@ -120,7 +120,7 @@ func (s *BaseServer) EventStore() activity.Store {
|
|||||||
|
|
||||||
func (s *BaseServer) APIHandler() http.Handler {
|
func (s *BaseServer) APIHandler() http.Handler {
|
||||||
return Create(s, func() http.Handler {
|
return Create(s, func() http.Handler {
|
||||||
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter(), s.IsValidChildAccount)
|
httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.Router(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.Metrics(), s.PermissionsManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.InstanceManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.AuthMiddleware())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create API handler: %v", err)
|
log.Fatalf("failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
@@ -153,6 +153,20 @@ func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) AuthMiddleware() mux.MiddlewareFunc {
|
||||||
|
return Create(s, func() mux.MiddlewareFunc {
|
||||||
|
m := middleware.NewAuthMiddleware(
|
||||||
|
s.AuthManager(),
|
||||||
|
s.AccountManager().GetAccountIDFromUserAuth,
|
||||||
|
s.AccountManager().SyncUserJWTGroups,
|
||||||
|
s.AccountManager().GetUserFromUserAuth,
|
||||||
|
s.RateLimiter(),
|
||||||
|
s.Metrics().GetMeter(),
|
||||||
|
)
|
||||||
|
return m.Handler
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (s *BaseServer) GRPCServer() *grpc.Server {
|
func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||||
return Create(s, func() *grpc.Server {
|
return Create(s, func() *grpc.Server {
|
||||||
trustedPeers := s.Config.ReverseProxy.TrustedPeers
|
trustedPeers := s.Config.ReverseProxy.TrustedPeers
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/idp"
|
"github.com/netbirdio/netbird/management/server/idp"
|
||||||
|
"github.com/netbirdio/netbird/management/server/instance"
|
||||||
"github.com/netbirdio/netbird/management/server/networks"
|
"github.com/netbirdio/netbird/management/server/networks"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/routers"
|
"github.com/netbirdio/netbird/management/server/networks/routers"
|
||||||
@@ -151,6 +152,16 @@ func (s *BaseServer) IdpManager() idp.Manager {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *BaseServer) InstanceManager() instance.Manager {
|
||||||
|
return Create(s, func() instance.Manager {
|
||||||
|
m, err := instance.NewManager(context.Background(), s.Store(), s.IdpManager())
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to create instance manager: %v", err)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil
|
// OAuthConfigProvider is only relevant when we have an embedded IdP service. Otherwise must be nil
|
||||||
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
|
func (s *BaseServer) OAuthConfigProvider() idp.OAuthConfigProvider {
|
||||||
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
|
if s.Config.EmbeddedIdP == nil || !s.Config.EmbeddedIdP.Enabled {
|
||||||
@@ -229,6 +240,3 @@ func (s *BaseServer) ReverseProxyDomainManager() *manager.Manager {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) IsValidChildAccount(_ context.Context, _, _, _ string) bool {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/rs/cors"
|
"github.com/rs/cors"
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain/manager"
|
||||||
|
|
||||||
@@ -20,7 +19,6 @@ import (
|
|||||||
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
reverseproxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service/manager"
|
||||||
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
idpmanager "github.com/netbirdio/netbird/management/server/idp"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
@@ -34,7 +32,6 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
"github.com/netbirdio/netbird/management/server/http/handlers/proxy"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/auth"
|
|
||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
nbgroups "github.com/netbirdio/netbird/management/server/groups"
|
||||||
"github.com/netbirdio/netbird/management/server/http/handlers/accounts"
|
"github.com/netbirdio/netbird/management/server/http/handlers/accounts"
|
||||||
@@ -49,7 +46,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/http/handlers/routes"
|
"github.com/netbirdio/netbird/management/server/http/handlers/routes"
|
||||||
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
|
"github.com/netbirdio/netbird/management/server/http/handlers/setup_keys"
|
||||||
"github.com/netbirdio/netbird/management/server/http/handlers/users"
|
"github.com/netbirdio/netbird/management/server/http/handlers/users"
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware"
|
|
||||||
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
"github.com/netbirdio/netbird/management/server/http/middleware/bypass"
|
||||||
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
nbinstance "github.com/netbirdio/netbird/management/server/instance"
|
||||||
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
|
nbnetworks "github.com/netbirdio/netbird/management/server/networks"
|
||||||
@@ -59,7 +55,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
|
||||||
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter, isValidChildAccount middleware.IsValidChildAccountFunc) (http.Handler, error) {
|
func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, appMetrics telemetry.AppMetrics, permissionsManager permissions.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, instanceManager nbinstance.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, authMiddleware mux.MiddlewareFunc) (http.Handler, error) {
|
||||||
|
|
||||||
// Register bypass paths for unauthenticated endpoints
|
// Register bypass paths for unauthenticated endpoints
|
||||||
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
if err := bypass.AddBypassPath("/api/instance"); err != nil {
|
||||||
@@ -80,32 +76,11 @@ func NewAPIHandler(ctx context.Context, router *mux.Router, accountManager accou
|
|||||||
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
return nil, fmt.Errorf("failed to add bypass path: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if rateLimiter == nil {
|
|
||||||
log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled")
|
|
||||||
rateLimiter = middleware.NewAPIRateLimiter(nil)
|
|
||||||
rateLimiter.SetEnabled(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
authMiddleware := middleware.NewAuthMiddleware(
|
|
||||||
authManager,
|
|
||||||
accountManager.GetAccountIDFromUserAuth,
|
|
||||||
accountManager.SyncUserJWTGroups,
|
|
||||||
accountManager.GetUserFromUserAuth,
|
|
||||||
rateLimiter,
|
|
||||||
appMetrics.GetMeter(),
|
|
||||||
isValidChildAccount,
|
|
||||||
)
|
|
||||||
|
|
||||||
corsMiddleware := cors.AllowAll()
|
corsMiddleware := cors.AllowAll()
|
||||||
|
|
||||||
metricsMiddleware := appMetrics.HTTPMiddleware()
|
metricsMiddleware := appMetrics.HTTPMiddleware()
|
||||||
|
|
||||||
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware.Handler)
|
router.Use(metricsMiddleware.Handler, corsMiddleware.Handler, authMiddleware)
|
||||||
|
|
||||||
instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), idpManager)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create instance manager: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
accounts.AddEndpoints(accountManager, settingsManager, router)
|
accounts.AddEndpoints(accountManager, settingsManager, router)
|
||||||
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
|
peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"go.opentelemetry.io/otel/metric"
|
"go.opentelemetry.io/otel/metric"
|
||||||
|
|
||||||
@@ -25,7 +26,8 @@ type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth auth.UserAuth) err
|
|||||||
|
|
||||||
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
type GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
|
||||||
|
|
||||||
type IsValidChildAccountFunc func(ctx context.Context, userID, accountID, childAccountID string) bool
|
// jwtTokenCtxKey carries the parsed JWT token.
|
||||||
|
type jwtTokenCtxKey struct{}
|
||||||
|
|
||||||
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
|
||||||
type AuthMiddleware struct {
|
type AuthMiddleware struct {
|
||||||
@@ -35,7 +37,6 @@ type AuthMiddleware struct {
|
|||||||
syncUserJWTGroups SyncUserJWTGroupsFunc
|
syncUserJWTGroups SyncUserJWTGroupsFunc
|
||||||
rateLimiter *APIRateLimiter
|
rateLimiter *APIRateLimiter
|
||||||
patUsageTracker *PATUsageTracker
|
patUsageTracker *PATUsageTracker
|
||||||
isValidChildAccount IsValidChildAccountFunc
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthMiddleware instance constructor
|
// NewAuthMiddleware instance constructor
|
||||||
@@ -46,7 +47,6 @@ func NewAuthMiddleware(
|
|||||||
getUserFromUserAuth GetUserFromUserAuthFunc,
|
getUserFromUserAuth GetUserFromUserAuthFunc,
|
||||||
rateLimiter *APIRateLimiter,
|
rateLimiter *APIRateLimiter,
|
||||||
meter metric.Meter,
|
meter metric.Meter,
|
||||||
isValidChildAccount IsValidChildAccountFunc,
|
|
||||||
) *AuthMiddleware {
|
) *AuthMiddleware {
|
||||||
var patUsageTracker *PATUsageTracker
|
var patUsageTracker *PATUsageTracker
|
||||||
if meter != nil {
|
if meter != nil {
|
||||||
@@ -64,12 +64,18 @@ func NewAuthMiddleware(
|
|||||||
getUserFromUserAuth: getUserFromUserAuth,
|
getUserFromUserAuth: getUserFromUserAuth,
|
||||||
rateLimiter: rateLimiter,
|
rateLimiter: rateLimiter,
|
||||||
patUsageTracker: patUsageTracker,
|
patUsageTracker: patUsageTracker,
|
||||||
isValidChildAccount: isValidChildAccount,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handler method of the middleware which authenticates a user either by JWT claims or by PAT
|
// Handler composes the full authentication chain by wrapping the given
|
||||||
|
// handler with ValidationHandler followed by AccountAccessHandler.
|
||||||
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
||||||
|
return m.ValidationHandler(m.AccountAccessHandler(h))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidationHandler authenticates the caller via JWT or PAT and stores the
|
||||||
|
// resulting UserAuth in the request context. It performs no account-level work.
|
||||||
|
func (m *AuthMiddleware) ValidationHandler(h http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
|
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
|
||||||
return
|
return
|
||||||
@@ -86,14 +92,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
|
|
||||||
switch authType {
|
switch authType {
|
||||||
case "bearer":
|
case "bearer":
|
||||||
if err := m.checkJWTFromRequest(r, authHeader); err != nil {
|
if err := m.validateJWT(r, authHeader); err != nil {
|
||||||
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error())
|
||||||
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
h.ServeHTTP(w, r)
|
h.ServeHTTP(w, r)
|
||||||
case "token":
|
case "token":
|
||||||
if err := m.checkPATFromRequest(r, authHeader); err != nil {
|
if err := m.validatePAT(r, authHeader); err != nil {
|
||||||
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error())
|
||||||
// Check if it's a status error, otherwise default to Unauthorized
|
// Check if it's a status error, otherwise default to Unauthorized
|
||||||
if _, ok := status.FromError(err); !ok {
|
if _, ok := status.FromError(err); !ok {
|
||||||
@@ -110,66 +116,55 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckJWTFromRequest checks if the JWT is valid
|
// AccountAccessHandler runs post-validation access checks for JWT-authenticated
|
||||||
func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error {
|
// requests. PAT requests pass through unchanged.
|
||||||
token, err := getTokenFromJWTRequest(authHeaderParts)
|
func (m *AuthMiddleware) AccountAccessHandler(h http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if bypass.ShouldBypass(r.URL.Path, h, w, r) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// If an error occurs, call the error handler and return an error
|
userAuth, err := nbcontext.GetUserAuthFromRequest(r)
|
||||||
|
if err != nil {
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if userAuth.IsPAT {
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
validatedToken, _ := r.Context().Value(jwtTokenCtxKey{}).(*jwt.Token)
|
||||||
|
|
||||||
|
if err := m.applyAccountAccess(r, userAuth, validatedToken); err != nil {
|
||||||
|
log.WithContext(r.Context()).Errorf("Error applying JWT account access: %s", err.Error())
|
||||||
|
util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *AuthMiddleware) validateJWT(r *http.Request, authHeaderParts []string) error {
|
||||||
|
token, err := getTokenFromJWTRequest(authHeaderParts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error extracting token: %w", err)
|
return fmt.Errorf("error extracting token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := r.Context()
|
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(r.Context(), token)
|
||||||
|
|
||||||
userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
|
||||||
if m.isValidChildAccount(ctx, userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
|
||||||
userAuth.AccountId = impersonate[0]
|
|
||||||
userAuth.IsChild = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Email is now extracted in ToUserAuth (from claims or userinfo endpoint)
|
|
||||||
// Available as userAuth.Email
|
|
||||||
|
|
||||||
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
|
||||||
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if userAuth.AccountId != accountId {
|
|
||||||
log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
|
|
||||||
userAuth.AccountId = accountId
|
|
||||||
}
|
|
||||||
|
|
||||||
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.syncUserJWTGroups(ctx, userAuth)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = m.getUserFromUserAuth(ctx, userAuth)
|
|
||||||
if err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// propagates ctx change to upstream middleware
|
// propagates ctx change to upstream middleware
|
||||||
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||||
|
*r = *r.WithContext(context.WithValue(r.Context(), jwtTokenCtxKey{}, validatedToken))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CheckPATFromRequest checks if the PAT is valid
|
func (m *AuthMiddleware) validatePAT(r *http.Request, authHeaderParts []string) error {
|
||||||
func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error {
|
|
||||||
token, err := getTokenFromPATRequest(authHeaderParts)
|
token, err := getTokenFromPATRequest(authHeaderParts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error extracting token: %w", err)
|
return fmt.Errorf("error extracting token: %w", err)
|
||||||
@@ -192,8 +187,7 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|||||||
return fmt.Errorf("token expired")
|
return fmt.Errorf("token expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.authManager.MarkPATUsed(ctx, pat.ID)
|
if err := m.authManager.MarkPATUsed(ctx, pat.ID); err != nil {
|
||||||
if err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,11 +199,40 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []
|
|||||||
IsPAT: true,
|
IsPAT: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 {
|
// propagates ctx change to upstream middleware
|
||||||
if m.isValidChildAccount(r.Context(), userAuth.UserId, userAuth.AccountId, impersonate[0]) {
|
*r = *nbcontext.SetUserAuthInRequest(r, userAuth)
|
||||||
userAuth.AccountId = impersonate[0]
|
return nil
|
||||||
userAuth.IsChild = true
|
}
|
||||||
}
|
|
||||||
|
// applyAccountAccess executes account-level checks for an authenticated JWT
|
||||||
|
// user: ensures the account exists, verifies access via JWT groups, syncs
|
||||||
|
// groups, and fetches the user record.
|
||||||
|
func (m *AuthMiddleware) applyAccountAccess(r *http.Request, userAuth auth.UserAuth, validatedToken *jwt.Token) error {
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
// we need to call this method because if user is new, we will automatically add it to existing or create a new account
|
||||||
|
accountId, _, err := m.ensureAccount(ctx, userAuth)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if userAuth.AccountId != accountId {
|
||||||
|
log.WithContext(ctx).Tracef("Auth middleware sets accountId from ensure, before %s, now %s", userAuth.AccountId, accountId)
|
||||||
|
userAuth.AccountId = accountId
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.syncUserJWTGroups(ctx, userAuth); err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := m.getUserFromUserAuth(ctx, userAuth); err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// propagates ctx change to upstream middleware
|
// propagates ctx change to upstream middleware
|
||||||
|
|||||||
@@ -211,7 +211,6 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
|||||||
},
|
},
|
||||||
disabledLimiter,
|
disabledLimiter,
|
||||||
nil,
|
nil,
|
||||||
func(_ context.Context, _, _, _ string) bool { return false },
|
|
||||||
)
|
)
|
||||||
|
|
||||||
handlerToTest := authMiddleware.Handler(nextHandler)
|
handlerToTest := authMiddleware.Handler(nextHandler)
|
||||||
@@ -271,7 +270,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
func(_ context.Context, _, _, _ string) bool { return false },
|
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -324,7 +322,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
func(_ context.Context, _, _, _ string) bool { return false },
|
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -368,7 +365,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
func(_ context.Context, _, _, _ string) bool { return false },
|
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -413,7 +409,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
func(_ context.Context, _, _, _ string) bool { return false },
|
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -478,7 +473,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
func(_ context.Context, _, _, _ string) bool { return false },
|
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -538,7 +532,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
func(_ context.Context, _, _, _ string) bool { return false },
|
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -594,7 +587,6 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
|||||||
},
|
},
|
||||||
NewAPIRateLimiter(rateLimitConfig),
|
NewAPIRateLimiter(rateLimitConfig),
|
||||||
nil,
|
nil,
|
||||||
func(_ context.Context, _, _, _ string) bool { return false },
|
|
||||||
)
|
)
|
||||||
|
|
||||||
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -695,7 +687,6 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) {
|
|||||||
},
|
},
|
||||||
disabledLimiter,
|
disabledLimiter,
|
||||||
nil,
|
nil,
|
||||||
func(_ context.Context, _, _, _ string) bool { return false },
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||||
"github.com/netbirdio/netbird/management/server/groups"
|
"github.com/netbirdio/netbird/management/server/groups"
|
||||||
http2 "github.com/netbirdio/netbird/management/server/http"
|
http2 "github.com/netbirdio/netbird/management/server/http"
|
||||||
|
"github.com/netbirdio/netbird/management/server/http/middleware"
|
||||||
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
"github.com/netbirdio/netbird/management/server/http/testing/testing_tools"
|
||||||
"github.com/netbirdio/netbird/management/server/networks"
|
"github.com/netbirdio/netbird/management/server/networks"
|
||||||
"github.com/netbirdio/netbird/management/server/networks/resources"
|
"github.com/netbirdio/netbird/management/server/networks/resources"
|
||||||
@@ -136,8 +137,12 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
|||||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||||
|
|
||||||
|
rateLimiter := middleware.NewAPIRateLimiter(nil)
|
||||||
|
rateLimiter.SetEnabled(false)
|
||||||
|
authMiddleware := middleware.NewAuthMiddleware(authManagerMock, am.GetAccountIDFromUserAuth, am.SyncUserJWTGroups, am.GetUserFromUserAuth, rateLimiter, metrics.GetMeter())
|
||||||
|
|
||||||
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, authMiddleware.Handler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create API handler: %v", err)
|
t.Fatalf("Failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
@@ -266,8 +271,12 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
|||||||
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "")
|
||||||
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager)
|
||||||
|
|
||||||
|
rateLimiter := middleware.NewAPIRateLimiter(nil)
|
||||||
|
rateLimiter.SetEnabled(false)
|
||||||
|
authMiddleware := middleware.NewAuthMiddleware(authManagerMock, am.GetAccountIDFromUserAuth, am.SyncUserJWTGroups, am.GetUserFromUserAuth, rateLimiter, metrics.GetMeter())
|
||||||
|
|
||||||
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
apiRouter := mux.NewRouter().PathPrefix("/api").Subrouter()
|
||||||
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil, nil)
|
apiHandler, err := http2.NewAPIHandler(context.Background(), apiRouter, am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, metrics, permissionsManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, authMiddleware.Handler)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create API handler: %v", err)
|
t.Fatalf("Failed to create API handler: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user