refactor: streamline initialization of application_generate_entity and task_state in task pipeline classes (#12326)

Signed-off-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
-LAN-
2025-01-03 18:41:44 +08:00
committed by GitHub
parent 478150e850
commit 7ed6485f86
6 changed files with 280 additions and 222 deletions

View File

@@ -15,7 +15,6 @@ from core.app.entities.queue_entities import (
from core.app.entities.task_entities import (
ErrorStreamResponse,
PingStreamResponse,
TaskState,
)
from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
@@ -30,22 +29,12 @@ class BasedGenerateTaskPipeline:
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_task_state: TaskState
_application_generate_entity: AppGenerateEntity
def __init__(
self,
application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager,
stream: bool,
) -> None:
"""
Initialize GenerateTaskPipeline.
:param application_generate_entity: application generate entity
:param queue_manager: queue manager
:param user: user
:param stream: stream
"""
self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager
self._start_at = time.perf_counter()

View File

@@ -31,10 +31,19 @@ from services.annotation_service import AppAnnotationService
class MessageCycleManage:
_application_generate_entity: Union[
ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity, AdvancedChatAppGenerateEntity
]
_task_state: Union[EasyUITaskState, WorkflowTaskState]
def __init__(
self,
*,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
task_state: Union[EasyUITaskState, WorkflowTaskState],
) -> None:
self._application_generate_entity = application_generate_entity
self._task_state = task_state
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
"""

View File

@@ -34,7 +34,6 @@ from core.app.entities.task_entities import (
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
WorkflowTaskState,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -58,13 +57,20 @@ from models.workflow import (
WorkflowRunStatus,
)
from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
from .exc import WorkflowRunNotFoundError
class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any]
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
) -> None:
self._workflow_run: WorkflowRun | None = None
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
def _handle_workflow_run_start(
self,
@@ -240,7 +246,7 @@ class WorkflowCycleManage:
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
stmt = select(WorkflowNodeExecution).where(
stmt = select(WorkflowNodeExecution.node_execution_id).where(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
@@ -248,16 +254,18 @@ class WorkflowCycleManage:
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
)
running_workflow_node_executions = session.scalars(stmt).all()
ids = session.scalars(stmt).all()
# Use self._get_workflow_node_execution here to make sure the cache is updated
running_workflow_node_executions = [
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
]
for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.elapsed_time = (
workflow_node_execution.finished_at - workflow_node_execution.created_at
).total_seconds()
workflow_node_execution.finished_at = now
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
if trace_manager:
trace_manager.add_trace_task(
@@ -299,6 +307,8 @@ class WorkflowCycleManage:
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
def _handle_workflow_node_execution_success(
@@ -326,6 +336,7 @@ class WorkflowCycleManage:
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_failed(
@@ -365,6 +376,7 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_retried(
@@ -416,6 +428,8 @@ class WorkflowCycleManage:
workflow_node_execution.index = event.node_run_index
session.add(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
#################################################
@@ -812,22 +826,20 @@ class WorkflowCycleManage:
return None
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
"""
Refetch workflow run
:param workflow_run_id: workflow run id
:return:
"""
if self._workflow_run and self._workflow_run.id == workflow_run_id:
cached_workflow_run = self._workflow_run
cached_workflow_run = session.merge(cached_workflow_run)
return cached_workflow_run
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalar(stmt)
if not workflow_run:
raise WorkflowRunNotFoundError(workflow_run_id)
self._workflow_run = workflow_run
return workflow_run
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.node_execution_id == node_execution_id)
workflow_node_execution = session.scalar(stmt)
if not workflow_node_execution:
raise WorkflowNodeExecutionNotFoundError(node_execution_id)
return workflow_node_execution
if node_execution_id not in self._workflow_node_executions:
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
return cached_workflow_node_execution