feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -32,11 +32,13 @@ class ApiTool(Tool):
:param meta: the meta data of a tool call processing, tenant_id is required
:return: the new tool
"""
if self.api_bundle is None:
raise ValueError("api_bundle is required")
return self.__class__(
identity=self.identity.model_copy() if self.identity else None,
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.model_copy() if self.description else None,
api_bundle=self.api_bundle.model_copy() if self.api_bundle else None,
api_bundle=self.api_bundle.model_copy(),
runtime=Tool.Runtime(**runtime),
)
@@ -61,6 +63,8 @@ class ApiTool(Tool):
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
headers = {}
if self.runtime is None:
raise ValueError("runtime is required")
credentials = self.runtime.credentials or {}
if "auth_type" not in credentials:
@@ -88,7 +92,7 @@ class ApiTool(Tool):
headers[api_key_header] = credentials["api_key_value"]
needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required]
needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required]
for parameter in needed_parameters:
if parameter.required and parameter.name not in parameters:
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
@@ -137,7 +141,8 @@ class ApiTool(Tool):
params = {}
path_params = {}
body = {}
# FIXME: body should be a dict[str, Any] but it changed a lot in this function
body: Any = {}
cookies = {}
files = []
@@ -198,7 +203,7 @@ class ApiTool(Tool):
body = body
if method in {"get", "head", "post", "put", "delete", "patch"}:
response = getattr(ssrf_proxy, method)(
response: httpx.Response = getattr(ssrf_proxy, method)(
url,
params=params,
headers=headers,
@@ -288,6 +293,7 @@ class ApiTool(Tool):
"""
invoke http request
"""
response: httpx.Response | str = ""
# assemble request
headers = self.assembling_request(tool_parameters)

View File

@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, cast
from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
@@ -32,9 +32,12 @@ class BuiltinTool(Tool):
:return: the model result
"""
# invoke model
if self.runtime is None or self.identity is None:
raise ValueError("runtime and identity are required")
return ModelInvocationUtils.invoke(
user_id=user_id,
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
tool_type="builtin",
tool_name=self.identity.name,
prompt_messages=prompt_messages,
@@ -50,8 +53,11 @@ class BuiltinTool(Tool):
:param model_config: the model config
:return: the max tokens
"""
if self.runtime is None:
raise ValueError("runtime is required")
return ModelInvocationUtils.get_max_llm_context_tokens(
tenant_id=self.runtime.tenant_id,
tenant_id=self.runtime.tenant_id or "",
)
def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
@@ -61,7 +67,12 @@ class BuiltinTool(Tool):
:param prompt_messages: the prompt messages
:return: the tokens
"""
return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages)
if self.runtime is None:
raise ValueError("runtime is required")
return ModelInvocationUtils.calculate_tokens(
tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages
)
def summary(self, user_id: str, content: str) -> str:
max_tokens = self.get_max_tokens()
@@ -81,7 +92,7 @@ class BuiltinTool(Tool):
stop=[],
)
return summary.message.content
return cast(str, summary.message.content)
lines = content.split("\n")
new_lines = []
@@ -102,16 +113,16 @@ class BuiltinTool(Tool):
# merge lines into messages with max tokens
messages: list[str] = []
for i in new_lines:
for j in new_lines:
if len(messages) == 0:
messages.append(i)
messages.append(j)
else:
if len(messages[-1]) + len(i) < max_tokens * 0.5:
messages[-1] += i
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
messages.append(i)
if len(messages[-1]) + len(j) < max_tokens * 0.5:
messages[-1] += j
if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7:
messages.append(j)
else:
messages[-1] += i
messages[-1] += j
summaries = []
for i in range(len(messages)):

View File

