chore(api/core): apply ruff reformatting (#7624)
This commit is contained in:
@@ -56,12 +56,13 @@ class ToolConfigurationManager(BaseModel):
|
||||
if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
|
||||
if field_name in credentials:
|
||||
if len(credentials[field_name]) > 6:
|
||||
credentials[field_name] = \
|
||||
credentials[field_name][:2] + \
|
||||
'*' * (len(credentials[field_name]) - 4) + \
|
||||
credentials[field_name][-2:]
|
||||
credentials[field_name] = (
|
||||
credentials[field_name][:2]
|
||||
+ "*" * (len(credentials[field_name]) - 4)
|
||||
+ credentials[field_name][-2:]
|
||||
)
|
||||
else:
|
||||
credentials[field_name] = '*' * len(credentials[field_name])
|
||||
credentials[field_name] = "*" * len(credentials[field_name])
|
||||
|
||||
return credentials
|
||||
|
||||
@@ -72,9 +73,9 @@ class ToolConfigurationManager(BaseModel):
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
if cached_credentials:
|
||||
@@ -95,16 +96,18 @@ class ToolConfigurationManager(BaseModel):
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
|
||||
class ToolParameterConfigurationManager(BaseModel):
|
||||
"""
|
||||
Tool parameter configuration manager
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
tool_runtime: Tool
|
||||
provider_name: str
|
||||
@@ -152,15 +155,19 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
current_parameters = self._merge_parameters()
|
||||
|
||||
for parameter in current_parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
if len(parameters[parameter.name]) > 6:
|
||||
parameters[parameter.name] = \
|
||||
parameters[parameter.name][:2] + \
|
||||
'*' * (len(parameters[parameter.name]) - 4) + \
|
||||
parameters[parameter.name][-2:]
|
||||
parameters[parameter.name] = (
|
||||
parameters[parameter.name][:2]
|
||||
+ "*" * (len(parameters[parameter.name]) - 4)
|
||||
+ parameters[parameter.name][-2:]
|
||||
)
|
||||
else:
|
||||
parameters[parameter.name] = '*' * len(parameters[parameter.name])
|
||||
parameters[parameter.name] = "*" * len(parameters[parameter.name])
|
||||
|
||||
return parameters
|
||||
|
||||
@@ -176,7 +183,10 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
parameters = self._deep_copy(parameters)
|
||||
|
||||
for parameter in current_parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
|
||||
parameters[parameter.name] = encrypted
|
||||
@@ -191,10 +201,10 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
"""
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f'{self.provider_type}.{self.provider_name}',
|
||||
provider=f"{self.provider_type}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
cached_parameters = cache.get()
|
||||
if cached_parameters:
|
||||
@@ -205,7 +215,10 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
has_secret_input = False
|
||||
|
||||
for parameter in current_parameters:
|
||||
if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
|
||||
if (
|
||||
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
):
|
||||
if parameter.name in parameters:
|
||||
try:
|
||||
has_secret_input = True
|
||||
@@ -221,9 +234,9 @@ class ToolParameterConfigurationManager(BaseModel):
|
||||
def delete_tool_parameters_cache(self):
|
||||
cache = ToolParameterCache(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=f'{self.provider_type}.{self.provider_name}',
|
||||
provider=f"{self.provider_type}.{self.provider_name}",
|
||||
tool_name=self.tool_runtime.identity.name,
|
||||
cache_type=ToolParameterCacheType.PARAMETER,
|
||||
identity_id=self.identity_id
|
||||
identity_id=self.identity_id,
|
||||
)
|
||||
cache.delete()
|
||||
|
@@ -17,8 +17,9 @@ class FeishuRequest:
|
||||
redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token"))
|
||||
return res.get("tenant_access_token")
|
||||
|
||||
def _send_request(self, url: str, method: str = "post", require_token: bool = True, payload: dict = None,
|
||||
params: dict = None):
|
||||
def _send_request(
|
||||
self, url: str, method: str = "post", require_token: bool = True, payload: dict = None, params: dict = None
|
||||
):
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"user-agent": "Dify",
|
||||
@@ -42,10 +43,7 @@ class FeishuRequest:
|
||||
}
|
||||
"""
|
||||
url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/access_token/get_tenant_access_token"
|
||||
payload = {
|
||||
"app_id": app_id,
|
||||
"app_secret": app_secret
|
||||
}
|
||||
payload = {"app_id": app_id, "app_secret": app_secret}
|
||||
res = self._send_request(url, require_token=False, payload=payload)
|
||||
return res
|
||||
|
||||
@@ -76,11 +74,7 @@ class FeishuRequest:
|
||||
|
||||
def write_document(self, document_id: str, content: str, position: str = "start") -> dict:
|
||||
url = "https://lark-plugin-api.solutionsuite.cn/lark-plugin/document/write_document"
|
||||
payload = {
|
||||
"document_id": document_id,
|
||||
"content": content,
|
||||
"position": position
|
||||
}
|
||||
payload = {"document_id": document_id, "content": content, "position": position}
|
||||
res = self._send_request(url, payload=payload)
|
||||
return res.get("data")
|
||||
|
||||
|
@@ -10,10 +10,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolFileMessageTransformer:
|
||||
@classmethod
|
||||
def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
conversation_id: str) -> list[ToolInvokeMessage]:
|
||||
def transform_tool_invoke_messages(
|
||||
cls, messages: list[ToolInvokeMessage], user_id: str, tenant_id: str, conversation_id: str
|
||||
) -> list[ToolInvokeMessage]:
|
||||
"""
|
||||
Transform tool message and handle file download
|
||||
"""
|
||||
@@ -28,78 +27,88 @@ class ToolFileMessageTransformer:
|
||||
# try to download image
|
||||
try:
|
||||
file = ToolFileManager.create_file_by_url(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_url=message.message
|
||||
user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, file_url=message.message
|
||||
)
|
||||
|
||||
url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
|
||||
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
result.append(
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=f"Failed to download image: {message.message}, you can try to download it yourself.",
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
save_as=message.save_as,
|
||||
))
|
||||
result.append(
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.TEXT,
|
||||
message=f"Failed to download image: {message.message}, you can try to download it yourself.",
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
save_as=message.save_as,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get mime type and save blob to storage
|
||||
mimetype = message.meta.get('mime_type', 'octet/stream')
|
||||
mimetype = message.meta.get("mime_type", "octet/stream")
|
||||
# if message is str, encode it to bytes
|
||||
if isinstance(message.message, str):
|
||||
message.message = message.message.encode('utf-8')
|
||||
message.message = message.message.encode("utf-8")
|
||||
|
||||
file = ToolFileManager.create_file_by_raw(
|
||||
user_id=user_id, tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=conversation_id,
|
||||
file_binary=message.message,
|
||||
mimetype=mimetype
|
||||
mimetype=mimetype,
|
||||
)
|
||||
|
||||
url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
|
||||
|
||||
# check if file is image
|
||||
if 'image' in mimetype:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
if "image" in mimetype:
|
||||
result.append(
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
result.append(
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
|
||||
file_var = message.meta.get('file_var')
|
||||
file_var = message.meta.get("file_var")
|
||||
if file_var:
|
||||
if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
|
||||
if file_var.type == FileType.IMAGE:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
result.append(
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
))
|
||||
result.append(
|
||||
ToolInvokeMessage(
|
||||
type=ToolInvokeMessage.MessageType.LINK,
|
||||
message=url,
|
||||
save_as=message.save_as,
|
||||
meta=message.meta.copy() if message.meta is not None else {},
|
||||
)
|
||||
)
|
||||
else:
|
||||
result.append(message)
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
|
||||
For some reason, model will be used in tools like WebScraperTool, WikipediaSearchTool etc.
|
||||
|
||||
Therefore, a model manager is needed to list/invoke/validate models.
|
||||
Therefore, a model manager is needed to list/invoke/validate models.
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -27,52 +27,49 @@ from models.tools import ToolModelInvoke
|
||||
class InvokeModelError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ModelInvocationUtils:
|
||||
@staticmethod
|
||||
def get_max_llm_context_tokens(
|
||||
tenant_id: str,
|
||||
) -> int:
|
||||
"""
|
||||
get max llm context tokens of the model
|
||||
get max llm context tokens of the model
|
||||
"""
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM,
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
if not model_instance:
|
||||
raise InvokeModelError('Model not found')
|
||||
|
||||
raise InvokeModelError("Model not found")
|
||||
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
|
||||
if not schema:
|
||||
raise InvokeModelError('No model schema found')
|
||||
raise InvokeModelError("No model schema found")
|
||||
|
||||
max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
|
||||
if max_tokens is None:
|
||||
return 2048
|
||||
|
||||
|
||||
return max_tokens
|
||||
|
||||
@staticmethod
|
||||
def calculate_tokens(
|
||||
tenant_id: str,
|
||||
prompt_messages: list[PromptMessage]
|
||||
) -> int:
|
||||
def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int:
|
||||
"""
|
||||
calculate tokens from prompt messages and model parameters
|
||||
calculate tokens from prompt messages and model parameters
|
||||
"""
|
||||
|
||||
# get model instance
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM
|
||||
)
|
||||
model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM)
|
||||
|
||||
if not model_instance:
|
||||
raise InvokeModelError('Model not found')
|
||||
|
||||
raise InvokeModelError("Model not found")
|
||||
|
||||
# get tokens
|
||||
tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
@@ -80,9 +77,7 @@ class ModelInvocationUtils:
|
||||
|
||||
@staticmethod
|
||||
def invoke(
|
||||
user_id: str, tenant_id: str,
|
||||
tool_type: str, tool_name: str,
|
||||
prompt_messages: list[PromptMessage]
|
||||
user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage]
|
||||
) -> LLMResult:
|
||||
"""
|
||||
invoke model with parameters in user's own context
|
||||
@@ -103,15 +98,16 @@ class ModelInvocationUtils:
|
||||
model_manager = ModelManager()
|
||||
# get model instance
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id, model_type=ModelType.LLM,
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
# get prompt tokens
|
||||
prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
model_parameters = {
|
||||
'temperature': 0.8,
|
||||
'top_p': 0.8,
|
||||
"temperature": 0.8,
|
||||
"top_p": 0.8,
|
||||
}
|
||||
|
||||
# create tool model invoke
|
||||
@@ -123,14 +119,14 @@ class ModelInvocationUtils:
|
||||
tool_name=tool_name,
|
||||
model_parameters=json.dumps(model_parameters),
|
||||
prompt_messages=json.dumps(jsonable_encoder(prompt_messages)),
|
||||
model_response='',
|
||||
model_response="",
|
||||
prompt_tokens=prompt_tokens,
|
||||
answer_tokens=0,
|
||||
answer_unit_price=0,
|
||||
answer_price_unit=0,
|
||||
provider_response_latency=0,
|
||||
total_price=0,
|
||||
currency='USD',
|
||||
currency="USD",
|
||||
)
|
||||
|
||||
db.session.add(tool_model_invoke)
|
||||
@@ -140,20 +136,24 @@ class ModelInvocationUtils:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=[], stop=[], stream=False, user=user_id, callbacks=[]
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=False,
|
||||
user=user_id,
|
||||
callbacks=[],
|
||||
)
|
||||
except InvokeRateLimitError as e:
|
||||
raise InvokeModelError(f'Invoke rate limit error: {e}')
|
||||
raise InvokeModelError(f"Invoke rate limit error: {e}")
|
||||
except InvokeBadRequestError as e:
|
||||
raise InvokeModelError(f'Invoke bad request error: {e}')
|
||||
raise InvokeModelError(f"Invoke bad request error: {e}")
|
||||
except InvokeConnectionError as e:
|
||||
raise InvokeModelError(f'Invoke connection error: {e}')
|
||||
raise InvokeModelError(f"Invoke connection error: {e}")
|
||||
except InvokeAuthorizationError as e:
|
||||
raise InvokeModelError('Invoke authorization error')
|
||||
raise InvokeModelError("Invoke authorization error")
|
||||
except InvokeServerUnavailableError as e:
|
||||
raise InvokeModelError(f'Invoke server unavailable error: {e}')
|
||||
raise InvokeModelError(f"Invoke server unavailable error: {e}")
|
||||
except Exception as e:
|
||||
raise InvokeModelError(f'Invoke error: {e}')
|
||||
raise InvokeModelError(f"Invoke error: {e}")
|
||||
|
||||
# update tool model invoke
|
||||
tool_model_invoke.model_response = response.message.content
|
||||
|
@@ -1,4 +1,3 @@
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from json import dumps as json_dumps
|
||||
@@ -16,54 +15,56 @@ from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolPro
|
||||
|
||||
class ApiBasedToolSchemaParser:
|
||||
@staticmethod
|
||||
def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
|
||||
def parse_openapi_to_tool_bundle(
|
||||
openapi: dict, extra_info: dict = None, warning: dict = None
|
||||
) -> list[ApiToolBundle]:
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
# set description to extra_info
|
||||
extra_info['description'] = openapi['info'].get('description', '')
|
||||
extra_info["description"] = openapi["info"].get("description", "")
|
||||
|
||||
if len(openapi['servers']) == 0:
|
||||
raise ToolProviderNotFoundError('No server found in the openapi yaml.')
|
||||
if len(openapi["servers"]) == 0:
|
||||
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||
|
||||
server_url = openapi['servers'][0]['url']
|
||||
server_url = openapi["servers"][0]["url"]
|
||||
|
||||
# list all interfaces
|
||||
interfaces = []
|
||||
for path, path_item in openapi['paths'].items():
|
||||
methods = ['get', 'post', 'put', 'delete', 'patch', 'head', 'options', 'trace']
|
||||
for path, path_item in openapi["paths"].items():
|
||||
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
|
||||
for method in methods:
|
||||
if method in path_item:
|
||||
interfaces.append({
|
||||
'path': path,
|
||||
'method': method,
|
||||
'operation': path_item[method],
|
||||
})
|
||||
interfaces.append(
|
||||
{
|
||||
"path": path,
|
||||
"method": method,
|
||||
"operation": path_item[method],
|
||||
}
|
||||
)
|
||||
|
||||
# get all parameters
|
||||
bundles = []
|
||||
for interface in interfaces:
|
||||
# convert parameters
|
||||
parameters = []
|
||||
if 'parameters' in interface['operation']:
|
||||
for parameter in interface['operation']['parameters']:
|
||||
if "parameters" in interface["operation"]:
|
||||
for parameter in interface["operation"]["parameters"]:
|
||||
tool_parameter = ToolParameter(
|
||||
name=parameter['name'],
|
||||
label=I18nObject(
|
||||
en_US=parameter['name'],
|
||||
zh_Hans=parameter['name']
|
||||
),
|
||||
name=parameter["name"],
|
||||
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
||||
human_description=I18nObject(
|
||||
en_US=parameter.get('description', ''),
|
||||
zh_Hans=parameter.get('description', '')
|
||||
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=parameter.get('required', False),
|
||||
required=parameter.get("required", False),
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=parameter.get('description'),
|
||||
default=parameter['schema']['default'] if 'schema' in parameter and 'default' in parameter['schema'] else None,
|
||||
llm_description=parameter.get("description"),
|
||||
default=parameter["schema"]["default"]
|
||||
if "schema" in parameter and "default" in parameter["schema"]
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
# check if there is a type
|
||||
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
|
||||
if typ:
|
||||
@@ -72,44 +73,40 @@ class ApiBasedToolSchemaParser:
|
||||
parameters.append(tool_parameter)
|
||||
# create tool bundle
|
||||
# check if there is a request body
|
||||
if 'requestBody' in interface['operation']:
|
||||
request_body = interface['operation']['requestBody']
|
||||
if 'content' in request_body:
|
||||
for content_type, content in request_body['content'].items():
|
||||
if "requestBody" in interface["operation"]:
|
||||
request_body = interface["operation"]["requestBody"]
|
||||
if "content" in request_body:
|
||||
for content_type, content in request_body["content"].items():
|
||||
# if there is a reference, get the reference and overwrite the content
|
||||
if 'schema' not in content:
|
||||
if "schema" not in content:
|
||||
continue
|
||||
|
||||
if '$ref' in content['schema']:
|
||||
if "$ref" in content["schema"]:
|
||||
# get the reference
|
||||
root = openapi
|
||||
reference = content['schema']['$ref'].split('/')[1:]
|
||||
reference = content["schema"]["$ref"].split("/")[1:]
|
||||
for ref in reference:
|
||||
root = root[ref]
|
||||
# overwrite the content
|
||||
interface['operation']['requestBody']['content'][content_type]['schema'] = root
|
||||
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
|
||||
|
||||
# parse body parameters
|
||||
if 'schema' in interface['operation']['requestBody']['content'][content_type]:
|
||||
body_schema = interface['operation']['requestBody']['content'][content_type]['schema']
|
||||
required = body_schema.get('required', [])
|
||||
properties = body_schema.get('properties', {})
|
||||
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||
required = body_schema.get("required", [])
|
||||
properties = body_schema.get("properties", {})
|
||||
for name, property in properties.items():
|
||||
tool = ToolParameter(
|
||||
name=name,
|
||||
label=I18nObject(
|
||||
en_US=name,
|
||||
zh_Hans=name
|
||||
),
|
||||
label=I18nObject(en_US=name, zh_Hans=name),
|
||||
human_description=I18nObject(
|
||||
en_US=property.get('description', ''),
|
||||
zh_Hans=property.get('description', '')
|
||||
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||
),
|
||||
type=ToolParameter.ToolParameterType.STRING,
|
||||
required=name in required,
|
||||
form=ToolParameter.ToolParameterForm.LLM,
|
||||
llm_description=property.get('description', ''),
|
||||
default=property.get('default', None),
|
||||
llm_description=property.get("description", ""),
|
||||
default=property.get("default", None),
|
||||
)
|
||||
|
||||
# check if there is a type
|
||||
@@ -127,172 +124,176 @@ class ApiBasedToolSchemaParser:
|
||||
parameters_count[parameter.name] += 1
|
||||
for name, count in parameters_count.items():
|
||||
if count > 1:
|
||||
warning['duplicated_parameter'] = f'Parameter {name} is duplicated.'
|
||||
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
|
||||
|
||||
# check if there is a operation id, use $path_$method as operation id if not
|
||||
if 'operationId' not in interface['operation']:
|
||||
if "operationId" not in interface["operation"]:
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = interface['path']
|
||||
if interface['path'].startswith('/'):
|
||||
path = interface['path'][1:]
|
||||
path = interface["path"]
|
||||
if interface["path"].startswith("/"):
|
||||
path = interface["path"][1:]
|
||||
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||
path = re.sub(r'[^a-zA-Z0-9_-]', '', path)
|
||||
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
|
||||
if not path:
|
||||
path = str(uuid.uuid4())
|
||||
|
||||
interface['operation']['operationId'] = f'{path}_{interface["method"]}'
|
||||
|
||||
bundles.append(ApiToolBundle(
|
||||
server_url=server_url + interface['path'],
|
||||
method=interface['method'],
|
||||
summary=interface['operation']['description'] if 'description' in interface['operation'] else
|
||||
interface['operation'].get('summary', None),
|
||||
operation_id=interface['operation']['operationId'],
|
||||
parameters=parameters,
|
||||
author='',
|
||||
icon=None,
|
||||
openapi=interface['operation'],
|
||||
))
|
||||
interface["operation"]["operationId"] = f'{path}_{interface["method"]}'
|
||||
|
||||
bundles.append(
|
||||
ApiToolBundle(
|
||||
server_url=server_url + interface["path"],
|
||||
method=interface["method"],
|
||||
summary=interface["operation"]["description"]
|
||||
if "description" in interface["operation"]
|
||||
else interface["operation"].get("summary", None),
|
||||
operation_id=interface["operation"]["operationId"],
|
||||
parameters=parameters,
|
||||
author="",
|
||||
icon=None,
|
||||
openapi=interface["operation"],
|
||||
)
|
||||
)
|
||||
|
||||
return bundles
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType:
|
||||
parameter = parameter or {}
|
||||
typ = None
|
||||
if 'type' in parameter:
|
||||
typ = parameter['type']
|
||||
elif 'schema' in parameter and 'type' in parameter['schema']:
|
||||
typ = parameter['schema']['type']
|
||||
|
||||
if typ == 'integer' or typ == 'number':
|
||||
if "type" in parameter:
|
||||
typ = parameter["type"]
|
||||
elif "schema" in parameter and "type" in parameter["schema"]:
|
||||
typ = parameter["schema"]["type"]
|
||||
|
||||
if typ == "integer" or typ == "number":
|
||||
return ToolParameter.ToolParameterType.NUMBER
|
||||
elif typ == 'boolean':
|
||||
elif typ == "boolean":
|
||||
return ToolParameter.ToolParameterType.BOOLEAN
|
||||
elif typ == 'string':
|
||||
elif typ == "string":
|
||||
return ToolParameter.ToolParameterType.STRING
|
||||
|
||||
@staticmethod
|
||||
def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
|
||||
def parse_openapi_yaml_to_tool_bundle(
|
||||
yaml: str, extra_info: dict = None, warning: dict = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi yaml to tool bundle
|
||||
parse openapi yaml to tool bundle
|
||||
|
||||
:param yaml: the yaml string
|
||||
:return: the tool bundle
|
||||
:param yaml: the yaml string
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
openapi: dict = safe_load(yaml)
|
||||
if openapi is None:
|
||||
raise ToolApiSchemaError('Invalid openapi yaml.')
|
||||
raise ToolApiSchemaError("Invalid openapi yaml.")
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict:
|
||||
"""
|
||||
parse swagger to openapi
|
||||
parse swagger to openapi
|
||||
|
||||
:param swagger: the swagger dict
|
||||
:return: the openapi dict
|
||||
:param swagger: the swagger dict
|
||||
:return: the openapi dict
|
||||
"""
|
||||
# convert swagger to openapi
|
||||
info = swagger.get('info', {
|
||||
'title': 'Swagger',
|
||||
'description': 'Swagger',
|
||||
'version': '1.0.0'
|
||||
})
|
||||
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
|
||||
|
||||
servers = swagger.get('servers', [])
|
||||
servers = swagger.get("servers", [])
|
||||
|
||||
if len(servers) == 0:
|
||||
raise ToolApiSchemaError('No server found in the swagger yaml.')
|
||||
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||
|
||||
openapi = {
|
||||
'openapi': '3.0.0',
|
||||
'info': {
|
||||
'title': info.get('title', 'Swagger'),
|
||||
'description': info.get('description', 'Swagger'),
|
||||
'version': info.get('version', '1.0.0')
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"title": info.get("title", "Swagger"),
|
||||
"description": info.get("description", "Swagger"),
|
||||
"version": info.get("version", "1.0.0"),
|
||||
},
|
||||
'servers': swagger['servers'],
|
||||
'paths': {},
|
||||
'components': {
|
||||
'schemas': {}
|
||||
}
|
||||
"servers": swagger["servers"],
|
||||
"paths": {},
|
||||
"components": {"schemas": {}},
|
||||
}
|
||||
|
||||
# check paths
|
||||
if 'paths' not in swagger or len(swagger['paths']) == 0:
|
||||
raise ToolApiSchemaError('No paths found in the swagger yaml.')
|
||||
if "paths" not in swagger or len(swagger["paths"]) == 0:
|
||||
raise ToolApiSchemaError("No paths found in the swagger yaml.")
|
||||
|
||||
# convert paths
|
||||
for path, path_item in swagger['paths'].items():
|
||||
openapi['paths'][path] = {}
|
||||
for path, path_item in swagger["paths"].items():
|
||||
openapi["paths"][path] = {}
|
||||
for method, operation in path_item.items():
|
||||
if 'operationId' not in operation:
|
||||
raise ToolApiSchemaError(f'No operationId found in operation {method} {path}.')
|
||||
|
||||
if ('summary' not in operation or len(operation['summary']) == 0) and \
|
||||
('description' not in operation or len(operation['description']) == 0):
|
||||
warning['missing_summary'] = f'No summary or description found in operation {method} {path}.'
|
||||
|
||||
openapi['paths'][path][method] = {
|
||||
'operationId': operation['operationId'],
|
||||
'summary': operation.get('summary', ''),
|
||||
'description': operation.get('description', ''),
|
||||
'parameters': operation.get('parameters', []),
|
||||
'responses': operation.get('responses', {}),
|
||||
if "operationId" not in operation:
|
||||
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||
|
||||
if ("summary" not in operation or len(operation["summary"]) == 0) and (
|
||||
"description" not in operation or len(operation["description"]) == 0
|
||||
):
|
||||
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||
|
||||
openapi["paths"][path][method] = {
|
||||
"operationId": operation["operationId"],
|
||||
"summary": operation.get("summary", ""),
|
||||
"description": operation.get("description", ""),
|
||||
"parameters": operation.get("parameters", []),
|
||||
"responses": operation.get("responses", {}),
|
||||
}
|
||||
|
||||
if 'requestBody' in operation:
|
||||
openapi['paths'][path][method]['requestBody'] = operation['requestBody']
|
||||
if "requestBody" in operation:
|
||||
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
|
||||
|
||||
# convert definitions
|
||||
for name, definition in swagger['definitions'].items():
|
||||
openapi['components']['schemas'][name] = definition
|
||||
for name, definition in swagger["definitions"].items():
|
||||
openapi["components"]["schemas"][name] = definition
|
||||
|
||||
return openapi
|
||||
|
||||
@staticmethod
|
||||
def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
|
||||
def parse_openai_plugin_json_to_tool_bundle(
|
||||
json: str, extra_info: dict = None, warning: dict = None
|
||||
) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse openapi plugin yaml to tool bundle
|
||||
parse openapi plugin yaml to tool bundle
|
||||
|
||||
:param json: the json string
|
||||
:return: the tool bundle
|
||||
:param json: the json string
|
||||
:return: the tool bundle
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
|
||||
try:
|
||||
openai_plugin = json_loads(json)
|
||||
api = openai_plugin['api']
|
||||
api_url = api['url']
|
||||
api_type = api['type']
|
||||
api = openai_plugin["api"]
|
||||
api_url = api["url"]
|
||||
api_type = api["type"]
|
||||
except:
|
||||
raise ToolProviderNotFoundError('Invalid openai plugin json.')
|
||||
|
||||
if api_type != 'openapi':
|
||||
raise ToolNotSupportedError('Only openapi is supported now.')
|
||||
|
||||
raise ToolProviderNotFoundError("Invalid openai plugin json.")
|
||||
|
||||
if api_type != "openapi":
|
||||
raise ToolNotSupportedError("Only openapi is supported now.")
|
||||
|
||||
# get openapi yaml
|
||||
response = get(api_url, headers={
|
||||
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) '
|
||||
}, timeout=5)
|
||||
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ToolProviderNotFoundError('cannot get openapi yaml from url.')
|
||||
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning)
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
|
||||
|
||||
:param content: the content
|
||||
:return: tools bundle, schema_type
|
||||
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
|
||||
response.text, extra_info=extra_info, warning=warning
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def auto_parse_to_tool_bundle(
|
||||
content: str, extra_info: dict = None, warning: dict = None
|
||||
) -> tuple[list[ApiToolBundle], str]:
|
||||
"""
|
||||
auto parse to tool bundle
|
||||
|
||||
:param content: the content
|
||||
:return: tools bundle, schema_type
|
||||
"""
|
||||
warning = warning if warning is not None else {}
|
||||
extra_info = extra_info if extra_info is not None else {}
|
||||
@@ -301,7 +302,7 @@ class ApiBasedToolSchemaParser:
|
||||
loaded_content = None
|
||||
json_error = None
|
||||
yaml_error = None
|
||||
|
||||
|
||||
try:
|
||||
loaded_content = json_loads(content)
|
||||
except JSONDecodeError as e:
|
||||
@@ -313,34 +314,46 @@ class ApiBasedToolSchemaParser:
|
||||
except YAMLError as e:
|
||||
yaml_error = e
|
||||
if loaded_content is None:
|
||||
raise ToolApiSchemaError(f'Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}, yaml error: {str(yaml_error)}')
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)}, yaml error: {str(yaml_error)}"
|
||||
)
|
||||
|
||||
swagger_error = None
|
||||
openapi_error = None
|
||||
openapi_plugin_error = None
|
||||
schema_type = None
|
||||
|
||||
|
||||
try:
|
||||
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(loaded_content, extra_info=extra_info, warning=warning)
|
||||
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.OPENAPI.value
|
||||
return openapi, schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
openapi_error = e
|
||||
|
||||
|
||||
# openai parse error, fallback to swagger
|
||||
try:
|
||||
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(loaded_content, extra_info=extra_info, warning=warning)
|
||||
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||
loaded_content, extra_info=extra_info, warning=warning
|
||||
)
|
||||
schema_type = ApiProviderSchemaType.SWAGGER.value
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(converted_swagger, extra_info=extra_info, warning=warning), schema_type
|
||||
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||
converted_swagger, extra_info=extra_info, warning=warning
|
||||
), schema_type
|
||||
except ToolApiSchemaError as e:
|
||||
swagger_error = e
|
||||
|
||||
|
||||
# swagger parse error, fallback to openai plugin
|
||||
try:
|
||||
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(json_dumps(loaded_content), extra_info=extra_info, warning=warning)
|
||||
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||
json_dumps(loaded_content), extra_info=extra_info, warning=warning
|
||||
)
|
||||
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
|
||||
except ToolNotSupportedError as e:
|
||||
# maybe it's not plugin at all
|
||||
openapi_plugin_error = e
|
||||
|
||||
raise ToolApiSchemaError(f'Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}, openapi plugin error: {str(openapi_plugin_error)}')
|
||||
raise ToolApiSchemaError(
|
||||
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)}, openapi plugin error: {str(openapi_plugin_error)}"
|
||||
)
|
||||
|
@@ -7,16 +7,18 @@ class ToolParameterConverter:
|
||||
@staticmethod
|
||||
def get_parameter_type(parameter_type: str | ToolParameter.ToolParameterType) -> str:
|
||||
match parameter_type:
|
||||
case ToolParameter.ToolParameterType.STRING \
|
||||
| ToolParameter.ToolParameterType.SECRET_INPUT \
|
||||
| ToolParameter.ToolParameterType.SELECT:
|
||||
return 'string'
|
||||
case (
|
||||
ToolParameter.ToolParameterType.STRING
|
||||
| ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
| ToolParameter.ToolParameterType.SELECT
|
||||
):
|
||||
return "string"
|
||||
|
||||
case ToolParameter.ToolParameterType.BOOLEAN:
|
||||
return 'boolean'
|
||||
return "boolean"
|
||||
|
||||
case ToolParameter.ToolParameterType.NUMBER:
|
||||
return 'number'
|
||||
return "number"
|
||||
|
||||
case _:
|
||||
raise ValueError(f"Unsupported parameter type {parameter_type}")
|
||||
@@ -26,11 +28,13 @@ class ToolParameterConverter:
|
||||
# convert tool parameter config to correct type
|
||||
try:
|
||||
match parameter_type:
|
||||
case ToolParameter.ToolParameterType.STRING \
|
||||
| ToolParameter.ToolParameterType.SECRET_INPUT \
|
||||
| ToolParameter.ToolParameterType.SELECT:
|
||||
case (
|
||||
ToolParameter.ToolParameterType.STRING
|
||||
| ToolParameter.ToolParameterType.SECRET_INPUT
|
||||
| ToolParameter.ToolParameterType.SELECT
|
||||
):
|
||||
if value is None:
|
||||
return ''
|
||||
return ""
|
||||
else:
|
||||
return value if isinstance(value, str) else str(value)
|
||||
|
||||
@@ -41,9 +45,9 @@ class ToolParameterConverter:
|
||||
# Allowed YAML boolean value strings: https://yaml.org/type/bool.html
|
||||
# and also '0' for False and '1' for True
|
||||
match value.lower():
|
||||
case 'true' | 'yes' | 'y' | '1':
|
||||
case "true" | "yes" | "y" | "1":
|
||||
return True
|
||||
case 'false' | 'no' | 'n' | '0':
|
||||
case "false" | "no" | "n" | "0":
|
||||
return False
|
||||
case _:
|
||||
return bool(value)
|
||||
@@ -53,8 +57,8 @@ class ToolParameterConverter:
|
||||
case ToolParameter.ToolParameterType.NUMBER:
|
||||
if isinstance(value, int) | isinstance(value, float):
|
||||
return value
|
||||
elif isinstance(value, str) and value != '':
|
||||
if '.' in value:
|
||||
elif isinstance(value, str) and value != "":
|
||||
if "." in value:
|
||||
return float(value)
|
||||
else:
|
||||
return int(value)
|
||||
|
@@ -32,7 +32,7 @@ TEXT:
|
||||
|
||||
def page_result(text: str, cursor: int, max_length: int) -> str:
|
||||
"""Page through `text` and return a substring of `max_length` characters starting from `cursor`."""
|
||||
return text[cursor: cursor + max_length]
|
||||
return text[cursor : cursor + max_length]
|
||||
|
||||
|
||||
def get_url(url: str, user_agent: str = None) -> str:
|
||||
@@ -49,15 +49,15 @@ def get_url(url: str, user_agent: str = None) -> str:
|
||||
|
||||
if response.status_code == 200:
|
||||
# check content-type
|
||||
content_type = response.headers.get('Content-Type')
|
||||
content_type = response.headers.get("Content-Type")
|
||||
if content_type:
|
||||
main_content_type = response.headers.get('Content-Type').split(';')[0].strip()
|
||||
main_content_type = response.headers.get("Content-Type").split(";")[0].strip()
|
||||
else:
|
||||
content_disposition = response.headers.get('Content-Disposition', '')
|
||||
content_disposition = response.headers.get("Content-Disposition", "")
|
||||
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
|
||||
if filename_match:
|
||||
filename = unquote(filename_match.group(1))
|
||||
extension = re.search(r'\.(\w+)$', filename)
|
||||
extension = re.search(r"\.(\w+)$", filename)
|
||||
if extension:
|
||||
main_content_type = mimetypes.guess_type(filename)[0]
|
||||
|
||||
@@ -78,7 +78,7 @@ def get_url(url: str, user_agent: str = None) -> str:
|
||||
|
||||
# Detect encoding using chardet
|
||||
detected_encoding = chardet.detect(response.content)
|
||||
encoding = detected_encoding['encoding']
|
||||
encoding = detected_encoding["encoding"]
|
||||
if encoding:
|
||||
try:
|
||||
content = response.content.decode(encoding)
|
||||
@@ -89,29 +89,29 @@ def get_url(url: str, user_agent: str = None) -> str:
|
||||
|
||||
a = extract_using_readabilipy(content)
|
||||
|
||||
if not a['plain_text'] or not a['plain_text'].strip():
|
||||
return ''
|
||||
if not a["plain_text"] or not a["plain_text"].strip():
|
||||
return ""
|
||||
|
||||
res = FULL_TEMPLATE.format(
|
||||
title=a['title'],
|
||||
authors=a['byline'],
|
||||
publish_date=a['date'],
|
||||
title=a["title"],
|
||||
authors=a["byline"],
|
||||
publish_date=a["date"],
|
||||
top_image="",
|
||||
text=a['plain_text'] if a['plain_text'] else "",
|
||||
text=a["plain_text"] if a["plain_text"] else "",
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def extract_using_readabilipy(html):
|
||||
with tempfile.NamedTemporaryFile(delete=False, mode='w+') as f_html:
|
||||
with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html:
|
||||
f_html.write(html)
|
||||
f_html.close()
|
||||
html_path = f_html.name
|
||||
|
||||
# Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file
|
||||
article_json_path = html_path + ".json"
|
||||
jsdir = os.path.join(find_module_path('readabilipy'), 'javascript')
|
||||
jsdir = os.path.join(find_module_path("readabilipy"), "javascript")
|
||||
with chdir(jsdir):
|
||||
subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path])
|
||||
|
||||
@@ -129,7 +129,7 @@ def extract_using_readabilipy(html):
|
||||
"date": None,
|
||||
"content": None,
|
||||
"plain_content": None,
|
||||
"plain_text": None
|
||||
"plain_text": None,
|
||||
}
|
||||
# Populate article fields from readability fields where present
|
||||
if input_json:
|
||||
@@ -145,7 +145,7 @@ def extract_using_readabilipy(html):
|
||||
article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"])
|
||||
if input_json.get("textContent"):
|
||||
article_json["plain_text"] = input_json["textContent"]
|
||||
article_json["plain_text"] = re.sub(r'\n\s*\n', '\n', article_json["plain_text"])
|
||||
article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"])
|
||||
|
||||
return article_json
|
||||
|
||||
@@ -158,6 +158,7 @@ def find_module_path(module_name):
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def chdir(path):
|
||||
"""Change directory in context and return to original on exit"""
|
||||
@@ -172,12 +173,14 @@ def chdir(path):
|
||||
|
||||
def extract_text_blocks_as_plain_text(paragraph_html):
|
||||
# Load article as DOM
|
||||
soup = BeautifulSoup(paragraph_html, 'html.parser')
|
||||
soup = BeautifulSoup(paragraph_html, "html.parser")
|
||||
# Select all lists
|
||||
list_elements = soup.find_all(['ul', 'ol'])
|
||||
list_elements = soup.find_all(["ul", "ol"])
|
||||
# Prefix text in all list items with "* " and make lists paragraphs
|
||||
for list_element in list_elements:
|
||||
plain_items = "".join(list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all('li')])))
|
||||
plain_items = "".join(
|
||||
list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")]))
|
||||
)
|
||||
list_element.string = plain_items
|
||||
list_element.name = "p"
|
||||
# Select all text blocks
|
||||
@@ -204,7 +207,7 @@ def plain_text_leaf_node(element):
|
||||
|
||||
def plain_content(readability_content, content_digests, node_indexes):
|
||||
# Load article as DOM
|
||||
soup = BeautifulSoup(readability_content, 'html.parser')
|
||||
soup = BeautifulSoup(readability_content, "html.parser")
|
||||
# Make all elements plain
|
||||
elements = plain_elements(soup.contents, content_digests, node_indexes)
|
||||
if node_indexes:
|
||||
@@ -217,8 +220,7 @@ def plain_content(readability_content, content_digests, node_indexes):
|
||||
|
||||
def plain_elements(elements, content_digests, node_indexes):
|
||||
# Get plain content versions of all elements
|
||||
elements = [plain_element(element, content_digests, node_indexes)
|
||||
for element in elements]
|
||||
elements = [plain_element(element, content_digests, node_indexes) for element in elements]
|
||||
if content_digests:
|
||||
# Add content digest attribute to nodes
|
||||
elements = [add_content_digest(element) for element in elements]
|
||||
@@ -258,11 +260,9 @@ def add_node_indexes(element, node_index="0"):
|
||||
# Add index to current element
|
||||
element["data-node-index"] = node_index
|
||||
# Add index to child elements
|
||||
for local_idx, child in enumerate(
|
||||
[c for c in element.contents if not is_text(c)], start=1):
|
||||
for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1):
|
||||
# Can't add attributes to leaf string types
|
||||
child_index = "{stem}.{local}".format(
|
||||
stem=node_index, local=local_idx)
|
||||
child_index = "{stem}.{local}".format(stem=node_index, local=local_idx)
|
||||
add_node_indexes(child, node_index=child_index)
|
||||
return element
|
||||
|
||||
@@ -284,11 +284,16 @@ def strip_control_characters(text):
|
||||
# [Cn]: Other, Not Assigned
|
||||
# [Co]: Other, Private Use
|
||||
# [Cs]: Other, Surrogate
|
||||
control_chars = {'Cc', 'Cf', 'Cn', 'Co', 'Cs'}
|
||||
retained_chars = ['\t', '\n', '\r', '\f']
|
||||
control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"}
|
||||
retained_chars = ["\t", "\n", "\r", "\f"]
|
||||
|
||||
# Remove non-printing control characters
|
||||
return "".join(["" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char for char in text])
|
||||
return "".join(
|
||||
[
|
||||
"" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char
|
||||
for char in text
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def normalize_unicode(text):
|
||||
@@ -305,8 +310,9 @@ def normalize_whitespace(text):
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
def is_leaf(element):
|
||||
return (element.name in ['p', 'li'])
|
||||
return element.name in ["p", "li"]
|
||||
|
||||
|
||||
def is_text(element):
|
||||
@@ -330,7 +336,7 @@ def content_digest(element):
|
||||
if trimmed_string == "":
|
||||
digest = ""
|
||||
else:
|
||||
digest = hashlib.sha256(trimmed_string.encode('utf-8')).hexdigest()
|
||||
digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest()
|
||||
else:
|
||||
contents = element.contents
|
||||
num_contents = len(contents)
|
||||
@@ -343,9 +349,8 @@ def content_digest(element):
|
||||
else:
|
||||
# Build content digest from the "non-empty" digests of child nodes
|
||||
digest = hashlib.sha256()
|
||||
child_digests = list(
|
||||
filter(lambda x: x != "", [content_digest(content) for content in contents]))
|
||||
child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents]))
|
||||
for child in child_digests:
|
||||
digest.update(child.encode('utf-8'))
|
||||
digest.update(child.encode("utf-8"))
|
||||
digest = digest.hexdigest()
|
||||
return digest
|
||||
|
@@ -10,27 +10,25 @@ class WorkflowToolConfigurationUtils:
|
||||
"""
|
||||
for configuration in configurations:
|
||||
if not WorkflowToolParameterConfiguration(**configuration):
|
||||
raise ValueError('invalid parameter configuration')
|
||||
raise ValueError("invalid parameter configuration")
|
||||
|
||||
@classmethod
|
||||
def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
|
||||
"""
|
||||
get workflow graph variables
|
||||
"""
|
||||
nodes = graph.get('nodes', [])
|
||||
start_node = next(filter(lambda x: x.get('data', {}).get('type') == 'start', nodes), None)
|
||||
nodes = graph.get("nodes", [])
|
||||
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
|
||||
|
||||
if not start_node:
|
||||
return []
|
||||
|
||||
return [
|
||||
VariableEntity(**variable) for variable in start_node.get('data', {}).get('variables', [])
|
||||
]
|
||||
|
||||
return [VariableEntity(**variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(cls,
|
||||
variables: list[VariableEntity],
|
||||
tool_configurations: list[WorkflowToolParameterConfiguration]) -> None:
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
) -> None:
|
||||
"""
|
||||
check is synced
|
||||
|
||||
@@ -39,10 +37,10 @@ class WorkflowToolConfigurationUtils:
|
||||
variable_names = [variable.variable for variable in variables]
|
||||
|
||||
if len(tool_configurations) != len(variables):
|
||||
raise ValueError('parameter configuration mismatch, please republish the tool to update')
|
||||
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
|
||||
for parameter in tool_configurations:
|
||||
if parameter.name not in variable_names:
|
||||
raise ValueError('parameter configuration mismatch, please republish the tool to update')
|
||||
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||
|
||||
return True
|
||||
return True
|
||||
|
@@ -18,12 +18,12 @@ def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any
|
||||
:return: an object of the YAML content
|
||||
"""
|
||||
try:
|
||||
with open(file_path, encoding='utf-8') as yaml_file:
|
||||
with open(file_path, encoding="utf-8") as yaml_file:
|
||||
try:
|
||||
yaml_content = yaml.safe_load(yaml_file)
|
||||
return yaml_content if yaml_content else default_value
|
||||
except Exception as e:
|
||||
raise YAMLError(f'Failed to load YAML file {file_path}: {e}')
|
||||
raise YAMLError(f"Failed to load YAML file {file_path}: {e}")
|
||||
except Exception as e:
|
||||
if ignore_error:
|
||||
return default_value
|
||||
|
Reference in New Issue
Block a user