feat: backend model load balancing support (#4927)

This commit is contained in:
takatost
2024-06-05 00:13:04 +08:00
committed by GitHub
parent 52ec152dd3
commit d1dbbc1e33
47 changed files with 2191 additions and 256 deletions

View File

@@ -10,6 +10,7 @@ from flask import current_app
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.helper.module_import_helper import load_single_subclass_from_source
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
from core.tools.entities.common_entities import I18nObject
@@ -31,7 +32,6 @@ from core.tools.utils.configuration import (
ToolParameterConfigurationManager,
)
from core.tools.utils.tool_parameter_converter import ToolParameterConverter
from core.utils.module_import_helper import load_single_subclass_from_source
from core.workflow.nodes.tool.entities import ToolEntity
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
@@ -102,10 +102,10 @@ class ToolManager:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@classmethod
def get_tool_runtime(cls, provider_type: str,
def get_tool_runtime(cls, provider_type: str,
provider_id: str,
tool_name: str,
tenant_id: str,
tool_name: str,
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
-> Union[BuiltinTool, ApiTool]:
@@ -222,7 +222,7 @@ class ToolManager:
get the agent tool runtime
"""
tool_entity = cls.get_tool_runtime(
provider_type=agent_tool.provider_type,
provider_type=agent_tool.provider_type,
provider_id=agent_tool.provider_id,
tool_name=agent_tool.tool_name,
tenant_id=tenant_id,
@@ -235,7 +235,7 @@ class ToolManager:
# check file types
if parameter.type == ToolParameter.ToolParameterType.FILE:
raise ValueError(f"file type parameter {parameter.name} not supported in agent")
if parameter.form == ToolParameter.ToolParameterForm.FORM:
# save tool parameter to tool entity memory
value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
@@ -403,7 +403,7 @@ class ToolManager:
# get builtin providers
builtin_providers = cls.list_builtin_providers()
# get db builtin providers
db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
filter(BuiltinToolProvider.tenant_id == tenant_id).all()
@@ -428,7 +428,7 @@ class ToolManager:
if 'api' in filters:
db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
filter(ApiToolProvider.tenant_id == tenant_id).all()
api_provider_controllers = [{
'provider': provider,
'controller': ToolTransformService.api_provider_to_controller(provider)
@@ -450,7 +450,7 @@ class ToolManager:
# get workflow providers
workflow_providers: list[WorkflowToolProvider] = db.session.query(WorkflowToolProvider). \
filter(WorkflowToolProvider.tenant_id == tenant_id).all()
workflow_provider_controllers = []
for provider in workflow_providers:
try:
@@ -460,7 +460,7 @@ class ToolManager:
except Exception as e:
# app has been deleted
pass
labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
for provider_controller in workflow_provider_controllers: