feat: mypy for all type check (#10921)

This commit is contained in:
yihong
2024-12-24 18:38:51 +08:00
committed by GitHub
parent c91e8b1737
commit 56e15d09a9
584 changed files with 3975 additions and 2826 deletions

View File

@@ -1,6 +1,7 @@
import json
import logging
from typing import Optional
from collections.abc import Mapping
from typing import Any, Optional, cast
from httpx import get
@@ -28,12 +29,12 @@ logger = logging.getLogger(__name__)
class ApiToolManageService:
@staticmethod
def parser_api_schema(schema: str) -> list[ApiToolBundle]:
def parser_api_schema(schema: str) -> Mapping[str, Any]:
"""
parse api schema to tool bundle
"""
try:
warnings = {}
warnings: dict[str, str] = {}
try:
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
except Exception as e:
@@ -68,13 +69,16 @@ class ApiToolManageService:
),
]
return jsonable_encoder(
{
"schema_type": schema_type,
"parameters_schema": tool_bundles,
"credentials_schema": credentials_schema,
"warning": warnings,
}
return cast(
Mapping,
jsonable_encoder(
{
"schema_type": schema_type,
"parameters_schema": tool_bundles,
"credentials_schema": credentials_schema,
"warning": warnings,
}
),
)
except Exception as e:
raise ValueError(f"invalid schema: {str(e)}")
@@ -129,7 +133,7 @@ class ApiToolManageService:
raise ValueError(f"provider {provider_name} already exists")
# parse openapi to tool bundle
extra_info = {}
extra_info: dict[str, str] = {}
# extra info like description will be set here
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
@@ -262,9 +266,8 @@ class ApiToolManageService:
if provider is None:
raise ValueError(f"api provider {provider_name} does not exists")
# parse openapi to tool bundle
extra_info = {}
extra_info: dict[str, str] = {}
# extra info like description will be set here
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
@@ -416,7 +419,7 @@ class ApiToolManageService:
provider_controller.validate_credentials_format(credentials)
# get tool
tool = provider_controller.get_tool(tool_name)
tool = tool.fork_tool_runtime(
runtime_tool = tool.fork_tool_runtime(
runtime={
"credentials": credentials,
"tenant_id": tenant_id,
@@ -454,7 +457,7 @@ class ApiToolManageService:
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
for tool in tools:
for tool in tools or []:
user_provider.tools.append(
ToolTransformService.tool_to_user_tool(
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels