This commit is contained in:
Ricky
2024-01-31 11:58:07 +08:00
committed by GitHub
parent 9e37702d24
commit 2660fbaa20
58 changed files with 312 additions and 312 deletions

View File

@@ -123,12 +123,12 @@ class ApiBasedToolProviderController(ToolProviderController):
return self.tools
def get_tools(self, user_id: str, tanent_id: str) -> List[ApiTool]:
def get_tools(self, user_id: str, tenant_id: str) -> List[ApiTool]:
"""
fetch tools from database
:param user_id: the user id
:param tanent_id: the tanent id
:param tenant_id: the tenant id
:return: the tools
"""
if self.tools is not None:
@@ -136,9 +136,9 @@ class ApiBasedToolProviderController(ToolProviderController):
tools: List[Tool] = []
# get tanent api providers
# get tenant api providers
db_providers: List[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
ApiToolProvider.tenant_id == tanent_id,
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == self.identity.name
).all()

View File

@@ -1,5 +1,5 @@
from typing import Any, Dict, List
from core.tools.entities.tool_entities import ToolProviderType, ToolParamter, ToolParamterOption
from core.tools.entities.tool_entities import ToolProviderType, ToolParameter, ToolParameterOption
from core.tools.tool.tool import Tool
from core.tools.entities.common_entities import I18nObject
from core.tools.provider.tool_provider import ToolProviderController
@@ -71,7 +71,7 @@ class AppBasedToolProviderEntity(ToolProviderController):
variable_name = input_form[form_type]['variable_name']
options = input_form[form_type].get('options', [])
if form_type == 'paragraph' or form_type == 'text-input':
tool['parameters'].append(ToolParamter(
tool['parameters'].append(ToolParameter(
name=variable_name,
label=I18nObject(
en_US=label,
@@ -82,13 +82,13 @@ class AppBasedToolProviderEntity(ToolProviderController):
zh_Hans=label
),
llm_description=label,
form=ToolParamter.ToolParameterForm.FORM,
type=ToolParamter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.FORM,
type=ToolParameter.ToolParameterType.STRING,
required=required,
default=default
))
elif form_type == 'select':
tool['parameters'].append(ToolParamter(
tool['parameters'].append(ToolParameter(
name=variable_name,
label=I18nObject(
en_US=label,
@@ -99,11 +99,11 @@ class AppBasedToolProviderEntity(ToolProviderController):
zh_Hans=label
),
llm_description=label,
form=ToolParamter.ToolParameterForm.FORM,
type=ToolParamter.ToolParameterType.SELECT,
form=ToolParameter.ToolParameterForm.FORM,
type=ToolParameter.ToolParameterType.SELECT,
required=required,
default=default,
options=[ToolParamterOption(
options=[ToolParameterOption(
value=option,
label=I18nObject(
en_US=option,

View File

@@ -13,7 +13,7 @@ class AzureDALLEProvider(BuiltinToolProviderController):
}
).invoke(
user_id='',
tool_paramters={
tool_parameters={
"prompt": "cute girl, blue eyes, white hair, anime style",
"size": "square",
"n": 1

View File

@@ -10,7 +10,7 @@ from openai import AzureOpenAI
class DallE3Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
@@ -28,19 +28,19 @@ class DallE3Tool(BuiltinTool):
}
# prompt
prompt = tool_paramters.get('prompt', '')
prompt = tool_parameters.get('prompt', '')
if not prompt:
return self.create_text_message('Please input prompt')
# get size
size = SIZE_MAPPING[tool_paramters.get('size', 'square')]
size = SIZE_MAPPING[tool_parameters.get('size', 'square')]
# get n
n = tool_paramters.get('n', 1)
n = tool_parameters.get('n', 1)
# get quality
quality = tool_paramters.get('quality', 'standard')
quality = tool_parameters.get('quality', 'standard')
if quality not in ['standard', 'hd']:
return self.create_text_message('Invalid quality')
# get style
style = tool_paramters.get('style', 'vivid')
style = tool_parameters.get('style', 'vivid')
if style not in ['natural', 'vivid']:
return self.create_text_message('Invalid style')

View File

@@ -16,7 +16,7 @@ class ChartProvider(BuiltinToolProviderController):
}
).invoke(
user_id='',
tool_paramters={
tool_parameters={
"data": "1,3,5,7,9,2,4,6,8,10",
},
)

View File

@@ -6,9 +6,9 @@ import io
from typing import Any, Dict, List, Union
class BarChartTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
data = tool_paramters.get('data', '')
data = tool_parameters.get('data', '')
if not data:
return self.create_text_message('Please input data')
data = data.split(';')
@@ -19,7 +19,7 @@ class BarChartTool(BuiltinTool):
else:
data = [float(i) for i in data]
axis = tool_paramters.get('x_axis', None) or None
axis = tool_parameters.get('x_axis', None) or None
if axis:
axis = axis.split(';')
if len(axis) != len(data):

View File

@@ -8,14 +8,14 @@ from typing import Any, Dict, List, Union
class LinearChartTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
data = tool_paramters.get('data', '')
data = tool_parameters.get('data', '')
if not data:
return self.create_text_message('Please input data')
data = data.split(';')
axis = tool_paramters.get('x_axis', None) or None
axis = tool_parameters.get('x_axis', None) or None
if axis:
axis = axis.split(';')
if len(axis) != len(data):

View File

@@ -8,13 +8,13 @@ from typing import Any, Dict, List, Union
class PieChartTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
data = tool_paramters.get('data', '')
data = tool_parameters.get('data', '')
if not data:
return self.create_text_message('Please input data')
data = data.split(';')
categories = tool_paramters.get('categories', None) or None
categories = tool_parameters.get('categories', None) or None
# if all data is int, convert to int
if all([i.isdigit() for i in data]):

View File

@@ -13,7 +13,7 @@ class DALLEProvider(BuiltinToolProviderController):
}
).invoke(
user_id='',
tool_paramters={
tool_parameters={
"prompt": "cute girl, blue eyes, white hair, anime style",
"size": "small",
"n": 1

View File

@@ -10,7 +10,7 @@ from openai import OpenAI
class DallE2Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
@@ -37,15 +37,15 @@ class DallE2Tool(BuiltinTool):
}
# prompt
prompt = tool_paramters.get('prompt', '')
prompt = tool_parameters.get('prompt', '')
if not prompt:
return self.create_text_message('Please input prompt')
# get size
size = SIZE_MAPPING[tool_paramters.get('size', 'large')]
size = SIZE_MAPPING[tool_parameters.get('size', 'large')]
# get n
n = tool_paramters.get('n', 1)
n = tool_parameters.get('n', 1)
# call openapi dalle2
response = client.images.generate(

View File

@@ -10,7 +10,7 @@ from openai import OpenAI
class DallE3Tool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
@@ -37,19 +37,19 @@ class DallE3Tool(BuiltinTool):
}
# prompt
prompt = tool_paramters.get('prompt', '')
prompt = tool_parameters.get('prompt', '')
if not prompt:
return self.create_text_message('Please input prompt')
# get size
size = SIZE_MAPPING[tool_paramters.get('size', 'square')]
size = SIZE_MAPPING[tool_parameters.get('size', 'square')]
# get n
n = tool_paramters.get('n', 1)
n = tool_parameters.get('n', 1)
# get quality
quality = tool_paramters.get('quality', 'standard')
quality = tool_parameters.get('quality', 'standard')
if quality not in ['standard', 'hd']:
return self.create_text_message('Invalid quality')
# get style
style = tool_paramters.get('style', 'vivid')
style = tool_parameters.get('style', 'vivid')
if style not in ['natural', 'vivid']:
return self.create_text_message('Invalid style')

View File

@@ -6,11 +6,11 @@ from typing import Any, Dict, List, Union
class GaodeRepositoriesTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
city = tool_paramters.get('city', '')
city = tool_parameters.get('city', '')
if not city:
return self.create_text_message('Please tell me your city')

View File

@@ -9,12 +9,12 @@ from typing import Any, Dict, List, Union
class GihubRepositoriesTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
top_n = tool_paramters.get('top_n', 5)
query = tool_paramters.get('query', '')
top_n = tool_parameters.get('top_n', 5)
query = tool_parameters.get('query', '')
if not query:
return self.create_text_message('Please input symbol')

View File

@@ -14,7 +14,7 @@ class GoogleProvider(BuiltinToolProviderController):
}
).invoke(
user_id='',
tool_paramters={
tool_parameters={
"query": "test",
"result_type": "link"
},

View File

@@ -148,13 +148,13 @@ class SerpAPI:
class GoogleSearchTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
query = tool_paramters['query']
result_type = tool_paramters['result_type']
query = tool_parameters['query']
result_type = tool_parameters['result_type']
api_key = self.runtime.credentials['serpapi_api_key']
result = SerpAPI(api_key).run(query, result_type=result_type)
if result_type == 'text':

View File

@@ -14,7 +14,7 @@ class StableDiffusionProvider(BuiltinToolProviderController):
}
).invoke(
user_id='',
tool_paramters={
tool_parameters={
"prompt": "cat",
"lora": "",
"steps": 1,

View File

@@ -1,5 +1,5 @@
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter, ToolParamterOption
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
from core.tools.entities.common_entities import I18nObject
from core.tools.errors import ToolProviderCredentialValidationError
@@ -60,7 +60,7 @@ DRAW_TEXT_OPTIONS = {
}
class StableDiffusionTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
@@ -86,25 +86,25 @@ class StableDiffusionTool(BuiltinTool):
# prompt
prompt = tool_paramters.get('prompt', '')
prompt = tool_parameters.get('prompt', '')
if not prompt:
return self.create_text_message('Please input prompt')
# get negative prompt
negative_prompt = tool_paramters.get('negative_prompt', '')
negative_prompt = tool_parameters.get('negative_prompt', '')
# get size
width = tool_paramters.get('width', 1024)
height = tool_paramters.get('height', 1024)
width = tool_parameters.get('width', 1024)
height = tool_parameters.get('height', 1024)
# get steps
steps = tool_paramters.get('steps', 1)
steps = tool_parameters.get('steps', 1)
# get lora
lora = tool_paramters.get('lora', '')
lora = tool_parameters.get('lora', '')
# get image id
image_id = tool_paramters.get('image_id', '')
image_id = tool_parameters.get('image_id', '')
if image_id.strip():
image_variable = self.get_default_image_variable()
if image_variable:
@@ -212,32 +212,32 @@ class StableDiffusionTool(BuiltinTool):
return self.create_text_message('Failed to generate image')
def get_runtime_parameters(self) -> List[ToolParamter]:
def get_runtime_parameters(self) -> List[ToolParameter]:
parameters = [
ToolParamter(name='prompt',
ToolParameter(name='prompt',
label=I18nObject(en_US='Prompt', zh_Hans='Prompt'),
human_description=I18nObject(
en_US='Image prompt, you can check the official documentation of Stable Diffusion',
zh_Hans='图像提示词,您可以查看 Stable Diffusion 的官方文档',
),
type=ToolParamter.ToolParameterType.STRING,
form=ToolParamter.ToolParameterForm.LLM,
type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM,
llm_description='Image prompt of Stable Diffusion, you should describe the image you want to generate as a list of words as possible as detailed, the prompt must be written in English.',
required=True),
]
if len(self.list_default_image_variables()) != 0:
parameters.append(
ToolParamter(name='image_id',
ToolParameter(name='image_id',
label=I18nObject(en_US='image_id', zh_Hans='image_id'),
human_description=I18nObject(
en_US='Image id of the image you want to generate based on, if you want to generate image based on the default image, you can leave this field empty.',
zh_Hans='您想要生成的图像的图像 ID如果您想要基于默认图像生成图像则可以将此字段留空。',
),
type=ToolParamter.ToolParameterType.STRING,
form=ToolParamter.ToolParameterForm.LLM,
type=ToolParameter.ToolParameterType.STRING,
form=ToolParameter.ToolParameterForm.LLM,
llm_description='Image id of the original image, you can leave this field empty if you want to generate a new image.',
required=True,
options=[ToolParamterOption(
options=[ToolParameterOption(
value=i.name,
label=I18nObject(en_US=i.name, zh_Hans=i.name)
) for i in self.list_default_image_variables()])

View File

@@ -10,7 +10,7 @@ class WikiPediaProvider(BuiltinToolProviderController):
try:
CurrentTimeTool().invoke(
user_id='',
tool_paramters={},
tool_parameters={},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))

View File

@@ -8,7 +8,7 @@ from datetime import datetime, timezone
class CurrentTimeTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools

View File

@@ -1,5 +1,5 @@
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParamter
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.provider.builtin.vectorizer.tools.test_data import VECTORIZER_ICON_PNG
from core.tools.errors import ToolProviderCredentialValidationError
@@ -8,21 +8,21 @@ from httpx import post
from base64 import b64decode
class VectorizerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
api_key_name = self.runtime.credentials.get('api_key_name', None)
api_key_value = self.runtime.credentials.get('api_key_value', None)
mode = tool_paramters.get('mode', 'test')
mode = tool_parameters.get('mode', 'test')
if mode == 'production':
mode = 'preview'
if not api_key_name or not api_key_value:
raise ToolProviderCredentialValidationError('Please input api key name and value')
image_id = tool_paramters.get('image_id', '')
image_id = tool_parameters.get('image_id', '')
if not image_id:
return self.create_text_message('Please input image id')
@@ -54,21 +54,21 @@ class VectorizerTool(BuiltinTool):
meta={'mime_type': 'image/svg+xml'})
]
def get_runtime_parameters(self) -> List[ToolParamter]:
def get_runtime_parameters(self) -> List[ToolParameter]:
"""
override the runtime parameters
"""
return [
ToolParamter.get_simple_instance(
ToolParameter.get_simple_instance(
name='image_id',
llm_description=f'the image id that you want to vectorize, \
and the image id should be specified in \
{[i.name for i in self.list_default_image_variables()]}',
type=ToolParamter.ToolParameterType.SELECT,
type=ToolParameter.ToolParameterType.SELECT,
required=True,
options=[i.name for i in self.list_default_image_variables()]
)
]
def is_tool_avaliable(self) -> bool:
def is_tool_available(self) -> bool:
return len(self.list_default_image_variables()) > 0

View File

@@ -14,7 +14,7 @@ class VectorizerProvider(BuiltinToolProviderController):
}
).invoke(
user_id='',
tool_paramters={
tool_parameters={
"mode": "test",
"image_id": "__test_123"
},

View File

@@ -7,14 +7,14 @@ from typing import Any, Dict, List, Union
class WebscraperTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
try:
url = tool_paramters.get('url', '')
user_agent = tool_paramters.get('user_agent', '')
url = tool_parameters.get('url', '')
user_agent = tool_parameters.get('user_agent', '')
if not url:
return self.create_text_message('Please input url')

View File

@@ -14,7 +14,7 @@ class WebscraperProvider(BuiltinToolProviderController):
}
).invoke(
user_id='',
tool_paramters={
tool_parameters={
'url': 'https://www.google.com',
'user_agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 '
},

View File

@@ -14,12 +14,12 @@ class WikipediaInput(BaseModel):
class WikiPediaSearchTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
query = tool_paramters.get('query', '')
query = tool_parameters.get('query', '')
if not query:
return self.create_text_message('Please input query')

View File

@@ -12,7 +12,7 @@ class WikiPediaProvider(BuiltinToolProviderController):
}
).invoke(
user_id='',
tool_paramters={
tool_parameters={
"query": "misaka mikoto",
},
)

View File

@@ -11,12 +11,12 @@ class WolframAlphaTool(BuiltinTool):
def _invoke(self,
user_id: str,
tool_paramters: Dict[str, Any],
tool_parameters: Dict[str, Any],
) -> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
query = tool_paramters.get('query', '')
query = tool_parameters.get('query', '')
if not query:
return self.create_text_message('Please input query')
appid = self.runtime.credentials.get('appid', '')

View File

@@ -16,7 +16,7 @@ class GoogleProvider(BuiltinToolProviderController):
}
).invoke(
user_id='',
tool_paramters={
tool_parameters={
"query": "1+2+....+111",
},
)

View File

@@ -9,23 +9,23 @@ from yfinance import download
import pandas as pd
class YahooFinanceAnalyticsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
symbol = tool_paramters.get('symbol', '')
symbol = tool_parameters.get('symbol', '')
if not symbol:
return self.create_text_message('Please input symbol')
time_range = [None, None]
start_date = tool_paramters.get('start_date', '')
start_date = tool_parameters.get('start_date', '')
if start_date:
time_range[0] = start_date
else:
time_range[0] = '1800-01-01'
end_date = tool_paramters.get('end_date', '')
end_date = tool_parameters.get('end_date', '')
if end_date:
time_range[1] = end_date
else:

View File

@@ -7,13 +7,13 @@ from requests.exceptions import HTTPError, ReadTimeout
import yfinance
class YahooFinanceSearchTickerTool(BuiltinTool):
def _invoke(self,user_id: str, tool_paramters: Dict[str, Any]) \
def _invoke(self,user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
'''
invoke tools
'''
query = tool_paramters.get('symbol', '')
query = tool_parameters.get('symbol', '')
if not query:
return self.create_text_message('Please input symbol')

View File

@@ -7,12 +7,12 @@ from requests.exceptions import HTTPError, ReadTimeout
from yfinance import Ticker
class YahooFinanceSearchTickerTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
query = tool_paramters.get('symbol', '')
query = tool_parameters.get('symbol', '')
if not query:
return self.create_text_message('Please input symbol')

View File

@@ -12,7 +12,7 @@ class YahooFinanceProvider(BuiltinToolProviderController):
}
).invoke(
user_id='',
tool_paramters={
tool_parameters={
"ticker": "MSFT",
},
)

View File

@@ -7,23 +7,23 @@ from datetime import datetime
from googleapiclient.discovery import build
class YoutubeVideosAnalyticsTool(BuiltinTool):
def _invoke(self, user_id: str, tool_paramters: Dict[str, Any]) \
def _invoke(self, user_id: str, tool_parameters: Dict[str, Any]) \
-> Union[ToolInvokeMessage, List[ToolInvokeMessage]]:
"""
invoke tools
"""
channel = tool_paramters.get('channel', '')
channel = tool_parameters.get('channel', '')
if not channel:
return self.create_text_message('Please input symbol')
time_range = [None, None]
start_date = tool_paramters.get('start_date', '')
start_date = tool_parameters.get('start_date', '')
if start_date:
time_range[0] = start_date
else:
time_range[0] = '1800-01-01'
end_date = tool_paramters.get('end_date', '')
end_date = tool_parameters.get('end_date', '')
if end_date:
time_range[1] = end_date
else:

View File

@@ -12,7 +12,7 @@ class YahooFinanceProvider(BuiltinToolProviderController):
}
).invoke(
user_id='',
tool_paramters={
tool_parameters={
"channel": "TOKYO GIRLS COLLECTION",
"start_date": "2020-01-01",
"end_date": "2024-12-31",

View File

@@ -5,13 +5,13 @@ from os import path, listdir
from yaml import load, FullLoader
from core.tools.entities.tool_entities import ToolProviderType, \
ToolParamter, ToolProviderCredentials
ToolParameter, ToolProviderCredentials
from core.tools.tool.tool import Tool
from core.tools.tool.builtin_tool import BuiltinTool
from core.tools.provider.tool_provider import ToolProviderController
from core.tools.entities.user_entities import UserToolProviderCredentials
from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError, \
ToolParamterValidationError, ToolProviderCredentialValidationError
ToolParameterValidationError, ToolProviderCredentialValidationError
import importlib
@@ -109,7 +109,7 @@ class BuiltinToolProviderController(ToolProviderController):
"""
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
def get_parameters(self, tool_name: str) -> List[ToolParamter]:
def get_parameters(self, tool_name: str) -> List[ToolParameter]:
"""
returns the parameters of the tool
@@ -148,62 +148,62 @@ class BuiltinToolProviderController(ToolProviderController):
"""
tool_parameters_schema = self.get_parameters(tool_name)
tool_parameters_need_to_validate: Dict[str, ToolParamter] = {}
tool_parameters_need_to_validate: Dict[str, ToolParameter] = {}
for parameter in tool_parameters_schema:
tool_parameters_need_to_validate[parameter.name] = parameter
for parameter in tool_parameters:
if parameter not in tool_parameters_need_to_validate:
raise ToolParamterValidationError(f'parameter {parameter} not found in tool {tool_name}')
raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}')
# check type
parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.type == ToolParamter.ToolParameterType.STRING:
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
if not isinstance(tool_parameters[parameter], str):
raise ToolParamterValidationError(f'parameter {parameter} should be string')
raise ToolParameterValidationError(f'parameter {parameter} should be string')
elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER:
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
if not isinstance(tool_parameters[parameter], (int, float)):
raise ToolParamterValidationError(f'parameter {parameter} should be number')
raise ToolParameterValidationError(f'parameter {parameter} should be number')
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
raise ToolParamterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
raise ToolParamterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN:
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
if not isinstance(tool_parameters[parameter], bool):
raise ToolParamterValidationError(f'parameter {parameter} should be boolean')
raise ToolParameterValidationError(f'parameter {parameter} should be boolean')
elif parameter_schema.type == ToolParamter.ToolParameterType.SELECT:
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
if not isinstance(tool_parameters[parameter], str):
raise ToolParamterValidationError(f'parameter {parameter} should be string')
raise ToolParameterValidationError(f'parameter {parameter} should be string')
options = parameter_schema.options
if not isinstance(options, list):
raise ToolParamterValidationError(f'parameter {parameter} options should be list')
raise ToolParameterValidationError(f'parameter {parameter} options should be list')
if tool_parameters[parameter] not in [x.value for x in options]:
raise ToolParamterValidationError(f'parameter {parameter} should be one of {options}')
raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}')
tool_parameters_need_to_validate.pop(parameter)
for parameter in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.required:
raise ToolParamterValidationError(f'parameter {parameter} is required')
raise ToolParameterValidationError(f'parameter {parameter} is required')
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
default_value = parameter_schema.default
# parse default value into the correct type
if parameter_schema.type == ToolParamter.ToolParameterType.STRING or \
parameter_schema.type == ToolParamter.ToolParameterType.SELECT:
if parameter_schema.type == ToolParameter.ToolParameterType.STRING or \
parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
default_value = str(default_value)
elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER:
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
default_value = float(default_value)
elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN:
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
default_value = bool(default_value)
tool_parameters[parameter] = default_value

View File

@@ -4,11 +4,11 @@ from typing import List, Dict, Any, Optional
from pydantic import BaseModel
from core.tools.entities.tool_entities import ToolProviderType, \
ToolProviderIdentity, ToolParamter, ToolProviderCredentials
ToolProviderIdentity, ToolParameter, ToolProviderCredentials
from core.tools.tool.tool import Tool
from core.tools.entities.user_entities import UserToolProviderCredentials
from core.tools.errors import ToolNotFoundError, \
ToolParamterValidationError, ToolProviderCredentialValidationError
ToolParameterValidationError, ToolProviderCredentialValidationError
class ToolProviderController(BaseModel, ABC):
identity: Optional[ToolProviderIdentity] = None
@@ -50,7 +50,7 @@ class ToolProviderController(BaseModel, ABC):
"""
pass
def get_parameters(self, tool_name: str) -> List[ToolParamter]:
def get_parameters(self, tool_name: str) -> List[ToolParameter]:
"""
returns the parameters of the tool
@@ -80,62 +80,62 @@ class ToolProviderController(BaseModel, ABC):
"""
tool_parameters_schema = self.get_parameters(tool_name)
tool_parameters_need_to_validate: Dict[str, ToolParamter] = {}
tool_parameters_need_to_validate: Dict[str, ToolParameter] = {}
for parameter in tool_parameters_schema:
tool_parameters_need_to_validate[parameter.name] = parameter
for parameter in tool_parameters:
if parameter not in tool_parameters_need_to_validate:
raise ToolParamterValidationError(f'parameter {parameter} not found in tool {tool_name}')
raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}')
# check type
parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.type == ToolParamter.ToolParameterType.STRING:
if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
if not isinstance(tool_parameters[parameter], str):
raise ToolParamterValidationError(f'parameter {parameter} should be string')
raise ToolParameterValidationError(f'parameter {parameter} should be string')
elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER:
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
if not isinstance(tool_parameters[parameter], (int, float)):
raise ToolParamterValidationError(f'parameter {parameter} should be number')
raise ToolParameterValidationError(f'parameter {parameter} should be number')
if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
raise ToolParamterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
raise ToolParamterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN:
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
if not isinstance(tool_parameters[parameter], bool):
raise ToolParamterValidationError(f'parameter {parameter} should be boolean')
raise ToolParameterValidationError(f'parameter {parameter} should be boolean')
elif parameter_schema.type == ToolParamter.ToolParameterType.SELECT:
elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
if not isinstance(tool_parameters[parameter], str):
raise ToolParamterValidationError(f'parameter {parameter} should be string')
raise ToolParameterValidationError(f'parameter {parameter} should be string')
options = parameter_schema.options
if not isinstance(options, list):
raise ToolParamterValidationError(f'parameter {parameter} options should be list')
raise ToolParameterValidationError(f'parameter {parameter} options should be list')
if tool_parameters[parameter] not in [x.value for x in options]:
raise ToolParamterValidationError(f'parameter {parameter} should be one of {options}')
raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}')
tool_parameters_need_to_validate.pop(parameter)
for parameter in tool_parameters_need_to_validate:
parameter_schema = tool_parameters_need_to_validate[parameter]
if parameter_schema.required:
raise ToolParamterValidationError(f'parameter {parameter} is required')
raise ToolParameterValidationError(f'parameter {parameter} is required')
# the parameter is not set currently, set the default value if needed
if parameter_schema.default is not None:
default_value = parameter_schema.default
# parse default value into the correct type
if parameter_schema.type == ToolParamter.ToolParameterType.STRING or \
parameter_schema.type == ToolParamter.ToolParameterType.SELECT:
if parameter_schema.type == ToolParameter.ToolParameterType.STRING or \
parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
default_value = str(default_value)
elif parameter_schema.type == ToolParamter.ToolParameterType.NUMBER:
elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
default_value = float(default_value)
elif parameter_schema.type == ToolParamter.ToolParameterType.BOOLEAN:
elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
default_value = bool(default_value)
tool_parameters[parameter] = default_value