feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -72,9 +72,13 @@ class ToolConfigurationManager(BaseModel):
return a deep copy of credentials with decrypted values
"""
identity_id = ""
if self.provider_controller.identity:
identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}"
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}",
identity_id=identity_id,
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cached_credentials = cache.get()
@@ -95,9 +99,13 @@ class ToolConfigurationManager(BaseModel):
return credentials
def delete_tool_credentials_cache(self):
identity_id = ""
if self.provider_controller.identity:
identity_id = f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}"
cache = ToolProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=f"{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}",
identity_id=identity_id,
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
)
cache.delete()
@@ -199,6 +207,9 @@ class ToolParameterConfigurationManager(BaseModel):
return a deep copy of parameters with decrypted values
"""
if self.tool_runtime is None or self.tool_runtime.identity is None:
raise ValueError("tool_runtime is required")
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type}.{self.provider_name}",
@@ -232,6 +243,9 @@ class ToolParameterConfigurationManager(BaseModel):
return parameters
def delete_tool_parameters_cache(self):
if self.tool_runtime is None or self.tool_runtime.identity is None:
raise ValueError("tool_runtime is required")
cache = ToolParameterCache(
tenant_id=self.tenant_id,
provider=f"{self.provider_type}.{self.provider_name}",

View File

@@ -1,5 +1,5 @@
import json
from typing import Optional
from typing import Any, Optional, cast
import httpx
@@ -101,7 +101,7 @@ class FeishuRequest:
"""
url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token"
payload = {"app_id": app_id, "app_secret": app_secret}
res = self._send_request(url, require_token=False, payload=payload)
res: dict = self._send_request(url, require_token=False, payload=payload)
return res
def create_document(self, title: str, content: str, folder_token: str) -> dict:
@@ -126,15 +126,16 @@ class FeishuRequest:
"content": content,
"folder_token": folder_token,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def write_document(self, document_id: str, content: str, position: str = "end") -> dict:
url = f"{self.API_BASE_URL}/document/write_document"
payload = {"document_id": document_id, "content": content, "position": position}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
return res
def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str:
@@ -155,9 +156,9 @@ class FeishuRequest:
"lang": lang,
}
url = f"{self.API_BASE_URL}/document/get_document_content"
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data").get("content")
return cast(str, res.get("data", {}).get("content"))
return ""
def list_document_blocks(
@@ -173,9 +174,10 @@ class FeishuRequest:
"page_token": page_token,
}
url = f"{self.API_BASE_URL}/document/list_document_blocks"
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict:
@@ -191,9 +193,10 @@ class FeishuRequest:
"msg_type": msg_type,
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict:
@@ -203,7 +206,7 @@ class FeishuRequest:
"msg_type": msg_type,
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
}
res = self._send_request(url, require_token=False, payload=payload)
res: dict = self._send_request(url, require_token=False, payload=payload)
return res
def get_chat_messages(
@@ -227,9 +230,10 @@ class FeishuRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_thread_messages(
@@ -245,9 +249,10 @@ class FeishuRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict:
@@ -260,9 +265,10 @@ class FeishuRequest:
"completed_at": completed_time,
"description": description,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def update_task(
@@ -278,9 +284,10 @@ class FeishuRequest:
"completed_time": completed_time,
"description": description,
}
res = self._send_request(url, method="PATCH", payload=payload)
res: dict = self._send_request(url, method="PATCH", payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def delete_task(self, task_guid: str) -> dict:
@@ -289,7 +296,7 @@ class FeishuRequest:
payload = {
"task_guid": task_guid,
}
res = self._send_request(url, method="DELETE", payload=payload)
res: dict = self._send_request(url, method="DELETE", payload=payload)
return res
def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict:
@@ -300,7 +307,7 @@ class FeishuRequest:
"member_phone_or_email": member_phone_or_email,
"member_role": member_role,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
return res
def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict:
@@ -312,9 +319,10 @@ class FeishuRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_primary_calendar(self, user_id_type: str = "open_id") -> dict:
@@ -322,9 +330,10 @@ class FeishuRequest:
params = {
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_event(
@@ -347,9 +356,10 @@ class FeishuRequest:
"auto_record": auto_record,
"attendee_ability": attendee_ability,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def update_event(
@@ -363,7 +373,7 @@ class FeishuRequest:
auto_record: bool,
) -> dict:
url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}"
payload = {}
payload: dict[str, Any] = {}
if summary:
payload["summary"] = summary
if description:
@@ -376,7 +386,7 @@ class FeishuRequest:
payload["need_notification"] = need_notification
if auto_record:
payload["auto_record"] = auto_record
res = self._send_request(url, method="PATCH", payload=payload)
res: dict = self._send_request(url, method="PATCH", payload=payload)
return res
def delete_event(self, event_id: str, need_notification: bool = True) -> dict:
@@ -384,7 +394,7 @@ class FeishuRequest:
params = {
"need_notification": need_notification,
}
res = self._send_request(url, method="DELETE", params=params)
res: dict = self._send_request(url, method="DELETE", params=params)
return res
def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict:
@@ -395,9 +405,10 @@ class FeishuRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def search_events(
@@ -418,9 +429,10 @@ class FeishuRequest:
"user_id_type": user_id_type,
"page_size": page_size,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict:
@@ -431,9 +443,10 @@ class FeishuRequest:
"attendee_phone_or_email": attendee_phone_or_email,
"need_notification": need_notification,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_spreadsheet(
@@ -447,9 +460,10 @@ class FeishuRequest:
"title": title,
"folder_token": folder_token,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_spreadsheet(
@@ -463,9 +477,10 @@ class FeishuRequest:
"spreadsheet_token": spreadsheet_token,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def list_spreadsheet_sheets(
@@ -477,9 +492,10 @@ class FeishuRequest:
params = {
"spreadsheet_token": spreadsheet_token,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_rows(
@@ -499,9 +515,10 @@ class FeishuRequest:
"length": length,
"values": values,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_cols(
@@ -521,9 +538,10 @@ class FeishuRequest:
"length": length,
"values": values,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def read_rows(
@@ -545,9 +563,10 @@ class FeishuRequest:
"num_rows": num_rows,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def read_cols(
@@ -569,9 +588,10 @@ class FeishuRequest:
"num_cols": num_cols,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def read_table(
@@ -593,9 +613,10 @@ class FeishuRequest:
"query": query,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_base(
@@ -609,9 +630,10 @@ class FeishuRequest:
"name": name,
"folder_token": folder_token,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_records(
@@ -633,9 +655,10 @@ class FeishuRequest:
payload = {
"records": convert_add_records(records),
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def update_records(
@@ -657,9 +680,10 @@ class FeishuRequest:
payload = {
"records": convert_update_records(records),
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def delete_records(
@@ -686,9 +710,10 @@ class FeishuRequest:
payload = {
"records": record_id_list,
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def search_record(
@@ -740,7 +765,7 @@ class FeishuRequest:
except json.JSONDecodeError:
raise ValueError("The input string is not valid JSON")
payload = {}
payload: dict[str, Any] = {}
if view_id:
payload["view_id"] = view_id
@@ -752,10 +777,11 @@ class FeishuRequest:
payload["filter"] = filter_dict
if automatic_fields:
payload["automatic_fields"] = automatic_fields
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_base_info(
@@ -767,9 +793,10 @@ class FeishuRequest:
params = {
"app_token": app_token,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_table(
@@ -797,9 +824,10 @@ class FeishuRequest:
}
if default_view_name:
payload["default_view_name"] = default_view_name
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def delete_tables(
@@ -834,9 +862,10 @@ class FeishuRequest:
"table_names": table_name_list,
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def list_tables(
@@ -852,9 +881,10 @@ class FeishuRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def read_records(
@@ -882,7 +912,8 @@ class FeishuRequest:
"record_ids": record_id_list,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params, payload=payload)
res: dict = self._send_request(url, method="GET", params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res

View File

@@ -1,5 +1,5 @@
import json
from typing import Optional
from typing import Any, Optional, cast
import httpx
@@ -62,12 +62,10 @@ class LarkRequest:
def tenant_access_token(self) -> str:
feishu_tenant_access_token = f"tools:{self.app_id}:feishu_tenant_access_token"
if redis_client.exists(feishu_tenant_access_token):
return redis_client.get(feishu_tenant_access_token).decode()
res = self.get_tenant_access_token(self.app_id, self.app_secret)
return str(redis_client.get(feishu_tenant_access_token).decode())
res: dict[str, str] = self.get_tenant_access_token(self.app_id, self.app_secret)
redis_client.setex(feishu_tenant_access_token, res.get("expire"), res.get("tenant_access_token"))
if "tenant_access_token" in res:
return res.get("tenant_access_token")
return ""
return res.get("tenant_access_token", "")
def _send_request(
self,
@@ -91,7 +89,7 @@ class LarkRequest:
def get_tenant_access_token(self, app_id: str, app_secret: str) -> dict:
url = f"{self.API_BASE_URL}/access_token/get_tenant_access_token"
payload = {"app_id": app_id, "app_secret": app_secret}
res = self._send_request(url, require_token=False, payload=payload)
res: dict = self._send_request(url, require_token=False, payload=payload)
return res
def create_document(self, title: str, content: str, folder_token: str) -> dict:
@@ -101,15 +99,16 @@ class LarkRequest:
"content": content,
"folder_token": folder_token,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def write_document(self, document_id: str, content: str, position: str = "end") -> dict:
url = f"{self.API_BASE_URL}/document/write_document"
payload = {"document_id": document_id, "content": content, "position": position}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
return res
def get_document_content(self, document_id: str, mode: str = "markdown", lang: str = "0") -> str | dict:
@@ -119,9 +118,9 @@ class LarkRequest:
"lang": lang,
}
url = f"{self.API_BASE_URL}/document/get_document_content"
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data").get("content")
return cast(dict, res.get("data", {}).get("content"))
return ""
def list_document_blocks(
@@ -134,9 +133,10 @@ class LarkRequest:
"page_token": page_token,
}
url = f"{self.API_BASE_URL}/document/list_document_blocks"
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def send_bot_message(self, receive_id_type: str, receive_id: str, msg_type: str, content: str) -> dict:
@@ -149,9 +149,10 @@ class LarkRequest:
"msg_type": msg_type,
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def send_webhook_message(self, webhook: str, msg_type: str, content: str) -> dict:
@@ -161,7 +162,7 @@ class LarkRequest:
"msg_type": msg_type,
"content": content.strip('"').replace(r"\"", '"').replace(r"\\", "\\"),
}
res = self._send_request(url, require_token=False, payload=payload)
res: dict = self._send_request(url, require_token=False, payload=payload)
return res
def get_chat_messages(
@@ -182,9 +183,10 @@ class LarkRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_thread_messages(
@@ -197,9 +199,10 @@ class LarkRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_task(self, summary: str, start_time: str, end_time: str, completed_time: str, description: str) -> dict:
@@ -211,9 +214,10 @@ class LarkRequest:
"completed_at": completed_time,
"description": description,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def update_task(
@@ -228,9 +232,10 @@ class LarkRequest:
"completed_time": completed_time,
"description": description,
}
res = self._send_request(url, method="PATCH", payload=payload)
res: dict = self._send_request(url, method="PATCH", payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def delete_task(self, task_guid: str) -> dict:
@@ -238,9 +243,10 @@ class LarkRequest:
payload = {
"task_guid": task_guid,
}
res = self._send_request(url, method="DELETE", payload=payload)
res: dict = self._send_request(url, method="DELETE", payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_members(self, task_guid: str, member_phone_or_email: str, member_role: str) -> dict:
@@ -250,9 +256,10 @@ class LarkRequest:
"member_phone_or_email": member_phone_or_email,
"member_role": member_role,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_wiki_nodes(self, space_id: str, parent_node_token: str, page_token: str, page_size: int = 20) -> dict:
@@ -263,9 +270,10 @@ class LarkRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_primary_calendar(self, user_id_type: str = "open_id") -> dict:
@@ -273,9 +281,10 @@ class LarkRequest:
params = {
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_event(
@@ -298,9 +307,10 @@ class LarkRequest:
"auto_record": auto_record,
"attendee_ability": attendee_ability,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def update_event(
@@ -314,7 +324,7 @@ class LarkRequest:
auto_record: bool,
) -> dict:
url = f"{self.API_BASE_URL}/calendar/update_event/{event_id}"
payload = {}
payload: dict[str, Any] = {}
if summary:
payload["summary"] = summary
if description:
@@ -327,7 +337,7 @@ class LarkRequest:
payload["need_notification"] = need_notification
if auto_record:
payload["auto_record"] = auto_record
res = self._send_request(url, method="PATCH", payload=payload)
res: dict = self._send_request(url, method="PATCH", payload=payload)
return res
def delete_event(self, event_id: str, need_notification: bool = True) -> dict:
@@ -335,7 +345,7 @@ class LarkRequest:
params = {
"need_notification": need_notification,
}
res = self._send_request(url, method="DELETE", params=params)
res: dict = self._send_request(url, method="DELETE", params=params)
return res
def list_events(self, start_time: str, end_time: str, page_token: str, page_size: int = 50) -> dict:
@@ -346,9 +356,10 @@ class LarkRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def search_events(
@@ -369,9 +380,10 @@ class LarkRequest:
"user_id_type": user_id_type,
"page_size": page_size,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_event_attendees(self, event_id: str, attendee_phone_or_email: str, need_notification: bool = True) -> dict:
@@ -381,9 +393,10 @@ class LarkRequest:
"attendee_phone_or_email": attendee_phone_or_email,
"need_notification": need_notification,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_spreadsheet(
@@ -396,9 +409,10 @@ class LarkRequest:
"title": title,
"folder_token": folder_token,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_spreadsheet(
@@ -411,9 +425,10 @@ class LarkRequest:
"spreadsheet_token": spreadsheet_token,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def list_spreadsheet_sheets(
@@ -424,9 +439,10 @@ class LarkRequest:
params = {
"spreadsheet_token": spreadsheet_token,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_rows(
@@ -445,9 +461,10 @@ class LarkRequest:
"length": length,
"values": values,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_cols(
@@ -466,9 +483,10 @@ class LarkRequest:
"length": length,
"values": values,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def read_rows(
@@ -489,9 +507,10 @@ class LarkRequest:
"num_rows": num_rows,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def read_cols(
@@ -512,9 +531,10 @@ class LarkRequest:
"num_cols": num_cols,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def read_table(
@@ -535,9 +555,10 @@ class LarkRequest:
"query": query,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_base(
@@ -550,9 +571,10 @@ class LarkRequest:
"name": name,
"folder_token": folder_token,
}
res = self._send_request(url, payload=payload)
res: dict = self._send_request(url, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def add_records(
@@ -573,9 +595,10 @@ class LarkRequest:
payload = {
"records": self.convert_add_records(records),
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def update_records(
@@ -596,9 +619,10 @@ class LarkRequest:
payload = {
"records": self.convert_update_records(records),
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def delete_records(
@@ -624,9 +648,10 @@ class LarkRequest:
payload = {
"records": record_id_list,
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def search_record(
@@ -678,7 +703,7 @@ class LarkRequest:
except json.JSONDecodeError:
raise ValueError("The input string is not valid JSON")
payload = {}
payload: dict[str, Any] = {}
if view_id:
payload["view_id"] = view_id
@@ -690,9 +715,10 @@ class LarkRequest:
payload["filter"] = filter_dict
if automatic_fields:
payload["automatic_fields"] = automatic_fields
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def get_base_info(
@@ -703,9 +729,10 @@ class LarkRequest:
params = {
"app_token": app_token,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def create_table(
@@ -732,9 +759,10 @@ class LarkRequest:
}
if default_view_name:
payload["default_view_name"] = default_view_name
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def delete_tables(
@@ -767,9 +795,10 @@ class LarkRequest:
"table_ids": table_id_list,
"table_names": table_name_list,
}
res = self._send_request(url, params=params, payload=payload)
res: dict = self._send_request(url, params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def list_tables(
@@ -784,9 +813,10 @@ class LarkRequest:
"page_token": page_token,
"page_size": page_size,
}
res = self._send_request(url, method="GET", params=params)
res: dict = self._send_request(url, method="GET", params=params)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res
def read_records(
@@ -814,7 +844,8 @@ class LarkRequest:
"record_ids": record_id_list,
"user_id_type": user_id_type,
}
res = self._send_request(url, method="POST", params=params, payload=payload)
res: dict = self._send_request(url, method="POST", params=params, payload=payload)
if "data" in res:
return res.get("data")
data: dict = res.get("data", {})
return data
return res

View File

@@ -90,12 +90,12 @@ class ToolFileMessageTransformer:
)
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
file = message.meta.get("file")
if isinstance(file, File):
if file.transfer_method == FileTransferMethod.TOOL_FILE:
assert file.related_id is not None
url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension)
if file.type == FileType.IMAGE:
file_mata = message.meta.get("file")
if isinstance(file_mata, File):
if file_mata.transfer_method == FileTransferMethod.TOOL_FILE:
assert file_mata.related_id is not None
url = cls.get_tool_file_url(tool_file_id=file_mata.related_id, extension=file_mata.extension)
if file_mata.type == FileType.IMAGE:
result.append(
ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.IMAGE_LINK,

View File

@@ -5,7 +5,7 @@ Therefore, a model manager is needed to list/invoke/validate models.
"""
import json
from typing import cast
from typing import Optional, cast
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMResult
@@ -51,7 +51,7 @@ class ModelInvocationUtils:
if not schema:
raise InvokeModelError("No model schema found")
max_tokens = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
max_tokens: Optional[int] = schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE, None)
if max_tokens is None:
return 2048
@@ -133,14 +133,17 @@ class ModelInvocationUtils:
db.session.commit()
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
response: LLMResult = cast(
LLMResult,
model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=model_parameters,
tools=[],
stop=[],
stream=False,
user=user_id,
callbacks=[],
),
)
except InvokeRateLimitError as e:
raise InvokeModelError(f"Invoke rate limit error: {e}")

