feat: mypy for all type check (#10921)
This commit is contained in:
@@ -48,7 +48,10 @@ class ApiExternalDataTool(ExternalDataTool):
|
||||
:return: the tool query result
|
||||
"""
|
||||
# get params from config
|
||||
if not self.config:
|
||||
raise ValueError("config is required, config: {}".format(self.config))
|
||||
api_based_extension_id = self.config.get("api_based_extension_id")
|
||||
assert api_based_extension_id is not None, "api_based_extension_id is required"
|
||||
|
||||
# get api_based_extension
|
||||
api_based_extension = (
|
||||
|
@@ -1,7 +1,7 @@
|
||||
import concurrent
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
@@ -17,9 +17,9 @@ class ExternalDataFetch:
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
) -> dict:
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
@@ -30,13 +30,14 @@ class ExternalDataFetch:
|
||||
:param query: the query
|
||||
:return: the filled inputs
|
||||
"""
|
||||
results = {}
|
||||
results: dict[str, Any] = {}
|
||||
inputs = dict(inputs)
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = {}
|
||||
for tool in external_data_tools:
|
||||
future = executor.submit(
|
||||
future: Future[tuple[str | None, str | None]] = executor.submit(
|
||||
self._query_external_data_tool,
|
||||
current_app._get_current_object(),
|
||||
current_app._get_current_object(), # type: ignore
|
||||
tenant_id,
|
||||
app_id,
|
||||
tool,
|
||||
@@ -46,9 +47,10 @@ class ExternalDataFetch:
|
||||
|
||||
futures[future] = tool
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
for future in as_completed(futures):
|
||||
tool_variable, result = future.result()
|
||||
results[tool_variable] = result
|
||||
if tool_variable is not None:
|
||||
results[tool_variable] = result
|
||||
|
||||
inputs.update(results)
|
||||
return inputs
|
||||
@@ -59,7 +61,7 @@ class ExternalDataFetch:
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tool: ExternalDataVariableEntity,
|
||||
inputs: dict,
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
) -> tuple[Optional[str], Optional[str]]:
|
||||
"""
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from typing import Optional
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.extension.extensible import ExtensionModule
|
||||
from extensions.ext_code_based_extension import code_based_extension
|
||||
@@ -23,9 +24,10 @@ class ExternalDataToolFactory:
|
||||
"""
|
||||
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
|
||||
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||
extension_class.validate_config(tenant_id, config)
|
||||
# FIXME mypy issue here, figure out how to fix it
|
||||
extension_class.validate_config(tenant_id, config) # type: ignore
|
||||
|
||||
def query(self, inputs: dict, query: Optional[str] = None) -> str:
|
||||
def query(self, inputs: Mapping[str, Any], query: Optional[str] = None) -> str:
|
||||
"""
|
||||
Query the external data tool.
|
||||
|
||||
@@ -33,4 +35,4 @@ class ExternalDataToolFactory:
|
||||
:param query: the query of chat app
|
||||
:return: the tool query result
|
||||
"""
|
||||
return self.__extension_instance.query(inputs, query)
|
||||
return cast(str, self.__extension_instance.query(inputs, query))
|
||||
|
Reference in New Issue
Block a user