feat: universal chat in explore (#649)

Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
This commit is contained in:
John Wang
2023-07-27 13:08:57 +08:00
committed by GitHub
parent 94b54b7ca9
commit 4fdb37771a
64 changed files with 3186 additions and 858 deletions

View File

@@ -37,6 +37,8 @@ class CompletionService:
if not query:
raise ValueError('query is required')
query = query.replace('\x00', '')
conversation_id = args['conversation_id'] if 'conversation_id' in args else None
conversation = None
@@ -140,6 +142,7 @@ class CompletionService:
suggested_questions=json.dumps(model_config['suggested_questions']),
suggested_questions_after_answer=json.dumps(model_config['suggested_questions_after_answer']),
more_like_this=json.dumps(model_config['more_like_this']),
sensitive_word_avoidance=json.dumps(model_config['sensitive_word_avoidance']),
model=json.dumps(model_config['model']),
user_input_form=json.dumps(model_config['user_input_form']),
pre_prompt=model_config['pre_prompt'],
@@ -171,7 +174,7 @@ class CompletionService:
generate_worker_thread.start()
# wait for 5 minutes to close the thread
# wait for 10 minutes to close the thread
cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
return cls.compact_response(pubsub, streaming)
@@ -179,9 +182,9 @@ class CompletionService:
@classmethod
def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]):
if isinstance(user, Account):
user = db.session.query(Account).get(user.id)
user = db.session.query(Account).filter(Account.id == user.id).first()
elif isinstance(user, EndUser):
user = db.session.query(EndUser).get(user.id)
user = db.session.query(EndUser).filter(EndUser.id == user.id).first()
else:
raise Exception("Unknown user type")
@@ -226,12 +229,15 @@ class CompletionService:
@classmethod
def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
# wait for 5 minutes to close the thread
timeout = 300
# wait for 10 minutes to close the thread
timeout = 600
def close_pubsub():
sleep_iterations = 0
while sleep_iterations < timeout and worker_thread.is_alive():
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
PubHandler.ping(user, generate_task_id)
time.sleep(1)
sleep_iterations += 1
@@ -369,7 +375,7 @@ class CompletionService:
if len(value) > max_length:
raise ValueError(f'{variable} in input form must be less than {max_length} characters')
filtered_inputs[variable] = value
filtered_inputs[variable] = value.replace('\x00', '') if value else None
return filtered_inputs
@@ -418,6 +424,10 @@ class CompletionService:
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
elif event == 'agent_thought':
yield "data: " + json.dumps(cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
elif event == 'ping':
yield "event: ping\n\n"
else:
yield "data: " + json.dumps(result) + "\n\n"
except ValueError as e:
if e.args[0] != "I/O operation on closed file.": # ignore this error
logging.exception(e)
@@ -467,16 +477,14 @@ class CompletionService:
def get_agent_thought_response_data(cls, data: dict):
response_data = {
'event': 'agent_thought',
'id': data.get('agent_thought_id'),
'id': data.get('id'),
'chain_id': data.get('chain_id'),
'task_id': data.get('task_id'),
'message_id': data.get('message_id'),
'position': data.get('position'),
'thought': data.get('thought'),
'tool': data.get('tool'), # todo use real dataset obj replace it
'tool': data.get('tool'),
'tool_input': data.get('tool_input'),
'observation': data.get('observation'),
'answer': data.get('answer') if not data.get('thought') else '',
'created_at': int(time.time())
}