FEAT: NEW WORKFLOW ENGINE (#3160)
Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: nite-knite <nkCoding@gmail.com> Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
@@ -8,7 +8,8 @@ import requests
|
||||
|
||||
import core.helper.ssrf_proxy as ssrf_proxy
|
||||
from core.tools.entities.tool_bundle import ApiBasedToolBundle
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
@@ -34,7 +35,7 @@ class ApiTool(Tool):
|
||||
api_bundle=self.api_bundle.copy() if self.api_bundle else None,
|
||||
runtime=Tool.Runtime(**meta)
|
||||
)
|
||||
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False) -> str:
|
||||
"""
|
||||
validate the credentials for Api tool
|
||||
@@ -49,6 +50,9 @@ class ApiTool(Tool):
|
||||
# validate response
|
||||
return self.validate_and_parse_response(response)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return UserToolProvider.ProviderType.API
|
||||
|
||||
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
headers = {}
|
||||
credentials = self.runtime.credentials or {}
|
||||
|
@@ -1,6 +1,8 @@
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.user_entities import UserToolProvider
|
||||
from core.tools.model.tool_model_manager import ToolModelManager
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.web_reader_tool import get_url
|
||||
@@ -40,6 +42,9 @@ class BuiltinTool(Tool):
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return UserToolProvider.ProviderType.BUILTIN
|
||||
|
||||
def get_max_tokens(self) -> int:
|
||||
"""
|
||||
get max tokens
|
||||
|
@@ -2,11 +2,18 @@ from typing import Any
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import DatasetRetrieveConfigEntity, InvokeFrom
|
||||
from core.features.dataset_retrieval.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolDescription, ToolIdentity, ToolInvokeMessage, ToolParameter
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
|
||||
@@ -30,7 +37,7 @@ class DatasetRetrieverTool(Tool):
|
||||
if retrieve_config is None:
|
||||
return []
|
||||
|
||||
feature = DatasetRetrievalFeature()
|
||||
feature = DatasetRetrieval()
|
||||
|
||||
# save original retrieve strategy, and set retrieve strategy to SINGLE
|
||||
# Agent only support SINGLE mode
|
||||
@@ -52,7 +59,7 @@ class DatasetRetrieverTool(Tool):
|
||||
for langchain_tool in langchain_tools:
|
||||
tool = DatasetRetrieverTool(
|
||||
langchain_tool=langchain_tool,
|
||||
identity=ToolIdentity(author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
|
||||
identity=ToolIdentity(provider='', author='', name=langchain_tool.name, label=I18nObject(en_US='', zh_Hans='')),
|
||||
parameters=[],
|
||||
is_team_authorization=True,
|
||||
description=ToolDescription(
|
||||
@@ -76,6 +83,9 @@ class DatasetRetrieverTool(Tool):
|
||||
required=True,
|
||||
default=''),
|
||||
]
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.DATASET_RETRIEVAL
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
|
@@ -11,7 +11,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage
|
||||
from core.tools.entities.tool_entities import ModelToolPropertyKey, ToolInvokeMessage, ToolProviderType
|
||||
from core.tools.tool.tool import Tool
|
||||
|
||||
VISION_PROMPT = """## Image Recognition Task
|
||||
@@ -79,6 +79,9 @@ class ModelTool(Tool):
|
||||
"""
|
||||
pass
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
|
||||
"""
|
||||
"""
|
||||
|
@@ -2,14 +2,14 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
ToolRuntimeImageVariable,
|
||||
ToolRuntimeVariable,
|
||||
ToolRuntimeVariablePool,
|
||||
@@ -22,8 +22,13 @@ class Tool(BaseModel, ABC):
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
description: ToolDescription = None
|
||||
is_team_authorization: bool = False
|
||||
agent_callback: Optional[DifyAgentCallbackHandler] = None
|
||||
use_callback: bool = False
|
||||
|
||||
@validator('parameters', pre=True, always=True)
|
||||
def set_parameters(cls, v, values):
|
||||
if not v:
|
||||
return []
|
||||
|
||||
return v
|
||||
|
||||
class Runtime(BaseModel):
|
||||
"""
|
||||
@@ -45,15 +50,10 @@ class Tool(BaseModel, ABC):
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
|
||||
if not self.agent_callback:
|
||||
self.use_callback = False
|
||||
else:
|
||||
self.use_callback = True
|
||||
|
||||
class VARIABLE_KEY(Enum):
|
||||
IMAGE = 'image'
|
||||
|
||||
def fork_tool_runtime(self, meta: dict[str, Any], agent_callback: DifyAgentCallbackHandler = None) -> 'Tool':
|
||||
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
@@ -65,9 +65,16 @@ class Tool(BaseModel, ABC):
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.copy() if self.description else None,
|
||||
runtime=Tool.Runtime(**meta),
|
||||
agent_callback=agent_callback
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
"""
|
||||
get the tool provider type
|
||||
|
||||
:return: the tool provider type
|
||||
"""
|
||||
|
||||
def load_variables(self, variables: ToolRuntimeVariablePool):
|
||||
"""
|
||||
load variables from database
|
||||
@@ -174,50 +181,22 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
return result
|
||||
|
||||
def invoke(self, user_id: str, tool_parameters: Union[dict[str, Any], str]) -> list[ToolInvokeMessage]:
|
||||
# check if tool_parameters is a string
|
||||
if isinstance(tool_parameters, str):
|
||||
# check if this tool has only one parameter
|
||||
parameters = [parameter for parameter in self.parameters if parameter.form == ToolParameter.ToolParameterForm.LLM]
|
||||
if parameters and len(parameters) == 1:
|
||||
tool_parameters = {
|
||||
parameters[0].name: tool_parameters
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
||||
|
||||
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
# update tool_parameters
|
||||
if self.runtime.runtime_parameters:
|
||||
tool_parameters.update(self.runtime.runtime_parameters)
|
||||
|
||||
# hit callback
|
||||
if self.use_callback:
|
||||
self.agent_callback.on_tool_start(
|
||||
tool_name=self.identity.name,
|
||||
tool_inputs=tool_parameters
|
||||
)
|
||||
# try parse tool parameters into the correct type
|
||||
tool_parameters = self._transform_tool_parameters_type(tool_parameters)
|
||||
|
||||
try:
|
||||
result = self._invoke(
|
||||
user_id=user_id,
|
||||
tool_parameters=tool_parameters,
|
||||
)
|
||||
except Exception as e:
|
||||
if self.use_callback:
|
||||
self.agent_callback.on_tool_error(e)
|
||||
raise e
|
||||
result = self._invoke(
|
||||
user_id=user_id,
|
||||
tool_parameters=tool_parameters,
|
||||
)
|
||||
|
||||
if not isinstance(result, list):
|
||||
result = [result]
|
||||
|
||||
# hit callback
|
||||
if self.use_callback:
|
||||
self.agent_callback.on_tool_end(
|
||||
tool_name=self.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=self._convert_tool_response_to_str(result)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||
@@ -242,6 +221,31 @@ class Tool(BaseModel, ABC):
|
||||
result += f"tool response: {response.message}."
|
||||
|
||||
return result
|
||||
|
||||
def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Transform tool parameters type
|
||||
"""
|
||||
for parameter in self.parameters:
|
||||
if parameter.name in tool_parameters:
|
||||
if parameter.type in [
|
||||
ToolParameter.ToolParameterType.SECRET_INPUT,
|
||||
ToolParameter.ToolParameterType.STRING,
|
||||
ToolParameter.ToolParameterType.SELECT,
|
||||
] and not isinstance(tool_parameters[parameter.name], str):
|
||||
tool_parameters[parameter.name] = str(tool_parameters[parameter.name])
|
||||
elif parameter.type == ToolParameter.ToolParameterType.NUMBER \
|
||||
and not isinstance(tool_parameters[parameter.name], int | float):
|
||||
if isinstance(tool_parameters[parameter.name], str):
|
||||
try:
|
||||
tool_parameters[parameter.name] = int(tool_parameters[parameter.name])
|
||||
except ValueError:
|
||||
tool_parameters[parameter.name] = float(tool_parameters[parameter.name])
|
||||
elif parameter.type == ToolParameter.ToolParameterType.BOOLEAN:
|
||||
if not isinstance(tool_parameters[parameter.name], bool):
|
||||
tool_parameters[parameter.name] = bool(tool_parameters[parameter.name])
|
||||
|
||||
return tool_parameters
|
||||
|
||||
@abstractmethod
|
||||
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
|
Reference in New Issue
Block a user