chore(api/core): apply ruff reformatting (#7624)

This commit is contained in:
Bowen Liang
2024-09-10 17:00:20 +08:00
committed by GitHub
parent 178730266d
commit 2cf1187b32
724 changed files with 21180 additions and 21123 deletions

View File

@@ -32,14 +32,11 @@ from models.dataset import Dataset, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
@@ -48,15 +45,18 @@ class DatasetRetrieval:
self.application_generate_entity = application_generate_entity
def retrieve(
self, app_id: str, user_id: str, tenant_id: str,
model_config: ModelConfigWithCredentialsEntity,
config: DatasetEntity,
query: str,
invoke_from: InvokeFrom,
show_retrieve_source: bool,
hit_callback: DatasetIndexToolCallbackHandler,
message_id: str,
memory: Optional[TokenBufferMemory] = None,
self,
app_id: str,
user_id: str,
tenant_id: str,
model_config: ModelConfigWithCredentialsEntity,
config: DatasetEntity,
query: str,
invoke_from: InvokeFrom,
show_retrieve_source: bool,
hit_callback: DatasetIndexToolCallbackHandler,
message_id: str,
memory: Optional[TokenBufferMemory] = None,
) -> Optional[str]:
"""
Retrieve dataset.
@@ -84,16 +84,12 @@ class DatasetRetrieval:
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,
model=model_config.model
tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model
)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model_config.model,
credentials=model_config.credentials
model=model_config.model, credentials=model_config.credentials
)
if not model_schema:
@@ -102,39 +98,46 @@ class DatasetRetrieval:
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features \
or ModelFeature.MULTI_TOOL_CALL in features:
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
# pass if dataset is not available
if not dataset:
continue
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
continue
available_datasets.append(dataset)
all_documents = []
user_from = 'account' if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'
user_from = "account" if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else "end_user"
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
all_documents = self.single_retrieve(
app_id, tenant_id, user_id, user_from, available_datasets, query,
app_id,
tenant_id,
user_id,
user_from,
available_datasets,
query,
model_instance,
model_config, planning_strategy, message_id
model_config,
planning_strategy,
message_id,
)
elif retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
all_documents = self.multiple_retrieve(
app_id, tenant_id, user_id, user_from,
available_datasets, query, retrieve_config.top_k,
app_id,
tenant_id,
user_id,
user_from,
available_datasets,
query,
retrieve_config.top_k,
retrieve_config.score_threshold,
retrieve_config.rerank_mode,
retrieve_config.reranking_model,
@@ -145,89 +148,89 @@ class DatasetRetrieval:
document_score_list = {}
for item in all_documents:
if item.metadata.get('score'):
document_score_list[item.metadata['doc_id']] = item.metadata['score']
if item.metadata.get("score"):
document_score_list[item.metadata["doc_id"]] = item.metadata["score"]
document_context_list = []
index_node_ids = [document.metadata['doc_id'] for document in all_documents]
index_node_ids = [document.metadata["doc_id"] for document in all_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == 'completed',
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids)
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(segments,
key=lambda segment: index_node_id_to_position.get(segment.index_node_id,
float('inf')))
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
if segment.answer:
document_context_list.append(f'question:{segment.get_sign_content()} answer:{segment.answer}')
document_context_list.append(f"question:{segment.get_sign_content()} answer:{segment.answer}")
else:
document_context_list.append(segment.get_sign_content())
if show_retrieve_source:
context_list = []
resource_number = 1
for segment in sorted_segments:
dataset = Dataset.query.filter_by(
id=segment.dataset_id
dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).first()
document = DatasetDocument.query.filter(DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
).first()
if dataset and document:
source = {
'position': resource_number,
'dataset_id': dataset.id,
'dataset_name': dataset.name,
'document_id': document.id,
'document_name': document.name,
'data_source_type': document.data_source_type,
'segment_id': segment.id,
'retriever_from': invoke_from.to_source(),
'score': document_score_list.get(segment.index_node_id, None)
"position": resource_number,
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": invoke_from.to_source(),
"score": document_score_list.get(segment.index_node_id, None),
}
if invoke_from.to_source() == 'dev':
source['hit_count'] = segment.hit_count
source['word_count'] = segment.word_count
source['segment_position'] = segment.position
source['index_node_hash'] = segment.index_node_hash
if invoke_from.to_source() == "dev":
source["hit_count"] = segment.hit_count
source["word_count"] = segment.word_count
source["segment_position"] = segment.position
source["index_node_hash"] = segment.index_node_hash
if segment.answer:
source['content'] = f'question:{segment.content} \nanswer:{segment.answer}'
source["content"] = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source['content'] = segment.content
source["content"] = segment.content
context_list.append(source)
resource_number += 1
if hit_callback:
hit_callback.return_retriever_resource_info(context_list)
return str("\n".join(document_context_list))
return ''
return ""
def single_retrieve(
self, app_id: str,
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
message_id: Optional[str] = None,
self,
app_id: str,
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
model_instance: ModelInstance,
model_config: ModelConfigWithCredentialsEntity,
planning_strategy: PlanningStrategy,
message_id: Optional[str] = None,
):
tools = []
for dataset in available_datasets:
description = dataset.description
if not description:
description = 'useful for when you want to answer queries about the ' + dataset.name
description = "useful for when you want to answer queries about the " + dataset.name
description = description.replace('\n', '').replace('\r', '')
description = description.replace("\n", "").replace("\r", "")
message_tool = PromptMessageTool(
name=dataset.id,
description=description,
@@ -235,14 +238,15 @@ class DatasetRetrieval:
"type": "object",
"properties": {},
"required": [],
}
},
)
tools.append(message_tool)
dataset_id = None
if planning_strategy == PlanningStrategy.REACT_ROUTER:
react_multi_dataset_router = ReactMultiDatasetRouter()
dataset_id = react_multi_dataset_router.invoke(query, tools, model_config, model_instance,
user_id, tenant_id)
dataset_id = react_multi_dataset_router.invoke(
query, tools, model_config, model_instance, user_id, tenant_id
)
elif planning_strategy == PlanningStrategy.ROUTER:
function_call_router = FunctionCallMultiDatasetRouter()
@@ -250,37 +254,37 @@ class DatasetRetrieval:
if dataset_id:
# get retrieval model config
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if dataset:
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
top_k = retrieval_model_config["top_k"]
# get retrieval method
if dataset.indexing_technique == "economy":
retrieval_method = 'keyword_search'
retrieval_method = "keyword_search"
else:
retrieval_method = retrieval_model_config['search_method']
retrieval_method = retrieval_model_config["search_method"]
# get reranking model
reranking_model = retrieval_model_config['reranking_model'] \
if retrieval_model_config['reranking_enable'] else None
reranking_model = (
retrieval_model_config["reranking_model"] if retrieval_model_config["reranking_enable"] else None
)
# get score threshold
score_threshold = .0
score_threshold = 0.0
score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
if score_threshold_enabled:
score_threshold = retrieval_model_config.get("score_threshold")
with measure_time() as timer:
results = RetrievalService.retrieve(
retrieval_method=retrieval_method, dataset_id=dataset.id,
retrieval_method=retrieval_method,
dataset_id=dataset.id,
query=query,
top_k=top_k, score_threshold=score_threshold,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
reranking_mode=retrieval_model_config.get('reranking_mode', 'reranking_model'),
weights=retrieval_model_config.get('weights', None),
reranking_mode=retrieval_model_config.get("reranking_mode", "reranking_model"),
weights=retrieval_model_config.get("weights", None),
)
self._on_query(query, [dataset_id], app_id, user_from, user_id)
@@ -291,20 +295,20 @@ class DatasetRetrieval:
return []
def multiple_retrieve(
self,
app_id: str,
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
top_k: int,
score_threshold: float,
reranking_mode: str,
reranking_model: Optional[dict] = None,
weights: Optional[dict] = None,
reranking_enable: bool = True,
message_id: Optional[str] = None,
self,
app_id: str,
tenant_id: str,
user_id: str,
user_from: str,
available_datasets: list,
query: str,
top_k: int,
score_threshold: float,
reranking_mode: str,
reranking_model: Optional[dict] = None,
weights: Optional[dict] = None,
reranking_enable: bool = True,
message_id: Optional[str] = None,
):
threads = []
all_documents = []
@@ -312,13 +316,16 @@ class DatasetRetrieval:
index_type = None
for dataset in available_datasets:
index_type = dataset.indexing_technique
retrieval_thread = threading.Thread(target=self._retriever, kwargs={
'flask_app': current_app._get_current_object(),
'dataset_id': dataset.id,
'query': query,
'top_k': top_k,
'all_documents': all_documents,
})
retrieval_thread = threading.Thread(
target=self._retriever,
kwargs={
"flask_app": current_app._get_current_object(),
"dataset_id": dataset.id,
"query": query,
"top_k": top_k,
"all_documents": all_documents,
},
)
threads.append(retrieval_thread)
retrieval_thread.start()
for thread in threads:
@@ -327,16 +334,10 @@ class DatasetRetrieval:
with measure_time() as timer:
if reranking_enable:
# do rerank for searched documents
data_post_processor = DataPostProcessor(
tenant_id, reranking_mode,
reranking_model, weights, False
)
data_post_processor = DataPostProcessor(tenant_id, reranking_mode, reranking_model, weights, False)
all_documents = data_post_processor.invoke(
query=query,
documents=all_documents,
score_threshold=score_threshold,
top_n=top_k
query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k
)
else:
if index_type == "economy":
@@ -357,30 +358,26 @@ class DatasetRetrieval:
"""Handle retrieval end."""
for document in documents:
query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata['doc_id']
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
# if 'dataset_id' in document.metadata:
if 'dataset_id' in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata['dataset_id'])
if "dataset_id" in document.metadata:
query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False
)
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
db.session.commit()
# get tracing instance
trace_manager: TraceQueueManager = self.application_generate_entity.trace_manager if self.application_generate_entity else None
trace_manager: TraceQueueManager = (
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE,
message_id=message_id,
documents=documents,
timer=timer
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
)
)
@@ -395,10 +392,10 @@ class DatasetRetrieval:
dataset_query = DatasetQuery(
dataset_id=dataset_id,
content=query,
source='app',
source="app",
source_app_id=app_id,
created_by_role=user_from,
created_by=user_id
created_by=user_id,
)
dataset_queries.append(dataset_query)
if dataset_queries:
@@ -407,9 +404,7 @@ class DatasetRetrieval:
def _retriever(self, flask_app: Flask, dataset_id: str, query: str, top_k: int, all_documents: list):
with flask_app.app_context():
dataset = db.session.query(Dataset).filter(
Dataset.id == dataset_id
).first()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
return []
@@ -419,38 +414,42 @@ class DatasetRetrieval:
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(retrieval_method='keyword_search',
dataset_id=dataset.id,
query=query,
top_k=top_k
)
documents = RetrievalService.retrieve(
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=top_k
)
if documents:
all_documents.extend(documents)
else:
if top_k > 0:
# retrieval source
documents = RetrievalService.retrieve(retrieval_method=retrieval_model['search_method'],
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=retrieval_model.get('score_threshold', .0)
if retrieval_model['score_threshold_enabled'] else None,
reranking_model=retrieval_model.get('reranking_model', None)
if retrieval_model['reranking_enable'] else None,
reranking_mode=retrieval_model.get('reranking_mode')
if retrieval_model.get('reranking_mode') else 'reranking_model',
weights=retrieval_model.get('weights', None),
)
documents = RetrievalService.retrieve(
retrieval_method=retrieval_model["search_method"],
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=retrieval_model.get("score_threshold", 0.0)
if retrieval_model["score_threshold_enabled"]
else None,
reranking_model=retrieval_model.get("reranking_model", None)
if retrieval_model["reranking_enable"]
else None,
reranking_mode=retrieval_model.get("reranking_mode")
if retrieval_model.get("reranking_mode")
else "reranking_model",
weights=retrieval_model.get("weights", None),
)
all_documents.extend(documents)
def to_dataset_retriever_tool(self, tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler) \
-> Optional[list[DatasetRetrieverBaseTool]]:
def to_dataset_retriever_tool(
self,
tenant_id: str,
dataset_ids: list[str],
retrieve_config: DatasetRetrieveConfigEntity,
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
) -> Optional[list[DatasetRetrieverBaseTool]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param tenant_id: tenant id
@@ -464,18 +463,14 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
# pass if dataset is not available
if not dataset:
continue
# pass if dataset is not available
if (dataset and dataset.available_document_count == 0
and dataset.available_document_count == 0):
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
continue
available_datasets.append(dataset)
@@ -483,22 +478,18 @@ class DatasetRetrieval:
if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
# get retrieval model config
default_retrieval_model = {
'search_method': RetrievalMethod.SEMANTIC_SEARCH.value,
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enabled': False
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 2,
"score_threshold_enabled": False,
}
for dataset in available_datasets:
retrieval_model_config = dataset.retrieval_model \
if dataset.retrieval_model else default_retrieval_model
retrieval_model_config = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
# get top k
top_k = retrieval_model_config['top_k']
top_k = retrieval_model_config["top_k"]
# get score threshold
score_threshold = None
@@ -512,7 +503,7 @@ class DatasetRetrieval:
score_threshold=score_threshold,
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source()
retriever_from=invoke_from.to_source(),
)
tools.append(tool)
@@ -525,8 +516,8 @@ class DatasetRetrieval:
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
reranking_provider_name=retrieve_config.reranking_model.get('reranking_provider_name'),
reranking_model_name=retrieve_config.reranking_model.get('reranking_model_name')
reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"),
reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"),
)
tools.append(tool)
@@ -547,7 +538,7 @@ class DatasetRetrieval:
for document in documents:
# get the document keywords
document_keywords = keyword_table_handler.extract_keywords(document.page_content, None)
document.metadata['keywords'] = document_keywords
document.metadata["keywords"] = document_keywords
documents_keywords.append(document_keywords)
# Counter query keywords(TF)
@@ -606,21 +597,19 @@ class DatasetRetrieval:
for document, score in zip(documents, similarities):
# format document
document.metadata['score'] = score
documents = sorted(documents, key=lambda x: x.metadata['score'], reverse=True)
document.metadata["score"] = score
documents = sorted(documents, key=lambda x: x.metadata["score"], reverse=True)
return documents[:top_k] if top_k else documents
def calculate_vector_score(self, all_documents: list[Document],
top_k: int, score_threshold: float) -> list[Document]:
def calculate_vector_score(
self, all_documents: list[Document], top_k: int, score_threshold: float
) -> list[Document]:
filter_documents = []
for document in all_documents:
if score_threshold is None or document.metadata['score'] >= score_threshold:
if score_threshold is None or document.metadata["score"] >= score_threshold:
filter_documents.append(document)
if not filter_documents:
return []
filter_documents = sorted(filter_documents, key=lambda x: x.metadata['score'], reverse=True)
filter_documents = sorted(filter_documents, key=lambda x: x.metadata["score"], reverse=True)
return filter_documents[:top_k] if top_k else filter_documents