FEAT: NEW WORKFLOW ENGINE (#3160)
Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
0
api/core/app/task_pipeline/__init__.py
Normal file
0
api/core/app/task_pipeline/__init__.py
Normal file
152
api/core/app/task_pipeline/based_generate_task_pipeline.py
Normal file
152
api/core/app/task_pipeline/based_generate_task_pipeline.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueErrorEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
ErrorStreamResponse,
|
||||
PingStreamResponse,
|
||||
TaskState,
|
||||
)
|
||||
from core.errors.error import QuotaExceededError
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.moderation.output_moderation import ModerationRule, OutputModeration
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser, Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasedGenerateTaskPipeline:
|
||||
"""
|
||||
BasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
_task_state: TaskState
|
||||
_application_generate_entity: AppGenerateEntity
|
||||
|
||||
def __init__(self, application_generate_entity: AppGenerateEntity,
|
||||
queue_manager: AppQueueManager,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param user: user
|
||||
:param stream: stream
|
||||
"""
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._queue_manager = queue_manager
|
||||
self._user = user
|
||||
self._start_at = time.perf_counter()
|
||||
self._output_moderation_handler = self._init_output_moderation()
|
||||
self._stream = stream
|
||||
|
||||
def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None) -> Exception:
|
||||
"""
|
||||
Handle error event.
|
||||
:param event: event
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
logger.debug("error: %s", event.error)
|
||||
e = event.error
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
err = InvokeAuthorizationError('Incorrect API key provided')
|
||||
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
||||
err = e
|
||||
else:
|
||||
err = Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
||||
|
||||
if message:
|
||||
message = db.session.query(Message).filter(Message.id == message.id).first()
|
||||
err_desc = self._error_to_desc(err)
|
||||
message.status = 'error'
|
||||
message.error = err_desc
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return err
|
||||
|
||||
def _error_to_desc(cls, e: Exception) -> str:
|
||||
"""
|
||||
Error to desc.
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
if isinstance(e, QuotaExceededError):
|
||||
return ("Your quota for Dify Hosted Model Provider has been exhausted. "
|
||||
"Please go to Settings -> Model Provider to complete your own provider credentials.")
|
||||
|
||||
message = getattr(e, 'description', str(e))
|
||||
if not message:
|
||||
message = 'Internal Server Error, please contact support.'
|
||||
|
||||
return message
|
||||
|
||||
def _error_to_stream_response(self, e: Exception) -> ErrorStreamResponse:
|
||||
"""
|
||||
Error to stream response.
|
||||
:param e: exception
|
||||
:return:
|
||||
"""
|
||||
return ErrorStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
err=e
|
||||
)
|
||||
|
||||
def _ping_stream_response(self) -> PingStreamResponse:
|
||||
"""
|
||||
Ping stream response.
|
||||
:return:
|
||||
"""
|
||||
return PingStreamResponse(task_id=self._application_generate_entity.task_id)
|
||||
|
||||
def _init_output_moderation(self) -> Optional[OutputModeration]:
|
||||
"""
|
||||
Init output moderation.
|
||||
:return:
|
||||
"""
|
||||
app_config = self._application_generate_entity.app_config
|
||||
sensitive_word_avoidance = app_config.sensitive_word_avoidance
|
||||
|
||||
if sensitive_word_avoidance:
|
||||
return OutputModeration(
|
||||
tenant_id=app_config.tenant_id,
|
||||
app_id=app_config.app_id,
|
||||
rule=ModerationRule(
|
||||
type=sensitive_word_avoidance.type,
|
||||
config=sensitive_word_avoidance.config
|
||||
),
|
||||
queue_manager=self._queue_manager
|
||||
)
|
||||
|
||||
def _handle_output_moderation_when_task_finished(self, completion: str) -> Optional[str]:
|
||||
"""
|
||||
Handle output moderation when task finished.
|
||||
:param completion: completion
|
||||
:return:
|
||||
"""
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
|
||||
completion = self._output_moderation_handler.moderation_completion(
|
||||
completion=completion,
|
||||
public_event=False
|
||||
)
|
||||
|
||||
self._output_moderation_handler = None
|
||||
|
||||
return completion
|
||||
|
||||
return None
|
@@ -0,0 +1,428 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AgentChatAppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentMessageEvent,
|
||||
QueueAgentThoughtEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueMessageEndEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueueMessageReplaceEvent,
|
||||
QueuePingEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueStopEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AgentMessageStreamResponse,
|
||||
AgentThoughtStreamResponse,
|
||||
ChatbotAppBlockingResponse,
|
||||
ChatbotAppStreamResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
CompletionAppStreamResponse,
|
||||
EasyUITaskState,
|
||||
ErrorStreamResponse,
|
||||
MessageEndStreamResponse,
|
||||
StreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import AppMode, Conversation, EndUser, Message, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage):
|
||||
"""
|
||||
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
_task_state: EasyUITaskState
|
||||
_application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity
|
||||
]
|
||||
|
||||
def __init__(self, application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity
|
||||
],
|
||||
queue_manager: AppQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
user: Union[Account, EndUser],
|
||||
stream: bool) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:param user: user
|
||||
:param stream: stream
|
||||
"""
|
||||
super().__init__(application_generate_entity, queue_manager, user, stream)
|
||||
self._model_config = application_generate_entity.model_config
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
|
||||
self._task_state = EasyUITaskState(
|
||||
llm_result=LLMResult(
|
||||
model=self._model_config.model,
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=LLMUsage.empty_usage()
|
||||
)
|
||||
)
|
||||
|
||||
def process(self) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse,
|
||||
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]
|
||||
]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
"""
|
||||
db.session.refresh(self._conversation)
|
||||
db.session.refresh(self._message)
|
||||
db.session.close()
|
||||
|
||||
generator = self._process_stream_response()
|
||||
if self._stream:
|
||||
return self._to_stream_response(generator)
|
||||
else:
|
||||
return self._to_blocking_response(generator)
|
||||
|
||||
def _to_blocking_response(self, generator: Generator[StreamResponse, None, None]) -> Union[
|
||||
ChatbotAppBlockingResponse,
|
||||
CompletionAppBlockingResponse
|
||||
]:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
"""
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, MessageEndStreamResponse):
|
||||
extras = {
|
||||
'usage': jsonable_encoder(self._task_state.llm_result.usage)
|
||||
}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata
|
||||
|
||||
if self._conversation.mode == AppMode.COMPLETION.value:
|
||||
response = CompletionAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=CompletionAppBlockingResponse.Data(
|
||||
id=self._message.id,
|
||||
mode=self._conversation.mode,
|
||||
message_id=self._message.id,
|
||||
answer=self._task_state.llm_result.message.content,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras
|
||||
)
|
||||
)
|
||||
else:
|
||||
response = ChatbotAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
data=ChatbotAppBlockingResponse.Data(
|
||||
id=self._message.id,
|
||||
mode=self._conversation.mode,
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id,
|
||||
answer=self._task_state.llm_result.message.content,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
**extras
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
raise Exception('Queue listening stopped unexpectedly.')
|
||||
|
||||
def _to_stream_response(self, generator: Generator[StreamResponse, None, None]) \
|
||||
-> Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None]:
|
||||
"""
|
||||
To stream response.
|
||||
:return:
|
||||
"""
|
||||
for stream_response in generator:
|
||||
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
|
||||
yield CompletionAppStreamResponse(
|
||||
message_id=self._message.id,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
stream_response=stream_response
|
||||
)
|
||||
else:
|
||||
yield ChatbotAppStreamResponse(
|
||||
conversation_id=self._conversation.id,
|
||||
message_id=self._message.id,
|
||||
created_at=int(self._message.created_at.timestamp()),
|
||||
stream_response=stream_response
|
||||
)
|
||||
|
||||
def _process_stream_response(self) -> Generator[StreamResponse, None, None]:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
err = self._handle_error(event, self._message)
|
||||
yield self._error_to_stream_response(err)
|
||||
break
|
||||
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
self._handle_stop(event)
|
||||
|
||||
# handle output moderation
|
||||
output_moderation_answer = self._handle_output_moderation_when_task_finished(
|
||||
self._task_state.llm_result.message.content
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.llm_result.message.content = output_moderation_answer
|
||||
yield self._message_replace_to_stream_response(answer=output_moderation_answer)
|
||||
|
||||
# Save message
|
||||
self._save_message()
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._handle_retriever_resources(event)
|
||||
elif isinstance(event, QueueAnnotationReplyEvent):
|
||||
annotation = self._handle_annotation_reply(event)
|
||||
if annotation:
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, QueueAgentThoughtEvent):
|
||||
yield self._agent_thought_to_stream_response(event)
|
||||
elif isinstance(event, QueueMessageFileEvent):
|
||||
response = self._message_file_to_stream_response(event)
|
||||
if response:
|
||||
yield response
|
||||
elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent):
|
||||
chunk = event.chunk
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
if not self._task_state.llm_result.prompt_messages:
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
# handle output moderation chunk
|
||||
should_direct_answer = self._handle_output_moderation_chunk(delta_text)
|
||||
if should_direct_answer:
|
||||
continue
|
||||
|
||||
self._task_state.llm_result.message.content += delta_text
|
||||
|
||||
if isinstance(event, QueueLLMChunkEvent):
|
||||
yield self._message_to_stream_response(delta_text, self._message.id)
|
||||
else:
|
||||
yield self._agent_message_to_stream_response(delta_text, self._message.id)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
yield self._message_replace_to_stream_response(answer=event.text)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield self._ping_stream_response()
|
||||
else:
|
||||
continue
|
||||
|
||||
def _save_message(self) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:return:
|
||||
"""
|
||||
llm_result = self._task_state.llm_result
|
||||
usage = llm_result.usage
|
||||
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
self._conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first()
|
||||
|
||||
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
self._model_config.mode,
|
||||
self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
|
||||
if llm_result.message.content else ''
|
||||
self._message.answer_tokens = usage.completion_tokens
|
||||
self._message.answer_unit_price = usage.completion_unit_price
|
||||
self._message.answer_price_unit = usage.completion_price_unit
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.total_price = usage.total_price
|
||||
self._message.currency = usage.currency
|
||||
self._message.message_metadata = json.dumps(jsonable_encoder(self._task_state.metadata)) \
|
||||
if self._task_state.metadata else None
|
||||
|
||||
db.session.commit()
|
||||
|
||||
message_was_created.send(
|
||||
self._message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.app_config.app_mode in [
|
||||
AppMode.AGENT_CHAT,
|
||||
AppMode.CHAT
|
||||
] and self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras
|
||||
)
|
||||
|
||||
def _handle_stop(self, event: QueueStopEvent) -> None:
|
||||
"""
|
||||
Handle stop.
|
||||
:return:
|
||||
"""
|
||||
model_config = self._model_config
|
||||
model = model_config.model
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = 0
|
||||
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model,
|
||||
model_config.credentials,
|
||||
self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
|
||||
completion_tokens = model_type_instance.get_num_tokens(
|
||||
model,
|
||||
model_config.credentials,
|
||||
[self._task_state.llm_result.message]
|
||||
)
|
||||
|
||||
credentials = model_config.credentials
|
||||
|
||||
# transform usage
|
||||
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
|
||||
model,
|
||||
credentials,
|
||||
prompt_tokens,
|
||||
completion_tokens
|
||||
)
|
||||
|
||||
def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
|
||||
"""
|
||||
Message end to stream response.
|
||||
:return:
|
||||
"""
|
||||
self._task_state.metadata['usage'] = jsonable_encoder(self._task_state.llm_result.usage)
|
||||
|
||||
extras = {}
|
||||
if self._task_state.metadata:
|
||||
extras['metadata'] = self._task_state.metadata
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message.id,
|
||||
**extras
|
||||
)
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
"""
|
||||
Agent message to stream response.
|
||||
:param answer: answer
|
||||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
return AgentMessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer
|
||||
)
|
||||
|
||||
def _agent_thought_to_stream_response(self, event: QueueAgentThoughtEvent) -> Optional[AgentThoughtStreamResponse]:
|
||||
"""
|
||||
Agent thought to stream response.
|
||||
:param event: agent thought event
|
||||
:return:
|
||||
"""
|
||||
agent_thought: MessageAgentThought = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.id == event.agent_thought_id)
|
||||
.first()
|
||||
)
|
||||
db.session.refresh(agent_thought)
|
||||
db.session.close()
|
||||
|
||||
if agent_thought:
|
||||
return AgentThoughtStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=agent_thought.id,
|
||||
position=agent_thought.position,
|
||||
thought=agent_thought.thought,
|
||||
observation=agent_thought.observation,
|
||||
tool=agent_thought.tool,
|
||||
tool_labels=agent_thought.tool_labels,
|
||||
tool_input=agent_thought.tool_input,
|
||||
message_files=agent_thought.files
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _handle_output_moderation_chunk(self, text: str) -> bool:
|
||||
"""
|
||||
Handle output moderation chunk.
|
||||
:param text: text
|
||||
:return: True if output moderation should direct output, otherwise False
|
||||
"""
|
||||
if self._output_moderation_handler:
|
||||
if self._output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish(
|
||||
QueueLLMChunkEvent(
|
||||
chunk=LLMResultChunk(
|
||||
model=self._task_state.llm_result.model,
|
||||
prompt_messages=self._task_state.llm_result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
|
||||
)
|
||||
)
|
||||
), PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
|
||||
self._queue_manager.publish(
|
||||
QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION),
|
||||
PublishFrom.TASK_PIPELINE
|
||||
)
|
||||
return True
|
||||
else:
|
||||
self._output_moderation_handler.append_new_token(text)
|
||||
|
||||
return False
|
147
api/core/app/task_pipeline/message_cycle_manage.py
Normal file
147
api/core/app/task_pipeline/message_cycle_manage.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatTaskState,
|
||||
EasyUITaskState,
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
MessageStreamResponse,
|
||||
)
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import MessageAnnotation, MessageFile
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
class MessageCycleManage:
|
||||
_application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity
|
||||
]
|
||||
_task_state: Union[EasyUITaskState, AdvancedChatTaskState]
|
||||
|
||||
def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Handle annotation reply.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata['annotation_reply'] = {
|
||||
'id': annotation.id,
|
||||
'account': {
|
||||
'id': annotation.account_id,
|
||||
'name': account.name if account else 'Dify user'
|
||||
}
|
||||
}
|
||||
|
||||
return annotation
|
||||
|
||||
return None
|
||||
|
||||
def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
|
||||
"""
|
||||
Handle retriever resources.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
|
||||
def _get_response_metadata(self) -> dict:
|
||||
"""
|
||||
Get response metadata by invoke from.
|
||||
:return:
|
||||
"""
|
||||
metadata = {}
|
||||
|
||||
# show_retrieve_source
|
||||
if 'retriever_resources' in self._task_state.metadata:
|
||||
metadata['retriever_resources'] = self._task_state.metadata['retriever_resources']
|
||||
|
||||
# show annotation reply
|
||||
if 'annotation_reply' in self._task_state.metadata:
|
||||
metadata['annotation_reply'] = self._task_state.metadata['annotation_reply']
|
||||
|
||||
# show usage
|
||||
if self._application_generate_entity.invoke_from in [InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API]:
|
||||
metadata['usage'] = self._task_state.metadata['usage']
|
||||
|
||||
return metadata
|
||||
|
||||
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
"""
|
||||
Message file to stream response.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
message_file: MessageFile = (
|
||||
db.session.query(MessageFile)
|
||||
.filter(MessageFile.id == event.message_file_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if message_file:
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split('/')[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split('.')[0]
|
||||
|
||||
# get extension
|
||||
if '.' in message_file.url:
|
||||
extension = f'.{message_file.url.split(".")[-1]}'
|
||||
if len(extension) > 10:
|
||||
extension = '.bin'
|
||||
else:
|
||||
extension = '.bin'
|
||||
# add sign url
|
||||
url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension)
|
||||
|
||||
return MessageFileStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_file.id,
|
||||
type=message_file.type,
|
||||
belongs_to=message_file.belongs_to or 'user',
|
||||
url=url
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _message_to_stream_response(self, answer: str, message_id: str) -> MessageStreamResponse:
|
||||
"""
|
||||
Message to stream response.
|
||||
:param answer: answer
|
||||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
return MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer
|
||||
)
|
||||
|
||||
def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
|
||||
"""
|
||||
Message replace to stream response.
|
||||
:param answer: answer
|
||||
:return:
|
||||
"""
|
||||
return MessageReplaceStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
answer=answer
|
||||
)
|
592
api/core/app/task_pipeline/workflow_cycle_manage.py
Normal file
592
api/core/app/task_pipeline/workflow_cycle_manage.py
Normal file
@@ -0,0 +1,592 @@
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueStopEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AdvancedChatTaskState,
|
||||
NodeExecutionInfo,
|
||||
NodeFinishStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.file.file_obj import FileVar
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeType, SystemVariable
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
CreatedByRole,
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
WorkflowRunTriggeredFrom,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowCycleManage:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
|
||||
_workflow_system_variables: dict[SystemVariable, Any]
|
||||
|
||||
def _init_workflow_run(self, workflow: Workflow,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
user: Union[Account, EndUser],
|
||||
user_inputs: dict,
|
||||
system_inputs: Optional[dict] = None) -> WorkflowRun:
|
||||
"""
|
||||
Init workflow run
|
||||
:param workflow: Workflow instance
|
||||
:param triggered_from: triggered from
|
||||
:param user: account or end user
|
||||
:param user_inputs: user variables inputs
|
||||
:param system_inputs: system inputs, like: query, files
|
||||
:return:
|
||||
"""
|
||||
max_sequence = db.session.query(db.func.max(WorkflowRun.sequence_number)) \
|
||||
.filter(WorkflowRun.tenant_id == workflow.tenant_id) \
|
||||
.filter(WorkflowRun.app_id == workflow.app_id) \
|
||||
.scalar() or 0
|
||||
new_sequence_number = max_sequence + 1
|
||||
|
||||
inputs = {**user_inputs}
|
||||
for key, value in (system_inputs or {}).items():
|
||||
if key.value == 'conversation':
|
||||
continue
|
||||
|
||||
inputs[f'sys.{key.value}'] = value
|
||||
inputs = WorkflowEngineManager.handle_special_values(inputs)
|
||||
|
||||
# init workflow run
|
||||
workflow_run = WorkflowRun(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
sequence_number=new_sequence_number,
|
||||
workflow_id=workflow.id,
|
||||
type=workflow.type,
|
||||
triggered_from=triggered_from.value,
|
||||
version=workflow.version,
|
||||
graph=workflow.graph,
|
||||
inputs=json.dumps(inputs),
|
||||
status=WorkflowRunStatus.RUNNING.value,
|
||||
created_by_role=(CreatedByRole.ACCOUNT.value
|
||||
if isinstance(user, Account) else CreatedByRole.END_USER.value),
|
||||
created_by=user.id
|
||||
)
|
||||
|
||||
db.session.add(workflow_run)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _workflow_run_success(self, workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Optional[str] = None) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run success
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param outputs: outputs
|
||||
:return:
|
||||
"""
|
||||
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
|
||||
workflow_run.outputs = outputs
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.utcnow()
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _workflow_run_failed(self, workflow_run: WorkflowRun,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
status: WorkflowRunStatus,
|
||||
error: str) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run failed
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param status: status
|
||||
:param error: error message
|
||||
:return:
|
||||
"""
|
||||
workflow_run.status = status.value
|
||||
workflow_run.error = error
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.utcnow()
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_run)
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _init_node_execution_from_workflow_run(self, workflow_run: WorkflowRun,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_title: str,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Init workflow node execution from workflow run
|
||||
:param workflow_run: workflow run
|
||||
:param node_id: node id
|
||||
:param node_type: node type
|
||||
:param node_title: node title
|
||||
:param node_run_index: run index
|
||||
:param predecessor_node_id: predecessor node id if exists
|
||||
:return:
|
||||
"""
|
||||
# init workflow node execution
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
workflow_run_id=workflow_run.id,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
index=node_run_index,
|
||||
node_id=node_id,
|
||||
node_type=node_type.value,
|
||||
title=node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=workflow_run.created_by_role,
|
||||
created_by=workflow_run.created_by
|
||||
)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_node_execution_success(self, workflow_node_execution: WorkflowNodeExecution,
|
||||
start_at: float,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution success
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:param start_at: start time
|
||||
:param inputs: inputs
|
||||
:param process_data: process data
|
||||
:param outputs: outputs
|
||||
:param execution_metadata: execution metadata
|
||||
:return:
|
||||
"""
|
||||
inputs = WorkflowEngineManager.handle_special_values(inputs)
|
||||
outputs = WorkflowEngineManager.handle_special_values(outputs)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = json.dumps(jsonable_encoder(execution_metadata)) \
|
||||
if execution_metadata else None
|
||||
workflow_node_execution.finished_at = datetime.utcnow()
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_node_execution_failed(self, workflow_node_execution: WorkflowNodeExecution,
|
||||
start_at: float,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:param start_at: start time
|
||||
:param error: error message
|
||||
:return:
|
||||
"""
|
||||
inputs = WorkflowEngineManager.handle_special_values(inputs)
|
||||
outputs = WorkflowEngineManager.handle_special_values(outputs)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_node_execution.finished_at = datetime.utcnow()
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _workflow_start_to_stream_response(self, task_id: str,
|
||||
workflow_run: WorkflowRun) -> WorkflowStartStreamResponse:
|
||||
"""
|
||||
Workflow start to stream response.
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:return:
|
||||
"""
|
||||
return WorkflowStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=WorkflowStartStreamResponse.Data(
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
inputs=workflow_run.inputs_dict,
|
||||
created_at=int(workflow_run.created_at.timestamp())
|
||||
)
|
||||
)
|
||||
|
||||
def _workflow_finish_to_stream_response(self, task_id: str,
|
||||
workflow_run: WorkflowRun) -> WorkflowFinishStreamResponse:
|
||||
"""
|
||||
Workflow finish to stream response.
|
||||
:param task_id: task id
|
||||
:param workflow_run: workflow run
|
||||
:return:
|
||||
"""
|
||||
created_by = None
|
||||
if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value:
|
||||
created_by_account = workflow_run.created_by_account
|
||||
if created_by_account:
|
||||
created_by = {
|
||||
"id": created_by_account.id,
|
||||
"name": created_by_account.name,
|
||||
"email": created_by_account.email,
|
||||
}
|
||||
else:
|
||||
created_by_end_user = workflow_run.created_by_end_user
|
||||
if created_by_end_user:
|
||||
created_by = {
|
||||
"id": created_by_end_user.id,
|
||||
"user": created_by_end_user.session_id,
|
||||
}
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
status=workflow_run.status,
|
||||
outputs=workflow_run.outputs_dict,
|
||||
error=workflow_run.error,
|
||||
elapsed_time=workflow_run.elapsed_time,
|
||||
total_tokens=workflow_run.total_tokens,
|
||||
total_steps=workflow_run.total_steps,
|
||||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_run.outputs_dict)
|
||||
)
|
||||
)
|
||||
|
||||
def _workflow_node_start_to_stream_response(self, event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution) \
|
||||
-> NodeStartStreamResponse:
|
||||
"""
|
||||
Workflow node start to stream response.
|
||||
:param event: queue node started event
|
||||
:param task_id: task id
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
title=workflow_node_execution.title,
|
||||
index=workflow_node_execution.index,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp())
|
||||
)
|
||||
)
|
||||
|
||||
# extras logic
|
||||
if event.node_type == NodeType.TOOL:
|
||||
node_data = cast(ToolNodeData, event.node_data)
|
||||
response.data.extras['icon'] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _workflow_node_finish_to_stream_response(self, task_id: str, workflow_node_execution: WorkflowNodeExecution) \
|
||||
-> NodeFinishStreamResponse:
|
||||
"""
|
||||
Workflow node finish to stream response.
|
||||
:param task_id: task id
|
||||
:param workflow_node_execution: workflow node execution
|
||||
:return:
|
||||
"""
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
process_data=workflow_node_execution.process_data_dict,
|
||||
outputs=workflow_node_execution.outputs_dict,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict)
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_workflow_start(self) -> WorkflowRun:
|
||||
self._task_state.start_at = time.perf_counter()
|
||||
|
||||
workflow_run = self._init_workflow_run(
|
||||
workflow=self._workflow,
|
||||
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN,
|
||||
user=self._user,
|
||||
user_inputs=self._application_generate_entity.inputs,
|
||||
system_inputs=self._workflow_system_variables
|
||||
)
|
||||
|
||||
self._task_state.workflow_run_id = workflow_run.id
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_node_start(self, event: QueueNodeStartedEvent) -> WorkflowNodeExecution:
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
workflow_node_execution = self._init_node_execution_from_workflow_run(
|
||||
workflow_run=workflow_run,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_data.title,
|
||||
node_run_index=event.node_run_index,
|
||||
predecessor_node_id=event.predecessor_node_id
|
||||
)
|
||||
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=event.node_type,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
self._task_state.total_steps += 1
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_node_finished(self, event: QueueNodeSucceededEvent | QueueNodeFailedEvent) -> WorkflowNodeExecution:
|
||||
current_node_execution = self._task_state.ran_node_execution_infos[event.node_id]
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_node_execution.workflow_node_execution_id).first()
|
||||
if isinstance(event, QueueNodeSucceededEvent):
|
||||
workflow_node_execution = self._workflow_node_execution_success(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=current_node_execution.start_at,
|
||||
inputs=event.inputs,
|
||||
process_data=event.process_data,
|
||||
outputs=event.outputs,
|
||||
execution_metadata=event.execution_metadata
|
||||
)
|
||||
|
||||
if event.execution_metadata and event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
self._task_state.total_tokens += (
|
||||
int(event.execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS)))
|
||||
|
||||
if workflow_node_execution.node_type == NodeType.LLM.value:
|
||||
outputs = workflow_node_execution.outputs_dict
|
||||
usage_dict = outputs.get('usage', {})
|
||||
self._task_state.metadata['usage'] = usage_dict
|
||||
else:
|
||||
workflow_node_execution = self._workflow_node_execution_failed(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=current_node_execution.start_at,
|
||||
error=event.error,
|
||||
inputs=event.inputs,
|
||||
process_data=event.process_data,
|
||||
outputs=event.outputs
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_finished(self, event: QueueStopEvent | QueueWorkflowSucceededEvent | QueueWorkflowFailedEvent) \
|
||||
-> Optional[WorkflowRun]:
|
||||
workflow_run = db.session.query(WorkflowRun).filter(
|
||||
WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
if not workflow_run:
|
||||
return None
|
||||
|
||||
if isinstance(event, QueueStopEvent):
|
||||
workflow_run = self._workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=self._task_state.start_at,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
status=WorkflowRunStatus.STOPPED,
|
||||
error='Workflow stopped.'
|
||||
)
|
||||
|
||||
latest_node_execution_info = self._task_state.latest_node_execution_info
|
||||
if latest_node_execution_info:
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == latest_node_execution_info.workflow_node_execution_id).first()
|
||||
if (workflow_node_execution
|
||||
and workflow_node_execution.status == WorkflowNodeExecutionStatus.RUNNING.value):
|
||||
self._workflow_node_execution_failed(
|
||||
workflow_node_execution=workflow_node_execution,
|
||||
start_at=latest_node_execution_info.start_at,
|
||||
error='Workflow stopped.'
|
||||
)
|
||||
elif isinstance(event, QueueWorkflowFailedEvent):
|
||||
workflow_run = self._workflow_run_failed(
|
||||
workflow_run=workflow_run,
|
||||
start_at=self._task_state.start_at,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
status=WorkflowRunStatus.FAILED,
|
||||
error=event.error
|
||||
)
|
||||
else:
|
||||
if self._task_state.latest_node_execution_info:
|
||||
workflow_node_execution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == self._task_state.latest_node_execution_info.workflow_node_execution_id).first()
|
||||
outputs = workflow_node_execution.outputs
|
||||
else:
|
||||
outputs = None
|
||||
|
||||
workflow_run = self._workflow_run_success(
|
||||
workflow_run=workflow_run,
|
||||
start_at=self._task_state.start_at,
|
||||
total_tokens=self._task_state.total_tokens,
|
||||
total_steps=self._task_state.total_steps,
|
||||
outputs=outputs
|
||||
)
|
||||
|
||||
self._task_state.workflow_run_id = workflow_run.id
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: dict) -> list[dict]:
|
||||
"""
|
||||
Fetch files from node outputs
|
||||
:param outputs_dict: node outputs dict
|
||||
:return:
|
||||
"""
|
||||
if not outputs_dict:
|
||||
return []
|
||||
|
||||
files = []
|
||||
for output_var, output_value in outputs_dict.items():
|
||||
file_vars = self._fetch_files_from_variable_value(output_value)
|
||||
if file_vars:
|
||||
files.extend(file_vars)
|
||||
|
||||
return files
|
||||
|
||||
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> list[dict]:
|
||||
"""
|
||||
Fetch files from variable value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return []
|
||||
|
||||
files = []
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
file_var = self._get_file_var_from_value(item)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
elif isinstance(value, dict):
|
||||
file_var = self._get_file_var_from_value(value)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
|
||||
return files
|
||||
|
||||
def _get_file_var_from_value(self, value: Union[dict, list]) -> Optional[dict]:
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
if '__variant' in value and value['__variant'] == FileVar.__name__:
|
||||
return value
|
||||
elif isinstance(value, FileVar):
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
Reference in New Issue
Block a user