Feat/add retriever rerank (#1560)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong
2023-11-17 22:13:37 +08:00
committed by GitHub
parent a4f37220a0
commit 4588831bff
44 changed files with 1899 additions and 164 deletions

View File

@@ -3,7 +3,7 @@ import pickle
from json import JSONDecodeError
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.dialects.postgresql import UUID, JSONB
from extensions.ext_database import db
from models.account import Account
@@ -15,6 +15,7 @@ class Dataset(db.Model):
__table_args__ = (
db.PrimaryKeyConstraint('id', name='dataset_pkey'),
db.Index('dataset_tenant_idx', 'tenant_id'),
db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin')
)
INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy']
@@ -39,7 +40,7 @@ class Dataset(db.Model):
embedding_model = db.Column(db.String(255), nullable=True)
embedding_model_provider = db.Column(db.String(255), nullable=True)
collection_binding_id = db.Column(UUID, nullable=True)
retrieval_model = db.Column(JSONB, nullable=True)
@property
def dataset_keyword_table(self):
@@ -93,6 +94,20 @@ class Dataset(db.Model):
return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
.filter(Document.dataset_id == self.id).scalar()
@property
def retrieval_model_dict(self):
default_retrieval_model = {
'search_method': 'semantic_search',
'reranking_enable': False,
'reranking_model': {
'reranking_provider_name': '',
'reranking_model_name': ''
},
'top_k': 2,
'score_threshold_enable': False
}
return self.retrieval_model if self.retrieval_model else default_retrieval_model
class DatasetProcessRule(db.Model):
__tablename__ = 'dataset_process_rules'
@@ -120,7 +135,7 @@ class DatasetProcessRule(db.Model):
],
'segmentation': {
'delimiter': '\n',
'max_tokens': 1000
'max_tokens': 512
}
}
@@ -462,4 +477,3 @@ class DatasetCollectionBinding(db.Model):
model_name = db.Column(db.String(40), nullable=False)
collection_name = db.Column(db.String(64), nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))