feat: add a flask_context_manager. (#21061)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-06-17 16:31:29 +08:00
committed by GitHub
parent 7a2a8a2ffd
commit 0dcacdf83d
8 changed files with 239 additions and 126 deletions

3
.gitignore vendored
View File

@@ -210,3 +210,6 @@ mise.toml
# Next.js build output
.next/
# AI Assistant
.roo/

View File

@@ -5,7 +5,7 @@ import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, copy_current_request_context, current_app, has_request_context
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker
@@ -31,6 +31,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
@@ -399,21 +400,18 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# new thread with request context and contextvars
context = contextvars.copy_context()
@copy_current_request_context
def worker_with_context():
# Run the worker within the copied context
return context.run(
self._generate_worker,
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation_id=conversation.id,
message_id=message.id,
context=context,
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
"context": context,
},
)
worker_thread = threading.Thread(target=worker_with_context)
worker_thread.start()
# return response or stream generator
@@ -449,24 +447,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param message_id: message ID
:return:
"""
for var, val in context.items():
var.set(val)
# FIXME(-LAN-): Save current user before entering new app context
from flask import g
saved_user = None
if has_request_context() and hasattr(g, "_login_user"):
saved_user = g._login_user
with flask_app.app_context():
with preserve_flask_contexts(flask_app, context_vars=context):
try:
# Restore user in new app context
if saved_user is not None:
from flask import g
g._login_user = saved_user
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)

View File

@@ -5,7 +5,7 @@ import uuid
from collections.abc import Generator, Mapping
from typing import Any, Literal, Union, overload
from flask import Flask, copy_current_request_context, current_app, has_request_context
from flask import Flask, current_app
from pydantic import ValidationError
from configs import dify_config
@@ -23,6 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser
from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError
@@ -182,21 +183,18 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# new thread with request context and contextvars
context = contextvars.copy_context()
@copy_current_request_context
def worker_with_context():
# Run the worker within the copied context
return context.run(
self._generate_worker,
flask_app=current_app._get_current_object(), # type: ignore
context=context,
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation_id=conversation.id,
message_id=message.id,
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"context": context,
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"conversation_id": conversation.id,
"message_id": message.id,
},
)
worker_thread = threading.Thread(target=worker_with_context)
worker_thread.start()
# return response or stream generator
@@ -229,24 +227,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
:param message_id: message ID
:return:
"""
for var, val in context.items():
var.set(val)
# FIXME(-LAN-): Save current user before entering new app context
from flask import g
saved_user = None
if has_request_context() and hasattr(g, "_login_user"):
saved_user = g._login_user
with flask_app.app_context():
with preserve_flask_contexts(flask_app, context_vars=context):
try:
# Restore user in new app context
if saved_user is not None:
from flask import g
g._login_user = saved_user
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)

View File

