feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -48,7 +48,10 @@ class ApiExternalDataTool(ExternalDataTool):
:return: the tool query result
"""
# get params from config
if not self.config:
raise ValueError("config is required, config: {}".format(self.config))
api_based_extension_id = self.config.get("api_based_extension_id")
assert api_based_extension_id is not None, "api_based_extension_id is required"
# get api_based_extension
api_based_extension = (

View File

@@ -1,7 +1,7 @@
import concurrent
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from collections.abc import Mapping
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from typing import Any, Optional
from flask import Flask, current_app
@@ -17,9 +17,9 @@ class ExternalDataFetch:
tenant_id: str,
app_id: str,
external_data_tools: list[ExternalDataVariableEntity],
inputs: dict,
inputs: Mapping[str, Any],
query: str,
) -> dict:
) -> Mapping[str, Any]:
"""
Fill in variable inputs from external data tools if exists.
@@ -30,13 +30,14 @@ class ExternalDataFetch:
:param query: the query
:return: the filled inputs
"""
results = {}
results: dict[str, Any] = {}
inputs = dict(inputs)
with ThreadPoolExecutor() as executor:
futures = {}
for tool in external_data_tools:
future = executor.submit(
future: Future[tuple[str | None, str | None]] = executor.submit(
self._query_external_data_tool,
current_app._get_current_object(),
current_app._get_current_object(), # type: ignore
tenant_id,
app_id,
tool,
@@ -46,9 +47,10 @@ class ExternalDataFetch:
futures[future] = tool
for future in concurrent.futures.as_completed(futures):
for future in as_completed(futures):
tool_variable, result = future.result()
results[tool_variable] = result
if tool_variable is not None:
results[tool_variable] = result
inputs.update(results)
return inputs
@@ -59,7 +61,7 @@ class ExternalDataFetch:
tenant_id: str,
app_id: str,
external_data_tool: ExternalDataVariableEntity,
inputs: dict,
inputs: Mapping[str, Any],
query: str,
) -> tuple[Optional[str], Optional[str]]:
"""

View File

@@ -1,4 +1,5 @@
from typing import Optional
from collections.abc import Mapping
from typing import Any, Optional, cast
from core.extension.extensible import ExtensionModule
from extensions.ext_code_based_extension import code_based_extension
@@ -23,9 +24,10 @@ class ExternalDataToolFactory:
"""
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
extension_class.validate_config(tenant_id, config)
# FIXME mypy issue here, figure out how to fix it
extension_class.validate_config(tenant_id, config) # type: ignore
def query(self, inputs: dict, query: Optional[str] = None) -> str:
def query(self, inputs: Mapping[str, Any], query: Optional[str] = None) -> str:
"""
Query the external data tool.
@@ -33,4 +35,4 @@ class ExternalDataToolFactory:
:param query: the query of chat app
:return: the tool query result
"""
return self.__extension_instance.query(inputs, query)
return cast(str, self.__extension_instance.query(inputs, query))