@@ -1,4 +1,5 @@
import threading
from typing import Any
from flask import Flask, current_app
from pydantic import BaseModel, Field
@@ -7,13 +8,14 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.models.document import Document as RagDocument
from core.rag.rerank.rerank_model import RerankModelRunner
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
default_retrieval_model = {
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
@@ -44,12 +46,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
def _run(self, query: str) -> str:
threads = []
all_documents = []
all_documents: list[RagDocument] = []
for dataset_id in self.dataset_ids:
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(),
"flask_app": current_app._get_current_object(), # type: ignore
"dataset_id": dataset_id,
"query": query,
"all_documents": all_documents,
@@ -77,11 +79,11 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
document_score_list = {}
for item in all_documents:
if item.metadata.get("score"):
if item.metadata and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents]
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
@@ -139,6 +141,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
return ""
def _retriever(
self,

View File

@@ -1,7 +1,7 @@
from abc import abstractmethod
from typing import Any, Optional
from msal_extensions.persistence import ABC
from msal_extensions.persistence import ABC # type: ignore
from pydantic import BaseModel, ConfigDict
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel, Field
from core.rag.datasource.retrieval_service import RetrievalService
@@ -69,25 +71,27 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
metadata=external_document.get("metadata"),
provider="external",
)
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset.id
document.metadata["dataset_name"] = dataset.name
results.append(document)
if document.metadata is not None:
document.metadata["score"] = external_document.get("score")
document.metadata["title"] = external_document.get("title")
document.metadata["dataset_id"] = dataset.id
document.metadata["dataset_name"] = dataset.name
results.append(document)
# deal with external documents
context_list = []
for position, item in enumerate(results, start=1):
source = {
"position": position,
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": self.retriever_from,
"score": item.metadata.get("score"),
"title": item.metadata.get("title"),
"content": item.page_content,
}
if item.metadata is not None:
source = {
"position": position,
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": self.retriever_from,
"score": item.metadata.get("score"),
"title": item.metadata.get("title"),
"content": item.page_content,
}
context_list.append(source)
for hit_callback in self.hit_callbacks:
hit_callback.return_retriever_resource_info(context_list)
@@ -95,7 +99,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return str("\n".join([item.page_content for item in results]))
else:
# get retrieval model , if the model is not setting , using default
retrieval_model = dataset.retrieval_model or default_retrieval_model
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
@@ -113,11 +117,11 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else 0.0,
reranking_model=retrieval_model.get("reranking_model", None)
reranking_model=retrieval_model.get("reranking_model")
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights", None),
weights=retrieval_model.get("weights"),
)
else:
documents = []
@@ -127,7 +131,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
document_score_list = {}
if dataset.indexing_technique != "economy":
for item in documents:
if item.metadata.get("score"):
if item.metadata is not None and item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in documents]
@@ -155,20 +159,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
context_list = []
resource_number = 1
for segment in sorted_segments:
context = {}
document = Document.query.filter(
document_segment = Document.query.filter(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
).first()
if dataset and document:
if not document_segment:
continue
if dataset and document_segment:
source = {
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"document_id": document_segment.id,
"document_name": document_segment.name,
"data_source_type": document_segment.data_source_type,
"segment_id": segment.id,
"retriever_from": self.retriever_from,
"score": document_score_list.get(segment.index_node_id, None),

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Optional
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -23,7 +23,7 @@ class DatasetRetrieverTool(Tool):
def get_dataset_tools(
tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
retrieve_config: Optional[DatasetRetrieveConfigEntity],
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
@@ -51,6 +51,8 @@ class DatasetRetrieverTool(Tool):
invoke_from=invoke_from,
hit_callback=hit_callback,
)
if retrieval_tools is None:
return []
# restore retrieve strategy
retrieve_config.retrieve_strategy = original_retriever_mode
@@ -83,6 +85,7 @@ class DatasetRetrieverTool(Tool):
llm_description="Query for the dataset to be used to retrieve the dataset.",
required=True,
default="",
placeholder=I18nObject(en_US="", zh_Hans=""),
),
]
@@ -102,7 +105,9 @@ class DatasetRetrieverTool(Tool):
return self.create_text_message(text=result)
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
def validate_credentials(
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
) -> str | None:
"""
validate the credentials for dataset retriever tool
"""

View File

@@ -91,7 +91,7 @@ class Tool(BaseModel, ABC):
:return: the tool provider type
"""
def load_variables(self, variables: ToolRuntimeVariablePool):
def load_variables(self, variables: ToolRuntimeVariablePool | None) -> None:
"""
load variables from database
@@ -105,6 +105,8 @@ class Tool(BaseModel, ABC):
"""
if not self.variables:
return
if self.identity is None:
return
self.variables.set_file(self.identity.name, variable_name, image_key)
@@ -114,6 +116,8 @@ class Tool(BaseModel, ABC):
"""
if not self.variables:
return
if self.identity is None:
return
self.variables.set_text(self.identity.name, variable_name, text)
@@ -200,7 +204,11 @@ class Tool(BaseModel, ABC):
def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]:
# update tool_parameters
# TODO: Fix type error.
if self.runtime is None:
return []
if self.runtime.runtime_parameters:
# Convert Mapping to dict before updating
tool_parameters = dict(tool_parameters)
tool_parameters.update(self.runtime.runtime_parameters)
# try parse tool parameters into the correct type
@@ -221,7 +229,7 @@ class Tool(BaseModel, ABC):
Transform tool parameters type
"""
# Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
result = deepcopy(tool_parameters)
result: dict[str, Any] = deepcopy(dict(tool_parameters))
for parameter in self.parameters or []:
if parameter.name in tool_parameters:
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
@@ -234,12 +242,15 @@ class Tool(BaseModel, ABC):
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
pass
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
def validate_credentials(
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
) -> str | None:
"""
validate the credentials
:param credentials: the credentials
:param parameters: the parameters
:param format_only: only return the formatted
"""
pass

View File

@@ -68,20 +68,20 @@ class WorkflowTool(Tool):
if data.get("error"):
raise Exception(data.get("error"))
result = []
r = []
outputs = data.get("outputs")
if outputs == None:
outputs = {}
else:
outputs, files = self._extract_files(outputs)
for file in files:
result.append(self.create_file_message(file))
outputs, extracted_files = self._extract_files(outputs)
for f in extracted_files:
r.append(self.create_file_message(f))
result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
result.append(self.create_json_message(outputs))
r.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
r.append(self.create_json_message(outputs))
return result
return r
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
"""