diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 2089313b0..3454ec348 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast import sqlalchemy as sa from pydantic import TypeAdapter +from sqlalchemy.orm import Session from yarl import URL import contexts @@ -617,8 +618,9 @@ class ToolManager: WHERE tenant_id = :tenant_id ORDER BY tenant_id, provider, is_default DESC, created_at DESC """ - ids = [row.id for row in db.session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()] - return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all() + with Session(db.engine, autoflush=False) as session: + ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()] + return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all() @classmethod def list_providers_from_api( diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index da0fc5856..84b958023 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -453,7 +453,7 @@ class BuiltinToolManageService: check if oauth system client exists """ tool_provider = ToolProviderID(provider_name) - with Session(db.engine).no_autoflush as session: + with Session(db.engine, autoflush=False) as session: system_client: ToolOAuthSystemClient | None = ( session.query(ToolOAuthSystemClient) .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) @@ -467,7 +467,7 @@ class BuiltinToolManageService: check if oauth custom client is enabled """ tool_provider = ToolProviderID(provider) - with Session(db.engine).no_autoflush as session: + with Session(db.engine, autoflush=False) as session: user_client: ToolOAuthTenantClient | None = ( session.query(ToolOAuthTenantClient) .filter_by( @@ -492,7 +492,7 @@ class BuiltinToolManageService: config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], cache=NoOpProviderCredentialCache(), ) - with Session(db.engine).no_autoflush as session: + with Session(db.engine, autoflush=False) as session: user_client: ToolOAuthTenantClient | None = ( session.query(ToolOAuthTenantClient) .filter_by( @@ -546,54 +546,53 @@ class BuiltinToolManageService: # get all builtin providers provider_controllers = ToolManager.list_builtin_providers(tenant_id) - with db.session.no_autoflush: - # get all user added providers - db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id) + # get all user added providers + db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id) - # rewrite db_providers - for db_provider in db_providers: - db_provider.provider = str(ToolProviderID(db_provider.provider)) + # rewrite db_providers + for db_provider in db_providers: + db_provider.provider = str(ToolProviderID(db_provider.provider)) - # find provider - def find_provider(provider): - return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) + # find provider + def find_provider(provider): + return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) - result: list[ToolProviderApiEntity] = [] + result: list[ToolProviderApiEntity] = [] - for provider_controller in provider_controllers: - try: - # handle include, exclude - if is_filtered( - include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore - exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore - data=provider_controller, - name_func=lambda x: x.identity.name, - ): - continue + for provider_controller in provider_controllers: + try: + # handle include, exclude + if is_filtered( + include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore + exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore + data=provider_controller, + name_func=lambda x: x.identity.name, + ): + continue - # convert provider controller to user provider - user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( - provider_controller=provider_controller, - db_provider=find_provider(provider_controller.entity.identity.name), - decrypt_credentials=True, + # convert provider controller to user provider + user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( + provider_controller=provider_controller, + db_provider=find_provider(provider_controller.entity.identity.name), + decrypt_credentials=True, + ) + + # add icon + ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider) + + tools = provider_controller.get_tools() + for tool in tools or []: + user_builtin_provider.tools.append( + ToolTransformService.convert_tool_entity_to_api_entity( + tenant_id=tenant_id, + tool=tool, + labels=ToolLabelManager.get_tool_labels(provider_controller), + ) ) - # add icon - ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider) - - tools = provider_controller.get_tools() - for tool in tools or []: - user_builtin_provider.tools.append( - ToolTransformService.convert_tool_entity_to_api_entity( - tenant_id=tenant_id, - tool=tool, - labels=ToolLabelManager.get_tool_labels(provider_controller), - ) - ) - - result.append(user_builtin_provider) - except Exception as e: - raise e + result.append(user_builtin_provider) + except Exception as e: + raise e return BuiltinToolProviderSort.sort(result) @@ -604,7 +603,7 @@ class BuiltinToolManageService: 1.if the default provider exists, return the default provider 2.if the default provider does not exist, return the oldest provider """ - with Session(db.engine) as session: + with Session(db.engine, autoflush=False) as session: try: full_provider_name = provider_name provider_id_entity = ToolProviderID(provider_name)