diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 483bb989a..3d4de31d0 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -62,6 +62,7 @@ func NewAPIHandler( authManager, accountManager.GetAccountIDFromUserAuth, accountManager.SyncUserJWTGroups, + accountManager.GetUserFromUserAuth, ) corsMiddleware := cors.AllowAll() diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index a8e6790a9..6f0d1556f 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -15,16 +15,20 @@ import ( "github.com/netbirdio/netbird/management/server/http/middleware/bypass" "github.com/netbirdio/netbird/management/server/http/util" "github.com/netbirdio/netbird/management/server/status" + "github.com/netbirdio/netbird/management/server/types" ) type EnsureAccountFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) type SyncUserJWTGroupsFunc func(ctx context.Context, userAuth nbcontext.UserAuth) error +type GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) + // AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens type AuthMiddleware struct { - authManager auth.Manager - ensureAccount EnsureAccountFunc - syncUserJWTGroups SyncUserJWTGroupsFunc + authManager auth.Manager + ensureAccount EnsureAccountFunc + getUserFromUserAuth GetUserFromUserAuthFunc + syncUserJWTGroups SyncUserJWTGroupsFunc } // NewAuthMiddleware instance constructor @@ -32,11 +36,13 @@ func NewAuthMiddleware( authManager auth.Manager, ensureAccount EnsureAccountFunc, syncUserJWTGroups SyncUserJWTGroupsFunc, + getUserFromUserAuth GetUserFromUserAuthFunc, ) *AuthMiddleware { return &AuthMiddleware{ - authManager: authManager, - ensureAccount: ensureAccount, - syncUserJWTGroups: syncUserJWTGroups, + authManager: authManager, + ensureAccount: ensureAccount, + syncUserJWTGroups: syncUserJWTGroups, + getUserFromUserAuth: getUserFromUserAuth, } } @@ -123,6 +129,12 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, auth []string) (*h log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err) } + _, err = m.getUserFromUserAuth(ctx, userAuth) + if err != nil { + log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err) + return r, err + } + return nbcontext.SetUserAuthInRequest(r, userAuth), nil } diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index 3dc7d51cb..410ff7e15 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -190,6 +190,9 @@ func TestAuthMiddleware_Handler(t *testing.T) { func(ctx context.Context, userAuth nbcontext.UserAuth) error { return nil }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, ) handlerToTest := authMiddleware.Handler(nextHandler) @@ -291,6 +294,9 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { func(ctx context.Context, userAuth nbcontext.UserAuth) error { return nil }, + func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) { + return &types.User{}, nil + }, ) for _, tc := range tt {