generalize helper for loading module from source (#2862)
This commit is contained in:
@@ -1,4 +1,3 @@
|
||||
import importlib
|
||||
from abc import abstractmethod
|
||||
from os import listdir, path
|
||||
from typing import Any
|
||||
@@ -16,6 +15,7 @@ from core.tools.errors import (
|
||||
from core.tools.provider.tool_provider import ToolProviderController
|
||||
from core.tools.tool.builtin_tool import BuiltinTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
|
||||
|
||||
class BuiltinToolProviderController(ToolProviderController):
|
||||
@@ -63,16 +63,11 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
tool_name = tool_file.split(".")[0]
|
||||
tool = load(f.read(), FullLoader)
|
||||
# get tool class, import the module
|
||||
py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, 'tools', f'{tool_name}.py')
|
||||
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.tools.{tool_name}', py_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
# get all the classes in the module
|
||||
classes = [x for _, x in vars(mod).items()
|
||||
if isinstance(x, type) and x not in [BuiltinTool, Tool] and issubclass(x, BuiltinTool)
|
||||
]
|
||||
assistant_tool_class = classes[0]
|
||||
assistant_tool_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.tools.{tool_name}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'builtin', provider, 'tools', f'{tool_name}.py'),
|
||||
parent_type=BuiltinTool)
|
||||
tools.append(assistant_tool_class(**tool))
|
||||
|
||||
self.tools = tools
|
||||
|
@@ -1,4 +1,3 @@
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
@@ -34,6 +33,7 @@ from core.tools.utils.configuration import (
|
||||
ToolParameterConfigurationManager,
|
||||
)
|
||||
from core.tools.utils.encoder import serialize_base_model_dict
|
||||
from core.utils.module_import_helper import load_single_subclass_from_source
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider
|
||||
|
||||
@@ -72,21 +72,11 @@ class ToolManager:
|
||||
|
||||
if provider_entity is None:
|
||||
# fetch the provider from .provider.builtin
|
||||
py_path = path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py')
|
||||
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
# get all the classes in the module
|
||||
classes = [ x for _, x in vars(mod).items()
|
||||
if isinstance(x, type) and x != ToolProviderController and issubclass(x, ToolProviderController)
|
||||
]
|
||||
if len(classes) == 0:
|
||||
raise ToolProviderNotFoundError(f'provider {provider} not found')
|
||||
if len(classes) > 1:
|
||||
raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
|
||||
|
||||
provider_entity = classes[0]()
|
||||
provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)), 'builtin', provider, f'{provider}.py'),
|
||||
parent_type=ToolProviderController)
|
||||
provider_entity = provider_class()
|
||||
|
||||
return provider_entity.invoke(tool_id, tool_name, tool_parameters, credentials, prompt_messages)
|
||||
|
||||
@@ -330,23 +320,12 @@ class ToolManager:
|
||||
if provider.startswith('__'):
|
||||
continue
|
||||
|
||||
py_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, f'{provider}.py')
|
||||
spec = importlib.util.spec_from_file_location(f'core.tools.provider.builtin.{provider}.{provider}', py_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
|
||||
# load all classes
|
||||
classes = [
|
||||
obj for name, obj in vars(mod).items()
|
||||
if isinstance(obj, type) and obj != BuiltinToolProviderController and issubclass(obj, BuiltinToolProviderController)
|
||||
]
|
||||
if len(classes) == 0:
|
||||
raise ToolProviderNotFoundError(f'provider {provider} not found')
|
||||
if len(classes) > 1:
|
||||
raise ToolProviderNotFoundError(f'multiple providers found for {provider}')
|
||||
|
||||
# init provider
|
||||
provider_class = classes[0]
|
||||
provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
parent_type=BuiltinToolProviderController)
|
||||
builtin_providers.append(provider_class())
|
||||
|
||||
# cache the builtin providers
|
||||
|
Reference in New Issue
Block a user