Feat/workflow phase2 (#4687)

This commit is contained in:
Yeuoly
2024-05-27 22:01:11 +08:00
committed by GitHub
parent 45deaee762
commit e852a21634
139 changed files with 5997 additions and 779 deletions

View File

@@ -75,6 +75,35 @@ class AppGenerateService:
else:
raise ValueError(f'Invalid app mode {app_model.mode}')
@classmethod
def generate_single_iteration(cls, app_model: App,
user: Union[Account, EndUser],
node_id: str,
args: Any,
streaming: bool = True):
if app_model.mode == AppMode.ADVANCED_CHAT.value:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return AdvancedChatAppGenerator().single_iteration_generate(
app_model=app_model,
workflow=workflow,
node_id=node_id,
user=user,
args=args,
stream=streaming
)
elif app_model.mode == AppMode.WORKFLOW.value:
workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER)
return WorkflowAppGenerator().single_iteration_generate(
app_model=app_model,
workflow=workflow,
node_id=node_id,
user=user,
args=args,
stream=streaming
)
else:
raise ValueError(f'Invalid app mode {app_model.mode}')
@classmethod
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \

View File

@@ -4,97 +4,30 @@ import logging
from httpx import get
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ApiProviderSchemaType,
ToolCredentialsOption,
ToolProviderCredentials,
)
from core.tools.entities.user_entities import UserTool, UserToolProvider
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolConfigurationManager
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider
from services.model_provider_service import ModelProviderService
from services.tools_transform_service import ToolTransformService
from models.tools import ApiToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class ToolManageService:
class ApiToolManageService:
@staticmethod
def list_tool_providers(user_id: str, tenant_id: str):
"""
list tool providers
:return: the list of tool providers
"""
providers = ToolManager.user_list_providers(
user_id, tenant_id
)
# add icon
for provider in providers:
ToolTransformService.repack_provider(provider)
result = [provider.to_dict() for provider in providers]
return result
@staticmethod
def list_builtin_tool_provider_tools(
user_id: str, tenant_id: str, provider: str
) -> list[UserTool]:
"""
list builtin tool provider tools
"""
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
tools = provider_controller.get_tools()
tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
# check if user has added the provider
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
).first()
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
result = []
for tool in tools:
result.append(ToolTransformService.tool_to_user_tool(
tool=tool, credentials=credentials, tenant_id=tenant_id
))
return result
@staticmethod
def list_builtin_provider_credentials_schema(
provider_name
):
"""
list builtin provider credentials schema
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name)
return jsonable_encoder([
v for _, v in (provider.credentials_schema or {}).items()
])
@staticmethod
def parser_api_schema(schema: str) -> list[ApiBasedToolBundle]:
def parser_api_schema(schema: str) -> list[ApiToolBundle]:
"""
parse api schema to tool bundle
"""
@@ -162,7 +95,7 @@ class ToolManageService:
raise ValueError(f'invalid schema: {str(e)}')
@staticmethod
def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiBasedToolBundle]:
def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]:
"""
convert schema to tool bundles
@@ -177,7 +110,7 @@ class ToolManageService:
@staticmethod
def create_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict,
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
):
"""
create api tool provider
@@ -197,7 +130,7 @@ class ToolManageService:
# parse openapi to tool bundle
extra_info = {}
# extra info like description will be set here
tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
if len(tool_bundles) > 100:
raise ValueError('the number of apis should be less than 100')
@@ -224,7 +157,7 @@ class ToolManageService:
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
# create provider entity
provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type)
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
# load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles)
@@ -236,6 +169,9 @@ class ToolManageService:
db.session.add(db_provider)
db.session.commit()
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
return { 'result': 'success' }
@staticmethod
@@ -257,7 +193,7 @@ class ToolManageService:
schema = response.text
# try to parse schema, avoid SSRF attack
ToolManageService.parser_api_schema(schema)
ApiToolManageService.parser_api_schema(schema)
except Exception as e:
logger.error(f"parse api schema error: {str(e)}")
raise ValueError('invalid schema, please check the url you provided')
@@ -281,91 +217,20 @@ class ToolManageService:
if provider is None:
raise ValueError(f'you have not added provider {provider}')
return [
ToolTransformService.tool_to_user_tool(tool_bundle) for tool_bundle in provider.tools
]
@staticmethod
def update_builtin_tool_provider(
user_id: str, tenant_id: str, provider_name: str, credentials: dict
):
"""
update builtin tool provider
"""
# get if the provider exists
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
).first()
try:
# get provider
provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials:
raise ValueError(f'provider {provider_name} does not need credentials')
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
# get original credentials if exists
if provider is not None:
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
# validate credentials
provider_controller.validate_credentials(credentials)
# encrypt credentials
credentials = tool_configuration.encrypt_tool_credentials(credentials)
except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
raise ValueError(str(e))
if provider is None:
# create provider
provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
encrypted_credentials=json.dumps(credentials),
)
db.session.add(provider)
db.session.commit()
else:
provider.encrypted_credentials = json.dumps(credentials)
db.session.add(provider)
db.session.commit()
# delete cache
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' }
@staticmethod
def get_builtin_tool_provider_credentials(
user_id: str, tenant_id: str, provider: str
):
"""
get builtin tool provider credentials
"""
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
).first()
if provider is None:
return {}
controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
labels = ToolLabelManager.get_tool_labels(controller)
provider_controller = ToolManager.get_builtin_provider(provider.provider)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
credentials = tool_configuration.mask_tool_credentials(credentials)
return credentials
return [
ToolTransformService.tool_to_user_tool(
tool_bundle,
labels=labels,
) for tool_bundle in provider.tools
]
@staticmethod
def update_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict,
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
):
"""
update api tool provider
@@ -385,7 +250,7 @@ class ToolManageService:
# parse openapi to tool bundle
extra_info = {}
# extra info like description will be set here
tool_bundles, schema_type = ToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
# update db provider
provider.name = provider_name
@@ -404,7 +269,7 @@ class ToolManageService:
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
# create provider entity
provider_controller = ApiBasedToolProviderController.from_db(provider, auth_type)
provider_controller = ApiToolProviderController.from_db(provider, auth_type)
# load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles)
@@ -427,84 +292,11 @@ class ToolManageService:
# delete cache
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' }
@staticmethod
def delete_builtin_tool_provider(
user_id: str, tenant_id: str, provider_name: str
):
"""
delete tool provider
"""
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
).first()
if provider is None:
raise ValueError(f'you have not added provider {provider_name}')
db.session.delete(provider)
db.session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration.delete_tool_credentials_cache()
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
return { 'result': 'success' }
@staticmethod
def get_builtin_tool_provider_icon(
provider: str
):
"""
get tool provider icon and it's mimetype
"""
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
with open(icon_path, 'rb') as f:
icon_bytes = f.read()
return icon_bytes, mime_type
@staticmethod
def get_model_tool_provider_icon(
provider: str
):
"""
get tool provider icon and it's mimetype
"""
service = ModelProviderService()
icon_bytes, mime_type = service.get_model_provider_icon(provider=provider, icon_type='icon_small', lang='en_US')
if icon_bytes is None:
raise ValueError(f'provider {provider} does not exists')
return icon_bytes, mime_type
@staticmethod
def list_model_tool_provider_tools(
user_id: str, tenant_id: str, provider: str
) -> list[UserTool]:
"""
list model tool provider tools
"""
provider_controller = ToolManager.get_model_provider(tenant_id=tenant_id, provider_name=provider)
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
result = [
UserTool(
author=tool.identity.author,
name=tool.identity.name,
label=tool.identity.label,
description=tool.description.human,
parameters=tool.parameters or []
) for tool in tools
]
return jsonable_encoder(result)
@staticmethod
def delete_api_tool_provider(
user_id: str, tenant_id: str, provider_name: str
@@ -583,7 +375,7 @@ class ToolManageService:
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
# create provider entity
provider_controller = ApiBasedToolProviderController.from_db(db_provider, auth_type)
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
# load tools into provider entity
provider_controller.load_bundled_tools(tool_bundles)
@@ -604,7 +396,7 @@ class ToolManageService:
provider_controller.validate_credentials_format(credentials)
# get tool
tool = provider_controller.get_tool(tool_name)
tool = tool.fork_tool_runtime(meta={
tool = tool.fork_tool_runtime(runtime={
'credentials': credentials,
'tenant_id': tenant_id,
})
@@ -614,49 +406,6 @@ class ToolManageService:
return { 'result': result or 'empty response' }
@staticmethod
def list_builtin_tools(
user_id: str, tenant_id: str
) -> list[UserToolProvider]:
"""
list builtin tools
"""
# get all builtin providers
provider_controllers = ToolManager.list_builtin_providers()
# get all user added providers
db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id
).all() or []
# find provider
find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
result: list[UserToolProvider] = []
for provider_controller in provider_controllers:
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=find_provider(provider_controller.identity.name),
decrypt_credentials=True
)
# add icon
ToolTransformService.repack_provider(user_builtin_provider)
tools = provider_controller.get_tools()
for tool in tools:
user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id,
tool=tool,
credentials=user_builtin_provider.original_credentials,
))
result.append(user_builtin_provider)
return BuiltinToolProviderSort.sort(result)
@staticmethod
def list_api_tools(
user_id: str, tenant_id: str
@@ -674,6 +423,7 @@ class ToolManageService:
for provider in db_providers:
# convert provider controller to user provider
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
labels = ToolLabelManager.get_tool_labels(provider_controller)
user_provider = ToolTransformService.api_provider_to_user_provider(
provider_controller,
db_provider=provider,
@@ -692,6 +442,7 @@ class ToolManageService:
tenant_id=tenant_id,
tool=tool,
credentials=user_provider.original_credentials,
labels=labels
))
result.append(user_provider)

