feat: support pinning, including, and excluding for model providers and tools (#7419)

Co-authored-by: GareArc <chen4851@purude.edu>
This commit is contained in:
Xiyuan Chen
2024-08-20 23:16:43 -04:00
committed by GitHub
parent 6c25d7bed3
commit 4e7b6aec3a
14 changed files with 363 additions and 57 deletions

View File

@@ -1,6 +1,8 @@
import json
import logging
from configs import dify_config
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
@@ -43,14 +45,14 @@ class BuiltinToolManageService:
result = []
for tool in tools:
result.append(ToolTransformService.tool_to_user_tool(
tool=tool,
credentials=credentials,
tool=tool,
credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller)
))
return result
@staticmethod
def list_builtin_provider_credentials_schema(
provider_name
@@ -78,7 +80,7 @@ class BuiltinToolManageService:
BuiltinToolProvider.provider == provider_name,
).first()
try:
try:
# get provider
provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials:
@@ -119,8 +121,8 @@ class BuiltinToolManageService:
# delete cache
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' }
return {'result': 'success'}
@staticmethod
def get_builtin_tool_provider_credentials(
user_id: str, tenant_id: str, provider: str
@@ -135,7 +137,7 @@ class BuiltinToolManageService:
if provider is None:
return {}
provider_controller = ToolManager.get_builtin_provider(provider.provider)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
@@ -156,7 +158,7 @@ class BuiltinToolManageService:
if provider is None:
raise ValueError(f'you have not added provider {provider_name}')
db.session.delete(provider)
db.session.commit()
@@ -165,8 +167,8 @@ class BuiltinToolManageService:
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' }
return {'result': 'success'}
@staticmethod
def get_builtin_tool_provider_icon(
provider: str
@@ -179,7 +181,7 @@ class BuiltinToolManageService:
icon_bytes = f.read()
return icon_bytes, mime_type
@staticmethod
def list_builtin_tools(
user_id: str, tenant_id: str
@@ -202,6 +204,15 @@ class BuiltinToolManageService:
for provider_controller in provider_controllers:
try:
# handle include, exclude
if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
data=provider_controller,
name_func=lambda x: x.identity.name
):
continue
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
@@ -226,4 +237,3 @@ class BuiltinToolManageService:
raise e
return BuiltinToolProviderSort.sort(result)