fix: tool provider deadlock (#24532)

Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
This commit is contained in:
Maries
2025-08-27 12:27:20 +08:00
committed by GitHub
parent ddf6192643
commit c06cfcbb5a
2 changed files with 48 additions and 47 deletions

View File

@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
import sqlalchemy as sa import sqlalchemy as sa
from pydantic import TypeAdapter from pydantic import TypeAdapter
from sqlalchemy.orm import Session
from yarl import URL from yarl import URL
import contexts import contexts
@@ -617,8 +618,9 @@ class ToolManager:
WHERE tenant_id = :tenant_id WHERE tenant_id = :tenant_id
ORDER BY tenant_id, provider, is_default DESC, created_at DESC ORDER BY tenant_id, provider, is_default DESC, created_at DESC
""" """
ids = [row.id for row in db.session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()] with Session(db.engine, autoflush=False) as session:
return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all() ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
@classmethod @classmethod
def list_providers_from_api( def list_providers_from_api(

View File

@@ -453,7 +453,7 @@ class BuiltinToolManageService:
check if oauth system client exists check if oauth system client exists
""" """
tool_provider = ToolProviderID(provider_name) tool_provider = ToolProviderID(provider_name)
with Session(db.engine).no_autoflush as session: with Session(db.engine, autoflush=False) as session:
system_client: ToolOAuthSystemClient | None = ( system_client: ToolOAuthSystemClient | None = (
session.query(ToolOAuthSystemClient) session.query(ToolOAuthSystemClient)
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
@@ -467,7 +467,7 @@ class BuiltinToolManageService:
check if oauth custom client is enabled check if oauth custom client is enabled
""" """
tool_provider = ToolProviderID(provider) tool_provider = ToolProviderID(provider)
with Session(db.engine).no_autoflush as session: with Session(db.engine, autoflush=False) as session:
user_client: ToolOAuthTenantClient | None = ( user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient) session.query(ToolOAuthTenantClient)
.filter_by( .filter_by(
@@ -492,7 +492,7 @@ class BuiltinToolManageService:
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(), cache=NoOpProviderCredentialCache(),
) )
with Session(db.engine).no_autoflush as session: with Session(db.engine, autoflush=False) as session:
user_client: ToolOAuthTenantClient | None = ( user_client: ToolOAuthTenantClient | None = (
session.query(ToolOAuthTenantClient) session.query(ToolOAuthTenantClient)
.filter_by( .filter_by(
@@ -546,54 +546,53 @@ class BuiltinToolManageService:
# get all builtin providers # get all builtin providers
provider_controllers = ToolManager.list_builtin_providers(tenant_id) provider_controllers = ToolManager.list_builtin_providers(tenant_id)
with db.session.no_autoflush: # get all user added providers
# get all user added providers db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
# rewrite db_providers # rewrite db_providers
for db_provider in db_providers: for db_provider in db_providers:
db_provider.provider = str(ToolProviderID(db_provider.provider)) db_provider.provider = str(ToolProviderID(db_provider.provider))
# find provider # find provider
def find_provider(provider): def find_provider(provider):
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None) return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
result: list[ToolProviderApiEntity] = [] result: list[ToolProviderApiEntity] = []
for provider_controller in provider_controllers: for provider_controller in provider_controllers:
try: try:
# handle include, exclude # handle include, exclude
if is_filtered( if is_filtered(
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
data=provider_controller, data=provider_controller,
name_func=lambda x: x.identity.name, name_func=lambda x: x.identity.name,
): ):
continue continue
# convert provider controller to user provider # convert provider controller to user provider
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider( user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller, provider_controller=provider_controller,
db_provider=find_provider(provider_controller.entity.identity.name), db_provider=find_provider(provider_controller.entity.identity.name),
decrypt_credentials=True, decrypt_credentials=True,
)
# add icon
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
tools = provider_controller.get_tools()
for tool in tools or []:
user_builtin_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
) )
# add icon result.append(user_builtin_provider)
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider) except Exception as e:
raise e
tools = provider_controller.get_tools()
for tool in tools or []:
user_builtin_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
result.append(user_builtin_provider)
except Exception as e:
raise e
return BuiltinToolProviderSort.sort(result) return BuiltinToolProviderSort.sort(result)
@@ -604,7 +603,7 @@ class BuiltinToolManageService:
1.if the default provider exists, return the default provider 1.if the default provider exists, return the default provider
2.if the default provider does not exist, return the oldest provider 2.if the default provider does not exist, return the oldest provider
""" """
with Session(db.engine) as session: with Session(db.engine, autoflush=False) as session:
try: try:
full_provider_name = provider_name full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name) provider_id_entity = ToolProviderID(provider_name)