feat: backend model load balancing support (#4927)
This commit is contained in:
@@ -10,6 +10,7 @@ from flask import current_app
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@@ -31,7 +32,6 @@ from core.tools.utils.configuration import (
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
|
||||
from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
from core.workflow.nodes.tool.entities import ToolEntity
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
@@ -102,10 +102,10 @@ class ToolManager:
|
||||
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
|
||||
|
||||
@classmethod
|
||||
def get_tool_runtime(cls, provider_type: str,
|
||||
def get_tool_runtime(cls, provider_type: str,
|
||||
provider_id: str,
|
||||
tool_name: str,
|
||||
tenant_id: str,
|
||||
tool_name: str,
|
||||
tenant_id: str,
|
||||
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
|
||||
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
@@ -222,7 +222,7 @@ class ToolManager:
|
||||
get the agent tool runtime
|
||||
"""
|
||||
tool_entity = cls.get_tool_runtime(
|
||||
provider_type=agent_tool.provider_type,
|
||||
provider_type=agent_tool.provider_type,
|
||||
provider_id=agent_tool.provider_id,
|
||||
tool_name=agent_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
@@ -235,7 +235,7 @@ class ToolManager:
|
||||
# check file types
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
|
||||
|
||||
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
# save tool parameter to tool entity memory
|
||||
value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
|
||||
@@ -403,7 +403,7 @@ class ToolManager:
|
||||
|
||||
# get builtin providers
|
||||
builtin_providers = cls.list_builtin_providers()
|
||||
|
||||
|
||||
# get db builtin providers
|
||||
db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
|
||||
filter(BuiltinToolProvider.tenant_id == tenant_id).all()
|
||||
@@ -428,7 +428,7 @@ class ToolManager:
|
||||
if 'api' in filters:
|
||||
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
|
||||
filter(ApiToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
|
||||
api_provider_controllers = [{
|
||||
'provider': provider,
|
||||
'controller': ToolTransformService.api_provider_to_controller(provider)
|
||||
@@ -450,7 +450,7 @@ class ToolManager:
|
||||
# get workflow providers
|
||||
workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \
|
||||
filter(WorkflowToolProvider.tenant_id == tenant_id).all()
|
||||
|
||||
|
||||
workflow_provider_controllers = []
|
||||
for provider in workflow_providers:
|
||||
try:
|
||||
@@ -460,7 +460,7 @@ class ToolManager:
|
||||
except Exception as e:
|
||||
# app has been deleted
|
||||
pass
|
||||
|
||||
|
||||
labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
|
||||
|
||||
for provider_controller in workflow_provider_controllers:
|
||||
|
Reference in New Issue
Block a user