Feat:dataset retiever resource (#1123)

Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
Jyong
2023-09-10 15:17:43 +08:00
committed by GitHub
parent e161c511af
commit 642842d61b
32 changed files with 442 additions and 33 deletions

View File

@@ -1,4 +1,5 @@
import json
from json import JSONDecodeError
from flask import current_app, request
from flask_login import UserMixin
@@ -90,6 +91,7 @@ class AppModelConfig(db.Model):
pre_prompt = db.Column(db.Text)
agent_mode = db.Column(db.Text)
sensitive_word_avoidance = db.Column(db.Text)
retriever_resource = db.Column(db.Text)
@property
def app(self):
@@ -114,6 +116,11 @@ class AppModelConfig(db.Model):
return json.loads(self.speech_to_text) if self.speech_to_text \
else {"enabled": False}
@property
def retriever_resource_dict(self) -> dict:
return json.loads(self.retriever_resource) if self.retriever_resource \
else {"enabled": False}
@property
def more_like_this_dict(self) -> dict:
return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}
@@ -140,6 +147,7 @@ class AppModelConfig(db.Model):
"suggested_questions": self.suggested_questions_list,
"suggested_questions_after_answer": self.suggested_questions_after_answer_dict,
"speech_to_text": self.speech_to_text_dict,
"retriever_resource": self.retriever_resource,
"more_like_this": self.more_like_this_dict,
"sensitive_word_avoidance": self.sensitive_word_avoidance_dict,
"model": self.model_dict,
@@ -164,7 +172,8 @@ class AppModelConfig(db.Model):
self.user_input_form = json.dumps(model_config['user_input_form'])
self.pre_prompt = model_config['pre_prompt']
self.agent_mode = json.dumps(model_config['agent_mode'])
self.retriever_resource = json.dumps(model_config['retriever_resource']) \
if model_config.get('retriever_resource') else None
return self
def copy(self):
@@ -318,6 +327,7 @@ class Conversation(db.Model):
model_config['suggested_questions'] = app_model_config.suggested_questions_list
model_config['suggested_questions_after_answer'] = app_model_config.suggested_questions_after_answer_dict
model_config['speech_to_text'] = app_model_config.speech_to_text_dict
model_config['retriever_resource'] = app_model_config.retriever_resource_dict
model_config['more_like_this'] = app_model_config.more_like_this_dict
model_config['sensitive_word_avoidance'] = app_model_config.sensitive_word_avoidance_dict
model_config['user_input_form'] = app_model_config.user_input_form_list
@@ -476,6 +486,11 @@ class Message(db.Model):
return db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == self.id) \
.order_by(MessageAgentThought.position.asc()).all()
@property
def retriever_resources(self):
return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \
.order_by(DatasetRetrieverResource.position.asc()).all()
class MessageFeedback(db.Model):
__tablename__ = 'message_feedbacks'
@@ -719,3 +734,31 @@ class MessageAgentThought(db.Model):
created_by_role = db.Column(db.String, nullable=False)
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
class DatasetRetrieverResource(db.Model):
__tablename__ = 'dataset_retriever_resources'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='dataset_retriever_resource_pkey'),
db.Index('dataset_retriever_resource_message_id_idx', 'message_id'),
)
id = db.Column(UUID, nullable=False, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(UUID, nullable=False)
position = db.Column(db.Integer, nullable=False)
dataset_id = db.Column(UUID, nullable=False)
dataset_name = db.Column(db.Text, nullable=False)
document_id = db.Column(UUID, nullable=False)
document_name = db.Column(db.Text, nullable=False)
data_source_type = db.Column(db.Text, nullable=False)
segment_id = db.Column(UUID, nullable=False)
score = db.Column(db.Float, nullable=True)
content = db.Column(db.Text, nullable=False)
hit_count = db.Column(db.Integer, nullable=True)
word_count = db.Column(db.Integer, nullable=True)
segment_position = db.Column(db.Integer, nullable=True)
index_node_hash = db.Column(db.Text, nullable=True)
retriever_from = db.Column(db.Text, nullable=False)
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())