feat: add a flask_context_manager. (#21061)
Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -210,3 +210,6 @@ mise.toml
|
|||||||
|
|
||||||
# Next.js build output
|
# Next.js build output
|
||||||
.next/
|
.next/
|
||||||
|
|
||||||
|
# AI Assistant
|
||||||
|
.roo/
|
||||||
|
@@ -5,7 +5,7 @@ import uuid
|
|||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Literal, Optional, Union, overload
|
from typing import Any, Literal, Optional, Union, overload
|
||||||
|
|
||||||
from flask import Flask, copy_current_request_context, current_app, has_request_context
|
from flask import Flask, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
@@ -31,6 +31,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
|
|||||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
|
from libs.flask_utils import preserve_flask_contexts
|
||||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||||
from models.enums import WorkflowRunTriggeredFrom
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
@@ -399,20 +400,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
# new thread with request context and contextvars
|
# new thread with request context and contextvars
|
||||||
context = contextvars.copy_context()
|
context = contextvars.copy_context()
|
||||||
|
|
||||||
@copy_current_request_context
|
worker_thread = threading.Thread(
|
||||||
def worker_with_context():
|
target=self._generate_worker,
|
||||||
# Run the worker within the copied context
|
kwargs={
|
||||||
return context.run(
|
"flask_app": current_app._get_current_object(), # type: ignore
|
||||||
self._generate_worker,
|
"application_generate_entity": application_generate_entity,
|
||||||
flask_app=current_app._get_current_object(), # type: ignore
|
"queue_manager": queue_manager,
|
||||||
application_generate_entity=application_generate_entity,
|
"conversation_id": conversation.id,
|
||||||
queue_manager=queue_manager,
|
"message_id": message.id,
|
||||||
conversation_id=conversation.id,
|
"context": context,
|
||||||
message_id=message.id,
|
},
|
||||||
context=context,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
worker_thread = threading.Thread(target=worker_with_context)
|
|
||||||
|
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
|
|
||||||
@@ -449,24 +447,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
:param message_id: message ID
|
:param message_id: message ID
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
for var, val in context.items():
|
|
||||||
var.set(val)
|
|
||||||
|
|
||||||
# FIXME(-LAN-): Save current user before entering new app context
|
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||||
from flask import g
|
|
||||||
|
|
||||||
saved_user = None
|
|
||||||
if has_request_context() and hasattr(g, "_login_user"):
|
|
||||||
saved_user = g._login_user
|
|
||||||
|
|
||||||
with flask_app.app_context():
|
|
||||||
try:
|
try:
|
||||||
# Restore user in new app context
|
|
||||||
if saved_user is not None:
|
|
||||||
from flask import g
|
|
||||||
|
|
||||||
g._login_user = saved_user
|
|
||||||
|
|
||||||
# get conversation and message
|
# get conversation and message
|
||||||
conversation = self._get_conversation(conversation_id)
|
conversation = self._get_conversation(conversation_id)
|
||||||
message = self._get_message(message_id)
|
message = self._get_message(message_id)
|
||||||
|
@@ -5,7 +5,7 @@ import uuid
|
|||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Any, Literal, Union, overload
|
from typing import Any, Literal, Union, overload
|
||||||
|
|
||||||
from flask import Flask, copy_current_request_context, current_app, has_request_context
|
from flask import Flask, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@@ -23,6 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
|||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
|
from libs.flask_utils import preserve_flask_contexts
|
||||||
from models import Account, App, EndUser
|
from models import Account, App, EndUser
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
from services.errors.message import MessageNotExistsError
|
from services.errors.message import MessageNotExistsError
|
||||||
@@ -182,20 +183,17 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
# new thread with request context and contextvars
|
# new thread with request context and contextvars
|
||||||
context = contextvars.copy_context()
|
context = contextvars.copy_context()
|
||||||
|
|
||||||
@copy_current_request_context
|
worker_thread = threading.Thread(
|
||||||
def worker_with_context():
|
target=self._generate_worker,
|
||||||
# Run the worker within the copied context
|
kwargs={
|
||||||
return context.run(
|
"flask_app": current_app._get_current_object(), # type: ignore
|
||||||
self._generate_worker,
|
"context": context,
|
||||||
flask_app=current_app._get_current_object(), # type: ignore
|
"application_generate_entity": application_generate_entity,
|
||||||
context=context,
|
"queue_manager": queue_manager,
|
||||||
application_generate_entity=application_generate_entity,
|
"conversation_id": conversation.id,
|
||||||
queue_manager=queue_manager,
|
"message_id": message.id,
|
||||||
conversation_id=conversation.id,
|
},
|
||||||
message_id=message.id,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
worker_thread = threading.Thread(target=worker_with_context)
|
|
||||||
|
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
|
|
||||||
@@ -229,24 +227,9 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
:param message_id: message ID
|
:param message_id: message ID
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
for var, val in context.items():
|
|
||||||
var.set(val)
|
|
||||||
|
|
||||||
# FIXME(-LAN-): Save current user before entering new app context
|
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||||
from flask import g
|
|
||||||
|
|
||||||
saved_user = None
|
|
||||||
if has_request_context() and hasattr(g, "_login_user"):
|
|
||||||
saved_user = g._login_user
|
|
||||||
|
|
||||||
with flask_app.app_context():
|
|
||||||
try:
|
try:
|
||||||
# Restore user in new app context
|
|
||||||
if saved_user is not None:
|
|
||||||
from flask import g
|
|
||||||
|
|
||||||
g._login_user = saved_user
|
|
||||||
|
|
||||||
# get conversation and message
|
# get conversation and message
|
||||||
conversation = self._get_conversation(conversation_id)
|
conversation = self._get_conversation(conversation_id)
|
||||||
message = self._get_message(message_id)
|
message = self._get_message(message_id)
|
||||||
|
@@ -5,7 +5,7 @@ import uuid
|
|||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
from typing import Any, Literal, Optional, Union, overload
|
from typing import Any, Literal, Optional, Union, overload
|
||||||
|
|
||||||
from flask import Flask, copy_current_request_context, current_app, has_request_context
|
from flask import Flask, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
@@ -29,6 +29,7 @@ from core.workflow.repositories.workflow_execution_repository import WorkflowExe
|
|||||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
|
from libs.flask_utils import preserve_flask_contexts
|
||||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||||
from models.enums import WorkflowRunTriggeredFrom
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
|
|
||||||
@@ -209,19 +210,16 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
# new thread with request context and contextvars
|
# new thread with request context and contextvars
|
||||||
context = contextvars.copy_context()
|
context = contextvars.copy_context()
|
||||||
|
|
||||||
@copy_current_request_context
|
worker_thread = threading.Thread(
|
||||||
def worker_with_context():
|
target=self._generate_worker,
|
||||||
# Run the worker within the copied context
|
kwargs={
|
||||||
return context.run(
|
"flask_app": current_app._get_current_object(), # type: ignore
|
||||||
self._generate_worker,
|
"application_generate_entity": application_generate_entity,
|
||||||
flask_app=current_app._get_current_object(), # type: ignore
|
"queue_manager": queue_manager,
|
||||||
application_generate_entity=application_generate_entity,
|
"context": context,
|
||||||
queue_manager=queue_manager,
|
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||||
context=context,
|
},
|
||||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
worker_thread = threading.Thread(target=worker_with_context)
|
|
||||||
|
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
|
|
||||||
@@ -408,24 +406,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
:param workflow_thread_pool_id: workflow thread pool id
|
:param workflow_thread_pool_id: workflow thread pool id
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
for var, val in context.items():
|
|
||||||
var.set(val)
|
|
||||||
|
|
||||||
# FIXME(-LAN-): Save current user before entering new app context
|
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||||
from flask import g
|
|
||||||
|
|
||||||
saved_user = None
|
|
||||||
if has_request_context() and hasattr(g, "_login_user"):
|
|
||||||
saved_user = g._login_user
|
|
||||||
|
|
||||||
with flask_app.app_context():
|
|
||||||
try:
|
try:
|
||||||
# Restore user in new app context
|
|
||||||
if saved_user is not None:
|
|
||||||
from flask import g
|
|
||||||
|
|
||||||
g._login_user = saved_user
|
|
||||||
|
|
||||||
# workflow app
|
# workflow app
|
||||||
runner = WorkflowAppRunner(
|
runner = WorkflowAppRunner(
|
||||||
application_generate_entity=application_generate_entity,
|
application_generate_entity=application_generate_entity,
|
||||||
|
@@ -9,7 +9,7 @@ from copy import copy, deepcopy
|
|||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import Any, Optional, cast
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
from flask import Flask, current_app, has_request_context
|
from flask import Flask, current_app
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||||
@@ -53,6 +53,7 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
|||||||
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
|
||||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||||
|
from libs.flask_utils import preserve_flask_contexts
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import WorkflowType
|
from models.workflow import WorkflowType
|
||||||
|
|
||||||
@@ -537,24 +538,9 @@ class GraphEngine:
|
|||||||
"""
|
"""
|
||||||
Run parallel nodes
|
Run parallel nodes
|
||||||
"""
|
"""
|
||||||
for var, val in context.items():
|
|
||||||
var.set(val)
|
|
||||||
|
|
||||||
# FIXME(-LAN-): Save current user before entering new app context
|
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||||
from flask import g
|
|
||||||
|
|
||||||
saved_user = None
|
|
||||||
if has_request_context() and hasattr(g, "_login_user"):
|
|
||||||
saved_user = g._login_user
|
|
||||||
|
|
||||||
with flask_app.app_context():
|
|
||||||
try:
|
try:
|
||||||
# Restore user in new app context
|
|
||||||
if saved_user is not None:
|
|
||||||
from flask import g
|
|
||||||
|
|
||||||
g._login_user = saved_user
|
|
||||||
|
|
||||||
q.put(
|
q.put(
|
||||||
ParallelBranchRunStartedEvent(
|
ParallelBranchRunStartedEvent(
|
||||||
parallel_id=parallel_id,
|
parallel_id=parallel_id,
|
||||||
|
@@ -7,7 +7,7 @@ from datetime import UTC, datetime
|
|||||||
from queue import Empty, Queue
|
from queue import Empty, Queue
|
||||||
from typing import TYPE_CHECKING, Any, Optional, cast
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
||||||
|
|
||||||
from flask import Flask, current_app, has_request_context
|
from flask import Flask, current_app
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
|
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
|
||||||
@@ -37,6 +37,7 @@ from core.workflow.nodes.base import BaseNode
|
|||||||
from core.workflow.nodes.enums import NodeType
|
from core.workflow.nodes.enums import NodeType
|
||||||
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
|
||||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||||
|
from libs.flask_utils import preserve_flask_contexts
|
||||||
|
|
||||||
from .exc import (
|
from .exc import (
|
||||||
InvalidIteratorValueError,
|
InvalidIteratorValueError,
|
||||||
@@ -583,23 +584,8 @@ class IterationNode(BaseNode[IterationNodeData]):
|
|||||||
"""
|
"""
|
||||||
run single iteration in parallel mode
|
run single iteration in parallel mode
|
||||||
"""
|
"""
|
||||||
for var, val in context.items():
|
|
||||||
var.set(val)
|
|
||||||
|
|
||||||
# FIXME(-LAN-): Save current user before entering new app context
|
|
||||||
from flask import g
|
|
||||||
|
|
||||||
saved_user = None
|
|
||||||
if has_request_context() and hasattr(g, "_login_user"):
|
|
||||||
saved_user = g._login_user
|
|
||||||
|
|
||||||
with flask_app.app_context():
|
|
||||||
# Restore user in new app context
|
|
||||||
if saved_user is not None:
|
|
||||||
from flask import g
|
|
||||||
|
|
||||||
g._login_user = saved_user
|
|
||||||
|
|
||||||
|
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||||
parallel_mode_run_id = uuid.uuid4().hex
|
parallel_mode_run_id = uuid.uuid4().hex
|
||||||
graph_engine_copy = graph_engine.create_copy()
|
graph_engine_copy = graph_engine.create_copy()
|
||||||
variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool
|
variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool
|
||||||
|
65
api/libs/flask_utils.py
Normal file
65
api/libs/flask_utils.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
import contextvars
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from flask import Flask, g, has_request_context
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def preserve_flask_contexts(
|
||||||
|
flask_app: Flask,
|
||||||
|
context_vars: contextvars.Context,
|
||||||
|
) -> Iterator[None]:
|
||||||
|
"""
|
||||||
|
A context manager that handles:
|
||||||
|
1. flask-login's UserProxy copy
|
||||||
|
2. ContextVars copy
|
||||||
|
3. flask_app.app_context()
|
||||||
|
|
||||||
|
This context manager ensures that the Flask application context is properly set up,
|
||||||
|
the current user is preserved across context boundaries, and any provided context variables
|
||||||
|
are set within the new context.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This manager aims to allow use current_user cross thread and app context,
|
||||||
|
but it's not the recommend use, it's better to pass user directly in parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
flask_app: The Flask application instance
|
||||||
|
context_vars: contextvars.Context object containing context variables to be set in the new context
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
None
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
with preserve_flask_contexts(flask_app, context_vars=context_vars):
|
||||||
|
# Code that needs Flask app context and context variables
|
||||||
|
# Current user will be preserved if available
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# Set context variables if provided
|
||||||
|
if context_vars:
|
||||||
|
for var, val in context_vars.items():
|
||||||
|
var.set(val)
|
||||||
|
|
||||||
|
# Save current user before entering new app context
|
||||||
|
saved_user = None
|
||||||
|
if has_request_context() and hasattr(g, "_login_user"):
|
||||||
|
saved_user = g._login_user
|
||||||
|
|
||||||
|
# Enter Flask app context
|
||||||
|
with flask_app.app_context():
|
||||||
|
try:
|
||||||
|
# Restore user in new app context if it was saved
|
||||||
|
if saved_user is not None:
|
||||||
|
g._login_user = saved_user
|
||||||
|
|
||||||
|
# Yield control back to the caller
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# Any cleanup can be added here if needed
|
||||||
|
pass
|
124
api/tests/unit_tests/libs/test_flask_utils.py
Normal file
124
api/tests/unit_tests/libs/test_flask_utils.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
import contextvars
|
||||||
|
import threading
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from flask import Flask
|
||||||
|
from flask_login import LoginManager, UserMixin, current_user, login_user
|
||||||
|
|
||||||
|
from libs.flask_utils import preserve_flask_contexts
|
||||||
|
|
||||||
|
|
||||||
|
class User(UserMixin):
|
||||||
|
"""Simple User class for testing."""
|
||||||
|
|
||||||
|
def __init__(self, id: str):
|
||||||
|
self.id = id
|
||||||
|
|
||||||
|
def get_id(self) -> str:
|
||||||
|
return self.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def login_app(app: Flask) -> Flask:
|
||||||
|
"""Set up a Flask app with flask-login."""
|
||||||
|
# Set a secret key for the app
|
||||||
|
app.config["SECRET_KEY"] = "test-secret-key"
|
||||||
|
|
||||||
|
login_manager = LoginManager()
|
||||||
|
login_manager.init_app(app)
|
||||||
|
|
||||||
|
@login_manager.user_loader
|
||||||
|
def load_user(user_id: str) -> Optional[User]:
|
||||||
|
if user_id == "test_user":
|
||||||
|
return User("test_user")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_user() -> User:
|
||||||
|
"""Create a test user."""
|
||||||
|
return User("test_user")
|
||||||
|
|
||||||
|
|
||||||
|
def test_current_user_not_accessible_across_threads(login_app: Flask, test_user: User):
|
||||||
|
"""
|
||||||
|
Test that current_user is not accessible in a different thread without preserve_flask_contexts.
|
||||||
|
|
||||||
|
This test demonstrates that without the preserve_flask_contexts, we cannot access
|
||||||
|
current_user in a different thread, even with app_context.
|
||||||
|
"""
|
||||||
|
# Log in the user in the main thread
|
||||||
|
with login_app.test_request_context():
|
||||||
|
login_user(test_user)
|
||||||
|
assert current_user.is_authenticated
|
||||||
|
assert current_user.id == "test_user"
|
||||||
|
|
||||||
|
# Store the result of the thread execution
|
||||||
|
result = {"user_accessible": True, "error": None}
|
||||||
|
|
||||||
|
# Define a function to run in a separate thread
|
||||||
|
def check_user_in_thread():
|
||||||
|
try:
|
||||||
|
# Try to access current_user in a different thread with app_context
|
||||||
|
with login_app.app_context():
|
||||||
|
# This should fail because current_user is not accessible across threads
|
||||||
|
# without preserve_flask_contexts
|
||||||
|
result["user_accessible"] = current_user.is_authenticated
|
||||||
|
except Exception as e:
|
||||||
|
result["error"] = str(e) # type: ignore
|
||||||
|
|
||||||
|
# Run the function in a separate thread
|
||||||
|
thread = threading.Thread(target=check_user_in_thread)
|
||||||
|
thread.start()
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
# Verify that we got an error or current_user is not authenticated
|
||||||
|
assert result["error"] is not None or (result["user_accessible"] is not None and not result["user_accessible"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_current_user_accessible_with_preserve_flask_contexts(login_app: Flask, test_user: User):
|
||||||
|
"""
|
||||||
|
Test that current_user is accessible in a different thread with preserve_flask_contexts.
|
||||||
|
|
||||||
|
This test demonstrates that with the preserve_flask_contexts, we can access
|
||||||
|
current_user in a different thread.
|
||||||
|
"""
|
||||||
|
# Log in the user in the main thread
|
||||||
|
with login_app.test_request_context():
|
||||||
|
login_user(test_user)
|
||||||
|
assert current_user.is_authenticated
|
||||||
|
assert current_user.id == "test_user"
|
||||||
|
|
||||||
|
# Save the context variables
|
||||||
|
context_vars = contextvars.copy_context()
|
||||||
|
|
||||||
|
# Store the result of the thread execution
|
||||||
|
result = {"user_accessible": False, "user_id": None, "error": None}
|
||||||
|
|
||||||
|
# Define a function to run in a separate thread
|
||||||
|
def check_user_in_thread_with_manager():
|
||||||
|
try:
|
||||||
|
# Use preserve_flask_contexts to access current_user in a different thread
|
||||||
|
with preserve_flask_contexts(login_app, context_vars):
|
||||||
|
from flask_login import current_user
|
||||||
|
|
||||||
|
if current_user:
|
||||||
|
result["user_accessible"] = True
|
||||||
|
result["user_id"] = current_user.id
|
||||||
|
else:
|
||||||
|
result["user_accessible"] = False
|
||||||
|
except Exception as e:
|
||||||
|
result["error"] = str(e) # type: ignore
|
||||||
|
|
||||||
|
# Run the function in a separate thread
|
||||||
|
thread = threading.Thread(target=check_user_in_thread_with_manager)
|
||||||
|
thread.start()
|
||||||
|
thread.join()
|
||||||
|
|
||||||
|
# Verify that current_user is accessible and has the correct ID
|
||||||
|
assert result["error"] is None
|
||||||
|
assert result["user_accessible"] is True
|
||||||
|
assert result["user_id"] == "test_user"
|
Reference in New Issue
Block a user