feat: mypy for all type check (#10921)
This commit is contained in:
@@ -32,11 +32,13 @@ class ApiTool(Tool):
|
||||
:param meta: the meta data of a tool call processing, tenant_id is required
|
||||
:return: the new tool
|
||||
"""
|
||||
if self.api_bundle is None:
|
||||
raise ValueError("api_bundle is required")
|
||||
return self.__class__(
|
||||
identity=self.identity.model_copy() if self.identity else None,
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.model_copy() if self.description else None,
|
||||
api_bundle=self.api_bundle.model_copy() if self.api_bundle else None,
|
||||
api_bundle=self.api_bundle.model_copy(),
|
||||
runtime=Tool.Runtime(**runtime),
|
||||
)
|
||||
|
||||
@@ -61,6 +63,8 @@ class ApiTool(Tool):
|
||||
|
||||
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||
headers = {}
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
credentials = self.runtime.credentials or {}
|
||||
|
||||
if "auth_type" not in credentials:
|
||||
@@ -88,7 +92,7 @@ class ApiTool(Tool):
|
||||
|
||||
headers[api_key_header] = credentials["api_key_value"]
|
||||
|
||||
needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required]
|
||||
needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required]
|
||||
for parameter in needed_parameters:
|
||||
if parameter.required and parameter.name not in parameters:
|
||||
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
|
||||
@@ -137,7 +141,8 @@ class ApiTool(Tool):
|
||||
|
||||
params = {}
|
||||
path_params = {}
|
||||
body = {}
|
||||
# FIXME: body should be a dict[str, Any] but it changed a lot in this function
|
||||
body: Any = {}
|
||||
cookies = {}
|
||||
files = []
|
||||
|
||||
@@ -198,7 +203,7 @@ class ApiTool(Tool):
|
||||
body = body
|
||||
|
||||
if method in {"get", "head", "post", "put", "delete", "patch"}:
|
||||
response = getattr(ssrf_proxy, method)(
|
||||
response: httpx.Response = getattr(ssrf_proxy, method)(
|
||||
url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
@@ -288,6 +293,7 @@ class ApiTool(Tool):
|
||||
"""
|
||||
invoke http request
|
||||
"""
|
||||
response: httpx.Response | str = ""
|
||||
# assemble request
|
||||
headers = self.assembling_request(tool_parameters)
|
||||
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
@@ -32,9 +32,12 @@ class BuiltinTool(Tool):
|
||||
:return: the model result
|
||||
"""
|
||||
# invoke model
|
||||
if self.runtime is None or self.identity is None:
|
||||
raise ValueError("runtime and identity are required")
|
||||
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tool_type="builtin",
|
||||
tool_name=self.identity.name,
|
||||
prompt_messages=prompt_messages,
|
||||
@@ -50,8 +53,11 @@ class BuiltinTool(Tool):
|
||||
:param model_config: the model config
|
||||
:return: the max tokens
|
||||
"""
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
|
||||
return ModelInvocationUtils.get_max_llm_context_tokens(
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
)
|
||||
|
||||
def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int:
|
||||
@@ -61,7 +67,12 @@ class BuiltinTool(Tool):
|
||||
:param prompt_messages: the prompt messages
|
||||
:return: the tokens
|
||||
"""
|
||||
return ModelInvocationUtils.calculate_tokens(tenant_id=self.runtime.tenant_id, prompt_messages=prompt_messages)
|
||||
if self.runtime is None:
|
||||
raise ValueError("runtime is required")
|
||||
|
||||
return ModelInvocationUtils.calculate_tokens(
|
||||
tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
def summary(self, user_id: str, content: str) -> str:
|
||||
max_tokens = self.get_max_tokens()
|
||||
@@ -81,7 +92,7 @@ class BuiltinTool(Tool):
|
||||
stop=[],
|
||||
)
|
||||
|
||||
return summary.message.content
|
||||
return cast(str, summary.message.content)
|
||||
|
||||
lines = content.split("\n")
|
||||
new_lines = []
|
||||
@@ -102,16 +113,16 @@ class BuiltinTool(Tool):
|
||||
|
||||
# merge lines into messages with max tokens
|
||||
messages: list[str] = []
|
||||
for i in new_lines:
|
||||
for j in new_lines:
|
||||
if len(messages) == 0:
|
||||
messages.append(i)
|
||||
messages.append(j)
|
||||
else:
|
||||
if len(messages[-1]) + len(i) < max_tokens * 0.5:
|
||||
messages[-1] += i
|
||||
if get_prompt_tokens(messages[-1] + i) > max_tokens * 0.7:
|
||||
messages.append(i)
|
||||
if len(messages[-1]) + len(j) < max_tokens * 0.5:
|
||||
messages[-1] += j
|
||||
if get_prompt_tokens(messages[-1] + j) > max_tokens * 0.7:
|
||||
messages.append(j)
|
||||
else:
|
||||
messages[-1] += i
|
||||
messages[-1] += j
|
||||
|
||||
summaries = []
|
||||
for i in range(len(messages)):
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -7,13 +8,14 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.models.document import Document as RagDocument
|
||||
from core.rag.rerank.rerank_model import RerankModelRunner
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
default_retrieval_model = {
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
@@ -44,12 +46,12 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
threads = []
|
||||
all_documents = []
|
||||
all_documents: list[RagDocument] = []
|
||||
for dataset_id in self.dataset_ids:
|
||||
retrieval_thread = threading.Thread(
|
||||
target=self._retriever,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(),
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"dataset_id": dataset_id,
|
||||
"query": query,
|
||||
"all_documents": all_documents,
|
||||
@@ -77,11 +79,11 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
|
||||
document_score_list = {}
|
||||
for item in all_documents:
|
||||
if item.metadata.get("score"):
|
||||
if item.metadata and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents]
|
||||
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
|
||||
segments = DocumentSegment.query.filter(
|
||||
DocumentSegment.dataset_id.in_(self.dataset_ids),
|
||||
DocumentSegment.completed_at.isnot(None),
|
||||
@@ -139,6 +141,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
|
||||
return str("\n".join(document_context_list))
|
||||
return ""
|
||||
|
||||
def _retriever(
|
||||
self,
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from msal_extensions.persistence import ABC
|
||||
from msal_extensions.persistence import ABC # type: ignore
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@@ -69,25 +71,27 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
metadata=external_document.get("metadata"),
|
||||
provider="external",
|
||||
)
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset.id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
if document.metadata is not None:
|
||||
document.metadata["score"] = external_document.get("score")
|
||||
document.metadata["title"] = external_document.get("title")
|
||||
document.metadata["dataset_id"] = dataset.id
|
||||
document.metadata["dataset_name"] = dataset.name
|
||||
results.append(document)
|
||||
# deal with external documents
|
||||
context_list = []
|
||||
for position, item in enumerate(results, start=1):
|
||||
source = {
|
||||
"position": position,
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": item.metadata.get("score"),
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
if item.metadata is not None:
|
||||
source = {
|
||||
"position": position,
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": item.metadata.get("score"),
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
context_list.append(source)
|
||||
for hit_callback in self.hit_callbacks:
|
||||
hit_callback.return_retriever_resource_info(context_list)
|
||||
@@ -95,7 +99,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
return str("\n".join([item.page_content for item in results]))
|
||||
else:
|
||||
# get retrieval model , if the model is not setting , using default
|
||||
retrieval_model = dataset.retrieval_model or default_retrieval_model
|
||||
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
|
||||
if dataset.indexing_technique == "economy":
|
||||
# use keyword table query
|
||||
documents = RetrievalService.retrieve(
|
||||
@@ -113,11 +117,11 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
score_threshold=retrieval_model.get("score_threshold", 0.0)
|
||||
if retrieval_model["score_threshold_enabled"]
|
||||
else 0.0,
|
||||
reranking_model=retrieval_model.get("reranking_model", None)
|
||||
reranking_model=retrieval_model.get("reranking_model")
|
||||
if retrieval_model["reranking_enable"]
|
||||
else None,
|
||||
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
|
||||
weights=retrieval_model.get("weights", None),
|
||||
weights=retrieval_model.get("weights"),
|
||||
)
|
||||
else:
|
||||
documents = []
|
||||
@@ -127,7 +131,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
document_score_list = {}
|
||||
if dataset.indexing_technique != "economy":
|
||||
for item in documents:
|
||||
if item.metadata.get("score"):
|
||||
if item.metadata is not None and item.metadata.get("score"):
|
||||
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
|
||||
document_context_list = []
|
||||
index_node_ids = [document.metadata["doc_id"] for document in documents]
|
||||
@@ -155,20 +159,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
context_list = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
context = {}
|
||||
document = Document.query.filter(
|
||||
document_segment = Document.query.filter(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
).first()
|
||||
if dataset and document:
|
||||
if not document_segment:
|
||||
continue
|
||||
if dataset and document_segment:
|
||||
source = {
|
||||
"position": resource_number,
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"document_id": document_segment.id,
|
||||
"document_name": document_segment.name,
|
||||
"data_source_type": document_segment.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": self.retriever_from,
|
||||
"score": document_score_list.get(segment.index_node_id, None),
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@@ -23,7 +23,7 @@ class DatasetRetrieverTool(Tool):
|
||||
def get_dataset_tools(
|
||||
tenant_id: str,
|
||||
dataset_ids: list[str],
|
||||
retrieve_config: DatasetRetrieveConfigEntity,
|
||||
retrieve_config: Optional[DatasetRetrieveConfigEntity],
|
||||
return_resource: bool,
|
||||
invoke_from: InvokeFrom,
|
||||
hit_callback: DatasetIndexToolCallbackHandler,
|
||||
@@ -51,6 +51,8 @@ class DatasetRetrieverTool(Tool):
|
||||
invoke_from=invoke_from,
|
||||
hit_callback=hit_callback,
|
||||
)
|
||||
if retrieval_tools is None:
|
||||
return []
|
||||
# restore retrieve strategy
|
||||
retrieve_config.retrieve_strategy = original_retriever_mode
|
||||
|
||||
@@ -83,6 +85,7 @@ class DatasetRetrieverTool(Tool):
|
||||
llm_description="Query for the dataset to be used to retrieve the dataset.",
|
||||
required=True,
|
||||
default="",
|
||||
placeholder=I18nObject(en_US="", zh_Hans=""),
|
||||
),
|
||||
]
|
||||
|
||||
@@ -102,7 +105,9 @@ class DatasetRetrieverTool(Tool):
|
||||
|
||||
return self.create_text_message(text=result)
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
|
||||
def validate_credentials(
|
||||
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
|
||||
) -> str | None:
|
||||
"""
|
||||
validate the credentials for dataset retriever tool
|
||||
"""
|
||||
|
@@ -91,7 +91,7 @@ class Tool(BaseModel, ABC):
|
||||
:return: the tool provider type
|
||||
"""
|
||||
|
||||
def load_variables(self, variables: ToolRuntimeVariablePool):
|
||||
def load_variables(self, variables: ToolRuntimeVariablePool | None) -> None:
|
||||
"""
|
||||
load variables from database
|
||||
|
||||
@@ -105,6 +105,8 @@ class Tool(BaseModel, ABC):
|
||||
"""
|
||||
if not self.variables:
|
||||
return
|
||||
if self.identity is None:
|
||||
return
|
||||
|
||||
self.variables.set_file(self.identity.name, variable_name, image_key)
|
||||
|
||||
@@ -114,6 +116,8 @@ class Tool(BaseModel, ABC):
|
||||
"""
|
||||
if not self.variables:
|
||||
return
|
||||
if self.identity is None:
|
||||
return
|
||||
|
||||
self.variables.set_text(self.identity.name, variable_name, text)
|
||||
|
||||
@@ -200,7 +204,11 @@ class Tool(BaseModel, ABC):
|
||||
def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> list[ToolInvokeMessage]:
|
||||
# update tool_parameters
|
||||
# TODO: Fix type error.
|
||||
if self.runtime is None:
|
||||
return []
|
||||
if self.runtime.runtime_parameters:
|
||||
# Convert Mapping to dict before updating
|
||||
tool_parameters = dict(tool_parameters)
|
||||
tool_parameters.update(self.runtime.runtime_parameters)
|
||||
|
||||
# try parse tool parameters into the correct type
|
||||
@@ -221,7 +229,7 @@ class Tool(BaseModel, ABC):
|
||||
Transform tool parameters type
|
||||
"""
|
||||
# Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
|
||||
result = deepcopy(tool_parameters)
|
||||
result: dict[str, Any] = deepcopy(dict(tool_parameters))
|
||||
for parameter in self.parameters or []:
|
||||
if parameter.name in tool_parameters:
|
||||
result[parameter.name] = parameter.type.cast_value(tool_parameters[parameter.name])
|
||||
@@ -234,12 +242,15 @@ class Tool(BaseModel, ABC):
|
||||
) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
|
||||
pass
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
|
||||
def validate_credentials(
|
||||
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
|
||||
) -> str | None:
|
||||
"""
|
||||
validate the credentials
|
||||
|
||||
:param credentials: the credentials
|
||||
:param parameters: the parameters
|
||||
:param format_only: only return the formatted
|
||||
"""
|
||||
pass
|
||||
|
||||
|
@@ -68,20 +68,20 @@ class WorkflowTool(Tool):
|
||||
if data.get("error"):
|
||||
raise Exception(data.get("error"))
|
||||
|
||||
result = []
|
||||
r = []
|
||||
|
||||
outputs = data.get("outputs")
|
||||
if outputs == None:
|
||||
outputs = {}
|
||||
else:
|
||||
outputs, files = self._extract_files(outputs)
|
||||
for file in files:
|
||||
result.append(self.create_file_message(file))
|
||||
outputs, extracted_files = self._extract_files(outputs)
|
||||
for f in extracted_files:
|
||||
r.append(self.create_file_message(f))
|
||||
|
||||
result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
|
||||
result.append(self.create_json_message(outputs))
|
||||
r.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
|
||||
r.append(self.create_json_message(outputs))
|
||||
|
||||
return result
|
||||
return r
|
||||
|
||||
def _get_user(self, user_id: str) -> Union[EndUser, Account]:
|
||||
"""
|
||||
|
Reference in New Issue
Block a user