From 13be84e4d42d12000356c4143ac255bb4763b600 Mon Sep 17 00:00:00 2001 From: -LAN- Date: Mon, 26 Aug 2024 15:29:10 +0800 Subject: [PATCH] chore(api/controllers): Apply Ruff Formatter. (#7645) --- api/controllers/__init__.py | 2 - api/controllers/console/__init__.py | 2 +- api/controllers/console/admin.py | 95 +-- api/controllers/console/apikey.py | 116 ++-- .../console/app/advanced_prompt_template.py | 13 +- api/controllers/console/app/agent.py | 15 +- api/controllers/console/app/annotation.py | 132 ++-- api/controllers/console/app/app.py | 176 +++-- api/controllers/console/app/audio.py | 50 +- api/controllers/console/app/completion.py | 59 +- api/controllers/console/app/conversation.py | 193 +++--- .../console/app/conversation_variables.py | 24 +- api/controllers/console/app/error.py | 50 +- api/controllers/console/app/generator.py | 18 +- api/controllers/console/app/message.py | 147 +++-- api/controllers/console/app/model_config.py | 43 +- api/controllers/console/app/ops_trace.py | 33 +- api/controllers/console/app/site.py | 73 +-- api/controllers/console/app/statistic.py | 286 ++++---- api/controllers/console/app/workflow.py | 201 +++--- .../console/app/workflow_app_log.py | 13 +- api/controllers/console/app/workflow_run.py | 30 +- .../console/app/workflow_statistic.py | 187 +++--- api/controllers/console/app/wraps.py | 24 +- api/controllers/console/auth/activate.py | 53 +- .../console/auth/data_source_bearer_auth.py | 39 +- .../console/auth/data_source_oauth.py | 72 +- api/controllers/console/auth/error.py | 11 +- .../console/auth/forgot_password.py | 37 +- api/controllers/console/auth/login.py | 32 +- api/controllers/console/auth/oauth.py | 26 +- api/controllers/console/billing/billing.py | 18 +- .../console/datasets/data_source.py | 183 +++--- api/controllers/console/datasets/datasets.py | 474 ++++++------- .../console/datasets/datasets_document.py | 620 +++++++++--------- .../console/datasets/datasets_segments.py | 214 +++--- api/controllers/console/datasets/error.py | 30 +- api/controllers/console/datasets/file.py | 26 +- .../console/datasets/hit_testing.py | 18 +- api/controllers/console/datasets/website.py | 16 +- api/controllers/console/error.py | 26 +- api/controllers/console/explore/audio.py | 49 +- api/controllers/console/explore/completion.py | 67 +- .../console/explore/conversation.py | 55 +- api/controllers/console/explore/error.py | 8 +- .../console/explore/installed_app.py | 76 +-- api/controllers/console/explore/message.py | 55 +- api/controllers/console/explore/parameter.py | 100 +-- .../console/explore/recommended_app.py | 42 +- .../console/explore/saved_message.py | 56 +- api/controllers/console/explore/workflow.py | 20 +- api/controllers/console/explore/wraps.py | 24 +- api/controllers/console/extension.py | 44 +- api/controllers/console/feature.py | 5 +- api/controllers/console/init_validate.py | 32 +- api/controllers/console/ping.py | 7 +- api/controllers/console/setup.py | 37 +- api/controllers/console/tag/tags.py | 81 +-- api/controllers/console/version.py | 36 +- api/controllers/console/workspace/account.py | 156 +++-- api/controllers/console/workspace/error.py | 12 +- .../workspace/load_balancing_config.py | 60 +- api/controllers/console/workspace/members.py | 80 ++- .../console/workspace/model_providers.py | 134 ++-- api/controllers/console/workspace/models.py | 276 ++++---- .../console/workspace/tool_providers.py | 400 ++++++----- .../console/workspace/workspace.py | 143 ++-- api/controllers/console/wraps.py | 39 +- api/controllers/files/__init__.py | 2 +- api/controllers/files/image_preview.py | 27 +- api/controllers/files/tool_files.py | 29 +- api/controllers/inner_api/__init__.py | 3 +- .../inner_api/workspace/workspace.py | 21 +- api/controllers/inner_api/wraps.py | 20 +- api/controllers/service_api/__init__.py | 2 +- api/controllers/service_api/app/app.py | 104 +-- api/controllers/service_api/app/audio.py | 48 +- api/controllers/service_api/app/completion.py | 52 +- .../service_api/app/conversation.py | 42 +- api/controllers/service_api/app/error.py | 46 +- api/controllers/service_api/app/file.py | 8 +- api/controllers/service_api/app/message.py | 118 ++-- api/controllers/service_api/app/workflow.py | 49 +- .../service_api/dataset/dataset.py | 88 +-- .../service_api/dataset/document.py | 295 ++++----- api/controllers/service_api/dataset/error.py | 26 +- .../service_api/dataset/segment.py | 127 ++-- api/controllers/service_api/index.py | 2 +- api/controllers/service_api/wraps.py | 109 +-- api/controllers/web/__init__.py | 2 +- api/controllers/web/app.py | 96 +-- api/controllers/web/audio.py | 50 +- api/controllers/web/completion.py | 59 +- api/controllers/web/conversation.py | 51 +- api/controllers/web/error.py | 52 +- api/controllers/web/feature.py | 2 +- api/controllers/web/file.py | 7 +- api/controllers/web/message.py | 108 ++- api/controllers/web/passport.py | 33 +- api/controllers/web/saved_message.py | 48 +- api/controllers/web/site.py | 73 ++- api/controllers/web/workflow.py | 18 +- api/controllers/web/wraps.py | 42 +- api/pyproject.toml | 1 - 104 files changed, 3849 insertions(+), 3982 deletions(-) diff --git a/api/controllers/__init__.py b/api/controllers/__init__.py index b28b04f64..8b1378917 100644 --- a/api/controllers/__init__.py +++ b/api/controllers/__init__.py @@ -1,3 +1 @@ - - diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index b2b9d8d49..eb7c1464d 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -2,7 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi -bp = Blueprint('console', __name__, url_prefix='/console/api') +bp = Blueprint("console", __name__, url_prefix="/console/api") api = ExternalApi(bp) # Import other controllers diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 028be5de5..a4ceec266 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -15,24 +15,24 @@ from models.model import App, InstalledApp, RecommendedApp def admin_required(view): @wraps(view) def decorated(*args, **kwargs): - if not os.getenv('ADMIN_API_KEY'): - raise Unauthorized('API key is invalid.') + if not os.getenv("ADMIN_API_KEY"): + raise Unauthorized("API key is invalid.") - auth_header = request.headers.get('Authorization') + auth_header = request.headers.get("Authorization") if auth_header is None: - raise Unauthorized('Authorization header is missing.') + raise Unauthorized("Authorization header is missing.") - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - if os.getenv('ADMIN_API_KEY') != auth_token: - raise Unauthorized('API key is invalid.') + if os.getenv("ADMIN_API_KEY") != auth_token: + raise Unauthorized("API key is invalid.") return view(*args, **kwargs) @@ -44,37 +44,41 @@ class InsertExploreAppListApi(Resource): @admin_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('app_id', type=str, required=True, nullable=False, location='json') - parser.add_argument('desc', type=str, location='json') - parser.add_argument('copyright', type=str, location='json') - parser.add_argument('privacy_policy', type=str, location='json') - parser.add_argument('custom_disclaimer', type=str, location='json') - parser.add_argument('language', type=supported_language, required=True, nullable=False, location='json') - parser.add_argument('category', type=str, required=True, nullable=False, location='json') - parser.add_argument('position', type=int, required=True, nullable=False, location='json') + parser.add_argument("app_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("desc", type=str, location="json") + parser.add_argument("copyright", type=str, location="json") + parser.add_argument("privacy_policy", type=str, location="json") + parser.add_argument("custom_disclaimer", type=str, location="json") + parser.add_argument("language", type=supported_language, required=True, nullable=False, location="json") + parser.add_argument("category", type=str, required=True, nullable=False, location="json") + parser.add_argument("position", type=int, required=True, nullable=False, location="json") args = parser.parse_args() - app = App.query.filter(App.id == args['app_id']).first() + app = App.query.filter(App.id == args["app_id"]).first() if not app: raise NotFound(f'App \'{args["app_id"]}\' is not found') site = app.site if not site: - desc = args['desc'] if args['desc'] else '' - copy_right = args['copyright'] if args['copyright'] else '' - privacy_policy = args['privacy_policy'] if args['privacy_policy'] else '' - custom_disclaimer = args['custom_disclaimer'] if args['custom_disclaimer'] else '' + desc = args["desc"] if args["desc"] else "" + copy_right = args["copyright"] if args["copyright"] else "" + privacy_policy = args["privacy_policy"] if args["privacy_policy"] else "" + custom_disclaimer = args["custom_disclaimer"] if args["custom_disclaimer"] else "" else: - desc = site.description if site.description else \ - args['desc'] if args['desc'] else '' - copy_right = site.copyright if site.copyright else \ - args['copyright'] if args['copyright'] else '' - privacy_policy = site.privacy_policy if site.privacy_policy else \ - args['privacy_policy'] if args['privacy_policy'] else '' - custom_disclaimer = site.custom_disclaimer if site.custom_disclaimer else \ - args['custom_disclaimer'] if args['custom_disclaimer'] else '' + desc = site.description if site.description else args["desc"] if args["desc"] else "" + copy_right = site.copyright if site.copyright else args["copyright"] if args["copyright"] else "" + privacy_policy = ( + site.privacy_policy if site.privacy_policy else args["privacy_policy"] if args["privacy_policy"] else "" + ) + custom_disclaimer = ( + site.custom_disclaimer + if site.custom_disclaimer + else args["custom_disclaimer"] + if args["custom_disclaimer"] + else "" + ) - recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() + recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() if not recommended_app: recommended_app = RecommendedApp( @@ -83,9 +87,9 @@ class InsertExploreAppListApi(Resource): copyright=copy_right, privacy_policy=privacy_policy, custom_disclaimer=custom_disclaimer, - language=args['language'], - category=args['category'], - position=args['position'] + language=args["language"], + category=args["category"], + position=args["position"], ) db.session.add(recommended_app) @@ -93,21 +97,21 @@ class InsertExploreAppListApi(Resource): app.is_public = True db.session.commit() - return {'result': 'success'}, 201 + return {"result": "success"}, 201 else: recommended_app.description = desc recommended_app.copyright = copy_right recommended_app.privacy_policy = privacy_policy recommended_app.custom_disclaimer = custom_disclaimer - recommended_app.language = args['language'] - recommended_app.category = args['category'] - recommended_app.position = args['position'] + recommended_app.language = args["language"] + recommended_app.category = args["category"] + recommended_app.position = args["position"] app.is_public = True db.session.commit() - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class InsertExploreAppApi(Resource): @@ -116,15 +120,14 @@ class InsertExploreAppApi(Resource): def delete(self, app_id): recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first() if not recommended_app: - return {'result': 'success'}, 204 + return {"result": "success"}, 204 app = App.query.filter(App.id == recommended_app.app_id).first() if app: app.is_public = False installed_apps = InstalledApp.query.filter( - InstalledApp.app_id == recommended_app.app_id, - InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id + InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id ).all() for installed_app in installed_apps: @@ -133,8 +136,8 @@ class InsertExploreAppApi(Resource): db.session.delete(recommended_app) db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 -api.add_resource(InsertExploreAppListApi, '/admin/insert-explore-apps') -api.add_resource(InsertExploreAppApi, '/admin/insert-explore-apps/') +api.add_resource(InsertExploreAppListApi, "/admin/insert-explore-apps") +api.add_resource(InsertExploreAppApi, "/admin/insert-explore-apps/") diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index 324b83117..3f5e1adca 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -14,26 +14,21 @@ from .setup import setup_required from .wraps import account_initialization_required api_key_fields = { - 'id': fields.String, - 'type': fields.String, - 'token': fields.String, - 'last_used_at': TimestampField, - 'created_at': TimestampField + "id": fields.String, + "type": fields.String, + "token": fields.String, + "last_used_at": TimestampField, + "created_at": TimestampField, } -api_key_list = { - 'data': fields.List(fields.Nested(api_key_fields), attribute="items") -} +api_key_list = {"data": fields.List(fields.Nested(api_key_fields), attribute="items")} def _get_resource(resource_id, tenant_id, resource_model): - resource = resource_model.query.filter_by( - id=resource_id, tenant_id=tenant_id - ).first() + resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first() if resource is None: - flask_restful.abort( - 404, message=f"{resource_model.__name__} not found.") + flask_restful.abort(404, message=f"{resource_model.__name__} not found.") return resource @@ -50,30 +45,32 @@ class BaseApiKeyListResource(Resource): @marshal_with(api_key_list) def get(self, resource_id): resource_id = str(resource_id) - _get_resource(resource_id, current_user.current_tenant_id, - self.resource_model) - keys = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \ - all() + _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) + keys = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) + .all() + ) return {"items": keys} @marshal_with(api_key_fields) def post(self, resource_id): resource_id = str(resource_id) - _get_resource(resource_id, current_user.current_tenant_id, - self.resource_model) + _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) if not current_user.is_admin_or_owner: raise Forbidden() - current_key_count = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id). \ - count() + current_key_count = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) + .count() + ) if current_key_count >= self.max_keys: flask_restful.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", - code='max_keys_exceeded' + code="max_keys_exceeded", ) key = ApiToken.generate_api_key(self.token_prefix, 24) @@ -97,79 +94,78 @@ class BaseApiKeyResource(Resource): def delete(self, resource_id, api_key_id): resource_id = str(resource_id) api_key_id = str(api_key_id) - _get_resource(resource_id, current_user.current_tenant_id, - self.resource_model) + _get_resource(resource_id, current_user.current_tenant_id, self.resource_model) # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - key = db.session.query(ApiToken). \ - filter(getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id). \ - first() + key = ( + db.session.query(ApiToken) + .filter( + getattr(ApiToken, self.resource_id_field) == resource_id, + ApiToken.type == self.resource_type, + ApiToken.id == api_key_id, + ) + .first() + ) if key is None: - flask_restful.abort(404, message='API key not found') + flask_restful.abort(404, message="API key not found") db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class AppApiKeyListResource(BaseApiKeyListResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'app' + resource_type = "app" resource_model = App - resource_id_field = 'app_id' - token_prefix = 'app-' + resource_id_field = "app_id" + token_prefix = "app-" class AppApiKeyResource(BaseApiKeyResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'app' + resource_type = "app" resource_model = App - resource_id_field = 'app_id' + resource_id_field = "app_id" class DatasetApiKeyListResource(BaseApiKeyListResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'dataset' + resource_type = "dataset" resource_model = Dataset - resource_id_field = 'dataset_id' - token_prefix = 'ds-' + resource_id_field = "dataset_id" + token_prefix = "ds-" class DatasetApiKeyResource(BaseApiKeyResource): - def after_request(self, resp): - resp.headers['Access-Control-Allow-Origin'] = '*' - resp.headers['Access-Control-Allow-Credentials'] = 'true' + resp.headers["Access-Control-Allow-Origin"] = "*" + resp.headers["Access-Control-Allow-Credentials"] = "true" return resp - resource_type = 'dataset' + + resource_type = "dataset" resource_model = Dataset - resource_id_field = 'dataset_id' + resource_id_field = "dataset_id" -api.add_resource(AppApiKeyListResource, '/apps//api-keys') -api.add_resource(AppApiKeyResource, - '/apps//api-keys/') -api.add_resource(DatasetApiKeyListResource, - '/datasets//api-keys') -api.add_resource(DatasetApiKeyResource, - '/datasets//api-keys/') +api.add_resource(AppApiKeyListResource, "/apps//api-keys") +api.add_resource(AppApiKeyResource, "/apps//api-keys/") +api.add_resource(DatasetApiKeyListResource, "/datasets//api-keys") +api.add_resource(DatasetApiKeyResource, "/datasets//api-keys/") diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py index fa2b3807e..e7346bdf1 100644 --- a/api/controllers/console/app/advanced_prompt_template.py +++ b/api/controllers/console/app/advanced_prompt_template.py @@ -8,19 +8,18 @@ from services.advanced_prompt_template_service import AdvancedPromptTemplateServ class AdvancedPromptTemplateList(Resource): - @setup_required @login_required @account_initialization_required def get(self): - parser = reqparse.RequestParser() - parser.add_argument('app_mode', type=str, required=True, location='args') - parser.add_argument('model_mode', type=str, required=True, location='args') - parser.add_argument('has_context', type=str, required=False, default='true', location='args') - parser.add_argument('model_name', type=str, required=True, location='args') + parser.add_argument("app_mode", type=str, required=True, location="args") + parser.add_argument("model_mode", type=str, required=True, location="args") + parser.add_argument("has_context", type=str, required=False, default="true", location="args") + parser.add_argument("model_name", type=str, required=True, location="args") args = parser.parse_args() return AdvancedPromptTemplateService.get_prompt(args) -api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates') \ No newline at end of file + +api.add_resource(AdvancedPromptTemplateList, "/app/prompt-templates") diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py index aee367276..51899da70 100644 --- a/api/controllers/console/app/agent.py +++ b/api/controllers/console/app/agent.py @@ -18,15 +18,12 @@ class AgentLogApi(Resource): def get(self, app_model): """Get agent logs""" parser = reqparse.RequestParser() - parser.add_argument('message_id', type=uuid_value, required=True, location='args') - parser.add_argument('conversation_id', type=uuid_value, required=True, location='args') + parser.add_argument("message_id", type=uuid_value, required=True, location="args") + parser.add_argument("conversation_id", type=uuid_value, required=True, location="args") args = parser.parse_args() - return AgentService.get_agent_logs( - app_model, - args['conversation_id'], - args['message_id'] - ) - -api.add_resource(AgentLogApi, '/apps//agent/logs') \ No newline at end of file + return AgentService.get_agent_logs(app_model, args["conversation_id"], args["message_id"]) + + +api.add_resource(AgentLogApi, "/apps//agent/logs") diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index bc15919a9..1ea1c8267 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -21,23 +21,23 @@ class AnnotationReplyActionApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def post(self, app_id, action): if not current_user.is_editor: raise Forbidden() app_id = str(app_id) parser = reqparse.RequestParser() - parser.add_argument('score_threshold', required=True, type=float, location='json') - parser.add_argument('embedding_provider_name', required=True, type=str, location='json') - parser.add_argument('embedding_model_name', required=True, type=str, location='json') + parser.add_argument("score_threshold", required=True, type=float, location="json") + parser.add_argument("embedding_provider_name", required=True, type=str, location="json") + parser.add_argument("embedding_model_name", required=True, type=str, location="json") args = parser.parse_args() - if action == 'enable': + if action == "enable": result = AppAnnotationService.enable_app_annotation(args, app_id) - elif action == 'disable': + elif action == "disable": result = AppAnnotationService.disable_app_annotation(app_id) else: - raise ValueError('Unsupported annotation reply action') + raise ValueError("Unsupported annotation reply action") return result, 200 @@ -66,7 +66,7 @@ class AppAnnotationSettingUpdateApi(Resource): annotation_setting_id = str(annotation_setting_id) parser = reqparse.RequestParser() - parser.add_argument('score_threshold', required=True, type=float, location='json') + parser.add_argument("score_threshold", required=True, type=float, location="json") args = parser.parse_args() result = AppAnnotationService.update_app_annotation_setting(app_id, annotation_setting_id, args) @@ -77,28 +77,24 @@ class AnnotationReplyActionStatusApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def get(self, app_id, job_id, action): if not current_user.is_editor: raise Forbidden() job_id = str(job_id) - app_annotation_job_key = '{}_app_annotation_job_{}'.format(action, str(job_id)) + app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) cache_result = redis_client.get(app_annotation_job_key) if cache_result is None: raise ValueError("The job is not exist.") job_status = cache_result.decode() - error_msg = '' - if job_status == 'error': - app_annotation_error_key = '{}_app_annotation_error_{}'.format(action, str(job_id)) + error_msg = "" + if job_status == "error": + app_annotation_error_key = "{}_app_annotation_error_{}".format(action, str(job_id)) error_msg = redis_client.get(app_annotation_error_key).decode() - return { - 'job_id': job_id, - 'job_status': job_status, - 'error_msg': error_msg - }, 200 + return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 class AnnotationListApi(Resource): @@ -109,18 +105,18 @@ class AnnotationListApi(Resource): if not current_user.is_editor: raise Forbidden() - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - keyword = request.args.get('keyword', default=None, type=str) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + keyword = request.args.get("keyword", default=None, type=str) app_id = str(app_id) annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword) response = { - 'data': marshal(annotation_list, annotation_fields), - 'has_more': len(annotation_list) == limit, - 'limit': limit, - 'total': total, - 'page': page + "data": marshal(annotation_list, annotation_fields), + "has_more": len(annotation_list) == limit, + "limit": limit, + "total": total, + "page": page, } return response, 200 @@ -135,9 +131,7 @@ class AnnotationExportApi(Resource): app_id = str(app_id) annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id) - response = { - 'data': marshal(annotation_list, annotation_fields) - } + response = {"data": marshal(annotation_list, annotation_fields)} return response, 200 @@ -145,7 +139,7 @@ class AnnotationCreateApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") @marshal_with(annotation_fields) def post(self, app_id): if not current_user.is_editor: @@ -153,8 +147,8 @@ class AnnotationCreateApi(Resource): app_id = str(app_id) parser = reqparse.RequestParser() - parser.add_argument('question', required=True, type=str, location='json') - parser.add_argument('answer', required=True, type=str, location='json') + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") args = parser.parse_args() annotation = AppAnnotationService.insert_app_annotation_directly(args, app_id) return annotation @@ -164,7 +158,7 @@ class AnnotationUpdateDeleteApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") @marshal_with(annotation_fields) def post(self, app_id, annotation_id): if not current_user.is_editor: @@ -173,8 +167,8 @@ class AnnotationUpdateDeleteApi(Resource): app_id = str(app_id) annotation_id = str(annotation_id) parser = reqparse.RequestParser() - parser.add_argument('question', required=True, type=str, location='json') - parser.add_argument('answer', required=True, type=str, location='json') + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") args = parser.parse_args() annotation = AppAnnotationService.update_app_annotation_directly(args, app_id, annotation_id) return annotation @@ -189,29 +183,29 @@ class AnnotationUpdateDeleteApi(Resource): app_id = str(app_id) annotation_id = str(annotation_id) AppAnnotationService.delete_app_annotation(app_id, annotation_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class AnnotationBatchImportApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def post(self, app_id): if not current_user.is_editor: raise Forbidden() app_id = str(app_id) # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() # check file type - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") return AppAnnotationService.batch_import_app_annotations(app_id, file) @@ -220,27 +214,23 @@ class AnnotationBatchImportStatusApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") def get(self, app_id, job_id): if not current_user.is_editor: raise Forbidden() job_id = str(job_id) - indexing_cache_key = 'app_annotation_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id)) cache_result = redis_client.get(indexing_cache_key) if cache_result is None: raise ValueError("The job is not exist.") job_status = cache_result.decode() - error_msg = '' - if job_status == 'error': - indexing_error_msg_key = 'app_annotation_batch_import_error_msg_{}'.format(str(job_id)) + error_msg = "" + if job_status == "error": + indexing_error_msg_key = "app_annotation_batch_import_error_msg_{}".format(str(job_id)) error_msg = redis_client.get(indexing_error_msg_key).decode() - return { - 'job_id': job_id, - 'job_status': job_status, - 'error_msg': error_msg - }, 200 + return {"job_id": job_id, "job_status": job_status, "error_msg": error_msg}, 200 class AnnotationHitHistoryListApi(Resource): @@ -251,30 +241,32 @@ class AnnotationHitHistoryListApi(Resource): if not current_user.is_editor: raise Forbidden() - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) app_id = str(app_id) annotation_id = str(annotation_id) - annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(app_id, annotation_id, - page, limit) + annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories( + app_id, annotation_id, page, limit + ) response = { - 'data': marshal(annotation_hit_history_list, annotation_hit_history_fields), - 'has_more': len(annotation_hit_history_list) == limit, - 'limit': limit, - 'total': total, - 'page': page + "data": marshal(annotation_hit_history_list, annotation_hit_history_fields), + "has_more": len(annotation_hit_history_list) == limit, + "limit": limit, + "total": total, + "page": page, } return response -api.add_resource(AnnotationReplyActionApi, '/apps//annotation-reply/') -api.add_resource(AnnotationReplyActionStatusApi, - '/apps//annotation-reply//status/') -api.add_resource(AnnotationListApi, '/apps//annotations') -api.add_resource(AnnotationExportApi, '/apps//annotations/export') -api.add_resource(AnnotationUpdateDeleteApi, '/apps//annotations/') -api.add_resource(AnnotationBatchImportApi, '/apps//annotations/batch-import') -api.add_resource(AnnotationBatchImportStatusApi, '/apps//annotations/batch-import-status/') -api.add_resource(AnnotationHitHistoryListApi, '/apps//annotations//hit-histories') -api.add_resource(AppAnnotationSettingDetailApi, '/apps//annotation-setting') -api.add_resource(AppAnnotationSettingUpdateApi, '/apps//annotation-settings/') +api.add_resource(AnnotationReplyActionApi, "/apps//annotation-reply/") +api.add_resource( + AnnotationReplyActionStatusApi, "/apps//annotation-reply//status/" +) +api.add_resource(AnnotationListApi, "/apps//annotations") +api.add_resource(AnnotationExportApi, "/apps//annotations/export") +api.add_resource(AnnotationUpdateDeleteApi, "/apps//annotations/") +api.add_resource(AnnotationBatchImportApi, "/apps//annotations/batch-import") +api.add_resource(AnnotationBatchImportStatusApi, "/apps//annotations/batch-import-status/") +api.add_resource(AnnotationHitHistoryListApi, "/apps//annotations//hit-histories") +api.add_resource(AppAnnotationSettingDetailApi, "/apps//annotation-setting") +api.add_resource(AppAnnotationSettingUpdateApi, "/apps//annotation-settings/") diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 8651597fd..cc9c8b31c 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -18,27 +18,35 @@ from libs.login import login_required from services.app_dsl_service import AppDslService from services.app_service import AppService -ALLOW_CREATE_APP_MODES = ['chat', 'agent-chat', 'advanced-chat', 'workflow', 'completion'] +ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] class AppListApi(Resource): - @setup_required @login_required @account_initialization_required def get(self): """Get app list""" + def uuid_list(value): try: - return [str(uuid.UUID(v)) for v in value.split(',')] + return [str(uuid.UUID(v)) for v in value.split(",")] except ValueError: abort(400, message="Invalid UUID format in tag_ids.") + parser = reqparse.RequestParser() - parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') - parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('mode', type=str, choices=['chat', 'workflow', 'agent-chat', 'channel', 'all'], default='all', location='args', required=False) - parser.add_argument('name', type=str, location='args', required=False) - parser.add_argument('tag_ids', type=uuid_list, location='args', required=False) + parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "mode", + type=str, + choices=["chat", "workflow", "agent-chat", "channel", "all"], + default="all", + location="args", + required=False, + ) + parser.add_argument("name", type=str, location="args", required=False) + parser.add_argument("tag_ids", type=uuid_list, location="args", required=False) args = parser.parse_args() @@ -46,7 +54,7 @@ class AppListApi(Resource): app_service = AppService() app_pagination = app_service.get_paginate_apps(current_user.current_tenant_id, args) if not app_pagination: - return {'data': [], 'total': 0, 'page': 1, 'limit': 20, 'has_more': False} + return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} return marshal(app_pagination, app_pagination_fields) @@ -54,23 +62,23 @@ class AppListApi(Resource): @login_required @account_initialization_required @marshal_with(app_detail_fields) - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): """Create app""" parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('mode', type=str, choices=ALLOW_CREATE_APP_MODES, location='json') - parser.add_argument('icon_type', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("mode", type=str, choices=ALLOW_CREATE_APP_MODES, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - if 'mode' not in args or args['mode'] is None: + if "mode" not in args or args["mode"] is None: raise BadRequest("mode is required") app_service = AppService() @@ -84,7 +92,7 @@ class AppImportApi(Resource): @login_required @account_initialization_required @marshal_with(app_detail_fields_with_site) - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): """Import app""" # The role of the current user in the ta table must be admin, owner, or editor @@ -92,19 +100,16 @@ class AppImportApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('data', type=str, required=True, nullable=False, location='json') - parser.add_argument('name', type=str, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon_type', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("data", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() app = AppDslService.import_and_create_new_app( - tenant_id=current_user.current_tenant_id, - data=args['data'], - args=args, - account=current_user + tenant_id=current_user.current_tenant_id, data=args["data"], args=args, account=current_user ) return app, 201 @@ -115,7 +120,7 @@ class AppImportFromUrlApi(Resource): @login_required @account_initialization_required @marshal_with(app_detail_fields_with_site) - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): """Import app from url""" # The role of the current user in the ta table must be admin, owner, or editor @@ -123,25 +128,21 @@ class AppImportFromUrlApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('url', type=str, required=True, nullable=False, location='json') - parser.add_argument('name', type=str, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("url", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() app = AppDslService.import_and_create_new_app_from_url( - tenant_id=current_user.current_tenant_id, - url=args['url'], - args=args, - account=current_user + tenant_id=current_user.current_tenant_id, url=args["url"], args=args, account=current_user ) return app, 201 class AppApi(Resource): - @setup_required @login_required @account_initialization_required @@ -165,14 +166,14 @@ class AppApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, nullable=False, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon_type', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') - parser.add_argument('max_active_requests', type=int, location='json') + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") + parser.add_argument("max_active_requests", type=int, location="json") args = parser.parse_args() app_service = AppService() @@ -193,7 +194,7 @@ class AppApi(Resource): app_service = AppService() app_service.delete_app(app_model) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class AppCopyApi(Resource): @@ -209,19 +210,16 @@ class AppCopyApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('name', type=str, location='json') - parser.add_argument('description', type=str, location='json') - parser.add_argument('icon_type', type=str, location='json') - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("name", type=str, location="json") + parser.add_argument("description", type=str, location="json") + parser.add_argument("icon_type", type=str, location="json") + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() data = AppDslService.export_dsl(app_model=app_model, include_secret=True) app = AppDslService.import_and_create_new_app( - tenant_id=current_user.current_tenant_id, - data=data, - args=args, - account=current_user + tenant_id=current_user.current_tenant_id, data=data, args=args, account=current_user ) return app, 201 @@ -240,12 +238,10 @@ class AppExportApi(Resource): # Add include_secret params parser = reqparse.RequestParser() - parser.add_argument('include_secret', type=inputs.boolean, default=False, location='args') + parser.add_argument("include_secret", type=inputs.boolean, default=False, location="args") args = parser.parse_args() - return { - "data": AppDslService.export_dsl(app_model=app_model, include_secret=args['include_secret']) - } + return {"data": AppDslService.export_dsl(app_model=app_model, include_secret=args["include_secret"])} class AppNameApi(Resource): @@ -258,13 +254,13 @@ class AppNameApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_name(app_model, args.get('name')) + app_model = app_service.update_app_name(app_model, args.get("name")) return app_model @@ -279,14 +275,14 @@ class AppIconApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('icon', type=str, location='json') - parser.add_argument('icon_background', type=str, location='json') + parser.add_argument("icon", type=str, location="json") + parser.add_argument("icon_background", type=str, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_icon(app_model, args.get('icon'), args.get('icon_background')) + app_model = app_service.update_app_icon(app_model, args.get("icon"), args.get("icon_background")) return app_model @@ -301,13 +297,13 @@ class AppSiteStatus(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('enable_site', type=bool, required=True, location='json') + parser.add_argument("enable_site", type=bool, required=True, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_site_status(app_model, args.get('enable_site')) + app_model = app_service.update_app_site_status(app_model, args.get("enable_site")) return app_model @@ -322,13 +318,13 @@ class AppApiStatus(Resource): # The role of the current user in the ta table must be admin or owner if not current_user.is_admin_or_owner: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('enable_api', type=bool, required=True, location='json') + parser.add_argument("enable_api", type=bool, required=True, location="json") args = parser.parse_args() app_service = AppService() - app_model = app_service.update_app_api_status(app_model, args.get('enable_api')) + app_model = app_service.update_app_api_status(app_model, args.get("enable_api")) return app_model @@ -339,9 +335,7 @@ class AppTraceApi(Resource): @account_initialization_required def get(self, app_id): """Get app trace""" - app_trace_config = OpsTraceManager.get_app_tracing_config( - app_id=app_id - ) + app_trace_config = OpsTraceManager.get_app_tracing_config(app_id=app_id) return app_trace_config @@ -353,27 +347,27 @@ class AppTraceApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('enabled', type=bool, required=True, location='json') - parser.add_argument('tracing_provider', type=str, required=True, location='json') + parser.add_argument("enabled", type=bool, required=True, location="json") + parser.add_argument("tracing_provider", type=str, required=True, location="json") args = parser.parse_args() OpsTraceManager.update_app_tracing_config( app_id=app_id, - enabled=args['enabled'], - tracing_provider=args['tracing_provider'], + enabled=args["enabled"], + tracing_provider=args["tracing_provider"], ) return {"result": "success"} -api.add_resource(AppListApi, '/apps') -api.add_resource(AppImportApi, '/apps/import') -api.add_resource(AppImportFromUrlApi, '/apps/import/url') -api.add_resource(AppApi, '/apps/') -api.add_resource(AppCopyApi, '/apps//copy') -api.add_resource(AppExportApi, '/apps//export') -api.add_resource(AppNameApi, '/apps//name') -api.add_resource(AppIconApi, '/apps//icon') -api.add_resource(AppSiteStatus, '/apps//site-enable') -api.add_resource(AppApiStatus, '/apps//api-enable') -api.add_resource(AppTraceApi, '/apps//trace') +api.add_resource(AppListApi, "/apps") +api.add_resource(AppImportApi, "/apps/import") +api.add_resource(AppImportFromUrlApi, "/apps/import/url") +api.add_resource(AppApi, "/apps/") +api.add_resource(AppCopyApi, "/apps//copy") +api.add_resource(AppExportApi, "/apps//export") +api.add_resource(AppNameApi, "/apps//name") +api.add_resource(AppIconApi, "/apps//icon") +api.add_resource(AppSiteStatus, "/apps//site-enable") +api.add_resource(AppApiStatus, "/apps//api-enable") +api.add_resource(AppTraceApi, "/apps//trace") diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 1de08afa4..437a6a7b3 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -39,7 +39,7 @@ class ChatMessageAudioApi(Resource): @account_initialization_required @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]) def post(self, app_model): - file = request.files['file'] + file = request.files["file"] try: response = AudioService.transcript_asr( @@ -85,31 +85,31 @@ class ChatMessageTextApi(Resource): try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get( - 'voice') + voice = ( + args.get("voice") + if args.get("voice") + else app_model.app_model_config.text_to_speech_dict.get("voice") + ) except Exception: voice = None - response = AudioService.transcript_tts( - app_model=app_model, - text=text, - message_id=message_id, - voice=voice - ) + response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice) return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") @@ -145,12 +145,12 @@ class TextModesApi(Resource): def get(self, app_model): try: parser = reqparse.RequestParser() - parser.add_argument('language', type=str, required=True, location='args') + parser.add_argument("language", type=str, required=True, location="args") args = parser.parse_args() response = AudioService.transcript_tts_voices( tenant_id=app_model.tenant_id, - language=args['language'], + language=args["language"], ) return response @@ -179,6 +179,6 @@ class TextModesApi(Resource): raise InternalServerError() -api.add_resource(ChatMessageAudioApi, '/apps//audio-to-text') -api.add_resource(ChatMessageTextApi, '/apps//text-to-audio') -api.add_resource(TextModesApi, '/apps//text-to-audio/voices') +api.add_resource(ChatMessageAudioApi, "/apps//audio-to-text") +api.add_resource(ChatMessageTextApi, "/apps//text-to-audio") +api.add_resource(TextModesApi, "/apps//text-to-audio/voices") diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 61582536f..6fe52ec28 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -35,33 +35,28 @@ from services.app_generate_service import AppGenerateService # define completion message api for user class CompletionMessageApi(Resource): - @setup_required @login_required @account_initialization_required @get_app_model(mode=AppMode.COMPLETION) def post(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('model_config', type=dict, required=True, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("model_config", type=dict, required=True, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args['response_mode'] != 'blocking' - args['auto_generate_name'] = False + streaming = args["response_mode"] != "blocking" + args["auto_generate_name"] = False account = flask_login.current_user try: response = AppGenerateService.generate( - app_model=app_model, - user=account, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming + app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -97,7 +92,7 @@ class CompletionMessageStopApi(Resource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatMessageApi(Resource): @@ -107,27 +102,23 @@ class ChatMessageApi(Resource): @get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT]) def post(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('model_config', type=dict, required=True, location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("model_config", type=dict, required=True, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args['response_mode'] != 'blocking' - args['auto_generate_name'] = False + streaming = args["response_mode"] != "blocking" + args["auto_generate_name"] = False account = flask_login.current_user try: response = AppGenerateService.generate( - app_model=app_model, - user=account, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=streaming + app_model=app_model, user=account, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=streaming ) return helper.compact_generate_response(response) @@ -163,10 +154,10 @@ class ChatMessageStopApi(Resource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, account.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionMessageApi, '/apps//completion-messages') -api.add_resource(CompletionMessageStopApi, '/apps//completion-messages//stop') -api.add_resource(ChatMessageApi, '/apps//chat-messages') -api.add_resource(ChatMessageStopApi, '/apps//chat-messages//stop') +api.add_resource(CompletionMessageApi, "/apps//completion-messages") +api.add_resource(CompletionMessageStopApi, "/apps//completion-messages//stop") +api.add_resource(ChatMessageApi, "/apps//chat-messages") +api.add_resource(ChatMessageStopApi, "/apps//chat-messages//stop") diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 995264541..753a6be20 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -26,7 +26,6 @@ from models.model import AppMode, Conversation, EndUser, Message, MessageAnnotat class CompletionConversationApi(Resource): - @setup_required @login_required @account_initialization_required @@ -36,24 +35,23 @@ class CompletionConversationApi(Resource): if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('keyword', type=str, location='args') - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('annotation_status', type=str, - choices=['annotated', 'not_annotated', 'all'], default='all', location='args') - parser.add_argument('page', type=int_range(1, 99999), default=1, location='args') - parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument( + "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" + ) + parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() - query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == 'completion') + query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion") - if args['keyword']: - query = query.join( - Message, Message.conversation_id == Conversation.id - ).filter( + if args["keyword"]: + query = query.join(Message, Message.conversation_id == Conversation.id).filter( or_( - Message.query.ilike('%{}%'.format(args['keyword'])), - Message.answer.ilike('%{}%'.format(args['keyword'])) + Message.query.ilike("%{}%".format(args["keyword"])), + Message.answer.ilike("%{}%".format(args["keyword"])), ) ) @@ -61,8 +59,8 @@ class CompletionConversationApi(Resource): timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) @@ -70,8 +68,8 @@ class CompletionConversationApi(Resource): query = query.where(Conversation.created_at >= start_datetime_utc) - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=59) end_datetime_timezone = timezone.localize(end_datetime) @@ -79,29 +77,25 @@ class CompletionConversationApi(Resource): query = query.where(Conversation.created_at < end_datetime_utc) - if args['annotation_status'] == "annotated": + if args["annotation_status"] == "annotated": query = query.options(joinedload(Conversation.message_annotations)).join( MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) - elif args['annotation_status'] == "not_annotated": - query = query.outerjoin( - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id - ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0) + elif args["annotation_status"] == "not_annotated": + query = ( + query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) + .group_by(Conversation.id) + .having(func.count(MessageAnnotation.id) == 0) + ) query = query.order_by(Conversation.created_at.desc()) - conversations = db.paginate( - query, - page=args['page'], - per_page=args['limit'], - error_out=False - ) + conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) return conversations class CompletionConversationDetailApi(Resource): - @setup_required @login_required @account_initialization_required @@ -123,8 +117,11 @@ class CompletionConversationDetailApi(Resource): raise Forbidden() conversation_id = str(conversation_id) - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") @@ -132,11 +129,10 @@ class CompletionConversationDetailApi(Resource): conversation.is_deleted = True db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class ChatConversationApi(Resource): - @setup_required @login_required @account_initialization_required @@ -146,22 +142,28 @@ class ChatConversationApi(Resource): if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('keyword', type=str, location='args') - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('annotation_status', type=str, - choices=['annotated', 'not_annotated', 'all'], default='all', location='args') - parser.add_argument('message_count_gte', type=int_range(1, 99999), required=False, location='args') - parser.add_argument('page', type=int_range(1, 99999), required=False, default=1, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'], - required=False, default='-updated_at', location='args') + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument( + "annotation_status", type=str, choices=["annotated", "not_annotated", "all"], default="all", location="args" + ) + parser.add_argument("message_count_gte", type=int_range(1, 99999), required=False, location="args") + parser.add_argument("page", type=int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) args = parser.parse_args() subquery = ( db.session.query( - Conversation.id.label('conversation_id'), - EndUser.session_id.label('from_end_user_session_id') + Conversation.id.label("conversation_id"), EndUser.session_id.label("from_end_user_session_id") ) .outerjoin(EndUser, Conversation.from_end_user_id == EndUser.id) .subquery() @@ -169,28 +171,31 @@ class ChatConversationApi(Resource): query = db.select(Conversation).where(Conversation.app_id == app_model.id) - if args['keyword']: - keyword_filter = '%{}%'.format(args['keyword']) - query = query.join( - Message, Message.conversation_id == Conversation.id, - ).join( - subquery, subquery.c.conversation_id == Conversation.id - ).filter( - or_( - Message.query.ilike(keyword_filter), - Message.answer.ilike(keyword_filter), - Conversation.name.ilike(keyword_filter), - Conversation.introduction.ilike(keyword_filter), - subquery.c.from_end_user_session_id.ilike(keyword_filter) - ), + if args["keyword"]: + keyword_filter = "%{}%".format(args["keyword"]) + query = ( + query.join( + Message, + Message.conversation_id == Conversation.id, + ) + .join(subquery, subquery.c.conversation_id == Conversation.id) + .filter( + or_( + Message.query.ilike(keyword_filter), + Message.answer.ilike(keyword_filter), + Conversation.name.ilike(keyword_filter), + Conversation.introduction.ilike(keyword_filter), + subquery.c.from_end_user_session_id.ilike(keyword_filter), + ), + ) ) account = current_user timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) @@ -198,8 +203,8 @@ class ChatConversationApi(Resource): query = query.where(Conversation.created_at >= start_datetime_utc) - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=59) end_datetime_timezone = timezone.localize(end_datetime) @@ -207,50 +212,46 @@ class ChatConversationApi(Resource): query = query.where(Conversation.created_at < end_datetime_utc) - if args['annotation_status'] == "annotated": + if args["annotation_status"] == "annotated": query = query.options(joinedload(Conversation.message_annotations)).join( MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id ) - elif args['annotation_status'] == "not_annotated": - query = query.outerjoin( - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id - ).group_by(Conversation.id).having(func.count(MessageAnnotation.id) == 0) + elif args["annotation_status"] == "not_annotated": + query = ( + query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id) + .group_by(Conversation.id) + .having(func.count(MessageAnnotation.id) == 0) + ) - if args['message_count_gte'] and args['message_count_gte'] >= 1: + if args["message_count_gte"] and args["message_count_gte"] >= 1: query = ( query.options(joinedload(Conversation.messages)) .join(Message, Message.conversation_id == Conversation.id) .group_by(Conversation.id) - .having(func.count(Message.id) >= args['message_count_gte']) + .having(func.count(Message.id) >= args["message_count_gte"]) ) if app_model.mode == AppMode.ADVANCED_CHAT.value: query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER.value) - match args['sort_by']: - case 'created_at': + match args["sort_by"]: + case "created_at": query = query.order_by(Conversation.created_at.asc()) - case '-created_at': + case "-created_at": query = query.order_by(Conversation.created_at.desc()) - case 'updated_at': + case "updated_at": query = query.order_by(Conversation.updated_at.asc()) - case '-updated_at': + case "-updated_at": query = query.order_by(Conversation.updated_at.desc()) case _: query = query.order_by(Conversation.created_at.desc()) - conversations = db.paginate( - query, - page=args['page'], - per_page=args['limit'], - error_out=False - ) + conversations = db.paginate(query, page=args["page"], per_page=args["limit"], error_out=False) return conversations class ChatConversationDetailApi(Resource): - @setup_required @login_required @account_initialization_required @@ -272,8 +273,11 @@ class ChatConversationDetailApi(Resource): raise Forbidden() conversation_id = str(conversation_id) - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") @@ -281,18 +285,21 @@ class ChatConversationDetailApi(Resource): conversation.is_deleted = True db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 -api.add_resource(CompletionConversationApi, '/apps//completion-conversations') -api.add_resource(CompletionConversationDetailApi, '/apps//completion-conversations/') -api.add_resource(ChatConversationApi, '/apps//chat-conversations') -api.add_resource(ChatConversationDetailApi, '/apps//chat-conversations/') +api.add_resource(CompletionConversationApi, "/apps//completion-conversations") +api.add_resource(CompletionConversationDetailApi, "/apps//completion-conversations/") +api.add_resource(ChatConversationApi, "/apps//chat-conversations") +api.add_resource(ChatConversationDetailApi, "/apps//chat-conversations/") def _get_conversation(app_model, conversation_id): - conversation = db.session.query(Conversation) \ - .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py index aa0722ea3..23b234dac 100644 --- a/api/controllers/console/app/conversation_variables.py +++ b/api/controllers/console/app/conversation_variables.py @@ -21,7 +21,7 @@ class ConversationVariablesApi(Resource): @marshal_with(paginated_conversation_variable_fields) def get(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('conversation_id', type=str, location='args') + parser.add_argument("conversation_id", type=str, location="args") args = parser.parse_args() stmt = ( @@ -29,10 +29,10 @@ class ConversationVariablesApi(Resource): .where(ConversationVariable.app_id == app_model.id) .order_by(ConversationVariable.created_at) ) - if args['conversation_id']: - stmt = stmt.where(ConversationVariable.conversation_id == args['conversation_id']) + if args["conversation_id"]: + stmt = stmt.where(ConversationVariable.conversation_id == args["conversation_id"]) else: - raise ValueError('conversation_id is required') + raise ValueError("conversation_id is required") # NOTE: This is a temporary solution to avoid performance issues. page = 1 @@ -43,14 +43,14 @@ class ConversationVariablesApi(Resource): rows = session.scalars(stmt).all() return { - 'page': page, - 'limit': page_size, - 'total': len(rows), - 'has_more': False, - 'data': [ + "page": page, + "limit": page_size, + "total": len(rows), + "has_more": False, + "data": [ { - 'created_at': row.created_at, - 'updated_at': row.updated_at, + "created_at": row.created_at, + "updated_at": row.updated_at, **row.to_variable().model_dump(), } for row in rows @@ -58,4 +58,4 @@ class ConversationVariablesApi(Resource): } -api.add_resource(ConversationVariablesApi, '/apps//conversation-variables') +api.add_resource(ConversationVariablesApi, "/apps//conversation-variables") diff --git a/api/controllers/console/app/error.py b/api/controllers/console/app/error.py index f6feed122..33d30c205 100644 --- a/api/controllers/console/app/error.py +++ b/api/controllers/console/app/error.py @@ -2,116 +2,120 @@ from libs.exception import BaseHTTPException class AppNotFoundError(BaseHTTPException): - error_code = 'app_not_found' + error_code = "app_not_found" description = "App not found." code = 404 class ProviderNotInitializeError(BaseHTTPException): - error_code = 'provider_not_initialize' - description = "No valid model provider credentials found. " \ - "Please go to Settings -> Model Provider to complete your provider credentials." + error_code = "provider_not_initialize" + description = ( + "No valid model provider credentials found. " + "Please go to Settings -> Model Provider to complete your provider credentials." + ) code = 400 class ProviderQuotaExceededError(BaseHTTPException): - error_code = 'provider_quota_exceeded' - description = "Your quota for Dify Hosted Model Provider has been exhausted. " \ - "Please go to Settings -> Model Provider to complete your own provider credentials." + error_code = "provider_quota_exceeded" + description = ( + "Your quota for Dify Hosted Model Provider has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) code = 400 class ProviderModelCurrentlyNotSupportError(BaseHTTPException): - error_code = 'model_currently_not_support' + error_code = "model_currently_not_support" description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." code = 400 class ConversationCompletedError(BaseHTTPException): - error_code = 'conversation_completed' + error_code = "conversation_completed" description = "The conversation has ended. Please start a new conversation." code = 400 class AppUnavailableError(BaseHTTPException): - error_code = 'app_unavailable' + error_code = "app_unavailable" description = "App unavailable, please check your app configurations." code = 400 class CompletionRequestError(BaseHTTPException): - error_code = 'completion_request_error' + error_code = "completion_request_error" description = "Completion request failed." code = 400 class AppMoreLikeThisDisabledError(BaseHTTPException): - error_code = 'app_more_like_this_disabled' + error_code = "app_more_like_this_disabled" description = "The 'More like this' feature is disabled. Please refresh your page." code = 403 class NoAudioUploadedError(BaseHTTPException): - error_code = 'no_audio_uploaded' + error_code = "no_audio_uploaded" description = "Please upload your audio." code = 400 class AudioTooLargeError(BaseHTTPException): - error_code = 'audio_too_large' + error_code = "audio_too_large" description = "Audio size exceeded. {message}" code = 413 class UnsupportedAudioTypeError(BaseHTTPException): - error_code = 'unsupported_audio_type' + error_code = "unsupported_audio_type" description = "Audio type not allowed." code = 415 class ProviderNotSupportSpeechToTextError(BaseHTTPException): - error_code = 'provider_not_support_speech_to_text' + error_code = "provider_not_support_speech_to_text" description = "Provider not support speech to text." code = 400 class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class DraftWorkflowNotExist(BaseHTTPException): - error_code = 'draft_workflow_not_exist' + error_code = "draft_workflow_not_exist" description = "Draft workflow need to be initialized." code = 400 class DraftWorkflowNotSync(BaseHTTPException): - error_code = 'draft_workflow_not_sync' + error_code = "draft_workflow_not_sync" description = "Workflow graph might have been modified, please refresh and resubmit." code = 400 class TracingConfigNotExist(BaseHTTPException): - error_code = 'trace_config_not_exist' + error_code = "trace_config_not_exist" description = "Trace config not exist." code = 400 class TracingConfigIsExist(BaseHTTPException): - error_code = 'trace_config_is_exist' + error_code = "trace_config_is_exist" description = "Trace config is exist." code = 400 class TracingConfigCheckError(BaseHTTPException): - error_code = 'trace_config_check_error' + error_code = "trace_config_check_error" description = "Invalid Credentials." code = 400 diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 6803775e2..3d1e6b7a3 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -24,21 +24,21 @@ class RuleGenerateApi(Resource): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('instruction', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_config', type=dict, required=True, nullable=False, location='json') - parser.add_argument('no_variable', type=bool, required=True, default=False, location='json') + parser.add_argument("instruction", type=str, required=True, nullable=False, location="json") + parser.add_argument("model_config", type=dict, required=True, nullable=False, location="json") + parser.add_argument("no_variable", type=bool, required=True, default=False, location="json") args = parser.parse_args() account = current_user - PROMPT_GENERATION_MAX_TOKENS = int(os.getenv('PROMPT_GENERATION_MAX_TOKENS', '512')) + PROMPT_GENERATION_MAX_TOKENS = int(os.getenv("PROMPT_GENERATION_MAX_TOKENS", "512")) try: rules = LLMGenerator.generate_rule_config( tenant_id=account.current_tenant_id, - instruction=args['instruction'], - model_config=args['model_config'], - no_variable=args['no_variable'], - rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS + instruction=args["instruction"], + model_config=args["model_config"], + no_variable=args["no_variable"], + rule_config_max_tokens=PROMPT_GENERATION_MAX_TOKENS, ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -52,4 +52,4 @@ class RuleGenerateApi(Resource): return rules -api.add_resource(RuleGenerateApi, '/rule-generate') +api.add_resource(RuleGenerateApi, "/rule-generate") diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 056415f19..fe0620198 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -33,9 +33,9 @@ from services.message_service import MessageService class ChatMessageListApi(Resource): message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_detail_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_detail_fields)), } @setup_required @@ -45,55 +45,69 @@ class ChatMessageListApi(Resource): @marshal_with(message_infinite_scroll_pagination_fields) def get(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - conversation = db.session.query(Conversation).filter( - Conversation.id == args['conversation_id'], - Conversation.app_id == app_model.id - ).first() + conversation = ( + db.session.query(Conversation) + .filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id) + .first() + ) if not conversation: raise NotFound("Conversation Not Exists.") - if args['first_id']: - first_message = db.session.query(Message) \ - .filter(Message.conversation_id == conversation.id, Message.id == args['first_id']).first() + if args["first_id"]: + first_message = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id, Message.id == args["first_id"]) + .first() + ) if not first_message: raise NotFound("First message not found") - history_messages = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < first_message.created_at, - Message.id != first_message.id - ) \ - .order_by(Message.created_at.desc()).limit(args['limit']).all() + history_messages = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < first_message.created_at, + Message.id != first_message.id, + ) + .order_by(Message.created_at.desc()) + .limit(args["limit"]) + .all() + ) else: - history_messages = db.session.query(Message).filter(Message.conversation_id == conversation.id) \ - .order_by(Message.created_at.desc()).limit(args['limit']).all() + history_messages = ( + db.session.query(Message) + .filter(Message.conversation_id == conversation.id) + .order_by(Message.created_at.desc()) + .limit(args["limit"]) + .all() + ) has_more = False - if len(history_messages) == args['limit']: + if len(history_messages) == args["limit"]: current_page_first_message = history_messages[-1] - rest_count = db.session.query(Message).filter( - Message.conversation_id == conversation.id, - Message.created_at < current_page_first_message.created_at, - Message.id != current_page_first_message.id - ).count() + rest_count = ( + db.session.query(Message) + .filter( + Message.conversation_id == conversation.id, + Message.created_at < current_page_first_message.created_at, + Message.id != current_page_first_message.id, + ) + .count() + ) if rest_count > 0: has_more = True history_messages = list(reversed(history_messages)) - return InfiniteScrollPagination( - data=history_messages, - limit=args['limit'], - has_more=has_more - ) + return InfiniteScrollPagination(data=history_messages, limit=args["limit"], has_more=has_more) class MessageFeedbackApi(Resource): @@ -103,49 +117,46 @@ class MessageFeedbackApi(Resource): @get_app_model def post(self, app_model): parser = reqparse.RequestParser() - parser.add_argument('message_id', required=True, type=uuid_value, location='json') - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("message_id", required=True, type=uuid_value, location="json") + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() - message_id = str(args['message_id']) + message_id = str(args["message_id"]) - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id - ).first() + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() if not message: raise NotFound("Message Not Exists.") feedback = message.admin_feedback - if not args['rating'] and feedback: + if not args["rating"] and feedback: db.session.delete(feedback) - elif args['rating'] and feedback: - feedback.rating = args['rating'] - elif not args['rating'] and not feedback: - raise ValueError('rating cannot be None when feedback not exists') + elif args["rating"] and feedback: + feedback.rating = args["rating"] + elif not args["rating"] and not feedback: + raise ValueError("rating cannot be None when feedback not exists") else: feedback = MessageFeedback( app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, - rating=args['rating'], - from_source='admin', - from_account_id=current_user.id + rating=args["rating"], + from_source="admin", + from_account_id=current_user.id, ) db.session.add(feedback) db.session.commit() - return {'result': 'success'} + return {"result": "success"} class MessageAnnotationApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('annotation') + @cloud_edition_billing_resource_check("annotation") @get_app_model @marshal_with(annotation_fields) def post(self, app_model): @@ -153,10 +164,10 @@ class MessageAnnotationApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('message_id', required=False, type=uuid_value, location='json') - parser.add_argument('question', required=True, type=str, location='json') - parser.add_argument('answer', required=True, type=str, location='json') - parser.add_argument('annotation_reply', required=False, type=dict, location='json') + parser.add_argument("message_id", required=False, type=uuid_value, location="json") + parser.add_argument("question", required=True, type=str, location="json") + parser.add_argument("answer", required=True, type=str, location="json") + parser.add_argument("annotation_reply", required=False, type=dict, location="json") args = parser.parse_args() annotation = AppAnnotationService.up_insert_app_annotation_from_message(args, app_model.id) @@ -169,11 +180,9 @@ class MessageAnnotationCountApi(Resource): @account_initialization_required @get_app_model def get(self, app_model): - count = db.session.query(MessageAnnotation).filter( - MessageAnnotation.app_id == app_model.id - ).count() + count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count() - return {'count': count} + return {"count": count} class MessageSuggestedQuestionApi(Resource): @@ -186,10 +195,7 @@ class MessageSuggestedQuestionApi(Resource): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - message_id=message_id, - user=current_user, - invoke_from=InvokeFrom.DEBUGGER + app_model=app_model, message_id=message_id, user=current_user, invoke_from=InvokeFrom.DEBUGGER ) except MessageNotExistsError: raise NotFound("Message not found") @@ -209,7 +215,7 @@ class MessageSuggestedQuestionApi(Resource): logging.exception("internal server error.") raise InternalServerError() - return {'data': questions} + return {"data": questions} class MessageApi(Resource): @@ -221,10 +227,7 @@ class MessageApi(Resource): def get(self, app_model, message_id): message_id = str(message_id) - message = db.session.query(Message).filter( - Message.id == message_id, - Message.app_id == app_model.id - ).first() + message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() if not message: raise NotFound("Message Not Exists.") @@ -232,9 +235,9 @@ class MessageApi(Resource): return message -api.add_resource(MessageSuggestedQuestionApi, '/apps//chat-messages//suggested-questions') -api.add_resource(ChatMessageListApi, '/apps//chat-messages', endpoint='console_chat_messages') -api.add_resource(MessageFeedbackApi, '/apps//feedbacks') -api.add_resource(MessageAnnotationApi, '/apps//annotations') -api.add_resource(MessageAnnotationCountApi, '/apps//annotations/count') -api.add_resource(MessageApi, '/apps//messages/', endpoint='console_message') +api.add_resource(MessageSuggestedQuestionApi, "/apps//chat-messages//suggested-questions") +api.add_resource(ChatMessageListApi, "/apps//chat-messages", endpoint="console_chat_messages") +api.add_resource(MessageFeedbackApi, "/apps//feedbacks") +api.add_resource(MessageAnnotationApi, "/apps//annotations") +api.add_resource(MessageAnnotationCountApi, "/apps//annotations/count") +api.add_resource(MessageApi, "/apps//messages/", endpoint="console_message") diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index c8df879a2..702afe986 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -19,19 +19,15 @@ from services.app_model_config_service import AppModelConfigService class ModelConfigResource(Resource): - @setup_required @login_required @account_initialization_required @get_app_model(mode=[AppMode.AGENT_CHAT, AppMode.CHAT, AppMode.COMPLETION]) def post(self, app_model): - """Modify app model config""" # validate config model_configuration = AppModelConfigService.validate_configuration( - tenant_id=current_user.current_tenant_id, - config=request.json, - app_mode=AppMode.value_of(app_model.mode) + tenant_id=current_user.current_tenant_id, config=request.json, app_mode=AppMode.value_of(app_model.mode) ) new_app_model_config = AppModelConfig( @@ -41,15 +37,15 @@ class ModelConfigResource(Resource): if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: # get original app model config - original_app_model_config: AppModelConfig = db.session.query(AppModelConfig).filter( - AppModelConfig.id == app_model.app_model_config_id - ).first() + original_app_model_config: AppModelConfig = ( + db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first() + ) agent_mode = original_app_model_config.agent_mode_dict # decrypt agent tool parameters if it's secret-input parameter_map = {} masked_parameter_map = {} tool_map = {} - for tool in agent_mode.get('tools') or []: + for tool in agent_mode.get("tools") or []: if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue @@ -66,7 +62,7 @@ class ModelConfigResource(Resource): tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, - identity_id=f'AGENT.{app_model.id}' + identity_id=f"AGENT.{app_model.id}", ) except Exception as e: continue @@ -79,18 +75,18 @@ class ModelConfigResource(Resource): parameters = {} masked_parameter = {} - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" masked_parameter_map[key] = masked_parameter parameter_map[key] = parameters tool_map[key] = tool_runtime # encrypt agent tool parameters if it's secret-input agent_mode = new_app_model_config.agent_mode_dict - for tool in agent_mode.get('tools') or []: + for tool in agent_mode.get("tools") or []: agent_tool_entity = AgentToolEntity(**tool) # get tool - key = f'{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}' + key = f"{agent_tool_entity.provider_id}.{agent_tool_entity.provider_type}.{agent_tool_entity.tool_name}" if key in tool_map: tool_runtime = tool_map[key] else: @@ -108,7 +104,7 @@ class ModelConfigResource(Resource): tool_runtime=tool_runtime, provider_name=agent_tool_entity.provider_id, provider_type=agent_tool_entity.provider_type, - identity_id=f'AGENT.{app_model.id}' + identity_id=f"AGENT.{app_model.id}", ) manager.delete_tool_parameters_cache() @@ -116,15 +112,17 @@ class ModelConfigResource(Resource): if agent_tool_entity.tool_parameters: if key not in masked_parameter_map: continue - + for masked_key, masked_value in masked_parameter_map[key].items(): - if masked_key in agent_tool_entity.tool_parameters and \ - agent_tool_entity.tool_parameters[masked_key] == masked_value: + if ( + masked_key in agent_tool_entity.tool_parameters + and agent_tool_entity.tool_parameters[masked_key] == masked_value + ): agent_tool_entity.tool_parameters[masked_key] = parameter_map[key].get(masked_key) # encrypt parameters if agent_tool_entity.tool_parameters: - tool['tool_parameters'] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) + tool["tool_parameters"] = manager.encrypt_tool_parameters(agent_tool_entity.tool_parameters or {}) # update app model config new_app_model_config.agent_mode = json.dumps(agent_mode) @@ -135,12 +133,9 @@ class ModelConfigResource(Resource): app_model.app_model_config_id = new_app_model_config.id db.session.commit() - app_model_config_was_updated.send( - app_model, - app_model_config=new_app_model_config - ) + app_model_config_was_updated.send(app_model, app_model_config=new_app_model_config) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(ModelConfigResource, '/apps//model-config') +api.add_resource(ModelConfigResource, "/apps//model-config") diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index c0cf7b9e3..374bd2b81 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -18,13 +18,11 @@ class TraceAppConfigApi(Resource): @account_initialization_required def get(self, app_id): parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='args') + parser.add_argument("tracing_provider", type=str, required=True, location="args") args = parser.parse_args() try: - trace_config = OpsService.get_tracing_app_config( - app_id=app_id, tracing_provider=args['tracing_provider'] - ) + trace_config = OpsService.get_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) if not trace_config: return {"has_not_configured": True} return trace_config @@ -37,19 +35,17 @@ class TraceAppConfigApi(Resource): def post(self, app_id): """Create a new trace app configuration""" parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='json') - parser.add_argument('tracing_config', type=dict, required=True, location='json') + parser.add_argument("tracing_provider", type=str, required=True, location="json") + parser.add_argument("tracing_config", type=dict, required=True, location="json") args = parser.parse_args() try: result = OpsService.create_tracing_app_config( - app_id=app_id, - tracing_provider=args['tracing_provider'], - tracing_config=args['tracing_config'] + app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] ) if not result: raise TracingConfigIsExist() - if result.get('error'): + if result.get("error"): raise TracingConfigCheckError() return result except Exception as e: @@ -61,15 +57,13 @@ class TraceAppConfigApi(Resource): def patch(self, app_id): """Update an existing trace app configuration""" parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='json') - parser.add_argument('tracing_config', type=dict, required=True, location='json') + parser.add_argument("tracing_provider", type=str, required=True, location="json") + parser.add_argument("tracing_config", type=dict, required=True, location="json") args = parser.parse_args() try: result = OpsService.update_tracing_app_config( - app_id=app_id, - tracing_provider=args['tracing_provider'], - tracing_config=args['tracing_config'] + app_id=app_id, tracing_provider=args["tracing_provider"], tracing_config=args["tracing_config"] ) if not result: raise TracingConfigNotExist() @@ -83,14 +77,11 @@ class TraceAppConfigApi(Resource): def delete(self, app_id): """Delete an existing trace app configuration""" parser = reqparse.RequestParser() - parser.add_argument('tracing_provider', type=str, required=True, location='args') + parser.add_argument("tracing_provider", type=str, required=True, location="args") args = parser.parse_args() try: - result = OpsService.delete_tracing_app_config( - app_id=app_id, - tracing_provider=args['tracing_provider'] - ) + result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) if not result: raise TracingConfigNotExist() return {"result": "success"} @@ -98,4 +89,4 @@ class TraceAppConfigApi(Resource): raise e -api.add_resource(TraceAppConfigApi, '/apps//trace-config') +api.add_resource(TraceAppConfigApi, "/apps//trace-config") diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 7db58c048..d903a2609 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -15,23 +15,23 @@ from models.model import Site def parse_app_site_args(): parser = reqparse.RequestParser() - parser.add_argument('title', type=str, required=False, location='json') - parser.add_argument('icon_type', type=str, required=False, location='json') - parser.add_argument('icon', type=str, required=False, location='json') - parser.add_argument('icon_background', type=str, required=False, location='json') - parser.add_argument('description', type=str, required=False, location='json') - parser.add_argument('default_language', type=supported_language, required=False, location='json') - parser.add_argument('chat_color_theme', type=str, required=False, location='json') - parser.add_argument('chat_color_theme_inverted', type=bool, required=False, location='json') - parser.add_argument('customize_domain', type=str, required=False, location='json') - parser.add_argument('copyright', type=str, required=False, location='json') - parser.add_argument('privacy_policy', type=str, required=False, location='json') - parser.add_argument('custom_disclaimer', type=str, required=False, location='json') - parser.add_argument('customize_token_strategy', type=str, choices=['must', 'allow', 'not_allow'], - required=False, - location='json') - parser.add_argument('prompt_public', type=bool, required=False, location='json') - parser.add_argument('show_workflow_steps', type=bool, required=False, location='json') + parser.add_argument("title", type=str, required=False, location="json") + parser.add_argument("icon_type", type=str, required=False, location="json") + parser.add_argument("icon", type=str, required=False, location="json") + parser.add_argument("icon_background", type=str, required=False, location="json") + parser.add_argument("description", type=str, required=False, location="json") + parser.add_argument("default_language", type=supported_language, required=False, location="json") + parser.add_argument("chat_color_theme", type=str, required=False, location="json") + parser.add_argument("chat_color_theme_inverted", type=bool, required=False, location="json") + parser.add_argument("customize_domain", type=str, required=False, location="json") + parser.add_argument("copyright", type=str, required=False, location="json") + parser.add_argument("privacy_policy", type=str, required=False, location="json") + parser.add_argument("custom_disclaimer", type=str, required=False, location="json") + parser.add_argument( + "customize_token_strategy", type=str, choices=["must", "allow", "not_allow"], required=False, location="json" + ) + parser.add_argument("prompt_public", type=bool, required=False, location="json") + parser.add_argument("show_workflow_steps", type=bool, required=False, location="json") return parser.parse_args() @@ -48,26 +48,24 @@ class AppSite(Resource): if not current_user.is_editor: raise Forbidden() - site = db.session.query(Site). \ - filter(Site.app_id == app_model.id). \ - one_or_404() + site = db.session.query(Site).filter(Site.app_id == app_model.id).one_or_404() for attr_name in [ - 'title', - 'icon_type', - 'icon', - 'icon_background', - 'description', - 'default_language', - 'chat_color_theme', - 'chat_color_theme_inverted', - 'customize_domain', - 'copyright', - 'privacy_policy', - 'custom_disclaimer', - 'customize_token_strategy', - 'prompt_public', - 'show_workflow_steps' + "title", + "icon_type", + "icon", + "icon_background", + "description", + "default_language", + "chat_color_theme", + "chat_color_theme_inverted", + "customize_domain", + "copyright", + "privacy_policy", + "custom_disclaimer", + "customize_token_strategy", + "prompt_public", + "show_workflow_steps", ]: value = args.get(attr_name) if value is not None: @@ -79,7 +77,6 @@ class AppSite(Resource): class AppSiteAccessTokenReset(Resource): - @setup_required @login_required @account_initialization_required @@ -101,5 +98,5 @@ class AppSiteAccessTokenReset(Resource): return site -api.add_resource(AppSite, '/apps//site') -api.add_resource(AppSiteAccessTokenReset, '/apps//site/access-token-reset') +api.add_resource(AppSite, "/apps//site") +api.add_resource(AppSiteAccessTokenReset, "/apps//site/access-token-reset") diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index b882ffef3..bf65efeae 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -17,7 +17,6 @@ from models.model import AppMode class DailyConversationStatistic(Resource): - @setup_required @login_required @account_initialization_required @@ -26,58 +25,52 @@ class DailyConversationStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.conversation_id) AS conversation_count FROM messages where app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'conversation_count': i.conversation_count - }) + response_data.append({"date": str(i.date), "conversation_count": i.conversation_count}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class DailyTerminalsStatistic(Resource): - @setup_required @login_required @account_initialization_required @@ -86,54 +79,49 @@ class DailyTerminalsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct messages.from_end_user_id) AS terminal_count FROM messages where app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'terminal_count': i.terminal_count - }) + response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class DailyTokenCostStatistic(Resource): @@ -145,58 +133,53 @@ class DailyTokenCostStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, (sum(messages.message_tokens) + sum(messages.answer_tokens)) as token_count, sum(total_price) as total_price FROM messages where app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'token_count': i.token_count, - 'total_price': i.total_price, - 'currency': 'USD' - }) + response_data.append( + {"date": str(i.date), "token_count": i.token_count, "total_price": i.total_price, "currency": "USD"} + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class AverageSessionInteractionStatistic(Resource): @@ -208,8 +191,8 @@ class AverageSessionInteractionStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """SELECT date(DATE_TRUNC('day', c.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, @@ -218,30 +201,30 @@ FROM (SELECT m.conversation_id, COUNT(m.id) AS message_count FROM conversations c JOIN messages m ON c.id = m.conversation_id WHERE c.override_model_configs IS NULL AND c.app_id = :app_id""" - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and c.created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and c.created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and c.created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and c.created_at < :end" + arg_dict["end"] = end_datetime_utc sql_query += """ GROUP BY m.conversation_id) subquery @@ -250,18 +233,15 @@ GROUP BY date ORDER BY date""" response_data = [] - + with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'interactions': float(i.interactions.quantize(Decimal('0.01'))) - }) + response_data.append( + {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class UserSatisfactionRateStatistic(Resource): @@ -273,57 +253,57 @@ class UserSatisfactionRateStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', m.created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, COUNT(m.id) as message_count, COUNT(mf.id) as feedback_count FROM messages m LEFT JOIN message_feedbacks mf on mf.message_id=m.id and mf.rating='like' WHERE m.app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and m.created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and m.created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and m.created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and m.created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'rate': round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2), - }) + response_data.append( + { + "date": str(i.date), + "rate": round((i.feedback_count * 1000 / i.message_count) if i.message_count > 0 else 0, 2), + } + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class AverageResponseTimeStatistic(Resource): @@ -335,56 +315,51 @@ class AverageResponseTimeStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, AVG(provider_response_latency) as latency FROM messages WHERE app_id = :app_id - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} + """ + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'latency': round(i.latency * 1000, 4) - }) + response_data.append({"date": str(i.date), "latency": round(i.latency * 1000, 4)}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) class TokensPerSecondStatistic(Resource): @@ -396,63 +371,58 @@ class TokensPerSecondStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = '''SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, + sql_query = """SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, CASE WHEN SUM(provider_response_latency) = 0 THEN 0 ELSE (SUM(answer_tokens) / SUM(provider_response_latency)) END as tokens_per_second FROM messages -WHERE app_id = :app_id''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id} +WHERE app_id = :app_id""" + arg_dict = {"tz": account.timezone, "app_id": app_model.id} timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'tps': round(i.tokens_per_second, 4) - }) + response_data.append({"date": str(i.date), "tps": round(i.tokens_per_second, 4)}) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) -api.add_resource(DailyConversationStatistic, '/apps//statistics/daily-conversations') -api.add_resource(DailyTerminalsStatistic, '/apps//statistics/daily-end-users') -api.add_resource(DailyTokenCostStatistic, '/apps//statistics/token-costs') -api.add_resource(AverageSessionInteractionStatistic, '/apps//statistics/average-session-interactions') -api.add_resource(UserSatisfactionRateStatistic, '/apps//statistics/user-satisfaction-rate') -api.add_resource(AverageResponseTimeStatistic, '/apps//statistics/average-response-time') -api.add_resource(TokensPerSecondStatistic, '/apps//statistics/tokens-per-second') +api.add_resource(DailyConversationStatistic, "/apps//statistics/daily-conversations") +api.add_resource(DailyTerminalsStatistic, "/apps//statistics/daily-end-users") +api.add_resource(DailyTokenCostStatistic, "/apps//statistics/token-costs") +api.add_resource(AverageSessionInteractionStatistic, "/apps//statistics/average-session-interactions") +api.add_resource(UserSatisfactionRateStatistic, "/apps//statistics/user-satisfaction-rate") +api.add_resource(AverageResponseTimeStatistic, "/apps//statistics/average-response-time") +api.add_resource(TokensPerSecondStatistic, "/apps//statistics/tokens-per-second") diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index a2052b976..e44820f63 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -64,51 +64,51 @@ class DraftWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - - content_type = request.headers.get('Content-Type', '') - if 'application/json' in content_type: + content_type = request.headers.get("Content-Type", "") + + if "application/json" in content_type: parser = reqparse.RequestParser() - parser.add_argument('graph', type=dict, required=True, nullable=False, location='json') - parser.add_argument('features', type=dict, required=True, nullable=False, location='json') - parser.add_argument('hash', type=str, required=False, location='json') + parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") + parser.add_argument("features", type=dict, required=True, nullable=False, location="json") + parser.add_argument("hash", type=str, required=False, location="json") # TODO: set this to required=True after frontend is updated - parser.add_argument('environment_variables', type=list, required=False, location='json') - parser.add_argument('conversation_variables', type=list, required=False, location='json') + parser.add_argument("environment_variables", type=list, required=False, location="json") + parser.add_argument("conversation_variables", type=list, required=False, location="json") args = parser.parse_args() - elif 'text/plain' in content_type: + elif "text/plain" in content_type: try: - data = json.loads(request.data.decode('utf-8')) - if 'graph' not in data or 'features' not in data: - raise ValueError('graph or features not found in data') + data = json.loads(request.data.decode("utf-8")) + if "graph" not in data or "features" not in data: + raise ValueError("graph or features not found in data") - if not isinstance(data.get('graph'), dict) or not isinstance(data.get('features'), dict): - raise ValueError('graph or features is not a dict') + if not isinstance(data.get("graph"), dict) or not isinstance(data.get("features"), dict): + raise ValueError("graph or features is not a dict") args = { - 'graph': data.get('graph'), - 'features': data.get('features'), - 'hash': data.get('hash'), - 'environment_variables': data.get('environment_variables'), - 'conversation_variables': data.get('conversation_variables'), + "graph": data.get("graph"), + "features": data.get("features"), + "hash": data.get("hash"), + "environment_variables": data.get("environment_variables"), + "conversation_variables": data.get("conversation_variables"), } except json.JSONDecodeError: - return {'message': 'Invalid JSON data'}, 400 + return {"message": "Invalid JSON data"}, 400 else: abort(415) workflow_service = WorkflowService() try: - environment_variables_list = args.get('environment_variables') or [] + environment_variables_list = args.get("environment_variables") or [] environment_variables = [factory.build_variable_from_mapping(obj) for obj in environment_variables_list] - conversation_variables_list = args.get('conversation_variables') or [] + conversation_variables_list = args.get("conversation_variables") or [] conversation_variables = [factory.build_variable_from_mapping(obj) for obj in conversation_variables_list] workflow = workflow_service.sync_draft_workflow( app_model=app_model, - graph=args['graph'], - features=args['features'], - unique_hash=args.get('hash'), + graph=args["graph"], + features=args["features"], + unique_hash=args.get("hash"), account=current_user, environment_variables=environment_variables, conversation_variables=conversation_variables, @@ -119,7 +119,7 @@ class DraftWorkflowApi(Resource): return { "result": "success", "hash": workflow.unique_hash, - "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at) + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), } @@ -138,13 +138,11 @@ class DraftWorkflowImportApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('data', type=str, required=True, nullable=False, location='json') + parser.add_argument("data", type=str, required=True, nullable=False, location="json") args = parser.parse_args() workflow = AppDslService.import_and_overwrite_workflow( - app_model=app_model, - data=args['data'], - account=current_user + app_model=app_model, data=args["data"], account=current_user ) return workflow @@ -162,21 +160,17 @@ class AdvancedChatDraftWorkflowRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, location='json') - parser.add_argument('query', type=str, required=True, location='json', default='') - parser.add_argument('files', type=list, location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') + parser.add_argument("inputs", type=dict, location="json") + parser.add_argument("query", type=str, required=True, location="json", default="") + parser.add_argument("files", type=list, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True ) return helper.compact_generate_response(response) @@ -190,6 +184,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource): logging.exception("internal server error.") raise InternalServerError() + class AdvancedChatDraftRunIterationNodeApi(Resource): @setup_required @login_required @@ -202,18 +197,14 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, location='json') + parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() try: response = AppGenerateService.generate_single_iteration( - app_model=app_model, - user=current_user, - node_id=node_id, - args=args, - streaming=True + app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True ) return helper.compact_generate_response(response) @@ -227,6 +218,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource): logging.exception("internal server error.") raise InternalServerError() + class WorkflowDraftRunIterationNodeApi(Resource): @setup_required @login_required @@ -239,18 +231,14 @@ class WorkflowDraftRunIterationNodeApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, location='json') + parser.add_argument("inputs", type=dict, location="json") args = parser.parse_args() try: response = AppGenerateService.generate_single_iteration( - app_model=app_model, - user=current_user, - node_id=node_id, - args=args, - streaming=True + app_model=app_model, user=current_user, node_id=node_id, args=args, streaming=True ) return helper.compact_generate_response(response) @@ -264,6 +252,7 @@ class WorkflowDraftRunIterationNodeApi(Resource): logging.exception("internal server error.") raise InternalServerError() + class DraftWorkflowRunApi(Resource): @setup_required @login_required @@ -276,19 +265,15 @@ class DraftWorkflowRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.DEBUGGER, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.DEBUGGER, streaming=True ) return helper.compact_generate_response(response) @@ -311,12 +296,10 @@ class WorkflowTaskStopApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + AppQueueManager.set_stop_flag(task_id, InvokeFrom.DEBUGGER, current_user.id) - return { - "result": "success" - } + return {"result": "success"} class DraftWorkflowNodeRunApi(Resource): @@ -332,24 +315,20 @@ class DraftWorkflowNodeRunApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() workflow_service = WorkflowService() workflow_node_execution = workflow_service.run_draft_workflow_node( - app_model=app_model, - node_id=node_id, - user_inputs=args.get('inputs'), - account=current_user + app_model=app_model, node_id=node_id, user_inputs=args.get("inputs"), account=current_user ) return workflow_node_execution class PublishedWorkflowApi(Resource): - @setup_required @login_required @account_initialization_required @@ -362,7 +341,7 @@ class PublishedWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + # fetch published workflow by app_model workflow_service = WorkflowService() workflow = workflow_service.get_published_workflow(app_model=app_model) @@ -381,14 +360,11 @@ class PublishedWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + workflow_service = WorkflowService() workflow = workflow_service.publish_workflow(app_model=app_model, account=current_user) - return { - "result": "success", - "created_at": TimestampField().format(workflow.created_at) - } + return {"result": "success", "created_at": TimestampField().format(workflow.created_at)} class DefaultBlockConfigsApi(Resource): @@ -403,7 +379,7 @@ class DefaultBlockConfigsApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + # Get default block configs workflow_service = WorkflowService() return workflow_service.get_default_block_configs() @@ -421,24 +397,21 @@ class DefaultBlockConfigApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('q', type=str, location='args') + parser.add_argument("q", type=str, location="args") args = parser.parse_args() filters = None - if args.get('q'): + if args.get("q"): try: - filters = json.loads(args.get('q')) + filters = json.loads(args.get("q")) except json.JSONDecodeError: - raise ValueError('Invalid filters') + raise ValueError("Invalid filters") # Get default block configs workflow_service = WorkflowService() - return workflow_service.get_default_block_config( - node_type=block_type, - filters=filters - ) + return workflow_service.get_default_block_config(node_type=block_type, filters=filters) class ConvertToWorkflowApi(Resource): @@ -455,41 +428,43 @@ class ConvertToWorkflowApi(Resource): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() - + if request.data: parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, nullable=True, location='json') - parser.add_argument('icon_type', type=str, required=False, nullable=True, location='json') - parser.add_argument('icon', type=str, required=False, nullable=True, location='json') - parser.add_argument('icon_background', type=str, required=False, nullable=True, location='json') + parser.add_argument("name", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon_type", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon", type=str, required=False, nullable=True, location="json") + parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json") args = parser.parse_args() else: args = {} # convert to workflow mode workflow_service = WorkflowService() - new_app_model = workflow_service.convert_to_workflow( - app_model=app_model, - account=current_user, - args=args - ) + new_app_model = workflow_service.convert_to_workflow(app_model=app_model, account=current_user, args=args) # return app id return { - 'new_app_id': new_app_model.id, + "new_app_id": new_app_model.id, } -api.add_resource(DraftWorkflowApi, '/apps//workflows/draft') -api.add_resource(DraftWorkflowImportApi, '/apps//workflows/draft/import') -api.add_resource(AdvancedChatDraftWorkflowRunApi, '/apps//advanced-chat/workflows/draft/run') -api.add_resource(DraftWorkflowRunApi, '/apps//workflows/draft/run') -api.add_resource(WorkflowTaskStopApi, '/apps//workflow-runs/tasks//stop') -api.add_resource(DraftWorkflowNodeRunApi, '/apps//workflows/draft/nodes//run') -api.add_resource(AdvancedChatDraftRunIterationNodeApi, '/apps//advanced-chat/workflows/draft/iteration/nodes//run') -api.add_resource(WorkflowDraftRunIterationNodeApi, '/apps//workflows/draft/iteration/nodes//run') -api.add_resource(PublishedWorkflowApi, '/apps//workflows/publish') -api.add_resource(DefaultBlockConfigsApi, '/apps//workflows/default-workflow-block-configs') -api.add_resource(DefaultBlockConfigApi, '/apps//workflows/default-workflow-block-configs' - '/') -api.add_resource(ConvertToWorkflowApi, '/apps//convert-to-workflow') +api.add_resource(DraftWorkflowApi, "/apps//workflows/draft") +api.add_resource(DraftWorkflowImportApi, "/apps//workflows/draft/import") +api.add_resource(AdvancedChatDraftWorkflowRunApi, "/apps//advanced-chat/workflows/draft/run") +api.add_resource(DraftWorkflowRunApi, "/apps//workflows/draft/run") +api.add_resource(WorkflowTaskStopApi, "/apps//workflow-runs/tasks//stop") +api.add_resource(DraftWorkflowNodeRunApi, "/apps//workflows/draft/nodes//run") +api.add_resource( + AdvancedChatDraftRunIterationNodeApi, + "/apps//advanced-chat/workflows/draft/iteration/nodes//run", +) +api.add_resource( + WorkflowDraftRunIterationNodeApi, "/apps//workflows/draft/iteration/nodes//run" +) +api.add_resource(PublishedWorkflowApi, "/apps//workflows/publish") +api.add_resource(DefaultBlockConfigsApi, "/apps//workflows/default-workflow-block-configs") +api.add_resource( + DefaultBlockConfigApi, "/apps//workflows/default-workflow-block-configs" "/" +) +api.add_resource(ConvertToWorkflowApi, "/apps//convert-to-workflow") diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 6d1709ed8..dc962409c 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -22,20 +22,19 @@ class WorkflowAppLogApi(Resource): Get workflow app logs """ parser = reqparse.RequestParser() - parser.add_argument('keyword', type=str, location='args') - parser.add_argument('status', type=str, choices=['succeeded', 'failed', 'stopped'], location='args') - parser.add_argument('page', type=int_range(1, 99999), default=1, location='args') - parser.add_argument('limit', type=int_range(1, 100), default=20, location='args') + parser.add_argument("keyword", type=str, location="args") + parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") + parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") + parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() # get paginate workflow app logs workflow_app_service = WorkflowAppService() workflow_app_log_pagination = workflow_app_service.get_paginate_workflow_app_logs( - app_model=app_model, - args=args + app_model=app_model, args=args ) return workflow_app_log_pagination -api.add_resource(WorkflowAppLogApi, '/apps//workflow-app-logs') +api.add_resource(WorkflowAppLogApi, "/apps//workflow-app-logs") diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 35d982e37..a055d03de 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -28,15 +28,12 @@ class AdvancedChatAppWorkflowRunListApi(Resource): Get advanced chat app workflow run list """ parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() workflow_run_service = WorkflowRunService() - result = workflow_run_service.get_paginate_advanced_chat_workflow_runs( - app_model=app_model, - args=args - ) + result = workflow_run_service.get_paginate_advanced_chat_workflow_runs(app_model=app_model, args=args) return result @@ -52,15 +49,12 @@ class WorkflowRunListApi(Resource): Get workflow run list """ parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() workflow_run_service = WorkflowRunService() - result = workflow_run_service.get_paginate_workflow_runs( - app_model=app_model, - args=args - ) + result = workflow_run_service.get_paginate_workflow_runs(app_model=app_model, args=args) return result @@ -98,12 +92,10 @@ class WorkflowRunNodeExecutionListApi(Resource): workflow_run_service = WorkflowRunService() node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id) - return { - 'data': node_executions - } + return {"data": node_executions} -api.add_resource(AdvancedChatAppWorkflowRunListApi, '/apps//advanced-chat/workflow-runs') -api.add_resource(WorkflowRunListApi, '/apps//workflow-runs') -api.add_resource(WorkflowRunDetailApi, '/apps//workflow-runs/') -api.add_resource(WorkflowRunNodeExecutionListApi, '/apps//workflow-runs//node-executions') +api.add_resource(AdvancedChatAppWorkflowRunListApi, "/apps//advanced-chat/workflow-runs") +api.add_resource(WorkflowRunListApi, "/apps//workflow-runs") +api.add_resource(WorkflowRunDetailApi, "/apps//workflow-runs/") +api.add_resource(WorkflowRunNodeExecutionListApi, "/apps//workflow-runs//node-executions") diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py index 1d7dc395f..db2f68358 100644 --- a/api/controllers/console/app/workflow_statistic.py +++ b/api/controllers/console/app/workflow_statistic.py @@ -26,56 +26,56 @@ class WorkflowDailyRunsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(id) AS runs FROM workflow_runs WHERE app_id = :app_id AND triggered_from = :triggered_from - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + """ + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'runs': i.runs - }) + response_data.append({"date": str(i.date), "runs": i.runs}) + + return jsonify({"data": response_data}) - return jsonify({ - 'data': response_data - }) class WorkflowDailyTerminalsStatistic(Resource): @setup_required @@ -86,56 +86,56 @@ class WorkflowDailyTerminalsStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, count(distinct workflow_runs.created_by) AS terminal_count FROM workflow_runs WHERE app_id = :app_id AND triggered_from = :triggered_from - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + """ + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) + rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'terminal_count': i.terminal_count - }) + response_data.append({"date": str(i.date), "terminal_count": i.terminal_count}) + + return jsonify({"data": response_data}) - return jsonify({ - 'data': response_data - }) class WorkflowDailyTokenCostStatistic(Resource): @setup_required @@ -146,58 +146,63 @@ class WorkflowDailyTokenCostStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = ''' + sql_query = """ SELECT date(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, SUM(workflow_runs.total_tokens) as token_count FROM workflow_runs WHERE app_id = :app_id AND triggered_from = :triggered_from - ''' - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + """ + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at >= :start' - arg_dict['start'] = start_datetime_utc + sql_query += " and created_at >= :start" + arg_dict["start"] = start_datetime_utc - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query += ' and created_at < :end' - arg_dict['end'] = end_datetime_utc + sql_query += " and created_at < :end" + arg_dict["end"] = end_datetime_utc - sql_query += ' GROUP BY date order by date' + sql_query += " GROUP BY date order by date" response_data = [] with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'token_count': i.token_count, - }) + response_data.append( + { + "date": str(i.date), + "token_count": i.token_count, + } + ) + + return jsonify({"data": response_data}) - return jsonify({ - 'data': response_data - }) class WorkflowAverageAppInteractionStatistic(Resource): @setup_required @@ -208,8 +213,8 @@ class WorkflowAverageAppInteractionStatistic(Resource): account = current_user parser = reqparse.RequestParser() - parser.add_argument('start', type=datetime_string('%Y-%m-%d %H:%M'), location='args') - parser.add_argument('end', type=datetime_string('%Y-%m-%d %H:%M'), location='args') + parser.add_argument("start", type=datetime_string("%Y-%m-%d %H:%M"), location="args") + parser.add_argument("end", type=datetime_string("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() sql_query = """ @@ -229,50 +234,54 @@ class WorkflowAverageAppInteractionStatistic(Resource): GROUP BY date, c.created_by) sub GROUP BY sub.date """ - arg_dict = {'tz': account.timezone, 'app_id': app_model.id, 'triggered_from': WorkflowRunTriggeredFrom.APP_RUN.value} + arg_dict = { + "tz": account.timezone, + "app_id": app_model.id, + "triggered_from": WorkflowRunTriggeredFrom.APP_RUN.value, + } timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc - if args['start']: - start_datetime = datetime.strptime(args['start'], '%Y-%m-%d %H:%M') + if args["start"]: + start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - sql_query = sql_query.replace('{{start}}', ' AND c.created_at >= :start') - arg_dict['start'] = start_datetime_utc + sql_query = sql_query.replace("{{start}}", " AND c.created_at >= :start") + arg_dict["start"] = start_datetime_utc else: - sql_query = sql_query.replace('{{start}}', '') + sql_query = sql_query.replace("{{start}}", "") - if args['end']: - end_datetime = datetime.strptime(args['end'], '%Y-%m-%d %H:%M') + if args["end"]: + end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) - sql_query = sql_query.replace('{{end}}', ' and c.created_at < :end') - arg_dict['end'] = end_datetime_utc + sql_query = sql_query.replace("{{end}}", " and c.created_at < :end") + arg_dict["end"] = end_datetime_utc else: - sql_query = sql_query.replace('{{end}}', '') + sql_query = sql_query.replace("{{end}}", "") response_data = [] - + with db.engine.begin() as conn: rs = conn.execute(db.text(sql_query), arg_dict) for i in rs: - response_data.append({ - 'date': str(i.date), - 'interactions': float(i.interactions.quantize(Decimal('0.01'))) - }) + response_data.append( + {"date": str(i.date), "interactions": float(i.interactions.quantize(Decimal("0.01")))} + ) - return jsonify({ - 'data': response_data - }) + return jsonify({"data": response_data}) -api.add_resource(WorkflowDailyRunsStatistic, '/apps//workflow/statistics/daily-conversations') -api.add_resource(WorkflowDailyTerminalsStatistic, '/apps//workflow/statistics/daily-terminals') -api.add_resource(WorkflowDailyTokenCostStatistic, '/apps//workflow/statistics/token-costs') -api.add_resource(WorkflowAverageAppInteractionStatistic, '/apps//workflow/statistics/average-app-interactions') + +api.add_resource(WorkflowDailyRunsStatistic, "/apps//workflow/statistics/daily-conversations") +api.add_resource(WorkflowDailyTerminalsStatistic, "/apps//workflow/statistics/daily-terminals") +api.add_resource(WorkflowDailyTokenCostStatistic, "/apps//workflow/statistics/token-costs") +api.add_resource( + WorkflowAverageAppInteractionStatistic, "/apps//workflow/statistics/average-app-interactions" +) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index d61ab6d6a..5e0a4bc81 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -8,24 +8,23 @@ from libs.login import current_user from models.model import App, AppMode -def get_app_model(view: Optional[Callable] = None, *, - mode: Union[AppMode, list[AppMode]] = None): +def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode]] = None): def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): - if not kwargs.get('app_id'): - raise ValueError('missing app_id in path parameters') + if not kwargs.get("app_id"): + raise ValueError("missing app_id in path parameters") - app_id = kwargs.get('app_id') + app_id = kwargs.get("app_id") app_id = str(app_id) - del kwargs['app_id'] + del kwargs["app_id"] - app_model = db.session.query(App).filter( - App.id == app_id, - App.tenant_id == current_user.current_tenant_id, - App.status == 'normal' - ).first() + app_model = ( + db.session.query(App) + .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal") + .first() + ) if not app_model: raise AppNotFoundError() @@ -44,9 +43,10 @@ def get_app_model(view: Optional[Callable] = None, *, mode_values = {m.value for m in modes} raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}") - kwargs['app_model'] = app_model + kwargs["app_model"] = app_model return view_func(*args, **kwargs) + return decorated_view if view is None: diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 8efb55cdb..8ba6b53e7 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -17,60 +17,61 @@ from services.account_service import RegisterService class ActivateCheckApi(Resource): def get(self): parser = reqparse.RequestParser() - parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='args') - parser.add_argument('email', type=email, required=False, nullable=True, location='args') - parser.add_argument('token', type=str, required=True, nullable=False, location='args') + parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="args") + parser.add_argument("email", type=email, required=False, nullable=True, location="args") + parser.add_argument("token", type=str, required=True, nullable=False, location="args") args = parser.parse_args() - workspaceId = args['workspace_id'] - reg_email = args['email'] - token = args['token'] + workspaceId = args["workspace_id"] + reg_email = args["email"] + token = args["token"] invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token) - return {'is_valid': invitation is not None, 'workspace_name': invitation['tenant'].name if invitation else None} + return {"is_valid": invitation is not None, "workspace_name": invitation["tenant"].name if invitation else None} class ActivateApi(Resource): def post(self): parser = reqparse.RequestParser() - parser.add_argument('workspace_id', type=str, required=False, nullable=True, location='json') - parser.add_argument('email', type=email, required=False, nullable=True, location='json') - parser.add_argument('token', type=str, required=True, nullable=False, location='json') - parser.add_argument('name', type=str_len(30), required=True, nullable=False, location='json') - parser.add_argument('password', type=valid_password, required=True, nullable=False, location='json') - parser.add_argument('interface_language', type=supported_language, required=True, nullable=False, - location='json') - parser.add_argument('timezone', type=timezone, required=True, nullable=False, location='json') + parser.add_argument("workspace_id", type=str, required=False, nullable=True, location="json") + parser.add_argument("email", type=email, required=False, nullable=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser.add_argument("name", type=str_len(30), required=True, nullable=False, location="json") + parser.add_argument("password", type=valid_password, required=True, nullable=False, location="json") + parser.add_argument( + "interface_language", type=supported_language, required=True, nullable=False, location="json" + ) + parser.add_argument("timezone", type=timezone, required=True, nullable=False, location="json") args = parser.parse_args() - invitation = RegisterService.get_invitation_if_token_valid(args['workspace_id'], args['email'], args['token']) + invitation = RegisterService.get_invitation_if_token_valid(args["workspace_id"], args["email"], args["token"]) if invitation is None: raise AlreadyActivateError() - RegisterService.revoke_token(args['workspace_id'], args['email'], args['token']) + RegisterService.revoke_token(args["workspace_id"], args["email"], args["token"]) - account = invitation['account'] - account.name = args['name'] + account = invitation["account"] + account.name = args["name"] # generate password salt salt = secrets.token_bytes(16) base64_salt = base64.b64encode(salt).decode() # encrypt password with salt - password_hashed = hash_password(args['password'], salt) + password_hashed = hash_password(args["password"], salt) base64_password_hashed = base64.b64encode(password_hashed).decode() account.password = base64_password_hashed account.password_salt = base64_salt - account.interface_language = args['interface_language'] - account.timezone = args['timezone'] - account.interface_theme = 'light' + account.interface_language = args["interface_language"] + account.timezone = args["timezone"] + account.interface_theme = "light" account.status = AccountStatus.ACTIVE.value account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() - return {'result': 'success'} + return {"result": "success"} -api.add_resource(ActivateCheckApi, '/activate/check') -api.add_resource(ActivateApi, '/activate') +api.add_resource(ActivateCheckApi, "/activate/check") +api.add_resource(ActivateApi, "/activate") diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index f79b93b74..50db6eebc 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -19,18 +19,19 @@ class ApiKeyAuthDataSource(Resource): data_source_api_key_bindings = ApiKeyAuthService.get_provider_auth_list(current_user.current_tenant_id) if data_source_api_key_bindings: return { - 'sources': [{ - 'id': data_source_api_key_binding.id, - 'category': data_source_api_key_binding.category, - 'provider': data_source_api_key_binding.provider, - 'disabled': data_source_api_key_binding.disabled, - 'created_at': int(data_source_api_key_binding.created_at.timestamp()), - 'updated_at': int(data_source_api_key_binding.updated_at.timestamp()), - } - for data_source_api_key_binding in - data_source_api_key_bindings] + "sources": [ + { + "id": data_source_api_key_binding.id, + "category": data_source_api_key_binding.category, + "provider": data_source_api_key_binding.provider, + "disabled": data_source_api_key_binding.disabled, + "created_at": int(data_source_api_key_binding.created_at.timestamp()), + "updated_at": int(data_source_api_key_binding.updated_at.timestamp()), + } + for data_source_api_key_binding in data_source_api_key_bindings + ] } - return {'sources': []} + return {"sources": []} class ApiKeyAuthDataSourceBinding(Resource): @@ -42,16 +43,16 @@ class ApiKeyAuthDataSourceBinding(Resource): if not current_user.is_admin_or_owner: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('category', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("category", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() ApiKeyAuthService.validate_api_key_auth_args(args) try: ApiKeyAuthService.create_provider_auth(current_user.current_tenant_id, args) except Exception as e: raise ApiKeyAuthFailedError(str(e)) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ApiKeyAuthDataSourceBindingDelete(Resource): @@ -65,9 +66,9 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(ApiKeyAuthDataSource, '/api-key-auth/data-source') -api.add_resource(ApiKeyAuthDataSourceBinding, '/api-key-auth/data-source/binding') -api.add_resource(ApiKeyAuthDataSourceBindingDelete, '/api-key-auth/data-source/') +api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source") +api.add_resource(ApiKeyAuthDataSourceBinding, "/api-key-auth/data-source/binding") +api.add_resource(ApiKeyAuthDataSourceBindingDelete, "/api-key-auth/data-source/") diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 45cfa9d7e..fd31e5ccc 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -17,13 +17,13 @@ from ..wraps import account_initialization_required def get_oauth_providers(): with current_app.app_context(): - notion_oauth = NotionOAuth(client_id=dify_config.NOTION_CLIENT_ID, - client_secret=dify_config.NOTION_CLIENT_SECRET, - redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/data-source/callback/notion') + notion_oauth = NotionOAuth( + client_id=dify_config.NOTION_CLIENT_ID, + client_secret=dify_config.NOTION_CLIENT_SECRET, + redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/data-source/callback/notion", + ) - OAUTH_PROVIDERS = { - 'notion': notion_oauth - } + OAUTH_PROVIDERS = {"notion": notion_oauth} return OAUTH_PROVIDERS @@ -37,18 +37,16 @@ class OAuthDataSource(Resource): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) print(vars(oauth_provider)) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 - if dify_config.NOTION_INTEGRATION_TYPE == 'internal': + return {"error": "Invalid provider"}, 400 + if dify_config.NOTION_INTEGRATION_TYPE == "internal": internal_secret = dify_config.NOTION_INTERNAL_SECRET if not internal_secret: - return {'error': 'Internal secret is not set'}, + return ({"error": "Internal secret is not set"},) oauth_provider.save_internal_access_token(internal_secret) - return { 'data': '' } + return {"data": ""} else: auth_url = oauth_provider.get_authorization_url() - return { 'data': auth_url }, 200 - - + return {"data": auth_url}, 200 class OAuthDataSourceCallback(Resource): @@ -57,18 +55,18 @@ class OAuthDataSourceCallback(Resource): with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 - if 'code' in request.args: - code = request.args.get('code') + return {"error": "Invalid provider"}, 400 + if "code" in request.args: + code = request.args.get("code") - return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}') - elif 'error' in request.args: - error = request.args.get('error') + return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&code={code}") + elif "error" in request.args: + error = request.args.get("error") - return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}') + return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error={error}") else: - return redirect(f'{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied') - + return redirect(f"{dify_config.CONSOLE_WEB_URL}?type=notion&error=Access denied") + class OAuthDataSourceBinding(Resource): def get(self, provider: str): @@ -76,17 +74,18 @@ class OAuthDataSourceBinding(Resource): with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 - if 'code' in request.args: - code = request.args.get('code') + return {"error": "Invalid provider"}, 400 + if "code" in request.args: + code = request.args.get("code") try: oauth_provider.get_access_token(code) except requests.exceptions.HTTPError as e: logging.exception( - f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") - return {'error': 'OAuth data source process failed'}, 400 + f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}" + ) + return {"error": "OAuth data source process failed"}, 400 - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class OAuthDataSourceSync(Resource): @@ -100,18 +99,17 @@ class OAuthDataSourceSync(Resource): with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 + return {"error": "Invalid provider"}, 400 try: oauth_provider.sync_data_source(binding_id) except requests.exceptions.HTTPError as e: - logging.exception( - f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") - return {'error': 'OAuth data source process failed'}, 400 + logging.exception(f"An error occurred during the OAuthCallback process with {provider}: {e.response.text}") + return {"error": "OAuth data source process failed"}, 400 - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(OAuthDataSource, '/oauth/data-source/') -api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/') -api.add_resource(OAuthDataSourceBinding, '/oauth/data-source/binding/') -api.add_resource(OAuthDataSourceSync, '/oauth/data-source///sync') +api.add_resource(OAuthDataSource, "/oauth/data-source/") +api.add_resource(OAuthDataSourceCallback, "/oauth/data-source/callback/") +api.add_resource(OAuthDataSourceBinding, "/oauth/data-source/binding/") +api.add_resource(OAuthDataSourceSync, "/oauth/data-source///sync") diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index 53dab3298..ea23e097d 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -2,31 +2,30 @@ from libs.exception import BaseHTTPException class ApiKeyAuthFailedError(BaseHTTPException): - error_code = 'auth_failed' + error_code = "auth_failed" description = "{message}" code = 500 class InvalidEmailError(BaseHTTPException): - error_code = 'invalid_email' + error_code = "invalid_email" description = "The email address is not valid." code = 400 class PasswordMismatchError(BaseHTTPException): - error_code = 'password_mismatch' + error_code = "password_mismatch" description = "The passwords do not match." code = 400 class InvalidTokenError(BaseHTTPException): - error_code = 'invalid_or_expired_token' + error_code = "invalid_or_expired_token" description = "The token is invalid or has expired." code = 400 class PasswordResetRateLimitExceededError(BaseHTTPException): - error_code = 'password_reset_rate_limit_exceeded' + error_code = "password_reset_rate_limit_exceeded" description = "Password reset rate limit exceeded. Try again later." code = 429 - diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index d78be770a..0b01a4906 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -21,14 +21,13 @@ from services.errors.account import RateLimitExceededError class ForgotPasswordSendEmailApi(Resource): - @setup_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('email', type=str, required=True, location='json') + parser.add_argument("email", type=str, required=True, location="json") args = parser.parse_args() - email = args['email'] + email = args["email"] if not email_validate(email): raise InvalidEmailError() @@ -49,38 +48,36 @@ class ForgotPasswordSendEmailApi(Resource): class ForgotPasswordCheckApi(Resource): - @setup_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('token', type=str, required=True, nullable=False, location='json') + parser.add_argument("token", type=str, required=True, nullable=False, location="json") args = parser.parse_args() - token = args['token'] + token = args["token"] reset_data = AccountService.get_reset_password_data(token) if reset_data is None: - return {'is_valid': False, 'email': None} - return {'is_valid': True, 'email': reset_data.get('email')} + return {"is_valid": False, "email": None} + return {"is_valid": True, "email": reset_data.get("email")} class ForgotPasswordResetApi(Resource): - @setup_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('token', type=str, required=True, nullable=False, location='json') - parser.add_argument('new_password', type=valid_password, required=True, nullable=False, location='json') - parser.add_argument('password_confirm', type=valid_password, required=True, nullable=False, location='json') + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") + parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") args = parser.parse_args() - new_password = args['new_password'] - password_confirm = args['password_confirm'] + new_password = args["new_password"] + password_confirm = args["password_confirm"] if str(new_password).strip() != str(password_confirm).strip(): raise PasswordMismatchError() - token = args['token'] + token = args["token"] reset_data = AccountService.get_reset_password_data(token) if reset_data is None: @@ -94,14 +91,14 @@ class ForgotPasswordResetApi(Resource): password_hashed = hash_password(new_password, salt) base64_password_hashed = base64.b64encode(password_hashed).decode() - account = Account.query.filter_by(email=reset_data.get('email')).first() + account = Account.query.filter_by(email=reset_data.get("email")).first() account.password = base64_password_hashed account.password_salt = base64_salt db.session.commit() - return {'result': 'success'} + return {"result": "success"} -api.add_resource(ForgotPasswordSendEmailApi, '/forgot-password') -api.add_resource(ForgotPasswordCheckApi, '/forgot-password/validity') -api.add_resource(ForgotPasswordResetApi, '/forgot-password/resets') +api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") +api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") +api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets") diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index c135ece67..62837af2b 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -20,37 +20,39 @@ class LoginApi(Resource): def post(self): """Authenticate user and login.""" parser = reqparse.RequestParser() - parser.add_argument('email', type=email, required=True, location='json') - parser.add_argument('password', type=valid_password, required=True, location='json') - parser.add_argument('remember_me', type=bool, required=False, default=False, location='json') + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("password", type=valid_password, required=True, location="json") + parser.add_argument("remember_me", type=bool, required=False, default=False, location="json") args = parser.parse_args() # todo: Verify the recaptcha try: - account = AccountService.authenticate(args['email'], args['password']) + account = AccountService.authenticate(args["email"], args["password"]) except services.errors.account.AccountLoginError as e: - return {'code': 'unauthorized', 'message': str(e)}, 401 + return {"code": "unauthorized", "message": str(e)}, 401 # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) if len(tenants) == 0: - return {'result': 'fail', 'data': 'workspace not found, please contact system admin to invite you to join in a workspace'} + return { + "result": "fail", + "data": "workspace not found, please contact system admin to invite you to join in a workspace", + } token = AccountService.login(account, ip_address=get_remote_ip(request)) - return {'result': 'success', 'data': token} + return {"result": "success", "data": token} class LogoutApi(Resource): - @setup_required def get(self): account = cast(Account, flask_login.current_user) - token = request.headers.get('Authorization', '').split(' ')[1] + token = request.headers.get("Authorization", "").split(" ")[1] AccountService.logout(account=account, token=token) flask_login.logout_user() - return {'result': 'success'} + return {"result": "success"} class ResetPasswordApi(Resource): @@ -80,11 +82,11 @@ class ResetPasswordApi(Resource): # 'subject': 'Reset your Dify password', # 'html': """ #

Dear User,

- #

The Dify team has generated a new password for you, details as follows:

+ #

The Dify team has generated a new password for you, details as follows:

#

{new_password}

#

Please change your password to log in as soon as possible.

#

Regards,

- #

The Dify Team

+ #

The Dify Team

# """ # } @@ -101,8 +103,8 @@ class ResetPasswordApi(Resource): # # handle error # pass - return {'result': 'success'} + return {"result": "success"} -api.add_resource(LoginApi, '/login') -api.add_resource(LogoutApi, '/logout') +api.add_resource(LoginApi, "/login") +api.add_resource(LogoutApi, "/logout") diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 4a651bfe7..ae1b49f3e 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -25,7 +25,7 @@ def get_oauth_providers(): github_oauth = GitHubOAuth( client_id=dify_config.GITHUB_CLIENT_ID, client_secret=dify_config.GITHUB_CLIENT_SECRET, - redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/github', + redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/github", ) if not dify_config.GOOGLE_CLIENT_ID or not dify_config.GOOGLE_CLIENT_SECRET: google_oauth = None @@ -33,10 +33,10 @@ def get_oauth_providers(): google_oauth = GoogleOAuth( client_id=dify_config.GOOGLE_CLIENT_ID, client_secret=dify_config.GOOGLE_CLIENT_SECRET, - redirect_uri=dify_config.CONSOLE_API_URL + '/console/api/oauth/authorize/google', + redirect_uri=dify_config.CONSOLE_API_URL + "/console/api/oauth/authorize/google", ) - OAUTH_PROVIDERS = {'github': github_oauth, 'google': google_oauth} + OAUTH_PROVIDERS = {"github": github_oauth, "google": google_oauth} return OAUTH_PROVIDERS @@ -47,7 +47,7 @@ class OAuthLogin(Resource): oauth_provider = OAUTH_PROVIDERS.get(provider) print(vars(oauth_provider)) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 + return {"error": "Invalid provider"}, 400 auth_url = oauth_provider.get_authorization_url() return redirect(auth_url) @@ -59,20 +59,20 @@ class OAuthCallback(Resource): with current_app.app_context(): oauth_provider = OAUTH_PROVIDERS.get(provider) if not oauth_provider: - return {'error': 'Invalid provider'}, 400 + return {"error": "Invalid provider"}, 400 - code = request.args.get('code') + code = request.args.get("code") try: token = oauth_provider.get_access_token(code) user_info = oauth_provider.get_user_info(token) except requests.exceptions.HTTPError as e: - logging.exception(f'An error occurred during the OAuth process with {provider}: {e.response.text}') - return {'error': 'OAuth process failed'}, 400 + logging.exception(f"An error occurred during the OAuth process with {provider}: {e.response.text}") + return {"error": "OAuth process failed"}, 400 account = _generate_account(provider, user_info) # Check account status if account.status == AccountStatus.BANNED.value or account.status == AccountStatus.CLOSED.value: - return {'error': 'Account is banned or closed.'}, 403 + return {"error": "Account is banned or closed."}, 403 if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value @@ -83,7 +83,7 @@ class OAuthCallback(Resource): token = AccountService.login(account, ip_address=get_remote_ip(request)) - return redirect(f'{dify_config.CONSOLE_WEB_URL}?console_token={token}') + return redirect(f"{dify_config.CONSOLE_WEB_URL}?console_token={token}") def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Optional[Account]: @@ -101,7 +101,7 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): if not account: # Create account - account_name = user_info.name if user_info.name else 'Dify' + account_name = user_info.name if user_info.name else "Dify" account = RegisterService.register( email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider ) @@ -121,5 +121,5 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): return account -api.add_resource(OAuthLogin, '/oauth/login/') -api.add_resource(OAuthCallback, '/oauth/authorize/') +api.add_resource(OAuthLogin, "/oauth/login/") +api.add_resource(OAuthCallback, "/oauth/authorize/") diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py index 72a6129ef..9a1d91486 100644 --- a/api/controllers/console/billing/billing.py +++ b/api/controllers/console/billing/billing.py @@ -9,28 +9,24 @@ from services.billing_service import BillingService class Subscription(Resource): - @setup_required @login_required @account_initialization_required @only_edition_cloud def get(self): - parser = reqparse.RequestParser() - parser.add_argument('plan', type=str, required=True, location='args', choices=['professional', 'team']) - parser.add_argument('interval', type=str, required=True, location='args', choices=['month', 'year']) + parser.add_argument("plan", type=str, required=True, location="args", choices=["professional", "team"]) + parser.add_argument("interval", type=str, required=True, location="args", choices=["month", "year"]) args = parser.parse_args() BillingService.is_tenant_owner_or_admin(current_user) - return BillingService.get_subscription(args['plan'], - args['interval'], - current_user.email, - current_user.current_tenant_id) + return BillingService.get_subscription( + args["plan"], args["interval"], current_user.email, current_user.current_tenant_id + ) class Invoices(Resource): - @setup_required @login_required @account_initialization_required @@ -40,5 +36,5 @@ class Invoices(Resource): return BillingService.get_invoices(current_user.email, current_user.current_tenant_id) -api.add_resource(Subscription, '/billing/subscription') -api.add_resource(Invoices, '/billing/invoices') +api.add_resource(Subscription, "/billing/subscription") +api.add_resource(Invoices, "/billing/invoices") diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 0ca0f0a85..0e1acab94 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -22,19 +22,22 @@ from tasks.document_indexing_sync_task import document_indexing_sync_task class DataSourceApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(integrate_list_fields) def get(self): # get workspace data source integrates - data_source_integrates = db.session.query(DataSourceOauthBinding).filter( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.disabled == False - ).all() + data_source_integrates = ( + db.session.query(DataSourceOauthBinding) + .filter( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.disabled == False, + ) + .all() + ) - base_url = request.url_root.rstrip('/') + base_url = request.url_root.rstrip("/") data_source_oauth_base_path = "/console/api/oauth/data-source" providers = ["notion"] @@ -44,26 +47,30 @@ class DataSourceApi(Resource): existing_integrates = filter(lambda item: item.provider == provider, data_source_integrates) if existing_integrates: for existing_integrate in list(existing_integrates): - integrate_data.append({ - 'id': existing_integrate.id, - 'provider': provider, - 'created_at': existing_integrate.created_at, - 'is_bound': True, - 'disabled': existing_integrate.disabled, - 'source_info': existing_integrate.source_info, - 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' - }) + integrate_data.append( + { + "id": existing_integrate.id, + "provider": provider, + "created_at": existing_integrate.created_at, + "is_bound": True, + "disabled": existing_integrate.disabled, + "source_info": existing_integrate.source_info, + "link": f"{base_url}{data_source_oauth_base_path}/{provider}", + } + ) else: - integrate_data.append({ - 'id': None, - 'provider': provider, - 'created_at': None, - 'source_info': None, - 'is_bound': False, - 'disabled': None, - 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' - }) - return {'data': integrate_data}, 200 + integrate_data.append( + { + "id": None, + "provider": provider, + "created_at": None, + "source_info": None, + "is_bound": False, + "disabled": None, + "link": f"{base_url}{data_source_oauth_base_path}/{provider}", + } + ) + return {"data": integrate_data}, 200 @setup_required @login_required @@ -71,92 +78,82 @@ class DataSourceApi(Resource): def patch(self, binding_id, action): binding_id = str(binding_id) action = str(action) - data_source_binding = DataSourceOauthBinding.query.filter_by( - id=binding_id - ).first() + data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first() if data_source_binding is None: - raise NotFound('Data source binding not found.') + raise NotFound("Data source binding not found.") # enable binding - if action == 'enable': + if action == "enable": if data_source_binding.disabled: data_source_binding.disabled = False data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: - raise ValueError('Data source is not disabled.') + raise ValueError("Data source is not disabled.") # disable binding - if action == 'disable': + if action == "disable": if not data_source_binding.disabled: data_source_binding.disabled = True data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: - raise ValueError('Data source is disabled.') - return {'result': 'success'}, 200 + raise ValueError("Data source is disabled.") + return {"result": "success"}, 200 class DataSourceNotionListApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(integrate_notion_info_list_fields) def get(self): - dataset_id = request.args.get('dataset_id', default=None, type=str) + dataset_id = request.args.get("dataset_id", default=None, type=str) exist_page_ids = [] # import notion in the exist dataset if dataset_id: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') - if dataset.data_source_type != 'notion_import': - raise ValueError('Dataset is not notion type.') + raise NotFound("Dataset not found.") + if dataset.data_source_type != "notion_import": + raise ValueError("Dataset is not notion type.") documents = Document.query.filter_by( dataset_id=dataset_id, tenant_id=current_user.current_tenant_id, - data_source_type='notion_import', - enabled=True + data_source_type="notion_import", + enabled=True, ).all() if documents: for document in documents: data_source_info = json.loads(document.data_source_info) - exist_page_ids.append(data_source_info['notion_page_id']) + exist_page_ids.append(data_source_info["notion_page_id"]) # get all authorized pages data_source_bindings = DataSourceOauthBinding.query.filter_by( - tenant_id=current_user.current_tenant_id, - provider='notion', - disabled=False + tenant_id=current_user.current_tenant_id, provider="notion", disabled=False ).all() if not data_source_bindings: - return { - 'notion_info': [] - }, 200 + return {"notion_info": []}, 200 pre_import_info_list = [] for data_source_binding in data_source_bindings: source_info = data_source_binding.source_info - pages = source_info['pages'] + pages = source_info["pages"] # Filter out already bound pages for page in pages: - if page['page_id'] in exist_page_ids: - page['is_bound'] = True + if page["page_id"] in exist_page_ids: + page["is_bound"] = True else: - page['is_bound'] = False + page["is_bound"] = False pre_import_info = { - 'workspace_name': source_info['workspace_name'], - 'workspace_icon': source_info['workspace_icon'], - 'workspace_id': source_info['workspace_id'], - 'pages': pages, + "workspace_name": source_info["workspace_name"], + "workspace_icon": source_info["workspace_icon"], + "workspace_id": source_info["workspace_id"], + "pages": pages, } pre_import_info_list.append(pre_import_info) - return { - 'notion_info': pre_import_info_list - }, 200 + return {"notion_info": pre_import_info_list}, 200 class DataSourceNotionApi(Resource): - @setup_required @login_required @account_initialization_required @@ -166,64 +163,67 @@ class DataSourceNotionApi(Resource): data_source_binding = DataSourceOauthBinding.query.filter( db.and_( DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == 'notion', + DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info['workspace_id'] == f'"{workspace_id}"' + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', ) ).first() if not data_source_binding: - raise NotFound('Data source binding not found.') + raise NotFound("Data source binding not found.") extractor = NotionExtractor( notion_workspace_id=workspace_id, notion_obj_id=page_id, notion_page_type=page_type, notion_access_token=data_source_binding.access_token, - tenant_id=current_user.current_tenant_id + tenant_id=current_user.current_tenant_id, ) text_docs = extractor.extract() - return { - 'content': "\n".join([doc.page_content for doc in text_docs]) - }, 200 + return {"content": "\n".join([doc.page_content for doc in text_docs])}, 200 @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, location='json') + parser.add_argument("notion_info_list", type=list, required=True, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) - notion_info_list = args['notion_info_list'] + notion_info_list = args["notion_info_list"] extract_settings = [] for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] - for page in notion_info['pages']: + workspace_id = notion_info["workspace_id"] + for page in notion_info["pages"]: extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ "notion_workspace_id": workspace_id, - "notion_obj_id": page['page_id'], - "notion_page_type": page['type'], - "tenant_id": current_user.current_tenant_id + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_user.current_tenant_id, }, - document_model=args['doc_form'] + document_model=args["doc_form"], ) extract_settings.append(extract_setting) indexing_runner = IndexingRunner() - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, - args['process_rule'], args['doc_form'], - args['doc_language']) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + extract_settings, + args["process_rule"], + args["doc_form"], + args["doc_language"], + ) return response, 200 class DataSourceNotionDatasetSyncApi(Resource): - @setup_required @login_required @account_initialization_required @@ -240,7 +240,6 @@ class DataSourceNotionDatasetSyncApi(Resource): class DataSourceNotionDocumentSyncApi(Resource): - @setup_required @login_required @account_initialization_required @@ -258,10 +257,14 @@ class DataSourceNotionDocumentSyncApi(Resource): return 200 -api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates//') -api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages') -api.add_resource(DataSourceNotionApi, - '/notion/workspaces//pages///preview', - '/datasets/notion-indexing-estimate') -api.add_resource(DataSourceNotionDatasetSyncApi, '/datasets//notion/sync') -api.add_resource(DataSourceNotionDocumentSyncApi, '/datasets//documents//notion/sync') +api.add_resource(DataSourceApi, "/data-source/integrates", "/data-source/integrates//") +api.add_resource(DataSourceNotionListApi, "/notion/pre-import/pages") +api.add_resource( + DataSourceNotionApi, + "/notion/workspaces//pages///preview", + "/datasets/notion-indexing-estimate", +) +api.add_resource(DataSourceNotionDatasetSyncApi, "/datasets//notion/sync") +api.add_resource( + DataSourceNotionDocumentSyncApi, "/datasets//documents//notion/sync" +) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index b9a1c2515..d36973059 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -31,45 +31,40 @@ from services.dataset_service import DatasetPermissionService, DatasetService, D def _validate_name(name): if not name or len(name) < 1 or len(name) > 40: - raise ValueError('Name must be between 1 to 40 characters.') + raise ValueError("Name must be between 1 to 40 characters.") return name def _validate_description_length(description): if len(description) > 400: - raise ValueError('Description cannot exceed 400 characters.') + raise ValueError("Description cannot exceed 400 characters.") return description class DatasetListApi(Resource): - @setup_required @login_required @account_initialization_required def get(self): - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - ids = request.args.getlist('ids') - provider = request.args.get('provider', default="vendor") - search = request.args.get('keyword', default=None, type=str) - tag_ids = request.args.getlist('tag_ids') + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + ids = request.args.getlist("ids") + provider = request.args.get("provider", default="vendor") + search = request.args.get("keyword", default=None, type=str) + tag_ids = request.args.getlist("tag_ids") if ids: datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) else: - datasets, total = DatasetService.get_datasets(page, limit, provider, - current_user.current_tenant_id, current_user, search, tag_ids) + datasets, total = DatasetService.get_datasets( + page, limit, provider, current_user.current_tenant_id, current_user, search, tag_ids + ) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations( - tenant_id=current_user.current_tenant_id - ) + configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models( - model_type=ModelType.TEXT_EMBEDDING, - only_active=True - ) + embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: @@ -77,28 +72,22 @@ class DatasetListApi(Resource): data = marshal(datasets, dataset_detail_fields) for item in data: - if item['indexing_technique'] == 'high_quality': + if item["indexing_technique"] == "high_quality": item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: - item['embedding_available'] = True + item["embedding_available"] = True else: - item['embedding_available'] = False + item["embedding_available"] = False else: - item['embedding_available'] = True + item["embedding_available"] = True - if item.get('permission') == 'partial_members': - part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item['id']) - item.update({'partial_member_list': part_users_list}) + if item.get("permission") == "partial_members": + part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"]) + item.update({"partial_member_list": part_users_list}) else: - item.update({'partial_member_list': []}) + item.update({"partial_member_list": []}) - response = { - 'data': data, - 'has_more': len(datasets) == limit, - 'limit': limit, - 'total': total, - 'page': page - } + response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 @setup_required @@ -106,13 +95,21 @@ class DatasetListApi(Resource): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='type is required. Name must be between 1 to 40 characters.', - type=_validate_name) - parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help='Invalid indexing technique.') + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help="Invalid indexing technique.", + ) args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator @@ -122,9 +119,9 @@ class DatasetListApi(Resource): try: dataset = DatasetService.create_empty_dataset( tenant_id=current_user.current_tenant_id, - name=args['name'], - indexing_technique=args['indexing_technique'], - account=current_user + name=args["name"], + indexing_technique=args["indexing_technique"], + account=current_user, ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() @@ -142,42 +139,36 @@ class DatasetApi(Resource): if dataset is None: raise NotFound("Dataset not found.") try: - DatasetService.check_dataset_permission( - dataset, current_user) + DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) data = marshal(dataset, dataset_detail_fields) - if data.get('permission') == 'partial_members': + if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - data.update({'partial_member_list': part_users_list}) + data.update({"partial_member_list": part_users_list}) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations( - tenant_id=current_user.current_tenant_id - ) + configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models( - model_type=ModelType.TEXT_EMBEDDING, - only_active=True - ) + embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") - if data['indexing_technique'] == 'high_quality': + if data["indexing_technique"] == "high_quality": item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" if item_model in model_names: - data['embedding_available'] = True + data["embedding_available"] = True else: - data['embedding_available'] = False + data["embedding_available"] = False else: - data['embedding_available'] = True + data["embedding_available"] = True - if data.get('permission') == 'partial_members': + if data.get("permission") == "partial_members": part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - data.update({'partial_member_list': part_users_list}) + data.update({"partial_member_list": part_users_list}) return data, 200 @@ -191,42 +182,49 @@ class DatasetApi(Resource): raise NotFound("Dataset not found.") parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, - help='type is required. Name must be between 1 to 40 characters.', - type=_validate_name) - parser.add_argument('description', - location='json', store_missing=False, - type=_validate_description_length) - parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, - help='Invalid indexing technique.') - parser.add_argument('permission', type=str, location='json', choices=( - DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.' - ) - parser.add_argument('embedding_model', type=str, - location='json', help='Invalid embedding model.') - parser.add_argument('embedding_model_provider', type=str, - location='json', help='Invalid embedding model provider.') - parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') - parser.add_argument('partial_member_list', type=list, location='json', help='Invalid parent user list.') + parser.add_argument( + "name", + nullable=False, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument("description", location="json", store_missing=False, type=_validate_description_length) + parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + help="Invalid indexing technique.", + ) + parser.add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", + ) + parser.add_argument("embedding_model", type=str, location="json", help="Invalid embedding model.") + parser.add_argument( + "embedding_model_provider", type=str, location="json", help="Invalid embedding model provider." + ) + parser.add_argument("retrieval_model", type=dict, location="json", help="Invalid retrieval model.") + parser.add_argument("partial_member_list", type=list, location="json", help="Invalid parent user list.") args = parser.parse_args() data = request.get_json() # check embedding model setting - if data.get('indexing_technique') == 'high_quality': - DatasetService.check_embedding_model_setting(dataset.tenant_id, - data.get('embedding_model_provider'), - data.get('embedding_model') - ) + if data.get("indexing_technique") == "high_quality": + DatasetService.check_embedding_model_setting( + dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model") + ) # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator DatasetPermissionService.check_permission( - current_user, dataset, data.get('permission'), data.get('partial_member_list') + current_user, dataset, data.get("permission"), data.get("partial_member_list") ) - dataset = DatasetService.update_dataset( - dataset_id_str, args, current_user) + dataset = DatasetService.update_dataset(dataset_id_str, args, current_user) if dataset is None: raise NotFound("Dataset not found.") @@ -234,16 +232,19 @@ class DatasetApi(Resource): result_data = marshal(dataset, dataset_detail_fields) tenant_id = current_user.current_tenant_id - if data.get('partial_member_list') and data.get('permission') == 'partial_members': + if data.get("partial_member_list") and data.get("permission") == "partial_members": DatasetPermissionService.update_partial_member_list( - tenant_id, dataset_id_str, data.get('partial_member_list') + tenant_id, dataset_id_str, data.get("partial_member_list") ) # clear partial member list when permission is only_me or all_team_members - elif data.get('permission') == DatasetPermissionEnum.ONLY_ME or data.get('permission') == DatasetPermissionEnum.ALL_TEAM: + elif ( + data.get("permission") == DatasetPermissionEnum.ONLY_ME + or data.get("permission") == DatasetPermissionEnum.ALL_TEAM + ): DatasetPermissionService.clear_partial_member_list(dataset_id_str) partial_member_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - result_data.update({'partial_member_list': partial_member_list}) + result_data.update({"partial_member_list": partial_member_list}) return result_data, 200 @@ -260,12 +261,13 @@ class DatasetApi(Resource): try: if DatasetService.delete_dataset(dataset_id_str, current_user): DatasetPermissionService.clear_partial_member_list(dataset_id_str) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 else: raise NotFound("Dataset not found.") except services.errors.dataset.DatasetInUseError: raise DatasetInUseError() + class DatasetUseCheckApi(Resource): @setup_required @login_required @@ -274,10 +276,10 @@ class DatasetUseCheckApi(Resource): dataset_id_str = str(dataset_id) dataset_is_using = DatasetService.dataset_use_check(dataset_id_str) - return {'is_using': dataset_is_using}, 200 + return {"is_using": dataset_is_using}, 200 + class DatasetQueryApi(Resource): - @setup_required @login_required @account_initialization_required @@ -292,51 +294,53 @@ class DatasetQueryApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) - dataset_queries, total = DatasetService.get_dataset_queries( - dataset_id=dataset.id, - page=page, - per_page=limit - ) + dataset_queries, total = DatasetService.get_dataset_queries(dataset_id=dataset.id, page=page, per_page=limit) response = { - 'data': marshal(dataset_queries, dataset_query_detail_fields), - 'has_more': len(dataset_queries) == limit, - 'limit': limit, - 'total': total, - 'page': page + "data": marshal(dataset_queries, dataset_query_detail_fields), + "has_more": len(dataset_queries) == limit, + "limit": limit, + "total": total, + "page": page, } return response, 200 class DatasetIndexingEstimateApi(Resource): - @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('indexing_technique', type=str, required=True, - choices=Dataset.INDEXING_TECHNIQUE_LIST, - nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') + parser.add_argument("info_list", type=dict, required=True, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + parser.add_argument( + "indexing_technique", + type=str, + required=True, + choices=Dataset.INDEXING_TECHNIQUE_LIST, + nullable=True, + location="json", + ) + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument("dataset_id", type=str, required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) args = parser.parse_args() # validate args DocumentService.estimate_args_validate(args) extract_settings = [] - if args['info_list']['data_source_type'] == 'upload_file': - file_ids = args['info_list']['file_info_list']['file_ids'] - file_details = db.session.query(UploadFile).filter( - UploadFile.tenant_id == current_user.current_tenant_id, - UploadFile.id.in_(file_ids) - ).all() + if args["info_list"]["data_source_type"] == "upload_file": + file_ids = args["info_list"]["file_info_list"]["file_ids"] + file_details = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids)) + .all() + ) if file_details is None: raise NotFound("File not found.") @@ -344,55 +348,58 @@ class DatasetIndexingEstimateApi(Resource): if file_details: for file_detail in file_details: extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file_detail, - document_model=args['doc_form'] + datasource_type="upload_file", upload_file=file_detail, document_model=args["doc_form"] ) extract_settings.append(extract_setting) - elif args['info_list']['data_source_type'] == 'notion_import': - notion_info_list = args['info_list']['notion_info_list'] + elif args["info_list"]["data_source_type"] == "notion_import": + notion_info_list = args["info_list"]["notion_info_list"] for notion_info in notion_info_list: - workspace_id = notion_info['workspace_id'] - for page in notion_info['pages']: + workspace_id = notion_info["workspace_id"] + for page in notion_info["pages"]: extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ "notion_workspace_id": workspace_id, - "notion_obj_id": page['page_id'], - "notion_page_type": page['type'], - "tenant_id": current_user.current_tenant_id + "notion_obj_id": page["page_id"], + "notion_page_type": page["type"], + "tenant_id": current_user.current_tenant_id, }, - document_model=args['doc_form'] + document_model=args["doc_form"], ) extract_settings.append(extract_setting) - elif args['info_list']['data_source_type'] == 'website_crawl': - website_info_list = args['info_list']['website_info_list'] - for url in website_info_list['urls']: + elif args["info_list"]["data_source_type"] == "website_crawl": + website_info_list = args["info_list"]["website_info_list"] + for url in website_info_list["urls"]: extract_setting = ExtractSetting( datasource_type="website_crawl", website_info={ - "provider": website_info_list['provider'], - "job_id": website_info_list['job_id'], + "provider": website_info_list["provider"], + "job_id": website_info_list["job_id"], "url": url, "tenant_id": current_user.current_tenant_id, - "mode": 'crawl', - "only_main_content": website_info_list['only_main_content'] + "mode": "crawl", + "only_main_content": website_info_list["only_main_content"], }, - document_model=args['doc_form'] + document_model=args["doc_form"], ) extract_settings.append(extract_setting) else: - raise ValueError('Data source type not support') + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, - args['process_rule'], args['doc_form'], - args['doc_language'], args['dataset_id'], - args['indexing_technique']) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + extract_settings, + args["process_rule"], + args["doc_form"], + args["doc_language"], + args["dataset_id"], + args["indexing_technique"], + ) except LLMBadRequestError: raise ProviderNotInitializeError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "No Embedding Model available. Please configure a valid provider " "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except Exception as e: @@ -402,7 +409,6 @@ class DatasetIndexingEstimateApi(Resource): class DatasetRelatedAppListApi(Resource): - @setup_required @login_required @account_initialization_required @@ -426,52 +432,52 @@ class DatasetRelatedAppListApi(Resource): if app_model: related_apps.append(app_model) - return { - 'data': related_apps, - 'total': len(related_apps) - }, 200 + return {"data": related_apps, "total": len(related_apps)}, 200 class DatasetIndexingStatusApi(Resource): - @setup_required @login_required @account_initialization_required def get(self, dataset_id): dataset_id = str(dataset_id) - documents = db.session.query(Document).filter( - Document.dataset_id == dataset_id, - Document.tenant_id == current_user.current_tenant_id - ).all() + documents = ( + db.session.query(Document) + .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id) + .all() + ) documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments documents_status.append(marshal(document, document_status_fields)) - data = { - 'data': documents_status - } + data = {"data": documents_status} return data class DatasetApiKeyApi(Resource): max_keys = 10 - token_prefix = 'dataset-' - resource_type = 'dataset' + token_prefix = "dataset-" + resource_type = "dataset" @setup_required @login_required @account_initialization_required @marshal_with(api_key_list) def get(self): - keys = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ - all() + keys = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) + .all() + ) return {"items": keys} @setup_required @@ -483,15 +489,17 @@ class DatasetApiKeyApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - current_key_count = db.session.query(ApiToken). \ - filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ - count() + current_key_count = ( + db.session.query(ApiToken) + .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id) + .count() + ) if current_key_count >= self.max_keys: flask_restful.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", - code='max_keys_exceeded' + code="max_keys_exceeded", ) key = ApiToken.generate_api_key(self.token_prefix, 24) @@ -505,7 +513,7 @@ class DatasetApiKeyApi(Resource): class DatasetApiDeleteApi(Resource): - resource_type = 'dataset' + resource_type = "dataset" @setup_required @login_required @@ -517,18 +525,23 @@ class DatasetApiDeleteApi(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - key = db.session.query(ApiToken). \ - filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type, - ApiToken.id == api_key_id). \ - first() + key = ( + db.session.query(ApiToken) + .filter( + ApiToken.tenant_id == current_user.current_tenant_id, + ApiToken.type == self.resource_type, + ApiToken.id == api_key_id, + ) + .first() + ) if key is None: - flask_restful.abort(404, message='API key not found') + flask_restful.abort(404, message="API key not found") db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() db.session.commit() - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DatasetApiBaseUrlApi(Resource): @@ -537,8 +550,10 @@ class DatasetApiBaseUrlApi(Resource): @account_initialization_required def get(self): return { - 'api_base_url': (dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL - else request.host_url.rstrip('/')) + '/v1' + "api_base_url": ( + dify_config.SERVICE_API_URL if dify_config.SERVICE_API_URL else request.host_url.rstrip("/") + ) + + "/v1" } @@ -549,15 +564,26 @@ class DatasetRetrievalSettingApi(Resource): def get(self): vector_type = dify_config.VECTOR_STORE match vector_type: - case VectorType.MILVUS | VectorType.RELYT | VectorType.PGVECTOR | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT: + case ( + VectorType.MILVUS + | VectorType.RELYT + | VectorType.PGVECTOR + | VectorType.TIDB_VECTOR + | VectorType.CHROMA + | VectorType.TENCENT + ): + return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} + case ( + VectorType.QDRANT + | VectorType.WEAVIATE + | VectorType.OPENSEARCH + | VectorType.ANALYTICDB + | VectorType.MYSCALE + | VectorType.ORACLE + | VectorType.ELASTICSEARCH + ): return { - 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH.value - ] - } - case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH: - return { - 'retrieval_method': [ + "retrieval_method": [ RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value, @@ -573,15 +599,27 @@ class DatasetRetrievalSettingMockApi(Resource): @account_initialization_required def get(self, vector_type): match vector_type: - case VectorType.MILVUS | VectorType.RELYT | VectorType.TIDB_VECTOR | VectorType.CHROMA | VectorType.TENCENT | VectorType.PGVECTO_RS: + case ( + VectorType.MILVUS + | VectorType.RELYT + | VectorType.TIDB_VECTOR + | VectorType.CHROMA + | VectorType.TENCENT + | VectorType.PGVECTO_RS + ): + return {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} + case ( + VectorType.QDRANT + | VectorType.WEAVIATE + | VectorType.OPENSEARCH + | VectorType.ANALYTICDB + | VectorType.MYSCALE + | VectorType.ORACLE + | VectorType.ELASTICSEARCH + | VectorType.PGVECTOR + ): return { - 'retrieval_method': [ - RetrievalMethod.SEMANTIC_SEARCH.value - ] - } - case VectorType.QDRANT | VectorType.WEAVIATE | VectorType.OPENSEARCH | VectorType.ANALYTICDB | VectorType.MYSCALE | VectorType.ORACLE | VectorType.ELASTICSEARCH | VectorType.PGVECTOR: - return { - 'retrieval_method': [ + "retrieval_method": [ RetrievalMethod.SEMANTIC_SEARCH.value, RetrievalMethod.FULL_TEXT_SEARCH.value, RetrievalMethod.HYBRID_SEARCH.value, @@ -591,7 +629,6 @@ class DatasetRetrievalSettingMockApi(Resource): raise ValueError(f"Unsupported vector db type {vector_type}.") - class DatasetErrorDocs(Resource): @setup_required @login_required @@ -603,10 +640,7 @@ class DatasetErrorDocs(Resource): raise NotFound("Dataset not found.") results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str) - return { - 'data': [marshal(item, document_status_fields) for item in results], - 'total': len(results) - }, 200 + return {"data": [marshal(item, document_status_fields) for item in results], "total": len(results)}, 200 class DatasetPermissionUserListApi(Resource): @@ -626,21 +660,21 @@ class DatasetPermissionUserListApi(Resource): partial_members_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) return { - 'data': partial_members_list, + "data": partial_members_list, }, 200 -api.add_resource(DatasetListApi, '/datasets') -api.add_resource(DatasetApi, '/datasets/') -api.add_resource(DatasetUseCheckApi, '/datasets//use-check') -api.add_resource(DatasetQueryApi, '/datasets//queries') -api.add_resource(DatasetErrorDocs, '/datasets//error-docs') -api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate') -api.add_resource(DatasetRelatedAppListApi, '/datasets//related-apps') -api.add_resource(DatasetIndexingStatusApi, '/datasets//indexing-status') -api.add_resource(DatasetApiKeyApi, '/datasets/api-keys') -api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/') -api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') -api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') -api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/') -api.add_resource(DatasetPermissionUserListApi, '/datasets//permission-part-users') +api.add_resource(DatasetListApi, "/datasets") +api.add_resource(DatasetApi, "/datasets/") +api.add_resource(DatasetUseCheckApi, "/datasets//use-check") +api.add_resource(DatasetQueryApi, "/datasets//queries") +api.add_resource(DatasetErrorDocs, "/datasets//error-docs") +api.add_resource(DatasetIndexingEstimateApi, "/datasets/indexing-estimate") +api.add_resource(DatasetRelatedAppListApi, "/datasets//related-apps") +api.add_resource(DatasetIndexingStatusApi, "/datasets//indexing-status") +api.add_resource(DatasetApiKeyApi, "/datasets/api-keys") +api.add_resource(DatasetApiDeleteApi, "/datasets/api-keys/") +api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info") +api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") +api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/") +api.add_resource(DatasetPermissionUserListApi, "/datasets//permission-part-users") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 976b97660..7d0b9f046 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -57,7 +57,7 @@ class DocumentResource(Resource): def get_document(self, dataset_id: str, document_id: str) -> Document: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -67,17 +67,17 @@ class DocumentResource(Resource): document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") if document.tenant_id != current_user.current_tenant_id: - raise Forbidden('No permission.') + raise Forbidden("No permission.") return document def get_batch_documents(self, dataset_id: str, batch: str) -> list[Document]: dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -87,7 +87,7 @@ class DocumentResource(Resource): documents = DocumentService.get_batch_documents(dataset_id, batch) if not documents: - raise NotFound('Documents not found.') + raise NotFound("Documents not found.") return documents @@ -99,11 +99,11 @@ class GetProcessRuleApi(Resource): def get(self): req_data = request.args - document_id = req_data.get('document_id') + document_id = req_data.get("document_id") # get default rules - mode = DocumentService.DEFAULT_RULES['mode'] - rules = DocumentService.DEFAULT_RULES['rules'] + mode = DocumentService.DEFAULT_RULES["mode"] + rules = DocumentService.DEFAULT_RULES["rules"] if document_id: # get the latest process rule document = Document.query.get_or_404(document_id) @@ -111,7 +111,7 @@ class GetProcessRuleApi(Resource): dataset = DatasetService.get_dataset(document.dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -119,19 +119,18 @@ class GetProcessRuleApi(Resource): raise Forbidden(str(e)) # get the latest process rule - dataset_process_rule = db.session.query(DatasetProcessRule). \ - filter(DatasetProcessRule.dataset_id == document.dataset_id). \ - order_by(DatasetProcessRule.created_at.desc()). \ - limit(1). \ - one_or_none() + dataset_process_rule = ( + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.dataset_id == document.dataset_id) + .order_by(DatasetProcessRule.created_at.desc()) + .limit(1) + .one_or_none() + ) if dataset_process_rule: mode = dataset_process_rule.mode rules = dataset_process_rule.rules_dict - return { - 'mode': mode, - 'rules': rules - } + return {"mode": mode, "rules": rules} class DatasetDocumentListApi(Resource): @@ -140,49 +139,48 @@ class DatasetDocumentListApi(Resource): @account_initialization_required def get(self, dataset_id): dataset_id = str(dataset_id) - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - search = request.args.get('keyword', default=None, type=str) - sort = request.args.get('sort', default='-created_at', type=str) + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + search = request.args.get("keyword", default=None, type=str) + sort = request.args.get("sort", default="-created_at", type=str) # "yes", "true", "t", "y", "1" convert to True, while others convert to False. try: - fetch = string_to_bool(request.args.get('fetch', default='false')) + fetch = string_to_bool(request.args.get("fetch", default="false")) except (ArgumentTypeError, ValueError, Exception) as e: fetch = False dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - query = Document.query.filter_by( - dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) + query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) if search: - search = f'%{search}%' + search = f"%{search}%" query = query.filter(Document.name.like(search)) - if sort.startswith('-'): + if sort.startswith("-"): sort_logic = desc sort = sort[1:] else: sort_logic = asc - if sort == 'hit_count': - sub_query = db.select(DocumentSegment.document_id, - db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) \ - .group_by(DocumentSegment.document_id) \ + if sort == "hit_count": + sub_query = ( + db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")) + .group_by(DocumentSegment.document_id) .subquery() + ) - query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id) \ - .order_by( - sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), - sort_logic(Document.position), - ) - elif sort == 'created_at': + query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by( + sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)), + sort_logic(Document.position), + ) + elif sort == "created_at": query = query.order_by( sort_logic(Document.created_at), sort_logic(Document.position), @@ -193,48 +191,47 @@ class DatasetDocumentListApi(Resource): desc(Document.position), ) - paginated_documents = query.paginate( - page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items if fetch: for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments data = marshal(documents, document_with_segments_fields) else: data = marshal(documents, document_fields) response = { - 'data': data, - 'has_more': len(documents) == limit, - 'limit': limit, - 'total': paginated_documents.total, - 'page': page + "data": data, + "has_more": len(documents) == limit, + "limit": limit, + "total": paginated_documents.total, + "page": page, } return response - documents_and_batch_fields = { - 'documents': fields.List(fields.Nested(document_fields)), - 'batch': fields.String - } + documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String} @setup_required @login_required @account_initialization_required @marshal_with(documents_and_batch_fields) - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def post(self, dataset_id): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_dataset_editor: @@ -246,21 +243,22 @@ class DatasetDocumentListApi(Resource): raise Forbidden(str(e)) parser = reqparse.RequestParser() - parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, - location='json') - parser.add_argument('data_source', type=dict, required=False, location='json') - parser.add_argument('process_rule', type=dict, required=False, location='json') - parser.add_argument('duplicate', type=bool, default=True, nullable=False, location='json') - parser.add_argument('original_document_id', type=str, required=False, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument( + "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" + ) + parser.add_argument("data_source", type=dict, required=False, location="json") + parser.add_argument("process_rule", type=dict, required=False, location="json") + parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") + parser.add_argument("original_document_id", type=str, required=False, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() - if not dataset.indexing_technique and not args['indexing_technique']: - raise ValueError('indexing_technique is required.') + if not dataset.indexing_technique and not args["indexing_technique"]: + raise ValueError("indexing_technique is required.") # validate args DocumentService.document_create_args_validate(args) @@ -274,51 +272,53 @@ class DatasetDocumentListApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - return { - 'documents': documents, - 'batch': batch - } + return {"documents": documents, "batch": batch} class DatasetInitApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(dataset_and_document_fields) - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def post(self): # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, required=True, - nullable=False, location='json') - parser.add_argument('data_source', type=dict, required=True, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument( + "indexing_technique", + type=str, + choices=Dataset.INDEXING_TECHNIQUE_LIST, + required=True, + nullable=False, + location="json", + ) + parser.add_argument("data_source", type=dict, required=True, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=True, nullable=True, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator if not current_user.is_dataset_editor: raise Forbidden() - if args['indexing_technique'] == 'high_quality': + if args["indexing_technique"] == "high_quality": try: model_manager = ModelManager() model_manager.get_default_model_instance( - tenant_id=current_user.current_tenant_id, - model_type=ModelType.TEXT_EMBEDDING + tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING ) except InvokeAuthorizationError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -327,9 +327,7 @@ class DatasetInitApi(Resource): try: dataset, documents, batch = DocumentService.save_document_without_dataset_id( - tenant_id=current_user.current_tenant_id, - document_data=args, - account=current_user + tenant_id=current_user.current_tenant_id, document_data=args, account=current_user ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) @@ -338,17 +336,12 @@ class DatasetInitApi(Resource): except ModelCurrentlyNotSupportError: raise ProviderModelCurrentlyNotSupportError() - response = { - 'dataset': dataset, - 'documents': documents, - 'batch': batch - } + response = {"dataset": dataset, "documents": documents, "batch": batch} return response class DocumentIndexingEstimateApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -357,50 +350,49 @@ class DocumentIndexingEstimateApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - if document.indexing_status in ['completed', 'error']: + if document.indexing_status in ["completed", "error"]: raise DocumentAlreadyFinishedError() data_process_rule = document.dataset_process_rule data_process_rule_dict = data_process_rule.to_dict() - response = { - "tokens": 0, - "total_price": 0, - "currency": "USD", - "total_segments": 0, - "preview": [] - } + response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} - if document.data_source_type == 'upload_file': + if document.data_source_type == "upload_file": data_source_info = document.data_source_info_dict - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] - file = db.session.query(UploadFile).filter( - UploadFile.tenant_id == document.tenant_id, - UploadFile.id == file_id - ).first() + file = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id) + .first() + ) # raise error if file not found if not file: - raise NotFound('File not found.') + raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file, - document_model=document.doc_form + datasource_type="upload_file", upload_file=file, document_model=document.doc_form ) indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, [extract_setting], - data_process_rule_dict, document.doc_form, - 'English', dataset_id) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + [extract_setting], + data_process_rule_dict, + document.doc_form, + "English", + dataset_id, + ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except Exception as e: @@ -410,7 +402,6 @@ class DocumentIndexingEstimateApi(DocumentResource): class DocumentBatchIndexingEstimateApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -418,13 +409,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): dataset_id = str(dataset_id) batch = str(batch) documents = self.get_batch_documents(dataset_id, batch) - response = { - "tokens": 0, - "total_price": 0, - "currency": "USD", - "total_segments": 0, - "preview": [] - } + response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []} if not documents: return response data_process_rule = documents[0].dataset_process_rule @@ -432,82 +417,83 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): info_list = [] extract_settings = [] for document in documents: - if document.indexing_status in ['completed', 'error']: + if document.indexing_status in ["completed", "error"]: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict # format document files info - if data_source_info and 'upload_file_id' in data_source_info: - file_id = data_source_info['upload_file_id'] + if data_source_info and "upload_file_id" in data_source_info: + file_id = data_source_info["upload_file_id"] info_list.append(file_id) # format document notion info - elif data_source_info and 'notion_workspace_id' in data_source_info and 'notion_page_id' in data_source_info: + elif ( + data_source_info and "notion_workspace_id" in data_source_info and "notion_page_id" in data_source_info + ): pages = [] - page = { - 'page_id': data_source_info['notion_page_id'], - 'type': data_source_info['type'] - } + page = {"page_id": data_source_info["notion_page_id"], "type": data_source_info["type"]} pages.append(page) - notion_info = { - 'workspace_id': data_source_info['notion_workspace_id'], - 'pages': pages - } + notion_info = {"workspace_id": data_source_info["notion_workspace_id"], "pages": pages} info_list.append(notion_info) - if document.data_source_type == 'upload_file': - file_id = data_source_info['upload_file_id'] - file_detail = db.session.query(UploadFile).filter( - UploadFile.tenant_id == current_user.current_tenant_id, - UploadFile.id == file_id - ).first() + if document.data_source_type == "upload_file": + file_id = data_source_info["upload_file_id"] + file_detail = ( + db.session.query(UploadFile) + .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id) + .first() + ) if file_detail is None: raise NotFound("File not found.") extract_setting = ExtractSetting( - datasource_type="upload_file", - upload_file=file_detail, - document_model=document.doc_form + datasource_type="upload_file", upload_file=file_detail, document_model=document.doc_form ) extract_settings.append(extract_setting) - elif document.data_source_type == 'notion_import': + elif document.data_source_type == "notion_import": extract_setting = ExtractSetting( datasource_type="notion_import", notion_info={ - "notion_workspace_id": data_source_info['notion_workspace_id'], - "notion_obj_id": data_source_info['notion_page_id'], - "notion_page_type": data_source_info['type'], - "tenant_id": current_user.current_tenant_id + "notion_workspace_id": data_source_info["notion_workspace_id"], + "notion_obj_id": data_source_info["notion_page_id"], + "notion_page_type": data_source_info["type"], + "tenant_id": current_user.current_tenant_id, }, - document_model=document.doc_form + document_model=document.doc_form, ) extract_settings.append(extract_setting) - elif document.data_source_type == 'website_crawl': + elif document.data_source_type == "website_crawl": extract_setting = ExtractSetting( datasource_type="website_crawl", website_info={ - "provider": data_source_info['provider'], - "job_id": data_source_info['job_id'], - "url": data_source_info['url'], + "provider": data_source_info["provider"], + "job_id": data_source_info["job_id"], + "url": data_source_info["url"], "tenant_id": current_user.current_tenant_id, - "mode": data_source_info['mode'], - "only_main_content": data_source_info['only_main_content'] + "mode": data_source_info["mode"], + "only_main_content": data_source_info["only_main_content"], }, - document_model=document.doc_form + document_model=document.doc_form, ) extract_settings.append(extract_setting) else: - raise ValueError('Data source type not support') + raise ValueError("Data source type not support") indexing_runner = IndexingRunner() try: - response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, - data_process_rule_dict, document.doc_form, - 'English', dataset_id) + response = indexing_runner.indexing_estimate( + current_user.current_tenant_id, + extract_settings, + data_process_rule_dict, + document.doc_form, + "English", + dataset_id, + ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except Exception as e: @@ -516,7 +502,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): class DocumentBatchIndexingStatusApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -526,24 +511,24 @@ class DocumentBatchIndexingStatusApi(DocumentResource): documents = self.get_batch_documents(dataset_id, batch) documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: - document.indexing_status = 'paused' + document.indexing_status = "paused" documents_status.append(marshal(document, document_status_fields)) - data = { - 'data': documents_status - } + data = {"data": documents_status} return data class DocumentIndexingStatusApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -552,25 +537,24 @@ class DocumentIndexingStatusApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - completed_segments = DocumentSegment.query \ - .filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document_id), - DocumentSegment.status != 're_segment') \ - .count() - total_segments = DocumentSegment.query \ - .filter(DocumentSegment.document_id == str(document_id), - DocumentSegment.status != 're_segment') \ - .count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: - document.indexing_status = 'paused' + document.indexing_status = "paused" return marshal(document, document_status_fields) class DocumentDetailApi(DocumentResource): - METADATA_CHOICES = {'all', 'only', 'without'} + METADATA_CHOICES = {"all", "only", "without"} @setup_required @login_required @@ -580,77 +564,73 @@ class DocumentDetailApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - metadata = request.args.get('metadata', 'all') + metadata = request.args.get("metadata", "all") if metadata not in self.METADATA_CHOICES: - raise InvalidMetadataError(f'Invalid metadata value: {metadata}') + raise InvalidMetadataError(f"Invalid metadata value: {metadata}") - if metadata == 'only': - response = { - 'id': document.id, - 'doc_type': document.doc_type, - 'doc_metadata': document.doc_metadata - } - elif metadata == 'without': + if metadata == "only": + response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} + elif metadata == "without": process_rules = DatasetService.get_process_rules(dataset_id) data_source_info = document.data_source_detail_dict response = { - 'id': document.id, - 'position': document.position, - 'data_source_type': document.data_source_type, - 'data_source_info': data_source_info, - 'dataset_process_rule_id': document.dataset_process_rule_id, - 'dataset_process_rule': process_rules, - 'name': document.name, - 'created_from': document.created_from, - 'created_by': document.created_by, - 'created_at': document.created_at.timestamp(), - 'tokens': document.tokens, - 'indexing_status': document.indexing_status, - 'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None, - 'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None, - 'indexing_latency': document.indexing_latency, - 'error': document.error, - 'enabled': document.enabled, - 'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None, - 'disabled_by': document.disabled_by, - 'archived': document.archived, - 'segment_count': document.segment_count, - 'average_segment_length': document.average_segment_length, - 'hit_count': document.hit_count, - 'display_status': document.display_status, - 'doc_form': document.doc_form + "id": document.id, + "position": document.position, + "data_source_type": document.data_source_type, + "data_source_info": data_source_info, + "dataset_process_rule_id": document.dataset_process_rule_id, + "dataset_process_rule": process_rules, + "name": document.name, + "created_from": document.created_from, + "created_by": document.created_by, + "created_at": document.created_at.timestamp(), + "tokens": document.tokens, + "indexing_status": document.indexing_status, + "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None, + "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None, + "indexing_latency": document.indexing_latency, + "error": document.error, + "enabled": document.enabled, + "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None, + "disabled_by": document.disabled_by, + "archived": document.archived, + "segment_count": document.segment_count, + "average_segment_length": document.average_segment_length, + "hit_count": document.hit_count, + "display_status": document.display_status, + "doc_form": document.doc_form, } else: process_rules = DatasetService.get_process_rules(dataset_id) data_source_info = document.data_source_detail_dict response = { - 'id': document.id, - 'position': document.position, - 'data_source_type': document.data_source_type, - 'data_source_info': data_source_info, - 'dataset_process_rule_id': document.dataset_process_rule_id, - 'dataset_process_rule': process_rules, - 'name': document.name, - 'created_from': document.created_from, - 'created_by': document.created_by, - 'created_at': document.created_at.timestamp(), - 'tokens': document.tokens, - 'indexing_status': document.indexing_status, - 'completed_at': int(document.completed_at.timestamp()) if document.completed_at else None, - 'updated_at': int(document.updated_at.timestamp()) if document.updated_at else None, - 'indexing_latency': document.indexing_latency, - 'error': document.error, - 'enabled': document.enabled, - 'disabled_at': int(document.disabled_at.timestamp()) if document.disabled_at else None, - 'disabled_by': document.disabled_by, - 'archived': document.archived, - 'doc_type': document.doc_type, - 'doc_metadata': document.doc_metadata, - 'segment_count': document.segment_count, - 'average_segment_length': document.average_segment_length, - 'hit_count': document.hit_count, - 'display_status': document.display_status, - 'doc_form': document.doc_form + "id": document.id, + "position": document.position, + "data_source_type": document.data_source_type, + "data_source_info": data_source_info, + "dataset_process_rule_id": document.dataset_process_rule_id, + "dataset_process_rule": process_rules, + "name": document.name, + "created_from": document.created_from, + "created_by": document.created_by, + "created_at": document.created_at.timestamp(), + "tokens": document.tokens, + "indexing_status": document.indexing_status, + "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None, + "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None, + "indexing_latency": document.indexing_latency, + "error": document.error, + "enabled": document.enabled, + "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None, + "disabled_by": document.disabled_by, + "archived": document.archived, + "doc_type": document.doc_type, + "doc_metadata": document.doc_metadata, + "segment_count": document.segment_count, + "average_segment_length": document.average_segment_length, + "hit_count": document.hit_count, + "display_status": document.display_status, + "doc_form": document.doc_form, } return response, 200 @@ -671,7 +651,7 @@ class DocumentProcessingApi(DocumentResource): if action == "pause": if document.indexing_status != "indexing": - raise InvalidActionError('Document not in indexing state.') + raise InvalidActionError("Document not in indexing state.") document.paused_by = current_user.id document.paused_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -680,7 +660,7 @@ class DocumentProcessingApi(DocumentResource): elif action == "resume": if document.indexing_status not in ["paused", "error"]: - raise InvalidActionError('Document not in paused or error state.') + raise InvalidActionError("Document not in paused or error state.") document.paused_by = None document.paused_at = None @@ -689,7 +669,7 @@ class DocumentProcessingApi(DocumentResource): else: raise InvalidActionError() - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class DocumentDeleteApi(DocumentResource): @@ -710,9 +690,9 @@ class DocumentDeleteApi(DocumentResource): try: DocumentService.delete_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot delete document during indexing.') + raise DocumentIndexingError("Cannot delete document during indexing.") - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentMetadataApi(DocumentResource): @@ -726,26 +706,26 @@ class DocumentMetadataApi(DocumentResource): req_data = request.get_json() - doc_type = req_data.get('doc_type') - doc_metadata = req_data.get('doc_metadata') + doc_type = req_data.get("doc_type") + doc_metadata = req_data.get("doc_metadata") # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() if doc_type is None or doc_metadata is None: - raise ValueError('Both doc_type and doc_metadata must be provided.') + raise ValueError("Both doc_type and doc_metadata must be provided.") if doc_type not in DocumentService.DOCUMENT_METADATA_SCHEMA: - raise ValueError('Invalid doc_type.') + raise ValueError("Invalid doc_type.") if not isinstance(doc_metadata, dict): - raise ValueError('doc_metadata must be a dictionary.') + raise ValueError("doc_metadata must be a dictionary.") metadata_schema = DocumentService.DOCUMENT_METADATA_SCHEMA[doc_type] document.doc_metadata = {} - if doc_type == 'others': + if doc_type == "others": document.doc_metadata = doc_metadata else: for key, value_type in metadata_schema.items(): @@ -757,14 +737,14 @@ class DocumentMetadataApi(DocumentResource): document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() - return {'result': 'success', 'message': 'Document metadata updated.'}, 200 + return {"result": "success", "message": "Document metadata updated."}, 200 class DocumentStatusApi(DocumentResource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def patch(self, dataset_id, document_id, action): dataset_id = str(dataset_id) document_id = str(document_id) @@ -784,14 +764,14 @@ class DocumentStatusApi(DocumentResource): document = self.get_document(dataset_id, document_id) - indexing_cache_key = 'document_{}_indexing'.format(document.id) + indexing_cache_key = "document_{}_indexing".format(document.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise InvalidActionError("Document is being indexed, please try again later") if action == "enable": if document.enabled: - raise InvalidActionError('Document already enabled.') + raise InvalidActionError("Document already enabled.") document.enabled = True document.disabled_at = None @@ -804,13 +784,13 @@ class DocumentStatusApi(DocumentResource): add_document_to_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "disable": - if not document.completed_at or document.indexing_status != 'completed': - raise InvalidActionError('Document is not completed.') + if not document.completed_at or document.indexing_status != "completed": + raise InvalidActionError("Document is not completed.") if not document.enabled: - raise InvalidActionError('Document already disabled.') + raise InvalidActionError("Document already disabled.") document.enabled = False document.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -823,11 +803,11 @@ class DocumentStatusApi(DocumentResource): remove_document_from_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "archive": if document.archived: - raise InvalidActionError('Document already archived.') + raise InvalidActionError("Document already archived.") document.archived = True document.archived_at = datetime.now(timezone.utc).replace(tzinfo=None) @@ -841,10 +821,10 @@ class DocumentStatusApi(DocumentResource): remove_document_from_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "un_archive": if not document.archived: - raise InvalidActionError('Document is not archived.') + raise InvalidActionError("Document is not archived.") document.archived = False document.archived_at = None @@ -857,13 +837,12 @@ class DocumentStatusApi(DocumentResource): add_document_to_index_task.delay(document_id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 else: raise InvalidActionError() class DocumentPauseApi(DocumentResource): - @setup_required @login_required @account_initialization_required @@ -874,7 +853,7 @@ class DocumentPauseApi(DocumentResource): dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") document = DocumentService.get_document(dataset.id, document_id) @@ -890,9 +869,9 @@ class DocumentPauseApi(DocumentResource): # pause document DocumentService.pause_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot pause completed document.') + raise DocumentIndexingError("Cannot pause completed document.") - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentRecoverApi(DocumentResource): @@ -905,7 +884,7 @@ class DocumentRecoverApi(DocumentResource): document_id = str(document_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") document = DocumentService.get_document(dataset.id, document_id) # 404 if document not found @@ -919,9 +898,9 @@ class DocumentRecoverApi(DocumentResource): # pause document DocumentService.recover_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Document is not in paused status.') + raise DocumentIndexingError("Document is not in paused status.") - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentRetryApi(DocumentResource): @@ -932,15 +911,14 @@ class DocumentRetryApi(DocumentResource): """retry document.""" parser = reqparse.RequestParser() - parser.add_argument('document_ids', type=list, required=True, nullable=False, - location='json') + parser.add_argument("document_ids", type=list, required=True, nullable=False, location="json") args = parser.parse_args() dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) retry_documents = [] if not dataset: - raise NotFound('Dataset not found.') - for document_id in args['document_ids']: + raise NotFound("Dataset not found.") + for document_id in args["document_ids"]: try: document_id = str(document_id) @@ -955,7 +933,7 @@ class DocumentRetryApi(DocumentResource): raise ArchivedDocumentImmutableError() # 400 if document is completed - if document.indexing_status == 'completed': + if document.indexing_status == "completed": raise DocumentAlreadyFinishedError() retry_documents.append(document) except Exception as e: @@ -964,7 +942,7 @@ class DocumentRetryApi(DocumentResource): # retry document DocumentService.retry_document(dataset_id, retry_documents) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class DocumentRenameApi(DocumentResource): @@ -979,13 +957,13 @@ class DocumentRenameApi(DocumentResource): dataset = DatasetService.get_dataset(dataset_id) DatasetService.check_dataset_operator_permission(current_user, dataset) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, nullable=False, location='json') + parser.add_argument("name", type=str, required=True, nullable=False, location="json") args = parser.parse_args() try: - document = DocumentService.rename_document(dataset_id, document_id, args['name']) + document = DocumentService.rename_document(dataset_id, document_id, args["name"]) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot delete document during indexing.') + raise DocumentIndexingError("Cannot delete document during indexing.") return document @@ -999,51 +977,43 @@ class WebsiteDocumentSyncApi(DocumentResource): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") if document.tenant_id != current_user.current_tenant_id: - raise Forbidden('No permission.') - if document.data_source_type != 'website_crawl': - raise ValueError('Document is not a website document.') + raise Forbidden("No permission.") + if document.data_source_type != "website_crawl": + raise ValueError("Document is not a website document.") # 403 if document is archived if DocumentService.check_archived(document): raise ArchivedDocumentImmutableError() # sync document DocumentService.sync_website_document(dataset_id, document) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(GetProcessRuleApi, '/datasets/process-rule') -api.add_resource(DatasetDocumentListApi, - '/datasets//documents') -api.add_resource(DatasetInitApi, - '/datasets/init') -api.add_resource(DocumentIndexingEstimateApi, - '/datasets//documents//indexing-estimate') -api.add_resource(DocumentBatchIndexingEstimateApi, - '/datasets//batch//indexing-estimate') -api.add_resource(DocumentBatchIndexingStatusApi, - '/datasets//batch//indexing-status') -api.add_resource(DocumentIndexingStatusApi, - '/datasets//documents//indexing-status') -api.add_resource(DocumentDetailApi, - '/datasets//documents/') -api.add_resource(DocumentProcessingApi, - '/datasets//documents//processing/') -api.add_resource(DocumentDeleteApi, - '/datasets//documents/') -api.add_resource(DocumentMetadataApi, - '/datasets//documents//metadata') -api.add_resource(DocumentStatusApi, - '/datasets//documents//status/') -api.add_resource(DocumentPauseApi, '/datasets//documents//processing/pause') -api.add_resource(DocumentRecoverApi, '/datasets//documents//processing/resume') -api.add_resource(DocumentRetryApi, '/datasets//retry') -api.add_resource(DocumentRenameApi, - '/datasets//documents//rename') +api.add_resource(GetProcessRuleApi, "/datasets/process-rule") +api.add_resource(DatasetDocumentListApi, "/datasets//documents") +api.add_resource(DatasetInitApi, "/datasets/init") +api.add_resource( + DocumentIndexingEstimateApi, "/datasets//documents//indexing-estimate" +) +api.add_resource(DocumentBatchIndexingEstimateApi, "/datasets//batch//indexing-estimate") +api.add_resource(DocumentBatchIndexingStatusApi, "/datasets//batch//indexing-status") +api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") +api.add_resource(DocumentDetailApi, "/datasets//documents/") +api.add_resource( + DocumentProcessingApi, "/datasets//documents//processing/" +) +api.add_resource(DocumentDeleteApi, "/datasets//documents/") +api.add_resource(DocumentMetadataApi, "/datasets//documents//metadata") +api.add_resource(DocumentStatusApi, "/datasets//documents//status/") +api.add_resource(DocumentPauseApi, "/datasets//documents//processing/pause") +api.add_resource(DocumentRecoverApi, "/datasets//documents//processing/resume") +api.add_resource(DocumentRetryApi, "/datasets//retry") +api.add_resource(DocumentRenameApi, "/datasets//documents//rename") -api.add_resource(WebsiteDocumentSyncApi, '/datasets//documents//website-sync') +api.add_resource(WebsiteDocumentSyncApi, "/datasets//documents//website-sync") diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index a4210d5a0..240564938 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -40,7 +40,7 @@ class DatasetDocumentSegmentListApi(Resource): document_id = str(document_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") try: DatasetService.check_dataset_permission(dataset, current_user) @@ -50,37 +50,33 @@ class DatasetDocumentSegmentListApi(Resource): document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") parser = reqparse.RequestParser() - parser.add_argument('last_id', type=str, default=None, location='args') - parser.add_argument('limit', type=int, default=20, location='args') - parser.add_argument('status', type=str, - action='append', default=[], location='args') - parser.add_argument('hit_count_gte', type=int, - default=None, location='args') - parser.add_argument('enabled', type=str, default='all', location='args') - parser.add_argument('keyword', type=str, default=None, location='args') + parser.add_argument("last_id", type=str, default=None, location="args") + parser.add_argument("limit", type=int, default=20, location="args") + parser.add_argument("status", type=str, action="append", default=[], location="args") + parser.add_argument("hit_count_gte", type=int, default=None, location="args") + parser.add_argument("enabled", type=str, default="all", location="args") + parser.add_argument("keyword", type=str, default=None, location="args") args = parser.parse_args() - last_id = args['last_id'] - limit = min(args['limit'], 100) - status_list = args['status'] - hit_count_gte = args['hit_count_gte'] - keyword = args['keyword'] + last_id = args["last_id"] + limit = min(args["limit"], 100) + status_list = args["status"] + hit_count_gte = args["hit_count_gte"] + keyword = args["keyword"] query = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id ) if last_id is not None: last_segment = db.session.get(DocumentSegment, str(last_id)) if last_segment: - query = query.filter( - DocumentSegment.position > last_segment.position) + query = query.filter(DocumentSegment.position > last_segment.position) else: - return {'data': [], 'has_more': False, 'limit': limit}, 200 + return {"data": [], "has_more": False, "limit": limit}, 200 if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) @@ -89,12 +85,12 @@ class DatasetDocumentSegmentListApi(Resource): query = query.filter(DocumentSegment.hit_count >= hit_count_gte) if keyword: - query = query.where(DocumentSegment.content.ilike(f'%{keyword}%')) + query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) - if args['enabled'].lower() != 'all': - if args['enabled'].lower() == 'true': + if args["enabled"].lower() != "all": + if args["enabled"].lower() == "true": query = query.filter(DocumentSegment.enabled == True) - elif args['enabled'].lower() == 'false': + elif args["enabled"].lower() == "false": query = query.filter(DocumentSegment.enabled == False) total = query.count() @@ -106,11 +102,11 @@ class DatasetDocumentSegmentListApi(Resource): segments = segments[:-1] return { - 'data': marshal(segments, segment_fields), - 'doc_form': document.doc_form, - 'has_more': has_more, - 'limit': limit, - 'total': total + "data": marshal(segments, segment_fields), + "doc_form": document.doc_form, + "has_more": has_more, + "limit": limit, + "total": total, }, 200 @@ -118,12 +114,12 @@ class DatasetDocumentSegmentApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def patch(self, dataset_id, segment_id, action): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # The role of the current user in the ta table must be admin, owner, or editor @@ -134,7 +130,7 @@ class DatasetDocumentSegmentApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": # check embedding model setting try: model_manager = ModelManager() @@ -142,32 +138,32 @@ class DatasetDocumentSegmentApi(Resource): tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") - if segment.status != 'completed': - raise NotFound('Segment is not completed, enable or disable function is not allowed') + if segment.status != "completed": + raise NotFound("Segment is not completed, enable or disable function is not allowed") - document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id) + document_indexing_cache_key = "document_{}_indexing".format(segment.document_id) cache_result = redis_client.get(document_indexing_cache_key) if cache_result is not None: raise InvalidActionError("Document is being indexed, please try again later") - indexing_cache_key = 'segment_{}_indexing'.format(segment.id) + indexing_cache_key = "segment_{}_indexing".format(segment.id) cache_result = redis_client.get(indexing_cache_key) if cache_result is not None: raise InvalidActionError("Segment is being indexed, please try again later") @@ -186,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource): enable_segment_to_index_task.delay(segment.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 elif action == "disable": if not segment.enabled: raise InvalidActionError("Segment is already disabled.") @@ -201,7 +197,7 @@ class DatasetDocumentSegmentApi(Resource): disable_segment_from_index_task.delay(segment.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 else: raise InvalidActionError() @@ -210,35 +206,36 @@ class DatasetDocumentSegmentAddApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') - @cloud_edition_billing_knowledge_limit_check('add_segment') + @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_knowledge_limit_check("add_segment") def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") if not current_user.is_editor: raise Forbidden() # check embedding model setting - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) try: @@ -247,37 +244,34 @@ class DatasetDocumentSegmentAddApi(Resource): raise Forbidden(str(e)) # validate args parser = reqparse.RequestParser() - parser.add_argument('content', type=str, required=True, nullable=False, location='json') - parser.add_argument('answer', type=str, required=False, nullable=True, location='json') - parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + parser.add_argument("answer", type=str, required=False, nullable=True, location="json") + parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) segment = SegmentService.create_segment(args, document, dataset) - return { - 'data': marshal(segment, segment_fields), - 'doc_form': document.doc_form - }, 200 + return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 class DatasetDocumentSegmentUpdateApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') + @cloud_edition_billing_resource_check("vector_space") def patch(self, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') - if dataset.indexing_technique == 'high_quality': + raise NotFound("Document not found.") + if dataset.indexing_technique == "high_quality": # check embedding model setting try: model_manager = ModelManager() @@ -285,22 +279,22 @@ class DatasetDocumentSegmentUpdateApi(Resource): tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") # The role of the current user in the ta table must be admin, owner, or editor if not current_user.is_editor: raise Forbidden() @@ -310,16 +304,13 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise Forbidden(str(e)) # validate args parser = reqparse.RequestParser() - parser.add_argument('content', type=str, required=True, nullable=False, location='json') - parser.add_argument('answer', type=str, required=False, nullable=True, location='json') - parser.add_argument('keywords', type=list, required=False, nullable=True, location='json') + parser.add_argument("content", type=str, required=True, nullable=False, location="json") + parser.add_argument("answer", type=str, required=False, nullable=True, location="json") + parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") args = parser.parse_args() SegmentService.segment_create_args_validate(args, document) segment = SegmentService.update_segment(args, segment, document, dataset) - return { - 'data': marshal(segment, segment_fields), - 'doc_form': document.doc_form - }, 200 + return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 @setup_required @login_required @@ -329,22 +320,21 @@ class DatasetDocumentSegmentUpdateApi(Resource): dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check segment segment_id = str(segment_id) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") # The role of the current user in the ta table must be admin or owner if not current_user.is_editor: raise Forbidden() @@ -353,36 +343,36 @@ class DatasetDocumentSegmentUpdateApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) SegmentService.delete_segment(segment, document, dataset) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class DatasetDocumentSegmentBatchImportApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('vector_space') - @cloud_edition_billing_knowledge_limit_check('add_segment') + @cloud_edition_billing_resource_check("vector_space") + @cloud_edition_billing_knowledge_limit_check("add_segment") def post(self, dataset_id, document_id): # check dataset dataset_id = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id) if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() # check file type - if not file.filename.endswith('.csv'): + if not file.filename.endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") try: @@ -390,51 +380,47 @@ class DatasetDocumentSegmentBatchImportApi(Resource): df = pd.read_csv(file) result = [] for index, row in df.iterrows(): - if document.doc_form == 'qa_model': - data = {'content': row[0], 'answer': row[1]} + if document.doc_form == "qa_model": + data = {"content": row[0], "answer": row[1]} else: - data = {'content': row[0]} + data = {"content": row[0]} result.append(data) if len(result) == 0: raise ValueError("The CSV file is empty.") # async job job_id = str(uuid.uuid4()) - indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id)) + indexing_cache_key = "segment_batch_import_{}".format(str(job_id)) # send batch add segments task - redis_client.setnx(indexing_cache_key, 'waiting') - batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id, - current_user.current_tenant_id, current_user.id) + redis_client.setnx(indexing_cache_key, "waiting") + batch_create_segment_to_index_task.delay( + str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id + ) except Exception as e: - return {'error': str(e)}, 500 - return { - 'job_id': job_id, - 'job_status': 'waiting' - }, 200 + return {"error": str(e)}, 500 + return {"job_id": job_id, "job_status": "waiting"}, 200 @setup_required @login_required @account_initialization_required def get(self, job_id): job_id = str(job_id) - indexing_cache_key = 'segment_batch_import_{}'.format(job_id) + indexing_cache_key = "segment_batch_import_{}".format(job_id) cache_result = redis_client.get(indexing_cache_key) if cache_result is None: raise ValueError("The job is not exist.") - return { - 'job_id': job_id, - 'job_status': cache_result.decode() - }, 200 + return {"job_id": job_id, "job_status": cache_result.decode()}, 200 -api.add_resource(DatasetDocumentSegmentListApi, - '/datasets//documents//segments') -api.add_resource(DatasetDocumentSegmentApi, - '/datasets//segments//') -api.add_resource(DatasetDocumentSegmentAddApi, - '/datasets//documents//segment') -api.add_resource(DatasetDocumentSegmentUpdateApi, - '/datasets//documents//segments/') -api.add_resource(DatasetDocumentSegmentBatchImportApi, - '/datasets//documents//segments/batch_import', - '/datasets/batch_import_status/') +api.add_resource(DatasetDocumentSegmentListApi, "/datasets//documents//segments") +api.add_resource(DatasetDocumentSegmentApi, "/datasets//segments//") +api.add_resource(DatasetDocumentSegmentAddApi, "/datasets//documents//segment") +api.add_resource( + DatasetDocumentSegmentUpdateApi, + "/datasets//documents//segments/", +) +api.add_resource( + DatasetDocumentSegmentBatchImportApi, + "/datasets//documents//segments/batch_import", + "/datasets/batch_import_status/", +) diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py index 9270b610c..6a7a3971a 100644 --- a/api/controllers/console/datasets/error.py +++ b/api/controllers/console/datasets/error.py @@ -2,90 +2,90 @@ from libs.exception import BaseHTTPException class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 class HighQualityDatasetOnlyError(BaseHTTPException): - error_code = 'high_quality_dataset_only' + error_code = "high_quality_dataset_only" description = "Current operation only supports 'high-quality' datasets." code = 400 class DatasetNotInitializedError(BaseHTTPException): - error_code = 'dataset_not_initialized' + error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." code = 400 class ArchivedDocumentImmutableError(BaseHTTPException): - error_code = 'archived_document_immutable' + error_code = "archived_document_immutable" description = "The archived document is not editable." code = 403 class DatasetNameDuplicateError(BaseHTTPException): - error_code = 'dataset_name_duplicate' + error_code = "dataset_name_duplicate" description = "The dataset name already exists. Please modify your dataset name." code = 409 class InvalidActionError(BaseHTTPException): - error_code = 'invalid_action' + error_code = "invalid_action" description = "Invalid action." code = 400 class DocumentAlreadyFinishedError(BaseHTTPException): - error_code = 'document_already_finished' + error_code = "document_already_finished" description = "The document has been processed. Please refresh the page or go to the document details." code = 400 class DocumentIndexingError(BaseHTTPException): - error_code = 'document_indexing' + error_code = "document_indexing" description = "The document is being processed and cannot be edited." code = 400 class InvalidMetadataError(BaseHTTPException): - error_code = 'invalid_metadata' + error_code = "invalid_metadata" description = "The metadata content is incorrect. Please check and verify." code = 400 class WebsiteCrawlError(BaseHTTPException): - error_code = 'crawl_failed' + error_code = "crawl_failed" description = "{message}" code = 500 class DatasetInUseError(BaseHTTPException): - error_code = 'dataset_in_use' + error_code = "dataset_in_use" description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." code = 409 class IndexingEstimateError(BaseHTTPException): - error_code = 'indexing_estimate_error' + error_code = "indexing_estimate_error" description = "Knowledge indexing estimate failed: {message}" code = 500 diff --git a/api/controllers/console/datasets/file.py b/api/controllers/console/datasets/file.py index 3b2083bcc..d6a464545 100644 --- a/api/controllers/console/datasets/file.py +++ b/api/controllers/console/datasets/file.py @@ -21,7 +21,6 @@ PREVIEW_WORDS_LIMIT = 3000 class FileApi(Resource): - @setup_required @login_required @account_initialization_required @@ -31,23 +30,22 @@ class FileApi(Resource): batch_count_limit = dify_config.UPLOAD_FILE_BATCH_LIMIT image_file_size_limit = dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT return { - 'file_size_limit': file_size_limit, - 'batch_count_limit': batch_count_limit, - 'image_file_size_limit': image_file_size_limit + "file_size_limit": file_size_limit, + "batch_count_limit": batch_count_limit, + "image_file_size_limit": image_file_size_limit, }, 200 @setup_required @login_required @account_initialization_required @marshal_with(file_fields) - @cloud_edition_billing_resource_check(resource='documents') + @cloud_edition_billing_resource_check(resource="documents") def post(self): - # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: @@ -69,7 +67,7 @@ class FilePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) text = FileService.get_file_preview(file_id) - return {'content': text} + return {"content": text} class FileSupportTypeApi(Resource): @@ -78,10 +76,10 @@ class FileSupportTypeApi(Resource): @account_initialization_required def get(self): etl_type = dify_config.ETL_TYPE - allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == 'Unstructured' else ALLOWED_EXTENSIONS - return {'allowed_extensions': allowed_extensions} + allowed_extensions = UNSTRUCTURED_ALLOWED_EXTENSIONS if etl_type == "Unstructured" else ALLOWED_EXTENSIONS + return {"allowed_extensions": allowed_extensions} -api.add_resource(FileApi, '/files/upload') -api.add_resource(FilePreviewApi, '/files//preview') -api.add_resource(FileSupportTypeApi, '/files/support-type') +api.add_resource(FileApi, "/files/upload") +api.add_resource(FilePreviewApi, "/files//preview") +api.add_resource(FileSupportTypeApi, "/files/support-type") diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py index 8771bf909..0b4a7be98 100644 --- a/api/controllers/console/datasets/hit_testing.py +++ b/api/controllers/console/datasets/hit_testing.py @@ -29,7 +29,6 @@ from services.hit_testing_service import HitTestingService class HitTestingApi(Resource): - @setup_required @login_required @account_initialization_required @@ -46,8 +45,8 @@ class HitTestingApi(Resource): raise Forbidden(str(e)) parser = reqparse.RequestParser() - parser.add_argument('query', type=str, location='json') - parser.add_argument('retrieval_model', type=dict, required=False, location='json') + parser.add_argument("query", type=str, location="json") + parser.add_argument("retrieval_model", type=dict, required=False, location="json") args = parser.parse_args() HitTestingService.hit_testing_args_check(args) @@ -55,13 +54,13 @@ class HitTestingApi(Resource): try: response = HitTestingService.retrieve( dataset=dataset, - query=args['query'], + query=args["query"], account=current_user, - retrieval_model=args['retrieval_model'], - limit=10 + retrieval_model=args["retrieval_model"], + limit=10, ) - return {"query": response['query'], 'records': marshal(response['records'], hit_testing_record_fields)} + return {"query": response["query"], "records": marshal(response["records"], hit_testing_record_fields)} except services.errors.index.IndexNotInitializedError: raise DatasetNotInitializedError() except ProviderTokenNotInitError as ex: @@ -73,7 +72,8 @@ class HitTestingApi(Resource): except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model or Reranking Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except InvokeError as e: raise CompletionRequestError(e.description) except ValueError as e: @@ -83,4 +83,4 @@ class HitTestingApi(Resource): raise InternalServerError(str(e)) -api.add_resource(HitTestingApi, '/datasets//hit-testing') +api.add_resource(HitTestingApi, "/datasets//hit-testing") diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index bbd91256f..cb54f1aac 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -9,16 +9,14 @@ from services.website_service import WebsiteService class WebsiteCrawlApi(Resource): - @setup_required @login_required @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, choices=['firecrawl'], - required=True, nullable=True, location='json') - parser.add_argument('url', type=str, required=True, nullable=True, location='json') - parser.add_argument('options', type=dict, required=True, nullable=True, location='json') + parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, nullable=True, location="json") + parser.add_argument("url", type=str, required=True, nullable=True, location="json") + parser.add_argument("options", type=dict, required=True, nullable=True, location="json") args = parser.parse_args() WebsiteService.document_create_args_validate(args) # crawl url @@ -35,15 +33,15 @@ class WebsiteCrawlStatusApi(Resource): @account_initialization_required def get(self, job_id: str): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, choices=['firecrawl'], required=True, location='args') + parser.add_argument("provider", type=str, choices=["firecrawl"], required=True, location="args") args = parser.parse_args() # get crawl status try: - result = WebsiteService.get_crawl_status(job_id, args['provider']) + result = WebsiteService.get_crawl_status(job_id, args["provider"]) except Exception as e: raise WebsiteCrawlError(str(e)) return result, 200 -api.add_resource(WebsiteCrawlApi, '/website/crawl') -api.add_resource(WebsiteCrawlStatusApi, '/website/crawl/status/') +api.add_resource(WebsiteCrawlApi, "/website/crawl") +api.add_resource(WebsiteCrawlStatusApi, "/website/crawl/status/") diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index 888dad83c..1c70ea6c5 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -2,35 +2,41 @@ from libs.exception import BaseHTTPException class AlreadySetupError(BaseHTTPException): - error_code = 'already_setup' + error_code = "already_setup" description = "Dify has been successfully installed. Please refresh the page or return to the dashboard homepage." code = 403 class NotSetupError(BaseHTTPException): - error_code = 'not_setup' - description = "Dify has not been initialized and installed yet. " \ - "Please proceed with the initialization and installation process first." + error_code = "not_setup" + description = ( + "Dify has not been initialized and installed yet. " + "Please proceed with the initialization and installation process first." + ) code = 401 + class NotInitValidateError(BaseHTTPException): - error_code = 'not_init_validated' - description = "Init validation has not been completed yet. " \ - "Please proceed with the init validation process first." + error_code = "not_init_validated" + description = ( + "Init validation has not been completed yet. " "Please proceed with the init validation process first." + ) code = 401 + class InitValidateFailedError(BaseHTTPException): - error_code = 'init_validate_failed' + error_code = "init_validate_failed" description = "Init validation failed. Please check the password and try again." code = 401 + class AccountNotLinkTenantError(BaseHTTPException): - error_code = 'account_not_link_tenant' + error_code = "account_not_link_tenant" description = "Account not link tenant." code = 403 class AlreadyActivateError(BaseHTTPException): - error_code = 'already_activate' + error_code = "already_activate" description = "Auth Token is invalid or account already activated, please check again." code = 403 diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 27cc83042..71cb060ec 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -33,14 +33,10 @@ class ChatAudioApi(InstalledAppResource): def post(self, installed_app): app_model = installed_app.app - file = request.files['file'] + file = request.files["file"] try: - response = AudioService.transcript_asr( - app_model=app_model, - file=file, - end_user=None - ) + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -76,30 +72,31 @@ class ChatTextApi(InstalledAppResource): app_model = installed_app.app try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, required=False, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + voice = ( + args.get("voice") + if args.get("voice") + else app_model.app_model_config.text_to_speech_dict.get("voice") + ) except Exception: voice = None - response = AudioService.transcript_tts( - app_model=app_model, - message_id=message_id, - voice=voice, - text=text - ) + response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text) return response except services.errors.app_model_config.AppModelConfigBrokenError: logging.exception("App model config broken.") @@ -127,7 +124,7 @@ class ChatTextApi(InstalledAppResource): raise InternalServerError() -api.add_resource(ChatAudioApi, '/installed-apps//audio-to-text', endpoint='installed_app_audio') -api.add_resource(ChatTextApi, '/installed-apps//text-to-audio', endpoint='installed_app_text') +api.add_resource(ChatAudioApi, "/installed-apps//audio-to-text", endpoint="installed_app_audio") +api.add_resource(ChatTextApi, "/installed-apps//text-to-audio", endpoint="installed_app_text") # api.add_resource(ChatTextApiWithMessageId, '/installed-apps//text-to-audio/message-id', # endpoint='installed_app_text_with_message_id') diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 869b56e13..c039e8bca 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -30,33 +30,28 @@ from services.app_generate_service import AppGenerateService # define completion api for user class CompletionApi(InstalledAppResource): - def post(self, installed_app): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' - args['auto_generate_name'] = False + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.EXPLORE, - streaming=streaming + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming ) return helper.compact_generate_response(response) @@ -85,12 +80,12 @@ class CompletionApi(InstalledAppResource): class CompletionStopApi(InstalledAppResource): def post(self, installed_app, task_id): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatApi(InstalledAppResource): @@ -101,25 +96,21 @@ class ChatApi(InstalledAppResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json") args = parser.parse_args() - args['auto_generate_name'] = False + args["auto_generate_name"] = False installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) db.session.commit() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.EXPLORE, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) return helper.compact_generate_response(response) @@ -154,10 +145,22 @@ class ChatStopApi(InstalledAppResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionApi, '/installed-apps//completion-messages', endpoint='installed_app_completion') -api.add_resource(CompletionStopApi, '/installed-apps//completion-messages//stop', endpoint='installed_app_stop_completion') -api.add_resource(ChatApi, '/installed-apps//chat-messages', endpoint='installed_app_chat_completion') -api.add_resource(ChatStopApi, '/installed-apps//chat-messages//stop', endpoint='installed_app_stop_chat_completion') +api.add_resource( + CompletionApi, "/installed-apps//completion-messages", endpoint="installed_app_completion" +) +api.add_resource( + CompletionStopApi, + "/installed-apps//completion-messages//stop", + endpoint="installed_app_stop_completion", +) +api.add_resource( + ChatApi, "/installed-apps//chat-messages", endpoint="installed_app_chat_completion" +) +api.add_resource( + ChatStopApi, + "/installed-apps//chat-messages//stop", + endpoint="installed_app_stop_chat_completion", +) diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py index ea0fa4e17..2918024b6 100644 --- a/api/controllers/console/explore/conversation.py +++ b/api/controllers/console/explore/conversation.py @@ -16,7 +16,6 @@ from services.web_conversation_service import WebConversationService class ConversationListApi(InstalledAppResource): - @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, installed_app): app_model = installed_app.app @@ -25,21 +24,21 @@ class ConversationListApi(InstalledAppResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") args = parser.parse_args() pinned = None - if 'pinned' in args and args['pinned'] is not None: - pinned = True if args['pinned'] == 'true' else False + if "pinned" in args and args["pinned"] is not None: + pinned = True if args["pinned"] == "true" else False try: return WebConversationService.pagination_by_last_id( app_model=app_model, user=current_user, - last_id=args['last_id'], - limit=args['limit'], + last_id=args["last_id"], + limit=args["limit"], invoke_from=InvokeFrom.EXPLORE, pinned=pinned, ) @@ -65,7 +64,6 @@ class ConversationApi(InstalledAppResource): class ConversationRenameApi(InstalledAppResource): - @marshal_with(simple_conversation_fields) def post(self, installed_app, c_id): app_model = installed_app.app @@ -76,24 +74,19 @@ class ConversationRenameApi(InstalledAppResource): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') + parser.add_argument("name", type=str, required=False, location="json") + parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() try: return ConversationService.rename( - app_model, - conversation_id, - current_user, - args['name'], - args['auto_generate'] + app_model, conversation_id, current_user, args["name"], args["auto_generate"] ) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") class ConversationPinApi(InstalledAppResource): - def patch(self, installed_app, c_id): app_model = installed_app.app app_mode = AppMode.value_of(app_model.mode) @@ -123,8 +116,26 @@ class ConversationUnPinApi(InstalledAppResource): return {"result": "success"} -api.add_resource(ConversationRenameApi, '/installed-apps//conversations//name', endpoint='installed_app_conversation_rename') -api.add_resource(ConversationListApi, '/installed-apps//conversations', endpoint='installed_app_conversations') -api.add_resource(ConversationApi, '/installed-apps//conversations/', endpoint='installed_app_conversation') -api.add_resource(ConversationPinApi, '/installed-apps//conversations//pin', endpoint='installed_app_conversation_pin') -api.add_resource(ConversationUnPinApi, '/installed-apps//conversations//unpin', endpoint='installed_app_conversation_unpin') +api.add_resource( + ConversationRenameApi, + "/installed-apps//conversations//name", + endpoint="installed_app_conversation_rename", +) +api.add_resource( + ConversationListApi, "/installed-apps//conversations", endpoint="installed_app_conversations" +) +api.add_resource( + ConversationApi, + "/installed-apps//conversations/", + endpoint="installed_app_conversation", +) +api.add_resource( + ConversationPinApi, + "/installed-apps//conversations//pin", + endpoint="installed_app_conversation_pin", +) +api.add_resource( + ConversationUnPinApi, + "/installed-apps//conversations//unpin", + endpoint="installed_app_conversation_unpin", +) diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py index 9c3216ecc..18221b779 100644 --- a/api/controllers/console/explore/error.py +++ b/api/controllers/console/explore/error.py @@ -2,24 +2,24 @@ from libs.exception import BaseHTTPException class NotCompletionAppError(BaseHTTPException): - error_code = 'not_completion_app' + error_code = "not_completion_app" description = "Not Completion App" code = 400 class NotChatAppError(BaseHTTPException): - error_code = 'not_chat_app' + error_code = "not_chat_app" description = "App mode is invalid." code = 400 class NotWorkflowAppError(BaseHTTPException): - error_code = 'not_workflow_app' + error_code = "not_workflow_app" description = "Only support workflow app." code = 400 class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): - error_code = 'app_suggested_questions_after_answer_disabled' + error_code = "app_suggested_questions_after_answer_disabled" description = "Function Suggested questions after answer disabled." code = 403 diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index ec7bbed30..b71078760 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -21,72 +21,71 @@ class InstalledAppsListApi(Resource): @marshal_with(installed_app_list_fields) def get(self): current_tenant_id = current_user.current_tenant_id - installed_apps = db.session.query(InstalledApp).filter( - InstalledApp.tenant_id == current_tenant_id - ).all() + installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all() current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant) installed_apps = [ { - 'id': installed_app.id, - 'app': installed_app.app, - 'app_owner_tenant_id': installed_app.app_owner_tenant_id, - 'is_pinned': installed_app.is_pinned, - 'last_used_at': installed_app.last_used_at, - 'editable': current_user.role in ["owner", "admin"], - 'uninstallable': current_tenant_id == installed_app.app_owner_tenant_id + "id": installed_app.id, + "app": installed_app.app, + "app_owner_tenant_id": installed_app.app_owner_tenant_id, + "is_pinned": installed_app.is_pinned, + "last_used_at": installed_app.last_used_at, + "editable": current_user.role in ["owner", "admin"], + "uninstallable": current_tenant_id == installed_app.app_owner_tenant_id, } for installed_app in installed_apps ] - installed_apps.sort(key=lambda app: (-app['is_pinned'], - app['last_used_at'] is None, - -app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0)) + installed_apps.sort( + key=lambda app: ( + -app["is_pinned"], + app["last_used_at"] is None, + -app["last_used_at"].timestamp() if app["last_used_at"] is not None else 0, + ) + ) - return {'installed_apps': installed_apps} + return {"installed_apps": installed_apps} @login_required @account_initialization_required - @cloud_edition_billing_resource_check('apps') + @cloud_edition_billing_resource_check("apps") def post(self): parser = reqparse.RequestParser() - parser.add_argument('app_id', type=str, required=True, help='Invalid app_id') + parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") args = parser.parse_args() - recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args['app_id']).first() + recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() if recommended_app is None: - raise NotFound('App not found') + raise NotFound("App not found") current_tenant_id = current_user.current_tenant_id - app = db.session.query(App).filter( - App.id == args['app_id'] - ).first() + app = db.session.query(App).filter(App.id == args["app_id"]).first() if app is None: - raise NotFound('App not found') + raise NotFound("App not found") if not app.is_public: - raise Forbidden('You can\'t install a non-public app') + raise Forbidden("You can't install a non-public app") - installed_app = InstalledApp.query.filter(and_( - InstalledApp.app_id == args['app_id'], - InstalledApp.tenant_id == current_tenant_id - )).first() + installed_app = InstalledApp.query.filter( + and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id) + ).first() if installed_app is None: # todo: position recommended_app.install_count += 1 new_installed_app = InstalledApp( - app_id=args['app_id'], + app_id=args["app_id"], tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, - last_used_at=datetime.now(timezone.utc).replace(tzinfo=None) + last_used_at=datetime.now(timezone.utc).replace(tzinfo=None), ) db.session.add(new_installed_app) db.session.commit() - return {'message': 'App installed successfully'} + return {"message": "App installed successfully"} class InstalledAppApi(InstalledAppResource): @@ -94,30 +93,31 @@ class InstalledAppApi(InstalledAppResource): update and delete an installed app use InstalledAppResource to apply default decorators and get installed_app """ + def delete(self, installed_app): if installed_app.app_owner_tenant_id == current_user.current_tenant_id: - raise BadRequest('You can\'t uninstall an app owned by the current tenant') + raise BadRequest("You can't uninstall an app owned by the current tenant") db.session.delete(installed_app) db.session.commit() - return {'result': 'success', 'message': 'App uninstalled successfully'} + return {"result": "success", "message": "App uninstalled successfully"} def patch(self, installed_app): parser = reqparse.RequestParser() - parser.add_argument('is_pinned', type=inputs.boolean) + parser.add_argument("is_pinned", type=inputs.boolean) args = parser.parse_args() commit_args = False - if 'is_pinned' in args: - installed_app.is_pinned = args['is_pinned'] + if "is_pinned" in args: + installed_app.is_pinned = args["is_pinned"] commit_args = True if commit_args: db.session.commit() - return {'result': 'success', 'message': 'App info updated successfully'} + return {"result": "success", "message": "App info updated successfully"} -api.add_resource(InstalledAppsListApi, '/installed-apps') -api.add_resource(InstalledAppApi, '/installed-apps/') +api.add_resource(InstalledAppsListApi, "/installed-apps") +api.add_resource(InstalledAppApi, "/installed-apps/") diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 3523a8690..f5eb18517 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -44,19 +44,21 @@ class MessageListApi(InstalledAppResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() try: - return MessageService.pagination_by_first_id(app_model, current_user, - args['conversation_id'], args['first_id'], args['limit']) + return MessageService.pagination_by_first_id( + app_model, current_user, args["conversation_id"], args["first_id"], args["limit"] + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.message.FirstMessageNotExistsError: raise NotFound("First Message Not Exists.") + class MessageFeedbackApi(InstalledAppResource): def post(self, installed_app, message_id): app_model = installed_app.app @@ -64,30 +66,32 @@ class MessageFeedbackApi(InstalledAppResource): message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, current_user, args['rating']) + MessageService.create_feedback(app_model, message_id, current_user, args["rating"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class MessageMoreLikeThisApi(InstalledAppResource): def get(self, installed_app, message_id): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + parser.add_argument( + "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" + ) args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" try: response = AppGenerateService.generate_more_like_this( @@ -95,7 +99,7 @@ class MessageMoreLikeThisApi(InstalledAppResource): user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE, - streaming=streaming + streaming=streaming, ) return helper.compact_generate_response(response) except MessageNotExistsError: @@ -128,10 +132,7 @@ class MessageSuggestedQuestionApi(InstalledAppResource): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - user=current_user, - message_id=message_id, - invoke_from=InvokeFrom.EXPLORE + app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE ) except MessageNotExistsError: raise NotFound("Message not found") @@ -151,10 +152,22 @@ class MessageSuggestedQuestionApi(InstalledAppResource): logging.exception("internal server error.") raise InternalServerError() - return {'data': questions} + return {"data": questions} -api.add_resource(MessageListApi, '/installed-apps//messages', endpoint='installed_app_messages') -api.add_resource(MessageFeedbackApi, '/installed-apps//messages//feedbacks', endpoint='installed_app_message_feedback') -api.add_resource(MessageMoreLikeThisApi, '/installed-apps//messages//more-like-this', endpoint='installed_app_more_like_this') -api.add_resource(MessageSuggestedQuestionApi, '/installed-apps//messages//suggested-questions', endpoint='installed_app_suggested_question') +api.add_resource(MessageListApi, "/installed-apps//messages", endpoint="installed_app_messages") +api.add_resource( + MessageFeedbackApi, + "/installed-apps//messages//feedbacks", + endpoint="installed_app_message_feedback", +) +api.add_resource( + MessageMoreLikeThisApi, + "/installed-apps//messages//more-like-this", + endpoint="installed_app_more_like_this", +) +api.add_resource( + MessageSuggestedQuestionApi, + "/installed-apps//messages//suggested-questions", + endpoint="installed_app_suggested_question", +) diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 0a168d630..ad55b0404 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,4 +1,3 @@ - from flask_restful import fields, marshal_with from configs import dify_config @@ -11,33 +10,32 @@ from services.app_service import AppService class AppParameterApi(InstalledAppResource): """Resource for app variables.""" + variable_fields = { - 'key': fields.String, - 'name': fields.String, - 'description': fields.String, - 'type': fields.String, - 'default': fields.String, - 'max_length': fields.Integer, - 'options': fields.List(fields.String) + "key": fields.String, + "name": fields.String, + "description": fields.String, + "type": fields.String, + "default": fields.String, + "max_length": fields.Integer, + "options": fields.List(fields.String), } - system_parameters_fields = { - 'image_file_size_limit': fields.String - } + system_parameters_fields = {"image_file_size_limit": fields.String} parameters_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'suggested_questions_after_answer': fields.Raw, - 'speech_to_text': fields.Raw, - 'text_to_speech': fields.Raw, - 'retriever_resource': fields.Raw, - 'annotation_reply': fields.Raw, - 'more_like_this': fields.Raw, - 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw, - 'file_upload': fields.Raw, - 'system_parameters': fields.Nested(system_parameters_fields) + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "suggested_questions_after_answer": fields.Raw, + "speech_to_text": fields.Raw, + "text_to_speech": fields.Raw, + "retriever_resource": fields.Raw, + "annotation_reply": fields.Raw, + "more_like_this": fields.Raw, + "user_input_form": fields.Raw, + "sensitive_word_avoidance": fields.Raw, + "file_upload": fields.Raw, + "system_parameters": fields.Nested(system_parameters_fields), } @marshal_with(parameters_fields) @@ -56,30 +54,35 @@ class AppParameterApi(InstalledAppResource): app_model_config = app_model.app_model_config features_dict = app_model_config.to_dict() - user_input_form = features_dict.get('user_input_form', []) + user_input_form = features_dict.get("user_input_form", []) return { - 'opening_statement': features_dict.get('opening_statement'), - 'suggested_questions': features_dict.get('suggested_questions', []), - 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', - {"enabled": False}), - 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), - 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), - 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), - 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), - 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), - 'user_input_form': user_input_form, - 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', - {"enabled": False, "type": "", "configs": []}), - 'file_upload': features_dict.get('file_upload', {"image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"] - }}), - 'system_parameters': { - 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT - } + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get( + "suggested_questions_after_answer", {"enabled": False} + ), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, } @@ -90,6 +93,7 @@ class ExploreAppMetaApi(InstalledAppResource): return AppService().get_app_meta(app_model) -api.add_resource(AppParameterApi, '/installed-apps//parameters', - endpoint='installed_app_parameters') -api.add_resource(ExploreAppMetaApi, '/installed-apps//meta', endpoint='installed_app_meta') +api.add_resource( + AppParameterApi, "/installed-apps//parameters", endpoint="installed_app_parameters" +) +api.add_resource(ExploreAppMetaApi, "/installed-apps//meta", endpoint="installed_app_meta") diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py index 6e10e2ec9..5daaa1e7c 100644 --- a/api/controllers/console/explore/recommended_app.py +++ b/api/controllers/console/explore/recommended_app.py @@ -8,28 +8,28 @@ from libs.login import login_required from services.recommended_app_service import RecommendedAppService app_fields = { - 'id': fields.String, - 'name': fields.String, - 'mode': fields.String, - 'icon': fields.String, - 'icon_background': fields.String + "id": fields.String, + "name": fields.String, + "mode": fields.String, + "icon": fields.String, + "icon_background": fields.String, } recommended_app_fields = { - 'app': fields.Nested(app_fields, attribute='app'), - 'app_id': fields.String, - 'description': fields.String(attribute='description'), - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'category': fields.String, - 'position': fields.Integer, - 'is_listed': fields.Boolean + "app": fields.Nested(app_fields, attribute="app"), + "app_id": fields.String, + "description": fields.String(attribute="description"), + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "category": fields.String, + "position": fields.Integer, + "is_listed": fields.Boolean, } recommended_app_list_fields = { - 'recommended_apps': fields.List(fields.Nested(recommended_app_fields)), - 'categories': fields.List(fields.String) + "recommended_apps": fields.List(fields.Nested(recommended_app_fields)), + "categories": fields.List(fields.String), } @@ -40,11 +40,11 @@ class RecommendedAppListApi(Resource): def get(self): # language args parser = reqparse.RequestParser() - parser.add_argument('language', type=str, location='args') + parser.add_argument("language", type=str, location="args") args = parser.parse_args() - if args.get('language') and args.get('language') in languages: - language_prefix = args.get('language') + if args.get("language") and args.get("language") in languages: + language_prefix = args.get("language") elif current_user and current_user.interface_language: language_prefix = current_user.interface_language else: @@ -61,5 +61,5 @@ class RecommendedAppApi(Resource): return RecommendedAppService.get_recommend_app_detail(app_id) -api.add_resource(RecommendedAppListApi, '/explore/apps') -api.add_resource(RecommendedAppApi, '/explore/apps/') +api.add_resource(RecommendedAppListApi, "/explore/apps") +api.add_resource(RecommendedAppApi, "/explore/apps/") diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index cf86b2fee..a7ccf737a 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -11,56 +11,54 @@ from libs.helper import TimestampField, uuid_value from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} message_fields = { - 'id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String, - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'created_at': TimestampField + "id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String, + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "created_at": TimestampField, } class SavedMessageListApi(InstalledAppResource): saved_message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, installed_app): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - return SavedMessageService.pagination_by_last_id(app_model, current_user, args['last_id'], args['limit']) + return SavedMessageService.pagination_by_last_id(app_model, current_user, args["last_id"], args["limit"]) def post(self, installed_app): app_model = installed_app.app - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('message_id', type=uuid_value, required=True, location='json') + parser.add_argument("message_id", type=uuid_value, required=True, location="json") args = parser.parse_args() try: - SavedMessageService.save(app_model, current_user, args['message_id']) + SavedMessageService.save(app_model, current_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class SavedMessageApi(InstalledAppResource): @@ -69,13 +67,21 @@ class SavedMessageApi(InstalledAppResource): message_id = str(message_id) - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() SavedMessageService.delete(app_model, current_user, message_id) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(SavedMessageListApi, '/installed-apps//saved-messages', endpoint='installed_app_saved_messages') -api.add_resource(SavedMessageApi, '/installed-apps//saved-messages/', endpoint='installed_app_saved_message') +api.add_resource( + SavedMessageListApi, + "/installed-apps//saved-messages", + endpoint="installed_app_saved_messages", +) +api.add_resource( + SavedMessageApi, + "/installed-apps//saved-messages/", + endpoint="installed_app_saved_message", +) diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 7c5e211d4..45f99b1db 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -35,17 +35,13 @@ class InstalledAppWorkflowRunApi(InstalledAppResource): raise NotWorkflowAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=current_user, - args=args, - invoke_from=InvokeFrom.EXPLORE, - streaming=True + app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True ) return helper.compact_generate_response(response) @@ -76,10 +72,10 @@ class InstalledAppWorkflowTaskStopApi(InstalledAppResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.EXPLORE, current_user.id) - return { - "result": "success" - } + return {"result": "success"} -api.add_resource(InstalledAppWorkflowRunApi, '/installed-apps//workflows/run') -api.add_resource(InstalledAppWorkflowTaskStopApi, '/installed-apps//workflows/tasks//stop') +api.add_resource(InstalledAppWorkflowRunApi, "/installed-apps//workflows/run") +api.add_resource( + InstalledAppWorkflowTaskStopApi, "/installed-apps//workflows/tasks//stop" +) diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 84890f1b4..3c9317847 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -14,29 +14,33 @@ def installed_app_required(view=None): def decorator(view): @wraps(view) def decorated(*args, **kwargs): - if not kwargs.get('installed_app_id'): - raise ValueError('missing installed_app_id in path parameters') + if not kwargs.get("installed_app_id"): + raise ValueError("missing installed_app_id in path parameters") - installed_app_id = kwargs.get('installed_app_id') + installed_app_id = kwargs.get("installed_app_id") installed_app_id = str(installed_app_id) - del kwargs['installed_app_id'] + del kwargs["installed_app_id"] - installed_app = db.session.query(InstalledApp).filter( - InstalledApp.id == str(installed_app_id), - InstalledApp.tenant_id == current_user.current_tenant_id - ).first() + installed_app = ( + db.session.query(InstalledApp) + .filter( + InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id + ) + .first() + ) if installed_app is None: - raise NotFound('Installed app not found') + raise NotFound("Installed app not found") if not installed_app.app: db.session.delete(installed_app) db.session.commit() - raise NotFound('Installed app not found') + raise NotFound("Installed app not found") return view(installed_app, *args, **kwargs) + return decorated if view: diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index fe73bcb98..5d6a8bf15 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -13,23 +13,18 @@ from services.code_based_extension_service import CodeBasedExtensionService class CodeBasedExtensionAPI(Resource): - @setup_required @login_required @account_initialization_required def get(self): parser = reqparse.RequestParser() - parser.add_argument('module', type=str, required=True, location='args') + parser.add_argument("module", type=str, required=True, location="args") args = parser.parse_args() - return { - 'module': args['module'], - 'data': CodeBasedExtensionService.get_code_based_extension(args['module']) - } + return {"module": args["module"], "data": CodeBasedExtensionService.get_code_based_extension(args["module"])} class APIBasedExtensionAPI(Resource): - @setup_required @login_required @account_initialization_required @@ -44,23 +39,22 @@ class APIBasedExtensionAPI(Resource): @marshal_with(api_based_extension_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('api_endpoint', type=str, required=True, location='json') - parser.add_argument('api_key', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("api_endpoint", type=str, required=True, location="json") + parser.add_argument("api_key", type=str, required=True, location="json") args = parser.parse_args() extension_data = APIBasedExtension( tenant_id=current_user.current_tenant_id, - name=args['name'], - api_endpoint=args['api_endpoint'], - api_key=args['api_key'] + name=args["name"], + api_endpoint=args["api_endpoint"], + api_key=args["api_key"], ) return APIBasedExtensionService.save(extension_data) class APIBasedExtensionDetailAPI(Resource): - @setup_required @login_required @account_initialization_required @@ -82,16 +76,16 @@ class APIBasedExtensionDetailAPI(Resource): extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('api_endpoint', type=str, required=True, location='json') - parser.add_argument('api_key', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("api_endpoint", type=str, required=True, location="json") + parser.add_argument("api_key", type=str, required=True, location="json") args = parser.parse_args() - extension_data_from_db.name = args['name'] - extension_data_from_db.api_endpoint = args['api_endpoint'] + extension_data_from_db.name = args["name"] + extension_data_from_db.api_endpoint = args["api_endpoint"] - if args['api_key'] != HIDDEN_VALUE: - extension_data_from_db.api_key = args['api_key'] + if args["api_key"] != HIDDEN_VALUE: + extension_data_from_db.api_key = args["api_key"] return APIBasedExtensionService.save(extension_data_from_db) @@ -106,10 +100,10 @@ class APIBasedExtensionDetailAPI(Resource): APIBasedExtensionService.delete(extension_data_from_db) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(CodeBasedExtensionAPI, '/code-based-extension') +api.add_resource(CodeBasedExtensionAPI, "/code-based-extension") -api.add_resource(APIBasedExtensionAPI, '/api-based-extension') -api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/') +api.add_resource(APIBasedExtensionAPI, "/api-based-extension") +api.add_resource(APIBasedExtensionDetailAPI, "/api-based-extension/") diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py index 8475cd848..f0482f749 100644 --- a/api/controllers/console/feature.py +++ b/api/controllers/console/feature.py @@ -10,7 +10,6 @@ from .wraps import account_initialization_required, cloud_utm_record class FeatureApi(Resource): - @setup_required @login_required @account_initialization_required @@ -24,5 +23,5 @@ class SystemFeatureApi(Resource): return FeatureService.get_system_features().model_dump() -api.add_resource(FeatureApi, '/features') -api.add_resource(SystemFeatureApi, '/system-features') +api.add_resource(FeatureApi, "/features") +api.add_resource(SystemFeatureApi, "/system-features") diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py index 6feb1003a..7d3ae677e 100644 --- a/api/controllers/console/init_validate.py +++ b/api/controllers/console/init_validate.py @@ -14,12 +14,11 @@ from .wraps import only_edition_self_hosted class InitValidateAPI(Resource): - def get(self): init_status = get_init_validate_status() if init_status: - return { 'status': 'finished' } - return {'status': 'not_started' } + return {"status": "finished"} + return {"status": "not_started"} @only_edition_self_hosted def post(self): @@ -29,22 +28,23 @@ class InitValidateAPI(Resource): raise AlreadySetupError() parser = reqparse.RequestParser() - parser.add_argument('password', type=str_len(30), - required=True, location='json') - input_password = parser.parse_args()['password'] + parser.add_argument("password", type=str_len(30), required=True, location="json") + input_password = parser.parse_args()["password"] - if input_password != os.environ.get('INIT_PASSWORD'): - session['is_init_validated'] = False + if input_password != os.environ.get("INIT_PASSWORD"): + session["is_init_validated"] = False raise InitValidateFailedError() - - session['is_init_validated'] = True - return {'result': 'success'}, 201 + + session["is_init_validated"] = True + return {"result": "success"}, 201 + def get_init_validate_status(): - if dify_config.EDITION == 'SELF_HOSTED': - if os.environ.get('INIT_PASSWORD'): - return session.get('is_init_validated') or DifySetup.query.first() - + if dify_config.EDITION == "SELF_HOSTED": + if os.environ.get("INIT_PASSWORD"): + return session.get("is_init_validated") or DifySetup.query.first() + return True -api.add_resource(InitValidateAPI, '/init') + +api.add_resource(InitValidateAPI, "/init") diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py index 7664ba8c1..cd28cc946 100644 --- a/api/controllers/console/ping.py +++ b/api/controllers/console/ping.py @@ -4,14 +4,11 @@ from controllers.console import api class PingApi(Resource): - def get(self): """ For connection health check """ - return { - "result": "pong" - } + return {"result": "pong"} -api.add_resource(PingApi, '/ping') +api.add_resource(PingApi, "/ping") diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index ef7cc6bc0..827695e00 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -16,17 +16,13 @@ from .wraps import only_edition_self_hosted class SetupApi(Resource): - def get(self): - if dify_config.EDITION == 'SELF_HOSTED': + if dify_config.EDITION == "SELF_HOSTED": setup_status = get_setup_status() if setup_status: - return { - 'step': 'finished', - 'setup_at': setup_status.setup_at.isoformat() - } - return {'step': 'not_started'} - return {'step': 'finished'} + return {"step": "finished", "setup_at": setup_status.setup_at.isoformat()} + return {"step": "not_started"} + return {"step": "finished"} @only_edition_self_hosted def post(self): @@ -38,28 +34,22 @@ class SetupApi(Resource): tenant_count = TenantService.get_tenant_count() if tenant_count > 0: raise AlreadySetupError() - + if not get_init_validate_status(): raise NotInitValidateError() parser = reqparse.RequestParser() - parser.add_argument('email', type=email, - required=True, location='json') - parser.add_argument('name', type=str_len( - 30), required=True, location='json') - parser.add_argument('password', type=valid_password, - required=True, location='json') + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("name", type=str_len(30), required=True, location="json") + parser.add_argument("password", type=valid_password, required=True, location="json") args = parser.parse_args() # setup RegisterService.setup( - email=args['email'], - name=args['name'], - password=args['password'], - ip_address=get_remote_ip(request) + email=args["email"], name=args["name"], password=args["password"], ip_address=get_remote_ip(request) ) - return {'result': 'success'}, 201 + return {"result": "success"}, 201 def setup_required(view): @@ -68,7 +58,7 @@ def setup_required(view): # check setup if not get_init_validate_status(): raise NotInitValidateError() - + elif not get_setup_status(): raise NotSetupError() @@ -78,9 +68,10 @@ def setup_required(view): def get_setup_status(): - if dify_config.EDITION == 'SELF_HOSTED': + if dify_config.EDITION == "SELF_HOSTED": return DifySetup.query.first() else: return True -api.add_resource(SetupApi, '/setup') + +api.add_resource(SetupApi, "/setup") diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index 004afaa53..7293aeeb3 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -14,19 +14,18 @@ from services.tag_service import TagService def _validate_name(name): if not name or len(name) < 1 or len(name) > 40: - raise ValueError('Name must be between 1 to 50 characters.') + raise ValueError("Name must be between 1 to 50 characters.") return name class TagListApi(Resource): - @setup_required @login_required @account_initialization_required @marshal_with(tag_fields) def get(self): - tag_type = request.args.get('type', type=str) - keyword = request.args.get('keyword', default=None, type=str) + tag_type = request.args.get("type", type=str) + keyword = request.args.get("keyword", default=None, type=str) tags = TagService.get_tags(tag_type, current_user.current_tenant_id, keyword) return tags, 200 @@ -40,28 +39,21 @@ class TagListApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='Name must be between 1 to 50 characters.', - type=_validate_name) - parser.add_argument('type', type=str, location='json', - choices=Tag.TAG_TYPE_LIST, - nullable=True, - help='Invalid tag type.') + parser.add_argument( + "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name + ) + parser.add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) args = parser.parse_args() tag = TagService.save_tags(args) - response = { - 'id': tag.id, - 'name': tag.name, - 'type': tag.type, - 'binding_count': 0 - } + response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} return response, 200 class TagUpdateDeleteApi(Resource): - @setup_required @login_required @account_initialization_required @@ -72,20 +64,15 @@ class TagUpdateDeleteApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='Name must be between 1 to 50 characters.', - type=_validate_name) + parser.add_argument( + "name", nullable=False, required=True, help="Name must be between 1 to 50 characters.", type=_validate_name + ) args = parser.parse_args() tag = TagService.update_tags(args, tag_id) binding_count = TagService.get_tag_binding_count(tag_id) - response = { - 'id': tag.id, - 'name': tag.name, - 'type': tag.type, - 'binding_count': binding_count - } + response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} return response, 200 @@ -104,7 +91,6 @@ class TagUpdateDeleteApi(Resource): class TagBindingCreateApi(Resource): - @setup_required @login_required @account_initialization_required @@ -114,14 +100,15 @@ class TagBindingCreateApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('tag_ids', type=list, nullable=False, required=True, location='json', - help='Tag IDs is required.') - parser.add_argument('target_id', type=str, nullable=False, required=True, location='json', - help='Target ID is required.') - parser.add_argument('type', type=str, location='json', - choices=Tag.TAG_TYPE_LIST, - nullable=True, - help='Invalid tag type.') + parser.add_argument( + "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." + ) + parser.add_argument( + "target_id", type=str, nullable=False, required=True, location="json", help="Target ID is required." + ) + parser.add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) args = parser.parse_args() TagService.save_tag_binding(args) @@ -129,7 +116,6 @@ class TagBindingCreateApi(Resource): class TagBindingDeleteApi(Resource): - @setup_required @login_required @account_initialization_required @@ -139,21 +125,18 @@ class TagBindingDeleteApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('tag_id', type=str, nullable=False, required=True, - help='Tag ID is required.') - parser.add_argument('target_id', type=str, nullable=False, required=True, - help='Target ID is required.') - parser.add_argument('type', type=str, location='json', - choices=Tag.TAG_TYPE_LIST, - nullable=True, - help='Invalid tag type.') + parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") + parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") + parser.add_argument( + "type", type=str, location="json", choices=Tag.TAG_TYPE_LIST, nullable=True, help="Invalid tag type." + ) args = parser.parse_args() TagService.delete_tag_binding(args) return 200 -api.add_resource(TagListApi, '/tags') -api.add_resource(TagUpdateDeleteApi, '/tags/') -api.add_resource(TagBindingCreateApi, '/tag-bindings/create') -api.add_resource(TagBindingDeleteApi, '/tag-bindings/remove') +api.add_resource(TagListApi, "/tags") +api.add_resource(TagUpdateDeleteApi, "/tags/") +api.add_resource(TagBindingCreateApi, "/tag-bindings/create") +api.add_resource(TagBindingDeleteApi, "/tag-bindings/remove") diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py index 1fcf4bdc0..76adbfe6a 100644 --- a/api/controllers/console/version.py +++ b/api/controllers/console/version.py @@ -1,4 +1,3 @@ - import json import logging @@ -11,42 +10,39 @@ from . import api class VersionApi(Resource): - def get(self): parser = reqparse.RequestParser() - parser.add_argument('current_version', type=str, required=True, location='args') + parser.add_argument("current_version", type=str, required=True, location="args") args = parser.parse_args() check_update_url = dify_config.CHECK_UPDATE_URL result = { - 'version': dify_config.CURRENT_VERSION, - 'release_date': '', - 'release_notes': '', - 'can_auto_update': False, - 'features': { - 'can_replace_logo': dify_config.CAN_REPLACE_LOGO, - 'model_load_balancing_enabled': dify_config.MODEL_LB_ENABLED - } + "version": dify_config.CURRENT_VERSION, + "release_date": "", + "release_notes": "", + "can_auto_update": False, + "features": { + "can_replace_logo": dify_config.CAN_REPLACE_LOGO, + "model_load_balancing_enabled": dify_config.MODEL_LB_ENABLED, + }, } if not check_update_url: return result try: - response = requests.get(check_update_url, { - 'current_version': args.get('current_version') - }) + response = requests.get(check_update_url, {"current_version": args.get("current_version")}) except Exception as error: logging.warning("Check update version error: {}.".format(str(error))) - result['version'] = args.get('current_version') + result["version"] = args.get("current_version") return result content = json.loads(response.content) - result['version'] = content['version'] - result['release_date'] = content['releaseDate'] - result['release_notes'] = content['releaseNotes'] - result['can_auto_update'] = content['canAutoUpdate'] + result["version"] = content["version"] + result["release_date"] = content["releaseDate"] + result["release_notes"] = content["releaseNotes"] + result["can_auto_update"] = content["canAutoUpdate"] return result -api.add_resource(VersionApi, '/version') +api.add_resource(VersionApi, "/version") diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 1056d5eb6..dec426128 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -26,52 +26,53 @@ from services.errors.account import CurrentPasswordIncorrectError as ServiceCurr class AccountInitApi(Resource): - @setup_required @login_required def post(self): account = current_user - if account.status == 'active': + if account.status == "active": raise AccountAlreadyInitedError() parser = reqparse.RequestParser() - if dify_config.EDITION == 'CLOUD': - parser.add_argument('invitation_code', type=str, location='json') + if dify_config.EDITION == "CLOUD": + parser.add_argument("invitation_code", type=str, location="json") - parser.add_argument( - 'interface_language', type=supported_language, required=True, location='json') - parser.add_argument('timezone', type=timezone, - required=True, location='json') + parser.add_argument("interface_language", type=supported_language, required=True, location="json") + parser.add_argument("timezone", type=timezone, required=True, location="json") args = parser.parse_args() - if dify_config.EDITION == 'CLOUD': - if not args['invitation_code']: - raise ValueError('invitation_code is required') + if dify_config.EDITION == "CLOUD": + if not args["invitation_code"]: + raise ValueError("invitation_code is required") # check invitation code - invitation_code = db.session.query(InvitationCode).filter( - InvitationCode.code == args['invitation_code'], - InvitationCode.status == 'unused', - ).first() + invitation_code = ( + db.session.query(InvitationCode) + .filter( + InvitationCode.code == args["invitation_code"], + InvitationCode.status == "unused", + ) + .first() + ) if not invitation_code: raise InvalidInvitationCodeError() - invitation_code.status = 'used' + invitation_code.status = "used" invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id - account.interface_language = args['interface_language'] - account.timezone = args['timezone'] - account.interface_theme = 'light' - account.status = 'active' + account.interface_language = args["interface_language"] + account.timezone = args["timezone"] + account.interface_theme = "light" + account.status = "active" account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) db.session.commit() - return {'result': 'success'} + return {"result": "success"} class AccountProfileApi(Resource): @@ -90,15 +91,14 @@ class AccountNameApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() # Validate account name length - if len(args['name']) < 3 or len(args['name']) > 30: - raise ValueError( - "Account name must be between 3 and 30 characters.") + if len(args["name"]) < 3 or len(args["name"]) > 30: + raise ValueError("Account name must be between 3 and 30 characters.") - updated_account = AccountService.update_account(current_user, name=args['name']) + updated_account = AccountService.update_account(current_user, name=args["name"]) return updated_account @@ -110,10 +110,10 @@ class AccountAvatarApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('avatar', type=str, required=True, location='json') + parser.add_argument("avatar", type=str, required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, avatar=args['avatar']) + updated_account = AccountService.update_account(current_user, avatar=args["avatar"]) return updated_account @@ -125,11 +125,10 @@ class AccountInterfaceLanguageApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument( - 'interface_language', type=supported_language, required=True, location='json') + parser.add_argument("interface_language", type=supported_language, required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, interface_language=args['interface_language']) + updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"]) return updated_account @@ -141,11 +140,10 @@ class AccountInterfaceThemeApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('interface_theme', type=str, choices=[ - 'light', 'dark'], required=True, location='json') + parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") args = parser.parse_args() - updated_account = AccountService.update_account(current_user, interface_theme=args['interface_theme']) + updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"]) return updated_account @@ -157,15 +155,14 @@ class AccountTimezoneApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('timezone', type=str, - required=True, location='json') + parser.add_argument("timezone", type=str, required=True, location="json") args = parser.parse_args() # Validate timezone string, e.g. America/New_York, Asia/Shanghai - if args['timezone'] not in pytz.all_timezones: + if args["timezone"] not in pytz.all_timezones: raise ValueError("Invalid timezone string.") - updated_account = AccountService.update_account(current_user, timezone=args['timezone']) + updated_account = AccountService.update_account(current_user, timezone=args["timezone"]) return updated_account @@ -177,20 +174,16 @@ class AccountPasswordApi(Resource): @marshal_with(account_fields) def post(self): parser = reqparse.RequestParser() - parser.add_argument('password', type=str, - required=False, location='json') - parser.add_argument('new_password', type=str, - required=True, location='json') - parser.add_argument('repeat_new_password', type=str, - required=True, location='json') + parser.add_argument("password", type=str, required=False, location="json") + parser.add_argument("new_password", type=str, required=True, location="json") + parser.add_argument("repeat_new_password", type=str, required=True, location="json") args = parser.parse_args() - if args['new_password'] != args['repeat_new_password']: + if args["new_password"] != args["repeat_new_password"]: raise RepeatPasswordNotMatchError() try: - AccountService.update_account_password( - current_user, args['password'], args['new_password']) + AccountService.update_account_password(current_user, args["password"], args["new_password"]) except ServiceCurrentPasswordIncorrectError: raise CurrentPasswordIncorrectError() @@ -199,14 +192,14 @@ class AccountPasswordApi(Resource): class AccountIntegrateApi(Resource): integrate_fields = { - 'provider': fields.String, - 'created_at': TimestampField, - 'is_bound': fields.Boolean, - 'link': fields.String + "provider": fields.String, + "created_at": TimestampField, + "is_bound": fields.Boolean, + "link": fields.String, } integrate_list_fields = { - 'data': fields.List(fields.Nested(integrate_fields)), + "data": fields.List(fields.Nested(integrate_fields)), } @setup_required @@ -216,10 +209,9 @@ class AccountIntegrateApi(Resource): def get(self): account = current_user - account_integrates = db.session.query(AccountIntegrate).filter( - AccountIntegrate.account_id == account.id).all() + account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all() - base_url = request.url_root.rstrip('/') + base_url = request.url_root.rstrip("/") oauth_base_path = "/console/api/oauth/login" providers = ["github", "google"] @@ -227,36 +219,38 @@ class AccountIntegrateApi(Resource): for provider in providers: existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None) if existing_integrate: - integrate_data.append({ - 'id': existing_integrate.id, - 'provider': provider, - 'created_at': existing_integrate.created_at, - 'is_bound': True, - 'link': None - }) + integrate_data.append( + { + "id": existing_integrate.id, + "provider": provider, + "created_at": existing_integrate.created_at, + "is_bound": True, + "link": None, + } + ) else: - integrate_data.append({ - 'id': None, - 'provider': provider, - 'created_at': None, - 'is_bound': False, - 'link': f'{base_url}{oauth_base_path}/{provider}' - }) - - return {'data': integrate_data} - + integrate_data.append( + { + "id": None, + "provider": provider, + "created_at": None, + "is_bound": False, + "link": f"{base_url}{oauth_base_path}/{provider}", + } + ) + return {"data": integrate_data} # Register API resources -api.add_resource(AccountInitApi, '/account/init') -api.add_resource(AccountProfileApi, '/account/profile') -api.add_resource(AccountNameApi, '/account/name') -api.add_resource(AccountAvatarApi, '/account/avatar') -api.add_resource(AccountInterfaceLanguageApi, '/account/interface-language') -api.add_resource(AccountInterfaceThemeApi, '/account/interface-theme') -api.add_resource(AccountTimezoneApi, '/account/timezone') -api.add_resource(AccountPasswordApi, '/account/password') -api.add_resource(AccountIntegrateApi, '/account/integrates') +api.add_resource(AccountInitApi, "/account/init") +api.add_resource(AccountProfileApi, "/account/profile") +api.add_resource(AccountNameApi, "/account/name") +api.add_resource(AccountAvatarApi, "/account/avatar") +api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language") +api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme") +api.add_resource(AccountTimezoneApi, "/account/timezone") +api.add_resource(AccountPasswordApi, "/account/password") +api.add_resource(AccountIntegrateApi, "/account/integrates") # api.add_resource(AccountEmailApi, '/account/email') # api.add_resource(AccountEmailVerifyApi, '/account/email-verify') diff --git a/api/controllers/console/workspace/error.py b/api/controllers/console/workspace/error.py index 99f55835b..9e13c7b92 100644 --- a/api/controllers/console/workspace/error.py +++ b/api/controllers/console/workspace/error.py @@ -2,36 +2,36 @@ from libs.exception import BaseHTTPException class RepeatPasswordNotMatchError(BaseHTTPException): - error_code = 'repeat_password_not_match' + error_code = "repeat_password_not_match" description = "New password and repeat password does not match." code = 400 class CurrentPasswordIncorrectError(BaseHTTPException): - error_code = 'current_password_incorrect' + error_code = "current_password_incorrect" description = "Current password is incorrect." code = 400 class ProviderRequestFailedError(BaseHTTPException): - error_code = 'provider_request_failed' + error_code = "provider_request_failed" description = None code = 400 class InvalidInvitationCodeError(BaseHTTPException): - error_code = 'invalid_invitation_code' + error_code = "invalid_invitation_code" description = "Invalid invitation code." code = 400 class AccountAlreadyInitedError(BaseHTTPException): - error_code = 'account_already_inited' + error_code = "account_already_inited" description = "The account has been initialized. Please refresh the page." code = 400 class AccountNotInitializedError(BaseHTTPException): - error_code = 'account_not_initialized' + error_code = "account_not_initialized" description = "The account has not been initialized yet. Please proceed with the initialization process first." code = 400 diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 50514e39f..771a86662 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -22,10 +22,16 @@ class LoadBalancingCredentialsValidateApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() # validate model load balancing credentials @@ -38,18 +44,18 @@ class LoadBalancingCredentialsValidateApi(Resource): model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'] + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response @@ -65,10 +71,16 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() # validate model load balancing config credentials @@ -81,26 +93,30 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): model_load_balancing_service.validate_load_balancing_credentials( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'], + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], config_id=config_id, ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response # Load Balancing Config -api.add_resource(LoadBalancingCredentialsValidateApi, - '/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate') +api.add_resource( + LoadBalancingCredentialsValidateApi, + "/workspaces/current/model-providers//models/load-balancing-configs/credentials-validate", +) -api.add_resource(LoadBalancingConfigCredentialsValidateApi, - '/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate') +api.add_resource( + LoadBalancingConfigCredentialsValidateApi, + "/workspaces/current/model-providers//models/load-balancing-configs//credentials-validate", +) diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 34e9da384..3e87bebf5 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -23,7 +23,7 @@ class MemberListApi(Resource): @marshal_with(account_with_role_list_fields) def get(self): members = TenantService.get_tenant_members(current_user.current_tenant) - return {'result': 'success', 'accounts': members}, 200 + return {"result": "success", "accounts": members}, 200 class MemberInviteEmailApi(Resource): @@ -32,48 +32,46 @@ class MemberInviteEmailApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('members') + @cloud_edition_billing_resource_check("members") def post(self): parser = reqparse.RequestParser() - parser.add_argument('emails', type=str, required=True, location='json', action='append') - parser.add_argument('role', type=str, required=True, default='admin', location='json') - parser.add_argument('language', type=str, required=False, location='json') + parser.add_argument("emails", type=str, required=True, location="json", action="append") + parser.add_argument("role", type=str, required=True, default="admin", location="json") + parser.add_argument("language", type=str, required=False, location="json") args = parser.parse_args() - invitee_emails = args['emails'] - invitee_role = args['role'] - interface_language = args['language'] + invitee_emails = args["emails"] + invitee_role = args["role"] + interface_language = args["language"] if not TenantAccountRole.is_non_owner_role(invitee_role): - return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 + return {"code": "invalid-role", "message": "Invalid role"}, 400 inviter = current_user invitation_results = [] console_web_url = dify_config.CONSOLE_WEB_URL for invitee_email in invitee_emails: try: - token = RegisterService.invite_new_member(inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter) - invitation_results.append({ - 'status': 'success', - 'email': invitee_email, - 'url': f'{console_web_url}/activate?email={invitee_email}&token={token}' - }) + token = RegisterService.invite_new_member( + inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter + ) + invitation_results.append( + { + "status": "success", + "email": invitee_email, + "url": f"{console_web_url}/activate?email={invitee_email}&token={token}", + } + ) except AccountAlreadyInTenantError: - invitation_results.append({ - 'status': 'success', - 'email': invitee_email, - 'url': f'{console_web_url}/signin' - }) + invitation_results.append( + {"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"} + ) break except Exception as e: - invitation_results.append({ - 'status': 'failed', - 'email': invitee_email, - 'message': str(e) - }) + invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)}) return { - 'result': 'success', - 'invitation_results': invitation_results, + "result": "success", + "invitation_results": invitation_results, }, 201 @@ -91,15 +89,15 @@ class MemberCancelInviteApi(Resource): try: TenantService.remove_member_from_tenant(current_user.current_tenant, member, current_user) except services.errors.account.CannotOperateSelfError as e: - return {'code': 'cannot-operate-self', 'message': str(e)}, 400 + return {"code": "cannot-operate-self", "message": str(e)}, 400 except services.errors.account.NoPermissionError as e: - return {'code': 'forbidden', 'message': str(e)}, 403 + return {"code": "forbidden", "message": str(e)}, 403 except services.errors.account.MemberNotInTenantError as e: - return {'code': 'member-not-found', 'message': str(e)}, 404 + return {"code": "member-not-found", "message": str(e)}, 404 except Exception as e: raise ValueError(str(e)) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class MemberUpdateRoleApi(Resource): @@ -110,12 +108,12 @@ class MemberUpdateRoleApi(Resource): @account_initialization_required def put(self, member_id): parser = reqparse.RequestParser() - parser.add_argument('role', type=str, required=True, location='json') + parser.add_argument("role", type=str, required=True, location="json") args = parser.parse_args() - new_role = args['role'] + new_role = args["role"] if not TenantAccountRole.is_valid_role(new_role): - return {'code': 'invalid-role', 'message': 'Invalid role'}, 400 + return {"code": "invalid-role", "message": "Invalid role"}, 400 member = db.session.get(Account, str(member_id)) if not member: @@ -128,7 +126,7 @@ class MemberUpdateRoleApi(Resource): # todo: 403 - return {'result': 'success'} + return {"result": "success"} class DatasetOperatorMemberListApi(Resource): @@ -140,11 +138,11 @@ class DatasetOperatorMemberListApi(Resource): @marshal_with(account_with_role_list_fields) def get(self): members = TenantService.get_dataset_operator_members(current_user.current_tenant) - return {'result': 'success', 'accounts': members}, 200 + return {"result": "success", "accounts": members}, 200 -api.add_resource(MemberListApi, '/workspaces/current/members') -api.add_resource(MemberInviteEmailApi, '/workspaces/current/members/invite-email') -api.add_resource(MemberCancelInviteApi, '/workspaces/current/members/') -api.add_resource(MemberUpdateRoleApi, '/workspaces/current/members//update-role') -api.add_resource(DatasetOperatorMemberListApi, '/workspaces/current/dataset-operators') +api.add_resource(MemberListApi, "/workspaces/current/members") +api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email") +api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/") +api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members//update-role") +api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators") diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index c888159f8..8c3842022 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -17,7 +17,6 @@ from services.model_provider_service import ModelProviderService class ModelProviderListApi(Resource): - @setup_required @login_required @account_initialization_required @@ -25,21 +24,23 @@ class ModelProviderListApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model_type', type=str, required=False, nullable=True, - choices=[mt.value for mt in ModelType], location='args') + parser.add_argument( + "model_type", + type=str, + required=False, + nullable=True, + choices=[mt.value for mt in ModelType], + location="args", + ) args = parser.parse_args() model_provider_service = ModelProviderService() - provider_list = model_provider_service.get_provider_list( - tenant_id=tenant_id, - model_type=args.get('model_type') - ) + provider_list = model_provider_service.get_provider_list(tenant_id=tenant_id, model_type=args.get("model_type")) return jsonable_encoder({"data": provider_list}) class ModelProviderCredentialApi(Resource): - @setup_required @login_required @account_initialization_required @@ -47,25 +48,18 @@ class ModelProviderCredentialApi(Resource): tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - credentials = model_provider_service.get_provider_credentials( - tenant_id=tenant_id, - provider=provider - ) + credentials = model_provider_service.get_provider_credentials(tenant_id=tenant_id, provider=provider) - return { - "credentials": credentials - } + return {"credentials": credentials} class ModelProviderValidateApi(Resource): - @setup_required @login_required @account_initialization_required def post(self, provider: str): - parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() tenant_id = current_user.current_tenant_id @@ -77,24 +71,21 @@ class ModelProviderValidateApi(Resource): try: model_provider_service.provider_credentials_validate( - tenant_id=tenant_id, - provider=provider, - credentials=args['credentials'] + tenant_id=tenant_id, provider=provider, credentials=args["credentials"] ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response class ModelProviderApi(Resource): - @setup_required @login_required @account_initialization_required @@ -103,21 +94,19 @@ class ModelProviderApi(Resource): raise Forbidden() parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() try: model_provider_service.save_provider_credentials( - tenant_id=current_user.current_tenant_id, - provider=provider, - credentials=args['credentials'] + tenant_id=current_user.current_tenant_id, provider=provider, credentials=args["credentials"] ) except CredentialsValidateFailedError as ex: raise ValueError(str(ex)) - return {'result': 'success'}, 201 + return {"result": "success"}, 201 @setup_required @login_required @@ -127,12 +116,9 @@ class ModelProviderApi(Resource): raise Forbidden() model_provider_service = ModelProviderService() - model_provider_service.remove_provider_credentials( - tenant_id=current_user.current_tenant_id, - provider=provider - ) + model_provider_service.remove_provider_credentials(tenant_id=current_user.current_tenant_id, provider=provider) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class ModelProviderIconApi(Resource): @@ -146,16 +132,13 @@ class ModelProviderIconApi(Resource): def get(self, provider: str, icon_type: str, lang: str): model_provider_service = ModelProviderService() icon, mimetype = model_provider_service.get_model_provider_icon( - provider=provider, - icon_type=icon_type, - lang=lang + provider=provider, icon_type=icon_type, lang=lang ) return send_file(io.BytesIO(icon), mimetype=mimetype) class PreferredProviderTypeUpdateApi(Resource): - @setup_required @login_required @account_initialization_required @@ -166,18 +149,22 @@ class PreferredProviderTypeUpdateApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('preferred_provider_type', type=str, required=True, nullable=False, - choices=['system', 'custom'], location='json') + parser.add_argument( + "preferred_provider_type", + type=str, + required=True, + nullable=False, + choices=["system", "custom"], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.switch_preferred_provider( - tenant_id=tenant_id, - provider=provider, - preferred_provider_type=args['preferred_provider_type'] + tenant_id=tenant_id, provider=provider, preferred_provider_type=args["preferred_provider_type"] ) - return {'result': 'success'} + return {"result": "success"} class ModelProviderPaymentCheckoutUrlApi(Resource): @@ -185,13 +172,15 @@ class ModelProviderPaymentCheckoutUrlApi(Resource): @login_required @account_initialization_required def get(self, provider: str): - if provider != 'anthropic': - raise ValueError(f'provider name {provider} is invalid') + if provider != "anthropic": + raise ValueError(f"provider name {provider} is invalid") BillingService.is_tenant_owner_or_admin(current_user) - data = BillingService.get_model_provider_payment_link(provider_name=provider, - tenant_id=current_user.current_tenant_id, - account_id=current_user.id, - prefilled_email=current_user.email) + data = BillingService.get_model_provider_payment_link( + provider_name=provider, + tenant_id=current_user.current_tenant_id, + account_id=current_user.id, + prefilled_email=current_user.email, + ) return data @@ -201,10 +190,7 @@ class ModelProviderFreeQuotaSubmitApi(Resource): @account_initialization_required def post(self, provider: str): model_provider_service = ModelProviderService() - result = model_provider_service.free_quota_submit( - tenant_id=current_user.current_tenant_id, - provider=provider - ) + result = model_provider_service.free_quota_submit(tenant_id=current_user.current_tenant_id, provider=provider) return result @@ -215,32 +201,36 @@ class ModelProviderFreeQuotaQualificationVerifyApi(Resource): @account_initialization_required def get(self, provider: str): parser = reqparse.RequestParser() - parser.add_argument('token', type=str, required=False, nullable=True, location='args') + parser.add_argument("token", type=str, required=False, nullable=True, location="args") args = parser.parse_args() model_provider_service = ModelProviderService() result = model_provider_service.free_quota_qualification_verify( - tenant_id=current_user.current_tenant_id, - provider=provider, - token=args['token'] + tenant_id=current_user.current_tenant_id, provider=provider, token=args["token"] ) return result -api.add_resource(ModelProviderListApi, '/workspaces/current/model-providers') +api.add_resource(ModelProviderListApi, "/workspaces/current/model-providers") -api.add_resource(ModelProviderCredentialApi, '/workspaces/current/model-providers//credentials') -api.add_resource(ModelProviderValidateApi, '/workspaces/current/model-providers//credentials/validate') -api.add_resource(ModelProviderApi, '/workspaces/current/model-providers/') -api.add_resource(ModelProviderIconApi, '/workspaces/current/model-providers//' - '/') +api.add_resource(ModelProviderCredentialApi, "/workspaces/current/model-providers//credentials") +api.add_resource(ModelProviderValidateApi, "/workspaces/current/model-providers//credentials/validate") +api.add_resource(ModelProviderApi, "/workspaces/current/model-providers/") +api.add_resource( + ModelProviderIconApi, "/workspaces/current/model-providers//" "/" +) -api.add_resource(PreferredProviderTypeUpdateApi, - '/workspaces/current/model-providers//preferred-provider-type') -api.add_resource(ModelProviderPaymentCheckoutUrlApi, - '/workspaces/current/model-providers//checkout-url') -api.add_resource(ModelProviderFreeQuotaSubmitApi, - '/workspaces/current/model-providers//free-quota-submit') -api.add_resource(ModelProviderFreeQuotaQualificationVerifyApi, - '/workspaces/current/model-providers//free-quota-qualification-verify') +api.add_resource( + PreferredProviderTypeUpdateApi, "/workspaces/current/model-providers//preferred-provider-type" +) +api.add_resource( + ModelProviderPaymentCheckoutUrlApi, "/workspaces/current/model-providers//checkout-url" +) +api.add_resource( + ModelProviderFreeQuotaSubmitApi, "/workspaces/current/model-providers//free-quota-submit" +) +api.add_resource( + ModelProviderFreeQuotaQualificationVerifyApi, + "/workspaces/current/model-providers//free-quota-qualification-verify", +) diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 35f02761d..dc88f6b81 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -16,27 +16,29 @@ from services.model_provider_service import ModelProviderService class DefaultModelApi(Resource): - @setup_required @login_required @account_initialization_required def get(self): parser = reqparse.RequestParser() - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='args') + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="args", + ) args = parser.parse_args() tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() default_model_entity = model_provider_service.get_default_model_of_model_type( - tenant_id=tenant_id, - model_type=args['model_type'] + tenant_id=tenant_id, model_type=args["model_type"] ) - return jsonable_encoder({ - "data": default_model_entity - }) + return jsonable_encoder({"data": default_model_entity}) @setup_required @login_required @@ -44,40 +46,39 @@ class DefaultModelApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + parser = reqparse.RequestParser() - parser.add_argument('model_settings', type=list, required=True, nullable=False, location='json') + parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json") args = parser.parse_args() tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - model_settings = args['model_settings'] + model_settings = args["model_settings"] for model_setting in model_settings: - if 'model_type' not in model_setting or model_setting['model_type'] not in [mt.value for mt in ModelType]: - raise ValueError('invalid model type') + if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]: + raise ValueError("invalid model type") - if 'provider' not in model_setting: + if "provider" not in model_setting: continue - if 'model' not in model_setting: - raise ValueError('invalid model') + if "model" not in model_setting: + raise ValueError("invalid model") try: model_provider_service.update_default_model_of_model_type( tenant_id=tenant_id, - model_type=model_setting['model_type'], - provider=model_setting['provider'], - model=model_setting['model'] + model_type=model_setting["model_type"], + provider=model_setting["provider"], + model=model_setting["model"], ) except Exception: logging.warning(f"{model_setting['model_type']} save error") - return {'result': 'success'} + return {"result": "success"} class ModelProviderModelApi(Resource): - @setup_required @login_required @account_initialization_required @@ -85,14 +86,9 @@ class ModelProviderModelApi(Resource): tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - models = model_provider_service.get_models_by_provider( - tenant_id=tenant_id, - provider=provider - ) + models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider) - return jsonable_encoder({ - "data": models - }) + return jsonable_encoder({"data": models}) @setup_required @login_required @@ -104,62 +100,66 @@ class ModelProviderModelApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=False, nullable=True, location='json') - parser.add_argument('load_balancing', type=dict, required=False, nullable=True, location='json') - parser.add_argument('config_from', type=str, required=False, nullable=True, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") + parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") + parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") args = parser.parse_args() model_load_balancing_service = ModelLoadBalancingService() - if ('load_balancing' in args and args['load_balancing'] and - 'enabled' in args['load_balancing'] and args['load_balancing']['enabled']): - if 'configs' not in args['load_balancing']: - raise ValueError('invalid load balancing configs') + if ( + "load_balancing" in args + and args["load_balancing"] + and "enabled" in args["load_balancing"] + and args["load_balancing"]["enabled"] + ): + if "configs" not in args["load_balancing"]: + raise ValueError("invalid load balancing configs") # save load balancing configs model_load_balancing_service.update_load_balancing_configs( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - configs=args['load_balancing']['configs'] + model=args["model"], + model_type=args["model_type"], + configs=args["load_balancing"]["configs"], ) # enable load balancing model_load_balancing_service.enable_model_load_balancing( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) else: # disable load balancing model_load_balancing_service.disable_model_load_balancing( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - if args.get('config_from', '') != 'predefined-model': + if args.get("config_from", "") != "predefined-model": model_provider_service = ModelProviderService() try: model_provider_service.save_model_credentials( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'] + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: logging.exception(f"save model credentials error: {ex}") raise ValueError(str(ex)) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 @setup_required @login_required @@ -171,24 +171,26 @@ class ModelProviderModelApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.remove_model_credentials( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {'result': 'success'}, 204 + return {"result": "success"}, 204 class ModelProviderModelCredentialApi(Resource): - @setup_required @login_required @account_initialization_required @@ -196,38 +198,34 @@ class ModelProviderModelCredentialApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='args') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='args') + parser.add_argument("model", type=str, required=True, nullable=False, location="args") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="args", + ) args = parser.parse_args() model_provider_service = ModelProviderService() credentials = model_provider_service.get_model_credentials( - tenant_id=tenant_id, - provider=provider, - model_type=args['model_type'], - model=args['model'] + tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"] ) model_load_balancing_service = ModelLoadBalancingService() is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) return { "credentials": credentials, - "load_balancing": { - "enabled": is_load_balancing_enabled, - "configs": load_balancing_configs - } + "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, } class ModelProviderModelEnableApi(Resource): - @setup_required @login_required @account_initialization_required @@ -235,24 +233,26 @@ class ModelProviderModelEnableApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.enable_model( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {'result': 'success'} + return {"result": "success"} class ModelProviderModelDisableApi(Resource): - @setup_required @login_required @account_initialization_required @@ -260,24 +260,26 @@ class ModelProviderModelDisableApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) args = parser.parse_args() model_provider_service = ModelProviderService() model_provider_service.disable_model( - tenant_id=tenant_id, - provider=provider, - model=args['model'], - model_type=args['model_type'] + tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] ) - return {'result': 'success'} + return {"result": "success"} class ModelProviderModelValidateApi(Resource): - @setup_required @login_required @account_initialization_required @@ -285,10 +287,16 @@ class ModelProviderModelValidateApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='json') - parser.add_argument('model_type', type=str, required=True, nullable=False, - choices=[mt.value for mt in ModelType], location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("model", type=str, required=True, nullable=False, location="json") + parser.add_argument( + "model_type", + type=str, + required=True, + nullable=False, + choices=[mt.value for mt in ModelType], + location="json", + ) + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() model_provider_service = ModelProviderService() @@ -300,48 +308,42 @@ class ModelProviderModelValidateApi(Resource): model_provider_service.model_credentials_validate( tenant_id=tenant_id, provider=provider, - model=args['model'], - model_type=args['model_type'], - credentials=args['credentials'] + model=args["model"], + model_type=args["model_type"], + credentials=args["credentials"], ) except CredentialsValidateFailedError as ex: result = False error = str(ex) - response = {'result': 'success' if result else 'error'} + response = {"result": "success" if result else "error"} if not result: - response['error'] = error + response["error"] = error return response class ModelProviderModelParameterRuleApi(Resource): - @setup_required @login_required @account_initialization_required def get(self, provider: str): parser = reqparse.RequestParser() - parser.add_argument('model', type=str, required=True, nullable=False, location='args') + parser.add_argument("model", type=str, required=True, nullable=False, location="args") args = parser.parse_args() tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() parameter_rules = model_provider_service.get_model_parameter_rules( - tenant_id=tenant_id, - provider=provider, - model=args['model'] + tenant_id=tenant_id, provider=provider, model=args["model"] ) - return jsonable_encoder({ - "data": parameter_rules - }) + return jsonable_encoder({"data": parameter_rules}) class ModelProviderAvailableModelApi(Resource): - @setup_required @login_required @account_initialization_required @@ -349,27 +351,31 @@ class ModelProviderAvailableModelApi(Resource): tenant_id = current_user.current_tenant_id model_provider_service = ModelProviderService() - models = model_provider_service.get_models_by_model_type( - tenant_id=tenant_id, - model_type=model_type - ) + models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) - return jsonable_encoder({ - "data": models - }) + return jsonable_encoder({"data": models}) -api.add_resource(ModelProviderModelApi, '/workspaces/current/model-providers//models') -api.add_resource(ModelProviderModelEnableApi, '/workspaces/current/model-providers//models/enable', - endpoint='model-provider-model-enable') -api.add_resource(ModelProviderModelDisableApi, '/workspaces/current/model-providers//models/disable', - endpoint='model-provider-model-disable') -api.add_resource(ModelProviderModelCredentialApi, - '/workspaces/current/model-providers//models/credentials') -api.add_resource(ModelProviderModelValidateApi, - '/workspaces/current/model-providers//models/credentials/validate') +api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers//models") +api.add_resource( + ModelProviderModelEnableApi, + "/workspaces/current/model-providers//models/enable", + endpoint="model-provider-model-enable", +) +api.add_resource( + ModelProviderModelDisableApi, + "/workspaces/current/model-providers//models/disable", + endpoint="model-provider-model-disable", +) +api.add_resource( + ModelProviderModelCredentialApi, "/workspaces/current/model-providers//models/credentials" +) +api.add_resource( + ModelProviderModelValidateApi, "/workspaces/current/model-providers//models/credentials/validate" +) -api.add_resource(ModelProviderModelParameterRuleApi, - '/workspaces/current/model-providers//models/parameter-rules') -api.add_resource(ModelProviderAvailableModelApi, '/workspaces/current/models/model-types/') -api.add_resource(DefaultModelApi, '/workspaces/current/default-model') +api.add_resource( + ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers//models/parameter-rules" +) +api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/") +api.add_resource(DefaultModelApi, "/workspaces/current/default-model") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index bafeabb08..c41a898fd 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -28,10 +28,18 @@ class ToolProviderListApi(Resource): tenant_id = current_user.current_tenant_id req = reqparse.RequestParser() - req.add_argument('type', type=str, choices=['builtin', 'model', 'api', 'workflow'], required=False, nullable=True, location='args') + req.add_argument( + "type", + type=str, + choices=["builtin", "model", "api", "workflow"], + required=False, + nullable=True, + location="args", + ) args = req.parse_args() - return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get('type', None)) + return ToolCommonService.list_tool_providers(user_id, tenant_id, args.get("type", None)) + class ToolBuiltinProviderListToolsApi(Resource): @setup_required @@ -41,11 +49,14 @@ class ToolBuiltinProviderListToolsApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder(BuiltinToolManageService.list_builtin_tool_provider_tools( - user_id, - tenant_id, - provider, - )) + return jsonable_encoder( + BuiltinToolManageService.list_builtin_tool_provider_tools( + user_id, + tenant_id, + provider, + ) + ) + class ToolBuiltinProviderDeleteApi(Resource): @setup_required @@ -54,7 +65,7 @@ class ToolBuiltinProviderDeleteApi(Resource): def post(self, provider): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id @@ -63,7 +74,8 @@ class ToolBuiltinProviderDeleteApi(Resource): tenant_id, provider, ) - + + class ToolBuiltinProviderUpdateApi(Resource): @setup_required @login_required @@ -71,12 +83,12 @@ class ToolBuiltinProviderUpdateApi(Resource): def post(self, provider): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") args = parser.parse_args() @@ -84,9 +96,10 @@ class ToolBuiltinProviderUpdateApi(Resource): user_id, tenant_id, provider, - args['credentials'], + args["credentials"], ) - + + class ToolBuiltinProviderGetCredentialsApi(Resource): @setup_required @login_required @@ -101,6 +114,7 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): provider, ) + class ToolBuiltinProviderIconApi(Resource): @setup_required def get(self, provider): @@ -108,6 +122,7 @@ class ToolBuiltinProviderIconApi(Resource): icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) + class ToolApiProviderAddApi(Resource): @setup_required @login_required @@ -115,35 +130,36 @@ class ToolApiProviderAddApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') - parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - parser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json') - parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json', default=[]) - parser.add_argument('custom_disclaimer', type=str, required=False, nullable=True, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + parser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json") + parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json", default=[]) + parser.add_argument("custom_disclaimer", type=str, required=False, nullable=True, location="json") args = parser.parse_args() return ApiToolManageService.create_api_tool_provider( user_id, tenant_id, - args['provider'], - args['icon'], - args['credentials'], - args['schema_type'], - args['schema'], - args.get('privacy_policy', ''), - args.get('custom_disclaimer', ''), - args.get('labels', []), + args["provider"], + args["icon"], + args["credentials"], + args["schema_type"], + args["schema"], + args.get("privacy_policy", ""), + args.get("custom_disclaimer", ""), + args.get("labels", []), ) + class ToolApiProviderGetRemoteSchemaApi(Resource): @setup_required @login_required @@ -151,16 +167,17 @@ class ToolApiProviderGetRemoteSchemaApi(Resource): def get(self): parser = reqparse.RequestParser() - parser.add_argument('url', type=str, required=True, nullable=False, location='args') + parser.add_argument("url", type=str, required=True, nullable=False, location="args") args = parser.parse_args() return ApiToolManageService.get_api_tool_provider_remote_schema( current_user.id, current_user.current_tenant_id, - args['url'], + args["url"], ) - + + class ToolApiProviderListToolsApi(Resource): @setup_required @login_required @@ -171,15 +188,18 @@ class ToolApiProviderListToolsApi(Resource): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, required=True, nullable=False, location='args') + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") args = parser.parse_args() - return jsonable_encoder(ApiToolManageService.list_api_tool_provider_tools( - user_id, - tenant_id, - args['provider'], - )) + return jsonable_encoder( + ApiToolManageService.list_api_tool_provider_tools( + user_id, + tenant_id, + args["provider"], + ) + ) + class ToolApiProviderUpdateApi(Resource): @setup_required @@ -188,37 +208,38 @@ class ToolApiProviderUpdateApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') - parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('original_provider', type=str, required=True, nullable=False, location='json') - parser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - parser.add_argument('privacy_policy', type=str, required=True, nullable=True, location='json') - parser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') - parser.add_argument('custom_disclaimer', type=str, required=True, nullable=True, location='json') + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("original_provider", type=str, required=True, nullable=False, location="json") + parser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + parser.add_argument("privacy_policy", type=str, required=True, nullable=True, location="json") + parser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") + parser.add_argument("custom_disclaimer", type=str, required=True, nullable=True, location="json") args = parser.parse_args() return ApiToolManageService.update_api_tool_provider( user_id, tenant_id, - args['provider'], - args['original_provider'], - args['icon'], - args['credentials'], - args['schema_type'], - args['schema'], - args['privacy_policy'], - args['custom_disclaimer'], - args.get('labels', []), + args["provider"], + args["original_provider"], + args["icon"], + args["credentials"], + args["schema_type"], + args["schema"], + args["privacy_policy"], + args["custom_disclaimer"], + args.get("labels", []), ) + class ToolApiProviderDeleteApi(Resource): @setup_required @login_required @@ -226,22 +247,23 @@ class ToolApiProviderDeleteApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, required=True, nullable=False, location='json') + parser.add_argument("provider", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return ApiToolManageService.delete_api_tool_provider( user_id, tenant_id, - args['provider'], + args["provider"], ) + class ToolApiProviderGetApi(Resource): @setup_required @login_required @@ -252,16 +274,17 @@ class ToolApiProviderGetApi(Resource): parser = reqparse.RequestParser() - parser.add_argument('provider', type=str, required=True, nullable=False, location='args') + parser.add_argument("provider", type=str, required=True, nullable=False, location="args") args = parser.parse_args() return ApiToolManageService.get_api_tool_provider( user_id, tenant_id, - args['provider'], + args["provider"], ) + class ToolBuiltinProviderCredentialsSchemaApi(Resource): @setup_required @login_required @@ -269,6 +292,7 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): def get(self, provider): return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider) + class ToolApiProviderSchemaApi(Resource): @setup_required @login_required @@ -276,14 +300,15 @@ class ToolApiProviderSchemaApi(Resource): def post(self): parser = reqparse.RequestParser() - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return ApiToolManageService.parser_api_schema( - schema=args['schema'], + schema=args["schema"], ) + class ToolApiProviderPreviousTestApi(Resource): @setup_required @login_required @@ -291,25 +316,26 @@ class ToolApiProviderPreviousTestApi(Resource): def post(self): parser = reqparse.RequestParser() - parser.add_argument('tool_name', type=str, required=True, nullable=False, location='json') - parser.add_argument('provider_name', type=str, required=False, nullable=False, location='json') - parser.add_argument('credentials', type=dict, required=True, nullable=False, location='json') - parser.add_argument('parameters', type=dict, required=True, nullable=False, location='json') - parser.add_argument('schema_type', type=str, required=True, nullable=False, location='json') - parser.add_argument('schema', type=str, required=True, nullable=False, location='json') + parser.add_argument("tool_name", type=str, required=True, nullable=False, location="json") + parser.add_argument("provider_name", type=str, required=False, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("parameters", type=dict, required=True, nullable=False, location="json") + parser.add_argument("schema_type", type=str, required=True, nullable=False, location="json") + parser.add_argument("schema", type=str, required=True, nullable=False, location="json") args = parser.parse_args() return ApiToolManageService.test_api_tool_preview( current_user.current_tenant_id, - args['provider_name'] if args['provider_name'] else '', - args['tool_name'], - args['credentials'], - args['parameters'], - args['schema_type'], - args['schema'], + args["provider_name"] if args["provider_name"] else "", + args["tool_name"], + args["credentials"], + args["parameters"], + args["schema_type"], + args["schema"], ) + class ToolWorkflowProviderCreateApi(Resource): @setup_required @login_required @@ -317,35 +343,36 @@ class ToolWorkflowProviderCreateApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id reqparser = reqparse.RequestParser() - reqparser.add_argument('workflow_app_id', type=uuid_value, required=True, nullable=False, location='json') - reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json') - reqparser.add_argument('label', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('description', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json') - reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='') - reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') + reqparser.add_argument("workflow_app_id", type=uuid_value, required=True, nullable=False, location="json") + reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") + reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") + reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") + reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") args = reqparser.parse_args() return WorkflowToolManageService.create_workflow_tool( user_id, tenant_id, - args['workflow_app_id'], - args['name'], - args['label'], - args['icon'], - args['description'], - args['parameters'], - args['privacy_policy'], - args.get('labels', []), + args["workflow_app_id"], + args["name"], + args["label"], + args["icon"], + args["description"], + args["parameters"], + args["privacy_policy"], + args.get("labels", []), ) + class ToolWorkflowProviderUpdateApi(Resource): @setup_required @login_required @@ -353,38 +380,39 @@ class ToolWorkflowProviderUpdateApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id reqparser = reqparse.RequestParser() - reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json') - reqparser.add_argument('name', type=alphanumeric, required=True, nullable=False, location='json') - reqparser.add_argument('label', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('description', type=str, required=True, nullable=False, location='json') - reqparser.add_argument('icon', type=dict, required=True, nullable=False, location='json') - reqparser.add_argument('parameters', type=list[dict], required=True, nullable=False, location='json') - reqparser.add_argument('privacy_policy', type=str, required=False, nullable=True, location='json', default='') - reqparser.add_argument('labels', type=list[str], required=False, nullable=True, location='json') - + reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") + reqparser.add_argument("name", type=alphanumeric, required=True, nullable=False, location="json") + reqparser.add_argument("label", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("description", type=str, required=True, nullable=False, location="json") + reqparser.add_argument("icon", type=dict, required=True, nullable=False, location="json") + reqparser.add_argument("parameters", type=list[dict], required=True, nullable=False, location="json") + reqparser.add_argument("privacy_policy", type=str, required=False, nullable=True, location="json", default="") + reqparser.add_argument("labels", type=list[str], required=False, nullable=True, location="json") + args = reqparser.parse_args() - if not args['workflow_tool_id']: - raise ValueError('incorrect workflow_tool_id') - + if not args["workflow_tool_id"]: + raise ValueError("incorrect workflow_tool_id") + return WorkflowToolManageService.update_workflow_tool( user_id, tenant_id, - args['workflow_tool_id'], - args['name'], - args['label'], - args['icon'], - args['description'], - args['parameters'], - args['privacy_policy'], - args.get('labels', []), + args["workflow_tool_id"], + args["name"], + args["label"], + args["icon"], + args["description"], + args["parameters"], + args["privacy_policy"], + args.get("labels", []), ) + class ToolWorkflowProviderDeleteApi(Resource): @setup_required @login_required @@ -392,21 +420,22 @@ class ToolWorkflowProviderDeleteApi(Resource): def post(self): if not current_user.is_admin_or_owner: raise Forbidden() - + user_id = current_user.id tenant_id = current_user.current_tenant_id reqparser = reqparse.RequestParser() - reqparser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='json') + reqparser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="json") args = reqparser.parse_args() return WorkflowToolManageService.delete_workflow_tool( user_id, tenant_id, - args['workflow_tool_id'], + args["workflow_tool_id"], ) - + + class ToolWorkflowProviderGetApi(Resource): @setup_required @login_required @@ -416,28 +445,29 @@ class ToolWorkflowProviderGetApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('workflow_tool_id', type=uuid_value, required=False, nullable=True, location='args') - parser.add_argument('workflow_app_id', type=uuid_value, required=False, nullable=True, location='args') + parser.add_argument("workflow_tool_id", type=uuid_value, required=False, nullable=True, location="args") + parser.add_argument("workflow_app_id", type=uuid_value, required=False, nullable=True, location="args") args = parser.parse_args() - if args.get('workflow_tool_id'): + if args.get("workflow_tool_id"): tool = WorkflowToolManageService.get_workflow_tool_by_tool_id( user_id, tenant_id, - args['workflow_tool_id'], + args["workflow_tool_id"], ) - elif args.get('workflow_app_id'): + elif args.get("workflow_app_id"): tool = WorkflowToolManageService.get_workflow_tool_by_app_id( user_id, tenant_id, - args['workflow_app_id'], + args["workflow_app_id"], ) else: - raise ValueError('incorrect workflow_tool_id or workflow_app_id') + raise ValueError("incorrect workflow_tool_id or workflow_app_id") return jsonable_encoder(tool) - + + class ToolWorkflowProviderListToolApi(Resource): @setup_required @login_required @@ -447,15 +477,18 @@ class ToolWorkflowProviderListToolApi(Resource): tenant_id = current_user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument('workflow_tool_id', type=uuid_value, required=True, nullable=False, location='args') + parser.add_argument("workflow_tool_id", type=uuid_value, required=True, nullable=False, location="args") args = parser.parse_args() - return jsonable_encoder(WorkflowToolManageService.list_single_workflow_tools( - user_id, - tenant_id, - args['workflow_tool_id'], - )) + return jsonable_encoder( + WorkflowToolManageService.list_single_workflow_tools( + user_id, + tenant_id, + args["workflow_tool_id"], + ) + ) + class ToolBuiltinListApi(Resource): @setup_required @@ -465,11 +498,17 @@ class ToolBuiltinListApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - )]) - + return jsonable_encoder( + [ + provider.to_dict() + for provider in BuiltinToolManageService.list_builtin_tools( + user_id, + tenant_id, + ) + ] + ) + + class ToolApiListApi(Resource): @setup_required @login_required @@ -478,11 +517,17 @@ class ToolApiListApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in ApiToolManageService.list_api_tools( - user_id, - tenant_id, - )]) - + return jsonable_encoder( + [ + provider.to_dict() + for provider in ApiToolManageService.list_api_tools( + user_id, + tenant_id, + ) + ] + ) + + class ToolWorkflowListApi(Resource): @setup_required @login_required @@ -491,11 +536,17 @@ class ToolWorkflowListApi(Resource): user_id = current_user.id tenant_id = current_user.current_tenant_id - return jsonable_encoder([provider.to_dict() for provider in WorkflowToolManageService.list_tenant_workflow_tools( - user_id, - tenant_id, - )]) - + return jsonable_encoder( + [ + provider.to_dict() + for provider in WorkflowToolManageService.list_tenant_workflow_tools( + user_id, + tenant_id, + ) + ] + ) + + class ToolLabelsApi(Resource): @setup_required @login_required @@ -503,36 +554,41 @@ class ToolLabelsApi(Resource): def get(self): return jsonable_encoder(ToolLabelsService.list_tool_labels()) + # tool provider -api.add_resource(ToolProviderListApi, '/workspaces/current/tool-providers') +api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") # builtin tool provider -api.add_resource(ToolBuiltinProviderListToolsApi, '/workspaces/current/tool-provider/builtin//tools') -api.add_resource(ToolBuiltinProviderDeleteApi, '/workspaces/current/tool-provider/builtin//delete') -api.add_resource(ToolBuiltinProviderUpdateApi, '/workspaces/current/tool-provider/builtin//update') -api.add_resource(ToolBuiltinProviderGetCredentialsApi, '/workspaces/current/tool-provider/builtin//credentials') -api.add_resource(ToolBuiltinProviderCredentialsSchemaApi, '/workspaces/current/tool-provider/builtin//credentials_schema') -api.add_resource(ToolBuiltinProviderIconApi, '/workspaces/current/tool-provider/builtin//icon') +api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools") +api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") +api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") +api.add_resource( + ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" +) +api.add_resource( + ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin//credentials_schema" +) +api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") # api tool provider -api.add_resource(ToolApiProviderAddApi, '/workspaces/current/tool-provider/api/add') -api.add_resource(ToolApiProviderGetRemoteSchemaApi, '/workspaces/current/tool-provider/api/remote') -api.add_resource(ToolApiProviderListToolsApi, '/workspaces/current/tool-provider/api/tools') -api.add_resource(ToolApiProviderUpdateApi, '/workspaces/current/tool-provider/api/update') -api.add_resource(ToolApiProviderDeleteApi, '/workspaces/current/tool-provider/api/delete') -api.add_resource(ToolApiProviderGetApi, '/workspaces/current/tool-provider/api/get') -api.add_resource(ToolApiProviderSchemaApi, '/workspaces/current/tool-provider/api/schema') -api.add_resource(ToolApiProviderPreviousTestApi, '/workspaces/current/tool-provider/api/test/pre') +api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add") +api.add_resource(ToolApiProviderGetRemoteSchemaApi, "/workspaces/current/tool-provider/api/remote") +api.add_resource(ToolApiProviderListToolsApi, "/workspaces/current/tool-provider/api/tools") +api.add_resource(ToolApiProviderUpdateApi, "/workspaces/current/tool-provider/api/update") +api.add_resource(ToolApiProviderDeleteApi, "/workspaces/current/tool-provider/api/delete") +api.add_resource(ToolApiProviderGetApi, "/workspaces/current/tool-provider/api/get") +api.add_resource(ToolApiProviderSchemaApi, "/workspaces/current/tool-provider/api/schema") +api.add_resource(ToolApiProviderPreviousTestApi, "/workspaces/current/tool-provider/api/test/pre") # workflow tool provider -api.add_resource(ToolWorkflowProviderCreateApi, '/workspaces/current/tool-provider/workflow/create') -api.add_resource(ToolWorkflowProviderUpdateApi, '/workspaces/current/tool-provider/workflow/update') -api.add_resource(ToolWorkflowProviderDeleteApi, '/workspaces/current/tool-provider/workflow/delete') -api.add_resource(ToolWorkflowProviderGetApi, '/workspaces/current/tool-provider/workflow/get') -api.add_resource(ToolWorkflowProviderListToolApi, '/workspaces/current/tool-provider/workflow/tools') +api.add_resource(ToolWorkflowProviderCreateApi, "/workspaces/current/tool-provider/workflow/create") +api.add_resource(ToolWorkflowProviderUpdateApi, "/workspaces/current/tool-provider/workflow/update") +api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provider/workflow/delete") +api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get") +api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools") -api.add_resource(ToolBuiltinListApi, '/workspaces/current/tools/builtin') -api.add_resource(ToolApiListApi, '/workspaces/current/tools/api') -api.add_resource(ToolWorkflowListApi, '/workspaces/current/tools/workflow') +api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin") +api.add_resource(ToolApiListApi, "/workspaces/current/tools/api") +api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow") -api.add_resource(ToolLabelsApi, '/workspaces/current/tool-labels') \ No newline at end of file +api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels") diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 7a11a45ae..623f0b8b7 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -26,39 +26,34 @@ from services.file_service import FileService from services.workspace_service import WorkspaceService provider_fields = { - 'provider_name': fields.String, - 'provider_type': fields.String, - 'is_valid': fields.Boolean, - 'token_is_set': fields.Boolean, + "provider_name": fields.String, + "provider_type": fields.String, + "is_valid": fields.Boolean, + "token_is_set": fields.Boolean, } tenant_fields = { - 'id': fields.String, - 'name': fields.String, - 'plan': fields.String, - 'status': fields.String, - 'created_at': TimestampField, - 'role': fields.String, - 'in_trial': fields.Boolean, - 'trial_end_reason': fields.String, - 'custom_config': fields.Raw(attribute='custom_config'), + "id": fields.String, + "name": fields.String, + "plan": fields.String, + "status": fields.String, + "created_at": TimestampField, + "role": fields.String, + "in_trial": fields.Boolean, + "trial_end_reason": fields.String, + "custom_config": fields.Raw(attribute="custom_config"), } tenants_fields = { - 'id': fields.String, - 'name': fields.String, - 'plan': fields.String, - 'status': fields.String, - 'created_at': TimestampField, - 'current': fields.Boolean + "id": fields.String, + "name": fields.String, + "plan": fields.String, + "status": fields.String, + "created_at": TimestampField, + "current": fields.Boolean, } -workspace_fields = { - 'id': fields.String, - 'name': fields.String, - 'status': fields.String, - 'created_at': TimestampField -} +workspace_fields = {"id": fields.String, "name": fields.String, "status": fields.String, "created_at": TimestampField} class TenantListApi(Resource): @@ -71,7 +66,7 @@ class TenantListApi(Resource): for tenant in tenants: if tenant.id == current_user.current_tenant_id: tenant.current = True # Set current=True for current tenant - return {'workspaces': marshal(tenants, tenants_fields)}, 200 + return {"workspaces": marshal(tenants, tenants_fields)}, 200 class WorkspaceListApi(Resource): @@ -79,31 +74,37 @@ class WorkspaceListApi(Resource): @admin_required def get(self): parser = reqparse.RequestParser() - parser.add_argument('page', type=inputs.int_range(1, 99999), required=False, default=1, location='args') - parser.add_argument('limit', type=inputs.int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("page", type=inputs.int_range(1, 99999), required=False, default=1, location="args") + parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - tenants = db.session.query(Tenant).order_by(Tenant.created_at.desc())\ - .paginate(page=args['page'], per_page=args['limit']) + tenants = ( + db.session.query(Tenant) + .order_by(Tenant.created_at.desc()) + .paginate(page=args["page"], per_page=args["limit"]) + ) has_more = False - if len(tenants.items) == args['limit']: + if len(tenants.items) == args["limit"]: current_page_first_tenant = tenants[-1] - rest_count = db.session.query(Tenant).filter( - Tenant.created_at < current_page_first_tenant.created_at, - Tenant.id != current_page_first_tenant.id - ).count() + rest_count = ( + db.session.query(Tenant) + .filter( + Tenant.created_at < current_page_first_tenant.created_at, Tenant.id != current_page_first_tenant.id + ) + .count() + ) if rest_count > 0: has_more = True total = db.session.query(Tenant).count() return { - 'data': marshal(tenants.items, workspace_fields), - 'has_more': has_more, - 'limit': args['limit'], - 'page': args['page'], - 'total': total - }, 200 + "data": marshal(tenants.items, workspace_fields), + "has_more": has_more, + "limit": args["limit"], + "page": args["page"], + "total": total, + }, 200 class TenantApi(Resource): @@ -112,8 +113,8 @@ class TenantApi(Resource): @account_initialization_required @marshal_with(tenant_fields) def get(self): - if request.path == '/info': - logging.warning('Deprecated URL /info was used.') + if request.path == "/info": + logging.warning("Deprecated URL /info was used.") tenant = current_user.current_tenant @@ -125,7 +126,7 @@ class TenantApi(Resource): tenant = tenants[0] # else, raise Unauthorized else: - raise Unauthorized('workspace is archived') + raise Unauthorized("workspace is archived") return WorkspaceService.get_tenant_info(tenant), 200 @@ -136,62 +137,64 @@ class SwitchWorkspaceApi(Resource): @account_initialization_required def post(self): parser = reqparse.RequestParser() - parser.add_argument('tenant_id', type=str, required=True, location='json') + parser.add_argument("tenant_id", type=str, required=True, location="json") args = parser.parse_args() # check if tenant_id is valid, 403 if not try: - TenantService.switch_tenant(current_user, args['tenant_id']) + TenantService.switch_tenant(current_user, args["tenant_id"]) except Exception: raise AccountNotLinkTenantError("Account not link tenant") - new_tenant = db.session.query(Tenant).get(args['tenant_id']) # Get new tenant + new_tenant = db.session.query(Tenant).get(args["tenant_id"]) # Get new tenant + + return {"result": "success", "new_tenant": marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} - return {'result': 'success', 'new_tenant': marshal(WorkspaceService.get_tenant_info(new_tenant), tenant_fields)} - class CustomConfigWorkspaceApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('workspace_custom') + @cloud_edition_billing_resource_check("workspace_custom") def post(self): parser = reqparse.RequestParser() - parser.add_argument('remove_webapp_brand', type=bool, location='json') - parser.add_argument('replace_webapp_logo', type=str, location='json') + parser.add_argument("remove_webapp_brand", type=bool, location="json") + parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() tenant = db.session.query(Tenant).filter(Tenant.id == current_user.current_tenant_id).one_or_404() custom_config_dict = { - 'remove_webapp_brand': args['remove_webapp_brand'], - 'replace_webapp_logo': args['replace_webapp_logo'] if args['replace_webapp_logo'] is not None else tenant.custom_config_dict.get('replace_webapp_logo') , + "remove_webapp_brand": args["remove_webapp_brand"], + "replace_webapp_logo": args["replace_webapp_logo"] + if args["replace_webapp_logo"] is not None + else tenant.custom_config_dict.get("replace_webapp_logo"), } tenant.custom_config_dict = custom_config_dict db.session.commit() - return {'result': 'success', 'tenant': marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} - + return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)} + class WebappLogoWorkspaceApi(Resource): @setup_required @login_required @account_initialization_required - @cloud_edition_billing_resource_check('workspace_custom') + @cloud_edition_billing_resource_check("workspace_custom") def post(self): # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() - extension = file.filename.split('.')[-1] - if extension.lower() not in ['svg', 'png']: + extension = file.filename.split(".")[-1] + if extension.lower() not in ["svg", "png"]: raise UnsupportedFileTypeError() try: @@ -201,14 +204,14 @@ class WebappLogoWorkspaceApi(Resource): raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() - - return { 'id': upload_file.id }, 201 + + return {"id": upload_file.id}, 201 -api.add_resource(TenantListApi, '/workspaces') # GET for getting all tenants -api.add_resource(WorkspaceListApi, '/all-workspaces') # GET for getting all tenants -api.add_resource(TenantApi, '/workspaces/current', endpoint='workspaces_current') # GET for getting current tenant info -api.add_resource(TenantApi, '/info', endpoint='info') # Deprecated -api.add_resource(SwitchWorkspaceApi, '/workspaces/switch') # POST for switching tenant -api.add_resource(CustomConfigWorkspaceApi, '/workspaces/custom-config') -api.add_resource(WebappLogoWorkspaceApi, '/workspaces/custom-config/webapp-logo/upload') +api.add_resource(TenantListApi, "/workspaces") # GET for getting all tenants +api.add_resource(WorkspaceListApi, "/all-workspaces") # GET for getting all tenants +api.add_resource(TenantApi, "/workspaces/current", endpoint="workspaces_current") # GET for getting current tenant info +api.add_resource(TenantApi, "/info", endpoint="info") # Deprecated +api.add_resource(SwitchWorkspaceApi, "/workspaces/switch") # POST for switching tenant +api.add_resource(CustomConfigWorkspaceApi, "/workspaces/custom-config") +api.add_resource(WebappLogoWorkspaceApi, "/workspaces/custom-config/webapp-logo/upload") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 3baf69acf..5a964c84f 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -16,7 +16,7 @@ def account_initialization_required(view): # check account initialization account = current_user - if account.status == 'uninitialized': + if account.status == "uninitialized": raise AccountNotInitializedError() return view(*args, **kwargs) @@ -27,7 +27,7 @@ def account_initialization_required(view): def only_edition_cloud(view): @wraps(view) def decorated(*args, **kwargs): - if dify_config.EDITION != 'CLOUD': + if dify_config.EDITION != "CLOUD": abort(404) return view(*args, **kwargs) @@ -38,7 +38,7 @@ def only_edition_cloud(view): def only_edition_self_hosted(view): @wraps(view) def decorated(*args, **kwargs): - if dify_config.EDITION != 'SELF_HOSTED': + if dify_config.EDITION != "SELF_HOSTED": abort(404) return view(*args, **kwargs) @@ -46,8 +46,9 @@ def only_edition_self_hosted(view): return decorated -def cloud_edition_billing_resource_check(resource: str, - error_msg: str = "You have reached the limit of your subscription."): +def cloud_edition_billing_resource_check( + resource: str, error_msg: str = "You have reached the limit of your subscription." +): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): @@ -58,22 +59,22 @@ def cloud_edition_billing_resource_check(resource: str, vector_space = features.vector_space documents_upload_quota = features.documents_upload_quota annotation_quota_limit = features.annotation_quota_limit - if resource == 'members' and 0 < members.limit <= members.size: + if resource == "members" and 0 < members.limit <= members.size: abort(403, error_msg) - elif resource == 'apps' and 0 < apps.limit <= apps.size: + elif resource == "apps" and 0 < apps.limit <= apps.size: abort(403, error_msg) - elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: + elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: abort(403, error_msg) - elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size: + elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: # The api of file upload is used in the multiple places, so we need to check the source of the request from datasets - source = request.args.get('source') - if source == 'datasets': + source = request.args.get("source") + if source == "datasets": abort(403, error_msg) else: return view(*args, **kwargs) - elif resource == 'workspace_custom' and not features.can_replace_logo: + elif resource == "workspace_custom" and not features.can_replace_logo: abort(403, error_msg) - elif resource == 'annotation' and 0 < annotation_quota_limit.limit < annotation_quota_limit.size: + elif resource == "annotation" and 0 < annotation_quota_limit.limit < annotation_quota_limit.size: abort(403, error_msg) else: return view(*args, **kwargs) @@ -85,15 +86,17 @@ def cloud_edition_billing_resource_check(resource: str, return interceptor -def cloud_edition_billing_knowledge_limit_check(resource: str, - error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."): +def cloud_edition_billing_knowledge_limit_check( + resource: str, + error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.", +): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - if resource == 'add_segment': - if features.billing.subscription.plan == 'sandbox': + if resource == "add_segment": + if features.billing.subscription.plan == "sandbox": abort(403, error_msg) else: return view(*args, **kwargs) @@ -112,7 +115,7 @@ def cloud_utm_record(view): features = FeatureService.get_features(current_user.current_tenant_id) if features.billing.enabled: - utm_info = request.cookies.get('utm_info') + utm_info = request.cookies.get("utm_info") if utm_info: utm_info = json.loads(utm_info) diff --git a/api/controllers/files/__init__.py b/api/controllers/files/__init__.py index 8d38ab986..97d5c3f88 100644 --- a/api/controllers/files/__init__.py +++ b/api/controllers/files/__init__.py @@ -2,7 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi -bp = Blueprint('files', __name__) +bp = Blueprint("files", __name__) api = ExternalApi(bp) diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py index 247b5d45e..2432285d9 100644 --- a/api/controllers/files/image_preview.py +++ b/api/controllers/files/image_preview.py @@ -13,35 +13,30 @@ class ImagePreviewApi(Resource): def get(self, file_id): file_id = str(file_id) - timestamp = request.args.get('timestamp') - nonce = request.args.get('nonce') - sign = request.args.get('sign') + timestamp = request.args.get("timestamp") + nonce = request.args.get("nonce") + sign = request.args.get("sign") if not timestamp or not nonce or not sign: - return {'content': 'Invalid request.'}, 400 + return {"content": "Invalid request."}, 400 try: - generator, mimetype = FileService.get_image_preview( - file_id, - timestamp, - nonce, - sign - ) + generator, mimetype = FileService.get_image_preview(file_id, timestamp, nonce, sign) except services.errors.file.UnsupportedFileTypeError: raise UnsupportedFileTypeError() return Response(generator, mimetype=mimetype) - + class WorkspaceWebappLogoApi(Resource): def get(self, workspace_id): workspace_id = str(workspace_id) custom_config = TenantService.get_custom_config(workspace_id) - webapp_logo_file_id = custom_config.get('replace_webapp_logo') if custom_config is not None else None + webapp_logo_file_id = custom_config.get("replace_webapp_logo") if custom_config is not None else None if not webapp_logo_file_id: - raise NotFound('webapp logo is not found') + raise NotFound("webapp logo is not found") try: generator, mimetype = FileService.get_public_image_preview( @@ -53,11 +48,11 @@ class WorkspaceWebappLogoApi(Resource): return Response(generator, mimetype=mimetype) -api.add_resource(ImagePreviewApi, '/files//image-preview') -api.add_resource(WorkspaceWebappLogoApi, '/files/workspaces//webapp-logo') +api.add_resource(ImagePreviewApi, "/files//image-preview") +api.add_resource(WorkspaceWebappLogoApi, "/files/workspaces//webapp-logo") class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 5a07ad2ea..38ac0815d 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -13,36 +13,39 @@ class ToolFilePreviewApi(Resource): parser = reqparse.RequestParser() - parser.add_argument('timestamp', type=str, required=True, location='args') - parser.add_argument('nonce', type=str, required=True, location='args') - parser.add_argument('sign', type=str, required=True, location='args') + parser.add_argument("timestamp", type=str, required=True, location="args") + parser.add_argument("nonce", type=str, required=True, location="args") + parser.add_argument("sign", type=str, required=True, location="args") args = parser.parse_args() - if not ToolFileManager.verify_file(file_id=file_id, - timestamp=args['timestamp'], - nonce=args['nonce'], - sign=args['sign'], + if not ToolFileManager.verify_file( + file_id=file_id, + timestamp=args["timestamp"], + nonce=args["nonce"], + sign=args["sign"], ): - raise Forbidden('Invalid request.') - + raise Forbidden("Invalid request.") + try: result = ToolFileManager.get_file_generator_by_tool_file_id( file_id, ) if not result: - raise NotFound('file is not found') - + raise NotFound("file is not found") + generator, mimetype = result except Exception: raise UnsupportedFileTypeError() return Response(generator, mimetype=mimetype) -api.add_resource(ToolFilePreviewApi, '/files/tools/.') + +api.add_resource(ToolFilePreviewApi, "/files/tools/.") + class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index ad49a649c..9f124736a 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -2,8 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi -bp = Blueprint('inner_api', __name__, url_prefix='/inner/api') +bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") api = ExternalApi(bp) from .workspace import workspace - diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index 06610d893..914b60f26 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -9,29 +9,24 @@ from services.account_service import TenantService class EnterpriseWorkspace(Resource): - @setup_required @inner_api_only def post(self): parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, location='json') - parser.add_argument('owner_email', type=str, required=True, location='json') + parser.add_argument("name", type=str, required=True, location="json") + parser.add_argument("owner_email", type=str, required=True, location="json") args = parser.parse_args() - account = Account.query.filter_by(email=args['owner_email']).first() + account = Account.query.filter_by(email=args["owner_email"]).first() if account is None: - return { - 'message': 'owner account not found.' - }, 404 + return {"message": "owner account not found."}, 404 - tenant = TenantService.create_tenant(args['name']) - TenantService.create_tenant_member(tenant, account, role='owner') + tenant = TenantService.create_tenant(args["name"]) + TenantService.create_tenant_member(tenant, account, role="owner") tenant_was_created.send(tenant) - return { - 'message': 'enterprise workspace created.' - } + return {"message": "enterprise workspace created."} -api.add_resource(EnterpriseWorkspace, '/enterprise/workspace') +api.add_resource(EnterpriseWorkspace, "/enterprise/workspace") diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 5c37f5276..51ffe683f 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -17,7 +17,7 @@ def inner_api_only(view): abort(404) # get header 'X-Inner-Api-Key' - inner_api_key = request.headers.get('X-Inner-Api-Key') + inner_api_key = request.headers.get("X-Inner-Api-Key") if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY: abort(401) @@ -33,29 +33,29 @@ def inner_api_user_auth(view): return view(*args, **kwargs) # get header 'X-Inner-Api-Key' - authorization = request.headers.get('Authorization') + authorization = request.headers.get("Authorization") if not authorization: return view(*args, **kwargs) - parts = authorization.split(':') + parts = authorization.split(":") if len(parts) != 2: return view(*args, **kwargs) user_id, token = parts - if ' ' in user_id: - user_id = user_id.split(' ')[1] + if " " in user_id: + user_id = user_id.split(" ")[1] - inner_api_key = request.headers.get('X-Inner-Api-Key') + inner_api_key = request.headers.get("X-Inner-Api-Key") - data_to_sign = f'DIFY {user_id}' + data_to_sign = f"DIFY {user_id}" - signature = hmac_new(inner_api_key.encode('utf-8'), data_to_sign.encode('utf-8'), sha1) - signature = b64encode(signature.digest()).decode('utf-8') + signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) + signature = b64encode(signature.digest()).decode("utf-8") if signature != token: return view(*args, **kwargs) - kwargs['user'] = db.session.query(EndUser).filter(EndUser.id == user_id).first() + kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first() return view(*args, **kwargs) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index 082660a89..ad39c160a 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -2,7 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi -bp = Blueprint('service_api', __name__, url_prefix='/v1') +bp = Blueprint("service_api", __name__, url_prefix="/v1") api = ExternalApi(bp) diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index af8e38ed3..ecc2d73de 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,4 +1,3 @@ - from flask_restful import Resource, fields, marshal_with from configs import dify_config @@ -13,32 +12,30 @@ class AppParameterApi(Resource): """Resource for app variables.""" variable_fields = { - 'key': fields.String, - 'name': fields.String, - 'description': fields.String, - 'type': fields.String, - 'default': fields.String, - 'max_length': fields.Integer, - 'options': fields.List(fields.String) + "key": fields.String, + "name": fields.String, + "description": fields.String, + "type": fields.String, + "default": fields.String, + "max_length": fields.Integer, + "options": fields.List(fields.String), } - system_parameters_fields = { - 'image_file_size_limit': fields.String - } + system_parameters_fields = {"image_file_size_limit": fields.String} parameters_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'suggested_questions_after_answer': fields.Raw, - 'speech_to_text': fields.Raw, - 'text_to_speech': fields.Raw, - 'retriever_resource': fields.Raw, - 'annotation_reply': fields.Raw, - 'more_like_this': fields.Raw, - 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw, - 'file_upload': fields.Raw, - 'system_parameters': fields.Nested(system_parameters_fields) + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "suggested_questions_after_answer": fields.Raw, + "speech_to_text": fields.Raw, + "text_to_speech": fields.Raw, + "retriever_resource": fields.Raw, + "annotation_reply": fields.Raw, + "more_like_this": fields.Raw, + "user_input_form": fields.Raw, + "sensitive_word_avoidance": fields.Raw, + "file_upload": fields.Raw, + "system_parameters": fields.Nested(system_parameters_fields), } @validate_app_token @@ -56,30 +53,35 @@ class AppParameterApi(Resource): app_model_config = app_model.app_model_config features_dict = app_model_config.to_dict() - user_input_form = features_dict.get('user_input_form', []) + user_input_form = features_dict.get("user_input_form", []) return { - 'opening_statement': features_dict.get('opening_statement'), - 'suggested_questions': features_dict.get('suggested_questions', []), - 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', - {"enabled": False}), - 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), - 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), - 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), - 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), - 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), - 'user_input_form': user_input_form, - 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', - {"enabled": False, "type": "", "configs": []}), - 'file_upload': features_dict.get('file_upload', {"image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"] - }}), - 'system_parameters': { - 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT - } + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get( + "suggested_questions_after_answer", {"enabled": False} + ), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, } @@ -89,16 +91,14 @@ class AppMetaApi(Resource): """Get app meta""" return AppService().get_app_meta(app_model) + class AppInfoApi(Resource): @validate_app_token def get(self, app_model: App): """Get app information""" - return { - 'name':app_model.name, - 'description':app_model.description - } + return {"name": app_model.name, "description": app_model.description} -api.add_resource(AppParameterApi, '/parameters') -api.add_resource(AppMetaApi, '/meta') -api.add_resource(AppInfoApi, '/info') +api.add_resource(AppParameterApi, "/parameters") +api.add_resource(AppMetaApi, "/meta") +api.add_resource(AppInfoApi, "/info") diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 3c009af34..85aab047a 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -33,14 +33,10 @@ from services.errors.audio import ( class AudioApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) def post(self, app_model: App, end_user: EndUser): - file = request.files['file'] + file = request.files["file"] try: - response = AudioService.transcript_asr( - app_model=app_model, - file=file, - end_user=end_user - ) + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -74,30 +70,32 @@ class TextApi(Resource): def post(self, app_model: App, end_user: EndUser): try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, required=False, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get('voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + voice = ( + args.get("voice") + if args.get("voice") + else app_model.app_model_config.text_to_speech_dict.get("voice") + ) except Exception: voice = None response = AudioService.transcript_tts( - app_model=app_model, - message_id=message_id, - end_user=end_user.external_user_id, - voice=voice, - text=text + app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text ) return response @@ -127,5 +125,5 @@ class TextApi(Resource): raise InternalServerError() -api.add_resource(AudioApi, '/audio-to-text') -api.add_resource(TextApi, '/text-to-audio') +api.add_resource(AudioApi, "/audio-to-text") +api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 2511f46ba..f1771baf3 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -33,21 +33,21 @@ from services.app_generate_service import AppGenerateService class CompletionApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise AppUnavailableError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" - args['auto_generate_name'] = False + args["auto_generate_name"] = False try: response = AppGenerateService.generate( @@ -84,12 +84,12 @@ class CompletionApi(Resource): class CompletionStopApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser, task_id): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise AppUnavailableError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatApi(Resource): @@ -100,25 +100,21 @@ class ChatApi(Resource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json') - parser.add_argument('auto_generate_name', type=bool, required=False, default=True, location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="dev", location="json") + parser.add_argument("auto_generate_name", type=bool, required=False, default=True, location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.SERVICE_API, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming ) return helper.compact_generate_response(response) @@ -153,10 +149,10 @@ class ChatStopApi(Resource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionApi, '/completion-messages') -api.add_resource(CompletionStopApi, '/completion-messages//stop') -api.add_resource(ChatApi, '/chat-messages') -api.add_resource(ChatStopApi, '/chat-messages//stop') +api.add_resource(CompletionApi, "/completion-messages") +api.add_resource(CompletionStopApi, "/completion-messages//stop") +api.add_resource(ChatApi, "/chat-messages") +api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 4598a1461..734027a1c 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -14,7 +14,6 @@ from services.conversation_service import ConversationService class ConversationApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model: App, end_user: EndUser): @@ -23,20 +22,26 @@ class ConversationApi(Resource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'], - required=False, default='-updated_at', location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) args = parser.parse_args() try: return ConversationService.pagination_by_last_id( app_model=app_model, user=end_user, - last_id=args['last_id'], - limit=args['limit'], + last_id=args["last_id"], + limit=args["limit"], invoke_from=InvokeFrom.SERVICE_API, - sort_by=args['sort_by'] + sort_by=args["sort_by"], ) except services.errors.conversation.LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -56,11 +61,10 @@ class ConversationDetailApi(Resource): ConversationService.delete(app_model, conversation_id, end_user) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ConversationRenameApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) @marshal_with(simple_conversation_fields) def post(self, app_model: App, end_user: EndUser, c_id): @@ -71,22 +75,16 @@ class ConversationRenameApi(Resource): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') + parser.add_argument("name", type=str, required=False, location="json") + parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() try: - return ConversationService.rename( - app_model, - conversation_id, - end_user, - args['name'], - args['auto_generate'] - ) + return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") -api.add_resource(ConversationRenameApi, '/conversations//name', endpoint='conversation_name') -api.add_resource(ConversationApi, '/conversations') -api.add_resource(ConversationDetailApi, '/conversations/', endpoint='conversation_detail') +api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="conversation_name") +api.add_resource(ConversationApi, "/conversations") +api.add_resource(ConversationDetailApi, "/conversations/", endpoint="conversation_detail") diff --git a/api/controllers/service_api/app/error.py b/api/controllers/service_api/app/error.py index ac9edb1b4..ca91da80c 100644 --- a/api/controllers/service_api/app/error.py +++ b/api/controllers/service_api/app/error.py @@ -2,104 +2,108 @@ from libs.exception import BaseHTTPException class AppUnavailableError(BaseHTTPException): - error_code = 'app_unavailable' + error_code = "app_unavailable" description = "App unavailable, please check your app configurations." code = 400 class NotCompletionAppError(BaseHTTPException): - error_code = 'not_completion_app' + error_code = "not_completion_app" description = "Please check if your Completion app mode matches the right API route." code = 400 class NotChatAppError(BaseHTTPException): - error_code = 'not_chat_app' + error_code = "not_chat_app" description = "Please check if your app mode matches the right API route." code = 400 class NotWorkflowAppError(BaseHTTPException): - error_code = 'not_workflow_app' + error_code = "not_workflow_app" description = "Please check if your app mode matches the right API route." code = 400 class ConversationCompletedError(BaseHTTPException): - error_code = 'conversation_completed' + error_code = "conversation_completed" description = "The conversation has ended. Please start a new conversation." code = 400 class ProviderNotInitializeError(BaseHTTPException): - error_code = 'provider_not_initialize' - description = "No valid model provider credentials found. " \ - "Please go to Settings -> Model Provider to complete your provider credentials." + error_code = "provider_not_initialize" + description = ( + "No valid model provider credentials found. " + "Please go to Settings -> Model Provider to complete your provider credentials." + ) code = 400 class ProviderQuotaExceededError(BaseHTTPException): - error_code = 'provider_quota_exceeded' - description = "Your quota for Dify Hosted OpenAI has been exhausted. " \ - "Please go to Settings -> Model Provider to complete your own provider credentials." + error_code = "provider_quota_exceeded" + description = ( + "Your quota for Dify Hosted OpenAI has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) code = 400 class ProviderModelCurrentlyNotSupportError(BaseHTTPException): - error_code = 'model_currently_not_support' + error_code = "model_currently_not_support" description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." code = 400 class CompletionRequestError(BaseHTTPException): - error_code = 'completion_request_error' + error_code = "completion_request_error" description = "Completion request failed." code = 400 class NoAudioUploadedError(BaseHTTPException): - error_code = 'no_audio_uploaded' + error_code = "no_audio_uploaded" description = "Please upload your audio." code = 400 class AudioTooLargeError(BaseHTTPException): - error_code = 'audio_too_large' + error_code = "audio_too_large" description = "Audio size exceeded. {message}" code = 413 class UnsupportedAudioTypeError(BaseHTTPException): - error_code = 'unsupported_audio_type' + error_code = "unsupported_audio_type" description = "Audio type not allowed." code = 415 class ProviderNotSupportSpeechToTextError(BaseHTTPException): - error_code = 'provider_not_support_speech_to_text' + error_code = "provider_not_support_speech_to_text" description = "Provider not support speech to text." code = 400 class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py index 5dbc1b1d1..e0a772eb3 100644 --- a/api/controllers/service_api/app/file.py +++ b/api/controllers/service_api/app/file.py @@ -16,15 +16,13 @@ from services.file_service import FileService class FileApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.FORM)) @marshal_with(file_fields) def post(self, app_model: App, end_user: EndUser): - - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if not file.mimetype: @@ -43,4 +41,4 @@ class FileApi(Resource): return upload_file, 201 -api.add_resource(FileApi, '/files/upload') +api.add_resource(FileApi, "/files/upload") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 875870e66..b39aaf7dd 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -17,61 +17,59 @@ from services.message_service import MessageService class MessageListApi(Resource): - feedback_fields = { - 'rating': fields.String - } + feedback_fields = {"rating": fields.String} retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } agent_thought_fields = { - 'id': fields.String, - 'chain_id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'thought': fields.String, - 'tool': fields.String, - 'tool_labels': fields.Raw, - 'tool_input': fields.String, - 'created_at': TimestampField, - 'observation': fields.String, - 'message_files': fields.List(fields.String, attribute='files') + "id": fields.String, + "chain_id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "thought": fields.String, + "tool": fields.String, + "tool_labels": fields.Raw, + "tool_input": fields.String, + "created_at": TimestampField, + "observation": fields.String, + "message_files": fields.List(fields.String, attribute="files"), } message_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'status': fields.String, - 'error': fields.String, + "id": fields.String, + "conversation_id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "status": fields.String, + "error": fields.String, } message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) @@ -82,14 +80,15 @@ class MessageListApi(Resource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() try: - return MessageService.pagination_by_first_id(app_model, end_user, - args['conversation_id'], args['first_id'], args['limit']) + return MessageService.pagination_by_first_id( + app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.message.FirstMessageNotExistsError: @@ -102,15 +101,15 @@ class MessageFeedbackApi(Resource): message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args['rating']) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class MessageSuggestedApi(Resource): @@ -123,10 +122,7 @@ class MessageSuggestedApi(Resource): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - user=end_user, - message_id=message_id, - invoke_from=InvokeFrom.SERVICE_API + app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.SERVICE_API ) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") @@ -136,9 +132,9 @@ class MessageSuggestedApi(Resource): logging.exception("internal server error.") raise InternalServerError() - return {'result': 'success', 'data': questions} + return {"result": "success", "data": questions} -api.add_resource(MessageListApi, '/messages') -api.add_resource(MessageFeedbackApi, '/messages//feedbacks') -api.add_resource(MessageSuggestedApi, '/messages//suggested') +api.add_resource(MessageListApi, "/messages") +api.add_resource(MessageFeedbackApi, "/messages//feedbacks") +api.add_resource(MessageSuggestedApi, "/messages//suggested") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 9446f9d58..5822e0921 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -30,19 +30,20 @@ from services.app_generate_service import AppGenerateService logger = logging.getLogger(__name__) workflow_run_fields = { - 'id': fields.String, - 'workflow_id': fields.String, - 'status': fields.String, - 'inputs': fields.Raw, - 'outputs': fields.Raw, - 'error': fields.String, - 'total_steps': fields.Integer, - 'total_tokens': fields.Integer, - 'created_at': fields.DateTime, - 'finished_at': fields.DateTime, - 'elapsed_time': fields.Float, + "id": fields.String, + "workflow_id": fields.String, + "status": fields.String, + "inputs": fields.Raw, + "outputs": fields.Raw, + "error": fields.String, + "total_steps": fields.Integer, + "total_tokens": fields.Integer, + "created_at": fields.DateTime, + "finished_at": fields.DateTime, + "elapsed_time": fields.Float, } + class WorkflowRunDetailApi(Resource): @validate_app_token @marshal_with(workflow_run_fields) @@ -56,6 +57,8 @@ class WorkflowRunDetailApi(Resource): workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_id).first() return workflow_run + + class WorkflowRunApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON, required=True)) def post(self, app_model: App, end_user: EndUser): @@ -67,20 +70,16 @@ class WorkflowRunApi(Resource): raise NotWorkflowAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") args = parser.parse_args() - streaming = args.get('response_mode') == 'streaming' + streaming = args.get("response_mode") == "streaming" try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.SERVICE_API, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=streaming ) return helper.compact_generate_response(response) @@ -111,11 +110,9 @@ class WorkflowTaskStopApi(Resource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.SERVICE_API, end_user.id) - return { - "result": "success" - } + return {"result": "success"} -api.add_resource(WorkflowRunApi, '/workflows/run') -api.add_resource(WorkflowRunDetailApi, '/workflows/run/') -api.add_resource(WorkflowTaskStopApi, '/workflows/tasks//stop') +api.add_resource(WorkflowRunApi, "/workflows/run") +api.add_resource(WorkflowRunDetailApi, "/workflows/run/") +api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index e0863859a..c2c0672a0 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -16,7 +16,7 @@ from services.dataset_service import DatasetService def _validate_name(name): if not name or len(name) < 1 or len(name) > 40: - raise ValueError('Name must be between 1 to 40 characters.') + raise ValueError("Name must be between 1 to 40 characters.") return name @@ -26,24 +26,18 @@ class DatasetListApi(DatasetApiResource): def get(self, tenant_id): """Resource for getting datasets.""" - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - provider = request.args.get('provider', default="vendor") - search = request.args.get('keyword', default=None, type=str) - tag_ids = request.args.getlist('tag_ids') + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + provider = request.args.get("provider", default="vendor") + search = request.args.get("keyword", default=None, type=str) + tag_ids = request.args.getlist("tag_ids") - datasets, total = DatasetService.get_datasets(page, limit, provider, - tenant_id, current_user, search, tag_ids) + datasets, total = DatasetService.get_datasets(page, limit, provider, tenant_id, current_user, search, tag_ids) # check embedding setting provider_manager = ProviderManager() - configurations = provider_manager.get_configurations( - tenant_id=current_user.current_tenant_id - ) + configurations = provider_manager.get_configurations(tenant_id=current_user.current_tenant_id) - embedding_models = configurations.get_models( - model_type=ModelType.TEXT_EMBEDDING, - only_active=True - ) + embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) model_names = [] for embedding_model in embedding_models: @@ -51,50 +45,59 @@ class DatasetListApi(DatasetApiResource): data = marshal(datasets, dataset_detail_fields) for item in data: - if item['indexing_technique'] == 'high_quality': + if item["indexing_technique"] == "high_quality": item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" if item_model in model_names: - item['embedding_available'] = True + item["embedding_available"] = True else: - item['embedding_available'] = False + item["embedding_available"] = False else: - item['embedding_available'] = True - response = { - 'data': data, - 'has_more': len(datasets) == limit, - 'limit': limit, - 'total': total, - 'page': page - } + item["embedding_available"] = True + response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 - def post(self, tenant_id): """Resource for creating datasets.""" parser = reqparse.RequestParser() - parser.add_argument('name', nullable=False, required=True, - help='type is required. Name must be between 1 to 40 characters.', - type=_validate_name) - parser.add_argument('indexing_technique', type=str, location='json', - choices=Dataset.INDEXING_TECHNIQUE_LIST, - help='Invalid indexing technique.') - parser.add_argument('permission', type=str, location='json', choices=( - DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), help='Invalid permission.', required=False, nullable=False) + parser.add_argument( + "name", + nullable=False, + required=True, + help="type is required. Name must be between 1 to 40 characters.", + type=_validate_name, + ) + parser.add_argument( + "indexing_technique", + type=str, + location="json", + choices=Dataset.INDEXING_TECHNIQUE_LIST, + help="Invalid indexing technique.", + ) + parser.add_argument( + "permission", + type=str, + location="json", + choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM), + help="Invalid permission.", + required=False, + nullable=False, + ) args = parser.parse_args() try: dataset = DatasetService.create_empty_dataset( tenant_id=tenant_id, - name=args['name'], - indexing_technique=args['indexing_technique'], + name=args["name"], + indexing_technique=args["indexing_technique"], account=current_user, - permission=args['permission'] + permission=args["permission"], ) except services.errors.dataset.DatasetNameDuplicateError: raise DatasetNameDuplicateError() return marshal(dataset, dataset_detail_fields), 200 + class DatasetApi(DatasetApiResource): """Resource for dataset.""" @@ -106,7 +109,7 @@ class DatasetApi(DatasetApiResource): dataset_id (UUID): The ID of the dataset to be deleted. Returns: - dict: A dictionary with a key 'result' and a value 'success' + dict: A dictionary with a key 'result' and a value 'success' if the dataset was successfully deleted. Omitted in HTTP response. int: HTTP status code 204 indicating that the operation was successful. @@ -118,11 +121,12 @@ class DatasetApi(DatasetApiResource): try: if DatasetService.delete_dataset(dataset_id_str, current_user): - return {'result': 'success'}, 204 + return {"result": "success"}, 204 else: raise NotFound("Dataset not found.") except services.errors.dataset.DatasetInUseError: raise DatasetInUseError() -api.add_resource(DatasetListApi, '/datasets') -api.add_resource(DatasetApi, '/datasets/') + +api.add_resource(DatasetListApi, "/datasets") +api.add_resource(DatasetApi, "/datasets/") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index ac1ea820a..fb48a6c76 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -27,47 +27,40 @@ from services.file_service import FileService class DocumentAddByTextApi(DatasetApiResource): """Resource for documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') - @cloud_edition_billing_resource_check('documents', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_resource_check("documents", "dataset") def post(self, tenant_id, dataset_id): """Create document by text.""" parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=True, nullable=False, location='json') - parser.add_argument('text', type=str, required=True, nullable=False, location='json') - parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json') - parser.add_argument('original_document_id', type=str, required=False, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('indexing_technique', type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument("name", type=str, required=True, nullable=False, location="json") + parser.add_argument("text", type=str, required=True, nullable=False, location="json") + parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") + parser.add_argument("original_document_id", type=str, required=False, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument( + "indexing_technique", type=str, choices=Dataset.INDEXING_TECHNIQUE_LIST, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') + raise ValueError("Dataset is not exist.") - if not dataset.indexing_technique and not args['indexing_technique']: - raise ValueError('indexing_technique is required.') + if not dataset.indexing_technique and not args["indexing_technique"]: + raise ValueError("indexing_technique is required.") - upload_file = FileService.upload_text(args.get('text'), args.get('name')) + upload_file = FileService.upload_text(args.get("text"), args.get("name")) data_source = { - 'type': 'upload_file', - 'info_list': { - 'data_source_type': 'upload_file', - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } - args['data_source'] = data_source + args["data_source"] = data_source # validate args DocumentService.document_create_args_validate(args) @@ -76,60 +69,49 @@ class DocumentAddByTextApi(DatasetApiResource): dataset=dataset, document_data=args, account=current_user, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 class DocumentUpdateByTextApi(DatasetApiResource): """Resource for update documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by text.""" parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, nullable=True, location='json') - parser.add_argument('text', type=str, required=False, nullable=True, location='json') - parser.add_argument('process_rule', type=dict, required=False, nullable=True, location='json') - parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') - parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, - location='json') - parser.add_argument('retrieval_model', type=dict, required=False, nullable=False, - location='json') + parser.add_argument("name", type=str, required=False, nullable=True, location="json") + parser.add_argument("text", type=str, required=False, nullable=True, location="json") + parser.add_argument("process_rule", type=dict, required=False, nullable=True, location="json") + parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") + parser.add_argument( + "doc_language", type=str, default="English", required=False, nullable=False, location="json" + ) + parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json") args = parser.parse_args() dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') + raise ValueError("Dataset is not exist.") - if args['text']: - upload_file = FileService.upload_text(args.get('text'), args.get('name')) + if args["text"]: + upload_file = FileService.upload_text(args.get("text"), args.get("name")) data_source = { - 'type': 'upload_file', - 'info_list': { - 'data_source_type': 'upload_file', - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } + "type": "upload_file", + "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, } - args['data_source'] = data_source + args["data_source"] = data_source # validate args - args['original_document_id'] = str(document_id) + args["original_document_id"] = str(document_id) DocumentService.document_create_args_validate(args) try: @@ -137,65 +119,53 @@ class DocumentUpdateByTextApi(DatasetApiResource): dataset=dataset, document_data=args, account=current_user, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 class DocumentAddByFileApi(DatasetApiResource): """Resource for documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') - @cloud_edition_billing_resource_check('documents', 'dataset') + + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_resource_check("documents", "dataset") def post(self, tenant_id, dataset_id): """Create document by upload file.""" args = {} - if 'data' in request.form: - args = json.loads(request.form['data']) - if 'doc_form' not in args: - args['doc_form'] = 'text_model' - if 'doc_language' not in args: - args['doc_language'] = 'English' + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" # get dataset info dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') - if not dataset.indexing_technique and not args.get('indexing_technique'): - raise ValueError('indexing_technique is required.') + raise ValueError("Dataset is not exist.") + if not dataset.indexing_technique and not args.get("indexing_technique"): + raise ValueError("indexing_technique is required.") # save file info - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: raise TooManyFilesError() upload_file = FileService.upload_file(file, current_user) - data_source = { - 'type': 'upload_file', - 'info_list': { - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } - } - args['data_source'] = data_source + data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} + args["data_source"] = data_source # validate args DocumentService.document_create_args_validate(args) @@ -204,63 +174,49 @@ class DocumentAddByFileApi(DatasetApiResource): dataset=dataset, document_data=args, account=dataset.created_by_account, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 class DocumentUpdateByFileApi(DatasetApiResource): """Resource for update documents.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" args = {} - if 'data' in request.form: - args = json.loads(request.form['data']) - if 'doc_form' not in args: - args['doc_form'] = 'text_model' - if 'doc_language' not in args: - args['doc_language'] = 'English' + if "data" in request.form: + args = json.loads(request.form["data"]) + if "doc_form" not in args: + args["doc_form"] = "text_model" + if "doc_language" not in args: + args["doc_language"] = "English" # get dataset info dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') - if 'file' in request.files: + raise ValueError("Dataset is not exist.") + if "file" in request.files: # save file info - file = request.files['file'] - + file = request.files["file"] if len(request.files) > 1: raise TooManyFilesError() upload_file = FileService.upload_file(file, current_user) - data_source = { - 'type': 'upload_file', - 'info_list': { - 'file_info_list': { - 'file_ids': [upload_file.id] - } - } - } - args['data_source'] = data_source + data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} + args["data_source"] = data_source # validate args - args['original_document_id'] = str(document_id) + args["original_document_id"] = str(document_id) DocumentService.document_create_args_validate(args) try: @@ -268,16 +224,13 @@ class DocumentUpdateByFileApi(DatasetApiResource): dataset=dataset, document_data=args, account=dataset.created_by_account, - dataset_process_rule=dataset.latest_process_rule if 'process_rule' not in args else None, - created_from='api' + dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + created_from="api", ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) document = documents[0] - documents_and_batch_fields = { - 'document': marshal(document, document_fields), - 'batch': batch - } + documents_and_batch_fields = {"document": marshal(document, document_fields), "batch": batch} return documents_and_batch_fields, 200 @@ -289,13 +242,10 @@ class DocumentDeleteApi(DatasetApiResource): tenant_id = str(tenant_id) # get dataset info - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise ValueError('Dataset is not exist.') + raise ValueError("Dataset is not exist.") document = DocumentService.get_document(dataset.id, document_id) @@ -311,44 +261,39 @@ class DocumentDeleteApi(DatasetApiResource): # delete document DocumentService.delete_document(document) except services.errors.document.DocumentIndexingError: - raise DocumentIndexingError('Cannot delete document during indexing.') + raise DocumentIndexingError("Cannot delete document during indexing.") - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class DocumentListApi(DatasetApiResource): def get(self, tenant_id, dataset_id): dataset_id = str(dataset_id) tenant_id = str(tenant_id) - page = request.args.get('page', default=1, type=int) - limit = request.args.get('limit', default=20, type=int) - search = request.args.get('keyword', default=None, type=str) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + page = request.args.get("page", default=1, type=int) + limit = request.args.get("limit", default=20, type=int) + search = request.args.get("keyword", default=None, type=str) + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") - query = Document.query.filter_by( - dataset_id=str(dataset_id), tenant_id=tenant_id) + query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) if search: - search = f'%{search}%' + search = f"%{search}%" query = query.filter(Document.name.like(search)) query = query.order_by(desc(Document.created_at)) - paginated_documents = query.paginate( - page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items response = { - 'data': marshal(documents, document_fields), - 'has_more': len(documents) == limit, - 'limit': limit, - 'total': paginated_documents.total, - 'page': page + "data": marshal(documents, document_fields), + "has_more": len(documents) == limit, + "limit": limit, + "total": paginated_documents.total, + "page": page, } return response @@ -360,38 +305,36 @@ class DocumentIndexingStatusApi(DatasetApiResource): batch = str(batch) tenant_id = str(tenant_id) # get dataset - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # get documents documents = DocumentService.get_batch_documents(dataset_id, batch) if not documents: - raise NotFound('Documents not found.') + raise NotFound("Documents not found.") documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() - total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), - DocumentSegment.status != 're_segment').count() + completed_segments = DocumentSegment.query.filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ).count() + total_segments = DocumentSegment.query.filter( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" + ).count() document.completed_segments = completed_segments document.total_segments = total_segments if document.is_paused: - document.indexing_status = 'paused' + document.indexing_status = "paused" documents_status.append(marshal(document, document_status_fields)) - data = { - 'data': documents_status - } + data = {"data": documents_status} return data -api.add_resource(DocumentAddByTextApi, '/datasets//document/create_by_text') -api.add_resource(DocumentAddByFileApi, '/datasets//document/create_by_file') -api.add_resource(DocumentUpdateByTextApi, '/datasets//documents//update_by_text') -api.add_resource(DocumentUpdateByFileApi, '/datasets//documents//update_by_file') -api.add_resource(DocumentDeleteApi, '/datasets//documents/') -api.add_resource(DocumentListApi, '/datasets//documents') -api.add_resource(DocumentIndexingStatusApi, '/datasets//documents//indexing-status') +api.add_resource(DocumentAddByTextApi, "/datasets//document/create_by_text") +api.add_resource(DocumentAddByFileApi, "/datasets//document/create_by_file") +api.add_resource(DocumentUpdateByTextApi, "/datasets//documents//update_by_text") +api.add_resource(DocumentUpdateByFileApi, "/datasets//documents//update_by_file") +api.add_resource(DocumentDeleteApi, "/datasets//documents/") +api.add_resource(DocumentListApi, "/datasets//documents") +api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status") diff --git a/api/controllers/service_api/dataset/error.py b/api/controllers/service_api/dataset/error.py index e77693b6c..5ff5e08c7 100644 --- a/api/controllers/service_api/dataset/error.py +++ b/api/controllers/service_api/dataset/error.py @@ -2,78 +2,78 @@ from libs.exception import BaseHTTPException class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 class HighQualityDatasetOnlyError(BaseHTTPException): - error_code = 'high_quality_dataset_only' + error_code = "high_quality_dataset_only" description = "Current operation only supports 'high-quality' datasets." code = 400 class DatasetNotInitializedError(BaseHTTPException): - error_code = 'dataset_not_initialized' + error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." code = 400 class ArchivedDocumentImmutableError(BaseHTTPException): - error_code = 'archived_document_immutable' + error_code = "archived_document_immutable" description = "The archived document is not editable." code = 403 class DatasetNameDuplicateError(BaseHTTPException): - error_code = 'dataset_name_duplicate' + error_code = "dataset_name_duplicate" description = "The dataset name already exists. Please modify your dataset name." code = 409 class InvalidActionError(BaseHTTPException): - error_code = 'invalid_action' + error_code = "invalid_action" description = "Invalid action." code = 400 class DocumentAlreadyFinishedError(BaseHTTPException): - error_code = 'document_already_finished' + error_code = "document_already_finished" description = "The document has been processed. Please refresh the page or go to the document details." code = 400 class DocumentIndexingError(BaseHTTPException): - error_code = 'document_indexing' + error_code = "document_indexing" description = "The document is being processed and cannot be edited." code = 400 class InvalidMetadataError(BaseHTTPException): - error_code = 'invalid_metadata' + error_code = "invalid_metadata" description = "The metadata content is incorrect. Please check and verify." code = 400 class DatasetInUseError(BaseHTTPException): - error_code = 'dataset_in_use' + error_code = "dataset_in_use" description = "The dataset is being used by some apps. Please remove the dataset from the apps before deleting it." code = 409 diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 0fa2aa65b..5e10f3b48 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -21,52 +21,47 @@ from services.dataset_service import DatasetService, DocumentService, SegmentSer class SegmentApi(DatasetApiResource): """Resource for segments.""" - @cloud_edition_billing_resource_check('vector_space', 'dataset') - @cloud_edition_billing_knowledge_limit_check('add_segment', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") def post(self, tenant_id, dataset_id, document_id): """Create single segment.""" # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check embedding model setting - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") - except ProviderTokenNotInitError as ex: + "in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # validate args parser = reqparse.RequestParser() - parser.add_argument('segments', type=list, required=False, nullable=True, location='json') + parser.add_argument("segments", type=list, required=False, nullable=True, location="json") args = parser.parse_args() - if args['segments'] is not None: - for args_item in args['segments']: + if args["segments"] is not None: + for args_item in args["segments"]: SegmentService.segment_create_args_validate(args_item, document) - segments = SegmentService.multi_create_segment(args['segments'], document, dataset) - return { - 'data': marshal(segments, segment_fields), - 'doc_form': document.doc_form - }, 200 + segments = SegmentService.multi_create_segment(args["segments"], document, dataset) + return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form}, 200 else: return {"error": "Segemtns is required"}, 400 @@ -75,61 +70,53 @@ class SegmentApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check document document_id = str(document_id) document = DocumentService.get_document(dataset.id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check embedding model setting - if dataset.indexing_technique == 'high_quality': + if dataset.indexing_technique == "high_quality": try: model_manager = ModelManager() model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) parser = reqparse.RequestParser() - parser.add_argument('status', type=str, - action='append', default=[], location='args') - parser.add_argument('keyword', type=str, default=None, location='args') + parser.add_argument("status", type=str, action="append", default=[], location="args") + parser.add_argument("keyword", type=str, default=None, location="args") args = parser.parse_args() - status_list = args['status'] - keyword = args['keyword'] + status_list = args["status"] + keyword = args["keyword"] query = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id ) if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) if keyword: - query = query.where(DocumentSegment.content.ilike(f'%{keyword}%')) + query = query.where(DocumentSegment.content.ilike(f"%{keyword}%")) total = query.count() segments = query.order_by(DocumentSegment.position).all() - return { - 'data': marshal(segments, segment_fields), - 'doc_form': document.doc_form, - 'total': total - }, 200 + return {"data": marshal(segments, segment_fields), "doc_form": document.doc_form, "total": total}, 200 class DatasetSegmentApi(DatasetApiResource): @@ -137,48 +124,41 @@ class DatasetSegmentApi(DatasetApiResource): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') + raise NotFound("Document not found.") # check segment segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") SegmentService.delete_segment(segment, document, dataset) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 - @cloud_edition_billing_resource_check('vector_space', 'dataset') + @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).filter( - Dataset.tenant_id == tenant_id, - Dataset.id == dataset_id - ).first() + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() if not dataset: - raise NotFound('Dataset not found.') + raise NotFound("Dataset not found.") # check user's model setting DatasetService.check_dataset_model_setting(dataset) # check document document_id = str(document_id) document = DocumentService.get_document(dataset_id, document_id) if not document: - raise NotFound('Document not found.') - if dataset.indexing_technique == 'high_quality': + raise NotFound("Document not found.") + if dataset.indexing_technique == "high_quality": # check embedding model setting try: model_manager = ModelManager() @@ -186,35 +166,34 @@ class DatasetSegmentApi(DatasetApiResource): tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model + model=dataset.embedding_model, ) except LLMBadRequestError: raise ProviderNotInitializeError( "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider.") + "in the Settings -> Model Provider." + ) except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), - DocumentSegment.tenant_id == current_user.current_tenant_id + DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id ).first() if not segment: - raise NotFound('Segment not found.') + raise NotFound("Segment not found.") # validate args parser = reqparse.RequestParser() - parser.add_argument('segment', type=dict, required=False, nullable=True, location='json') + parser.add_argument("segment", type=dict, required=False, nullable=True, location="json") args = parser.parse_args() - SegmentService.segment_create_args_validate(args['segment'], document) - segment = SegmentService.update_segment(args['segment'], segment, document, dataset) - return { - 'data': marshal(segment, segment_fields), - 'doc_form': document.doc_form - }, 200 + SegmentService.segment_create_args_validate(args["segment"], document) + segment = SegmentService.update_segment(args["segment"], segment, document, dataset) + return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 -api.add_resource(SegmentApi, '/datasets//documents//segments') -api.add_resource(DatasetSegmentApi, '/datasets//documents//segments/') +api.add_resource(SegmentApi, "/datasets//documents//segments") +api.add_resource( + DatasetSegmentApi, "/datasets//documents//segments/" +) diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py index c910063eb..d24c4597e 100644 --- a/api/controllers/service_api/index.py +++ b/api/controllers/service_api/index.py @@ -13,4 +13,4 @@ class IndexApi(Resource): } -api.add_resource(IndexApi, '/') +api.add_resource(IndexApi, "/") diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 819512edf..a596c6f28 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -21,9 +21,10 @@ class WhereisUserArg(Enum): """ Enum for whereis_user_arg. """ - QUERY = 'query' - JSON = 'json' - FORM = 'form' + + QUERY = "query" + JSON = "json" + FORM = "form" class FetchUserArg(BaseModel): @@ -35,13 +36,13 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio def decorator(view_func): @wraps(view_func) def decorated_view(*args, **kwargs): - api_token = validate_and_get_api_token('app') + api_token = validate_and_get_api_token("app") app_model = db.session.query(App).filter(App.id == api_token.app_id).first() if not app_model: raise Forbidden("The app no longer exists.") - if app_model.status != 'normal': + if app_model.status != "normal": raise Forbidden("The app's status is abnormal.") if not app_model.enable_api: @@ -51,15 +52,15 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio if tenant.status == TenantStatus.ARCHIVE: raise Forbidden("The workspace's status is archived.") - kwargs['app_model'] = app_model + kwargs["app_model"] = app_model if fetch_user_arg: if fetch_user_arg.fetch_from == WhereisUserArg.QUERY: - user_id = request.args.get('user') + user_id = request.args.get("user") elif fetch_user_arg.fetch_from == WhereisUserArg.JSON: - user_id = request.get_json().get('user') + user_id = request.get_json().get("user") elif fetch_user_arg.fetch_from == WhereisUserArg.FORM: - user_id = request.form.get('user') + user_id = request.form.get("user") else: # use default-user user_id = None @@ -70,9 +71,10 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio if user_id: user_id = str(user_id) - kwargs['end_user'] = create_or_update_end_user_for_user_id(app_model, user_id) + kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id) return view_func(*args, **kwargs) + return decorated_view if view is None: @@ -81,9 +83,9 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio return decorator(view) -def cloud_edition_billing_resource_check(resource: str, - api_token_type: str, - error_msg: str = "You have reached the limit of your subscription."): +def cloud_edition_billing_resource_check( + resource: str, api_token_type: str, error_msg: str = "You have reached the limit of your subscription." +): def interceptor(view): def decorated(*args, **kwargs): api_token = validate_and_get_api_token(api_token_type) @@ -95,33 +97,37 @@ def cloud_edition_billing_resource_check(resource: str, vector_space = features.vector_space documents_upload_quota = features.documents_upload_quota - if resource == 'members' and 0 < members.limit <= members.size: + if resource == "members" and 0 < members.limit <= members.size: raise Forbidden(error_msg) - elif resource == 'apps' and 0 < apps.limit <= apps.size: + elif resource == "apps" and 0 < apps.limit <= apps.size: raise Forbidden(error_msg) - elif resource == 'vector_space' and 0 < vector_space.limit <= vector_space.size: + elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size: raise Forbidden(error_msg) - elif resource == 'documents' and 0 < documents_upload_quota.limit <= documents_upload_quota.size: + elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size: raise Forbidden(error_msg) else: return view(*args, **kwargs) return view(*args, **kwargs) + return decorated + return interceptor -def cloud_edition_billing_knowledge_limit_check(resource: str, - api_token_type: str, - error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan."): +def cloud_edition_billing_knowledge_limit_check( + resource: str, + api_token_type: str, + error_msg: str = "To unlock this feature and elevate your Dify experience, please upgrade to a paid plan.", +): def interceptor(view): @wraps(view) def decorated(*args, **kwargs): api_token = validate_and_get_api_token(api_token_type) features = FeatureService.get_features(api_token.tenant_id) if features.billing.enabled: - if resource == 'add_segment': - if features.billing.subscription.plan == 'sandbox': + if resource == "add_segment": + if features.billing.subscription.plan == "sandbox": raise Forbidden(error_msg) else: return view(*args, **kwargs) @@ -132,17 +138,20 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, return interceptor + def validate_dataset_token(view=None): def decorator(view): @wraps(view) def decorated(*args, **kwargs): - api_token = validate_and_get_api_token('dataset') - tenant_account_join = db.session.query(Tenant, TenantAccountJoin) \ - .filter(Tenant.id == api_token.tenant_id) \ - .filter(TenantAccountJoin.tenant_id == Tenant.id) \ - .filter(TenantAccountJoin.role.in_(['owner'])) \ - .filter(Tenant.status == TenantStatus.NORMAL) \ - .one_or_none() # TODO: only owner information is required, so only one is returned. + api_token = validate_and_get_api_token("dataset") + tenant_account_join = ( + db.session.query(Tenant, TenantAccountJoin) + .filter(Tenant.id == api_token.tenant_id) + .filter(TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.role.in_(["owner"])) + .filter(Tenant.status == TenantStatus.NORMAL) + .one_or_none() + ) # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join account = Account.query.filter_by(id=ta.account_id).first() @@ -156,6 +165,7 @@ def validate_dataset_token(view=None): else: raise Unauthorized("Tenant does not exist.") return view(api_token.tenant_id, *args, **kwargs) + return decorated if view: @@ -170,20 +180,24 @@ def validate_and_get_api_token(scope=None): """ Validate and get API token. """ - auth_header = request.headers.get('Authorization') - if auth_header is None or ' ' not in auth_header: + auth_header = request.headers.get("Authorization") + if auth_header is None or " " not in auth_header: raise Unauthorized("Authorization header must be provided and start with 'Bearer'") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': + if auth_scheme != "bearer": raise Unauthorized("Authorization scheme must be 'Bearer'") - api_token = db.session.query(ApiToken).filter( - ApiToken.token == auth_token, - ApiToken.type == scope, - ).first() + api_token = ( + db.session.query(ApiToken) + .filter( + ApiToken.token == auth_token, + ApiToken.type == scope, + ) + .first() + ) if not api_token: raise Unauthorized("Access token is invalid") @@ -199,23 +213,26 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str] Create or update session terminal based on user ID. """ if not user_id: - user_id = 'DEFAULT-USER' + user_id = "DEFAULT-USER" - end_user = db.session.query(EndUser) \ + end_user = ( + db.session.query(EndUser) .filter( - EndUser.tenant_id == app_model.tenant_id, - EndUser.app_id == app_model.id, - EndUser.session_id == user_id, - EndUser.type == 'service_api' - ).first() + EndUser.tenant_id == app_model.tenant_id, + EndUser.app_id == app_model.id, + EndUser.session_id == user_id, + EndUser.type == "service_api", + ) + .first() + ) if end_user is None: end_user = EndUser( tenant_id=app_model.tenant_id, app_id=app_model.id, - type='service_api', - is_anonymous=True if user_id == 'DEFAULT-USER' else False, - session_id=user_id + type="service_api", + is_anonymous=True if user_id == "DEFAULT-USER" else False, + session_id=user_id, ) db.session.add(end_user) db.session.commit() diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index aa19bdc03..630b9468a 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -2,7 +2,7 @@ from flask import Blueprint from libs.external_api import ExternalApi -bp = Blueprint('web', __name__, url_prefix='/api') +bp = Blueprint("web", __name__, url_prefix="/api") api = ExternalApi(bp) diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index f4db82552..aabca9333 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -10,33 +10,32 @@ from services.app_service import AppService class AppParameterApi(WebApiResource): """Resource for app variables.""" + variable_fields = { - 'key': fields.String, - 'name': fields.String, - 'description': fields.String, - 'type': fields.String, - 'default': fields.String, - 'max_length': fields.Integer, - 'options': fields.List(fields.String) + "key": fields.String, + "name": fields.String, + "description": fields.String, + "type": fields.String, + "default": fields.String, + "max_length": fields.Integer, + "options": fields.List(fields.String), } - system_parameters_fields = { - 'image_file_size_limit': fields.String - } + system_parameters_fields = {"image_file_size_limit": fields.String} parameters_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw, - 'suggested_questions_after_answer': fields.Raw, - 'speech_to_text': fields.Raw, - 'text_to_speech': fields.Raw, - 'retriever_resource': fields.Raw, - 'annotation_reply': fields.Raw, - 'more_like_this': fields.Raw, - 'user_input_form': fields.Raw, - 'sensitive_word_avoidance': fields.Raw, - 'file_upload': fields.Raw, - 'system_parameters': fields.Nested(system_parameters_fields) + "opening_statement": fields.String, + "suggested_questions": fields.Raw, + "suggested_questions_after_answer": fields.Raw, + "speech_to_text": fields.Raw, + "text_to_speech": fields.Raw, + "retriever_resource": fields.Raw, + "annotation_reply": fields.Raw, + "more_like_this": fields.Raw, + "user_input_form": fields.Raw, + "sensitive_word_avoidance": fields.Raw, + "file_upload": fields.Raw, + "system_parameters": fields.Nested(system_parameters_fields), } @marshal_with(parameters_fields) @@ -53,30 +52,35 @@ class AppParameterApi(WebApiResource): app_model_config = app_model.app_model_config features_dict = app_model_config.to_dict() - user_input_form = features_dict.get('user_input_form', []) + user_input_form = features_dict.get("user_input_form", []) return { - 'opening_statement': features_dict.get('opening_statement'), - 'suggested_questions': features_dict.get('suggested_questions', []), - 'suggested_questions_after_answer': features_dict.get('suggested_questions_after_answer', - {"enabled": False}), - 'speech_to_text': features_dict.get('speech_to_text', {"enabled": False}), - 'text_to_speech': features_dict.get('text_to_speech', {"enabled": False}), - 'retriever_resource': features_dict.get('retriever_resource', {"enabled": False}), - 'annotation_reply': features_dict.get('annotation_reply', {"enabled": False}), - 'more_like_this': features_dict.get('more_like_this', {"enabled": False}), - 'user_input_form': user_input_form, - 'sensitive_word_avoidance': features_dict.get('sensitive_word_avoidance', - {"enabled": False, "type": "", "configs": []}), - 'file_upload': features_dict.get('file_upload', {"image": { - "enabled": False, - "number_limits": 3, - "detail": "high", - "transfer_methods": ["remote_url", "local_file"] - }}), - 'system_parameters': { - 'image_file_size_limit': dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT - } + "opening_statement": features_dict.get("opening_statement"), + "suggested_questions": features_dict.get("suggested_questions", []), + "suggested_questions_after_answer": features_dict.get( + "suggested_questions_after_answer", {"enabled": False} + ), + "speech_to_text": features_dict.get("speech_to_text", {"enabled": False}), + "text_to_speech": features_dict.get("text_to_speech", {"enabled": False}), + "retriever_resource": features_dict.get("retriever_resource", {"enabled": False}), + "annotation_reply": features_dict.get("annotation_reply", {"enabled": False}), + "more_like_this": features_dict.get("more_like_this", {"enabled": False}), + "user_input_form": user_input_form, + "sensitive_word_avoidance": features_dict.get( + "sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []} + ), + "file_upload": features_dict.get( + "file_upload", + { + "image": { + "enabled": False, + "number_limits": 3, + "detail": "high", + "transfer_methods": ["remote_url", "local_file"], + } + }, + ), + "system_parameters": {"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT}, } @@ -86,5 +90,5 @@ class AppMeta(WebApiResource): return AppService().get_app_meta(app_model) -api.add_resource(AppParameterApi, '/parameters') -api.add_resource(AppMeta, '/meta') \ No newline at end of file +api.add_resource(AppParameterApi, "/parameters") +api.add_resource(AppMeta, "/meta") diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 0e905f905..d062d2893 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -31,14 +31,10 @@ from services.errors.audio import ( class AudioApi(WebApiResource): def post(self, app_model: App, end_user): - file = request.files['file'] + file = request.files["file"] try: - response = AudioService.transcript_asr( - app_model=app_model, - file=file, - end_user=end_user - ) + response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=end_user) return response except services.errors.app_model_config.AppModelConfigBrokenError: @@ -70,34 +66,36 @@ class AudioApi(WebApiResource): class TextApi(WebApiResource): def post(self, app_model: App, end_user): from flask_restful import reqparse + try: parser = reqparse.RequestParser() - parser.add_argument('message_id', type=str, required=False, location='json') - parser.add_argument('voice', type=str, location='json') - parser.add_argument('text', type=str, location='json') - parser.add_argument('streaming', type=bool, location='json') + parser.add_argument("message_id", type=str, required=False, location="json") + parser.add_argument("voice", type=str, location="json") + parser.add_argument("text", type=str, location="json") + parser.add_argument("streaming", type=bool, location="json") args = parser.parse_args() - message_id = args.get('message_id', None) - text = args.get('text', None) - if (app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] - and app_model.workflow - and app_model.workflow.features_dict): - text_to_speech = app_model.workflow.features_dict.get('text_to_speech') - voice = args.get('voice') if args.get('voice') else text_to_speech.get('voice') + message_id = args.get("message_id", None) + text = args.get("text", None) + if ( + app_model.mode in [AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value] + and app_model.workflow + and app_model.workflow.features_dict + ): + text_to_speech = app_model.workflow.features_dict.get("text_to_speech") + voice = args.get("voice") if args.get("voice") else text_to_speech.get("voice") else: try: - voice = args.get('voice') if args.get( - 'voice') else app_model.app_model_config.text_to_speech_dict.get('voice') + voice = ( + args.get("voice") + if args.get("voice") + else app_model.app_model_config.text_to_speech_dict.get("voice") + ) except Exception: voice = None response = AudioService.transcript_tts( - app_model=app_model, - message_id=message_id, - end_user=end_user.external_user_id, - voice=voice, - text=text + app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text ) return response @@ -127,5 +125,5 @@ class TextApi(WebApiResource): raise InternalServerError() -api.add_resource(AudioApi, '/audio-to-text') -api.add_resource(TextApi, '/text-to-audio') +api.add_resource(AudioApi, "/audio-to-text") +api.add_resource(TextApi, "/text-to-audio") diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 948d5fabb..bd636a048 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -28,30 +28,25 @@ from services.app_generate_service import AppGenerateService # define completion api for user class CompletionApi(WebApiResource): - def post(self, app_model, end_user): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, location='json', default='') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, location="json", default="") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' - args['auto_generate_name'] = False + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.WEB_APP, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming ) return helper.compact_generate_response(response) @@ -79,12 +74,12 @@ class CompletionApi(WebApiResource): class CompletionStopApi(WebApiResource): def post(self, app_model, end_user, task_id): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 class ChatApi(WebApiResource): @@ -94,25 +89,21 @@ class ChatApi(WebApiResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, location='json') - parser.add_argument('query', type=str, required=True, location='json') - parser.add_argument('files', type=list, required=False, location='json') - parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json') - parser.add_argument('conversation_id', type=uuid_value, location='json') - parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json') + parser.add_argument("inputs", type=dict, required=True, location="json") + parser.add_argument("query", type=str, required=True, location="json") + parser.add_argument("files", type=list, required=False, location="json") + parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json") + parser.add_argument("conversation_id", type=uuid_value, location="json") + parser.add_argument("retriever_from", type=str, required=False, default="web_app", location="json") args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' - args['auto_generate_name'] = False + streaming = args["response_mode"] == "streaming" + args["auto_generate_name"] = False try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.WEB_APP, - streaming=streaming + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=streaming ) return helper.compact_generate_response(response) @@ -146,10 +137,10 @@ class ChatStopApi(WebApiResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return {'result': 'success'}, 200 + return {"result": "success"}, 200 -api.add_resource(CompletionApi, '/completion-messages') -api.add_resource(CompletionStopApi, '/completion-messages//stop') -api.add_resource(ChatApi, '/chat-messages') -api.add_resource(ChatStopApi, '/chat-messages//stop') +api.add_resource(CompletionApi, "/completion-messages") +api.add_resource(CompletionStopApi, "/completion-messages//stop") +api.add_resource(ChatApi, "/chat-messages") +api.add_resource(ChatStopApi, "/chat-messages//stop") diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py index 334ee382a..6bbfa94c2 100644 --- a/api/controllers/web/conversation.py +++ b/api/controllers/web/conversation.py @@ -15,7 +15,6 @@ from services.web_conversation_service import WebConversationService class ConversationListApi(WebApiResource): - @marshal_with(conversation_infinite_scroll_pagination_fields) def get(self, app_model, end_user): app_mode = AppMode.value_of(app_model.mode) @@ -23,26 +22,32 @@ class ConversationListApi(WebApiResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') - parser.add_argument('pinned', type=str, choices=['true', 'false', None], location='args') - parser.add_argument('sort_by', type=str, choices=['created_at', '-created_at', 'updated_at', '-updated_at'], - required=False, default='-updated_at', location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") + parser.add_argument("pinned", type=str, choices=["true", "false", None], location="args") + parser.add_argument( + "sort_by", + type=str, + choices=["created_at", "-created_at", "updated_at", "-updated_at"], + required=False, + default="-updated_at", + location="args", + ) args = parser.parse_args() pinned = None - if 'pinned' in args and args['pinned'] is not None: - pinned = True if args['pinned'] == 'true' else False + if "pinned" in args and args["pinned"] is not None: + pinned = True if args["pinned"] == "true" else False try: return WebConversationService.pagination_by_last_id( app_model=app_model, user=end_user, - last_id=args['last_id'], - limit=args['limit'], + last_id=args["last_id"], + limit=args["limit"], invoke_from=InvokeFrom.WEB_APP, pinned=pinned, - sort_by=args['sort_by'] + sort_by=args["sort_by"], ) except LastConversationNotExistsError: raise NotFound("Last Conversation Not Exists.") @@ -65,7 +70,6 @@ class ConversationApi(WebApiResource): class ConversationRenameApi(WebApiResource): - @marshal_with(simple_conversation_fields) def post(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) @@ -75,24 +79,17 @@ class ConversationRenameApi(WebApiResource): conversation_id = str(c_id) parser = reqparse.RequestParser() - parser.add_argument('name', type=str, required=False, location='json') - parser.add_argument('auto_generate', type=bool, required=False, default=False, location='json') + parser.add_argument("name", type=str, required=False, location="json") + parser.add_argument("auto_generate", type=bool, required=False, default=False, location="json") args = parser.parse_args() try: - return ConversationService.rename( - app_model, - conversation_id, - end_user, - args['name'], - args['auto_generate'] - ) + return ConversationService.rename(app_model, conversation_id, end_user, args["name"], args["auto_generate"]) except ConversationNotExistsError: raise NotFound("Conversation Not Exists.") class ConversationPinApi(WebApiResource): - def patch(self, app_model, end_user, c_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT]: @@ -120,8 +117,8 @@ class ConversationUnPinApi(WebApiResource): return {"result": "success"} -api.add_resource(ConversationRenameApi, '/conversations//name', endpoint='web_conversation_name') -api.add_resource(ConversationListApi, '/conversations') -api.add_resource(ConversationApi, '/conversations/') -api.add_resource(ConversationPinApi, '/conversations//pin') -api.add_resource(ConversationUnPinApi, '/conversations//unpin') +api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="web_conversation_name") +api.add_resource(ConversationListApi, "/conversations") +api.add_resource(ConversationApi, "/conversations/") +api.add_resource(ConversationPinApi, "/conversations//pin") +api.add_resource(ConversationUnPinApi, "/conversations//unpin") diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index bc87f5105..2f6bb39cf 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -2,122 +2,126 @@ from libs.exception import BaseHTTPException class AppUnavailableError(BaseHTTPException): - error_code = 'app_unavailable' + error_code = "app_unavailable" description = "App unavailable, please check your app configurations." code = 400 class NotCompletionAppError(BaseHTTPException): - error_code = 'not_completion_app' + error_code = "not_completion_app" description = "Please check if your Completion app mode matches the right API route." code = 400 class NotChatAppError(BaseHTTPException): - error_code = 'not_chat_app' + error_code = "not_chat_app" description = "Please check if your app mode matches the right API route." code = 400 class NotWorkflowAppError(BaseHTTPException): - error_code = 'not_workflow_app' + error_code = "not_workflow_app" description = "Please check if your Workflow app mode matches the right API route." code = 400 class ConversationCompletedError(BaseHTTPException): - error_code = 'conversation_completed' + error_code = "conversation_completed" description = "The conversation has ended. Please start a new conversation." code = 400 class ProviderNotInitializeError(BaseHTTPException): - error_code = 'provider_not_initialize' - description = "No valid model provider credentials found. " \ - "Please go to Settings -> Model Provider to complete your provider credentials." + error_code = "provider_not_initialize" + description = ( + "No valid model provider credentials found. " + "Please go to Settings -> Model Provider to complete your provider credentials." + ) code = 400 class ProviderQuotaExceededError(BaseHTTPException): - error_code = 'provider_quota_exceeded' - description = "Your quota for Dify Hosted OpenAI has been exhausted. " \ - "Please go to Settings -> Model Provider to complete your own provider credentials." + error_code = "provider_quota_exceeded" + description = ( + "Your quota for Dify Hosted OpenAI has been exhausted. " + "Please go to Settings -> Model Provider to complete your own provider credentials." + ) code = 400 class ProviderModelCurrentlyNotSupportError(BaseHTTPException): - error_code = 'model_currently_not_support' + error_code = "model_currently_not_support" description = "Dify Hosted OpenAI trial currently not support the GPT-4 model." code = 400 class CompletionRequestError(BaseHTTPException): - error_code = 'completion_request_error' + error_code = "completion_request_error" description = "Completion request failed." code = 400 class AppMoreLikeThisDisabledError(BaseHTTPException): - error_code = 'app_more_like_this_disabled' + error_code = "app_more_like_this_disabled" description = "The 'More like this' feature is disabled. Please refresh your page." code = 403 class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): - error_code = 'app_suggested_questions_after_answer_disabled' + error_code = "app_suggested_questions_after_answer_disabled" description = "The 'Suggested Questions After Answer' feature is disabled. Please refresh your page." code = 403 class NoAudioUploadedError(BaseHTTPException): - error_code = 'no_audio_uploaded' + error_code = "no_audio_uploaded" description = "Please upload your audio." code = 400 class AudioTooLargeError(BaseHTTPException): - error_code = 'audio_too_large' + error_code = "audio_too_large" description = "Audio size exceeded. {message}" code = 413 class UnsupportedAudioTypeError(BaseHTTPException): - error_code = 'unsupported_audio_type' + error_code = "unsupported_audio_type" description = "Audio type not allowed." code = 415 class ProviderNotSupportSpeechToTextError(BaseHTTPException): - error_code = 'provider_not_support_speech_to_text' + error_code = "provider_not_support_speech_to_text" description = "Provider not support speech to text." code = 400 class NoFileUploadedError(BaseHTTPException): - error_code = 'no_file_uploaded' + error_code = "no_file_uploaded" description = "Please upload your file." code = 400 class TooManyFilesError(BaseHTTPException): - error_code = 'too_many_files' + error_code = "too_many_files" description = "Only one file is allowed." code = 400 class FileTooLargeError(BaseHTTPException): - error_code = 'file_too_large' + error_code = "file_too_large" description = "File size exceeded. {message}" code = 413 class UnsupportedFileTypeError(BaseHTTPException): - error_code = 'unsupported_file_type' + error_code = "unsupported_file_type" description = "File type not allowed." code = 415 class WebSSOAuthRequiredError(BaseHTTPException): - error_code = 'web_sso_auth_required' + error_code = "web_sso_auth_required" description = "Web SSO authentication required." code = 401 diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py index 69b38faaf..0563ed223 100644 --- a/api/controllers/web/feature.py +++ b/api/controllers/web/feature.py @@ -9,4 +9,4 @@ class SystemFeatureApi(Resource): return FeatureService.get_system_features().model_dump() -api.add_resource(SystemFeatureApi, '/system-features') +api.add_resource(SystemFeatureApi, "/system-features") diff --git a/api/controllers/web/file.py b/api/controllers/web/file.py index ca83f6037..253b1d511 100644 --- a/api/controllers/web/file.py +++ b/api/controllers/web/file.py @@ -10,14 +10,13 @@ from services.file_service import FileService class FileApi(WebApiResource): - @marshal_with(file_fields) def post(self, app_model, end_user): # get file from request - file = request.files['file'] + file = request.files["file"] # check file - if 'file' not in request.files: + if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: @@ -32,4 +31,4 @@ class FileApi(WebApiResource): return upload_file, 201 -api.add_resource(FileApi, '/files/upload') +api.add_resource(FileApi, "/files/upload") diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 865d2270a..56aaaa930 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -33,48 +33,46 @@ from services.message_service import MessageService class MessageListApi(WebApiResource): - feedback_fields = { - 'rating': fields.String - } + feedback_fields = {"rating": fields.String} retriever_resource_fields = { - 'id': fields.String, - 'message_id': fields.String, - 'position': fields.Integer, - 'dataset_id': fields.String, - 'dataset_name': fields.String, - 'document_id': fields.String, - 'document_name': fields.String, - 'data_source_type': fields.String, - 'segment_id': fields.String, - 'score': fields.Float, - 'hit_count': fields.Integer, - 'word_count': fields.Integer, - 'segment_position': fields.Integer, - 'index_node_hash': fields.String, - 'content': fields.String, - 'created_at': TimestampField + "id": fields.String, + "message_id": fields.String, + "position": fields.Integer, + "dataset_id": fields.String, + "dataset_name": fields.String, + "document_id": fields.String, + "document_name": fields.String, + "data_source_type": fields.String, + "segment_id": fields.String, + "score": fields.Float, + "hit_count": fields.Integer, + "word_count": fields.Integer, + "segment_position": fields.Integer, + "index_node_hash": fields.String, + "content": fields.String, + "created_at": TimestampField, } message_fields = { - 'id': fields.String, - 'conversation_id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String(attribute='re_sign_file_url_answer'), - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)), - 'created_at': TimestampField, - 'agent_thoughts': fields.List(fields.Nested(agent_thought_fields)), - 'status': fields.String, - 'error': fields.String, + "id": fields.String, + "conversation_id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String(attribute="re_sign_file_url_answer"), + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "retriever_resources": fields.List(fields.Nested(retriever_resource_fields)), + "created_at": TimestampField, + "agent_thoughts": fields.List(fields.Nested(agent_thought_fields)), + "status": fields.String, + "error": fields.String, } message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @marshal_with(message_infinite_scroll_pagination_fields) @@ -84,14 +82,15 @@ class MessageListApi(WebApiResource): raise NotChatAppError() parser = reqparse.RequestParser() - parser.add_argument('conversation_id', required=True, type=uuid_value, location='args') - parser.add_argument('first_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("conversation_id", required=True, type=uuid_value, location="args") + parser.add_argument("first_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() try: - return MessageService.pagination_by_first_id(app_model, end_user, - args['conversation_id'], args['first_id'], args['limit']) + return MessageService.pagination_by_first_id( + app_model, end_user, args["conversation_id"], args["first_id"], args["limit"] + ) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") except services.errors.message.FirstMessageNotExistsError: @@ -103,29 +102,31 @@ class MessageFeedbackApi(WebApiResource): message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('rating', type=str, choices=['like', 'dislike', None], location='json') + parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() try: - MessageService.create_feedback(app_model, message_id, end_user, args['rating']) + MessageService.create_feedback(app_model, message_id, end_user, args["rating"]) except services.errors.message.MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class MessageMoreLikeThisApi(WebApiResource): def get(self, app_model, end_user, message_id): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() message_id = str(message_id) parser = reqparse.RequestParser() - parser.add_argument('response_mode', type=str, required=True, choices=['blocking', 'streaming'], location='args') + parser.add_argument( + "response_mode", type=str, required=True, choices=["blocking", "streaming"], location="args" + ) args = parser.parse_args() - streaming = args['response_mode'] == 'streaming' + streaming = args["response_mode"] == "streaming" try: response = AppGenerateService.generate_more_like_this( @@ -133,7 +134,7 @@ class MessageMoreLikeThisApi(WebApiResource): user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP, - streaming=streaming + streaming=streaming, ) return helper.compact_generate_response(response) @@ -166,10 +167,7 @@ class MessageSuggestedQuestionApi(WebApiResource): try: questions = MessageService.get_suggested_questions_after_answer( - app_model=app_model, - user=end_user, - message_id=message_id, - invoke_from=InvokeFrom.WEB_APP + app_model=app_model, user=end_user, message_id=message_id, invoke_from=InvokeFrom.WEB_APP ) except MessageNotExistsError: raise NotFound("Message not found") @@ -189,10 +187,10 @@ class MessageSuggestedQuestionApi(WebApiResource): logging.exception("internal server error.") raise InternalServerError() - return {'data': questions} + return {"data": questions} -api.add_resource(MessageListApi, '/messages') -api.add_resource(MessageFeedbackApi, '/messages//feedbacks') -api.add_resource(MessageMoreLikeThisApi, '/messages//more-like-this') -api.add_resource(MessageSuggestedQuestionApi, '/messages//suggested-questions') +api.add_resource(MessageListApi, "/messages") +api.add_resource(MessageFeedbackApi, "/messages//feedbacks") +api.add_resource(MessageMoreLikeThisApi, "/messages//more-like-this") +api.add_resource(MessageSuggestedQuestionApi, "/messages//suggested-questions") diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index cce8943ea..a01ffd861 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -15,33 +15,31 @@ from services.feature_service import FeatureService class PassportResource(Resource): """Base resource for passport.""" + def get(self): system_features = FeatureService.get_system_features() - app_code = request.headers.get('X-App-Code') + app_code = request.headers.get("X-App-Code") if app_code is None: - raise Unauthorized('X-App-Code header is missing.') + raise Unauthorized("X-App-Code header is missing.") if system_features.sso_enforced_for_web: - app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get('enabled', False) + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) if app_web_sso_enabled: raise WebSSOAuthRequiredError() - + # get site from db and check if it is normal - site = db.session.query(Site).filter( - Site.code == app_code, - Site.status == 'normal' - ).first() + site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() if not site: raise NotFound() # get app from db and check if it is normal and enable_site app_model = db.session.query(App).filter(App.id == site.app_id).first() - if not app_model or app_model.status != 'normal' or not app_model.enable_site: + if not app_model or app_model.status != "normal" or not app_model.enable_site: raise NotFound() end_user = EndUser( tenant_id=app_model.tenant_id, app_id=app_model.id, - type='browser', + type="browser", is_anonymous=True, session_id=generate_session_id(), ) @@ -51,20 +49,20 @@ class PassportResource(Resource): payload = { "iss": site.app_id, - 'sub': 'Web API Passport', - 'app_id': site.app_id, - 'app_code': app_code, - 'end_user_id': end_user.id, + "sub": "Web API Passport", + "app_id": site.app_id, + "app_code": app_code, + "end_user_id": end_user.id, } tk = PassportService().issue(payload) return { - 'access_token': tk, + "access_token": tk, } -api.add_resource(PassportResource, '/passport') +api.add_resource(PassportResource, "/passport") def generate_session_id(): @@ -73,7 +71,6 @@ def generate_session_id(): """ while True: session_id = str(uuid.uuid4()) - existing_count = db.session.query(EndUser) \ - .filter(EndUser.session_id == session_id).count() + existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count() if existing_count == 0: return session_id diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index e17869ffd..8253f5fc5 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -10,67 +10,65 @@ from libs.helper import TimestampField, uuid_value from services.errors.message import MessageNotExistsError from services.saved_message_service import SavedMessageService -feedback_fields = { - 'rating': fields.String -} +feedback_fields = {"rating": fields.String} message_fields = { - 'id': fields.String, - 'inputs': fields.Raw, - 'query': fields.String, - 'answer': fields.String, - 'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'), - 'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True), - 'created_at': TimestampField + "id": fields.String, + "inputs": fields.Raw, + "query": fields.String, + "answer": fields.String, + "message_files": fields.List(fields.Nested(message_file_fields), attribute="files"), + "feedback": fields.Nested(feedback_fields, attribute="user_feedback", allow_null=True), + "created_at": TimestampField, } class SavedMessageListApi(WebApiResource): saved_message_infinite_scroll_pagination_fields = { - 'limit': fields.Integer, - 'has_more': fields.Boolean, - 'data': fields.List(fields.Nested(message_fields)) + "limit": fields.Integer, + "has_more": fields.Boolean, + "data": fields.List(fields.Nested(message_fields)), } @marshal_with(saved_message_infinite_scroll_pagination_fields) def get(self, app_model, end_user): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('last_id', type=uuid_value, location='args') - parser.add_argument('limit', type=int_range(1, 100), required=False, default=20, location='args') + parser.add_argument("last_id", type=uuid_value, location="args") + parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - return SavedMessageService.pagination_by_last_id(app_model, end_user, args['last_id'], args['limit']) + return SavedMessageService.pagination_by_last_id(app_model, end_user, args["last_id"], args["limit"]) def post(self, app_model, end_user): - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() parser = reqparse.RequestParser() - parser.add_argument('message_id', type=uuid_value, required=True, location='json') + parser.add_argument("message_id", type=uuid_value, required=True, location="json") args = parser.parse_args() try: - SavedMessageService.save(app_model, end_user, args['message_id']) + SavedMessageService.save(app_model, end_user, args["message_id"]) except MessageNotExistsError: raise NotFound("Message Not Exists.") - return {'result': 'success'} + return {"result": "success"} class SavedMessageApi(WebApiResource): def delete(self, app_model, end_user, message_id): message_id = str(message_id) - if app_model.mode != 'completion': + if app_model.mode != "completion": raise NotCompletionAppError() SavedMessageService.delete(app_model, end_user, message_id) - return {'result': 'success'} + return {"result": "success"} -api.add_resource(SavedMessageListApi, '/saved-messages') -api.add_resource(SavedMessageApi, '/saved-messages/') +api.add_resource(SavedMessageListApi, "/saved-messages") +api.add_resource(SavedMessageApi, "/saved-messages/") diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index 0f4a7cabe..2b4d0e763 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,4 +1,3 @@ - from flask_restful import fields, marshal_with from werkzeug.exceptions import Forbidden @@ -16,41 +15,41 @@ class AppSiteApi(WebApiResource): """Resource for app sites.""" model_config_fields = { - 'opening_statement': fields.String, - 'suggested_questions': fields.Raw(attribute='suggested_questions_list'), - 'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'), - 'more_like_this': fields.Raw(attribute='more_like_this_dict'), - 'model': fields.Raw(attribute='model_dict'), - 'user_input_form': fields.Raw(attribute='user_input_form_list'), - 'pre_prompt': fields.String, + "opening_statement": fields.String, + "suggested_questions": fields.Raw(attribute="suggested_questions_list"), + "suggested_questions_after_answer": fields.Raw(attribute="suggested_questions_after_answer_dict"), + "more_like_this": fields.Raw(attribute="more_like_this_dict"), + "model": fields.Raw(attribute="model_dict"), + "user_input_form": fields.Raw(attribute="user_input_form_list"), + "pre_prompt": fields.String, } site_fields = { - 'title': fields.String, - 'chat_color_theme': fields.String, - 'chat_color_theme_inverted': fields.Boolean, - 'icon_type': fields.String, - 'icon': fields.String, - 'icon_background': fields.String, - 'icon_url': AppIconUrlField, - 'description': fields.String, - 'copyright': fields.String, - 'privacy_policy': fields.String, - 'custom_disclaimer': fields.String, - 'default_language': fields.String, - 'prompt_public': fields.Boolean, - 'show_workflow_steps': fields.Boolean, + "title": fields.String, + "chat_color_theme": fields.String, + "chat_color_theme_inverted": fields.Boolean, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "description": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "default_language": fields.String, + "prompt_public": fields.Boolean, + "show_workflow_steps": fields.Boolean, } app_fields = { - 'app_id': fields.String, - 'end_user_id': fields.String, - 'enable_site': fields.Boolean, - 'site': fields.Nested(site_fields), - 'model_config': fields.Nested(model_config_fields, allow_null=True), - 'plan': fields.String, - 'can_replace_logo': fields.Boolean, - 'custom_config': fields.Raw(attribute='custom_config'), + "app_id": fields.String, + "end_user_id": fields.String, + "enable_site": fields.Boolean, + "site": fields.Nested(site_fields), + "model_config": fields.Nested(model_config_fields, allow_null=True), + "plan": fields.String, + "can_replace_logo": fields.Boolean, + "custom_config": fields.Raw(attribute="custom_config"), } @marshal_with(app_fields) @@ -70,7 +69,7 @@ class AppSiteApi(WebApiResource): return AppSiteInfo(app_model.tenant, app_model, site, end_user.id, can_replace_logo) -api.add_resource(AppSiteApi, '/site') +api.add_resource(AppSiteApi, "/site") class AppSiteInfo: @@ -88,9 +87,13 @@ class AppSiteInfo: if can_replace_logo: base_url = dify_config.FILES_URL - remove_webapp_brand = tenant.custom_config_dict.get('remove_webapp_brand', False) - replace_webapp_logo = f'{base_url}/files/workspaces/{tenant.id}/webapp-logo' if tenant.custom_config_dict.get('replace_webapp_logo') else None + remove_webapp_brand = tenant.custom_config_dict.get("remove_webapp_brand", False) + replace_webapp_logo = ( + f"{base_url}/files/workspaces/{tenant.id}/webapp-logo" + if tenant.custom_config_dict.get("replace_webapp_logo") + else None + ) self.custom_config = { - 'remove_webapp_brand': remove_webapp_brand, - 'replace_webapp_logo': replace_webapp_logo, + "remove_webapp_brand": remove_webapp_brand, + "replace_webapp_logo": replace_webapp_logo, } diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 77c468e41..55b0c3e2a 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -33,17 +33,13 @@ class WorkflowRunApi(WebApiResource): raise NotWorkflowAppError() parser = reqparse.RequestParser() - parser.add_argument('inputs', type=dict, required=True, nullable=False, location='json') - parser.add_argument('files', type=list, required=False, location='json') + parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") + parser.add_argument("files", type=list, required=False, location="json") args = parser.parse_args() try: response = AppGenerateService.generate( - app_model=app_model, - user=end_user, - args=args, - invoke_from=InvokeFrom.WEB_APP, - streaming=True + app_model=app_model, user=end_user, args=args, invoke_from=InvokeFrom.WEB_APP, streaming=True ) return helper.compact_generate_response(response) @@ -73,10 +69,8 @@ class WorkflowTaskStopApi(WebApiResource): AppQueueManager.set_stop_flag(task_id, InvokeFrom.WEB_APP, end_user.id) - return { - "result": "success" - } + return {"result": "success"} -api.add_resource(WorkflowRunApi, '/workflows/run') -api.add_resource(WorkflowTaskStopApi, '/workflows/tasks//stop') +api.add_resource(WorkflowRunApi, "/workflows/run") +api.add_resource(WorkflowTaskStopApi, "/workflows/tasks//stop") diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index ae363672c..93dc691d6 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -19,7 +19,9 @@ def validate_jwt_token(view=None): app_model, end_user = decode_jwt_token() return view(app_model, end_user, *args, **kwargs) + return decorated + if view: return decorator(view) return decorator @@ -27,31 +29,31 @@ def validate_jwt_token(view=None): def decode_jwt_token(): system_features = FeatureService.get_system_features() - app_code = request.headers.get('X-App-Code') + app_code = request.headers.get("X-App-Code") try: - auth_header = request.headers.get('Authorization') + auth_header = request.headers.get("Authorization") if auth_header is None: - raise Unauthorized('Authorization header is missing.') + raise Unauthorized("Authorization header is missing.") - if ' ' not in auth_header: - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, tk = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() - if auth_scheme != 'bearer': - raise Unauthorized('Invalid Authorization header format. Expected \'Bearer \' format.') + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") decoded = PassportService().verify(tk) - app_code = decoded.get('app_code') - app_model = db.session.query(App).filter(App.id == decoded['app_id']).first() + app_code = decoded.get("app_code") + app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first() site = db.session.query(Site).filter(Site.code == app_code).first() if not app_model: raise NotFound() if not app_code or not site: - raise BadRequest('Site URL is no longer valid.') + raise BadRequest("Site URL is no longer valid.") if app_model.enable_site is False: - raise BadRequest('Site is disabled.') - end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first() + raise BadRequest("Site is disabled.") + end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() if not end_user: raise NotFound() @@ -60,7 +62,7 @@ def decode_jwt_token(): return app_model, end_user except Unauthorized as e: if system_features.sso_enforced_for_web: - app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get('enabled', False) + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) if app_web_sso_enabled: raise WebSSOAuthRequiredError() @@ -69,20 +71,20 @@ def decode_jwt_token(): def _validate_web_sso_token(decoded, system_features, app_code): app_web_sso_enabled = False - + # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login if system_features.sso_enforced_for_web: - app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get('enabled', False) + app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) if app_web_sso_enabled: - source = decoded.get('token_source') - if not source or source != 'sso': + source = decoded.get("token_source") + if not source or source != "sso": raise WebSSOAuthRequiredError() # Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login if not system_features.sso_enforced_for_web or not app_web_sso_enabled: - source = decoded.get('token_source') - if source and source == 'sso': - raise Unauthorized('sso token expired.') + source = decoded.get("token_source") + if source and source == "sso": + raise Unauthorized("sso token expired.") class WebApiResource(Resource): diff --git a/api/pyproject.toml b/api/pyproject.toml index eddeeb0cd..e2d6704e8 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -71,7 +71,6 @@ ignore = [ [tool.ruff.format] exclude = [ "core/**/*.py", - "controllers/**/*.py", "models/**/*.py", "migrations/**/*", ]