chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -8,46 +8,47 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
class ToolEntity(BaseModel):
provider_id: str
provider_type: Literal['builtin', 'api', 'workflow']
provider_name: str # redundancy
provider_type: Literal["builtin", "api", "workflow"]
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
tool_label: str # redundancy
tool_configurations: dict[str, Any]
@field_validator('tool_configurations', mode='before')
@field_validator("tool_configurations", mode="before")
@classmethod
def validate_tool_configurations(cls, value, values: ValidationInfo):
if not isinstance(value, dict):
raise ValueError('tool_configurations must be a dictionary')
for key in values.data.get('tool_configurations', {}).keys():
value = values.data.get('tool_configurations', {}).get(key)
raise ValueError("tool_configurations must be a dictionary")
for key in values.data.get("tool_configurations", {}).keys():
value = values.data.get("tool_configurations", {}).get(key)
if not isinstance(value, str | int | float | bool):
raise ValueError(f'{key} must be a string')
raise ValueError(f"{key} must be a string")
return value
class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal['mixed', 'variable', 'constant']
type: Literal["mixed", "variable", "constant"]
@field_validator('type', mode='before')
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get('value')
if typ == 'mixed' and not isinstance(value, str):
raise ValueError('value must be a string')
elif typ == 'variable':
value = validation_info.data.get("value")
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError('value must be a list')
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError('value must be a list of strings')
elif typ == 'constant' and not isinstance(value, str | int | float | bool):
raise ValueError('value must be a string, int, float, or bool')
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, str | int | float | bool):
raise ValueError("value must be a string, int, float, or bool")
return typ
"""

View File

@@ -34,10 +34,7 @@ class ToolNode(BaseNode):
node_data = cast(ToolNodeData, self.node_data)
# fetch tool icon
tool_info = {
'provider_type': node_data.provider_type,
'provider_id': node_data.provider_id
}
tool_info = {"provider_type": node_data.provider_type, "provider_id": node_data.provider_id}
# get tool runtime
try:
@@ -48,16 +45,21 @@ class ToolNode(BaseNode):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs={},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to get tool runtime: {str(e)}'
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to get tool runtime: {str(e)}",
)
# get parameters
tool_parameters = tool_runtime.get_runtime_parameters() or []
parameters = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data)
parameters_for_log = self._generate_parameters(tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data, for_log=True)
parameters = self._generate_parameters(
tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=node_data
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
node_data=node_data,
for_log=True,
)
try:
messages = ToolEngine.workflow_invoke(
@@ -72,10 +74,8 @@ class ToolNode(BaseNode):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to invoke tool: {str(e)}',
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
error=f"Failed to invoke tool: {str(e)}",
)
# convert tool messages
@@ -83,15 +83,9 @@ class ToolNode(BaseNode):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'text': plain_text,
'files': files,
'json': json
},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
inputs=parameters_for_log
outputs={"text": plain_text, "files": files, "json": json},
metadata={NodeRunMetadataKey.TOOL_INFO: tool_info},
inputs=parameters_for_log,
)
def _generate_parameters(
@@ -123,12 +117,10 @@ class ToolNode(BaseNode):
result[parameter_name] = None
continue
if parameter.type == ToolParameter.ToolParameterType.FILE:
result[parameter_name] = [
v.to_dict() for v in self._fetch_files(variable_pool)
]
result[parameter_name] = [v.to_dict() for v in self._fetch_files(variable_pool)]
else:
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == 'variable':
if tool_input.type == "variable":
# TODO: check if the variable exists in the variable pool
parameter_value = variable_pool.get(tool_input.value).value
else:
@@ -142,12 +134,11 @@ class ToolNode(BaseNode):
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
variable = variable_pool.get(['sys', SystemVariableKey.FILES.value])
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
return list(variable.value) if variable else []
def _convert_tool_messages(self, messages: list[ToolInvokeMessage])\
-> tuple[str, list[FileVar], list[dict]]:
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar], list[dict]]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
"""
@@ -172,38 +163,44 @@ class ToolNode(BaseNode):
result = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
if (
response.type == ToolInvokeMessage.MessageType.IMAGE_LINK
or response.type == ToolInvokeMessage.MessageType.IMAGE
):
url = response.message
ext = path.splitext(url)[1]
mimetype = response.meta.get('mime_type', 'image/jpeg')
filename = response.save_as or url.split('/')[-1]
transfer_method = response.meta.get('transfer_method', FileTransferMethod.TOOL_FILE)
mimetype = response.meta.get("mime_type", "image/jpeg")
filename = response.save_as or url.split("/")[-1]
transfer_method = response.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
# get tool file id
tool_file_id = url.split('/')[-1].split('.')[0]
result.append(FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
url=url,
related_id=tool_file_id,
filename=filename,
extension=ext,
mime_type=mimetype,
))
tool_file_id = url.split("/")[-1].split(".")[0]
result.append(
FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
url=url,
related_id=tool_file_id,
filename=filename,
extension=ext,
mime_type=mimetype,
)
)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
tool_file_id = response.message.split('/')[-1].split('.')[0]
result.append(FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
filename=response.save_as,
extension=path.splitext(response.save_as)[1],
mime_type=response.meta.get('mime_type', 'application/octet-stream'),
))
tool_file_id = response.message.split("/")[-1].split(".")[0]
result.append(
FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
filename=response.save_as,
extension=path.splitext(response.save_as)[1],
mime_type=response.meta.get("mime_type", "application/octet-stream"),
)
)
elif response.type == ToolInvokeMessage.MessageType.LINK:
pass # TODO:
@@ -213,21 +210,23 @@ class ToolNode(BaseNode):
"""
Extract tool response text
"""
return '\n'.join([
f'{message.message}' if message.type == ToolInvokeMessage.MessageType.TEXT else
f'Link: {message.message}' if message.type == ToolInvokeMessage.MessageType.LINK else ''
for message in tool_response
])
return "\n".join(
[
f"{message.message}"
if message.type == ToolInvokeMessage.MessageType.TEXT
else f"Link: {message.message}"
if message.type == ToolInvokeMessage.MessageType.LINK
else ""
for message in tool_response
]
)
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]:
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
graph_config: Mapping[str, Any],
node_id: str,
node_data: ToolNodeData
cls, graph_config: Mapping[str, Any], node_id: str, node_data: ToolNodeData
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -239,17 +238,15 @@ class ToolNode(BaseNode):
result = {}
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == 'mixed':
if input.type == "mixed":
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == 'variable':
elif input.type == "variable":
result[parameter_name] = input.value
elif input.type == 'constant':
elif input.type == "constant":
pass
result = {
node_id + '.' + key: value for key, value in result.items()
}
result = {node_id + "." + key: value for key, value in result.items()}
return result