chore(api/services): apply ruff reformatting (#7599)
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -29,111 +29,107 @@ class ApiToolManageService:
|
||||
@staticmethod
|
||||
def parser_api_schema(schema: str) -> list[ApiToolBundle]:
|
||||
"""
|
||||
parse api schema to tool bundle
|
||||
parse api schema to tool bundle
|
||||
"""
|
||||
try:
|
||||
warnings = {}
|
||||
try:
|
||||
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings)
|
||||
except Exception as e:
|
||||
raise ValueError(f'invalid schema: {str(e)}')
|
||||
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
credentials_schema = [
|
||||
ToolProviderCredentials(
|
||||
name='auth_type',
|
||||
name="auth_type",
|
||||
type=ToolProviderCredentials.CredentialsType.SELECT,
|
||||
required=True,
|
||||
default='none',
|
||||
default="none",
|
||||
options=[
|
||||
ToolCredentialsOption(value='none', label=I18nObject(
|
||||
en_US='None',
|
||||
zh_Hans='无'
|
||||
)),
|
||||
ToolCredentialsOption(value='api_key', label=I18nObject(
|
||||
en_US='Api Key',
|
||||
zh_Hans='Api Key'
|
||||
)),
|
||||
ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
|
||||
ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
|
||||
],
|
||||
placeholder=I18nObject(
|
||||
en_US='Select auth type',
|
||||
zh_Hans='选择认证方式'
|
||||
)
|
||||
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
|
||||
),
|
||||
ToolProviderCredentials(
|
||||
name='api_key_header',
|
||||
name="api_key_header",
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
required=False,
|
||||
placeholder=I18nObject(
|
||||
en_US='Enter api key header',
|
||||
zh_Hans='输入 api key header,如:X-API-KEY'
|
||||
),
|
||||
default='api_key',
|
||||
help=I18nObject(
|
||||
en_US='HTTP header name for api key',
|
||||
zh_Hans='HTTP 头部字段名,用于传递 api key'
|
||||
)
|
||||
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"),
|
||||
default="api_key",
|
||||
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
|
||||
),
|
||||
ToolProviderCredentials(
|
||||
name='api_key_value',
|
||||
name="api_key_value",
|
||||
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
|
||||
required=False,
|
||||
placeholder=I18nObject(
|
||||
en_US='Enter api key',
|
||||
zh_Hans='输入 api key'
|
||||
),
|
||||
default=''
|
||||
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
|
||||
default="",
|
||||
),
|
||||
]
|
||||
|
||||
return jsonable_encoder({
|
||||
'schema_type': schema_type,
|
||||
'parameters_schema': tool_bundles,
|
||||
'credentials_schema': credentials_schema,
|
||||
'warning': warnings
|
||||
})
|
||||
return jsonable_encoder(
|
||||
{
|
||||
"schema_type": schema_type,
|
||||
"parameters_schema": tool_bundles,
|
||||
"credentials_schema": credentials_schema,
|
||||
"warning": warnings,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f'invalid schema: {str(e)}')
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]:
|
||||
"""
|
||||
convert schema to tool bundles
|
||||
convert schema to tool bundles
|
||||
|
||||
:return: the list of tool bundles, description
|
||||
:return: the list of tool bundles, description
|
||||
"""
|
||||
try:
|
||||
tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
|
||||
return tool_bundles
|
||||
except Exception as e:
|
||||
raise ValueError(f'invalid schema: {str(e)}')
|
||||
raise ValueError(f"invalid schema: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def create_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str, icon: dict, credentials: dict,
|
||||
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
icon: dict,
|
||||
credentials: dict,
|
||||
schema_type: str,
|
||||
schema: str,
|
||||
privacy_policy: str,
|
||||
custom_disclaimer: str,
|
||||
labels: list[str],
|
||||
):
|
||||
"""
|
||||
create api tool provider
|
||||
create api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f'invalid schema type {schema}')
|
||||
|
||||
raise ValueError(f"invalid schema type {schema}")
|
||||
|
||||
# check if the provider exists
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
).first()
|
||||
provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is not None:
|
||||
raise ValueError(f'provider {provider_name} already exists')
|
||||
raise ValueError(f"provider {provider_name} already exists")
|
||||
|
||||
# parse openapi to tool bundle
|
||||
extra_info = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
|
||||
if len(tool_bundles) > 100:
|
||||
raise ValueError('the number of apis should be less than 100')
|
||||
raise ValueError("the number of apis should be less than 100")
|
||||
|
||||
# create db provider
|
||||
db_provider = ApiToolProvider(
|
||||
@@ -142,19 +138,19 @@ class ApiToolManageService:
|
||||
name=provider_name,
|
||||
icon=json.dumps(icon),
|
||||
schema=schema,
|
||||
description=extra_info.get('description', ''),
|
||||
description=extra_info.get("description", ""),
|
||||
schema_type_str=schema_type,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str={},
|
||||
privacy_policy=privacy_policy,
|
||||
custom_disclaimer=custom_disclaimer
|
||||
custom_disclaimer=custom_disclaimer,
|
||||
)
|
||||
|
||||
if 'auth_type' not in credentials:
|
||||
raise ValueError('auth_type is required')
|
||||
if "auth_type" not in credentials:
|
||||
raise ValueError("auth_type is required")
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
|
||||
@@ -172,14 +168,12 @@ class ApiToolManageService:
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_api_tool_provider_remote_schema(
|
||||
user_id: str, tenant_id: str, url: str
|
||||
):
|
||||
def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str):
|
||||
"""
|
||||
get api tool provider remote schema
|
||||
get api tool provider remote schema
|
||||
"""
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0",
|
||||
@@ -189,84 +183,98 @@ class ApiToolManageService:
|
||||
try:
|
||||
response = get(url, headers=headers, timeout=10)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f'Got status code {response.status_code}')
|
||||
raise ValueError(f"Got status code {response.status_code}")
|
||||
schema = response.text
|
||||
|
||||
# try to parse schema, avoid SSRF attack
|
||||
ApiToolManageService.parser_api_schema(schema)
|
||||
except Exception as e:
|
||||
logger.error(f"parse api schema error: {str(e)}")
|
||||
raise ValueError('invalid schema, please check the url you provided')
|
||||
|
||||
return {
|
||||
'schema': schema
|
||||
}
|
||||
raise ValueError("invalid schema, please check the url you provided")
|
||||
|
||||
return {"schema": schema}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tool_provider_tools(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
) -> list[UserTool]:
|
||||
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
|
||||
"""
|
||||
list api tool provider tools
|
||||
list api tool provider tools
|
||||
"""
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider,
|
||||
).first()
|
||||
provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider}')
|
||||
|
||||
raise ValueError(f"you have not added provider {provider}")
|
||||
|
||||
controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
|
||||
labels = ToolLabelManager.get_tool_labels(controller)
|
||||
|
||||
|
||||
return [
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tool_bundle,
|
||||
labels=labels,
|
||||
) for tool_bundle in provider.tools
|
||||
)
|
||||
for tool_bundle in provider.tools
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def update_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str, original_provider: str, icon: dict, credentials: dict,
|
||||
schema_type: str, schema: str, privacy_policy: str, custom_disclaimer: str, labels: list[str]
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
original_provider: str,
|
||||
icon: dict,
|
||||
credentials: dict,
|
||||
schema_type: str,
|
||||
schema: str,
|
||||
privacy_policy: str,
|
||||
custom_disclaimer: str,
|
||||
labels: list[str],
|
||||
):
|
||||
"""
|
||||
update api tool provider
|
||||
update api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f'invalid schema type {schema}')
|
||||
|
||||
raise ValueError(f"invalid schema type {schema}")
|
||||
|
||||
# check if the provider exists
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == original_provider,
|
||||
).first()
|
||||
provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == original_provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'api provider {provider_name} does not exists')
|
||||
raise ValueError(f"api provider {provider_name} does not exists")
|
||||
|
||||
# parse openapi to tool bundle
|
||||
extra_info = {}
|
||||
# extra info like description will be set here
|
||||
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info)
|
||||
|
||||
|
||||
# update db provider
|
||||
provider.name = provider_name
|
||||
provider.icon = json.dumps(icon)
|
||||
provider.schema = schema
|
||||
provider.description = extra_info.get('description', '')
|
||||
provider.description = extra_info.get("description", "")
|
||||
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value
|
||||
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles))
|
||||
provider.privacy_policy = privacy_policy
|
||||
provider.custom_disclaimer = custom_disclaimer
|
||||
|
||||
if 'auth_type' not in credentials:
|
||||
raise ValueError('auth_type is required')
|
||||
if "auth_type" not in credentials:
|
||||
raise ValueError("auth_type is required")
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(provider, auth_type)
|
||||
@@ -295,84 +303,91 @@ class ApiToolManageService:
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def delete_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str
|
||||
):
|
||||
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str):
|
||||
"""
|
||||
delete tool provider
|
||||
delete tool provider
|
||||
"""
|
||||
provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
).first()
|
||||
provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f'you have not added provider {provider_name}')
|
||||
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
db.session.delete(provider)
|
||||
db.session.commit()
|
||||
|
||||
return { 'result': 'success' }
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_api_tool_provider(
|
||||
user_id: str, tenant_id: str, provider: str
|
||||
):
|
||||
def get_api_tool_provider(user_id: str, tenant_id: str, provider: str):
|
||||
"""
|
||||
get api tool provider
|
||||
get api tool provider
|
||||
"""
|
||||
return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def test_api_tool_preview(
|
||||
tenant_id: str,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
tool_name: str,
|
||||
credentials: dict,
|
||||
parameters: dict,
|
||||
schema_type: str,
|
||||
schema: str
|
||||
tool_name: str,
|
||||
credentials: dict,
|
||||
parameters: dict,
|
||||
schema_type: str,
|
||||
schema: str,
|
||||
):
|
||||
"""
|
||||
test api tool before adding api tool provider
|
||||
test api tool before adding api tool provider
|
||||
"""
|
||||
if schema_type not in [member.value for member in ApiProviderSchemaType]:
|
||||
raise ValueError(f'invalid schema type {schema_type}')
|
||||
|
||||
raise ValueError(f"invalid schema type {schema_type}")
|
||||
|
||||
try:
|
||||
tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema)
|
||||
except Exception as e:
|
||||
raise ValueError('invalid schema')
|
||||
|
||||
raise ValueError("invalid schema")
|
||||
|
||||
# get tool bundle
|
||||
tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None)
|
||||
if tool_bundle is None:
|
||||
raise ValueError(f'invalid tool name {tool_name}')
|
||||
|
||||
db_provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
).first()
|
||||
raise ValueError(f"invalid tool name {tool_name}")
|
||||
|
||||
db_provider: ApiToolProvider = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.filter(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider_name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not db_provider:
|
||||
# create a fake db provider
|
||||
db_provider = ApiToolProvider(
|
||||
tenant_id='', user_id='', name='', icon='',
|
||||
tenant_id="",
|
||||
user_id="",
|
||||
name="",
|
||||
icon="",
|
||||
schema=schema,
|
||||
description='',
|
||||
description="",
|
||||
schema_type_str=ApiProviderSchemaType.OPENAPI.value,
|
||||
tools_str=json.dumps(jsonable_encoder(tool_bundles)),
|
||||
credentials_str=json.dumps(credentials),
|
||||
)
|
||||
|
||||
if 'auth_type' not in credentials:
|
||||
raise ValueError('auth_type is required')
|
||||
if "auth_type" not in credentials:
|
||||
raise ValueError("auth_type is required")
|
||||
|
||||
# get auth type, none or api key
|
||||
auth_type = ApiProviderAuthType.value_of(credentials['auth_type'])
|
||||
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"])
|
||||
|
||||
# create provider entity
|
||||
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type)
|
||||
@@ -381,10 +396,7 @@ class ApiToolManageService:
|
||||
|
||||
# decrypt credentials
|
||||
if db_provider.id:
|
||||
tool_configuration = ToolConfigurationManager(
|
||||
tenant_id=tenant_id,
|
||||
provider_controller=provider_controller
|
||||
)
|
||||
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
|
||||
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
@@ -396,27 +408,27 @@ class ApiToolManageService:
|
||||
provider_controller.validate_credentials_format(credentials)
|
||||
# get tool
|
||||
tool = provider_controller.get_tool(tool_name)
|
||||
tool = tool.fork_tool_runtime(runtime={
|
||||
'credentials': credentials,
|
||||
'tenant_id': tenant_id,
|
||||
})
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime={
|
||||
"credentials": credentials,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
)
|
||||
result = tool.validate_credentials(credentials, parameters)
|
||||
except Exception as e:
|
||||
return { 'error': str(e) }
|
||||
|
||||
return { 'result': result or 'empty response' }
|
||||
|
||||
return {"error": str(e)}
|
||||
|
||||
return {"result": result or "empty response"}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tools(
|
||||
user_id: str, tenant_id: str
|
||||
) -> list[UserToolProvider]:
|
||||
def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
"""
|
||||
list api tools
|
||||
list api tools
|
||||
"""
|
||||
# get all api providers
|
||||
db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
|
||||
ApiToolProvider.tenant_id == tenant_id
|
||||
).all() or []
|
||||
db_providers: list[ApiToolProvider] = (
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
|
||||
@@ -425,26 +437,21 @@ class ApiToolManageService:
|
||||
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
|
||||
labels = ToolLabelManager.get_tool_labels(provider_controller)
|
||||
user_provider = ToolTransformService.api_provider_to_user_provider(
|
||||
provider_controller,
|
||||
db_provider=provider,
|
||||
decrypt_credentials=True
|
||||
provider_controller, db_provider=provider, decrypt_credentials=True
|
||||
)
|
||||
user_provider.labels = labels
|
||||
|
||||
# add icon
|
||||
ToolTransformService.repack_provider(user_provider)
|
||||
|
||||
tools = provider_controller.get_tools(
|
||||
user_id=user_id, tenant_id=tenant_id
|
||||
)
|
||||
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
|
||||
|
||||
for tool in tools:
|
||||
user_provider.tools.append(ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
credentials=user_provider.original_credentials,
|
||||
labels=labels
|
||||
))
|
||||
user_provider.tools.append(
|
||||
ToolTransformService.tool_to_user_tool(
|
||||
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
|
||||
)
|
||||
)
|
||||
|
||||
result.append(user_provider)
|
||||
|
||||
|
Reference in New Issue
Block a user