chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -56,12 +56,13 @@ class ToolConfigurationManager(BaseModel):
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
if field_name in credentials:
if len(credentials[field_name]) > 6:
credentials[field_name] = \
credentials[field_name][:2] + \
'*' * (len(credentials[field_name]) - 4) + \
credentials[field_name][-2:]
credentials[field_name] = (
credentials[field_name][:2]
+ "*" * (len(credentials[field_name]) - 4)
+ credentials[field_name][-2:]
)
else:
credentials[field_name] = '*' * len(credentials[field_name])
credentials[field_name] = "*" * len(credentials[field_name])
return credentials
@@ -72,9 +73,9 @@ class ToolConfigurationManager(BaseModel):
return a deep copy of credentials with decrypted values
"""
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
tenant_id=self.tenant_id,
identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cached_credentials = cache.get()
if cached_credentials:
@@ -95,16 +96,18 @@ class ToolConfigurationManager(BaseModel):
def delete_tool_credentials_cache(self):
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
cache_type=ToolProviderCredentialsCacheType.PROVIDER
tenant_id=self.tenant_id,
identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}",
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cache.delete()
class ToolParameterConfigurationManager(BaseModel):
"""
Tool parameter configuration manager
"""
tenant_id: str
tool_runtime: Tool
provider_name: str
@@ -152,15 +155,19 @@ class ToolParameterConfigurationManager(BaseModel):
current_parameters = self._merge_parameters()
for parameter in current_parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
if len(parameters[parameter.name]) > 6:
parameters[parameter.name] = \
parameters[parameter.name][:2] + \
'*' * (len(parameters[parameter.name]) - 4) + \
parameters[parameter.name][-2:]
parameters[parameter.name] = (
parameters[parameter.name][:2]
+ "*" * (len(parameters[parameter.name]) - 4)
+ parameters[parameter.name][-2:]
)
else:
parameters[parameter.name] = '*' * len(parameters[parameter.name])
parameters[parameter.name] = "*" * len(parameters[parameter.name])
return parameters
@@ -176,7 +183,10 @@ class ToolParameterConfigurationManager(BaseModel):
parameters = self._deep_copy(parameters)
for parameter in current_parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
parameters[parameter.name] = encrypted
@@ -191,10 +201,10 @@ class ToolParameterConfigurationManager(BaseModel):
"""
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
provider=f"{self.provider_type}.{self.provider_name}",
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id
identity_id=self.identity_id,
)
cached_parameters = cache.get()
if cached_parameters:
@@ -205,7 +215,10 @@ class ToolParameterConfigurationManager(BaseModel):
has_secret_input = False
for parameter in current_parameters:
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
if (
parameter.form == ToolParameter.ToolParameterForm.FORM
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
):
if parameter.name in parameters:
try:
has_secret_input = True
@@ -221,9 +234,9 @@ class ToolParameterConfigurationManager(BaseModel):
def delete_tool_parameters_cache(self):
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f'{self.provider_type}.{self.provider_name}',
provider=f"{self.provider_type}.{self.provider_name}",
tool_name=self.tool_runtime.identity.name,
cache_type=ToolParameterCacheType.PARAMETER,
identity_id=self.identity_id
identity_id=self.identity_id,
)
cache.delete()

View File

@@ -17,8 +17,9 @@ class FeishuRequest:
redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token"))
return res.get("tenant_access_token")
def _send_request(self, url: str, method: str = "post", require_token: bool = True, payload: dict = None,
params: dict = None):
def _send_request(
self, url: str, method: str = "post", require_token: bool = True, payload: dict = None, params: dict = None
):
headers = {
"Content-Type": "application/json",
"user-agent": "Dify",
@@ -42,10 +43,7 @@ class FeishuRequest:
}
"""
url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/access_token/get_tenant_access_token"
payload = {
"app_id": app_id,
"app_secret": app_secret
}
payload = {"app_id": app_id, "app_secret": app_secret}
res = self._send_request(url, require_token=False, payload=payload)
return res
@@ -76,11 +74,7 @@ class FeishuRequest:
def write_document(self, document_id: str, content: str, position: str = "start") -> dict:
url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/write_document"
payload = {
"document_id": document_id,
"content": content,
"position": position
}
payload = {"document_id": document_id, "content": content, "position": position}
res = self._send_request(url, payload=payload)
return res.get("data")

View File

@@ -10,10 +10,9 @@ logger = logging.getLogger(__name__)
class ToolFileMessageTransformer:
@classmethod
def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage],
user_id: str,
tenant_id: str,
conversation_id: str) -> list[ToolInvokeMessage]:
def transform_tool_invoke_messages(
cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str
) -> list[ToolInvokeMessage]:
"""
Transform tool message and handle file download
"""
@@ -28,78 +27,88 @@ class ToolFileMessageTransformer:
# try to download image
try:
file = ToolFileManager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_url=message.message
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_url=message.message
)
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
result.append(
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
)
)
except Exception as e:
logger.exception(e)
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=f"Failed to download image: {message.message}, you can try to download it yourself.",
meta=message.meta.copy() if message.meta is not None else {},
save_as=message.save_as,
))
result.append(
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT,
message=f"Failed to download image: {message.message}, you can try to download it yourself.",
meta=message.meta.copy() if message.meta is not None else {},
save_as=message.save_as,
)
)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
# get mime type and save blob to storage
mimetype = message.meta.get('mime_type', 'octet/stream')
mimetype = message.meta.get("mime_type", "octet/stream")
# if message is str, encode it to bytes
if isinstance(message.message, str):
message.message = message.message.encode('utf-8')
message.message = message.message.encode("utf-8")
file = ToolFileManager.create_file_by_raw(
user_id=user_id, tenant_id=tenant_id,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
file_binary=message.message,
mimetype=mimetype
mimetype=mimetype,
)
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
# check if file is image
if 'image' in mimetype:
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
if "image" in mimetype:
result.append(
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
)
)
else:
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
result.append(
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
)
)
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
file_var = message.meta.get('file_var')
file_var = message.meta.get("file_var")
if file_var:
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
if file_var.type == FileType.IMAGE:
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
result.append(
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
)
)
else:
result.append(ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
))
result.append(
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.LINK,
message=url,
save_as=message.save_as,
meta=message.meta.copy() if message.meta is not None else {},
)
)
else:
result.append(message)

View File

@@ -1,7 +1,7 @@
"""
For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
Therefore, a model manager is needed to list/invoke/validate models.
Therefore, a model manager is needed to list/invoke/validate models.
"""
import json
@@ -27,52 +27,49 @@ from models.tools import ToolModelInvoke
class InvokeModelError(Exception):
pass
class ModelInvocationUtils:
@staticmethod
def get_max_llm_context_tokens(
tenant_id: str,
) -> int:
"""
get max llm context tokens of the model
get max llm context tokens of the model
"""
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM,
tenant_id=tenant_id,
model_type=ModelType.LLM,
)
if not model_instance:
raise InvokeModelError('Model not found')
raise InvokeModelError("Model not found")
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
if not schema:
raise InvokeModelError('No model schema found')
raise InvokeModelError("No model schema found")
max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
if max_tokens is None:
return 2048
return max_tokens
@staticmethod
def calculate_tokens(
tenant_id: str,
prompt_messages: list[PromptMessage]
) -> int:
def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
"""
calculate tokens from prompt messages and model parameters
calculate tokens from prompt messages and model parameters
"""
# get model instance
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM
)
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
if not model_instance:
raise InvokeModelError('Model not found')
raise InvokeModelError("Model not found")
# get tokens
tokens = model_instance.get_llm_num_tokens(prompt_messages)
@@ -80,9 +77,7 @@ class ModelInvocationUtils:
@staticmethod
def invoke(
user_id: str, tenant_id: str,
tool_type: str, tool_name: str,
prompt_messages: list[PromptMessage]
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
) -> LLMResult:
"""
invoke model with parameters in user's own context
@@ -103,15 +98,16 @@ class ModelInvocationUtils:
model_manager = ModelManager()
# get model instance
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id, model_type=ModelType.LLM,
tenant_id=tenant_id,
model_type=ModelType.LLM,
)
# get prompt tokens
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
model_parameters = {
'temperature': 0.8,
'top_p': 0.8,
"temperature": 0.8,
"top_p": 0.8,
}
# create tool model invoke
@@ -123,14 +119,14 @@ class ModelInvocationUtils:
tool_name=tool_name,
model_parameters=json.dumps(model_parameters),
prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
model_response='',
model_response="",
prompt_tokens=prompt_tokens,
answer_tokens=0,
answer_unit_price=0,
answer_price_unit=0,
provider_response_latency=0,
total_price=0,
currency='USD',
currency="USD",
)
db.session.add(tool_model_invoke)
@@ -140,20 +136,24 @@ class ModelInvocationUtils:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[], stop=[], stream=False, user=user_id, callbacks=[]
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
)
except InvokeRateLimitError as e:
raise InvokeModelError(f'Invoke rate limit error: {e}')
raise InvokeModelError(f"Invoke rate limit error: {e}")
except InvokeBadRequestError as e:
raise InvokeModelError(f'Invoke bad request error: {e}')
raise InvokeModelError(f"Invoke bad request error: {e}")
except InvokeConnectionError as e:
raise InvokeModelError(f'Invoke connection error: {e}')
raise InvokeModelError(f"Invoke connection error: {e}")
except InvokeAuthorizationError as e:
raise InvokeModelError('Invoke authorization error')
raise InvokeModelError("Invoke authorization error")
except InvokeServerUnavailableError as e:
raise InvokeModelError(f'Invoke server unavailable error: {e}')
raise InvokeModelError(f"Invoke server unavailable error: {e}")
except Exception as e:
raise InvokeModelError(f'Invoke error: {e}')
raise InvokeModelError(f"Invoke error: {e}")
# update tool model invoke
tool_model_invoke.model_response = response.message.content

View File

@@ -1,4 +1,3 @@
import re
import uuid
from json import dumps as json_dumps
@@ -16,54 +15,56 @@ from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolPro
class ApiBasedToolSchemaParser:
@staticmethod
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
def parse_openapi_to_tool_bundle(
openapi: dict, extra_info: dict = None, warning: dict = None
) -> list[ApiToolBundle]:
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
# set description to extra_info
extra_info['description'] = openapi['info'].get('description', '')
extra_info["description"] = openapi["info"].get("description", "")
if len(openapi['servers']) == 0:
raise ToolProviderNotFoundError('No server found in the openapi yaml.')
if len(openapi["servers"]) == 0:
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
server_url = openapi['servers'][0]['url']
server_url = openapi["servers"][0]["url"]
# list all interfaces
interfaces = []
for path, path_item in openapi['paths'].items():
methods = ['get', 'post', 'put', 'delete', 'patch', 'head', 'options', 'trace']
for path, path_item in openapi["paths"].items():
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
for method in methods:
if method in path_item:
interfaces.append({
'path': path,
'method': method,
'operation': path_item[method],
})
interfaces.append(
{
"path": path,
"method": method,
"operation": path_item[method],
}
)
# get all parameters
bundles = []
for interface in interfaces:
# convert parameters
parameters = []
if 'parameters' in interface['operation']:
for parameter in interface['operation']['parameters']:
if "parameters" in interface["operation"]:
for parameter in interface["operation"]["parameters"]:
tool_parameter = ToolParameter(
name=parameter['name'],
label=I18nObject(
en_US=parameter['name'],
zh_Hans=parameter['name']
),
name=parameter["name"],
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
human_description=I18nObject(
en_US=parameter.get('description', ''),
zh_Hans=parameter.get('description', '')
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=parameter.get('required', False),
required=parameter.get("required", False),
form=ToolParameter.ToolParameterForm.LLM,
llm_description=parameter.get('description'),
default=parameter['schema']['default'] if 'schema' in parameter and 'default' in parameter['schema'] else None,
llm_description=parameter.get("description"),
default=parameter["schema"]["default"]
if "schema" in parameter and "default" in parameter["schema"]
else None,
)
# check if there is a type
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
if typ:
@@ -72,44 +73,40 @@ class ApiBasedToolSchemaParser:
parameters.append(tool_parameter)
# create tool bundle
# check if there is a request body
if 'requestBody' in interface['operation']:
request_body = interface['operation']['requestBody']
if 'content' in request_body:
for content_type, content in request_body['content'].items():
if "requestBody" in interface["operation"]:
request_body = interface["operation"]["requestBody"]
if "content" in request_body:
for content_type, content in request_body["content"].items():
# if there is a reference, get the reference and overwrite the content
if 'schema' not in content:
if "schema" not in content:
continue
if '$ref' in content['schema']:
if "$ref" in content["schema"]:
# get the reference
root = openapi
reference = content['schema']['$ref'].split('/')[1:]
reference = content["schema"]["$ref"].split("/")[1:]
for ref in reference:
root = root[ref]
# overwrite the content
interface['operation']['requestBody']['content'][content_type]['schema'] = root
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
# parse body parameters
if 'schema' in interface['operation']['requestBody']['content'][content_type]:
body_schema = interface['operation']['requestBody']['content'][content_type]['schema']
required = body_schema.get('required', [])
properties = body_schema.get('properties', {})
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
required = body_schema.get("required", [])
properties = body_schema.get("properties", {})
for name, property in properties.items():
tool = ToolParameter(
name=name,
label=I18nObject(
en_US=name,
zh_Hans=name
),
label=I18nObject(en_US=name, zh_Hans=name),
human_description=I18nObject(
en_US=property.get('description', ''),
zh_Hans=property.get('description', '')
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
),
type=ToolParameter.ToolParameterType.STRING,
required=name in required,
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get('description', ''),
default=property.get('default', None),
llm_description=property.get("description", ""),
default=property.get("default", None),
)
# check if there is a type
@@ -127,172 +124,176 @@ class ApiBasedToolSchemaParser:
parameters_count[parameter.name] += 1
for name, count in parameters_count.items():
if count > 1:
warning['duplicated_parameter'] = f'Parameter {name} is duplicated.'
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
# check if there is a operation id, use $path_$method as operation id if not
if 'operationId' not in interface['operation']:
if "operationId" not in interface["operation"]:
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = interface['path']
if interface['path'].startswith('/'):
path = interface['path'][1:]
path = interface["path"]
if interface["path"].startswith("/"):
path = interface["path"][1:]
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = re.sub(r'[^a-zA-Z0-9_-]', '', path)
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
if not path:
path = str(uuid.uuid4())
interface['operation']['operationId'] = f'{path}_{interface["method"]}'
bundles.append(ApiToolBundle(
server_url=server_url + interface['path'],
method=interface['method'],
summary=interface['operation']['description'] if 'description' in interface['operation'] else
interface['operation'].get('summary', None),
operation_id=interface['operation']['operationId'],
parameters=parameters,
author='',
icon=None,
openapi=interface['operation'],
))
interface["operation"]["operationId"] = f'{path}_{interface["method"]}'
bundles.append(
ApiToolBundle(
server_url=server_url + interface["path"],
method=interface["method"],
summary=interface["operation"]["description"]
if "description" in interface["operation"]
else interface["operation"].get("summary", None),
operation_id=interface["operation"]["operationId"],
parameters=parameters,
author="",
icon=None,
openapi=interface["operation"],
)
)
return bundles
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType:
parameter = parameter or {}
typ = None
if 'type' in parameter:
typ = parameter['type']
elif 'schema' in parameter and 'type' in parameter['schema']:
typ = parameter['schema']['type']
if typ == 'integer' or typ == 'number':
if "type" in parameter:
typ = parameter["type"]
elif "schema" in parameter and "type" in parameter["schema"]:
typ = parameter["schema"]["type"]
if typ == "integer" or typ == "number":
return ToolParameter.ToolParameterType.NUMBER
elif typ == 'boolean':
elif typ == "boolean":
return ToolParameter.ToolParameterType.BOOLEAN
elif typ == 'string':
elif typ == "string":
return ToolParameter.ToolParameterType.STRING
@staticmethod
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
def parse_openapi_yaml_to_tool_bundle(
yaml: str, extra_info: dict = None, warning: dict = None
) -> list[ApiToolBundle]:
"""
parse openapi yaml to tool bundle
parse openapi yaml to tool bundle
:param yaml: the yaml string
:return: the tool bundle
:param yaml: the yaml string
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
openapi: dict = safe_load(yaml)
if openapi is None:
raise ToolApiSchemaError('Invalid openapi yaml.')
raise ToolApiSchemaError("Invalid openapi yaml.")
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
@staticmethod
def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict:
"""
parse swagger to openapi
parse swagger to openapi
:param swagger: the swagger dict
:return: the openapi dict
:param swagger: the swagger dict
:return: the openapi dict
"""
# convert swagger to openapi
info = swagger.get('info', {
'title': 'Swagger',
'description': 'Swagger',
'version': '1.0.0'
})
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
servers = swagger.get('servers', [])
servers = swagger.get("servers", [])
if len(servers) == 0:
raise ToolApiSchemaError('No server found in the swagger yaml.')
raise ToolApiSchemaError("No server found in the swagger yaml.")
openapi = {
'openapi': '3.0.0',
'info': {
'title': info.get('title', 'Swagger'),
'description': info.get('description', 'Swagger'),
'version': info.get('version', '1.0.0')
"openapi": "3.0.0",
"info": {
"title": info.get("title", "Swagger"),
"description": info.get("description", "Swagger"),
"version": info.get("version", "1.0.0"),
},
'servers': swagger['servers'],
'paths': {},
'components': {
'schemas': {}
}
"servers": swagger["servers"],
"paths": {},
"components": {"schemas": {}},
}
# check paths
if 'paths' not in swagger or len(swagger['paths']) == 0:
raise ToolApiSchemaError('No paths found in the swagger yaml.')
if "paths" not in swagger or len(swagger["paths"]) == 0:
raise ToolApiSchemaError("No paths found in the swagger yaml.")
# convert paths
for path, path_item in swagger['paths'].items():
openapi['paths'][path] = {}
for path, path_item in swagger["paths"].items():
openapi["paths"][path] = {}
for method, operation in path_item.items():
if 'operationId' not in operation:
raise ToolApiSchemaError(f'No operationId found in operation {method} {path}.')
if ('summary' not in operation or len(operation['summary']) == 0) and \
('description' not in operation or len(operation['description']) == 0):
warning['missing_summary'] = f'No summary or description found in operation {method} {path}.'
openapi['paths'][path][method] = {
'operationId': operation['operationId'],
'summary': operation.get('summary', ''),
'description': operation.get('description', ''),
'parameters': operation.get('parameters', []),
'responses': operation.get('responses', {}),
if "operationId" not in operation:
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
if ("summary" not in operation or len(operation["summary"]) == 0) and (
"description" not in operation or len(operation["description"]) == 0
):
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
openapi["paths"][path][method] = {
"operationId": operation["operationId"],
"summary": operation.get("summary", ""),
"description": operation.get("description", ""),
"parameters": operation.get("parameters", []),
"responses": operation.get("responses", {}),
}
if 'requestBody' in operation:
openapi['paths'][path][method]['requestBody'] = operation['requestBody']
if "requestBody" in operation:
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
# convert definitions
for name, definition in swagger['definitions'].items():
openapi['components']['schemas'][name] = definition
for name, definition in swagger["definitions"].items():
openapi["components"]["schemas"][name] = definition
return openapi
@staticmethod
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
def parse_openai_plugin_json_to_tool_bundle(
json: str, extra_info: dict = None, warning: dict = None
) -> list[ApiToolBundle]:
"""
parse openapi plugin yaml to tool bundle
parse openapi plugin yaml to tool bundle
:param json: the json string
:return: the tool bundle
:param json: the json string
:return: the tool bundle
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
try:
openai_plugin = json_loads(json)
api = openai_plugin['api']
api_url = api['url']
api_type = api['type']
api = openai_plugin["api"]
api_url = api["url"]
api_type = api["type"]
except:
raise ToolProviderNotFoundError('Invalid openai plugin json.')
if api_type != 'openapi':
raise ToolNotSupportedError('Only openapi is supported now.')
raise ToolProviderNotFoundError("Invalid openai plugin json.")
if api_type != "openapi":
raise ToolNotSupportedError("Only openapi is supported now.")
# get openapi yaml
response = get(api_url, headers={
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) '
}, timeout=5)
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
if response.status_code != 200:
raise ToolProviderNotFoundError('cannot get openapi yaml from url.')
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning)
@staticmethod
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]:
"""
auto parse to tool bundle
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
:param content: the content
:return: tools bundle, schema_type
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
response.text, extra_info=extra_info, warning=warning
)
@staticmethod
def auto_parse_to_tool_bundle(
content: str, extra_info: dict = None, warning: dict = None
) -> tuple[list[ApiToolBundle], str]:
"""
auto parse to tool bundle
:param content: the content
:return: tools bundle, schema_type
"""
warning = warning if warning is not None else {}
extra_info = extra_info if extra_info is not None else {}
@@ -301,7 +302,7 @@ class ApiBasedToolSchemaParser:
loaded_content = None
json_error = None
yaml_error = None
try:
loaded_content = json_loads(content)
except JSONDecodeError as e:
@@ -313,34 +314,46 @@ class ApiBasedToolSchemaParser:
except YAMLError as e:
yaml_error = e
if loaded_content is None:
raise ToolApiSchemaError(f'Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}, yaml error: {str(yaml_error)}')
raise ToolApiSchemaError(
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}, yaml error: {str(yaml_error)}"
)
swagger_error = None
openapi_error = None
openapi_plugin_error = None
schema_type = None
try:
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(loaded_content, extra_info=extra_info, warning=warning)
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.OPENAPI.value
return openapi, schema_type
except ToolApiSchemaError as e:
openapi_error = e
# openai parse error, fallback to swagger
try:
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(loaded_content, extra_info=extra_info, warning=warning)
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
loaded_content, extra_info=extra_info, warning=warning
)
schema_type = ApiProviderSchemaType.SWAGGER.value
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(converted_swagger, extra_info=extra_info, warning=warning), schema_type
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
converted_swagger, extra_info=extra_info, warning=warning
), schema_type
except ToolApiSchemaError as e:
swagger_error = e
# swagger parse error, fallback to openai plugin
try:
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(json_dumps(loaded_content), extra_info=extra_info, warning=warning)
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
json_dumps(loaded_content), extra_info=extra_info, warning=warning
)
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
except ToolNotSupportedError as e:
# maybe it's not plugin at all
openapi_plugin_error = e
raise ToolApiSchemaError(f'Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}, openapi plugin error: {str(openapi_plugin_error)}')
raise ToolApiSchemaError(
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}, openapi plugin error: {str(openapi_plugin_error)}"
)

View File

@@ -7,16 +7,18 @@ class ToolParameterConverter:
@staticmethod
def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str:
match parameter_type:
case ToolParameter.ToolParameterType.STRING \
| ToolParameter.ToolParameterType.SECRET_INPUT \
| ToolParameter.ToolParameterType.SELECT:
return 'string'
case (
ToolParameter.ToolParameterType.STRING
| ToolParameter.ToolParameterType.SECRET_INPUT
| ToolParameter.ToolParameterType.SELECT
):
return "string"
case ToolParameter.ToolParameterType.BOOLEAN:
return 'boolean'
return "boolean"
case ToolParameter.ToolParameterType.NUMBER:
return 'number'
return "number"
case _:
raise ValueError(f"Unsupported parameter type {parameter_type}")
@@ -26,11 +28,13 @@ class ToolParameterConverter:
# convert tool parameter config to correct type
try:
match parameter_type:
case ToolParameter.ToolParameterType.STRING \
| ToolParameter.ToolParameterType.SECRET_INPUT \
| ToolParameter.ToolParameterType.SELECT:
case (
ToolParameter.ToolParameterType.STRING
| ToolParameter.ToolParameterType.SECRET_INPUT
| ToolParameter.ToolParameterType.SELECT
):
if value is None:
return ''
return ""
else:
return value if isinstance(value, str) else str(value)
@@ -41,9 +45,9 @@ class ToolParameterConverter:
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
# and also '0' for False and '1' for True
match value.lower():
case 'true' | 'yes' | 'y' | '1':
case "true" | "yes" | "y" | "1":
return True
case 'false' | 'no' | 'n' | '0':
case "false" | "no" | "n" | "0":
return False
case _:
return bool(value)
@@ -53,8 +57,8 @@ class ToolParameterConverter:
case ToolParameter.ToolParameterType.NUMBER:
if isinstance(value, int) | isinstance(value, float):
return value
elif isinstance(value, str) and value != '':
if '.' in value:
elif isinstance(value, str) and value != "":
if "." in value:
return float(value)
else:
return int(value)

View File

@@ -32,7 +32,7 @@ TEXT:
def page_result(text: str, cursor: int, max_length: int) -> str:
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
return text[cursor: cursor + max_length]
return text[cursor : cursor + max_length]
def get_url(url: str, user_agent: str = None) -> str:
@@ -49,15 +49,15 @@ def get_url(url: str, user_agent: str = None) -> str:
if response.status_code == 200:
# check content-type
content_type = response.headers.get('Content-Type')
content_type = response.headers.get("Content-Type")
if content_type:
main_content_type = response.headers.get('Content-Type').split(';')[0].strip()
main_content_type = response.headers.get("Content-Type").split(";")[0].strip()
else:
content_disposition = response.headers.get('Content-Disposition', '')
content_disposition = response.headers.get("Content-Disposition", "")
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
extension = re.search(r'\.(\w+)$', filename)
extension = re.search(r"\.(\w+)$", filename)
if extension:
main_content_type = mimetypes.guess_type(filename)[0]
@@ -78,7 +78,7 @@ def get_url(url: str, user_agent: str = None) -> str:
# Detect encoding using chardet
detected_encoding = chardet.detect(response.content)
encoding = detected_encoding['encoding']
encoding = detected_encoding["encoding"]
if encoding:
try:
content = response.content.decode(encoding)
@@ -89,29 +89,29 @@ def get_url(url: str, user_agent: str = None) -> str:
a = extract_using_readabilipy(content)
if not a['plain_text'] or not a['plain_text'].strip():
return ''
if not a["plain_text"] or not a["plain_text"].strip():
return ""
res = FULL_TEMPLATE.format(
title=a['title'],
authors=a['byline'],
publish_date=a['date'],
title=a["title"],
authors=a["byline"],
publish_date=a["date"],
top_image="",
text=a['plain_text'] if a['plain_text'] else "",
text=a["plain_text"] if a["plain_text"] else "",
)
return res
def extract_using_readabilipy(html):
with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html:
with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html:
f_html.write(html)
f_html.close()
html_path = f_html.name
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
article_json_path = html_path + ".json"
jsdir = os.path.join(find_module_path('readabilipy'), 'javascript')
jsdir = os.path.join(find_module_path("readabilipy"), "javascript")
with chdir(jsdir):
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
@@ -129,7 +129,7 @@ def extract_using_readabilipy(html):
"date": None,
"content": None,
"plain_content": None,
"plain_text": None
"plain_text": None,
}
# Populate article fields from readability fields where present
if input_json:
@@ -145,7 +145,7 @@ def extract_using_readabilipy(html):
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
if input_json.get("textContent"):
article_json["plain_text"] = input_json["textContent"]
article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"])
article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"])
return article_json
@@ -158,6 +158,7 @@ def find_module_path(module_name):
return None
@contextmanager
def chdir(path):
"""Change directory in context and return to original on exit"""
@@ -172,12 +173,14 @@ def chdir(path):
def extract_text_blocks_as_plain_text(paragraph_html):
# Load article as DOM
soup = BeautifulSoup(paragraph_html, 'html.parser')
soup = BeautifulSoup(paragraph_html, "html.parser")
# Select all lists
list_elements = soup.find_all(['ul', 'ol'])
list_elements = soup.find_all(["ul", "ol"])
# Prefix text in all list items with "* " and make lists paragraphs
for list_element in list_elements:
plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')])))
plain_items = "".join(
list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")]))
)
list_element.string = plain_items
list_element.name = "p"
# Select all text blocks
@@ -204,7 +207,7 @@ def plain_text_leaf_node(element):
def plain_content(readability_content, content_digests, node_indexes):
# Load article as DOM
soup = BeautifulSoup(readability_content, 'html.parser')
soup = BeautifulSoup(readability_content, "html.parser")
# Make all elements plain
elements = plain_elements(soup.contents, content_digests, node_indexes)
if node_indexes:
@@ -217,8 +220,7 @@ def plain_content(readability_content, content_digests, node_indexes):
def plain_elements(elements, content_digests, node_indexes):
# Get plain content versions of all elements
elements = [plain_element(element, content_digests, node_indexes)
for element in elements]
elements = [plain_element(element, content_digests, node_indexes) for element in elements]
if content_digests:
# Add content digest attribute to nodes
elements = [add_content_digest(element) for element in elements]
@@ -258,11 +260,9 @@ def add_node_indexes(element, node_index="0"):
# Add index to current element
element["data-node-index"] = node_index
# Add index to child elements
for local_idx, child in enumerate(
[c for c in element.contents if not is_text(c)], start=1):
for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1):
# Can't add attributes to leaf string types
child_index = "{stem}.{local}".format(
stem=node_index, local=local_idx)
child_index = "{stem}.{local}".format(stem=node_index, local=local_idx)
add_node_indexes(child, node_index=child_index)
return element
@@ -284,11 +284,16 @@ def strip_control_characters(text):
# [Cn]: Other, Not Assigned
# [Co]: Other, Private Use
# [Cs]: Other, Surrogate
control_chars = {'Cc', 'Cf', 'Cn', 'Co', 'Cs'}
retained_chars = ['\t', '\n', '\r', '\f']
control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"}
retained_chars = ["\t", "\n", "\r", "\f"]
# Remove non-printing control characters
return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text])
return "".join(
[
"" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char
for char in text
]
)
def normalize_unicode(text):
@@ -305,8 +310,9 @@ def normalize_whitespace(text):
text = text.strip()
return text
def is_leaf(element):
return (element.name in ['p', 'li'])
return element.name in ["p", "li"]
def is_text(element):
@@ -330,7 +336,7 @@ def content_digest(element):
if trimmed_string == "":
digest = ""
else:
digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest()
digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest()
else:
contents = element.contents
num_contents = len(contents)
@@ -343,9 +349,8 @@ def content_digest(element):
else:
# Build content digest from the "non-empty" digests of child nodes
digest = hashlib.sha256()
child_digests = list(
filter(lambda x: x != "", [content_digest(content) for content in contents]))
child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents]))
for child in child_digests:
digest.update(child.encode('utf-8'))
digest.update(child.encode("utf-8"))
digest = digest.hexdigest()
return digest

View File

@@ -10,27 +10,25 @@ class WorkflowToolConfigurationUtils:
"""
for configuration in configurations:
if not WorkflowToolParameterConfiguration(**configuration):
raise ValueError('invalid parameter configuration')
raise ValueError("invalid parameter configuration")
@classmethod
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
"""
get workflow graph variables
"""
nodes = graph.get('nodes', [])
start_node = next(filter(lambda x: x.get('data', {}).get('type') == 'start', nodes), None)
nodes = graph.get("nodes", [])
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
if not start_node:
return []
return [
VariableEntity(**variable) for variable in start_node.get('data', {}).get('variables', [])
]
return [VariableEntity(**variable) for variable in start_node.get("data", {}).get("variables", [])]
@classmethod
def check_is_synced(cls,
variables: list[VariableEntity],
tool_configurations: list[WorkflowToolParameterConfiguration]) -> None:
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
) -> None:
"""
check is synced
@@ -39,10 +37,10 @@ class WorkflowToolConfigurationUtils:
variable_names = [variable.variable for variable in variables]
if len(tool_configurations) != len(variables):
raise ValueError('parameter configuration mismatch, please republish the tool to update')
raise ValueError("parameter configuration mismatch, please republish the tool to update")
for parameter in tool_configurations:
if parameter.name not in variable_names:
raise ValueError('parameter configuration mismatch, please republish the tool to update')
raise ValueError("parameter configuration mismatch, please republish the tool to update")
return True
return True

View File

@@ -18,12 +18,12 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any
:return: an object of the YAML content
"""
try:
with open(file_path, encoding='utf-8') as yaml_file:
with open(file_path, encoding="utf-8") as yaml_file:
try:
yaml_content = yaml.safe_load(yaml_file)
return yaml_content if yaml_content else default_value
except Exception as e:
raise YAMLError(f'Failed to load YAML file {file_path}: {e}')
raise YAMLError(f"Failed to load YAML file {file_path}: {e}")
except Exception as e:
if ignore_error:
return default_value