chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -8,46 +8,47 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: Literal['builtin', 'api', 'workflow']
|
||||
provider_name: str # redundancy
|
||||
provider_type: Literal["builtin", "api", "workflow"]
|
||||
provider_name: str # redundancy
|
||||
tool_name: str
|
||||
tool_label: str # redundancy
|
||||
tool_label: str # redundancy
|
||||
tool_configurations: dict[str, Any]
|
||||
|
||||
@field_validator('tool_configurations', mode='before')
|
||||
@field_validator("tool_configurations", mode="before")
|
||||
@classmethod
|
||||
def validate_tool_configurations(cls, value, values: ValidationInfo):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError('tool_configurations must be a dictionary')
|
||||
|
||||
for key in values.data.get('tool_configurations', {}).keys():
|
||||
value = values.data.get('tool_configurations', {}).get(key)
|
||||
raise ValueError("tool_configurations must be a dictionary")
|
||||
|
||||
for key in values.data.get("tool_configurations", {}).keys():
|
||||
value = values.data.get("tool_configurations", {}).get(key)
|
||||
if not isinstance(value, str | int | float | bool):
|
||||
raise ValueError(f'{key} must be a string')
|
||||
|
||||
raise ValueError(f"{key} must be a string")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal['mixed', 'variable', 'constant']
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
|
||||
@field_validator('type', mode='before')
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = validation_info.data.get('value')
|
||||
if typ == 'mixed' and not isinstance(value, str):
|
||||
raise ValueError('value must be a string')
|
||||
elif typ == 'variable':
|
||||
value = validation_info.data.get("value")
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError('value must be a list')
|
||||
raise ValueError("value must be a list")
|
||||
for val in value:
|
||||
if not isinstance(val, str):
|
||||
raise ValueError('value must be a list of strings')
|
||||
elif typ == 'constant' and not isinstance(value, str | int | float | bool):
|
||||
raise ValueError('value must be a string, int, float, or bool')
|
||||
raise ValueError("value must be a list of strings")
|
||||
elif typ == "constant" and not isinstance(value, str | int | float | bool):
|
||||
raise ValueError("value must be a string, int, float, or bool")
|
||||
return typ
|
||||
|
||||
"""
|
||||
|
@@ -34,10 +34,7 @@ class ToolNode(BaseNode):
|
||||
node_data = cast(ToolNodeData, self.node_data)
|
||||
|
||||
# fetch tool icon
|
||||
tool_info = {
|
||||
'provider_type': node_data.provider_type,
|
||||
'provider_id': node_data.provider_id
|
||||
}
|
||||
tool_info = {"provider_type": node_data.provider_type, "provider_id": node_data.provider_id}
|
||||
|
||||
# get tool runtime
|
||||
try:
|
||||
@@ -48,16 +45,21 @@ class ToolNode(BaseNode):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
error=f'Failed to get tool runtime: {str(e)}'
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to get tool runtime: {str(e)}",
|
||||
)
|
||||
|
||||
# get parameters
|
||||
tool_parameters = tool_runtime.get_runtime_parameters() or []
|
||||
parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data)
|
||||
parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True)
|
||||
parameters = self._generate_parameters(
|
||||
tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data
|
||||
)
|
||||
parameters_for_log = self._generate_parameters(
|
||||
tool_parameters=tool_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=node_data,
|
||||
for_log=True,
|
||||
)
|
||||
|
||||
try:
|
||||
messages = ToolEngine.workflow_invoke(
|
||||
@@ -72,10 +74,8 @@ class ToolNode(BaseNode):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
error=f'Failed to invoke tool: {str(e)}',
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
error=f"Failed to invoke tool: {str(e)}",
|
||||
)
|
||||
|
||||
# convert tool messages
|
||||
@@ -83,15 +83,9 @@ class ToolNode(BaseNode):
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
'text': plain_text,
|
||||
'files': files,
|
||||
'json': json
|
||||
},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
inputs=parameters_for_log
|
||||
outputs={"text": plain_text, "files": files, "json": json},
|
||||
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
|
||||
inputs=parameters_for_log,
|
||||
)
|
||||
|
||||
def _generate_parameters(
|
||||
@@ -123,12 +117,10 @@ class ToolNode(BaseNode):
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
result[parameter_name] = [
|
||||
v.to_dict() for v in self._fetch_files(variable_pool)
|
||||
]
|
||||
result[parameter_name] = [v.to_dict() for v in self._fetch_files(variable_pool)]
|
||||
else:
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == 'variable':
|
||||
if tool_input.type == "variable":
|
||||
# TODO: check if the variable exists in the variable pool
|
||||
parameter_value = variable_pool.get(tool_input.value).value
|
||||
else:
|
||||
@@ -142,12 +134,11 @@ class ToolNode(BaseNode):
|
||||
return result
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||
variable = variable_pool.get(['sys', SystemVariableKey.FILES.value])
|
||||
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||
return list(variable.value) if variable else []
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\
|
||||
-> tuple[str, list[FileVar], list[dict]]:
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar], list[dict]]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
@@ -172,38 +163,44 @@ class ToolNode(BaseNode):
|
||||
result = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
if (
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE_LINK
|
||||
or response.type == ToolInvokeMessage.MessageType.IMAGE
|
||||
):
|
||||
url = response.message
|
||||
ext = path.splitext(url)[1]
|
||||
mimetype = response.meta.get('mime_type', 'image/jpeg')
|
||||
filename = response.save_as or url.split('/')[-1]
|
||||
transfer_method = response.meta.get('transfer_method', FileTransferMethod.TOOL_FILE)
|
||||
mimetype = response.meta.get("mime_type", "image/jpeg")
|
||||
filename = response.save_as or url.split("/")[-1]
|
||||
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
|
||||
# get tool file id
|
||||
tool_file_id = url.split('/')[-1].split('.')[0]
|
||||
result.append(FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
url=url,
|
||||
related_id=tool_file_id,
|
||||
filename=filename,
|
||||
extension=ext,
|
||||
mime_type=mimetype,
|
||||
))
|
||||
tool_file_id = url.split("/")[-1].split(".")[0]
|
||||
result.append(
|
||||
FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
url=url,
|
||||
related_id=tool_file_id,
|
||||
filename=filename,
|
||||
extension=ext,
|
||||
mime_type=mimetype,
|
||||
)
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
tool_file_id = response.message.split('/')[-1].split('.')[0]
|
||||
result.append(FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file_id,
|
||||
filename=response.save_as,
|
||||
extension=path.splitext(response.save_as)[1],
|
||||
mime_type=response.meta.get('mime_type', 'application/octet-stream'),
|
||||
))
|
||||
tool_file_id = response.message.split("/")[-1].split(".")[0]
|
||||
result.append(
|
||||
FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file_id,
|
||||
filename=response.save_as,
|
||||
extension=path.splitext(response.save_as)[1],
|
||||
mime_type=response.meta.get("mime_type", "application/octet-stream"),
|
||||
)
|
||||
)
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
pass # TODO:
|
||||
|
||||
@@ -213,21 +210,23 @@ class ToolNode(BaseNode):
|
||||
"""
|
||||
Extract tool response text
|
||||
"""
|
||||
return '\n'.join([
|
||||
f'{message.message}' if message.type == ToolInvokeMessage.MessageType.TEXT else
|
||||
f'Link: {message.message}' if message.type == ToolInvokeMessage.MessageType.LINK else ''
|
||||
for message in tool_response
|
||||
])
|
||||
return "\n".join(
|
||||
[
|
||||
f"{message.message}"
|
||||
if message.type == ToolInvokeMessage.MessageType.TEXT
|
||||
else f"Link: {message.message}"
|
||||
if message.type == ToolInvokeMessage.MessageType.LINK
|
||||
else ""
|
||||
for message in tool_response
|
||||
]
|
||||
)
|
||||
|
||||
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]:
|
||||
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: ToolNodeData
|
||||
cls, graph_config: Mapping[str, Any], node_id: str, node_data: ToolNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
@@ -239,17 +238,15 @@ class ToolNode(BaseNode):
|
||||
result = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
if input.type == 'mixed':
|
||||
if input.type == "mixed":
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == 'variable':
|
||||
elif input.type == "variable":
|
||||
result[parameter_name] = input.value
|
||||
elif input.type == 'constant':
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
|
||||
result = {
|
||||
node_id + '.' + key: value for key, value in result.items()
|
||||
}
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
Reference in New Issue
Block a user