Annotation management (#1767)

Co-authored-by: jyong <jyong@dify.ai>
This commit is contained in:
Jyong
2023-12-18 13:10:05 +08:00
committed by GitHub
parent a9b942981d
commit a71f2863ac
41 changed files with 1871 additions and 67 deletions

View File

@@ -165,7 +165,8 @@ class CompletionService:
'streaming': streaming,
'is_model_config_override': is_model_config_override,
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
'auto_generate_name': auto_generate_name
'auto_generate_name': auto_generate_name,
'from_source': from_source
})
generate_worker_thread.start()
@@ -193,7 +194,7 @@ class CompletionService:
query: str, inputs: dict, files: List[PromptMessageFile],
detached_user: Union[Account, EndUser],
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
retriever_from: str = 'dev', auto_generate_name: bool = True):
retriever_from: str = 'dev', auto_generate_name: bool = True, from_source: str = 'console'):
with flask_app.app_context():
# fixed the state of the model object when it detached from the original session
user = db.session.merge(detached_user)
@@ -218,7 +219,8 @@ class CompletionService:
streaming=streaming,
is_override=is_model_config_override,
retriever_from=retriever_from,
auto_generate_name=auto_generate_name
auto_generate_name=auto_generate_name,
from_source=from_source
)
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
pass
@@ -385,6 +387,9 @@ class CompletionService:
result = json.loads(result)
if result.get('error'):
cls.handle_error(result)
if result['event'] == 'annotation' and 'data' in result:
message_result['annotation'] = result.get('data')
return cls.get_blocking_annotation_message_response_data(message_result)
if result['event'] == 'message' and 'data' in result:
message_result['message'] = result.get('data')
if result['event'] == 'message_end' and 'data' in result:
@@ -427,6 +432,9 @@ class CompletionService:
elif event == 'agent_thought':
yield "data: " + json.dumps(
cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
elif event == 'annotation':
yield "data: " + json.dumps(
cls.get_annotation_response_data(result.get('data'))) + "\n\n"
elif event == 'message_end':
yield "data: " + json.dumps(
cls.get_message_end_data(result.get('data'))) + "\n\n"
@@ -499,6 +507,25 @@ class CompletionService:
return response_data
@classmethod
def get_blocking_annotation_message_response_data(cls, data: dict):
message = data.get('annotation')
response_data = {
'event': 'annotation',
'task_id': message.get('task_id'),
'id': message.get('message_id'),
'answer': message.get('text'),
'metadata': {},
'created_at': int(time.time()),
'annotation_id': message.get('annotation_id'),
'annotation_author_name': message.get('annotation_author_name')
}
if message.get('mode') == 'chat':
response_data['conversation_id'] = message.get('conversation_id')
return response_data
@classmethod
def get_message_end_data(cls, data: dict):
response_data = {
@@ -551,6 +578,23 @@ class CompletionService:
return response_data
@classmethod
def get_annotation_response_data(cls, data: dict):
response_data = {
'event': 'annotation',
'task_id': data.get('task_id'),
'id': data.get('message_id'),
'answer': data.get('text'),
'created_at': int(time.time()),
'annotation_id': data.get('annotation_id'),
'annotation_author_name': data.get('annotation_author_name'),
}
if data.get('mode') == 'chat':
response_data['conversation_id'] = data.get('conversation_id')
return response_data
@classmethod
def handle_error(cls, result: dict):
logging.debug("error: %s", result)