fix: tool provider deadlock (#24532)
Co-authored-by: Wu Tianwei <30284043+WTW0313@users.noreply.github.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy.orm import Session
|
||||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
@@ -617,8 +618,9 @@ class ToolManager:
|
||||
WHERE tenant_id = :tenant_id
|
||||
ORDER BY tenant_id, provider, is_default DESC, created_at DESC
|
||||
"""
|
||||
ids = [row.id for row in db.session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
ids = [row.id for row in session.execute(sa.text(sql), {"tenant_id": tenant_id}).all()]
|
||||
return session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
|
||||
|
||||
@classmethod
|
||||
def list_providers_from_api(
|
||||
|
@@ -453,7 +453,7 @@ class BuiltinToolManageService:
|
||||
check if oauth system client exists
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider_name)
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
system_client: ToolOAuthSystemClient | None = (
|
||||
session.query(ToolOAuthSystemClient)
|
||||
.filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
|
||||
@@ -467,7 +467,7 @@ class BuiltinToolManageService:
|
||||
check if oauth custom client is enabled
|
||||
"""
|
||||
tool_provider = ToolProviderID(provider)
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
user_client: ToolOAuthTenantClient | None = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
@@ -492,7 +492,7 @@ class BuiltinToolManageService:
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
with Session(db.engine).no_autoflush as session:
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
user_client: ToolOAuthTenantClient | None = (
|
||||
session.query(ToolOAuthTenantClient)
|
||||
.filter_by(
|
||||
@@ -546,54 +546,53 @@ class BuiltinToolManageService:
|
||||
# get all builtin providers
|
||||
provider_controllers = ToolManager.list_builtin_providers(tenant_id)
|
||||
|
||||
with db.session.no_autoflush:
|
||||
# get all user added providers
|
||||
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
|
||||
# get all user added providers
|
||||
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
|
||||
|
||||
# rewrite db_providers
|
||||
for db_provider in db_providers:
|
||||
db_provider.provider = str(ToolProviderID(db_provider.provider))
|
||||
# rewrite db_providers
|
||||
for db_provider in db_providers:
|
||||
db_provider.provider = str(ToolProviderID(db_provider.provider))
|
||||
|
||||
# find provider
|
||||
def find_provider(provider):
|
||||
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
|
||||
# find provider
|
||||
def find_provider(provider):
|
||||
return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
|
||||
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
for provider_controller in provider_controllers:
|
||||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
for provider_controller in provider_controllers:
|
||||
try:
|
||||
# handle include, exclude
|
||||
if is_filtered(
|
||||
include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
|
||||
exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
|
||||
data=provider_controller,
|
||||
name_func=lambda x: x.identity.name,
|
||||
):
|
||||
continue
|
||||
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=find_provider(provider_controller.entity.identity.name),
|
||||
decrypt_credentials=True,
|
||||
# convert provider controller to user provider
|
||||
user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
db_provider=find_provider(provider_controller.entity.identity.name),
|
||||
decrypt_credentials=True,
|
||||
)
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
|
||||
|
||||
tools = provider_controller.get_tools()
|
||||
for tool in tools or []:
|
||||
user_builtin_provider.tools.append(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
)
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
|
||||
|
||||
tools = provider_controller.get_tools()
|
||||
for tool in tools or []:
|
||||
user_builtin_provider.tools.append(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
)
|
||||
|
||||
result.append(user_builtin_provider)
|
||||
except Exception as e:
|
||||
raise e
|
||||
result.append(user_builtin_provider)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
@@ -604,7 +603,7 @@ class BuiltinToolManageService:
|
||||
1.if the default provider exists, return the default provider
|
||||
2.if the default provider does not exist, return the oldest provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
with Session(db.engine, autoflush=False) as session:
|
||||
try:
|
||||
full_provider_name = provider_name
|
||||
provider_id_entity = ToolProviderID(provider_name)
|
||||
|
Reference in New Issue
Block a user