chore(api/controllers): Apply Ruff Formatter. (#7645)

This commit is contained in:
-LAN-
2024-08-26 15:29:10 +08:00
committed by GitHub
parent 7ae728a9a3
commit 13be84e4d4
104 changed files with 3849 additions and 3982 deletions

View File

@@ -40,7 +40,7 @@ class DatasetDocumentSegmentListApi(Resource):
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
raise NotFound("Dataset not found.")
try:
DatasetService.check_dataset_permission(dataset, current_user)
@@ -50,37 +50,33 @@ class DatasetDocumentSegmentListApi(Resource):
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
raise NotFound("Document not found.")
parser = reqparse.RequestParser()
parser.add_argument('last_id', type=str, default=None, location='args')
parser.add_argument('limit', type=int, default=20, location='args')
parser.add_argument('status', type=str,
action='append', default=[], location='args')
parser.add_argument('hit_count_gte', type=int,
default=None, location='args')
parser.add_argument('enabled', type=str, default='all', location='args')
parser.add_argument('keyword', type=str, default=None, location='args')
parser.add_argument("last_id", type=str, default=None, location="args")
parser.add_argument("limit", type=int, default=20, location="args")
parser.add_argument("status", type=str, action="append", default=[], location="args")
parser.add_argument("hit_count_gte", type=int, default=None, location="args")
parser.add_argument("enabled", type=str, default="all", location="args")
parser.add_argument("keyword", type=str, default=None, location="args")
args = parser.parse_args()
last_id = args['last_id']
limit = min(args['limit'], 100)
status_list = args['status']
hit_count_gte = args['hit_count_gte']
keyword = args['keyword']
last_id = args["last_id"]
limit = min(args["limit"], 100)
status_list = args["status"]
hit_count_gte = args["hit_count_gte"]
keyword = args["keyword"]
query = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
)
if last_id is not None:
last_segment = db.session.get(DocumentSegment, str(last_id))
if last_segment:
query = query.filter(
DocumentSegment.position > last_segment.position)
query = query.filter(DocumentSegment.position > last_segment.position)
else:
return {'data': [], 'has_more': False, 'limit': limit}, 200
return {"data": [], "has_more": False, "limit": limit}, 200
if status_list:
query = query.filter(DocumentSegment.status.in_(status_list))
@@ -89,12 +85,12 @@ class DatasetDocumentSegmentListApi(Resource):
query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
if keyword:
query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
if args['enabled'].lower() != 'all':
if args['enabled'].lower() == 'true':
if args["enabled"].lower() != "all":
if args["enabled"].lower() == "true":
query = query.filter(DocumentSegment.enabled == True)
elif args['enabled'].lower() == 'false':
elif args["enabled"].lower() == "false":
query = query.filter(DocumentSegment.enabled == False)
total = query.count()
@@ -106,11 +102,11 @@ class DatasetDocumentSegmentListApi(Resource):
segments = segments[:-1]
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form,
'has_more': has_more,
'limit': limit,
'total': total
"data": marshal(segments, segment_fields),
"doc_form": document.doc_form,
"has_more": has_more,
"limit": limit,
"total": total,
}, 200
@@ -118,12 +114,12 @@ class DatasetDocumentSegmentApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('vector_space')
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, segment_id, action):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin, owner, or editor
@@ -134,7 +130,7 @@ class DatasetDocumentSegmentApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
if dataset.indexing_technique == 'high_quality':
if dataset.indexing_technique == "high_quality":
# check embedding model setting
try:
model_manager = ModelManager()
@@ -142,32 +138,32 @@ class DatasetDocumentSegmentApi(Resource):
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
raise NotFound("Segment not found.")
if segment.status != 'completed':
raise NotFound('Segment is not completed, enable or disable function is not allowed')
if segment.status != "completed":
raise NotFound("Segment is not completed, enable or disable function is not allowed")
document_indexing_cache_key = 'document_{}_indexing'.format(segment.document_id)
document_indexing_cache_key = "document_{}_indexing".format(segment.document_id)
cache_result = redis_client.get(document_indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later")
indexing_cache_key = 'segment_{}_indexing'.format(segment.id)
indexing_cache_key = "segment_{}_indexing".format(segment.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise InvalidActionError("Segment is being indexed, please try again later")
@@ -186,7 +182,7 @@ class DatasetDocumentSegmentApi(Resource):
enable_segment_to_index_task.delay(segment.id)
return {'result': 'success'}, 200
return {"result": "success"}, 200
elif action == "disable":
if not segment.enabled:
raise InvalidActionError("Segment is already disabled.")
@@ -201,7 +197,7 @@ class DatasetDocumentSegmentApi(Resource):
disable_segment_from_index_task.delay(segment.id)
return {'result': 'success'}, 200
return {"result": "success"}, 200
else:
raise InvalidActionError()
@@ -210,35 +206,36 @@ class DatasetDocumentSegmentAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('vector_space')
@cloud_edition_billing_knowledge_limit_check('add_segment')
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
raise NotFound("Document not found.")
if not current_user.is_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
@@ -247,37 +244,34 @@ class DatasetDocumentSegmentAddApi(Resource):
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.create_segment(args, document, dataset)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
class DatasetDocumentSegmentUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('vector_space')
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
if dataset.indexing_technique == 'high_quality':
raise NotFound("Document not found.")
if dataset.indexing_technique == "high_quality":
# check embedding model setting
try:
model_manager = ModelManager()
@@ -285,22 +279,22 @@ class DatasetDocumentSegmentUpdateApi(Resource):
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider.")
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
@@ -310,16 +304,13 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument('content', type=str, required=True, nullable=False, location='json')
parser.add_argument('answer', type=str, required=False, nullable=True, location='json')
parser.add_argument('keywords', type=list, required=False, nullable=True, location='json')
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document, dataset)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required
@login_required
@@ -329,22 +320,21 @@ class DatasetDocumentSegmentUpdateApi(Resource):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
@@ -353,36 +343,36 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segment(segment, document, dataset)
return {'result': 'success'}, 200
return {"result": "success"}, 200
class DatasetDocumentSegmentBatchImportApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check('vector_space')
@cloud_edition_billing_knowledge_limit_check('add_segment')
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
def post(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound('Dataset not found.')
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
raise NotFound("Document not found.")
# get file from request
file = request.files['file']
file = request.files["file"]
# check file
if 'file' not in request.files:
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
if not file.filename.endswith('.csv'):
if not file.filename.endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
try:
@@ -390,51 +380,47 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
df = pd.read_csv(file)
result = []
for index, row in df.iterrows():
if document.doc_form == 'qa_model':
data = {'content': row[0], 'answer': row[1]}
if document.doc_form == "qa_model":
data = {"content": row[0], "answer": row[1]}
else:
data = {'content': row[0]}
data = {"content": row[0]}
result.append(data)
if len(result) == 0:
raise ValueError("The CSV file is empty.")
# async job
job_id = str(uuid.uuid4())
indexing_cache_key = 'segment_batch_import_{}'.format(str(job_id))
indexing_cache_key = "segment_batch_import_{}".format(str(job_id))
# send batch add segments task
redis_client.setnx(indexing_cache_key, 'waiting')
batch_create_segment_to_index_task.delay(str(job_id), result, dataset_id, document_id,
current_user.current_tenant_id, current_user.id)
redis_client.setnx(indexing_cache_key, "waiting")
batch_create_segment_to_index_task.delay(
str(job_id), result, dataset_id, document_id, current_user.current_tenant_id, current_user.id
)
except Exception as e:
return {'error': str(e)}, 500
return {
'job_id': job_id,
'job_status': 'waiting'
}, 200
return {"error": str(e)}, 500
return {"job_id": job_id, "job_status": "waiting"}, 200
@setup_required
@login_required
@account_initialization_required
def get(self, job_id):
job_id = str(job_id)
indexing_cache_key = 'segment_batch_import_{}'.format(job_id)
indexing_cache_key = "segment_batch_import_{}".format(job_id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is None:
raise ValueError("The job is not exist.")
return {
'job_id': job_id,
'job_status': cache_result.decode()
}, 200
return {"job_id": job_id, "job_status": cache_result.decode()}, 200
api.add_resource(DatasetDocumentSegmentListApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
api.add_resource(DatasetDocumentSegmentApi,
'/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>')
api.add_resource(DatasetDocumentSegmentAddApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment')
api.add_resource(DatasetDocumentSegmentUpdateApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')
api.add_resource(DatasetDocumentSegmentBatchImportApi,
'/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import',
'/datasets/batch_import_status/<uuid:job_id>')
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>")
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
api.add_resource(
DatasetDocumentSegmentUpdateApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>",
)
api.add_resource(
DatasetDocumentSegmentBatchImportApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
"/datasets/batch_import_status/<uuid:job_id>",
)