From 063191889d86a4d593337603321e6f785990809e Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Fri, 9 Feb 2024 15:21:33 +0800 Subject: [PATCH] chore: apply ruff's pyupgrade linter rules to modernize Python code with targeted version (#2419) --- api/app.py | 1 - api/config.py | 1 - api/controllers/console/app/app.py | 1 - api/controllers/console/app/audio.py | 1 - api/controllers/console/app/completion.py | 7 +- api/controllers/console/app/message.py | 6 +- api/controllers/console/app/model_config.py | 1 - api/controllers/console/app/site.py | 1 - api/controllers/console/app/statistic.py | 1 - api/controllers/console/auth/login.py | 1 - api/controllers/console/datasets/datasets.py | 1 - .../console/datasets/datasets_document.py | 4 +- .../console/datasets/datasets_segments.py | 1 - api/controllers/console/explore/audio.py | 1 - api/controllers/console/explore/completion.py | 7 +- .../console/explore/conversation.py | 1 - api/controllers/console/explore/error.py | 1 - .../console/explore/installed_app.py | 1 - api/controllers/console/explore/message.py | 7 +- api/controllers/console/explore/parameter.py | 1 - .../console/explore/recommended_app.py | 1 - api/controllers/console/setup.py | 1 - api/controllers/console/version.py | 1 - api/controllers/console/workspace/account.py | 1 - api/controllers/console/workspace/members.py | 1 - .../console/workspace/workspace.py | 1 - api/controllers/console/wraps.py | 1 - api/controllers/service_api/app/app.py | 1 - api/controllers/service_api/app/completion.py | 6 +- .../service_api/app/conversation.py | 1 - api/controllers/service_api/app/error.py | 1 - api/controllers/service_api/app/message.py | 1 - api/controllers/service_api/wraps.py | 1 - api/controllers/web/app.py | 1 - api/controllers/web/audio.py | 1 - api/controllers/web/completion.py | 7 +- api/controllers/web/conversation.py | 1 - api/controllers/web/error.py | 1 - api/controllers/web/message.py | 7 +- api/controllers/web/passport.py | 1 - api/controllers/web/site.py | 1 - api/controllers/web/wraps.py | 1 - api/core/agent/agent/agent_llm_callback.py | 10 +- api/core/agent/agent/calc_token_mixin.py | 4 +- .../agent/agent/multi_dataset_router_agent.py | 11 ++- api/core/agent/agent/openai_function_call.py | 15 +-- .../structed_multi_dataset_router_agent.py | 17 ++-- api/core/agent/agent/structured_chat.py | 21 ++-- api/core/app_runner/app_runner.py | 9 +- api/core/app_runner/generate_task_pipeline.py | 9 +- api/core/app_runner/moderation_handler.py | 4 +- api/core/application_manager.py | 5 +- api/core/application_queue_manager.py | 3 +- .../agent_loop_gather_callback_handler.py | 12 +-- .../agent_tool_callback_handler.py | 6 +- .../index_tool_callback_handler.py | 5 +- .../std_out_callback_handler.py | 14 +-- api/core/chain/llm_chain.py | 6 +- api/core/data_loader/file_extractor.py | 8 +- api/core/data_loader/loader/csv_loader.py | 6 +- api/core/data_loader/loader/excel.py | 3 +- api/core/data_loader/loader/html.py | 3 +- api/core/data_loader/loader/markdown.py | 12 +-- api/core/data_loader/loader/notion.py | 20 ++-- api/core/data_loader/loader/pdf.py | 4 +- .../loader/unstructured/unstructured_eml.py | 3 +- .../unstructured/unstructured_markdown.py | 3 +- .../loader/unstructured/unstructured_msg.py | 3 +- .../loader/unstructured/unstructured_ppt.py | 3 +- .../loader/unstructured/unstructured_pptx.py | 3 +- .../loader/unstructured/unstructured_text.py | 3 +- .../loader/unstructured/unstructured_xml.py | 3 +- api/core/docstore/dataset_docstore.py | 9 +- api/core/embedding/cached_embedding.py | 6 +- api/core/entities/provider_configuration.py | 11 ++- api/core/extension/extensible.py | 4 +- api/core/features/assistant_base_runner.py | 20 ++-- api/core/features/assistant_cot_runner.py | 21 ++-- api/core/features/assistant_fc_runner.py | 15 +-- api/core/features/dataset_retrieval.py | 4 +- api/core/features/external_data_fetch.py | 4 +- api/core/features/moderation.py | 3 +- api/core/file/message_file_parser.py | 14 +-- api/core/index/base.py | 4 +- .../jieba_keyword_table_handler.py | 5 +- .../keyword_table_index.py | 16 +-- api/core/index/vector_index/base.py | 6 +- .../index/vector_index/milvus_vector_index.py | 4 +- .../index/vector_index/qdrant_vector_index.py | 4 +- .../vector_index/weaviate_vector_index.py | 4 +- api/core/indexing_runner.py | 28 +++--- api/core/model_manager.py | 5 +- .../model_runtime/callbacks/base_callback.py | 10 +- .../callbacks/logging_callback.py | 10 +- api/core/model_runtime/entities/defaults.py | 3 +- .../model_providers/__base/ai_model.py | 4 +- .../__base/large_language_model.py | 21 ++-- .../model_providers/__base/model_provider.py | 5 +- .../model_providers/anthropic/llm/llm.py | 9 +- .../model_providers/azure_openai/llm/llm.py | 11 ++- .../text_embedding/text_embedding.py | 4 +- .../baichuan/llm/baichuan_tokenizer.py | 2 +- .../baichuan/llm/baichuan_turbo.py | 21 ++-- .../model_providers/baichuan/llm/llm.py | 9 +- .../baichuan/text_embedding/text_embedding.py | 4 +- .../model_providers/bedrock/llm/llm.py | 11 ++- .../model_providers/chatglm/llm/llm.py | 9 +- .../model_providers/cohere/llm/llm.py | 17 ++-- .../cohere/text_embedding/text_embedding.py | 4 +- .../model_providers/google/llm/llm.py | 7 +- .../huggingface_hub/llm/llm.py | 5 +- .../model_providers/localai/llm/llm.py | 11 ++- .../minimax/llm/chat_completion.py | 9 +- .../minimax/llm/chat_completion_pro.py | 9 +- .../model_providers/minimax/llm/llm.py | 8 +- .../model_providers/minimax/llm/types.py | 8 +- .../model_providers/model_provider_factory.py | 2 +- .../model_providers/moonshot/llm/llm.py | 5 +- .../model_providers/ollama/llm/llm.py | 9 +- .../model_providers/openai/llm/llm.py | 11 ++- .../openai/text_embedding/text_embedding.py | 4 +- .../openai_api_compatible/llm/llm.py | 9 +- .../model_providers/openllm/llm/llm.py | 8 +- .../openllm/llm/openllm_generate.py | 13 +-- .../model_providers/replicate/llm/llm.py | 5 +- .../model_providers/spark/llm/llm.py | 9 +- .../model_providers/togetherai/llm/llm.py | 7 +- .../model_providers/tongyi/llm/_client.py | 10 +- .../model_providers/tongyi/llm/llm.py | 9 +- .../model_providers/wenxin/llm/ernie_bot.py | 41 ++++---- .../model_providers/wenxin/llm/llm.py | 9 +- .../model_providers/xinference/llm/llm.py | 11 ++- .../xinference/xinference_helper.py | 7 +- .../model_providers/zhipuai/llm/llm.py | 15 +-- .../zhipuai/text_embedding/text_embedding.py | 6 +- .../zhipuai/zhipuai_sdk/_client.py | 3 +- .../api_resource/chat/async_completions.py | 9 +- .../api_resource/chat/completions.py | 9 +- .../zhipuai_sdk/api_resource/embeddings.py | 6 +- .../zhipuai/zhipuai_sdk/api_resource/files.py | 2 +- .../api_resource/fine_tuning/jobs.py | 2 +- .../zhipuai_sdk/api_resource/images.py | 2 +- .../zhipuai/zhipuai_sdk/core/_base_type.py | 27 ++--- .../zhipuai/zhipuai_sdk/core/_files.py | 4 +- .../zhipuai/zhipuai_sdk/core/_http_client.py | 20 ++-- .../zhipuai/zhipuai_sdk/core/_jwt_token.py | 1 - .../zhipuai/zhipuai_sdk/core/_request_opt.py | 6 +- .../zhipuai/zhipuai_sdk/core/_response.py | 8 +- .../zhipuai/zhipuai_sdk/core/_sse_client.py | 9 +- .../zhipuai/zhipuai_sdk/core/_utils.py | 3 +- .../types/chat/async_chat_completion.py | 4 +- .../zhipuai_sdk/types/chat/chat_completion.py | 6 +- .../types/chat/chat_completion_chunk.py | 6 +- .../zhipuai/zhipuai_sdk/types/embeddings.py | 6 +- .../zhipuai/zhipuai_sdk/types/file_object.py | 4 +- .../types/fine_tuning/fine_tuning_job.py | 6 +- .../fine_tuning/fine_tuning_job_event.py | 4 +- .../types/fine_tuning/job_create_params.py | 4 +- .../zhipuai/zhipuai_sdk/types/image.py | 4 +- api/core/model_runtime/utils/_compat.py | 3 +- api/core/model_runtime/utils/encoders.py | 21 ++-- api/core/prompt/prompt_transform.py | 36 +++---- api/core/rerank/rerank.py | 6 +- api/core/splitter/fixed_text_splitter.py | 12 +-- api/core/third_party/langchain/llms/fake.py | 11 ++- api/core/tool/current_datetime_tool.py | 3 +- api/core/tool/web_reader_tool.py | 6 +- api/core/tools/entities/tool_bundle.py | 6 +- api/core/tools/entities/tool_entities.py | 12 +-- api/core/tools/entities/user_entities.py | 6 +- api/core/tools/model/tool_model_manager.py | 6 +- api/core/tools/provider/api_tool_provider.py | 14 +-- api/core/tools/provider/app_tool_provider.py | 12 +-- api/core/tools/provider/builtin/_positions.py | 5 +- .../provider/builtin/azuredalle/azuredalle.py | 4 +- .../builtin/azuredalle/tools/dalle3.py | 6 +- api/core/tools/provider/builtin/bing/bing.py | 4 +- .../builtin/bing/tools/bing_web_search.py | 6 +- .../tools/provider/builtin/chart/tools/bar.py | 6 +- .../provider/builtin/chart/tools/line.py | 6 +- .../tools/provider/builtin/chart/tools/pie.py | 6 +- .../tools/provider/builtin/dalle/dalle.py | 4 +- .../provider/builtin/dalle/tools/dalle2.py | 6 +- .../provider/builtin/dalle/tools/dalle3.py | 6 +- .../builtin/gaode/tools/gaode_weather.py | 4 +- .../github/tools/github_repositories.py | 4 +- .../tools/provider/builtin/google/google.py | 4 +- .../builtin/google/tools/google_search.py | 8 +- .../tools/provider/builtin/maths/maths.py | 4 +- .../builtin/maths/tools/eval_expression.py | 6 +- .../stablediffusion/stablediffusion.py | 4 +- .../stablediffusion/tools/stable_diffusion.py | 14 +-- api/core/tools/provider/builtin/time/time.py | 4 +- .../builtin/time/tools/current_time.py | 6 +- .../builtin/vectorizer/tools/vectorizer.py | 8 +- .../provider/builtin/vectorizer/vectorizer.py | 4 +- .../builtin/webscraper/tools/webscraper.py | 6 +- .../provider/builtin/webscraper/webscraper.py | 4 +- .../wikipedia/tools/wikipedia_search.py | 6 +- .../wolframalpha/tools/wolframalpha.py | 6 +- .../builtin/wolframalpha/wolframalpha.py | 4 +- .../provider/builtin/yahoo/tools/analytics.py | 6 +- .../provider/builtin/yahoo/tools/news.py | 6 +- .../provider/builtin/yahoo/tools/ticker.py | 6 +- .../provider/builtin/youtube/tools/videos.py | 6 +- .../tools/provider/builtin_tool_provider.py | 28 +++--- api/core/tools/provider/tool_provider.py | 26 ++--- api/core/tools/tool/api_tool.py | 12 +-- api/core/tools/tool/builtin_tool.py | 7 +- .../dataset_multi_retriever_tool.py | 14 +-- .../dataset_retriever_tool.py | 6 +- api/core/tools/tool/dataset_retriever_tool.py | 10 +- api/core/tools/tool/tool.py | 24 ++--- api/core/tools/tool_file_manager.py | 9 +- api/core/tools/tool_manager.py | 26 ++--- api/core/tools/utils/configuration.py | 10 +- api/core/tools/utils/encoder.py | 5 +- api/core/tools/utils/parser.py | 15 ++- api/core/tools/utils/web_reader_tool.py | 6 +- api/core/vector_store/vector/milvus.py | 29 +++--- api/core/vector_store/vector/qdrant.py | 99 ++++++++++--------- api/core/vector_store/vector/weaviate.py | 45 ++++----- api/extensions/ext_storage.py | 3 +- api/libs/gmpy2_pkcs10aep_cipher.py | 1 - api/libs/helper.py | 9 +- api/libs/infinite_scroll_pagination.py | 1 - api/libs/json_in_md_parser.py | 3 +- api/libs/passport.py | 1 - api/libs/password.py | 1 - api/libs/rsa.py | 1 - api/models/account.py | 5 +- api/models/tools.py | 3 +- api/pyproject.toml | 3 + api/services/account_service.py | 7 +- api/services/completion_service.py | 3 +- api/services/dataset_service.py | 6 +- api/services/file_service.py | 5 +- api/services/hit_testing_service.py | 3 +- api/services/message_service.py | 4 +- api/services/model_provider_service.py | 4 +- api/services/tools_manage_service.py | 5 +- api/services/vector_service.py | 6 +- .../batch_create_segment_to_index_task.py | 4 +- api/tasks/clean_notion_document_task.py | 3 +- api/tasks/create_segment_to_index_task.py | 4 +- api/tasks/update_segment_index_task.py | 4 +- 246 files changed, 912 insertions(+), 937 deletions(-) diff --git a/api/app.py b/api/app.py index 255c1dbc0..bcf3856c1 100644 --- a/api/app.py +++ b/api/app.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import os from werkzeug.exceptions import Unauthorized diff --git a/api/config.py b/api/config.py index 84572ff59..b37a559e0 100644 --- a/api/config.py +++ b/api/config.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import os import dotenv diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index c06193f91..87cad0746 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import json import logging from datetime import datetime diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 775b3315a..d95b3d03c 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import logging from flask import request diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index be8d3bf08..f01d2afa0 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -1,7 +1,7 @@ -# -*- coding:utf-8 -*- import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union import flask_login from flask import Response, stream_with_context @@ -169,8 +169,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index d29d826b6..0064dbe66 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -1,6 +1,7 @@ import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union from flask import Response, stream_with_context from flask_login import current_user @@ -246,8 +247,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index fd526b393..f67fff4b0 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask import request from flask_login import current_user diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 8d6231cba..4e9d9ed9b 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index d6ced934a..7aed7da40 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from datetime import datetime from decimal import Decimal diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 646f672c7..cec022ed5 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import flask_login from flask import current_app, request from flask_restful import Resource, reqparse diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 5a71ccd6e..2d26d0ecf 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import flask_restful from flask import current_app, request from flask_login import current_user diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 612838a31..3fb6f16cd 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,6 +1,4 @@ -# -*- coding:utf-8 -*- from datetime import datetime -from typing import List from flask import request from flask_login import current_user @@ -71,7 +69,7 @@ class DocumentResource(Resource): return document - def get_batch_documents(self, dataset_id: str, batch: str) -> List[Document]: + def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise NotFound('Dataset not found.') diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 319b78b6d..1395963f1 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import uuid from datetime import datetime diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 48d58524b..d6afee0d6 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import logging from flask import request diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 924578f7b..6406d5b3b 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,8 +1,8 @@ -# -*- coding:utf-8 -*- import json import logging +from collections.abc import Generator from datetime import datetime -from typing import Generator, Union +from typing import Union from flask import Response, stream_with_context from flask_login import current_user @@ -164,8 +164,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index 8a3fb3a20..34a5904ec 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask_login import current_user from flask_restful import marshal_with, reqparse from flask_restful.inputs import int_range diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py index e3180bf98..89c4d113a 100644 --- a/api/controllers/console/explore/error.py +++ b/api/controllers/console/explore/error.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from libs.exception import BaseHTTPException diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 6e914ef3a..920d9141a 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from datetime import datetime from flask_login import current_user diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 75c3cdd5c..47af28425 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -1,7 +1,7 @@ -# -*- coding:utf-8 -*- import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union from flask import Response, stream_with_context from flask_login import current_user @@ -123,8 +123,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 4b18be6dc..c4afb0b92 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import json from flask import current_app diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 3c2c80666..fd90be03b 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask_login import current_user from flask_restful import Resource, fields, marshal_with from sqlalchemy import and_ diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index 58c285347..a8d0dd434 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from functools import wraps from flask import current_app, request diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 519fa2551..a50e4c41a 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import json import logging diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index c511c9778..b7cfba9d0 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from datetime import datetime import pytz diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 1b7d08a87..6ee018882 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask import current_app from flask_login import current_user from flask_restful import Resource, abort, fields, marshal_with, reqparse diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index dbeb712bc..7b3f08f46 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import logging from flask import request diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 1e20265c4..d5777a330 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import json from functools import wraps diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 89d99d66f..9cd9770c0 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import json from flask import current_app diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index d47bb089d..5331f796e 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -1,6 +1,7 @@ import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union from flask import Response, stream_with_context from flask_restful import reqparse @@ -182,8 +183,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index d275552d0..3c157bed9 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask import request from flask_restful import marshal_with, reqparse from flask_restful.inputs import int_range diff --git a/api/controllers/service_api/app/error.py b/api/controllers/service_api/app/error.py index 56beb5694..eb953d095 100644 --- a/api/controllers/service_api/app/error.py +++ b/api/controllers/service_api/app/error.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from libs.exception import BaseHTTPException diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index a0257b3ed..d90f536a4 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask_restful import fields, marshal_with, reqparse from flask_restful.inputs import int_range from werkzeug.exceptions import NotFound diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 0cc63a2ad..a0d89fe62 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from datetime import datetime from functools import wraps diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 6e62c042d..25492b114 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import json from flask import current_app diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index b3d7280b6..673aa9ad8 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import logging from flask import request diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index c61995b72..61d4f8c36 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,7 +1,7 @@ -# -*- coding:utf-8 -*- import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union from flask import Response, stream_with_context from flask_restful import reqparse @@ -154,8 +154,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index b0d7747d6..c287f2a87 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask_restful import marshal_with, reqparse from flask_restful.inputs import int_range from werkzeug.exceptions import NotFound diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 4566c323a..9cb3c8f23 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from libs.exception import BaseHTTPException diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 1a084fe53..e03bdd63b 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -1,7 +1,7 @@ -# -*- coding:utf-8 -*- import json import logging -from typing import Generator, Union +from collections.abc import Generator +from typing import Union from flask import Response, stream_with_context from flask_restful import fields, marshal_with, reqparse @@ -160,8 +160,7 @@ def compact_response(response: Union[dict, Generator]) -> Response: return Response(response=json.dumps(response), status=200, mimetype='application/json') else: def generate() -> Generator: - for chunk in response: - yield chunk + yield from response return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream') diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 188cc4125..92b28d812 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import uuid from flask import request diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 8ce3a8108..d8e2d5970 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from flask import current_app from flask_restful import fields, marshal_with diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index ebf661178..bdaa476f3 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- from functools import wraps from flask import request diff --git a/api/core/agent/agent/agent_llm_callback.py b/api/core/agent/agent/agent_llm_callback.py index 833173120..5ec549de8 100644 --- a/api/core/agent/agent/agent_llm_callback.py +++ b/api/core/agent/agent/agent_llm_callback.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional +from typing import Optional from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler from core.model_runtime.callbacks.base_callback import Callback @@ -17,7 +17,7 @@ class AgentLLMCallback(Callback): def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Before invoke callback @@ -38,7 +38,7 @@ class AgentLLMCallback(Callback): def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None): """ On new chunk callback @@ -58,7 +58,7 @@ class AgentLLMCallback(Callback): def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ After invoke callback @@ -80,7 +80,7 @@ class AgentLLMCallback(Callback): def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Invoke error callback diff --git a/api/core/agent/agent/calc_token_mixin.py b/api/core/agent/agent/calc_token_mixin.py index b25ab2d88..9c0f9c5b3 100644 --- a/api/core/agent/agent/calc_token_mixin.py +++ b/api/core/agent/agent/calc_token_mixin.py @@ -1,4 +1,4 @@ -from typing import List, cast +from typing import cast from core.entities.application_entities import ModelConfigEntity from core.model_runtime.entities.message_entities import PromptMessage @@ -8,7 +8,7 @@ from core.model_runtime.model_providers.__base.large_language_model import Large class CalcTokenMixin: - def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: List[PromptMessage], **kwargs) -> int: + def get_message_rest_tokens(self, model_config: ModelConfigEntity, messages: list[PromptMessage], **kwargs) -> int: """ Got the rest tokens available for the model after excluding messages tokens and completion max tokens diff --git a/api/core/agent/agent/multi_dataset_router_agent.py b/api/core/agent/agent/multi_dataset_router_agent.py index 201421910..eb594c3d2 100644 --- a/api/core/agent/agent/multi_dataset_router_agent.py +++ b/api/core/agent/agent/multi_dataset_router_agent.py @@ -1,4 +1,5 @@ -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message @@ -42,7 +43,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): def plan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -85,7 +86,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): def real_plan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -146,7 +147,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): async def aplan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -158,7 +159,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent): model_config: ModelConfigEntity, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, - extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, + extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, system_message: Optional[SystemMessage] = SystemMessage( content="You are a helpful AI assistant." ), diff --git a/api/core/agent/agent/openai_function_call.py b/api/core/agent/agent/openai_function_call.py index 3dafa4517..1f2d5f24b 100644 --- a/api/core/agent/agent/openai_function_call.py +++ b/api/core/agent/agent/openai_function_call.py @@ -1,4 +1,5 @@ -from typing import Any, List, Optional, Sequence, Tuple, Union +from collections.abc import Sequence +from typing import Any, Optional, Union from langchain.agents import BaseSingleActionAgent, OpenAIFunctionsAgent from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message @@ -51,7 +52,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi model_config: ModelConfigEntity, tools: Sequence[BaseTool], callback_manager: Optional[BaseCallbackManager] = None, - extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None, + extra_prompt_messages: Optional[list[BaseMessagePromptTemplate]] = None, system_message: Optional[SystemMessage] = SystemMessage( content="You are a helpful AI assistant." ), @@ -125,7 +126,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi def plan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -207,7 +208,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi def return_stopped_response( self, early_stopping_method: str, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], **kwargs: Any, ) -> AgentFinish: try: @@ -215,7 +216,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi except ValueError: return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "") - def summarize_messages_if_needed(self, messages: List[PromptMessage], **kwargs) -> List[PromptMessage]: + def summarize_messages_if_needed(self, messages: list[PromptMessage], **kwargs) -> list[PromptMessage]: # calculate rest tokens and summarize previous function observation messages if rest_tokens < 0 rest_tokens = self.get_message_rest_tokens( self.model_config, @@ -264,7 +265,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi return new_messages def predict_new_summary( - self, messages: List[BaseMessage], existing_summary: str + self, messages: list[BaseMessage], existing_summary: str ) -> str: new_lines = get_buffer_string( messages, @@ -275,7 +276,7 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixi chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT) return chain.predict(summary=existing_summary, new_lines=new_lines) - def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: List[BaseMessage], **kwargs) -> int: + def get_num_tokens_from_messages(self, model_config: ModelConfigEntity, messages: list[BaseMessage], **kwargs) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. Official documentation: https://github.com/openai/openai-cookbook/blob/ diff --git a/api/core/agent/agent/structed_multi_dataset_router_agent.py b/api/core/agent/agent/structed_multi_dataset_router_agent.py index 9d36e01d7..e104bb01f 100644 --- a/api/core/agent/agent/structed_multi_dataset_router_agent.py +++ b/api/core/agent/agent/structed_multi_dataset_router_agent.py @@ -1,5 +1,6 @@ import re -from typing import Any, List, Optional, Sequence, Tuple, Union, cast +from collections.abc import Sequence +from typing import Any, Optional, Union, cast from langchain import BasePromptTemplate, PromptTemplate from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent @@ -68,7 +69,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): def plan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -125,8 +126,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): suffix: str = SUFFIX, human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, - memory_prompts: Optional[List[BasePromptTemplate]] = None, + input_variables: Optional[list[str]] = None, + memory_prompts: Optional[list[BasePromptTemplate]] = None, ) -> BasePromptTemplate: tool_strings = [] for tool in tools: @@ -153,7 +154,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent): tools: Sequence[BaseTool], prefix: str = PREFIX, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, + input_variables: Optional[list[str]] = None, ) -> PromptTemplate: """Create prompt in the style of the zero shot agent. @@ -180,7 +181,7 @@ Thought: {agent_scratchpad} return PromptTemplate(template=template, input_variables=input_variables) def _construct_scratchpad( - self, intermediate_steps: List[Tuple[AgentAction, str]] + self, intermediate_steps: list[tuple[AgentAction, str]] ) -> str: agent_scratchpad = "" for action, observation in intermediate_steps: @@ -213,8 +214,8 @@ Thought: {agent_scratchpad} suffix: str = SUFFIX, human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, - memory_prompts: Optional[List[BasePromptTemplate]] = None, + input_variables: Optional[list[str]] = None, + memory_prompts: Optional[list[BasePromptTemplate]] = None, **kwargs: Any, ) -> Agent: """Construct an agent from an LLM and tools.""" diff --git a/api/core/agent/agent/structured_chat.py b/api/core/agent/agent/structured_chat.py index 03fea8c27..e1be62420 100644 --- a/api/core/agent/agent/structured_chat.py +++ b/api/core/agent/agent/structured_chat.py @@ -1,5 +1,6 @@ import re -from typing import Any, List, Optional, Sequence, Tuple, Union, cast +from collections.abc import Sequence +from typing import Any, Optional, Union, cast from langchain import BasePromptTemplate, PromptTemplate from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent @@ -82,7 +83,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): def plan( self, - intermediate_steps: List[Tuple[AgentAction, str]], + intermediate_steps: list[tuple[AgentAction, str]], callbacks: Callbacks = None, **kwargs: Any, ) -> Union[AgentAction, AgentFinish]: @@ -127,7 +128,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): return AgentFinish({"output": "I'm sorry, the answer of model is invalid, " "I don't know how to respond to that."}, "") - def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs): + def summarize_messages(self, intermediate_steps: list[tuple[AgentAction, str]], **kwargs): if len(intermediate_steps) >= 2 and self.summary_model_config: should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1] should_summary_messages = [AIMessage(content=observation) @@ -154,7 +155,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): return self.get_full_inputs([intermediate_steps[-1]], **kwargs) def predict_new_summary( - self, messages: List[BaseMessage], existing_summary: str + self, messages: list[BaseMessage], existing_summary: str ) -> str: new_lines = get_buffer_string( messages, @@ -173,8 +174,8 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): suffix: str = SUFFIX, human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, - memory_prompts: Optional[List[BasePromptTemplate]] = None, + input_variables: Optional[list[str]] = None, + memory_prompts: Optional[list[BasePromptTemplate]] = None, ) -> BasePromptTemplate: tool_strings = [] for tool in tools: @@ -200,7 +201,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin): tools: Sequence[BaseTool], prefix: str = PREFIX, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, + input_variables: Optional[list[str]] = None, ) -> PromptTemplate: """Create prompt in the style of the zero shot agent. @@ -227,7 +228,7 @@ Thought: {agent_scratchpad} return PromptTemplate(template=template, input_variables=input_variables) def _construct_scratchpad( - self, intermediate_steps: List[Tuple[AgentAction, str]] + self, intermediate_steps: list[tuple[AgentAction, str]] ) -> str: agent_scratchpad = "" for action, observation in intermediate_steps: @@ -260,8 +261,8 @@ Thought: {agent_scratchpad} suffix: str = SUFFIX, human_message_template: str = HUMAN_MESSAGE_TEMPLATE, format_instructions: str = FORMAT_INSTRUCTIONS, - input_variables: Optional[List[str]] = None, - memory_prompts: Optional[List[BasePromptTemplate]] = None, + input_variables: Optional[list[str]] = None, + memory_prompts: Optional[list[BasePromptTemplate]] = None, agent_llm_callback: Optional[AgentLLMCallback] = None, **kwargs: Any, ) -> Agent: diff --git a/api/core/app_runner/app_runner.py b/api/core/app_runner/app_runner.py index 457cae828..2b8ddc5d4 100644 --- a/api/core/app_runner/app_runner.py +++ b/api/core/app_runner/app_runner.py @@ -1,5 +1,6 @@ import time -from typing import Generator, List, Optional, Tuple, Union, cast +from collections.abc import Generator +from typing import Optional, Union, cast from core.application_queue_manager import ApplicationQueueManager, PublishFrom from core.entities.application_entities import ( @@ -84,7 +85,7 @@ class AppRunner: return rest_tokens def recale_llm_max_tokens(self, model_config: ModelConfigEntity, - prompt_messages: List[PromptMessage]): + prompt_messages: list[PromptMessage]): # recalc max_tokens if sum(prompt_token + max_tokens) over model token limit model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -126,7 +127,7 @@ class AppRunner: query: Optional[str] = None, context: Optional[str] = None, memory: Optional[TokenBufferMemory] = None) \ - -> Tuple[List[PromptMessage], Optional[List[str]]]: + -> tuple[list[PromptMessage], Optional[list[str]]]: """ Organize prompt messages :param context: @@ -295,7 +296,7 @@ class AppRunner: tenant_id: str, app_orchestration_config_entity: AppOrchestrationConfigEntity, inputs: dict, - query: str) -> Tuple[bool, dict, str]: + query: str) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. :param app_id: app id diff --git a/api/core/app_runner/generate_task_pipeline.py b/api/core/app_runner/generate_task_pipeline.py index 39f51ee1b..20e4bc799 100644 --- a/api/core/app_runner/generate_task_pipeline.py +++ b/api/core/app_runner/generate_task_pipeline.py @@ -1,7 +1,8 @@ import json import logging import time -from typing import Generator, Optional, Union, cast +from collections.abc import Generator +from typing import Optional, Union, cast from pydantic import BaseModel @@ -118,7 +119,7 @@ class GenerateTaskPipeline: } self._task_state.llm_result.message.content = annotation.content - elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)): + elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: @@ -201,7 +202,7 @@ class GenerateTaskPipeline: data = self._error_to_stream_response_data(self._handle_error(event)) yield self._yield_response(data) break - elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)): + elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): if isinstance(event, QueueMessageEndEvent): self._task_state.llm_result = event.llm_result else: @@ -353,7 +354,7 @@ class GenerateTaskPipeline: yield self._yield_response(response) - elif isinstance(event, (QueueMessageEvent, QueueAgentMessageEvent)): + elif isinstance(event, QueueMessageEvent | QueueAgentMessageEvent): chunk = event.chunk delta_text = chunk.delta.message.content if delta_text is None: diff --git a/api/core/app_runner/moderation_handler.py b/api/core/app_runner/moderation_handler.py index 392425ed8..b2098344c 100644 --- a/api/core/app_runner/moderation_handler.py +++ b/api/core/app_runner/moderation_handler.py @@ -1,7 +1,7 @@ import logging import threading import time -from typing import Any, Dict, Optional +from typing import Any, Optional from flask import Flask, current_app from pydantic import BaseModel @@ -15,7 +15,7 @@ logger = logging.getLogger(__name__) class ModerationRule(BaseModel): type: str - config: Dict[str, Any] + config: dict[str, Any] class OutputModerationHandler(BaseModel): diff --git a/api/core/application_manager.py b/api/core/application_manager.py index b718cefab..d2f4326b4 100644 --- a/api/core/application_manager.py +++ b/api/core/application_manager.py @@ -2,7 +2,8 @@ import json import logging import threading import uuid -from typing import Any, Generator, Optional, Tuple, Union, cast +from collections.abc import Generator +from typing import Any, Optional, Union, cast from flask import Flask, current_app from pydantic import ValidationError @@ -585,7 +586,7 @@ class ApplicationManager: return AppOrchestrationConfigEntity(**properties) def _init_generate_records(self, application_generate_entity: ApplicationGenerateEntity) \ - -> Tuple[Conversation, Message]: + -> tuple[Conversation, Message]: """ Initialize generate records :param application_generate_entity: application generate entity diff --git a/api/core/application_queue_manager.py b/api/core/application_queue_manager.py index 75a56d670..9590a1e72 100644 --- a/api/core/application_queue_manager.py +++ b/api/core/application_queue_manager.py @@ -1,7 +1,8 @@ import queue import time +from collections.abc import Generator from enum import Enum -from typing import Any, Generator +from typing import Any from sqlalchemy.orm import DeclarativeMeta diff --git a/api/core/callback_handler/agent_loop_gather_callback_handler.py b/api/core/callback_handler/agent_loop_gather_callback_handler.py index f9347198d..1d25b8ab6 100644 --- a/api/core/callback_handler/agent_loop_gather_callback_handler.py +++ b/api/core/callback_handler/agent_loop_gather_callback_handler.py @@ -1,7 +1,7 @@ import json import logging import time -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Optional, Union, cast from langchain.agents import openai_functions_agent, openai_functions_multi_agent from langchain.callbacks.base import BaseCallbackHandler @@ -37,7 +37,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): self._message_agent_thought = None @property - def agent_loops(self) -> List[AgentLoop]: + def agent_loops(self) -> list[AgentLoop]: return self._agent_loops def clear_agent_loops(self) -> None: @@ -95,14 +95,14 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], **kwargs: Any ) -> Any: pass def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any ) -> None: pass @@ -120,7 +120,7 @@ class AgentLoopGatherCallbackHandler(BaseCallbackHandler): def on_tool_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], input_str: str, **kwargs: Any, ) -> None: diff --git a/api/core/callback_handler/agent_tool_callback_handler.py b/api/core/callback_handler/agent_tool_callback_handler.py index ae77bf6cd..3fed7d0ad 100644 --- a/api/core/callback_handler/agent_tool_callback_handler.py +++ b/api/core/callback_handler/agent_tool_callback_handler.py @@ -1,5 +1,5 @@ import os -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from langchain.callbacks.base import BaseCallbackHandler from langchain.input import print_text @@ -21,7 +21,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel): def on_tool_start( self, tool_name: str, - tool_inputs: Dict[str, Any], + tool_inputs: dict[str, Any], ) -> None: """Do nothing.""" print_text("\n[on_tool_start] ToolCall:" + tool_name + "\n" + str(tool_inputs) + "\n", color=self.color) @@ -29,7 +29,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel): def on_tool_end( self, tool_name: str, - tool_inputs: Dict[str, Any], + tool_inputs: dict[str, Any], tool_outputs: str, ) -> None: """If not the final action, print out observation.""" diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 63c9bbe41..7c8a3ce47 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,4 +1,3 @@ -from typing import List from langchain.schema import Document @@ -40,7 +39,7 @@ class DatasetIndexToolCallbackHandler: db.session.add(dataset_query) db.session.commit() - def on_tool_end(self, documents: List[Document]) -> None: + def on_tool_end(self, documents: list[Document]) -> None: """Handle tool end.""" for document in documents: doc_id = document.metadata['doc_id'] @@ -55,7 +54,7 @@ class DatasetIndexToolCallbackHandler: db.session.commit() - def return_retriever_resource_info(self, resource: List): + def return_retriever_resource_info(self, resource: list): """Handle return_retriever_resource_info.""" if resource and len(resource) > 0: for item in resource: diff --git a/api/core/callback_handler/std_out_callback_handler.py b/api/core/callback_handler/std_out_callback_handler.py index 9f586d2c9..1f95471af 100644 --- a/api/core/callback_handler/std_out_callback_handler.py +++ b/api/core/callback_handler/std_out_callback_handler.py @@ -1,6 +1,6 @@ import os import sys -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from langchain.callbacks.base import BaseCallbackHandler from langchain.input import print_text @@ -16,8 +16,8 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): def on_chat_model_start( self, - serialized: Dict[str, Any], - messages: List[List[BaseMessage]], + serialized: dict[str, Any], + messages: list[list[BaseMessage]], **kwargs: Any ) -> Any: print_text("\n[on_chat_model_start]\n", color='blue') @@ -26,7 +26,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): print_text(str(sub_message) + "\n", color='blue') def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any ) -> None: """Print out the prompts.""" print_text("\n[on_llm_start]\n", color='blue') @@ -48,13 +48,13 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue') def on_chain_start( - self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any + self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any ) -> None: """Print out that we are entering a chain.""" chain_type = serialized['id'][-1] print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink') - def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None: + def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None: """Print out that we finished a chain.""" print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink') @@ -66,7 +66,7 @@ class DifyStdOutCallbackHandler(BaseCallbackHandler): def on_tool_start( self, - serialized: Dict[str, Any], + serialized: dict[str, Any], input_str: str, **kwargs: Any, ) -> None: diff --git a/api/core/chain/llm_chain.py b/api/core/chain/llm_chain.py index a5d160c99..86fb15629 100644 --- a/api/core/chain/llm_chain.py +++ b/api/core/chain/llm_chain.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Optional from langchain import LLMChain as LCLLMChain from langchain.callbacks.manager import CallbackManagerForChainRun @@ -16,12 +16,12 @@ class LLMChain(LCLLMChain): model_config: ModelConfigEntity """The language model instance to use.""" llm: BaseLanguageModel = FakeLLM(response="") - parameters: Dict[str, Any] = {} + parameters: dict[str, Any] = {} agent_llm_callback: Optional[AgentLLMCallback] = None def generate( self, - input_list: List[Dict[str, Any]], + input_list: list[dict[str, Any]], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> LLMResult: """Generate LLM result from inputs.""" diff --git a/api/core/data_loader/file_extractor.py b/api/core/data_loader/file_extractor.py index af0fb1d35..4a6eb3654 100644 --- a/api/core/data_loader/file_extractor.py +++ b/api/core/data_loader/file_extractor.py @@ -1,6 +1,6 @@ import tempfile from pathlib import Path -from typing import List, Optional, Union +from typing import Optional, Union import requests from flask import current_app @@ -28,7 +28,7 @@ USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTM class FileExtractor: @classmethod - def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[List[Document], str]: + def load(cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False) -> Union[list[Document], str]: with tempfile.TemporaryDirectory() as temp_dir: suffix = Path(upload_file.key).suffix file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" @@ -37,7 +37,7 @@ class FileExtractor: return cls.load_from_file(file_path, return_text, upload_file, is_automatic) @classmethod - def load_from_url(cls, url: str, return_text: bool = False) -> Union[List[Document], str]: + def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]: response = requests.get(url, headers={ "User-Agent": USER_AGENT }) @@ -53,7 +53,7 @@ class FileExtractor: @classmethod def load_from_file(cls, file_path: str, return_text: bool = False, upload_file: Optional[UploadFile] = None, - is_automatic: bool = False) -> Union[List[Document], str]: + is_automatic: bool = False) -> Union[list[Document], str]: input_file = Path(file_path) delimiter = '\n' file_extension = input_file.suffix.lower() diff --git a/api/core/data_loader/loader/csv_loader.py b/api/core/data_loader/loader/csv_loader.py index a4d4ed2b3..ce252c157 100644 --- a/api/core/data_loader/loader/csv_loader.py +++ b/api/core/data_loader/loader/csv_loader.py @@ -1,6 +1,6 @@ import csv import logging -from typing import Dict, List, Optional +from typing import Optional from langchain.document_loaders import CSVLoader as LCCSVLoader from langchain.document_loaders.helpers import detect_file_encodings @@ -14,7 +14,7 @@ class CSVLoader(LCCSVLoader): self, file_path: str, source_column: Optional[str] = None, - csv_args: Optional[Dict] = None, + csv_args: Optional[dict] = None, encoding: Optional[str] = None, autodetect_encoding: bool = True, ): @@ -24,7 +24,7 @@ class CSVLoader(LCCSVLoader): self.csv_args = csv_args or {} self.autodetect_encoding = autodetect_encoding - def load(self) -> List[Document]: + def load(self) -> list[Document]: """Load data into document objects.""" try: with open(self.file_path, newline="", encoding=self.encoding) as csvfile: diff --git a/api/core/data_loader/loader/excel.py b/api/core/data_loader/loader/excel.py index f5f6b2d69..cddb29854 100644 --- a/api/core/data_loader/loader/excel.py +++ b/api/core/data_loader/loader/excel.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -23,7 +22,7 @@ class ExcelLoader(BaseLoader): """Initialize with file path.""" self._file_path = file_path - def load(self) -> List[Document]: + def load(self) -> list[Document]: data = [] keys = [] wb = load_workbook(filename=self._file_path, read_only=True) diff --git a/api/core/data_loader/loader/html.py b/api/core/data_loader/loader/html.py index 414975007..6a9b48a5b 100644 --- a/api/core/data_loader/loader/html.py +++ b/api/core/data_loader/loader/html.py @@ -1,5 +1,4 @@ import logging -from typing import List from bs4 import BeautifulSoup from langchain.document_loaders.base import BaseLoader @@ -23,7 +22,7 @@ class HTMLLoader(BaseLoader): """Initialize with file path.""" self._file_path = file_path - def load(self) -> List[Document]: + def load(self) -> list[Document]: return [Document(page_content=self._load_as_text())] def _load_as_text(self) -> str: diff --git a/api/core/data_loader/loader/markdown.py b/api/core/data_loader/loader/markdown.py index 545c6b10e..ecbc6d548 100644 --- a/api/core/data_loader/loader/markdown.py +++ b/api/core/data_loader/loader/markdown.py @@ -1,6 +1,6 @@ import logging import re -from typing import List, Optional, Tuple, cast +from typing import Optional, cast from langchain.document_loaders.base import BaseLoader from langchain.document_loaders.helpers import detect_file_encodings @@ -42,7 +42,7 @@ class MarkdownLoader(BaseLoader): self._encoding = encoding self._autodetect_encoding = autodetect_encoding - def load(self) -> List[Document]: + def load(self) -> list[Document]: tups = self.parse_tups(self._file_path) documents = [] for header, value in tups: @@ -54,13 +54,13 @@ class MarkdownLoader(BaseLoader): return documents - def markdown_to_tups(self, markdown_text: str) -> List[Tuple[Optional[str], str]]: + def markdown_to_tups(self, markdown_text: str) -> list[tuple[Optional[str], str]]: """Convert a markdown file to a dictionary. The keys are the headers and the values are the text under each header. """ - markdown_tups: List[Tuple[Optional[str], str]] = [] + markdown_tups: list[tuple[Optional[str], str]] = [] lines = markdown_text.split("\n") current_header = None @@ -103,11 +103,11 @@ class MarkdownLoader(BaseLoader): content = re.sub(pattern, r"\1", content) return content - def parse_tups(self, filepath: str) -> List[Tuple[Optional[str], str]]: + def parse_tups(self, filepath: str) -> list[tuple[Optional[str], str]]: """Parse file into tuples.""" content = "" try: - with open(filepath, "r", encoding=self._encoding) as f: + with open(filepath, encoding=self._encoding) as f: content = f.read() except UnicodeDecodeError as e: if self._autodetect_encoding: diff --git a/api/core/data_loader/loader/notion.py b/api/core/data_loader/loader/notion.py index 9f9198c3c..f8d883768 100644 --- a/api/core/data_loader/loader/notion.py +++ b/api/core/data_loader/loader/notion.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional import requests from flask import current_app @@ -67,7 +67,7 @@ class NotionLoader(BaseLoader): document_model=document_model ) - def load(self) -> List[Document]: + def load(self) -> list[Document]: self.update_last_edited_time( self._document_model ) @@ -78,7 +78,7 @@ class NotionLoader(BaseLoader): def _load_data_as_documents( self, notion_obj_id: str, notion_page_type: str - ) -> List[Document]: + ) -> list[Document]: docs = [] if notion_page_type == 'database': # get all the pages in the database @@ -94,8 +94,8 @@ class NotionLoader(BaseLoader): return docs def _get_notion_database_data( - self, database_id: str, query_dict: Dict[str, Any] = {} - ) -> List[Document]: + self, database_id: str, query_dict: dict[str, Any] = {} + ) -> list[Document]: """Get all the pages from a Notion database.""" res = requests.post( DATABASE_URL_TMPL.format(database_id=database_id), @@ -149,12 +149,12 @@ class NotionLoader(BaseLoader): return database_content_list - def _get_notion_block_data(self, page_id: str) -> List[str]: + def _get_notion_block_data(self, page_id: str) -> list[str]: result_lines_arr = [] cur_block_id = page_id while True: block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) - query_dict: Dict[str, Any] = {} + query_dict: dict[str, Any] = {} res = requests.request( "GET", @@ -216,7 +216,7 @@ class NotionLoader(BaseLoader): cur_block_id = block_id while True: block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) - query_dict: Dict[str, Any] = {} + query_dict: dict[str, Any] = {} res = requests.request( "GET", @@ -280,7 +280,7 @@ class NotionLoader(BaseLoader): cur_block_id = block_id while not done: block_url = BLOCK_CHILD_URL_TMPL.format(block_id=cur_block_id) - query_dict: Dict[str, Any] = {} + query_dict: dict[str, Any] = {} res = requests.request( "GET", @@ -346,7 +346,7 @@ class NotionLoader(BaseLoader): else: retrieve_page_url = RETRIEVE_PAGE_URL_TMPL.format(page_id=obj_id) - query_dict: Dict[str, Any] = {} + query_dict: dict[str, Any] = {} res = requests.request( "GET", diff --git a/api/core/data_loader/loader/pdf.py b/api/core/data_loader/loader/pdf.py index 881d0026b..a3452b367 100644 --- a/api/core/data_loader/loader/pdf.py +++ b/api/core/data_loader/loader/pdf.py @@ -1,5 +1,5 @@ import logging -from typing import List, Optional +from typing import Optional from langchain.document_loaders import PyPDFium2Loader from langchain.document_loaders.base import BaseLoader @@ -28,7 +28,7 @@ class PdfLoader(BaseLoader): self._file_path = file_path self._upload_file = upload_file - def load(self) -> List[Document]: + def load(self) -> list[Document]: plaintext_file_key = '' plaintext_file_exists = False if self._upload_file: diff --git a/api/core/data_loader/loader/unstructured/unstructured_eml.py b/api/core/data_loader/loader/unstructured/unstructured_eml.py index 26e0ce8cd..2fa3aac13 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_eml.py +++ b/api/core/data_loader/loader/unstructured/unstructured_eml.py @@ -1,6 +1,5 @@ import base64 import logging -from typing import List from bs4 import BeautifulSoup from langchain.document_loaders.base import BaseLoader @@ -24,7 +23,7 @@ class UnstructuredEmailLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.email import partition_email elements = partition_email(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_markdown.py b/api/core/data_loader/loader/unstructured/unstructured_markdown.py index cf6e7c9c8..036a2afd2 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_markdown.py +++ b/api/core/data_loader/loader/unstructured/unstructured_markdown.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -34,7 +33,7 @@ class UnstructuredMarkdownLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.md import partition_md elements = partition_md(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_msg.py b/api/core/data_loader/loader/unstructured/unstructured_msg.py index 5a9813237..495be328e 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_msg.py +++ b/api/core/data_loader/loader/unstructured/unstructured_msg.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -24,7 +23,7 @@ class UnstructuredMsgLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.msg import partition_msg elements = partition_msg(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_ppt.py b/api/core/data_loader/loader/unstructured/unstructured_ppt.py index 9b1e6b5ab..cfac91cc7 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_ppt.py +++ b/api/core/data_loader/loader/unstructured/unstructured_ppt.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -23,7 +22,7 @@ class UnstructuredPPTLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.ppt import partition_ppt elements = partition_ppt(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_pptx.py b/api/core/data_loader/loader/unstructured/unstructured_pptx.py index 0eecee9ff..41e3bfcb5 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_pptx.py +++ b/api/core/data_loader/loader/unstructured/unstructured_pptx.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -22,7 +21,7 @@ class UnstructuredPPTXLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.pptx import partition_pptx elements = partition_pptx(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_text.py b/api/core/data_loader/loader/unstructured/unstructured_text.py index dd684b37f..09d14fdb1 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_text.py +++ b/api/core/data_loader/loader/unstructured/unstructured_text.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -24,7 +23,7 @@ class UnstructuredTextLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.text import partition_text elements = partition_text(filename=self._file_path, api_url=self._api_url) diff --git a/api/core/data_loader/loader/unstructured/unstructured_xml.py b/api/core/data_loader/loader/unstructured/unstructured_xml.py index 0ddbb74b9..cca6e1b0b 100644 --- a/api/core/data_loader/loader/unstructured/unstructured_xml.py +++ b/api/core/data_loader/loader/unstructured/unstructured_xml.py @@ -1,5 +1,4 @@ import logging -from typing import List from langchain.document_loaders.base import BaseLoader from langchain.schema import Document @@ -24,7 +23,7 @@ class UnstructuredXmlLoader(BaseLoader): self._file_path = file_path self._api_url = api_url - def load(self) -> List[Document]: + def load(self) -> list[Document]: from unstructured.partition.xml import partition_xml elements = partition_xml(filename=self._file_path, xml_keep_tags=True, api_url=self._api_url) diff --git a/api/core/docstore/dataset_docstore.py b/api/core/docstore/dataset_docstore.py index 77a5dde9e..556b3aced 100644 --- a/api/core/docstore/dataset_docstore.py +++ b/api/core/docstore/dataset_docstore.py @@ -1,4 +1,5 @@ -from typing import Any, Dict, Optional, Sequence, cast +from collections.abc import Sequence +from typing import Any, Optional, cast from langchain.schema import Document from sqlalchemy import func @@ -22,10 +23,10 @@ class DatasetDocumentStore: self._document_id = document_id @classmethod - def from_dict(cls, config_dict: Dict[str, Any]) -> "DatasetDocumentStore": + def from_dict(cls, config_dict: dict[str, Any]) -> "DatasetDocumentStore": return cls(**config_dict) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Serialize to dict.""" return { "dataset_id": self._dataset.id, @@ -40,7 +41,7 @@ class DatasetDocumentStore: return self._user_id @property - def docs(self) -> Dict[str, Document]: + def docs(self) -> dict[str, Document]: document_segments = db.session.query(DocumentSegment).filter( DocumentSegment.dataset_id == self._dataset.id ).all() diff --git a/api/core/embedding/cached_embedding.py b/api/core/embedding/cached_embedding.py index 4f7b3a153..a86afd817 100644 --- a/api/core/embedding/cached_embedding.py +++ b/api/core/embedding/cached_embedding.py @@ -1,6 +1,6 @@ import base64 import logging -from typing import List, Optional, cast +from typing import Optional, cast import numpy as np from langchain.embeddings.base import Embeddings @@ -21,7 +21,7 @@ class CacheEmbedding(Embeddings): self._model_instance = model_instance self._user = user - def embed_documents(self, texts: List[str]) -> List[List[float]]: + def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" text_embeddings = [] try: @@ -52,7 +52,7 @@ class CacheEmbedding(Embeddings): return text_embeddings - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> list[float]: """Embed query text.""" # use doc embedding cache or store if not exists hash = helper.generate_text_hash(text) diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index fd6164763..b83ae0c8e 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,8 +1,9 @@ import datetime import json import logging +from collections.abc import Iterator from json import JSONDecodeError -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Optional from pydantic import BaseModel @@ -135,7 +136,7 @@ class ProviderConfiguration(BaseModel): if self.provider.provider_credential_schema else [] ) - def custom_credentials_validate(self, credentials: dict) -> Tuple[Provider, dict]: + def custom_credentials_validate(self, credentials: dict) -> tuple[Provider, dict]: """ Validate custom credentials. :param credentials: provider credentials @@ -282,7 +283,7 @@ class ProviderConfiguration(BaseModel): return None def custom_model_credentials_validate(self, model_type: ModelType, model: str, credentials: dict) \ - -> Tuple[ProviderModel, dict]: + -> tuple[ProviderModel, dict]: """ Validate custom model credentials. @@ -711,7 +712,7 @@ class ProviderConfigurations(BaseModel): Model class for provider configuration dict. """ tenant_id: str - configurations: Dict[str, ProviderConfiguration] = {} + configurations: dict[str, ProviderConfiguration] = {} def __init__(self, tenant_id: str): super().__init__(tenant_id=tenant_id) @@ -759,7 +760,7 @@ class ProviderConfigurations(BaseModel): return all_models - def to_list(self) -> List[ProviderConfiguration]: + def to_list(self) -> list[ProviderConfiguration]: """ Convert to list. diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 6b27062f1..c19aaefe9 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -61,7 +61,7 @@ class Extensible: builtin_file_path = os.path.join(subdir_path, '__builtin__') if os.path.exists(builtin_file_path): - with open(builtin_file_path, 'r', encoding='utf-8') as f: + with open(builtin_file_path, encoding='utf-8') as f: position = int(f.read().strip()) if (extension_name + '.py') not in file_names: @@ -93,7 +93,7 @@ class Extensible: json_path = os.path.join(subdir_path, 'schema.json') json_data = {} if os.path.exists(json_path): - with open(json_path, 'r', encoding='utf-8') as f: + with open(json_path, encoding='utf-8') as f: json_data = json.load(f) extensions[extension_name] = ModuleExtension( diff --git a/api/core/features/assistant_base_runner.py b/api/core/features/assistant_base_runner.py index 4c0bde989..c62028eaf 100644 --- a/api/core/features/assistant_base_runner.py +++ b/api/core/features/assistant_base_runner.py @@ -2,7 +2,7 @@ import json import logging from datetime import datetime from mimetypes import guess_extension -from typing import List, Optional, Tuple, Union, cast +from typing import Optional, Union, cast from core.app_runner.app_runner import AppRunner from core.application_queue_manager import ApplicationQueueManager @@ -50,7 +50,7 @@ class BaseAssistantApplicationRunner(AppRunner): message: Message, user_id: str, memory: Optional[TokenBufferMemory] = None, - prompt_messages: Optional[List[PromptMessage]] = None, + prompt_messages: Optional[list[PromptMessage]] = None, variables_pool: Optional[ToolRuntimeVariablePool] = None, db_variables: Optional[ToolConversationVariables] = None, model_instance: ModelInstance = None @@ -122,7 +122,7 @@ class BaseAssistantApplicationRunner(AppRunner): return app_orchestration_config - def _convert_tool_response_to_str(self, tool_response: List[ToolInvokeMessage]) -> str: + def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str: """ Handle tool response """ @@ -140,7 +140,7 @@ class BaseAssistantApplicationRunner(AppRunner): return result - def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> Tuple[PromptMessageTool, Tool]: + def _convert_tool_to_prompt_message_tool(self, tool: AgentToolEntity) -> tuple[PromptMessageTool, Tool]: """ convert tool to prompt message tool """ @@ -325,7 +325,7 @@ class BaseAssistantApplicationRunner(AppRunner): return prompt_tool - def extract_tool_response_binary(self, tool_response: List[ToolInvokeMessage]) -> List[ToolInvokeMessageBinary]: + def extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]: """ Extract tool response binary """ @@ -356,7 +356,7 @@ class BaseAssistantApplicationRunner(AppRunner): return result - def create_message_files(self, messages: List[ToolInvokeMessageBinary]) -> List[Tuple[MessageFile, bool]]: + def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[tuple[MessageFile, bool]]: """ Create message file @@ -404,7 +404,7 @@ class BaseAssistantApplicationRunner(AppRunner): return result def create_agent_thought(self, message_id: str, message: str, - tool_name: str, tool_input: str, messages_ids: List[str] + tool_name: str, tool_input: str, messages_ids: list[str] ) -> MessageAgentThought: """ Create agent thought @@ -449,7 +449,7 @@ class BaseAssistantApplicationRunner(AppRunner): thought: str, observation: str, answer: str, - messages_ids: List[str], + messages_ids: list[str], llm_usage: LLMUsage = None) -> MessageAgentThought: """ Save agent thought @@ -505,7 +505,7 @@ class BaseAssistantApplicationRunner(AppRunner): db.session.commit() - def get_history_prompt_messages(self) -> List[PromptMessage]: + def get_history_prompt_messages(self) -> list[PromptMessage]: """ Get history prompt messages """ @@ -516,7 +516,7 @@ class BaseAssistantApplicationRunner(AppRunner): return self.history_prompt_messages - def transform_tool_invoke_messages(self, messages: List[ToolInvokeMessage]) -> List[ToolInvokeMessage]: + def transform_tool_invoke_messages(self, messages: list[ToolInvokeMessage]) -> list[ToolInvokeMessage]: """ Transform tool message into agent thought """ diff --git a/api/core/features/assistant_cot_runner.py b/api/core/features/assistant_cot_runner.py index 546406983..b8d08bb5d 100644 --- a/api/core/features/assistant_cot_runner.py +++ b/api/core/features/assistant_cot_runner.py @@ -1,6 +1,7 @@ import json import re -from typing import Dict, Generator, List, Literal, Union +from collections.abc import Generator +from typing import Literal, Union from core.application_queue_manager import PublishFrom from core.entities.application_entities import AgentPromptEntity, AgentScratchpadUnit @@ -29,7 +30,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): def run(self, conversation: Conversation, message: Message, query: str, - inputs: Dict[str, str], + inputs: dict[str, str], ) -> Union[Generator, LLMResult]: """ Run Cot agent application @@ -37,7 +38,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): app_orchestration_config = self.app_orchestration_config self._repack_app_orchestration_config(app_orchestration_config) - agent_scratchpad: List[AgentScratchpadUnit] = [] + agent_scratchpad: list[AgentScratchpadUnit] = [] # check model mode if self.app_orchestration_config.model_config.mode == "completion": @@ -56,7 +57,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): prompt_messages = self.history_prompt_messages # convert tools into ModelRuntime Tool format - prompt_messages_tools: List[PromptMessageTool] = [] + prompt_messages_tools: list[PromptMessageTool] = [] tool_instances = {} for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: try: @@ -83,7 +84,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): } final_answer = '' - def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if not final_llm_usage_dict['usage']: final_llm_usage_dict['usage'] = usage else: @@ -493,7 +494,7 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): if not next_iteration.find("{{observation}}") >= 0: raise ValueError("{{observation}} is required in next_iteration") - def _convert_scratchpad_list_to_str(self, agent_scratchpad: List[AgentScratchpadUnit]) -> str: + def _convert_scratchpad_list_to_str(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str: """ convert agent scratchpad list to str """ @@ -506,13 +507,13 @@ class AssistantCotApplicationRunner(BaseAssistantApplicationRunner): return result def _organize_cot_prompt_messages(self, mode: Literal["completion", "chat"], - prompt_messages: List[PromptMessage], - tools: List[PromptMessageTool], - agent_scratchpad: List[AgentScratchpadUnit], + prompt_messages: list[PromptMessage], + tools: list[PromptMessageTool], + agent_scratchpad: list[AgentScratchpadUnit], agent_prompt_message: AgentPromptEntity, instruction: str, input: str, - ) -> List[PromptMessage]: + ) -> list[PromptMessage]: """ organize chain of thought prompt messages, a standard prompt message is like: Respond to the human as helpfully and accurately as possible. diff --git a/api/core/features/assistant_fc_runner.py b/api/core/features/assistant_fc_runner.py index b0e3d3a7a..7ad9d7bd2 100644 --- a/api/core/features/assistant_fc_runner.py +++ b/api/core/features/assistant_fc_runner.py @@ -1,6 +1,7 @@ import json import logging -from typing import Any, Dict, Generator, List, Tuple, Union +from collections.abc import Generator +from typing import Any, Union from core.application_queue_manager import PublishFrom from core.features.assistant_base_runner import BaseAssistantApplicationRunner @@ -44,7 +45,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): ) # convert tools into ModelRuntime Tool format - prompt_messages_tools: List[PromptMessageTool] = [] + prompt_messages_tools: list[PromptMessageTool] = [] tool_instances = {} for tool in self.app_orchestration_config.agent.tools if self.app_orchestration_config.agent else []: try: @@ -70,13 +71,13 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): # continue to run until there is not any tool call function_call_state = True - agent_thoughts: List[MessageAgentThought] = [] + agent_thoughts: list[MessageAgentThought] = [] llm_usage = { 'usage': None } final_answer = '' - def increase_usage(final_llm_usage_dict: Dict[str, LLMUsage], usage: LLMUsage): + def increase_usage(final_llm_usage_dict: dict[str, LLMUsage], usage: LLMUsage): if not final_llm_usage_dict['usage']: final_llm_usage_dict['usage'] = usage else: @@ -117,7 +118,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): callbacks=[], ) - tool_calls: List[Tuple[str, str, Dict[str, Any]]] = [] + tool_calls: list[tuple[str, str, dict[str, Any]]] = [] # save full response response = '' @@ -364,7 +365,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): return True return False - def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: + def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: """ Extract tool calls from llm result chunk @@ -381,7 +382,7 @@ class AssistantFunctionCallApplicationRunner(BaseAssistantApplicationRunner): return tool_calls - def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, List[Tuple[str, str, Dict[str, Any]]]]: + def extract_blocking_tool_calls(self, llm_result: LLMResult) -> Union[None, list[tuple[str, str, dict[str, Any]]]]: """ Extract blocking tool calls from llm result diff --git a/api/core/features/dataset_retrieval.py b/api/core/features/dataset_retrieval.py index 159428aad..488a8ca8d 100644 --- a/api/core/features/dataset_retrieval.py +++ b/api/core/features/dataset_retrieval.py @@ -1,4 +1,4 @@ -from typing import List, Optional, cast +from typing import Optional, cast from langchain.tools import BaseTool @@ -96,7 +96,7 @@ class DatasetRetrievalFeature: return_resource: bool, invoke_from: InvokeFrom, hit_callback: DatasetIndexToolCallbackHandler) \ - -> Optional[List[BaseTool]]: + -> Optional[list[BaseTool]]: """ A dataset tool is a tool that can be used to retrieve information from a dataset :param tenant_id: tenant id diff --git a/api/core/features/external_data_fetch.py b/api/core/features/external_data_fetch.py index 33154d838..7f23c8ed7 100644 --- a/api/core/features/external_data_fetch.py +++ b/api/core/features/external_data_fetch.py @@ -2,7 +2,7 @@ import concurrent import json import logging from concurrent.futures import ThreadPoolExecutor -from typing import Optional, Tuple +from typing import Optional from flask import Flask, current_app @@ -62,7 +62,7 @@ class ExternalDataFetchFeature: app_id: str, external_data_tool: ExternalDataVariableEntity, inputs: dict, - query: str) -> Tuple[Optional[str], Optional[str]]: + query: str) -> tuple[Optional[str], Optional[str]]: """ Query external data tool. :param flask_app: flask app diff --git a/api/core/features/moderation.py b/api/core/features/moderation.py index 9735fad0e..a9d65f56e 100644 --- a/api/core/features/moderation.py +++ b/api/core/features/moderation.py @@ -1,5 +1,4 @@ import logging -from typing import Tuple from core.entities.application_entities import AppOrchestrationConfigEntity from core.moderation.base import ModerationAction, ModerationException @@ -13,7 +12,7 @@ class ModerationFeature: tenant_id: str, app_orchestration_config_entity: AppOrchestrationConfigEntity, inputs: dict, - query: str) -> Tuple[bool, dict, str]: + query: str) -> tuple[bool, dict, str]: """ Process sensitive_word_avoidance. :param app_id: app id diff --git a/api/core/file/message_file_parser.py b/api/core/file/message_file_parser.py index ce783d8fb..1b7b8b87d 100644 --- a/api/core/file/message_file_parser.py +++ b/api/core/file/message_file_parser.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Union +from typing import Optional, Union import requests @@ -15,8 +15,8 @@ class MessageFileParser: self.tenant_id = tenant_id self.app_id = app_id - def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig, - user: Union[Account, EndUser]) -> List[FileObj]: + def validate_and_transform_files_arg(self, files: list[dict], app_model_config: AppModelConfig, + user: Union[Account, EndUser]) -> list[FileObj]: """ validate and transform files arg @@ -96,7 +96,7 @@ class MessageFileParser: # return all file objs return new_files - def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]: + def transform_message_files(self, files: list[MessageFile], app_model_config: Optional[AppModelConfig]) -> list[FileObj]: """ transform message files @@ -110,8 +110,8 @@ class MessageFileParser: # return all file objs return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs] - def _to_file_objs(self, files: List[Union[Dict, MessageFile]], - file_upload_config: dict) -> Dict[FileType, List[FileObj]]: + def _to_file_objs(self, files: list[Union[dict, MessageFile]], + file_upload_config: dict) -> dict[FileType, list[FileObj]]: """ transform files to file objs @@ -119,7 +119,7 @@ class MessageFileParser: :param file_upload_config: :return: """ - type_file_objs: Dict[FileType, List[FileObj]] = { + type_file_objs: dict[FileType, list[FileObj]] = { # Currently only support image FileType.IMAGE: [] } diff --git a/api/core/index/base.py b/api/core/index/base.py index 1dc7cfdcc..f8eb1a134 100644 --- a/api/core/index/base.py +++ b/api/core/index/base.py @@ -1,7 +1,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, List +from typing import Any from langchain.schema import BaseRetriever, Document @@ -53,7 +53,7 @@ class BaseIndex(ABC): def search( self, query: str, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: raise NotImplementedError def delete(self) -> None: diff --git a/api/core/index/keyword_table_index/jieba_keyword_table_handler.py b/api/core/index/keyword_table_index/jieba_keyword_table_handler.py index db9fd027a..df93a1903 100644 --- a/api/core/index/keyword_table_index/jieba_keyword_table_handler.py +++ b/api/core/index/keyword_table_index/jieba_keyword_table_handler.py @@ -1,5 +1,4 @@ import re -from typing import Set import jieba from jieba.analyse import default_tfidf @@ -12,7 +11,7 @@ class JiebaKeywordTableHandler: def __init__(self): default_tfidf.stop_words = STOPWORDS - def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> Set[str]: + def extract_keywords(self, text: str, max_keywords_per_chunk: int = 10) -> set[str]: """Extract keywords with JIEBA tfidf.""" keywords = jieba.analyse.extract_tags( sentence=text, @@ -21,7 +20,7 @@ class JiebaKeywordTableHandler: return set(self._expand_tokens_with_subtokens(keywords)) - def _expand_tokens_with_subtokens(self, tokens: Set[str]) -> Set[str]: + def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: """Get subtokens from a list of tokens., filtering for stopwords.""" results = set() for token in tokens: diff --git a/api/core/index/keyword_table_index/keyword_table_index.py b/api/core/index/keyword_table_index/keyword_table_index.py index 9ad8b8d64..8bf0b1334 100644 --- a/api/core/index/keyword_table_index/keyword_table_index.py +++ b/api/core/index/keyword_table_index/keyword_table_index.py @@ -1,6 +1,6 @@ import json from collections import defaultdict -from typing import Any, Dict, List, Optional +from typing import Any, Optional from langchain.schema import BaseRetriever, Document from pydantic import BaseModel, Extra, Field @@ -116,7 +116,7 @@ class KeywordTableIndex(BaseIndex): def search( self, query: str, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: keyword_table = self._get_dataset_keyword_table() search_kwargs = kwargs.get('search_kwargs') if kwargs.get('search_kwargs') else {} @@ -221,7 +221,7 @@ class KeywordTableIndex(BaseIndex): keywords = keyword_table_handler.extract_keywords(query) # go through text chunks in order of most matching keywords - chunk_indices_count: Dict[str, int] = defaultdict(int) + chunk_indices_count: dict[str, int] = defaultdict(int) keywords = [keyword for keyword in keywords if keyword in set(keyword_table.keys())] for keyword in keywords: for node_id in keyword_table[keyword]: @@ -235,7 +235,7 @@ class KeywordTableIndex(BaseIndex): return sorted_chunk_indices[: k] - def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: List[str]): + def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]): document_segment = db.session.query(DocumentSegment).filter( DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id @@ -244,7 +244,7 @@ class KeywordTableIndex(BaseIndex): document_segment.keywords = keywords db.session.commit() - def create_segment_keywords(self, node_id: str, keywords: List[str]): + def create_segment_keywords(self, node_id: str, keywords: list[str]): keyword_table = self._get_dataset_keyword_table() self._update_segment_keywords(self.dataset.id, node_id, keywords) keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) @@ -266,7 +266,7 @@ class KeywordTableIndex(BaseIndex): keyword_table = self._add_text_to_keyword_table(keyword_table, segment.index_node_id, list(keywords)) self._save_dataset_keyword_table(keyword_table) - def update_segment_keywords_index(self, node_id: str, keywords: List[str]): + def update_segment_keywords_index(self, node_id: str, keywords: list[str]): keyword_table = self._get_dataset_keyword_table() keyword_table = self._add_text_to_keyword_table(keyword_table, node_id, keywords) self._save_dataset_keyword_table(keyword_table) @@ -282,7 +282,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel): extra = Extra.forbid arbitrary_types_allowed = True - def get_relevant_documents(self, query: str) -> List[Document]: + def get_relevant_documents(self, query: str) -> list[Document]: """Get documents relevant for a query. Args: @@ -293,7 +293,7 @@ class KeywordTableRetriever(BaseRetriever, BaseModel): """ return self.index.search(query, **self.search_kwargs) - async def aget_relevant_documents(self, query: str) -> List[Document]: + async def aget_relevant_documents(self, query: str) -> list[Document]: raise NotImplementedError("KeywordTableRetriever does not support async") diff --git a/api/core/index/vector_index/base.py b/api/core/index/vector_index/base.py index b9b8e6d3d..36aa1917a 100644 --- a/api/core/index/vector_index/base.py +++ b/api/core/index/vector_index/base.py @@ -1,7 +1,7 @@ import json import logging from abc import abstractmethod -from typing import Any, List, cast +from typing import Any, cast from langchain.embeddings.base import Embeddings from langchain.schema import BaseRetriever, Document @@ -43,13 +43,13 @@ class BaseVectorIndex(BaseIndex): def search_by_full_text_index( self, query: str, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: raise NotImplementedError def search( self, query: str, **kwargs: Any - ) -> List[Document]: + ) -> list[Document]: vector_store = self._get_vector_store() vector_store = cast(self._get_vector_store_class(), vector_store) diff --git a/api/core/index/vector_index/milvus_vector_index.py b/api/core/index/vector_index/milvus_vector_index.py index a0b6f5d20..a18cf35a2 100644 --- a/api/core/index/vector_index/milvus_vector_index.py +++ b/api/core/index/vector_index/milvus_vector_index.py @@ -1,4 +1,4 @@ -from typing import Any, List, cast +from typing import Any, cast from langchain.embeddings.base import Embeddings from langchain.schema import Document @@ -160,6 +160,6 @@ class MilvusVectorIndex(BaseVectorIndex): ], )) - def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: + def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: # milvus/zilliz doesn't support bm25 search return [] diff --git a/api/core/index/vector_index/qdrant_vector_index.py b/api/core/index/vector_index/qdrant_vector_index.py index f182c4c0e..046260d2f 100644 --- a/api/core/index/vector_index/qdrant_vector_index.py +++ b/api/core/index/vector_index/qdrant_vector_index.py @@ -1,5 +1,5 @@ import os -from typing import Any, List, Optional, cast +from typing import Any, Optional, cast import qdrant_client from langchain.embeddings.base import Embeddings @@ -210,7 +210,7 @@ class QdrantVectorIndex(BaseVectorIndex): return False - def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: + def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: vector_store = self._get_vector_store() vector_store = cast(self._get_vector_store_class(), vector_store) diff --git a/api/core/index/vector_index/weaviate_vector_index.py b/api/core/index/vector_index/weaviate_vector_index.py index 8af3c5926..72a74a039 100644 --- a/api/core/index/vector_index/weaviate_vector_index.py +++ b/api/core/index/vector_index/weaviate_vector_index.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, cast +from typing import Any, Optional, cast import requests import weaviate @@ -172,7 +172,7 @@ class WeaviateVectorIndex(BaseVectorIndex): return False - def search_by_full_text_index(self, query: str, **kwargs: Any) -> List[Document]: + def search_by_full_text_index(self, query: str, **kwargs: Any) -> list[Document]: vector_store = self._get_vector_store() vector_store = cast(self._get_vector_store_class(), vector_store) return vector_store.similarity_search_by_bm25(query, kwargs.get('top_k', 2), **kwargs) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 1f36362a8..a14001d04 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -5,7 +5,7 @@ import re import threading import time import uuid -from typing import List, Optional, cast +from typing import Optional, cast from flask import Flask, current_app from flask_login import current_user @@ -40,7 +40,7 @@ class IndexingRunner: self.storage = storage self.model_manager = ModelManager() - def run(self, dataset_documents: List[DatasetDocument]): + def run(self, dataset_documents: list[DatasetDocument]): """Run the indexing process.""" for dataset_document in dataset_documents: try: @@ -238,7 +238,7 @@ class IndexingRunner: dataset_document.stopped_at = datetime.datetime.utcnow() db.session.commit() - def file_indexing_estimate(self, tenant_id: str, file_details: List[UploadFile], tmp_processing_rule: dict, + def file_indexing_estimate(self, tenant_id: str, file_details: list[UploadFile], tmp_processing_rule: dict, doc_form: str = None, doc_language: str = 'English', dataset_id: str = None, indexing_technique: str = 'economy') -> dict: """ @@ -494,7 +494,7 @@ class IndexingRunner: "preview": preview_texts } - def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> List[Document]: + def _load_data(self, dataset_document: DatasetDocument, automatic: bool = False) -> list[Document]: # load file if dataset_document.data_source_type not in ["upload_file", "notion_import"]: return [] @@ -526,7 +526,7 @@ class IndexingRunner: ) # replace doc id to document model id - text_docs = cast(List[Document], text_docs) + text_docs = cast(list[Document], text_docs) for text_doc in text_docs: # remove invalid symbol text_doc.page_content = self.filter_string(text_doc.page_content) @@ -540,7 +540,7 @@ class IndexingRunner: text = re.sub(r'\|>', '>', text) text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text) # Unicode U+FFFE - text = re.sub(u'\uFFFE', '', text) + text = re.sub('\uFFFE', '', text) return text def _get_splitter(self, processing_rule: DatasetProcessRule, @@ -577,9 +577,9 @@ class IndexingRunner: return character_splitter - def _step_split(self, text_docs: List[Document], splitter: TextSplitter, + def _step_split(self, text_docs: list[Document], splitter: TextSplitter, dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \ - -> List[Document]: + -> list[Document]: """ Split the text documents into documents and save them to the document segment. """ @@ -624,9 +624,9 @@ class IndexingRunner: return documents - def _split_to_documents(self, text_docs: List[Document], splitter: TextSplitter, + def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule, tenant_id: str, - document_form: str, document_language: str) -> List[Document]: + document_form: str, document_language: str) -> list[Document]: """ Split the text documents into nodes. """ @@ -699,8 +699,8 @@ class IndexingRunner: all_qa_documents.extend(format_documents) - def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter, - processing_rule: DatasetProcessRule) -> List[Document]: + def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter, + processing_rule: DatasetProcessRule) -> list[Document]: """ Split the text documents into nodes. """ @@ -770,7 +770,7 @@ class IndexingRunner: for q, a in matches if q and a ] - def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None: + def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: list[Document]) -> None: """ Build the index for the document. """ @@ -877,7 +877,7 @@ class IndexingRunner: DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.commit() - def batch_add_segments(self, segments: List[DocumentSegment], dataset: Dataset): + def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset): """ Batch add segments index processing """ diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 68df0ac31..8e36ab7ee 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -1,4 +1,5 @@ -from typing import IO, Generator, List, Optional, Union, cast +from collections.abc import Generator +from typing import IO, Optional, Union, cast from core.entities.provider_configuration import ProviderModelBundle from core.errors.error import ProviderTokenNotInitError @@ -47,7 +48,7 @@ class ModelInstance: return credentials def invoke_llm(self, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ -> Union[LLMResult, Generator]: """ diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/core/model_runtime/callbacks/base_callback.py index 58150ef4d..51af9786f 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/core/model_runtime/callbacks/base_callback.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import List, Optional +from typing import Optional from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -23,7 +23,7 @@ class Callback(ABC): def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Before invoke callback @@ -42,7 +42,7 @@ class Callback(ABC): def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None): """ On new chunk callback @@ -62,7 +62,7 @@ class Callback(ABC): def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ After invoke callback @@ -82,7 +82,7 @@ class Callback(ABC): def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Invoke error callback diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/core/model_runtime/callbacks/logging_callback.py index 486485844..0406853b8 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/core/model_runtime/callbacks/logging_callback.py @@ -1,7 +1,7 @@ import json import logging import sys -from typing import List, Optional +from typing import Optional from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) class LoggingCallback(Callback): def on_before_invoke(self, llm_instance: AIModel, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Before invoke callback @@ -60,7 +60,7 @@ class LoggingCallback(Callback): def on_new_chunk(self, llm_instance: AIModel, chunk: LLMResultChunk, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None): """ On new chunk callback @@ -81,7 +81,7 @@ class LoggingCallback(Callback): def on_after_invoke(self, llm_instance: AIModel, result: LLMResult, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ After invoke callback @@ -113,7 +113,7 @@ class LoggingCallback(Callback): def on_invoke_error(self, llm_instance: AIModel, ex: Exception, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> None: """ Invoke error callback diff --git a/api/core/model_runtime/entities/defaults.py b/api/core/model_runtime/entities/defaults.py index b39427dcc..856f4ce7d 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/core/model_runtime/entities/defaults.py @@ -1,8 +1,7 @@ -from typing import Dict from core.model_runtime.entities.model_entities import DefaultParameterName -PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = { +PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { DefaultParameterName.TEMPERATURE: { 'label': { 'en_US': 'Temperature', diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index eb811ab22..a9f7a539e 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -153,7 +153,7 @@ class AIModel(ABC): # read _position.yaml file position_map = {} if os.path.exists(position_file_path): - with open(position_file_path, 'r', encoding='utf-8') as f: + with open(position_file_path, encoding='utf-8') as f: positions = yaml.safe_load(f) # convert list to dict with key as model provider name, value as index position_map = {position: index for index, position in enumerate(positions)} @@ -161,7 +161,7 @@ class AIModel(ABC): # traverse all model_schema_yaml_paths for model_schema_yaml_path in model_schema_yaml_paths: # read yaml data from yaml file - with open(model_schema_yaml_path, 'r', encoding='utf-8') as f: + with open(model_schema_yaml_path, encoding='utf-8') as f: yaml_data = yaml.safe_load(f) new_parameter_rules = [] diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 173b4dcab..1f7edd245 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -3,7 +3,8 @@ import os import re import time from abc import abstractmethod -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.callbacks.logging_callback import LoggingCallback @@ -29,7 +30,7 @@ class LargeLanguageModel(AIModel): def invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) \ -> Union[LLMResult, Generator]: """ @@ -122,7 +123,7 @@ class LargeLanguageModel(AIModel): def _invoke_result_generator(self, model: str, result: Generator, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) -> Generator: """ Invoke result generator @@ -186,7 +187,7 @@ class LargeLanguageModel(AIModel): @abstractmethod def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -218,7 +219,7 @@ class LargeLanguageModel(AIModel): """ raise NotImplementedError - def enforce_stop_tokens(self, text: str, stop: List[str]) -> str: + def enforce_stop_tokens(self, text: str, stop: list[str]) -> str: """Cut off the text as soon as any stop words occur.""" return re.split("|".join(stop), text, maxsplit=1)[0] @@ -329,7 +330,7 @@ class LargeLanguageModel(AIModel): def _trigger_before_invoke_callbacks(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) -> None: """ Trigger before invoke callbacks @@ -367,7 +368,7 @@ class LargeLanguageModel(AIModel): def _trigger_new_chunk_callbacks(self, chunk: LLMResultChunk, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) -> None: """ Trigger new chunk callbacks @@ -406,7 +407,7 @@ class LargeLanguageModel(AIModel): def _trigger_after_invoke_callbacks(self, model: str, result: LLMResult, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) -> None: """ Trigger after invoke callbacks @@ -446,7 +447,7 @@ class LargeLanguageModel(AIModel): def _trigger_invoke_error_callbacks(self, model: str, ex: Exception, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None, callbacks: list[Callback] = None) -> None: """ Trigger invoke error callbacks @@ -527,7 +528,7 @@ class LargeLanguageModel(AIModel): raise ValueError( f"Model Parameter {parameter_name} should be less than or equal to {parameter_rule.max}.") elif parameter_rule.type == ParameterType.FLOAT: - if not isinstance(parameter_value, (float, int)): + if not isinstance(parameter_value, float | int): raise ValueError(f"Model Parameter {parameter_name} should be float.") # validate parameter value precision diff --git a/api/core/model_runtime/model_providers/__base/model_provider.py b/api/core/model_runtime/model_providers/__base/model_provider.py index f3d71670f..97ce07d35 100644 --- a/api/core/model_runtime/model_providers/__base/model_provider.py +++ b/api/core/model_runtime/model_providers/__base/model_provider.py @@ -1,7 +1,6 @@ import importlib import os from abc import ABC, abstractmethod -from typing import Dict import yaml @@ -12,7 +11,7 @@ from core.model_runtime.model_providers.__base.ai_model import AIModel class ModelProvider(ABC): provider_schema: ProviderEntity = None - model_instance_map: Dict[str, AIModel] = {} + model_instance_map: dict[str, AIModel] = {} @abstractmethod def validate_provider_credentials(self, credentials: dict) -> None: @@ -47,7 +46,7 @@ class ModelProvider(ABC): yaml_path = os.path.join(current_path, f'{provider_name}.yaml') yaml_data = {} if os.path.exists(yaml_path): - with open(yaml_path, 'r', encoding='utf-8') as f: + with open(yaml_path, encoding='utf-8') as f: yaml_data = yaml.safe_load(f) try: diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 3f689a724..c74370889 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union import anthropic from anthropic import Anthropic, Stream @@ -29,7 +30,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -90,7 +91,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -255,7 +256,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt_anthropic(self, messages: List[PromptMessage]) -> str: + def _convert_messages_to_prompt_anthropic(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Anthropic model diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 1bab34edd..4b89adaa4 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -1,6 +1,7 @@ import copy import logging -from typing import Generator, List, Optional, Union, cast +from collections.abc import Generator +from typing import Optional, Union, cast import tiktoken from openai import AzureOpenAI, Stream @@ -34,7 +35,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: @@ -121,7 +122,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return ai_model_entity.entity if ai_model_entity else None def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) @@ -239,7 +240,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): def _chat_generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: client = AzureOpenAI(**self._to_credential_kwargs(credentials)) @@ -537,7 +538,7 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, credentials: dict, messages: List[PromptMessage], + def _num_tokens_from_messages(self, credentials: dict, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. diff --git a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py index 606a898db..e073bef01 100644 --- a/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/azure_openai/text_embedding/text_embedding.py @@ -1,7 +1,7 @@ import base64 import copy import time -from typing import Optional, Tuple, Union +from typing import Optional, Union import numpy as np import tiktoken @@ -149,7 +149,7 @@ class AzureOpenAITextEmbeddingModel(_CommonAzureOpenAI, TextEmbeddingModel): @staticmethod def _embedding_invoke(model: str, client: AzureOpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> Tuple[list[list[float]], int]: + extra_model_kwargs: dict) -> tuple[list[list[float]], int]: response = client.embeddings.create( input=texts, model=model, diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py index 4562bb2be..7549b2fb6 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_tokenizer.py @@ -1,7 +1,7 @@ import re -class BaichuanTokenizer(object): +class BaichuanTokenizer: @classmethod def count_chinese_characters(cls, text: str) -> int: return len(re.findall(r'[\u4e00-\u9fa5]', text)) diff --git a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py index 46ba0cffa..ae73c1735 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/baichuan_turbo.py @@ -1,7 +1,8 @@ +from collections.abc import Generator from enum import Enum from hashlib import md5 from json import dumps, loads -from typing import Any, Dict, Generator, List, Union +from typing import Any, Union from requests import post @@ -24,10 +25,10 @@ class BaichuanMessage: role: str = Role.USER.value content: str - usage: Dict[str, int] = None + usage: dict[str, int] = None stop_reason: str = '' - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { 'role': self.role, 'content': self.content, @@ -37,7 +38,7 @@ class BaichuanMessage: self.content = content self.role = role -class BaichuanModel(object): +class BaichuanModel: api_key: str secret_key: str @@ -106,9 +107,9 @@ class BaichuanModel(object): message.stop_reason = stop_reason yield message - def _build_parameters(self, model: str, stream: bool, messages: List[BaichuanMessage], - parameters: Dict[str, Any]) \ - -> Dict[str, Any]: + def _build_parameters(self, model: str, stream: bool, messages: list[BaichuanMessage], + parameters: dict[str, Any]) \ + -> dict[str, Any]: if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b': prompt_messages = [] for message in messages: @@ -139,7 +140,7 @@ class BaichuanModel(object): else: raise BadRequestError(f"Unknown model: {model}") - def _build_headers(self, model: str, data: Dict[str, Any]) -> Dict[str, Any]: + def _build_headers(self, model: str, data: dict[str, Any]) -> dict[str, Any]: if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b': # there is no secret key for turbo api return { @@ -153,8 +154,8 @@ class BaichuanModel(object): def _calculate_md5(self, input_string): return md5(input_string.encode('utf-8')).hexdigest() - def generate(self, model: str, stream: bool, messages: List[BaichuanMessage], - parameters: Dict[str, Any], timeout: int) \ + def generate(self, model: str, stream: bool, messages: list[BaichuanMessage], + parameters: dict[str, Any], timeout: int) \ -> Union[Generator, BaichuanMessage]: if model == 'baichuan2-turbo' or model == 'baichuan2-turbo-192k' or model == 'baichuan2-53b': diff --git a/api/core/model_runtime/model_providers/baichuan/llm/llm.py b/api/core/model_runtime/model_providers/baichuan/llm/llm.py index a7c6119d1..707355fa7 100644 --- a/api/core/model_runtime/model_providers/baichuan/llm/llm.py +++ b/api/core/model_runtime/model_providers/baichuan/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, cast +from collections.abc import Generator +from typing import cast from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -33,7 +34,7 @@ from core.model_runtime.model_providers.baichuan.llm.baichuan_turbo_errors impor class BaichuanLarguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: List[str] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, @@ -43,7 +44,7 @@ class BaichuanLarguageModel(LargeLanguageModel): tools: list[PromptMessageTool] | None = None) -> int: return self._num_tokens_from_messages(prompt_messages) - def _num_tokens_from_messages(self, messages: List[PromptMessage],) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int: """Calculate num tokens for baichuan model""" def tokens(text: str): return BaichuanTokenizer._get_num_tokens(text) @@ -107,7 +108,7 @@ class BaichuanLarguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: if tools is not None and len(tools) > 0: raise InvokeBadRequestError("Baichuan model doesn't support tools") diff --git a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py index 5020c5899..da4ba5588 100644 --- a/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/baichuan/text_embedding/text_embedding.py @@ -1,6 +1,6 @@ import time from json import dumps -from typing import Optional, Tuple +from typing import Optional from requests import post @@ -84,7 +84,7 @@ class BaichuanTextEmbeddingModel(TextEmbeddingModel): return result def embedding(self, model: str, api_key, texts: list[str], user: Optional[str] = None) \ - -> Tuple[list[list[float]], int]: + -> tuple[list[list[float]], int]: """ Embed given texts diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 7a2faae89..c6aaa24ad 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -1,6 +1,7 @@ import json import logging -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union import boto3 from botocore.config import Config @@ -37,7 +38,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -159,7 +160,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: List[PromptMessage], model_prefix: str) -> str: + def _convert_messages_to_prompt(self, messages: list[PromptMessage], model_prefix: str) -> str: """ Format a list of messages into a full prompt for the Anthropic, Amazon and Llama models @@ -181,7 +182,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() - def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, stream: bool = True): + def _create_payload(self, model_prefix: str, prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True): """ Create payload for bedrock api call depending on model provider """ @@ -231,7 +232,7 @@ class BedrockLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model diff --git a/api/core/model_runtime/model_providers/chatglm/llm/llm.py b/api/core/model_runtime/model_providers/chatglm/llm/llm.py index fd2bcd5ec..12dc75aec 100644 --- a/api/core/model_runtime/model_providers/chatglm/llm/llm.py +++ b/api/core/model_runtime/model_providers/chatglm/llm/llm.py @@ -1,6 +1,7 @@ import logging +from collections.abc import Generator from os.path import join -from typing import Generator, List, Optional, cast +from typing import Optional, cast from httpx import Timeout from openai import ( @@ -45,7 +46,7 @@ logger = logging.getLogger(__name__) class ChatGLMLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: List[str] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ @@ -138,7 +139,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: List[str] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ @@ -394,7 +395,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, messages: List[PromptMessage], + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """Calculate num tokens for chatglm2 and chatglm3 with GPT2 tokenizer. diff --git a/api/core/model_runtime/model_providers/cohere/llm/llm.py b/api/core/model_runtime/model_providers/cohere/llm/llm.py index 95d3252b1..667ba4c78 100644 --- a/api/core/model_runtime/model_providers/cohere/llm/llm.py +++ b/api/core/model_runtime/model_providers/cohere/llm/llm.py @@ -1,5 +1,6 @@ import logging -from typing import Generator, List, Optional, Tuple, Union, cast +from collections.abc import Generator +from typing import Optional, Union, cast import cohere from cohere.responses import Chat, Generations @@ -38,7 +39,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -138,7 +139,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): raise CredentialsValidateFailedError(str(ex)) def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm model @@ -264,7 +265,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): break def _chat_generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -306,7 +307,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): return self._handle_chat_generate_response(model, credentials, response, prompt_messages, stop) def _handle_chat_generate_response(self, model: str, credentials: dict, response: Chat, - prompt_messages: list[PromptMessage], stop: Optional[List[str]] = None) \ + prompt_messages: list[PromptMessage], stop: Optional[list[str]] = None) \ -> LLMResult: """ Handle llm chat response @@ -352,7 +353,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): def _handle_chat_generate_stream_response(self, model: str, credentials: dict, response: StreamingChat, prompt_messages: list[PromptMessage], - stop: Optional[List[str]] = None) -> Generator: + stop: Optional[list[str]] = None) -> Generator: """ Handle llm chat stream response @@ -427,7 +428,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): index += 1 def _convert_prompt_messages_to_message_and_chat_histories(self, prompt_messages: list[PromptMessage]) \ - -> Tuple[str, list[dict]]: + -> tuple[str, list[dict]]: """ Convert prompt messages to message and chat histories :param prompt_messages: prompt messages @@ -495,7 +496,7 @@ class CohereLargeLanguageModel(LargeLanguageModel): return response.length - def _num_tokens_from_messages(self, model: str, credentials: dict, messages: List[PromptMessage]) -> int: + def _num_tokens_from_messages(self, model: str, credentials: dict, messages: list[PromptMessage]) -> int: """Calculate num tokens Cohere model.""" messages = [self._convert_prompt_message_to_dict(m) for m in messages] message_strs = [f"{message['role']}: {message['message']}" for message in messages] diff --git a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py index fda8b27de..5eec72184 100644 --- a/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/cohere/text_embedding/text_embedding.py @@ -1,5 +1,5 @@ import time -from typing import Optional, Tuple +from typing import Optional import cohere import numpy as np @@ -168,7 +168,7 @@ class CohereTextEmbeddingModel(TextEmbeddingModel): except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> Tuple[list[list[float]], int]: + def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]: """ Invoke embedding model diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index e376e72c0..686761ab5 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -1,5 +1,6 @@ import logging -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union import google.api_core.exceptions as exceptions import google.generativeai as genai @@ -34,7 +35,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -103,7 +104,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model diff --git a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py index 381d29c7e..f43a8aeda 100644 --- a/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py +++ b/api/core/model_runtime/model_providers/huggingface_hub/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from huggingface_hub import InferenceClient from huggingface_hub.hf_api import HfApi @@ -29,7 +30,7 @@ from core.model_runtime.model_providers.huggingface_hub._common import _CommonHu class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: client = InferenceClient(token=credentials['huggingfacehub_api_token']) diff --git a/api/core/model_runtime/model_providers/localai/llm/llm.py b/api/core/model_runtime/model_providers/localai/llm/llm.py index 8d571d20b..694f5891f 100644 --- a/api/core/model_runtime/model_providers/localai/llm/llm.py +++ b/api/core/model_runtime/model_providers/localai/llm/llm.py @@ -1,5 +1,6 @@ +from collections.abc import Generator from os.path import join -from typing import Generator, List, cast +from typing import cast from httpx import Timeout from openai import ( @@ -52,7 +53,7 @@ from core.model_runtime.utils import helper class LocalAILarguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: List[str] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, @@ -63,7 +64,7 @@ class LocalAILarguageModel(LargeLanguageModel): # tools is not supported yet return self._num_tokens_from_messages(prompt_messages, tools=tools) - def _num_tokens_from_messages(self, messages: List[PromptMessage], tools: list[PromptMessageTool]) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ Calculate num tokens for baichuan model LocalAI does not supports @@ -241,7 +242,7 @@ class LocalAILarguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: kwargs = self._to_client_kwargs(credentials) @@ -346,7 +347,7 @@ class LocalAILarguageModel(LargeLanguageModel): return message_dict - def _convert_prompt_message_to_completion_prompts(self, messages: List[PromptMessage]) -> str: + def _convert_prompt_message_to_completion_prompts(self, messages: list[PromptMessage]) -> str: """ Convert PromptMessage to completion prompts """ diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py index ee73005bd..6c41e0d2a 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion.py @@ -1,5 +1,6 @@ +from collections.abc import Generator from json import dumps, loads -from typing import Any, Dict, Generator, List, Union +from typing import Any, Union from requests import Response, post @@ -14,13 +15,13 @@ from core.model_runtime.model_providers.minimax.llm.errors import ( from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage -class MinimaxChatCompletion(object): +class MinimaxChatCompletion: """ Minimax Chat Completion API """ def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: List[MinimaxMessage], model_parameters: dict, - tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \ + prompt_messages: list[MinimaxMessage], model_parameters: dict, + tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ generate chat completion diff --git a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py index 2497a9d7b..81ea2e165 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py +++ b/api/core/model_runtime/model_providers/minimax/llm/chat_completion_pro.py @@ -1,5 +1,6 @@ +from collections.abc import Generator from json import dumps, loads -from typing import Any, Dict, Generator, List, Union +from typing import Any, Union from requests import Response, post @@ -14,14 +15,14 @@ from core.model_runtime.model_providers.minimax.llm.errors import ( from core.model_runtime.model_providers.minimax.llm.types import MinimaxMessage -class MinimaxChatCompletionPro(object): +class MinimaxChatCompletionPro: """ Minimax Chat Completion Pro API, supports function calling however, we do not have enough time and energy to implement it, but the parameters are reserved """ def generate(self, model: str, api_key: str, group_id: str, - prompt_messages: List[MinimaxMessage], model_parameters: dict, - tools: List[Dict[str, Any]], stop: List[str] | None, stream: bool, user: str) \ + prompt_messages: list[MinimaxMessage], model_parameters: dict, + tools: list[dict[str, Any]], stop: list[str] | None, stream: bool, user: str) \ -> Union[MinimaxMessage, Generator[MinimaxMessage, None, None]]: """ generate chat completion diff --git a/api/core/model_runtime/model_providers/minimax/llm/llm.py b/api/core/model_runtime/model_providers/minimax/llm/llm.py index bc65e756e..cc88d1573 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/llm.py +++ b/api/core/model_runtime/model_providers/minimax/llm/llm.py @@ -1,4 +1,4 @@ -from typing import Generator, List +from collections.abc import Generator from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -42,7 +42,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -79,7 +79,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): tools: list[PromptMessageTool] | None = None) -> int: return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: List[PromptMessage], tools: list[PromptMessageTool]) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ Calculate num tokens for minimax model @@ -94,7 +94,7 @@ class MinimaxLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ use MinimaxChatCompletionPro as the type of client, anyway, MinimaxChatCompletion has the same interface diff --git a/api/core/model_runtime/model_providers/minimax/llm/types.py b/api/core/model_runtime/model_providers/minimax/llm/types.py index 622931244..b33a7ca9a 100644 --- a/api/core/model_runtime/model_providers/minimax/llm/types.py +++ b/api/core/model_runtime/model_providers/minimax/llm/types.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict +from typing import Any class MinimaxMessage: @@ -11,11 +11,11 @@ class MinimaxMessage: role: str = Role.USER.value content: str - usage: Dict[str, int] = None + usage: dict[str, int] = None stop_reason: str = '' - function_call: Dict[str, Any] = None + function_call: dict[str, Any] = None - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: if self.function_call and self.role == MinimaxMessage.Role.ASSISTANT.value: return { 'sender_type': 'BOT', diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index e1e74ea80..185ff6271 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -220,7 +220,7 @@ class ModelProviderFactory: # read _position.yaml file position_map = {} if os.path.exists(position_file_path): - with open(position_file_path, 'r', encoding='utf-8') as f: + with open(position_file_path, encoding='utf-8') as f: positions = yaml.safe_load(f) # convert list to dict with key as model provider name, value as index position_map = {position: index for index, position in enumerate(positions)} diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index 40618b7fb..5db3e2827 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -8,7 +9,7 @@ from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAI class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) diff --git a/api/core/model_runtime/model_providers/ollama/llm/llm.py b/api/core/model_runtime/model_providers/ollama/llm/llm.py index 848ac76d3..e4388699e 100644 --- a/api/core/model_runtime/model_providers/ollama/llm/llm.py +++ b/api/core/model_runtime/model_providers/ollama/llm/llm.py @@ -1,8 +1,9 @@ import json import logging import re +from collections.abc import Generator from decimal import Decimal -from typing import Generator, List, Optional, Union, cast +from typing import Optional, Union, cast from urllib.parse import urljoin import requests @@ -51,7 +52,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -131,7 +132,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}') def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -398,7 +399,7 @@ class OllamaLargeLanguageModel(LargeLanguageModel): return message_dict - def _num_tokens_from_messages(self, messages: List[PromptMessage]) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage]) -> int: """ Calculate num tokens. diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index 92a370e04..2a1137d44 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -1,5 +1,6 @@ import logging -from typing import Generator, List, Optional, Union, cast +from collections.abc import Generator +from typing import Optional, Union, cast import tiktoken from openai import OpenAI, Stream @@ -35,7 +36,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -215,7 +216,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return ai_model_entities def _generate(self, model: str, credentials: dict, - prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[List[str]] = None, + prompt_messages: list[PromptMessage], model_parameters: dict, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm completion model @@ -366,7 +367,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): def _chat_generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke llm chat model @@ -706,7 +707,7 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, model: str, messages: List[PromptMessage], + def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. diff --git a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py index 28ab5c30f..e23a2edf8 100644 --- a/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai/text_embedding/text_embedding.py @@ -1,6 +1,6 @@ import base64 import time -from typing import Optional, Tuple, Union +from typing import Optional, Union import numpy as np import tiktoken @@ -162,7 +162,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel): raise CredentialsValidateFailedError(str(ex)) def _embedding_invoke(self, model: str, client: OpenAI, texts: Union[list[str], str], - extra_model_kwargs: dict) -> Tuple[list[list[float]], int]: + extra_model_kwargs: dict) -> tuple[list[list[float]], int]: """ Invoke embedding model diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index 9a26f3dc0..ae856c5ce 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -1,7 +1,8 @@ import json import logging +from collections.abc import Generator from decimal import Decimal -from typing import Generator, List, Optional, Union, cast +from typing import Optional, Union, cast from urllib.parse import urljoin import requests @@ -46,7 +47,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -245,7 +246,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): # validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard. def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, \ user: Optional[str] = None) -> Union[LLMResult, Generator]: """ @@ -567,7 +568,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel): return num_tokens - def _num_tokens_from_messages(self, model: str, messages: List[PromptMessage], + def _num_tokens_from_messages(self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> int: """ Approximate num tokens with GPT2 tokenizer. diff --git a/api/core/model_runtime/model_providers/openllm/llm/llm.py b/api/core/model_runtime/model_providers/openllm/llm/llm.py index 3491f107a..8ea5819bd 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/llm.py +++ b/api/core/model_runtime/model_providers/openllm/llm/llm.py @@ -1,4 +1,4 @@ -from typing import Generator, List +from collections.abc import Generator from core.model_runtime.entities.common_entities import I18nObject from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -40,7 +40,7 @@ from core.model_runtime.model_providers.openllm.llm.openllm_generate_errors impo class OpenLLMLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) @@ -77,7 +77,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): tools: list[PromptMessageTool] | None = None) -> int: return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: List[PromptMessage], tools: list[PromptMessageTool]) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool]) -> int: """ Calculate num tokens for OpenLLM model it's a generate model, so we just join them by spe @@ -87,7 +87,7 @@ class OpenLLMLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: client = OpenLLMGenerate() response = client.generate( diff --git a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py index 06453cb3f..43258d1e5 100644 --- a/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py +++ b/api/core/model_runtime/model_providers/openllm/llm/openllm_generate.py @@ -1,6 +1,7 @@ +from collections.abc import Generator from enum import Enum from json import dumps, loads -from typing import Any, Dict, Generator, List, Union +from typing import Any, Union from requests import Response, post from requests.exceptions import ConnectionError, InvalidSchema, MissingSchema @@ -19,10 +20,10 @@ class OpenLLMGenerateMessage: role: str = Role.USER.value content: str - usage: Dict[str, int] = None + usage: dict[str, int] = None stop_reason: str = '' - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { 'role': self.role, 'content': self.content, @@ -33,10 +34,10 @@ class OpenLLMGenerateMessage: self.role = role -class OpenLLMGenerate(object): +class OpenLLMGenerate: def generate( - self, server_url: str, model_name: str, stream: bool, model_parameters: Dict[str, Any], - stop: List[str], prompt_messages: List[OpenLLMGenerateMessage], user: str, + self, server_url: str, model_name: str, stream: bool, model_parameters: dict[str, Any], + stop: list[str], prompt_messages: list[OpenLLMGenerateMessage], user: str, ) -> Union[Generator[OpenLLMGenerateMessage, None, None], OpenLLMGenerateMessage]: if not server_url: raise InvalidAuthenticationError('Invalid server URL') diff --git a/api/core/model_runtime/model_providers/replicate/llm/llm.py b/api/core/model_runtime/model_providers/replicate/llm/llm.py index ce69c6798..ee2de8560 100644 --- a/api/core/model_runtime/model_providers/replicate/llm/llm.py +++ b/api/core/model_runtime/model_providers/replicate/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from replicate import Client as ReplicateClient from replicate.exceptions import ReplicateError @@ -29,7 +30,7 @@ from core.model_runtime.model_providers.replicate._common import _CommonReplicat class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: version = credentials['model_version'] diff --git a/api/core/model_runtime/model_providers/spark/llm/llm.py b/api/core/model_runtime/model_providers/spark/llm/llm.py index 6dfa1e3a6..65beae517 100644 --- a/api/core/model_runtime/model_providers/spark/llm/llm.py +++ b/api/core/model_runtime/model_providers/spark/llm/llm.py @@ -1,5 +1,6 @@ import threading -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -27,7 +28,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -86,7 +87,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -244,7 +245,7 @@ class SparkLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str: + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Anthropic model diff --git a/api/core/model_runtime/model_providers/togetherai/llm/llm.py b/api/core/model_runtime/model_providers/togetherai/llm/llm.py index 89198fe4b..b312d99b1 100644 --- a/api/core/model_runtime/model_providers/togetherai/llm/llm.py +++ b/api/core/model_runtime/model_providers/togetherai/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -14,7 +15,7 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) @@ -27,7 +28,7 @@ class TogetherAILargeLanguageModel(OAIAPICompatLargeLanguageModel): return super().validate_credentials(model, cred_with_endpoint) def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: cred_with_endpoint = self._update_endpoint_url(credentials=credentials) diff --git a/api/core/model_runtime/model_providers/tongyi/llm/_client.py b/api/core/model_runtime/model_providers/tongyi/llm/_client.py index 2aab69af7..cfe33558e 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/_client.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/_client.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms import Tongyi @@ -8,7 +8,7 @@ from langchain.schema import Generation, LLMResult class EnhanceTongyi(Tongyi): @property - def _default_params(self) -> Dict[str, Any]: + def _default_params(self) -> dict[str, Any]: """Get the default parameters for calling OpenAI API.""" normal_params = { "top_p": self.top_p, @@ -19,13 +19,13 @@ class EnhanceTongyi(Tongyi): def _generate( self, - prompts: List[str], - stop: Optional[List[str]] = None, + prompts: list[str], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: generations = [] - params: Dict[str, Any] = { + params: dict[str, Any] = { **{"model": self.model_name}, **self._default_params, **kwargs, diff --git a/api/core/model_runtime/model_providers/tongyi/llm/llm.py b/api/core/model_runtime/model_providers/tongyi/llm/llm.py index 8aac4412f..7ae8b8776 100644 --- a/api/core/model_runtime/model_providers/tongyi/llm/llm.py +++ b/api/core/model_runtime/model_providers/tongyi/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from dashscope import get_tokenizer from dashscope.api_entities.dashscope_response import DashScopeAPIResponse @@ -38,7 +39,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -100,7 +101,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -268,7 +269,7 @@ class TongyiLargeLanguageModel(LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: List[PromptMessage]) -> str: + def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str: """ Format a list of messages into a full prompt for the Anthropic model diff --git a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py index af04eca59..81868aeed 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/ernie_bot.py @@ -1,8 +1,9 @@ +from collections.abc import Generator from datetime import datetime, timedelta from enum import Enum from json import dumps, loads from threading import Lock -from typing import Any, Dict, Generator, List, Union +from typing import Any, Union from requests import Response, post @@ -16,7 +17,7 @@ from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( ) # map api_key to access_token -baidu_access_tokens: Dict[str, 'BaiduAccessToken'] = {} +baidu_access_tokens: dict[str, 'BaiduAccessToken'] = {} baidu_access_tokens_lock = Lock() class BaiduAccessToken: @@ -105,10 +106,10 @@ class ErnieMessage: role: str = Role.USER.value content: str - usage: Dict[str, int] = None + usage: dict[str, int] = None stop_reason: str = '' - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { 'role': self.role, 'content': self.content, @@ -118,7 +119,7 @@ class ErnieMessage: self.content = content self.role = role -class ErnieBotModel(object): +class ErnieBotModel: api_bases = { 'ernie-bot': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions', 'ernie-bot-4': 'https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro', @@ -138,9 +139,9 @@ class ErnieBotModel(object): self.api_key = api_key self.secret_key = secret_key - def generate(self, model: str, stream: bool, messages: List[ErnieMessage], - parameters: Dict[str, Any], timeout: int, tools: List[PromptMessageTool], \ - stop: List[str], user: str) \ + def generate(self, model: str, stream: bool, messages: list[ErnieMessage], + parameters: dict[str, Any], timeout: int, tools: list[PromptMessageTool], \ + stop: list[str], user: str) \ -> Union[Generator[ErnieMessage, None, None], ErnieMessage]: # check parameters @@ -216,11 +217,11 @@ class ErnieBotModel(object): token = BaiduAccessToken.get_access_token(self.api_key, self.secret_key) return token.access_token - def _copy_messages(self, messages: List[ErnieMessage]) -> List[ErnieMessage]: + def _copy_messages(self, messages: list[ErnieMessage]) -> list[ErnieMessage]: return [ErnieMessage(message.content, message.role) for message in messages] - def _check_parameters(self, model: str, parameters: Dict[str, Any], - tools: List[PromptMessageTool], stop: List[str]) -> None: + def _check_parameters(self, model: str, parameters: dict[str, Any], + tools: list[PromptMessageTool], stop: list[str]) -> None: if model not in self.api_bases: raise BadRequestError(f'Invalid model: {model}') @@ -241,16 +242,16 @@ class ErnieBotModel(object): if len(s) > 20: raise BadRequestError('stop item should not exceed 20 characters.') - def _build_request_body(self, model: str, messages: List[ErnieMessage], stream: bool, parameters: Dict[str, Any], - tools: List[PromptMessageTool], stop: List[str], user: str) -> Dict[str, Any]: + def _build_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, parameters: dict[str, Any], + tools: list[PromptMessageTool], stop: list[str], user: str) -> dict[str, Any]: # if model in self.function_calling_supports: # return self._build_function_calling_request_body(model, messages, parameters, tools, stop, user) return self._build_chat_request_body(model, messages, stream, parameters, stop, user) - def _build_function_calling_request_body(self, model: str, messages: List[ErnieMessage], stream: bool, - parameters: Dict[str, Any], tools: List[PromptMessageTool], - stop: List[str], user: str) \ - -> Dict[str, Any]: + def _build_function_calling_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, + parameters: dict[str, Any], tools: list[PromptMessageTool], + stop: list[str], user: str) \ + -> dict[str, Any]: if len(messages) % 2 == 0: raise BadRequestError('The number of messages should be odd.') if messages[0].role == 'function': @@ -260,9 +261,9 @@ class ErnieBotModel(object): TODO: implement function calling """ - def _build_chat_request_body(self, model: str, messages: List[ErnieMessage], stream: bool, - parameters: Dict[str, Any], stop: List[str], user: str) \ - -> Dict[str, Any]: + def _build_chat_request_body(self, model: str, messages: list[ErnieMessage], stream: bool, + parameters: dict[str, Any], stop: list[str], user: str) \ + -> dict[str, Any]: if len(messages) == 0: raise BadRequestError('The number of messages should not be zero.') diff --git a/api/core/model_runtime/model_providers/wenxin/llm/llm.py b/api/core/model_runtime/model_providers/wenxin/llm/llm.py index b13e340d9..51b3c9749 100644 --- a/api/core/model_runtime/model_providers/wenxin/llm/llm.py +++ b/api/core/model_runtime/model_providers/wenxin/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, cast +from collections.abc import Generator +from typing import cast from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -32,7 +33,7 @@ from core.model_runtime.model_providers.wenxin.llm.ernie_bot_errors import ( class ErnieBotLarguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: list[PromptMessageTool] | None = None, stop: List[str] | None = None, + tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: return self._generate(model=model, credentials=credentials, prompt_messages=prompt_messages, @@ -43,7 +44,7 @@ class ErnieBotLarguageModel(LargeLanguageModel): # tools is not supported yet return self._num_tokens_from_messages(prompt_messages) - def _num_tokens_from_messages(self, messages: List[PromptMessage],) -> int: + def _num_tokens_from_messages(self, messages: list[PromptMessage],) -> int: """Calculate num tokens for baichuan model""" def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -78,7 +79,7 @@ class ErnieBotLarguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: instance = ErnieBotModel( api_key=credentials['api_key'], diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index 7da1b0065..83c003d05 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, Iterator, List, cast +from collections.abc import Generator, Iterator +from typing import cast from openai import ( APIConnectionError, @@ -62,7 +63,7 @@ from core.model_runtime.utils import helper class XinferenceAILargeLanguageModel(LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ invoke LLM @@ -131,7 +132,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): """ return self._num_tokens_from_messages(prompt_messages, tools) - def _num_tokens_from_messages(self, messages: List[PromptMessage], tools: list[PromptMessageTool], + def _num_tokens_from_messages(self, messages: list[PromptMessage], tools: list[PromptMessageTool], is_completion_model: bool = False) -> int: def tokens(text: str): return self._get_num_tokens_by_gpt2(text) @@ -359,7 +360,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, extra_model_kwargs: XinferenceModelExtraParameter, tools: list[PromptMessageTool] | None = None, - stop: List[str] | None = None, stream: bool = True, user: str | None = None) \ + stop: list[str] | None = None, stream: bool = True, user: str | None = None) \ -> LLMResult | Generator: """ generate text from LLM @@ -404,7 +405,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): } for tool in tools ] - if isinstance(xinference_model, (RESTfulChatModelHandle, RESTfulChatglmCppChatModelHandle)): + if isinstance(xinference_model, RESTfulChatModelHandle | RESTfulChatglmCppChatModelHandle): resp = client.chat.completions.create( model=credentials['model_uid'], messages=[self._convert_prompt_message_to_dict(message) for message in prompt_messages], diff --git a/api/core/model_runtime/model_providers/xinference/xinference_helper.py b/api/core/model_runtime/model_providers/xinference/xinference_helper.py index 089ffd691..24a91af62 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference_helper.py +++ b/api/core/model_runtime/model_providers/xinference/xinference_helper.py @@ -1,22 +1,21 @@ from os import path from threading import Lock from time import time -from typing import List from requests.adapters import HTTPAdapter from requests.exceptions import ConnectionError, MissingSchema, Timeout from requests.sessions import Session -class XinferenceModelExtraParameter(object): +class XinferenceModelExtraParameter: model_format: str model_handle_type: str - model_ability: List[str] + model_ability: list[str] max_tokens: int = 512 context_length: int = 2048 support_function_call: bool = False - def __init__(self, model_format: str, model_handle_type: str, model_ability: List[str], + def __init__(self, model_format: str, model_handle_type: str, model_ability: list[str], support_function_call: bool, max_tokens: int, context_length: int) -> None: self.model_format = model_format self.model_handle_type = model_handle_type diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index 6d1f462d0..c62422dfb 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -1,4 +1,5 @@ -from typing import Generator, List, Optional, Union +from collections.abc import Generator +from typing import Optional, Union from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.message_entities import ( @@ -23,7 +24,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): def _invoke(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict, - tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, + tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) \ -> Union[LLMResult, Generator]: """ @@ -89,7 +90,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): def _generate(self, model: str, credentials_kwargs: dict, prompt_messages: list[PromptMessage], model_parameters: dict, tools: Optional[list[PromptMessageTool]] = None, - stop: Optional[List[str]] = None, stream: bool = True, + stop: Optional[list[str]] = None, stream: bool = True, user: Optional[str] = None) -> Union[LLMResult, Generator]: """ Invoke large language model @@ -119,7 +120,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): prompt_messages = prompt_messages[1:] # resolve zhipuai model not support system message and user message, assistant message must be in sequence - new_prompt_messages: List[PromptMessage] = [] + new_prompt_messages: list[PromptMessage] = [] for prompt_message in prompt_messages: copy_prompt_message = prompt_message.copy() if copy_prompt_message.role in [PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL]: @@ -275,7 +276,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :return: llm response """ text = '' - assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = [] + assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for choice in response.choices: if choice.message.tool_calls: for tool_call in choice.message.tool_calls: @@ -335,7 +336,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): if delta.finish_reason is None and (delta.delta.content is None or delta.delta.content == ''): continue - assistant_tool_calls: List[AssistantPromptMessage.ToolCall] = [] + assistant_tool_calls: list[AssistantPromptMessage.ToolCall] = [] for tool_call in delta.delta.tool_calls or []: if tool_call.type == 'function': assistant_tool_calls.append( @@ -409,7 +410,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): return message_text - def _convert_messages_to_prompt(self, messages: List[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str: + def _convert_messages_to_prompt(self, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None) -> str: """ :param messages: List of PromptMessage to combine. :return: Combined string with necessary human_prompt and ai_prompt tags. diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index 30c373729..0f9fecfc7 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -1,5 +1,5 @@ import time -from typing import List, Optional, Tuple +from typing import Optional from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult @@ -81,7 +81,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): except Exception as ex: raise CredentialsValidateFailedError(str(ex)) - def embed_documents(self, model: str, client: ZhipuAI, texts: List[str]) -> Tuple[List[List[float]], int]: + def embed_documents(self, model: str, client: ZhipuAI, texts: list[str]) -> tuple[list[list[float]], int]: """Call out to ZhipuAI's embedding endpoint. Args: @@ -101,7 +101,7 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): return [list(map(float, e)) for e in embeddings], embedding_used_tokens - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> list[float]: """Call out to ZhipuAI's embedding endpoint. Args: diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py index 23fd968f3..29b174635 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/_client.py @@ -1,7 +1,8 @@ from __future__ import annotations import os -from typing import Mapping, Union +from collections.abc import Mapping +from typing import Union import httpx from httpx import Timeout diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py index 16c4b54f1..dab6dac5f 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/async_completions.py @@ -1,9 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import httpx -from typing_extensions import Literal from ...core._base_api import BaseAPI from ...core._base_type import NOT_GIVEN, Headers, NotGiven @@ -15,7 +14,7 @@ if TYPE_CHECKING: class AsyncCompletions(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) @@ -29,8 +28,8 @@ class AsyncCompletions(BaseAPI): top_p: Optional[float] | NotGiven = NOT_GIVEN, max_tokens: int | NotGiven = NOT_GIVEN, seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, List[str], List[int], List[List[int]], None], - stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], list[list[int]], None], + stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, tools: Optional[object] | NotGiven = NOT_GIVEN, tool_choice: str | NotGiven = NOT_GIVEN, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py index e5bb8cdf6..5c4ed4d1b 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/chat/completions.py @@ -1,9 +1,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Literal, Optional, Union import httpx -from typing_extensions import Literal from ...core._base_api import BaseAPI from ...core._base_type import NOT_GIVEN, Headers, NotGiven @@ -17,7 +16,7 @@ if TYPE_CHECKING: class Completions(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( @@ -31,8 +30,8 @@ class Completions(BaseAPI): top_p: Optional[float] | NotGiven = NOT_GIVEN, max_tokens: int | NotGiven = NOT_GIVEN, seed: int | NotGiven = NOT_GIVEN, - messages: Union[str, List[str], List[int], object, None], - stop: Optional[Union[str, List[str], None]] | NotGiven = NOT_GIVEN, + messages: Union[str, list[str], list[int], object, None], + stop: Optional[Union[str, list[str], None]] | NotGiven = NOT_GIVEN, sensitive_word_check: Optional[object] | NotGiven = NOT_GIVEN, tools: Optional[object] | NotGiven = NOT_GIVEN, tool_choice: str | NotGiven = NOT_GIVEN, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py index d5db469de..35d54592f 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/embeddings.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Optional, Union import httpx @@ -14,13 +14,13 @@ if TYPE_CHECKING: class Embeddings(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( self, *, - input: Union[str, List[str], List[int], List[List[int]]], + input: Union[str, list[str], list[int], list[list[int]]], model: Union[str], encoding_format: str | NotGiven = NOT_GIVEN, user: str | NotGiven = NOT_GIVEN, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py index 7796b778a..5deb8d08f 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/files.py @@ -18,7 +18,7 @@ __all__ = ["Files"] class Files(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py index ead6cdae2..b860de192 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/fine_tuning/jobs.py @@ -17,7 +17,7 @@ __all__ = ["Jobs"] class Jobs(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def create( diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py index ce852a48c..3201426df 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/api_resource/images.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: class Images(BaseAPI): - def __init__(self, client: "ZhipuAI") -> None: + def __init__(self, client: ZhipuAI) -> None: super().__init__(client) def generations( diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py index f3dde8461..b7cf6bb7f 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_base_type.py @@ -1,21 +1,22 @@ from __future__ import annotations +from collections.abc import Mapping, Sequence from os import PathLike -from typing import IO, TYPE_CHECKING, Any, List, Mapping, Sequence, Tuple, Type, TypeVar, Union +from typing import IO, TYPE_CHECKING, Any, Literal, TypeVar, Union import pydantic -from typing_extensions import Literal, override +from typing_extensions import override Query = Mapping[str, object] Body = object AnyMapping = Mapping[str, object] PrimitiveData = Union[str, int, float, bool, None] -Data = Union[PrimitiveData, List[Any], Tuple[Any], "Mapping[str, Any]"] +Data = Union[PrimitiveData, list[Any], tuple[Any], "Mapping[str, Any]"] ModelT = TypeVar("ModelT", bound=pydantic.BaseModel) _T = TypeVar("_T") if TYPE_CHECKING: - NoneType: Type[None] + NoneType: type[None] else: NoneType = type(None) @@ -74,7 +75,7 @@ Headers = Mapping[str, Union[str, Omit]] ResponseT = TypeVar( "ResponseT", - bound="Union[str, None, BaseModel, List[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", + bound="Union[str, None, BaseModel, list[Any], Dict[str, Any], Response, UnknownResponse, ModelBuilderProtocol, BinaryResponseContent]", ) # for user input files @@ -85,21 +86,21 @@ else: FileTypes = Union[ FileContent, # file content - Tuple[str, FileContent], # (filename, file) - Tuple[str, FileContent, str], # (filename, file , content_type) - Tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) + tuple[str, FileContent], # (filename, file) + tuple[str, FileContent, str], # (filename, file , content_type) + tuple[str, FileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) ] -RequestFiles = Union[Mapping[str, FileTypes], Sequence[Tuple[str, FileTypes]]] +RequestFiles = Union[Mapping[str, FileTypes], Sequence[tuple[str, FileTypes]]] # for httpx client supported files HttpxFileContent = Union[bytes, IO[bytes]] HttpxFileTypes = Union[ FileContent, # file content - Tuple[str, HttpxFileContent], # (filename, file) - Tuple[str, HttpxFileContent, str], # (filename, file , content_type) - Tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) + tuple[str, HttpxFileContent], # (filename, file) + tuple[str, HttpxFileContent, str], # (filename, file , content_type) + tuple[str, HttpxFileContent, str, Mapping[str, str]], # (filename, file , content_type, headers) ] -HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[Tuple[str, HttpxFileTypes]]] +HttpxRequestFiles = Union[Mapping[str, HttpxFileTypes], Sequence[tuple[str, HttpxFileTypes]]] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py index e41ede128..0796bfe11 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_files.py @@ -2,14 +2,14 @@ from __future__ import annotations import io import os +from collections.abc import Mapping, Sequence from pathlib import Path -from typing import Mapping, Sequence from ._base_type import FileTypes, HttpxFileTypes, HttpxRequestFiles, RequestFiles def is_file_content(obj: object) -> bool: - return isinstance(obj, (bytes, tuple, io.IOBase, os.PathLike)) + return isinstance(obj, bytes | tuple | io.IOBase | os.PathLike) def _transform_file(file: FileTypes) -> HttpxFileTypes: diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py index 5227d2061..e13d2b023 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_http_client.py @@ -1,8 +1,8 @@ -# -*- coding:utf-8 -*- from __future__ import annotations import inspect -from typing import Any, Mapping, Type, Union, cast +from collections.abc import Mapping +from typing import Any, Union, cast import httpx import pydantic @@ -140,7 +140,7 @@ class HttpClient: for k, v in value.items(): items.extend(self._object_to_formfata(f"{key}[{k}]", v)) return items - if isinstance(value, (list, tuple)): + if isinstance(value, list | tuple): for v in value: items.extend(self._object_to_formfata(key + "[]", v)) return items @@ -175,7 +175,7 @@ class HttpClient: def _parse_response( self, *, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], response: httpx.Response, enable_stream: bool, request_param: ClientRequestParam, @@ -224,7 +224,7 @@ class HttpClient: def request( self, *, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], params: ClientRequestParam, enable_stream: bool = False, stream_cls: type[StreamResponse[Any]] | None = None, @@ -259,7 +259,7 @@ class HttpClient: self, path: str, *, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], options: UserRequestInput = {}, enable_stream: bool = False, ) -> ResponseT | StreamResponse: @@ -274,7 +274,7 @@ class HttpClient: path: str, *, body: Body | None = None, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], options: UserRequestInput = {}, files: RequestFiles | None = None, enable_stream: bool = False, @@ -294,7 +294,7 @@ class HttpClient: path: str, *, body: Body | None = None, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], options: UserRequestInput = {}, ) -> ResponseT: opts = ClientRequestParam.construct(method="patch", url=path, json_data=body, **options) @@ -308,7 +308,7 @@ class HttpClient: path: str, *, body: Body | None = None, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], options: UserRequestInput = {}, files: RequestFiles | None = None, ) -> ResponseT | StreamResponse: @@ -324,7 +324,7 @@ class HttpClient: path: str, *, body: Body | None = None, - cast_type: Type[ResponseT], + cast_type: type[ResponseT], options: UserRequestInput = {}, ) -> ResponseT | StreamResponse: opts = ClientRequestParam.construct(method="delete", url=path, json_data=body, **options) diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py index bbf2e72e6..b0a91d04a 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_jwt_token.py @@ -1,4 +1,3 @@ -# -*- coding:utf-8 -*- import time import cachetools.func diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py index 2406e5782..a3f49ba84 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_request_opt.py @@ -1,10 +1,10 @@ from __future__ import annotations -from typing import Any, Union +from typing import Any, ClassVar, Union from httpx import Timeout from pydantic import ConfigDict -from typing_extensions import ClassVar, TypedDict, Unpack +from typing_extensions import TypedDict, Unpack from ._base_type import Body, Headers, HttpxRequestFiles, NotGiven, Query from ._utils import remove_notgiven_indict @@ -17,7 +17,7 @@ class UserRequestInput(TypedDict, total=False): params: Query | None -class ClientRequestParam(): +class ClientRequestParam: method: str url: str max_retries: Union[int, NotGiven] = NotGiven() diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py index 116246e64..2f831b6fc 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_response.py @@ -1,11 +1,11 @@ from __future__ import annotations import datetime -from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast +from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast, get_args, get_origin import httpx import pydantic -from typing_extensions import ParamSpec, get_args, get_origin +from typing_extensions import ParamSpec from ._base_type import NoneType from ._sse_client import StreamResponse @@ -19,7 +19,7 @@ R = TypeVar("R") class HttpResponse(Generic[R]): _cast_type: type[R] - _client: "HttpClient" + _client: HttpClient _parsed: R | None _enable_stream: bool _stream_cls: type[StreamResponse[Any]] @@ -30,7 +30,7 @@ class HttpResponse(Generic[R]): *, raw_response: httpx.Response, cast_type: type[R], - client: "HttpClient", + client: HttpClient, enable_stream: bool = False, stream_cls: type[StreamResponse[Any]] | None = None, ) -> None: diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py index 6efe20edc..66afbfd10 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_sse_client.py @@ -1,8 +1,8 @@ -# -*- coding:utf-8 -*- from __future__ import annotations import json -from typing import TYPE_CHECKING, Generic, Iterator, Mapping +from collections.abc import Iterator, Mapping +from typing import TYPE_CHECKING, Generic import httpx @@ -36,8 +36,7 @@ class StreamResponse(Generic[ResponseT]): return self._stream_chunks.__next__() def __iter__(self) -> Iterator[ResponseT]: - for item in self._stream_chunks: - yield item + yield from self._stream_chunks def __stream__(self) -> Iterator[ResponseT]: @@ -62,7 +61,7 @@ class StreamResponse(Generic[ResponseT]): pass -class Event(object): +class Event: def __init__( self, event: str | None = None, diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py index 78c97af65..6b610567d 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/core/_utils.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Iterable, Mapping, TypeVar +from collections.abc import Iterable, Mapping +from typing import TypeVar from ._base_type import NotGiven diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py index bae4197c5..f22f32d25 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/async_chat_completion.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -19,5 +19,5 @@ class AsyncCompletion(BaseModel): request_id: Optional[str] = None model: Optional[str] = None task_status: str - choices: List[CompletionChoice] + choices: list[CompletionChoice] usage: CompletionUsage \ No newline at end of file diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py index 524e218d3..b2a847c50 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -19,7 +19,7 @@ class CompletionMessageToolCall(BaseModel): class CompletionMessage(BaseModel): content: Optional[str] = None role: str - tool_calls: Optional[List[CompletionMessageToolCall]] = None + tool_calls: Optional[list[CompletionMessageToolCall]] = None class CompletionUsage(BaseModel): @@ -37,7 +37,7 @@ class CompletionChoice(BaseModel): class Completion(BaseModel): model: Optional[str] = None created: Optional[int] = None - choices: List[CompletionChoice] + choices: list[CompletionChoice] request_id: Optional[str] = None id: Optional[str] = None usage: CompletionUsage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py index c2e0c5766..c25069974 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/chat/chat_completion_chunk.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -32,7 +32,7 @@ class ChoiceDeltaToolCall(BaseModel): class ChoiceDelta(BaseModel): content: Optional[str] = None role: Optional[str] = None - tool_calls: Optional[List[ChoiceDeltaToolCall]] = None + tool_calls: Optional[list[ChoiceDeltaToolCall]] = None class Choice(BaseModel): @@ -49,7 +49,7 @@ class CompletionUsage(BaseModel): class ChatCompletionChunk(BaseModel): id: Optional[str] = None - choices: List[Choice] + choices: list[Choice] created: Optional[int] = None model: Optional[str] = None usage: Optional[CompletionUsage] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py index a8737cf8d..e01f2c815 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/embeddings.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -12,11 +12,11 @@ __all__ = ["Embedding", "EmbeddingsResponded"] class Embedding(BaseModel): object: str index: Optional[int] = None - embedding: List[float] + embedding: list[float] class EmbeddingsResponded(BaseModel): object: str - data: List[Embedding] + data: list[Embedding] model: str usage: CompletionUsage diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py index 94db046bd..917bda757 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/file_object.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -20,5 +20,5 @@ class FileObject(BaseModel): class ListOfFileObject(BaseModel): object: Optional[str] = None - data: List[FileObject] + data: list[FileObject] has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py index 6197b6faa..71c00eaff 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import Optional, Union from pydantic import BaseModel @@ -34,7 +34,7 @@ class FineTuningJob(BaseModel): object: Optional[str] = None - result_files: List[str] + result_files: list[str] status: str @@ -47,5 +47,5 @@ class FineTuningJob(BaseModel): class ListOfFineTuningJob(BaseModel): object: Optional[str] = None - data: List[FineTuningJob] + data: list[FineTuningJob] has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py index 6ff3f77fd..e26b44853 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/fine_tuning_job_event.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import Optional, Union from pydantic import BaseModel @@ -31,5 +31,5 @@ class JobEvent(BaseModel): class FineTuningJobEvent(BaseModel): object: Optional[str] = None - data: List[JobEvent] + data: list[JobEvent] has_more: Optional[bool] = None diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py index c661f7cdd..e1ebc352b 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/fine_tuning/job_create_params.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Union +from typing import Literal, Union -from typing_extensions import Literal, TypedDict +from typing_extensions import TypedDict __all__ = ["Hyperparameters"] diff --git a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py index 429a7e25b..b352ce095 100644 --- a/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py +++ b/api/core/model_runtime/model_providers/zhipuai/zhipuai_sdk/types/image.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -15,4 +15,4 @@ class GeneratedImage(BaseModel): class ImagesResponded(BaseModel): created: int - data: List[GeneratedImage] + data: list[GeneratedImage] diff --git a/api/core/model_runtime/utils/_compat.py b/api/core/model_runtime/utils/_compat.py index 305edcac8..5c3415275 100644 --- a/api/core/model_runtime/utils/_compat.py +++ b/api/core/model_runtime/utils/_compat.py @@ -1,8 +1,7 @@ -from typing import Any +from typing import Any, Literal from pydantic import BaseModel from pydantic.version import VERSION as PYDANTIC_VERSION -from typing_extensions import Literal PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index d0d93c34b..cf6c98e01 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -1,13 +1,14 @@ import dataclasses import datetime from collections import defaultdict, deque +from collections.abc import Callable from decimal import Decimal from enum import Enum from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network from pathlib import Path, PurePath from re import Pattern from types import GeneratorType -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union from uuid import UUID from pydantic import BaseModel @@ -46,7 +47,7 @@ def decimal_encoder(dec_value: Decimal) -> Union[int, float]: return float(dec_value) -ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { +ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { bytes: lambda o: o.decode(), Color: str, datetime.date: isoformat, @@ -77,9 +78,9 @@ ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = { def generate_encoders_by_class_tuples( - type_encoder_map: Dict[Any, Callable[[Any], Any]] -) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]: - encoders_by_class_tuples: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict( + type_encoder_map: dict[Any, Callable[[Any], Any]] +) -> dict[Callable[[Any], Any], tuple[Any, ...]]: + encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict( tuple ) for type_, encoder in type_encoder_map.items(): @@ -96,7 +97,7 @@ def jsonable_encoder( exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, - custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None, + custom_encoder: Optional[dict[Any, Callable[[Any], Any]]] = None, sqlalchemy_safe: bool = True, ) -> Any: custom_encoder = custom_encoder or {} @@ -109,7 +110,7 @@ def jsonable_encoder( return encoder_instance(obj) if isinstance(obj, BaseModel): # TODO: remove when deprecating Pydantic v1 - encoders: Dict[Any, Any] = {} + encoders: dict[Any, Any] = {} if not PYDANTIC_V2: encoders = getattr(obj.__config__, "json_encoders", {}) # type: ignore[attr-defined] if custom_encoder: @@ -149,7 +150,7 @@ def jsonable_encoder( return obj.value if isinstance(obj, PurePath): return str(obj) - if isinstance(obj, (str, int, float, type(None))): + if isinstance(obj, str | int | float | type(None)): return obj if isinstance(obj, Decimal): return format(obj, 'f') @@ -184,7 +185,7 @@ def jsonable_encoder( ) encoded_dict[encoded_key] = encoded_value return encoded_dict - if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)): + if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque): encoded_list = [] for item in obj: encoded_list.append( @@ -209,7 +210,7 @@ def jsonable_encoder( try: data = dict(obj) except Exception as e: - errors: List[Exception] = [] + errors: list[Exception] = [] errors.append(e) try: data = vars(obj) diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 5ffcaaec6..0a373b7c4 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -2,7 +2,7 @@ import enum import json import os import re -from typing import List, Optional, Tuple, cast +from typing import Optional, cast from core.entities.application_entities import ( AdvancedCompletionPromptTemplateEntity, @@ -67,11 +67,11 @@ class PromptTransform: prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: List[FileObj], + files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], model_config: ModelConfigEntity) -> \ - Tuple[List[PromptMessage], Optional[List[str]]]: + tuple[list[PromptMessage], Optional[list[str]]]: app_mode = AppMode.value_of(app_mode) model_mode = ModelMode.value_of(model_config.mode) @@ -115,10 +115,10 @@ class PromptTransform: prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: List[FileObj], + files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> List[PromptMessage]: + model_config: ModelConfigEntity) -> list[PromptMessage]: app_mode = AppMode.value_of(app_mode) model_mode = ModelMode.value_of(model_config.mode) @@ -182,7 +182,7 @@ class PromptTransform: ) def _get_history_messages_list_from_memory(self, memory: TokenBufferMemory, - max_token_limit: int) -> List[PromptMessage]: + max_token_limit: int) -> list[PromptMessage]: """Get memory messages.""" return memory.get_history_prompt_messages( max_token_limit=max_token_limit @@ -217,7 +217,7 @@ class PromptTransform: json_file_path = os.path.join(prompt_path, f'{prompt_name}.json') # Open the JSON file and read its content - with open(json_file_path, 'r', encoding='utf-8') as json_file: + with open(json_file_path, encoding='utf-8') as json_file: return json.load(json_file) def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict, @@ -225,9 +225,9 @@ class PromptTransform: inputs: dict, query: str, context: Optional[str], - files: List[FileObj], + files: list[FileObj], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> List[PromptMessage]: + model_config: ModelConfigEntity) -> list[PromptMessage]: prompt_messages = [] context_prompt_content = '' @@ -280,8 +280,8 @@ class PromptTransform: query: str, context: Optional[str], memory: Optional[TokenBufferMemory], - files: List[FileObj], - model_config: ModelConfigEntity) -> List[PromptMessage]: + files: list[FileObj], + model_config: ModelConfigEntity) -> list[PromptMessage]: context_prompt_content = '' if context and 'context_prompt' in prompt_rules: prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt']) @@ -451,10 +451,10 @@ class PromptTransform: prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: List[FileObj], + files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> List[PromptMessage]: + model_config: ModelConfigEntity) -> list[PromptMessage]: raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt role_prefix = prompt_template_entity.advanced_completion_prompt_template.role_prefix @@ -494,10 +494,10 @@ class PromptTransform: prompt_template_entity: PromptTemplateEntity, inputs: dict, query: str, - files: List[FileObj], + files: list[FileObj], context: Optional[str], memory: Optional[TokenBufferMemory], - model_config: ModelConfigEntity) -> List[PromptMessage]: + model_config: ModelConfigEntity) -> list[PromptMessage]: raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages prompt_messages = [] @@ -535,7 +535,7 @@ class PromptTransform: def _get_completion_app_completion_model_prompt_messages(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, - context: Optional[str]) -> List[PromptMessage]: + context: Optional[str]) -> list[PromptMessage]: raw_prompt = prompt_template_entity.advanced_completion_prompt_template.prompt prompt_messages = [] @@ -554,8 +554,8 @@ class PromptTransform: def _get_completion_app_chat_model_prompt_messages(self, prompt_template_entity: PromptTemplateEntity, inputs: dict, - files: List[FileObj], - context: Optional[str]) -> List[PromptMessage]: + files: list[FileObj], + context: Optional[str]) -> list[PromptMessage]: raw_prompt_list = prompt_template_entity.advanced_chat_prompt_template.messages prompt_messages = [] diff --git a/api/core/rerank/rerank.py b/api/core/rerank/rerank.py index 984cdb400..a675dfc56 100644 --- a/api/core/rerank/rerank.py +++ b/api/core/rerank/rerank.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from langchain.schema import Document @@ -9,8 +9,8 @@ class RerankRunner: def __init__(self, rerank_model_instance: ModelInstance) -> None: self.rerank_model_instance = rerank_model_instance - def run(self, query: str, documents: List[Document], score_threshold: Optional[float] = None, - top_n: Optional[int] = None, user: Optional[str] = None) -> List[Document]: + def run(self, query: str, documents: list[Document], score_threshold: Optional[float] = None, + top_n: Optional[int] = None, user: Optional[str] = None) -> list[Document]: """ Run rerank model :param query: search query diff --git a/api/core/splitter/fixed_text_splitter.py b/api/core/splitter/fixed_text_splitter.py index babb360a5..285a7ba14 100644 --- a/api/core/splitter/fixed_text_splitter.py +++ b/api/core/splitter/fixed_text_splitter.py @@ -1,7 +1,7 @@ """Functionality for splitting text.""" from __future__ import annotations -from typing import Any, List, Optional, cast +from typing import Any, Optional, cast from langchain.text_splitter import ( TS, @@ -28,8 +28,8 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): def from_encoder( cls: Type[TS], embedding_model_instance: Optional[ModelInstance], - allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), - disallowed_special: Union[Literal["all"], Collection[str]] = "all", + allowed_special: Union[Literal[all], AbstractSet[str]] = set(), + disallowed_special: Union[Literal[all], Collection[str]] = "all", **kwargs: Any, ): def _token_encoder(text: str) -> int: @@ -59,13 +59,13 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter): - def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any): + def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) self._fixed_separator = fixed_separator self._separators = separators or ["\n\n", "\n", " ", ""] - def split_text(self, text: str) -> List[str]: + def split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" if self._fixed_separator: chunks = text.split(self._fixed_separator) @@ -81,7 +81,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) return final_chunks - def recursive_split_text(self, text: str) -> List[str]: + def recursive_split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" final_chunks = [] # Get appropriate separator to use diff --git a/api/core/third_party/langchain/llms/fake.py b/api/core/third_party/langchain/llms/fake.py index 64117477e..ab5152b38 100644 --- a/api/core/third_party/langchain/llms/fake.py +++ b/api/core/third_party/langchain/llms/fake.py @@ -1,5 +1,6 @@ import time -from typing import Any, List, Mapping, Optional +from collections.abc import Mapping +from typing import Any, Optional from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chat_models.base import SimpleChatModel @@ -19,8 +20,8 @@ class FakeLLM(SimpleChatModel): def _call( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: @@ -36,8 +37,8 @@ class FakeLLM(SimpleChatModel): def _generate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, + messages: list[BaseMessage], + stop: Optional[list[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: diff --git a/api/core/tool/current_datetime_tool.py b/api/core/tool/current_datetime_tool.py index 3bb2bb5ea..208490a5b 100644 --- a/api/core/tool/current_datetime_tool.py +++ b/api/core/tool/current_datetime_tool.py @@ -1,5 +1,4 @@ from datetime import datetime -from typing import Type from langchain.tools import BaseTool from pydantic import BaseModel, Field @@ -12,7 +11,7 @@ class DatetimeToolInput(BaseModel): class DatetimeTool(BaseTool): """Tool for querying current datetime.""" name: str = "current_datetime" - args_schema: Type[BaseModel] = DatetimeToolInput + args_schema: type[BaseModel] = DatetimeToolInput description: str = "A tool when you want to get the current date, time, week, month or year, " \ "and the time zone is UTC. Result is \"