View File

@@ -6,7 +6,7 @@ from json.decoder import JSONDecodeError
from typing import Optional
from requests import get
from yaml import YAMLError, safe_load
from yaml import YAMLError, safe_load # type: ignore
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
@@ -64,6 +64,9 @@ class ApiBasedToolSchemaParser:
default=parameter["schema"]["default"]
if "schema" in parameter and "default" in parameter["schema"]
else None,
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
)
# check if there is a type
@@ -108,6 +111,9 @@ class ApiBasedToolSchemaParser:
form=ToolParameter.ToolParameterForm.LLM,
llm_description=property.get("description", ""),
default=property.get("default", None),
placeholder=I18nObject(
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
),
)
# check if there is a type
@@ -158,9 +164,9 @@ class ApiBasedToolSchemaParser:
return bundles
@staticmethod
def _get_tool_parameter_type(parameter: dict) -> ToolParameter.ToolParameterType:
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
parameter = parameter or {}
typ = None
typ: Optional[str] = None
if parameter.get("format") == "binary":
return ToolParameter.ToolParameterType.FILE
@@ -175,6 +181,8 @@ class ApiBasedToolSchemaParser:
return ToolParameter.ToolParameterType.BOOLEAN
elif typ == "string":
return ToolParameter.ToolParameterType.STRING
else:
return None
@staticmethod
def parse_openapi_yaml_to_tool_bundle(
@@ -236,7 +244,8 @@ class ApiBasedToolSchemaParser:
if ("summary" not in operation or len(operation["summary"]) == 0) and (
"description" not in operation or len(operation["description"]) == 0
):
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
if warning is not None:
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
openapi["paths"][path][method] = {
"operationId": operation["operationId"],

View File

@@ -9,13 +9,13 @@ import tempfile
import unicodedata
from contextlib import contextmanager
from pathlib import Path
from typing import Optional
from typing import Any, Literal, Optional, cast
from urllib.parse import unquote
import chardet
import cloudscraper
from bs4 import BeautifulSoup, CData, Comment, NavigableString
from regex import regex
import cloudscraper # type: ignore
from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore
from regex import regex # type: ignore
from core.helper import ssrf_proxy
from core.rag.extractor import extract_processor
@@ -68,7 +68,7 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str:
return "Unsupported content-type [{}] of URL.".format(main_content_type)
if main_content_type in extract_processor.SUPPORT_URL_CONTENT_TYPES:
return ExtractProcessor.load_from_url(url, return_text=True)
return cast(str, ExtractProcessor.load_from_url(url, return_text=True))
response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300))
elif response.status_code == 403:
@@ -125,7 +125,7 @@ def extract_using_readabilipy(html):
os.unlink(article_json_path)
os.unlink(html_path)
article_json = {
article_json: dict[str, Any] = {
"title": None,
"byline": None,
"date": None,
@@ -300,7 +300,7 @@ def strip_control_characters(text):
def normalize_unicode(text):
"""Normalize unicode such that things that are visually equivalent map to the same unicode string where possible."""
normal_form = "NFKC"
normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC"
text = unicodedata.normalize(normal_form, text)
return text
@@ -332,6 +332,7 @@ def add_content_digest(element):
def content_digest(element):
digest: Any
if is_text(element):
# Hash
trimmed_string = element.string.strip()

View File

@@ -7,7 +7,7 @@ from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
class WorkflowToolConfigurationUtils:
@classmethod
def check_parameter_configurations(cls, configurations: Mapping[str, Any]):
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
for configuration in configurations:
WorkflowToolParameterConfiguration.model_validate(configuration)
@@ -27,7 +27,7 @@ class WorkflowToolConfigurationUtils:
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
) -> None:
) -> bool:
"""
check is synced

View File

@@ -2,7 +2,7 @@ import logging
from pathlib import Path
from typing import Any
import yaml
import yaml # type: ignore
from yaml import YAMLError
logger = logging.getLogger(__name__)