@@ -5,7 +5,7 @@ import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Literal, Optional, Union, overload
from flask import Flask, copy_current_request_context, current_app, has_request_context
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker
@@ -29,6 +29,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
@@ -209,20 +210,17 @@ class WorkflowAppGenerator(BaseAppGenerator):
# new thread with request context and contextvars
context = contextvars.copy_context()
@copy_current_request_context
def worker_with_context():
# Run the worker within the copied context
return context.run(
self._generate_worker,
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
context=context,
workflow_thread_pool_id=workflow_thread_pool_id,
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"application_generate_entity": application_generate_entity,
"queue_manager": queue_manager,
"context": context,
"workflow_thread_pool_id": workflow_thread_pool_id,
},
)
worker_thread = threading.Thread(target=worker_with_context)
worker_thread.start()
# return response or stream generator
@@ -408,24 +406,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_thread_pool_id: workflow thread pool id
:return:
"""
for var, val in context.items():
var.set(val)
# FIXME(-LAN-): Save current user before entering new app context
from flask import g
saved_user = None
if has_request_context() and hasattr(g, "_login_user"):
saved_user = g._login_user
with flask_app.app_context():
with preserve_flask_contexts(flask_app, context_vars=context):
try:
# Restore user in new app context
if saved_user is not None:
from flask import g
g._login_user = saved_user
# workflow app
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,

View File

@@ -9,7 +9,7 @@ from copy import copy, deepcopy
from datetime import UTC, datetime
from typing import Any, Optional, cast
from flask import Flask, current_app, has_request_context
from flask import Flask, current_app
from configs import dify_config
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
@@ -53,6 +53,7 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from libs.flask_utils import preserve_flask_contexts
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -537,24 +538,9 @@ class GraphEngine:
"""
Run parallel nodes
"""
for var, val in context.items():
var.set(val)
# FIXME(-LAN-): Save current user before entering new app context
from flask import g
saved_user = None
if has_request_context() and hasattr(g, "_login_user"):
saved_user = g._login_user
with flask_app.app_context():
with preserve_flask_contexts(flask_app, context_vars=context):
try:
# Restore user in new app context
if saved_user is not None:
from flask import g
g._login_user = saved_user
q.put(
ParallelBranchRunStartedEvent(
parallel_id=parallel_id,

View File

@@ -7,7 +7,7 @@ from datetime import UTC, datetime
from queue import Empty, Queue
from typing import TYPE_CHECKING, Any, Optional, cast
from flask import Flask, current_app, has_request_context
from flask import Flask, current_app
from configs import dify_config
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
@@ -37,6 +37,7 @@ from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from libs.flask_utils import preserve_flask_contexts
from .exc import (
InvalidIteratorValueError,
@@ -583,23 +584,8 @@ class IterationNode(BaseNode[IterationNodeData]):
"""
run single iteration in parallel mode
"""
for var, val in context.items():
var.set(val)
# FIXME(-LAN-): Save current user before entering new app context
from flask import g
saved_user = None
if has_request_context() and hasattr(g, "_login_user"):
saved_user = g._login_user
with flask_app.app_context():
# Restore user in new app context
if saved_user is not None:
from flask import g
g._login_user = saved_user
with preserve_flask_contexts(flask_app, context_vars=context):
parallel_mode_run_id = uuid.uuid4().hex
graph_engine_copy = graph_engine.create_copy()
variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool

65
api/libs/flask_utils.py Normal file
View File

@@ -0,0 +1,65 @@
import contextvars
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TypeVar
from flask import Flask, g, has_request_context
T = TypeVar("T")
@contextmanager
def preserve_flask_contexts(
flask_app: Flask,
context_vars: contextvars.Context,
) -> Iterator[None]:
"""
A context manager that handles:
1. flask-login's UserProxy copy
2. ContextVars copy
3. flask_app.app_context()
This context manager ensures that the Flask application context is properly set up,
the current user is preserved across context boundaries, and any provided context variables
are set within the new context.
Note:
This manager aims to allow use current_user cross thread and app context,
but it's not the recommend use, it's better to pass user directly in parameters.
Args:
flask_app: The Flask application instance
context_vars: contextvars.Context object containing context variables to be set in the new context
Yields:
None
Example:
```python
with preserve_flask_contexts(flask_app, context_vars=context_vars):
# Code that needs Flask app context and context variables
# Current user will be preserved if available
```
"""
# Set context variables if provided
if context_vars:
for var, val in context_vars.items():
var.set(val)
# Save current user before entering new app context
saved_user = None
if has_request_context() and hasattr(g, "_login_user"):
saved_user = g._login_user
# Enter Flask app context
with flask_app.app_context():
try:
# Restore user in new app context if it was saved
if saved_user is not None:
g._login_user = saved_user
# Yield control back to the caller
yield
finally:
# Any cleanup can be added here if needed
pass

View File

@@ -0,0 +1,124 @@
import contextvars
import threading
from typing import Optional
import pytest
from flask import Flask
from flask_login import LoginManager, UserMixin, current_user, login_user
from libs.flask_utils import preserve_flask_contexts
class User(UserMixin):
"""Simple User class for testing."""
def __init__(self, id: str):
self.id = id
def get_id(self) -> str:
return self.id
@pytest.fixture
def login_app(app: Flask) -> Flask:
"""Set up a Flask app with flask-login."""
# Set a secret key for the app
app.config["SECRET_KEY"] = "test-secret-key"
login_manager = LoginManager()
login_manager.init_app(app)
@login_manager.user_loader
def load_user(user_id: str) -> Optional[User]:
if user_id == "test_user":
return User("test_user")
return None
return app
@pytest.fixture
def test_user() -> User:
"""Create a test user."""
return User("test_user")
def test_current_user_not_accessible_across_threads(login_app: Flask, test_user: User):
"""
Test that current_user is not accessible in a different thread without preserve_flask_contexts.
This test demonstrates that without the preserve_flask_contexts, we cannot access
current_user in a different thread, even with app_context.
"""
# Log in the user in the main thread
with login_app.test_request_context():
login_user(test_user)
assert current_user.is_authenticated
assert current_user.id == "test_user"
# Store the result of the thread execution
result = {"user_accessible": True, "error": None}
# Define a function to run in a separate thread
def check_user_in_thread():
try:
# Try to access current_user in a different thread with app_context
with login_app.app_context():
# This should fail because current_user is not accessible across threads
# without preserve_flask_contexts
result["user_accessible"] = current_user.is_authenticated
except Exception as e:
result["error"] = str(e) # type: ignore
# Run the function in a separate thread
thread = threading.Thread(target=check_user_in_thread)
thread.start()
thread.join()
# Verify that we got an error or current_user is not authenticated
assert result["error"] is not None or (result["user_accessible"] is not None and not result["user_accessible"])
def test_current_user_accessible_with_preserve_flask_contexts(login_app: Flask, test_user: User):
"""
Test that current_user is accessible in a different thread with preserve_flask_contexts.
This test demonstrates that with the preserve_flask_contexts, we can access
current_user in a different thread.
"""
# Log in the user in the main thread
with login_app.test_request_context():
login_user(test_user)
assert current_user.is_authenticated
assert current_user.id == "test_user"
# Save the context variables
context_vars = contextvars.copy_context()
# Store the result of the thread execution
result = {"user_accessible": False, "user_id": None, "error": None}
# Define a function to run in a separate thread
def check_user_in_thread_with_manager():
try:
# Use preserve_flask_contexts to access current_user in a different thread
with preserve_flask_contexts(login_app, context_vars):
from flask_login import current_user
if current_user:
result["user_accessible"] = True
result["user_id"] = current_user.id
else:
result["user_accessible"] = False
except Exception as e:
result["error"] = str(e) # type: ignore
# Run the function in a separate thread
thread = threading.Thread(target=check_user_in_thread_with_manager)
thread.start()
thread.join()
# Verify that current_user is accessible and has the correct ID
assert result["error"] is None
assert result["user_accessible"] is True
assert result["user_id"] == "test_user"