Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
@@ -1,29 +1,16 @@
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Generator, Union, Any, Optional, List
|
||||
from typing import Generator, Union, Any
|
||||
|
||||
from flask import current_app, Flask
|
||||
from redis.client import PubSub
|
||||
from sqlalchemy import and_
|
||||
|
||||
from core.completion import Completion
|
||||
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.application_manager import ApplicationManager
|
||||
from core.entities.application_entities import InvokeFrom
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, \
|
||||
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_providers.models.entity.message import PromptMessageFile
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.completion import CompletionStoppedError
|
||||
from services.errors.conversation import ConversationNotExistsError, ConversationCompletedError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
@@ -32,7 +19,7 @@ class CompletionService:
|
||||
|
||||
@classmethod
|
||||
def completion(cls, app_model: App, user: Union[Account, EndUser], args: Any,
|
||||
from_source: str, streaming: bool = True,
|
||||
invoke_from: InvokeFrom, streaming: bool = True,
|
||||
is_model_config_override: bool = False) -> Union[dict, Generator]:
|
||||
# is streaming mode
|
||||
inputs = args['inputs']
|
||||
@@ -56,7 +43,7 @@ class CompletionService:
|
||||
Conversation.status == 'normal'
|
||||
]
|
||||
|
||||
if from_source == 'console':
|
||||
if isinstance(user, Account):
|
||||
conversation_filter.append(Conversation.from_account_id == user.id)
|
||||
else:
|
||||
conversation_filter.append(Conversation.from_end_user_id == user.id if user else None)
|
||||
@@ -124,7 +111,7 @@ class CompletionService:
|
||||
tenant_id=app_model.tenant_id,
|
||||
account=user,
|
||||
config=args['model_config'],
|
||||
mode=app_model.mode
|
||||
app_mode=app_model.mode
|
||||
)
|
||||
|
||||
app_model_config = AppModelConfig(
|
||||
@@ -145,134 +132,29 @@ class CompletionService:
|
||||
user
|
||||
)
|
||||
|
||||
generate_task_id = str(uuid.uuid4())
|
||||
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
|
||||
|
||||
user = cls.get_real_user_instead_of_proxy_obj(user)
|
||||
|
||||
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'generate_task_id': generate_task_id,
|
||||
'detached_app_model': app_model,
|
||||
'app_model_config': app_model_config.copy(),
|
||||
'query': query,
|
||||
'inputs': inputs,
|
||||
'files': file_objs,
|
||||
'detached_user': user,
|
||||
'detached_conversation': conversation,
|
||||
'streaming': streaming,
|
||||
'is_model_config_override': is_model_config_override,
|
||||
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
|
||||
'auto_generate_name': auto_generate_name,
|
||||
'from_source': from_source
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
|
||||
# wait for 10 minutes to close the thread
|
||||
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
|
||||
generate_task_id)
|
||||
|
||||
return cls.compact_response(pubsub, streaming)
|
||||
|
||||
@classmethod
|
||||
def get_real_user_instead_of_proxy_obj(cls, user: Union[Account, EndUser]):
|
||||
if isinstance(user, Account):
|
||||
user = db.session.query(Account).filter(Account.id == user.id).first()
|
||||
elif isinstance(user, EndUser):
|
||||
user = db.session.query(EndUser).filter(EndUser.id == user.id).first()
|
||||
else:
|
||||
raise Exception("Unknown user type")
|
||||
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict, files: List[PromptMessageFile],
|
||||
detached_user: Union[Account, EndUser],
|
||||
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
|
||||
retriever_from: str = 'dev', auto_generate_name: bool = True, from_source: str = 'console'):
|
||||
with flask_app.app_context():
|
||||
# fixed the state of the model object when it detached from the original session
|
||||
user = db.session.merge(detached_user)
|
||||
app_model = db.session.merge(detached_app_model)
|
||||
|
||||
if detached_conversation:
|
||||
conversation = db.session.merge(detached_conversation)
|
||||
else:
|
||||
conversation = None
|
||||
|
||||
try:
|
||||
# run
|
||||
Completion.generate(
|
||||
task_id=generate_task_id,
|
||||
app=app_model,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
user=user,
|
||||
files=files,
|
||||
conversation=conversation,
|
||||
streaming=streaming,
|
||||
is_override=is_model_config_override,
|
||||
retriever_from=retriever_from,
|
||||
auto_generate_name=auto_generate_name,
|
||||
from_source=from_source
|
||||
)
|
||||
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
|
||||
pass
|
||||
except (ValueError, LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
|
||||
ModelCurrentlyNotSupportError) as e:
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
except LLMAuthorizationError:
|
||||
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
|
||||
except Exception as e:
|
||||
logging.exception("Unknown Error in completion")
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
@classmethod
|
||||
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user,
|
||||
generate_task_id) -> threading.Thread:
|
||||
# wait for 10 minutes to close the thread
|
||||
timeout = 600
|
||||
|
||||
def close_pubsub():
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
user = db.session.merge(detached_user)
|
||||
|
||||
sleep_iterations = 0
|
||||
while sleep_iterations < timeout and worker_thread.is_alive():
|
||||
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
|
||||
PubHandler.ping(user, generate_task_id)
|
||||
|
||||
time.sleep(1)
|
||||
sleep_iterations += 1
|
||||
|
||||
if worker_thread.is_alive():
|
||||
PubHandler.stop(user, generate_task_id)
|
||||
try:
|
||||
pubsub.close()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
countdown_thread = threading.Thread(target=close_pubsub)
|
||||
countdown_thread.start()
|
||||
|
||||
return countdown_thread
|
||||
application_manager = ApplicationManager()
|
||||
return application_manager.generate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=app_model_config.to_dict(),
|
||||
app_model_config_override=is_model_config_override,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=file_objs,
|
||||
conversation=conversation,
|
||||
stream=streaming,
|
||||
extras={
|
||||
"auto_generate_conversation_name": auto_generate_name
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this(cls, app_model: App, user: Union[Account, EndUser],
|
||||
message_id: str, streaming: bool = True,
|
||||
retriever_from: str = 'dev') -> Union[dict, Generator]:
|
||||
message_id: str, invoke_from: InvokeFrom, streaming: bool = True) \
|
||||
-> Union[dict, Generator]:
|
||||
if not user:
|
||||
raise ValueError('user cannot be None')
|
||||
|
||||
@@ -306,36 +188,24 @@ class CompletionService:
|
||||
message.files, app_model_config
|
||||
)
|
||||
|
||||
generate_task_id = str(uuid.uuid4())
|
||||
|
||||
pubsub = redis_client.pubsub()
|
||||
pubsub.subscribe(PubHandler.generate_channel_name(user, generate_task_id))
|
||||
|
||||
user = cls.get_real_user_instead_of_proxy_obj(user)
|
||||
|
||||
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'generate_task_id': generate_task_id,
|
||||
'detached_app_model': app_model,
|
||||
'app_model_config': app_model_config.copy(),
|
||||
'query': message.query,
|
||||
'inputs': message.inputs,
|
||||
'files': file_objs,
|
||||
'detached_user': user,
|
||||
'detached_conversation': None,
|
||||
'streaming': streaming,
|
||||
'is_model_config_override': True,
|
||||
'retriever_from': retriever_from,
|
||||
'auto_generate_name': False
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
|
||||
# wait for 10 minutes to close the thread
|
||||
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
|
||||
generate_task_id)
|
||||
|
||||
return cls.compact_response(pubsub, streaming)
|
||||
application_manager = ApplicationManager()
|
||||
return application_manager.generate(
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
app_model_config_id=app_model_config.id,
|
||||
app_model_config_dict=app_model_config.to_dict(),
|
||||
app_model_config_override=True,
|
||||
user=user,
|
||||
invoke_from=invoke_from,
|
||||
inputs=message.inputs,
|
||||
query=message.query,
|
||||
files=file_objs,
|
||||
conversation=None,
|
||||
stream=streaming,
|
||||
extras={
|
||||
"auto_generate_conversation_name": False
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
|
||||
@@ -375,247 +245,3 @@ class CompletionService:
|
||||
|
||||
return filtered_inputs
|
||||
|
||||
@classmethod
|
||||
def compact_response(cls, pubsub: PubSub, streaming: bool = False) -> Union[dict, Generator]:
|
||||
generate_channel = list(pubsub.channels.keys())[0].decode('utf-8')
|
||||
if not streaming:
|
||||
try:
|
||||
message_result = {}
|
||||
for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
result = message["data"].decode('utf-8')
|
||||
result = json.loads(result)
|
||||
if result.get('error'):
|
||||
cls.handle_error(result)
|
||||
if result['event'] == 'annotation' and 'data' in result:
|
||||
message_result['annotation'] = result.get('data')
|
||||
return cls.get_blocking_annotation_message_response_data(message_result)
|
||||
if result['event'] == 'message' and 'data' in result:
|
||||
message_result['message'] = result.get('data')
|
||||
if result['event'] == 'message_end' and 'data' in result:
|
||||
message_result['message_end'] = result.get('data')
|
||||
return cls.get_blocking_message_response_data(message_result)
|
||||
except ValueError as e:
|
||||
if e.args[0] != "I/O operation on closed file.": # ignore this error
|
||||
raise CompletionStoppedError()
|
||||
else:
|
||||
logging.exception(e)
|
||||
raise
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
try:
|
||||
pubsub.unsubscribe(generate_channel)
|
||||
except ConnectionError:
|
||||
pass
|
||||
else:
|
||||
def generate() -> Generator:
|
||||
try:
|
||||
for message in pubsub.listen():
|
||||
if message["type"] == "message":
|
||||
result = message["data"].decode('utf-8')
|
||||
result = json.loads(result)
|
||||
if result.get('error'):
|
||||
cls.handle_error(result)
|
||||
|
||||
event = result.get('event')
|
||||
if event == "end":
|
||||
logging.debug("{} finished".format(generate_channel))
|
||||
break
|
||||
if event == 'message':
|
||||
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'message_replace':
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'chain':
|
||||
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'agent_thought':
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_agent_thought_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'annotation':
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_annotation_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'message_end':
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_message_end_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'ping':
|
||||
yield "event: ping\n\n"
|
||||
else:
|
||||
yield "data: " + json.dumps(result) + "\n\n"
|
||||
except ValueError as e:
|
||||
if e.args[0] != "I/O operation on closed file.": # ignore this error
|
||||
logging.exception(e)
|
||||
raise
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
try:
|
||||
pubsub.unsubscribe(generate_channel)
|
||||
except ConnectionError:
|
||||
pass
|
||||
|
||||
return generate()
|
||||
|
||||
@classmethod
|
||||
def get_message_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'message',
|
||||
'task_id': data.get('task_id'),
|
||||
'id': data.get('message_id'),
|
||||
'answer': data.get('text'),
|
||||
'created_at': int(time.time())
|
||||
}
|
||||
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_message_replace_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'message_replace',
|
||||
'task_id': data.get('task_id'),
|
||||
'id': data.get('message_id'),
|
||||
'answer': data.get('text'),
|
||||
'created_at': int(time.time())
|
||||
}
|
||||
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_blocking_message_response_data(cls, data: dict):
|
||||
message = data.get('message')
|
||||
response_data = {
|
||||
'event': 'message',
|
||||
'task_id': message.get('task_id'),
|
||||
'id': message.get('message_id'),
|
||||
'answer': message.get('text'),
|
||||
'metadata': {},
|
||||
'created_at': int(time.time())
|
||||
}
|
||||
|
||||
if message.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = message.get('conversation_id')
|
||||
if 'message_end' in data:
|
||||
message_end = data.get('message_end')
|
||||
if 'retriever_resources' in message_end:
|
||||
response_data['metadata']['retriever_resources'] = message_end.get('retriever_resources')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_blocking_annotation_message_response_data(cls, data: dict):
|
||||
message = data.get('annotation')
|
||||
response_data = {
|
||||
'event': 'annotation',
|
||||
'task_id': message.get('task_id'),
|
||||
'id': message.get('message_id'),
|
||||
'answer': message.get('text'),
|
||||
'metadata': {},
|
||||
'created_at': int(time.time()),
|
||||
'annotation_id': message.get('annotation_id'),
|
||||
'annotation_author_name': message.get('annotation_author_name')
|
||||
}
|
||||
|
||||
if message.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = message.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_message_end_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'message_end',
|
||||
'task_id': data.get('task_id'),
|
||||
'id': data.get('message_id')
|
||||
}
|
||||
if 'retriever_resources' in data:
|
||||
response_data['retriever_resources'] = data.get('retriever_resources')
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_chain_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'chain',
|
||||
'id': data.get('chain_id'),
|
||||
'task_id': data.get('task_id'),
|
||||
'message_id': data.get('message_id'),
|
||||
'type': data.get('type'),
|
||||
'input': data.get('input'),
|
||||
'output': data.get('output'),
|
||||
'created_at': int(time.time())
|
||||
}
|
||||
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_agent_thought_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'agent_thought',
|
||||
'id': data.get('id'),
|
||||
'chain_id': data.get('chain_id'),
|
||||
'task_id': data.get('task_id'),
|
||||
'message_id': data.get('message_id'),
|
||||
'position': data.get('position'),
|
||||
'thought': data.get('thought'),
|
||||
'tool': data.get('tool'),
|
||||
'tool_input': data.get('tool_input'),
|
||||
'created_at': int(time.time())
|
||||
}
|
||||
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def get_annotation_response_data(cls, data: dict):
|
||||
response_data = {
|
||||
'event': 'annotation',
|
||||
'task_id': data.get('task_id'),
|
||||
'id': data.get('message_id'),
|
||||
'answer': data.get('text'),
|
||||
'created_at': int(time.time()),
|
||||
'annotation_id': data.get('annotation_id'),
|
||||
'annotation_author_name': data.get('annotation_author_name'),
|
||||
}
|
||||
|
||||
if data.get('mode') == 'chat':
|
||||
response_data['conversation_id'] = data.get('conversation_id')
|
||||
|
||||
return response_data
|
||||
|
||||
@classmethod
|
||||
def handle_error(cls, result: dict):
|
||||
logging.debug("error: %s", result)
|
||||
error = result.get('error')
|
||||
description = result.get('description')
|
||||
|
||||
# handle errors
|
||||
llm_errors = {
|
||||
'ValueError': LLMBadRequestError,
|
||||
'LLMBadRequestError': LLMBadRequestError,
|
||||
'LLMAPIConnectionError': LLMAPIConnectionError,
|
||||
'LLMAPIUnavailableError': LLMAPIUnavailableError,
|
||||
'LLMRateLimitError': LLMRateLimitError,
|
||||
'ProviderTokenNotInitError': ProviderTokenNotInitError,
|
||||
'QuotaExceededError': QuotaExceededError,
|
||||
'ModelCurrentlyNotSupportError': ModelCurrentlyNotSupportError
|
||||
}
|
||||
|
||||
if error in llm_errors:
|
||||
raise llm_errors[error](description)
|
||||
elif error == 'LLMAuthorizationError':
|
||||
raise LLMAuthorizationError('Incorrect API key provided')
|
||||
else:
|
||||
raise Exception(description)
|
||||
|
Reference in New Issue
Block a user