feat: advanced prompt backend (#1301)
Co-authored-by: takatost <takatost@gmail.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Type
|
||||
from typing import Type, Optional
|
||||
|
||||
from flask import current_app
|
||||
from langchain.tools import BaseTool
|
||||
@@ -28,7 +28,8 @@ class DatasetRetrieverTool(BaseTool):
|
||||
|
||||
tenant_id: str
|
||||
dataset_id: str
|
||||
k: int = 3
|
||||
top_k: int = 2
|
||||
score_threshold: Optional[float] = None
|
||||
conversation_message_task: ConversationMessageTask
|
||||
return_resource: bool
|
||||
retriever_from: str
|
||||
@@ -66,7 +67,7 @@ class DatasetRetrieverTool(BaseTool):
|
||||
)
|
||||
)
|
||||
|
||||
documents = kw_table_index.search(query, search_kwargs={'k': self.k})
|
||||
documents = kw_table_index.search(query, search_kwargs={'k': self.top_k})
|
||||
return str("\n".join([document.page_content for document in documents]))
|
||||
else:
|
||||
|
||||
@@ -80,20 +81,21 @@ class DatasetRetrieverTool(BaseTool):
|
||||
return ''
|
||||
except ProviderTokenNotInitError:
|
||||
return ''
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
|
||||
embeddings = CacheEmbedding(embedding_model)
|
||||
vector_index = VectorIndex(
|
||||
dataset=dataset,
|
||||
config=current_app.config,
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
if self.k > 0:
|
||||
if self.top_k > 0:
|
||||
documents = vector_index.search(
|
||||
query,
|
||||
search_type='similarity_score_threshold',
|
||||
search_kwargs={
|
||||
'k': self.k,
|
||||
'k': self.top_k,
|
||||
'score_threshold': self.score_threshold,
|
||||
'filter': {
|
||||
'group_id': [dataset.id]
|
||||
}
|
||||
|
Reference in New Issue
Block a user