Feat:dataset retiever resource (#1123)

Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
Jyong
2023-09-10 15:17:43 +08:00
committed by GitHub
parent e161c511af
commit 642842d61b
32 changed files with 442 additions and 33 deletions

View File

@@ -11,7 +11,8 @@ from sqlalchemy import and_
from core.completion import Completion
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, LLMRateLimitError, \
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, \
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -95,6 +96,7 @@ class CompletionService:
app_model_config_model = app_model_config.model_dict
app_model_config_model['completion_params'] = completion_params
app_model_config.retriever_resource = json.dumps({'enabled': True})
app_model_config = app_model_config.copy()
app_model_config.model = json.dumps(app_model_config_model)
@@ -145,7 +147,8 @@ class CompletionService:
'user': user,
'conversation': conversation,
'streaming': streaming,
'is_model_config_override': is_model_config_override
'is_model_config_override': is_model_config_override,
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
})
generate_worker_thread.start()
@@ -169,7 +172,8 @@ class CompletionService:
@classmethod
def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig,
query: str, inputs: dict, user: Union[Account, EndUser],
conversation: Conversation, streaming: bool, is_model_config_override: bool):
conversation: Conversation, streaming: bool, is_model_config_override: bool,
retriever_from: str = 'dev'):
with flask_app.app_context():
try:
if conversation:
@@ -188,6 +192,7 @@ class CompletionService:
conversation=conversation,
streaming=streaming,
is_override=is_model_config_override,
retriever_from=retriever_from
)
except ConversationTaskStoppedException:
pass
@@ -400,7 +405,11 @@ class CompletionService:
elif event == 'chain':
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
elif event == 'agent_thought':
yield "data: " + json.dumps(cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
yield "data: " + json.dumps(
cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
elif event == 'message_end':
yield "data: " + json.dumps(
cls.get_message_end_data(result.get('data'))) + "\n\n"
elif event == 'ping':
yield "event: ping\n\n"
else:
@@ -432,6 +441,20 @@ class CompletionService:
return response_data
@classmethod
def get_message_end_data(cls, data: dict):
response_data = {
'event': 'message_end',
'task_id': data.get('task_id'),
'id': data.get('message_id')
}
if 'retriever_resources' in data:
response_data['retriever_resources'] = data.get('retriever_resources')
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def get_chain_response_data(cls, data: dict):
response_data = {