py lint (#12102)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@@ -41,6 +41,7 @@ from models.source import DataSourceOauthBinding
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
ChildChunkUpdateArgs,
|
||||
KnowledgeConfig,
|
||||
RerankingModel,
|
||||
RetrievalModel,
|
||||
SegmentUpdateArgs,
|
||||
)
|
||||
@@ -548,12 +549,14 @@ class DocumentService:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_document(dataset_id: str, document_id: str) -> Optional[Document]:
|
||||
document = (
|
||||
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
|
||||
return document
|
||||
def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]:
|
||||
if document_id:
|
||||
document = (
|
||||
db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
return document
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_document_by_id(document_id: str) -> Optional[Document]:
|
||||
@@ -744,25 +747,26 @@ class DocumentService:
|
||||
if features.billing.enabled:
|
||||
if not knowledge_config.original_document_id:
|
||||
count = 0
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
count = len(upload_file_list)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info.pages)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
count = len(website_info.urls)
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
if knowledge_config.data_source:
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
count = len(upload_file_list)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list: # type: ignore
|
||||
count = count + len(notion_info.pages)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
count = len(website_info.urls) # type: ignore
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
|
||||
DocumentService.check_documents_upload_quota(count, features)
|
||||
DocumentService.check_documents_upload_quota(count, features)
|
||||
|
||||
# if dataset is empty, update dataset data_source_type
|
||||
if not dataset.data_source_type:
|
||||
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type
|
||||
dataset.data_source_type = knowledge_config.data_source.info_list.data_source_type # type: ignore
|
||||
|
||||
if not dataset.indexing_technique:
|
||||
if knowledge_config.indexing_technique not in Dataset.INDEXING_TECHNIQUE_LIST:
|
||||
@@ -789,7 +793,7 @@ class DocumentService:
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model
|
||||
dataset.retrieval_model = knowledge_config.retrieval_model.model_dump() or default_retrieval_model # type: ignore
|
||||
|
||||
documents = []
|
||||
if knowledge_config.original_document_id:
|
||||
@@ -801,34 +805,35 @@ class DocumentService:
|
||||
# save process rule
|
||||
if not dataset_process_rule:
|
||||
process_rule = knowledge_config.process_rule
|
||||
if process_rule.mode in ("custom", "hierarchical"):
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=process_rule.rules.model_dump_json(),
|
||||
created_by=account.id,
|
||||
)
|
||||
elif process_rule.mode == "automatic":
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
else:
|
||||
logging.warn(
|
||||
f"Invalid process rule mode: {process_rule['mode']}, can not find dataset process rule"
|
||||
)
|
||||
return
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.commit()
|
||||
if process_rule:
|
||||
if process_rule.mode in ("custom", "hierarchical"):
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
elif process_rule.mode == "automatic":
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES),
|
||||
created_by=account.id,
|
||||
)
|
||||
else:
|
||||
logging.warn(
|
||||
f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule"
|
||||
)
|
||||
return
|
||||
db.session.add(dataset_process_rule)
|
||||
db.session.commit()
|
||||
lock_name = "add_document_lock_dataset_id_{}".format(dataset.id)
|
||||
with redis_client.lock(lock_name, timeout=600):
|
||||
position = DocumentService.get_documents_position(dataset.id)
|
||||
document_ids = []
|
||||
duplicate_document_ids = []
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids # type: ignore
|
||||
for file_id in upload_file_list:
|
||||
file = (
|
||||
db.session.query(UploadFile)
|
||||
@@ -854,7 +859,7 @@ class DocumentService:
|
||||
name=file_name,
|
||||
).first()
|
||||
if document:
|
||||
document.dataset_process_rule_id = dataset_process_rule.id
|
||||
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
|
||||
document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
|
||||
document.created_from = created_from
|
||||
document.doc_form = knowledge_config.doc_form
|
||||
@@ -868,7 +873,7 @@ class DocumentService:
|
||||
continue
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
@@ -886,6 +891,8 @@ class DocumentService:
|
||||
position += 1
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
if not notion_info_list:
|
||||
raise ValueError("No notion info list found.")
|
||||
exist_page_ids = []
|
||||
exist_document = {}
|
||||
documents = Document.query.filter_by(
|
||||
@@ -921,7 +928,7 @@ class DocumentService:
|
||||
}
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
@@ -944,6 +951,8 @@ class DocumentService:
|
||||
clean_notion_document_task.delay(list(exist_document.values()), dataset.id)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
if not website_info:
|
||||
raise ValueError("No website info list found.")
|
||||
urls = website_info.urls
|
||||
for url in urls:
|
||||
data_source_info = {
|
||||
@@ -959,7 +968,7 @@ class DocumentService:
|
||||
document_name = url
|
||||
document = DocumentService.build_document(
|
||||
dataset,
|
||||
dataset_process_rule.id,
|
||||
dataset_process_rule.id, # type: ignore
|
||||
knowledge_config.data_source.info_list.data_source_type,
|
||||
knowledge_config.doc_form,
|
||||
knowledge_config.doc_language,
|
||||
@@ -1054,7 +1063,7 @@ class DocumentService:
|
||||
dataset_process_rule = DatasetProcessRule(
|
||||
dataset_id=dataset.id,
|
||||
mode=process_rule.mode,
|
||||
rules=process_rule.rules.model_dump_json(),
|
||||
rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
elif process_rule.mode == "automatic":
|
||||
@@ -1073,6 +1082,8 @@ class DocumentService:
|
||||
file_name = ""
|
||||
data_source_info = {}
|
||||
if document_data.data_source.info_list.data_source_type == "upload_file":
|
||||
if not document_data.data_source.info_list.file_info_list:
|
||||
raise ValueError("No file info list found.")
|
||||
upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
|
||||
for file_id in upload_file_list:
|
||||
file = (
|
||||
@@ -1090,6 +1101,8 @@ class DocumentService:
|
||||
"upload_file_id": file_id,
|
||||
}
|
||||
elif document_data.data_source.info_list.data_source_type == "notion_import":
|
||||
if not document_data.data_source.info_list.notion_info_list:
|
||||
raise ValueError("No notion info list found.")
|
||||
notion_info_list = document_data.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list:
|
||||
workspace_id = notion_info.workspace_id
|
||||
@@ -1107,20 +1120,21 @@ class DocumentService:
|
||||
data_source_info = {
|
||||
"notion_workspace_id": workspace_id,
|
||||
"notion_page_id": page.page_id,
|
||||
"notion_page_icon": page.page_icon,
|
||||
"notion_page_icon": page.page_icon.model_dump() if page.page_icon else None, # type: ignore
|
||||
"type": page.type,
|
||||
}
|
||||
elif document_data.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = document_data.data_source.info_list.website_info_list
|
||||
urls = website_info.urls
|
||||
for url in urls:
|
||||
data_source_info = {
|
||||
"url": url,
|
||||
"provider": website_info.provider,
|
||||
"job_id": website_info.job_id,
|
||||
"only_main_content": website_info.only_main_content,
|
||||
"mode": "crawl",
|
||||
}
|
||||
if website_info:
|
||||
urls = website_info.urls
|
||||
for url in urls:
|
||||
data_source_info = {
|
||||
"url": url,
|
||||
"provider": website_info.provider,
|
||||
"job_id": website_info.job_id,
|
||||
"only_main_content": website_info.only_main_content, # type: ignore
|
||||
"mode": "crawl",
|
||||
}
|
||||
document.data_source_type = document_data.data_source.info_list.data_source_type
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
document.name = file_name
|
||||
@@ -1155,15 +1169,21 @@ class DocumentService:
|
||||
if features.billing.enabled:
|
||||
count = 0
|
||||
if knowledge_config.data_source.info_list.data_source_type == "upload_file":
|
||||
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
upload_file_list = (
|
||||
knowledge_config.data_source.info_list.file_info_list.file_ids
|
||||
if knowledge_config.data_source.info_list.file_info_list
|
||||
else []
|
||||
)
|
||||
count = len(upload_file_list)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "notion_import":
|
||||
notion_info_list = knowledge_config.data_source.info_list.notion_info_list
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info.pages)
|
||||
if notion_info_list:
|
||||
for notion_info in notion_info_list:
|
||||
count = count + len(notion_info.pages)
|
||||
elif knowledge_config.data_source.info_list.data_source_type == "website_crawl":
|
||||
website_info = knowledge_config.data_source.info_list.website_info_list
|
||||
count = len(website_info.urls)
|
||||
if website_info:
|
||||
count = len(website_info.urls)
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
@@ -1174,20 +1194,20 @@ class DocumentService:
|
||||
retrieval_model = None
|
||||
if knowledge_config.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
knowledge_config.embedding_model_provider, knowledge_config.embedding_model
|
||||
knowledge_config.embedding_model_provider, # type: ignore
|
||||
knowledge_config.embedding_model, # type: ignore
|
||||
)
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
if knowledge_config.retrieval_model:
|
||||
retrieval_model = knowledge_config.retrieval_model
|
||||
else:
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 2,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
retrieval_model = RetrievalModel(**default_retrieval_model)
|
||||
retrieval_model = RetrievalModel(
|
||||
search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||
reranking_enable=False,
|
||||
reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
|
||||
top_k=2,
|
||||
score_threshold_enabled=False,
|
||||
)
|
||||
# save dataset
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
@@ -1557,12 +1577,12 @@ class SegmentService:
|
||||
raise ValueError("Can't update disabled segment")
|
||||
try:
|
||||
word_count_change = segment.word_count
|
||||
content = args.content
|
||||
content = args.content or segment.content
|
||||
if segment.content == content:
|
||||
segment.word_count = len(content)
|
||||
if document.doc_form == "qa_model":
|
||||
segment.answer = args.answer
|
||||
segment.word_count += len(args.answer)
|
||||
segment.word_count += len(args.answer) if args.answer else 0
|
||||
word_count_change = segment.word_count - word_count_change
|
||||
if args.keywords:
|
||||
segment.keywords = args.keywords
|
||||
@@ -1577,7 +1597,12 @@ class SegmentService:
|
||||
db.session.add(document)
|
||||
# update segment index task
|
||||
if args.enabled:
|
||||
VectorService.create_segments_vector([args.keywords], [segment], dataset)
|
||||
VectorService.create_segments_vector(
|
||||
[args.keywords] if args.keywords else None,
|
||||
[segment],
|
||||
dataset,
|
||||
document.doc_form,
|
||||
)
|
||||
if document.doc_form == IndexType.PARENT_CHILD_INDEX and args.regenerate_child_chunks:
|
||||
# regenerate child chunks
|
||||
# get embedding model instance
|
||||
@@ -1605,6 +1630,8 @@ class SegmentService:
|
||||
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if not processing_rule:
|
||||
raise ValueError("No processing rule found.")
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
@@ -1639,7 +1666,7 @@ class SegmentService:
|
||||
segment.disabled_by = None
|
||||
if document.doc_form == "qa_model":
|
||||
segment.answer = args.answer
|
||||
segment.word_count += len(args.answer)
|
||||
segment.word_count += len(args.answer) if args.answer else 0
|
||||
word_count_change = segment.word_count - word_count_change
|
||||
# update document word count
|
||||
if word_count_change != 0:
|
||||
@@ -1673,6 +1700,8 @@ class SegmentService:
|
||||
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
|
||||
.first()
|
||||
)
|
||||
if not processing_rule:
|
||||
raise ValueError("No processing rule found.")
|
||||
VectorService.generate_child_chunks(
|
||||
segment, document, dataset, embedding_model_instance, processing_rule, True
|
||||
)
|
||||
|
Reference in New Issue
Block a user