Feat/workflow phase2 (#4687)
This commit is contained in:
@@ -7,7 +7,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: Literal['builtin', 'api']
|
||||
provider_type: Literal['builtin', 'api', 'workflow']
|
||||
provider_name: str # redundancy
|
||||
tool_name: str
|
||||
tool_label: str # redundancy
|
||||
|
@@ -1,13 +1,14 @@
|
||||
from os import path
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
@@ -35,20 +36,23 @@ class ToolNode(BaseNode):
|
||||
'provider_id': node_data.provider_id
|
||||
}
|
||||
|
||||
# get parameters
|
||||
parameters = self._generate_parameters(variable_pool, node_data)
|
||||
# get tool runtime
|
||||
try:
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, self.app_id, self.node_id, node_data)
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters,
|
||||
inputs={},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
error=f'Failed to get tool runtime: {str(e)}'
|
||||
)
|
||||
|
||||
# get parameters
|
||||
parameters = self._generate_parameters(variable_pool, node_data, tool_runtime)
|
||||
|
||||
try:
|
||||
messages = ToolEngine.workflow_invoke(
|
||||
@@ -56,7 +60,8 @@ class ToolNode(BaseNode):
|
||||
tool_parameters=parameters,
|
||||
user_id=self.user_id,
|
||||
workflow_id=self.workflow_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler()
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth + 1
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
@@ -83,19 +88,32 @@ class ToolNode(BaseNode):
|
||||
inputs=parameters
|
||||
)
|
||||
|
||||
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict:
|
||||
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData, tool_runtime: Tool) -> dict:
|
||||
"""
|
||||
Generate parameters
|
||||
"""
|
||||
tool_parameters = tool_runtime.get_all_runtime_parameters()
|
||||
|
||||
def fetch_parameter(name: str) -> Optional[ToolParameter]:
|
||||
return next((parameter for parameter in tool_parameters if parameter.name == name), None)
|
||||
|
||||
result = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
if input.type == 'mixed':
|
||||
result[parameter_name] = self._format_variable_template(input.value, variable_pool)
|
||||
elif input.type == 'variable':
|
||||
result[parameter_name] = variable_pool.get_variable_value(input.value)
|
||||
elif input.type == 'constant':
|
||||
result[parameter_name] = input.value
|
||||
parameter = fetch_parameter(parameter_name)
|
||||
if not parameter:
|
||||
continue
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
result[parameter_name] = [
|
||||
v.to_dict() for v in self._fetch_files(variable_pool)
|
||||
]
|
||||
else:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
if input.type == 'mixed':
|
||||
result[parameter_name] = self._format_variable_template(input.value, variable_pool)
|
||||
elif input.type == 'variable':
|
||||
result[parameter_name] = variable_pool.get_variable_value(input.value)
|
||||
elif input.type == 'constant':
|
||||
result[parameter_name] = input.value
|
||||
|
||||
return result
|
||||
|
||||
@@ -109,6 +127,13 @@ class ToolNode(BaseNode):
|
||||
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
|
||||
|
||||
return template_parser.format(inputs)
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||
files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value])
|
||||
if not files:
|
||||
return []
|
||||
|
||||
return files
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user