This commit is contained in:
Ricky
2024-01-31 11:58:07 +08:00
committed by GitHub
parent 9e37702d24
commit 2660fbaa20
58 changed files with 312 additions and 312 deletions

View File

@@ -12,7 +12,7 @@ from core.tools.errors import ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
from core.tools.provider.app_tool_provider import AppBasedToolProviderEntity
from core.tools.entities.user_entities import UserToolProvider
from core.tools.utils.configration import ToolConfiguration
from core.tools.utils.configuration import ToolConfiguration
from core.tools.utils.encoder import serialize_base_model_dict
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
@@ -115,7 +115,7 @@ class ToolManager:
return tool
@staticmethod
def get_tool(provider_type: str, provider_id: str, tool_name: str, tanent_id: str = None) \
def get_tool(provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
-> Union[BuiltinTool, ApiTool]:
"""
get the tool
@@ -129,9 +129,9 @@ class ToolManager:
if provider_type == 'builtin':
return ToolManager.get_builtin_tool(provider_id, tool_name)
elif provider_type == 'api':
if tanent_id is None:
raise ValueError('tanent id is required for api provider')
api_provider, _ = ToolManager.get_api_provider_controller(tanent_id, provider_id)
if tenant_id is None:
raise ValueError('tenant id is required for api provider')
api_provider, _ = ToolManager.get_api_provider_controller(tenant_id, provider_id)
return api_provider.get_tool(tool_name)
elif provider_type == 'app':
raise NotImplementedError('app provider not implemented')
@@ -139,7 +139,7 @@ class ToolManager:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tanent_id,
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id,
agent_callback: DifyAgentCallbackHandler = None) \
-> Union[BuiltinTool, ApiTool]:
"""
@@ -158,13 +158,13 @@ class ToolManager:
provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials:
return builtin_tool.fork_tool_runtime(meta={
'tenant_id': tanent_id,
'tenant_id': tenant_id,
'credentials': {},
}, agent_callback=agent_callback)
# get credentials
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tanent_id,
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
).first()
@@ -174,28 +174,28 @@ class ToolManager:
# decrypt the credentials
credentials = builtin_provider.credentials
controller = ToolManager.get_builtin_provider(provider_name)
tool_configuration = ToolConfiguration(tenant_id=tanent_id, provider_controller=controller)
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=controller)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return builtin_tool.fork_tool_runtime(meta={
'tenant_id': tanent_id,
'tenant_id': tenant_id,
'credentials': decrypted_credentials,
'runtime_parameters': {}
}, agent_callback=agent_callback)
elif provider_type == 'api':
if tanent_id is None:
raise ValueError('tanent id is required for api provider')
if tenant_id is None:
raise ValueError('tenant id is required for api provider')
api_provider, credentials = ToolManager.get_api_provider_controller(tanent_id, provider_name)
api_provider, credentials = ToolManager.get_api_provider_controller(tenant_id, provider_name)
# decrypt the credentials
tool_configuration = ToolConfiguration(tenant_id=tanent_id, provider_controller=api_provider)
tool_configuration = ToolConfiguration(tenant_id=tenant_id, provider_controller=api_provider)
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
'tenant_id': tanent_id,
'tenant_id': tenant_id,
'credentials': decrypted_credentials,
})
elif provider_type == 'app':
@@ -321,7 +321,7 @@ class ToolManager:
schema = provider.get_credentials_schema()
for name, value in schema.items():
result_providers[provider.identity.name].team_credentials[name] = \
ToolProviderCredentials.CredentialsType.defaut(value.type)
ToolProviderCredentials.CredentialsType.default(value.type)
# check if the provider need credentials
if not provider.need_credentials:
@@ -398,7 +398,7 @@ class ToolManager:
return BuiltinToolProviderSort.sort(list(result_providers.values()))
@staticmethod
def get_api_provider_controller(tanent_id: str, provider_id: str) -> Tuple[ApiBasedToolProviderController, Dict[str, Any]]:
def get_api_provider_controller(tenant_id: str, provider_id: str) -> Tuple[ApiBasedToolProviderController, Dict[str, Any]]:
"""
get the api provider
@@ -408,7 +408,7 @@ class ToolManager:
"""
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
ApiToolProvider.id == provider_id,
ApiToolProvider.tenant_id == tanent_id,
ApiToolProvider.tenant_id == tenant_id,
).first()
if provider is None:
@@ -435,7 +435,7 @@ class ToolManager:
).first()
if provider is None:
raise ValueError(f'yout have not added provider {provider}')
raise ValueError(f'you have not added provider {provider}')
try:
credentials = json.loads(provider.credentials_str) or {}