View File

@@ -0,0 +1,226 @@
import json
import logging
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.provider.builtin._positions import BuiltinToolProviderSort
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolConfigurationManager
from extensions.ext_database import db
from models.tools import BuiltinToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class BuiltinToolManageService:
@staticmethod
def list_builtin_tool_provider_tools(
user_id: str, tenant_id: str, provider: str
) -> list[UserTool]:
"""
list builtin tool provider tools
"""
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
tools = provider_controller.get_tools()
tool_provider_configurations = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
# check if user has added the provider
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
).first()
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
result = []
for tool in tools:
result.append(ToolTransformService.tool_to_user_tool(
tool=tool,
credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller)
))
return result
@staticmethod
def list_builtin_provider_credentials_schema(
provider_name
):
"""
list builtin provider credentials schema
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name)
return jsonable_encoder([
v for _, v in (provider.credentials_schema or {}).items()
])
@staticmethod
def update_builtin_tool_provider(
user_id: str, tenant_id: str, provider_name: str, credentials: dict
):
"""
update builtin tool provider
"""
# get if the provider exists
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
).first()
try:
# get provider
provider_controller = ToolManager.get_builtin_provider(provider_name)
if not provider_controller.need_credentials:
raise ValueError(f'provider {provider_name} does not need credentials')
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
# get original credentials if exists
if provider is not None:
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
# validate credentials
provider_controller.validate_credentials(credentials)
# encrypt credentials
credentials = tool_configuration.encrypt_tool_credentials(credentials)
except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
raise ValueError(str(e))
if provider is None:
# create provider
provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
encrypted_credentials=json.dumps(credentials),
)
db.session.add(provider)
db.session.commit()
else:
provider.encrypted_credentials = json.dumps(credentials)
db.session.add(provider)
db.session.commit()
# delete cache
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' }
@staticmethod
def get_builtin_tool_provider_credentials(
user_id: str, tenant_id: str, provider: str
):
"""
get builtin tool provider credentials
"""
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
).first()
if provider is None:
return {}
provider_controller = ToolManager.get_builtin_provider(provider.provider)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
credentials = tool_configuration.mask_tool_credentials(credentials)
return credentials
@staticmethod
def delete_builtin_tool_provider(
user_id: str, tenant_id: str, provider_name: str
):
"""
delete tool provider
"""
provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider_name,
).first()
if provider is None:
raise ValueError(f'you have not added provider {provider_name}')
db.session.delete(provider)
db.session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name)
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
tool_configuration.delete_tool_credentials_cache()
return { 'result': 'success' }
@staticmethod
def get_builtin_tool_provider_icon(
provider: str
):
"""
get tool provider icon and it's mimetype
"""
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
with open(icon_path, 'rb') as f:
icon_bytes = f.read()
return icon_bytes, mime_type
@staticmethod
def list_builtin_tools(
user_id: str, tenant_id: str
) -> list[UserToolProvider]:
"""
list builtin tools
"""
# get all builtin providers
provider_controllers = ToolManager.list_builtin_providers()
# get all user added providers
db_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider).filter(
BuiltinToolProvider.tenant_id == tenant_id
).all() or []
# find provider
find_provider = lambda provider: next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
result: list[UserToolProvider] = []
for provider_controller in provider_controllers:
# convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
db_provider=find_provider(provider_controller.identity.name),
decrypt_credentials=True
)
# add icon
ToolTransformService.repack_provider(user_builtin_provider)
tools = provider_controller.get_tools()
for tool in tools:
user_builtin_provider.tools.append(ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id,
tool=tool,
credentials=user_builtin_provider.original_credentials,
labels=ToolLabelManager.get_tool_labels(provider_controller)
))
result.append(user_builtin_provider)
return BuiltinToolProviderSort.sort(result)

View File

@@ -0,0 +1,8 @@
from core.tools.entities.tool_entities import ToolLabel
from core.tools.entities.values import default_tool_labels
class ToolLabelsService:
@classmethod
def list_tool_labels(cls) -> list[ToolLabel]:
return default_tool_labels

View File

@@ -0,0 +1,29 @@
import logging
from core.tools.entities.api_entities import UserToolProviderTypeLiteral
from core.tools.tool_manager import ToolManager
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class ToolCommonService:
@staticmethod
def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None):
"""
list tool providers
:return: the list of tool providers
"""
providers = ToolManager.user_list_providers(
user_id, tenant_id, typ
)
# add icon
for provider in providers:
ToolTransformService.repack_provider(provider)
result = [provider.to_dict() for provider in providers]
return result

View File

@@ -5,14 +5,21 @@ from typing import Optional, Union
from flask import current_app
from core.model_runtime.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiBasedToolBundle
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolParameter, ToolProviderCredentials
from core.tools.entities.user_entities import UserTool, UserToolProvider
from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolParameter,
ToolProviderCredentials,
ToolProviderType,
)
from core.tools.provider.api_tool_provider import ApiToolProviderController
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
from core.tools.tool.tool import Tool
from core.tools.tool.workflow_tool import WorkflowTool
from core.tools.utils.configuration import ToolConfigurationManager
from models.tools import ApiToolProvider, BuiltinToolProvider
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
logger = logging.getLogger(__name__)
@@ -25,9 +32,9 @@ class ToolTransformService:
url_prefix = (current_app.config.get("CONSOLE_API_URL")
+ "/console/api/workspaces/current/tool-provider/")
if provider_type == UserToolProvider.ProviderType.BUILTIN.value:
if provider_type == ToolProviderType.BUILT_IN.value:
return url_prefix + 'builtin/' + provider_name + '/icon'
elif provider_type == UserToolProvider.ProviderType.API.value:
elif provider_type in [ToolProviderType.API.value, ToolProviderType.WORKFLOW.value]:
try:
return json.loads(icon)
except:
@@ -62,7 +69,7 @@ class ToolTransformService:
def builtin_provider_to_user_provider(
provider_controller: BuiltinToolProviderController,
db_provider: Optional[BuiltinToolProvider],
decrypt_credentials: bool = True
decrypt_credentials: bool = True,
) -> UserToolProvider:
"""
convert provider controller to user provider
@@ -80,10 +87,11 @@ class ToolTransformService:
en_US=provider_controller.identity.label.en_US,
zh_Hans=provider_controller.identity.label.zh_Hans,
),
type=UserToolProvider.ProviderType.BUILTIN,
type=ToolProviderType.BUILT_IN,
masked_credentials={},
is_team_authorization=False,
tools=[]
tools=[],
labels=provider_controller.tool_labels
)
# get credentials schema
@@ -119,24 +127,62 @@ class ToolTransformService:
@staticmethod
def api_provider_to_controller(
db_provider: ApiToolProvider,
) -> ApiBasedToolProviderController:
) -> ApiToolProviderController:
"""
convert provider controller to user provider
"""
# package tool provider controller
controller = ApiBasedToolProviderController.from_db(
controller = ApiToolProviderController.from_db(
db_provider=db_provider,
auth_type=ApiProviderAuthType.API_KEY if db_provider.credentials['auth_type'] == 'api_key' else
ApiProviderAuthType.NONE
)
return controller
@staticmethod
def workflow_provider_to_controller(
db_provider: WorkflowToolProvider
) -> WorkflowToolProviderController:
"""
convert provider controller to provider
"""
return WorkflowToolProviderController.from_db(db_provider)
@staticmethod
def workflow_provider_to_user_provider(
provider_controller: WorkflowToolProviderController,
labels: list[str] = None
):
"""
convert provider controller to user provider
"""
return UserToolProvider(
id=provider_controller.provider_id,
author=provider_controller.identity.author,
name=provider_controller.identity.name,
description=I18nObject(
en_US=provider_controller.identity.description.en_US,
zh_Hans=provider_controller.identity.description.zh_Hans,
),
icon=provider_controller.identity.icon,
label=I18nObject(
en_US=provider_controller.identity.label.en_US,
zh_Hans=provider_controller.identity.label.zh_Hans,
),
type=ToolProviderType.WORKFLOW,
masked_credentials={},
is_team_authorization=True,
tools=[],
labels=labels or []
)
@staticmethod
def api_provider_to_user_provider(
provider_controller: ApiBasedToolProviderController,
provider_controller: ApiToolProviderController,
db_provider: ApiToolProvider,
decrypt_credentials: bool = True
decrypt_credentials: bool = True,
labels: list[str] = None
) -> UserToolProvider:
"""
convert provider controller to user provider
@@ -161,10 +207,11 @@ class ToolTransformService:
en_US=db_provider.name,
zh_Hans=db_provider.name,
),
type=UserToolProvider.ProviderType.API,
type=ToolProviderType.API,
masked_credentials={},
is_team_authorization=True,
tools=[]
tools=[],
labels=labels or []
)
if decrypt_credentials:
@@ -184,14 +231,17 @@ class ToolTransformService:
@staticmethod
def tool_to_user_tool(
tool: Union[ApiBasedToolBundle, Tool], credentials: dict = None, tenant_id: str = None
tool: Union[ApiToolBundle, WorkflowTool, Tool],
credentials: dict = None,
tenant_id: str = None,
labels: list[str] = None
) -> UserTool:
"""
convert tool to user tool
"""
if isinstance(tool, Tool):
# fork tool runtime
tool = tool.fork_tool_runtime(meta={
tool = tool.fork_tool_runtime(runtime={
'credentials': credentials,
'tenant_id': tenant_id,
})
@@ -213,17 +263,15 @@ class ToolTransformService:
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
user_tool = UserTool(
return UserTool(
author=tool.identity.author,
name=tool.identity.name,
label=tool.identity.label,
description=tool.description.human,
parameters=current_parameters
parameters=current_parameters,
labels=labels
)
return user_tool
if isinstance(tool, ApiBasedToolBundle):
if isinstance(tool, ApiToolBundle):
return UserTool(
author=tool.author,
name=tool.operation_id,
@@ -235,5 +283,6 @@ class ToolTransformService:
en_US=tool.summary or '',
zh_Hans=tool.summary or ''
),
parameters=tool.parameters
parameters=tool.parameters,
labels=labels
)

View File

@@ -0,0 +1,326 @@
import json
from datetime import datetime
from sqlalchemy import or_
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserToolProvider
from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from extensions.ext_database import db
from models.model import App
from models.tools import WorkflowToolProvider
from models.workflow import Workflow
from services.tools.tools_transform_service import ToolTransformService
class WorkflowToolManageService:
"""
Service class for managing workflow tools.
"""
@classmethod
def create_workflow_tool(cls, user_id: str, tenant_id: str, workflow_app_id: str, name: str,
label: str, icon: dict, description: str,
parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict:
"""
Create a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param name: the name
:param icon: the icon
:param description: the description
:param parameters: the parameters
:param privacy_policy: the privacy policy
:return: the created tool
"""
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id)
).first()
if existing_workflow_tool_provider is not None:
raise ValueError(f'Tool with name {name} or app_id {workflow_app_id} already exists')
app: App = db.session.query(App).filter(
App.id == workflow_app_id,
App.tenant_id == tenant_id
).first()
if app is None:
raise ValueError(f'App {workflow_app_id} not found')
workflow: Workflow = app.workflow
if workflow is None:
raise ValueError(f'Workflow not found for app {workflow_app_id}')
workflow_tool_provider = WorkflowToolProvider(
tenant_id=tenant_id,
user_id=user_id,
app_id=workflow_app_id,
name=name,
label=label,
icon=json.dumps(icon),
description=description,
parameter_configuration=json.dumps(parameters),
privacy_policy=privacy_policy,
version=workflow.version,
)
try:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
raise ValueError(str(e))
db.session.add(workflow_tool_provider)
db.session.commit()
return {
'result': 'success'
}
@classmethod
def update_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str,
name: str, label: str, icon: dict, description: str,
parameters: list[dict], privacy_policy: str = '', labels: list[str] = None) -> dict:
"""
Update a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param tool: the tool
:return: the updated tool
"""
WorkflowToolConfigurationUtils.check_parameter_configurations(parameters)
# check if the name is unique
existing_workflow_tool_provider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id
).first()
if existing_workflow_tool_provider is not None:
raise ValueError(f'Tool with name {name} already exists')
workflow_tool_provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == workflow_tool_id
).first()
if workflow_tool_provider is None:
raise ValueError(f'Tool {workflow_tool_id} not found')
app: App = db.session.query(App).filter(
App.id == workflow_tool_provider.app_id,
App.tenant_id == tenant_id
).first()
if app is None:
raise ValueError(f'App {workflow_tool_provider.app_id} not found')
workflow: Workflow = app.workflow
if workflow is None:
raise ValueError(f'Workflow not found for app {workflow_tool_provider.app_id}')
workflow_tool_provider.name = name
workflow_tool_provider.label = label
workflow_tool_provider.icon = json.dumps(icon)
workflow_tool_provider.description = description
workflow_tool_provider.parameter_configuration = json.dumps(parameters)
workflow_tool_provider.privacy_policy = privacy_policy
workflow_tool_provider.version = workflow.version
workflow_tool_provider.updated_at = datetime.now()
try:
WorkflowToolProviderController.from_db(workflow_tool_provider)
except Exception as e:
raise ValueError(str(e))
db.session.add(workflow_tool_provider)
db.session.commit()
if labels is not None:
ToolLabelManager.update_tool_labels(
ToolTransformService.workflow_provider_to_controller(workflow_tool_provider),
labels
)
return {
'result': 'success'
}
@classmethod
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
"""
List workflow tools.
:param user_id: the user id
:param tenant_id: the tenant id
:return: the list of tools
"""
db_tools = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id
).all()
tools = []
for provider in db_tools:
try:
tools.append(ToolTransformService.workflow_provider_to_controller(provider))
except:
# skip deleted tools
pass
labels = ToolLabelManager.get_tools_labels(tools)
result = []
for tool in tools:
user_tool_provider = ToolTransformService.workflow_provider_to_user_provider(
provider_controller=tool,
labels=labels.get(tool.provider_id, [])
)
ToolTransformService.repack_provider(user_tool_provider)
user_tool_provider.tools = [
ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0],
labels=labels.get(tool.provider_id, [])
)
]
result.append(user_tool_provider)
return result
@classmethod
def delete_workflow_tool(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
"""
Delete a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_app_id: the workflow app id
"""
db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == workflow_tool_id
).delete()
db.session.commit()
return {
'result': 'success'
}
@classmethod
def get_workflow_tool_by_tool_id(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> dict:
"""
Get a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == workflow_tool_id
).first()
if db_tool is None:
raise ValueError(f'Tool {workflow_tool_id} not found')
workflow_app: App = db.session.query(App).filter(
App.id == db_tool.app_id,
App.tenant_id == tenant_id
).first()
if workflow_app is None:
raise ValueError(f'App {db_tool.app_id} not found')
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
return {
'name': db_tool.name,
'label': db_tool.label,
'workflow_tool_id': db_tool.id,
'workflow_app_id': db_tool.app_id,
'icon': json.loads(db_tool.icon),
'description': db_tool.description,
'parameters': jsonable_encoder(db_tool.parameter_configurations),
'tool': ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool)
),
'synced': workflow_app.workflow.version == db_tool.version,
'privacy_policy': db_tool.privacy_policy,
}
@classmethod
def get_workflow_tool_by_app_id(cls, user_id: str, tenant_id: str, workflow_app_id: str) -> dict:
"""
Get a workflow tool.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_app_id: the workflow app id
:return: the tool
"""
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == workflow_app_id
).first()
if db_tool is None:
raise ValueError(f'Tool {workflow_app_id} not found')
workflow_app: App = db.session.query(App).filter(
App.id == db_tool.app_id,
App.tenant_id == tenant_id
).first()
if workflow_app is None:
raise ValueError(f'App {db_tool.app_id} not found')
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
return {
'name': db_tool.name,
'label': db_tool.label,
'workflow_tool_id': db_tool.id,
'workflow_app_id': db_tool.app_id,
'icon': json.loads(db_tool.icon),
'description': db_tool.description,
'parameters': jsonable_encoder(db_tool.parameter_configurations),
'tool': ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool)
),
'synced': workflow_app.workflow.version == db_tool.version,
'privacy_policy': db_tool.privacy_policy
}
@classmethod
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[dict]:
"""
List workflow tool provider tools.
:param user_id: the user id
:param tenant_id: the tenant id
:param workflow_app_id: the workflow app id
:return: the list of tools
"""
db_tool: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.id == workflow_tool_id
).first()
if db_tool is None:
raise ValueError(f'Tool {workflow_tool_id} not found')
tool = ToolTransformService.workflow_provider_to_controller(db_tool)
return [
ToolTransformService.tool_to_user_tool(
tool.get_tools(user_id, tenant_id)[0],
labels=ToolLabelManager.get_tool_labels(tool)
)
]