Feat/workflow phase2 (#4687)
This commit is contained in:
@@ -1,7 +1,10 @@
|
||||
from copy import deepcopy
|
||||
from datetime import datetime, timezone
|
||||
from mimetypes import guess_type
|
||||
from typing import Union
|
||||
|
||||
from yarl import URL
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
@@ -17,6 +20,7 @@ from core.tools.errors import (
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool.workflow_tool import WorkflowTool
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, MessageFile
|
||||
@@ -115,7 +119,8 @@ class ToolEngine:
|
||||
@staticmethod
|
||||
def workflow_invoke(tool: Tool, tool_parameters: dict,
|
||||
user_id: str, workflow_id: str,
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler) \
|
||||
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
||||
workflow_call_depth: int) \
|
||||
-> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Workflow invokes the tool with the given arguments.
|
||||
@@ -127,6 +132,9 @@ class ToolEngine:
|
||||
tool_inputs=tool_parameters
|
||||
)
|
||||
|
||||
if isinstance(tool, WorkflowTool):
|
||||
tool.workflow_call_depth = workflow_call_depth + 1
|
||||
|
||||
response = tool.invoke(user_id, tool_parameters)
|
||||
|
||||
# hit the callback handler
|
||||
@@ -195,8 +203,24 @@ class ToolEngine:
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
mimetype = None
|
||||
if response.meta.get('mime_type'):
|
||||
mimetype = response.meta.get('mime_type')
|
||||
else:
|
||||
try:
|
||||
url = URL(response.message)
|
||||
extension = url.suffix
|
||||
guess_type_result, _ = guess_type(f'a{extension}')
|
||||
if guess_type_result:
|
||||
mimetype = guess_type_result
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not mimetype:
|
||||
mimetype = 'image/jpeg'
|
||||
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
||||
mimetype=response.meta.get('mime_type', 'image/jpeg'),
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
|
Reference in New Issue
Block a user