Model Runtime (#1858)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com> Co-authored-by: Garfield Dai <dai.hai@foxmail.com> Co-authored-by: chenhe <guchenhe@gmail.com> Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: Yeuoly <admin@srmxy.cn>
This commit is contained in:
0
api/core/app_runner/__init__.py
Normal file
0
api/core/app_runner/__init__.py
Normal file
251
api/core/app_runner/agent_app_runner.py
Normal file
251
api/core/app_runner/agent_app_runner.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from core.agent.agent.agent_llm_callback import AgentLLMCallback
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, PromptTemplateEntity, ModelConfigEntity
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.features.agent_runner import AgentRunnerFeature
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, App, MessageChain, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentApplicationRunner(AppRunner):
|
||||
"""
|
||||
Agent Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
"""
|
||||
Run agent application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
)
|
||||
|
||||
memory = None
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional)
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=None,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# Create MessageChain
|
||||
message_chain = self._init_message_chain(
|
||||
message=message,
|
||||
query=query
|
||||
)
|
||||
|
||||
# add agent callback to record agent thoughts
|
||||
agent_callback = AgentLoopGatherCallbackHandler(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
message=message,
|
||||
queue_manager=queue_manager,
|
||||
message_chain=message_chain
|
||||
)
|
||||
|
||||
# init LLM Callback
|
||||
agent_llm_callback = AgentLLMCallback(
|
||||
agent_callback=agent_callback
|
||||
)
|
||||
|
||||
agent_runner = AgentRunnerFeature(
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
config=app_orchestration_config.agent,
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id=application_generate_entity.user_id,
|
||||
agent_llm_callback=agent_llm_callback,
|
||||
callback=agent_callback,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# agent run
|
||||
result = agent_runner.run(
|
||||
query=query,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
)
|
||||
|
||||
if result:
|
||||
self._save_message_chain(
|
||||
message_chain=message_chain,
|
||||
output_text=result
|
||||
)
|
||||
|
||||
if (result
|
||||
and app_orchestration_config.prompt_template.prompt_type == PromptTemplateEntity.PromptType.SIMPLE
|
||||
and app_orchestration_config.prompt_template.simple_prompt_template
|
||||
):
|
||||
# Direct output if agent result exists and has pre prompt
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
stream=application_generate_entity.stream,
|
||||
text=result,
|
||||
usage=self._get_usage_of_all_agent_thoughts(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
message=message
|
||||
)
|
||||
)
|
||||
else:
|
||||
# As normal LLM run, agent result as context
|
||||
context = result
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=context,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recale_llm_max_tokens(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
user=application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
def _init_message_chain(self, message: Message, query: str) -> MessageChain:
|
||||
"""
|
||||
Init MessageChain
|
||||
:param message: message
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
message_chain = MessageChain(
|
||||
message_id=message.id,
|
||||
type="AgentExecutor",
|
||||
input=json.dumps({
|
||||
"input": query
|
||||
})
|
||||
)
|
||||
|
||||
db.session.add(message_chain)
|
||||
db.session.commit()
|
||||
|
||||
return message_chain
|
||||
|
||||
def _save_message_chain(self, message_chain: MessageChain, output_text: str) -> None:
|
||||
"""
|
||||
Save MessageChain
|
||||
:param message_chain: message chain
|
||||
:param output_text: output text
|
||||
:return:
|
||||
"""
|
||||
message_chain.output = json.dumps({
|
||||
"output": output_text
|
||||
})
|
||||
db.session.commit()
|
||||
|
||||
def _get_usage_of_all_agent_thoughts(self, model_config: ModelConfigEntity,
|
||||
message: Message) -> LLMUsage:
|
||||
"""
|
||||
Get usage of all agent thoughts
|
||||
:param model_config: model config
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
agent_thoughts = (db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.message_id == message.id).all())
|
||||
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
for agent_thought in agent_thoughts:
|
||||
all_message_tokens += agent_thought.message_tokens
|
||||
all_answer_tokens += agent_thought.answer_tokens
|
||||
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
return model_type_instance._calc_response_usage(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
all_message_tokens,
|
||||
all_answer_tokens
|
||||
)
|
267
api/core/app_runner/app_runner.py
Normal file
267
api/core/app_runner/app_runner.py
Normal file
@@ -0,0 +1,267 @@
|
||||
import time
|
||||
from typing import cast, Optional, List, Tuple, Generator, Union
|
||||
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.application_entities import ModelConfigEntity, PromptTemplateEntity, AppOrchestrationConfigEntity
|
||||
from core.file.file_obj import FileObj
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, AssistantPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_transform import PromptTransform
|
||||
from models.model import App
|
||||
|
||||
|
||||
class AppRunner:
|
||||
def get_pre_calculate_rest_tokens(self, app_record: App,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileObj],
|
||||
query: Optional[str] = None) -> int:
|
||||
"""
|
||||
Get pre calculate rest tokens
|
||||
:param app_record: app record
|
||||
:param model_config: model config entity
|
||||
:param prompt_template_entity: prompt template entity
|
||||
:param inputs: inputs
|
||||
:param files: files
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return -1
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
# get prompt messages without memory and context
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=model_config,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
)
|
||||
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
prompt_messages
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - prompt_tokens
|
||||
if rest_tokens < 0:
|
||||
raise InvokeBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "
|
||||
"or shrink the max token, or switch to a llm with a larger token limit size.")
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def recale_llm_max_tokens(self, model_config: ModelConfigEntity,
|
||||
prompt_messages: List[PromptMessage]):
|
||||
# recalc max_tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
if model_context_tokens is None:
|
||||
return -1
|
||||
|
||||
if max_tokens is None:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
prompt_messages
|
||||
)
|
||||
|
||||
if prompt_tokens + max_tokens > model_context_tokens:
|
||||
max_tokens = max(model_context_tokens - prompt_tokens, 16)
|
||||
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
model_config.parameters[parameter_rule.name] = max_tokens
|
||||
|
||||
def originze_prompt_messages(self, app_record: App,
|
||||
model_config: ModelConfigEntity,
|
||||
prompt_template_entity: PromptTemplateEntity,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileObj],
|
||||
query: Optional[str] = None,
|
||||
context: Optional[str] = None,
|
||||
memory: Optional[TokenBufferMemory] = None) \
|
||||
-> Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||
"""
|
||||
Organize prompt messages
|
||||
:param context:
|
||||
:param app_record: app record
|
||||
:param model_config: model config entity
|
||||
:param prompt_template_entity: prompt template entity
|
||||
:param inputs: inputs
|
||||
:param files: files
|
||||
:param query: query
|
||||
:param memory: memory
|
||||
:return:
|
||||
"""
|
||||
prompt_transform = PromptTransform()
|
||||
|
||||
# get prompt without memory and context
|
||||
if prompt_template_entity.prompt_type == PromptTemplateEntity.PromptType.SIMPLE:
|
||||
prompt_messages, stop = prompt_transform.get_prompt(
|
||||
app_mode=app_record.mode,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query if query else '',
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
else:
|
||||
prompt_messages = prompt_transform.get_advanced_prompt(
|
||||
app_mode=app_record.mode,
|
||||
prompt_template_entity=prompt_template_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
stop = model_config.stop
|
||||
|
||||
return prompt_messages, stop
|
||||
|
||||
def direct_output(self, queue_manager: ApplicationQueueManager,
|
||||
app_orchestration_config: AppOrchestrationConfigEntity,
|
||||
prompt_messages: list,
|
||||
text: str,
|
||||
stream: bool,
|
||||
usage: Optional[LLMUsage] = None) -> None:
|
||||
"""
|
||||
Direct output
|
||||
:param queue_manager: application queue manager
|
||||
:param app_orchestration_config: app orchestration config
|
||||
:param prompt_messages: prompt messages
|
||||
:param text: text
|
||||
:param stream: stream
|
||||
:param usage: usage
|
||||
:return:
|
||||
"""
|
||||
if stream:
|
||||
index = 0
|
||||
for token in text:
|
||||
queue_manager.publish_chunk_message(LLMResultChunk(
|
||||
model=app_orchestration_config.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=AssistantPromptMessage(content=token)
|
||||
)
|
||||
))
|
||||
index += 1
|
||||
time.sleep(0.01)
|
||||
|
||||
queue_manager.publish_message_end(
|
||||
llm_result=LLMResult(
|
||||
model=app_orchestration_config.model_config.model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage if usage else LLMUsage.empty_usage()
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_invoke_result(self, invoke_result: Union[LLMResult, Generator],
|
||||
queue_manager: ApplicationQueueManager,
|
||||
stream: bool) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:param stream: stream
|
||||
:return:
|
||||
"""
|
||||
if not stream:
|
||||
self._handle_invoke_result_direct(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager
|
||||
)
|
||||
else:
|
||||
self._handle_invoke_result_stream(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager
|
||||
)
|
||||
|
||||
def _handle_invoke_result_direct(self, invoke_result: LLMResult,
|
||||
queue_manager: ApplicationQueueManager) -> None:
|
||||
"""
|
||||
Handle invoke result direct
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:return:
|
||||
"""
|
||||
queue_manager.publish_message_end(
|
||||
llm_result=invoke_result
|
||||
)
|
||||
|
||||
def _handle_invoke_result_stream(self, invoke_result: Generator,
|
||||
queue_manager: ApplicationQueueManager) -> None:
|
||||
"""
|
||||
Handle invoke result
|
||||
:param invoke_result: invoke result
|
||||
:param queue_manager: application queue manager
|
||||
:return:
|
||||
"""
|
||||
model = None
|
||||
prompt_messages = []
|
||||
text = ''
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
queue_manager.publish_chunk_message(result)
|
||||
|
||||
text += result.delta.message.content
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
|
||||
if not prompt_messages:
|
||||
prompt_messages = result.prompt_messages
|
||||
|
||||
if not usage and result.delta.usage:
|
||||
usage = result.delta.usage
|
||||
|
||||
llm_result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=usage
|
||||
)
|
||||
|
||||
queue_manager.publish_message_end(
|
||||
llm_result=llm_result
|
||||
)
|
363
api/core/app_runner/basic_app_runner.py
Normal file
363
api/core/app_runner/basic_app_runner.py
Normal file
@@ -0,0 +1,363 @@
|
||||
import logging
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from core.app_runner.app_runner import AppRunner
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.entities.application_entities import ApplicationGenerateEntity, ModelConfigEntity, \
|
||||
AppOrchestrationConfigEntity, InvokeFrom, ExternalDataVariableEntity, DatasetEntity
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.features.annotation_reply import AnnotationReplyFeature
|
||||
from core.features.dataset_retrieval import DatasetRetrievalFeature
|
||||
from core.features.external_data_fetch import ExternalDataFetchFeature
|
||||
from core.features.hosting_moderation import HostingModerationFeature
|
||||
from core.features.moderation import ModerationFeature
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.moderation.base import ModerationException
|
||||
from core.prompt.prompt_transform import AppMode
|
||||
from extensions.ext_database import db
|
||||
from models.model import Conversation, Message, App, MessageAnnotation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasicApplicationRunner(AppRunner):
|
||||
"""
|
||||
Basic Application Runner
|
||||
"""
|
||||
|
||||
def run(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
"""
|
||||
Run application
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: application queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
:return:
|
||||
"""
|
||||
app_record = db.session.query(App).filter(App.id == application_generate_entity.app_id).first()
|
||||
if not app_record:
|
||||
raise ValueError(f"App not found")
|
||||
|
||||
app_orchestration_config = application_generate_entity.app_orchestration_config_entity
|
||||
|
||||
inputs = application_generate_entity.inputs
|
||||
query = application_generate_entity.query
|
||||
files = application_generate_entity.files
|
||||
|
||||
# Pre-calculate the number of tokens of the prompt messages,
|
||||
# and return the rest number of tokens by model context token size limit and max token size limit.
|
||||
# If the rest number of tokens is not enough, raise exception.
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# Not Include: memory, external data, dataset context
|
||||
self.get_pre_calculate_rest_tokens(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query
|
||||
)
|
||||
|
||||
memory = None
|
||||
if application_generate_entity.conversation_id:
|
||||
# get memory of conversation (read-only)
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
memory = TokenBufferMemory(
|
||||
conversation=conversation,
|
||||
model_instance=model_instance
|
||||
)
|
||||
|
||||
# organize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional)
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# moderation
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
_, inputs, query = self.moderation_for_inputs(
|
||||
app_id=app_record.id,
|
||||
tenant_id=application_generate_entity.tenant_id,
|
||||
app_orchestration_config_entity=app_orchestration_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
)
|
||||
except ModerationException as e:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
text=str(e),
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
return
|
||||
|
||||
if query:
|
||||
# annotation reply
|
||||
annotation_reply = self.query_app_annotations_to_reply(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from
|
||||
)
|
||||
|
||||
if annotation_reply:
|
||||
queue_manager.publish_annotation_reply(
|
||||
message_annotation_id=annotation_reply.id
|
||||
)
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=app_orchestration_config,
|
||||
prompt_messages=prompt_messages,
|
||||
text=annotation_reply.content,
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
return
|
||||
|
||||
# fill in variable inputs from external data tools if exists
|
||||
external_data_tools = app_orchestration_config.external_data_variables
|
||||
if external_data_tools:
|
||||
inputs = self.fill_in_inputs_from_external_data_tools(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_id=app_record.id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
# get context from datasets
|
||||
context = None
|
||||
if app_orchestration_config.dataset:
|
||||
context = self.retrieve_dataset_context(
|
||||
tenant_id=app_record.tenant_id,
|
||||
app_record=app_record,
|
||||
queue_manager=queue_manager,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
show_retrieve_source=app_orchestration_config.show_retrieve_source,
|
||||
dataset_config=app_orchestration_config.dataset,
|
||||
message=message,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# reorganize all inputs and template to prompt messages
|
||||
# Include: prompt template, inputs, query(optional), files(optional)
|
||||
# memory(optional), external data, dataset context(optional)
|
||||
prompt_messages, stop = self.originze_prompt_messages(
|
||||
app_record=app_record,
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_template_entity=app_orchestration_config.prompt_template,
|
||||
inputs=inputs,
|
||||
files=files,
|
||||
query=query,
|
||||
context=context,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
# check hosting moderation
|
||||
hosting_moderation_result = self.check_hosting_moderation(
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if hosting_moderation_result:
|
||||
return
|
||||
|
||||
# Re-calculate the max tokens if sum(prompt_token + max_tokens) over model token limit
|
||||
self.recale_llm_max_tokens(
|
||||
model_config=app_orchestration_config.model_config,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
# Invoke model
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=app_orchestration_config.model_config.provider_model_bundle,
|
||||
model=app_orchestration_config.model_config.model
|
||||
)
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_orchestration_config.model_config.parameters,
|
||||
stop=stop,
|
||||
stream=application_generate_entity.stream,
|
||||
user=application_generate_entity.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
self._handle_invoke_result(
|
||||
invoke_result=invoke_result,
|
||||
queue_manager=queue_manager,
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
def moderation_for_inputs(self, app_id: str,
|
||||
tenant_id: str,
|
||||
app_orchestration_config_entity: AppOrchestrationConfigEntity,
|
||||
inputs: dict,
|
||||
query: str) -> Tuple[bool, dict, str]:
|
||||
"""
|
||||
Process sensitive_word_avoidance.
|
||||
:param app_id: app id
|
||||
:param tenant_id: tenant id
|
||||
:param app_orchestration_config_entity: app orchestration config entity
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:return:
|
||||
"""
|
||||
moderation_feature = ModerationFeature()
|
||||
return moderation_feature.check(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
app_orchestration_config_entity=app_orchestration_config_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
)
|
||||
|
||||
def query_app_annotations_to_reply(self, app_record: App,
|
||||
message: Message,
|
||||
query: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Query app annotations to reply
|
||||
:param app_record: app record
|
||||
:param message: message
|
||||
:param query: query
|
||||
:param user_id: user id
|
||||
:param invoke_from: invoke from
|
||||
:return:
|
||||
"""
|
||||
annotation_reply_feature = AnnotationReplyFeature()
|
||||
return annotation_reply_feature.query(
|
||||
app_record=app_record,
|
||||
message=message,
|
||||
query=query,
|
||||
user_id=user_id,
|
||||
invoke_from=invoke_from
|
||||
)
|
||||
|
||||
def fill_in_inputs_from_external_data_tools(self, tenant_id: str,
|
||||
app_id: str,
|
||||
external_data_tools: list[ExternalDataVariableEntity],
|
||||
inputs: dict,
|
||||
query: str) -> dict:
|
||||
"""
|
||||
Fill in variable inputs from external data tools if exists.
|
||||
|
||||
:param tenant_id: workspace id
|
||||
:param app_id: app id
|
||||
:param external_data_tools: external data tools configs
|
||||
:param inputs: the inputs
|
||||
:param query: the query
|
||||
:return: the filled inputs
|
||||
"""
|
||||
external_data_fetch_feature = ExternalDataFetchFeature()
|
||||
return external_data_fetch_feature.fetch(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
external_data_tools=external_data_tools,
|
||||
inputs=inputs,
|
||||
query=query
|
||||
)
|
||||
|
||||
def retrieve_dataset_context(self, tenant_id: str,
|
||||
app_record: App,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
model_config: ModelConfigEntity,
|
||||
dataset_config: DatasetEntity,
|
||||
show_retrieve_source: bool,
|
||||
message: Message,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
user_id: str,
|
||||
invoke_from: InvokeFrom,
|
||||
memory: Optional[TokenBufferMemory] = None) -> Optional[str]:
|
||||
"""
|
||||
Retrieve dataset context
|
||||
:param tenant_id: tenant id
|
||||
:param app_record: app record
|
||||
:param queue_manager: queue manager
|
||||
:param model_config: model config
|
||||
:param dataset_config: dataset config
|
||||
:param show_retrieve_source: show retrieve source
|
||||
:param message: message
|
||||
:param inputs: inputs
|
||||
:param query: query
|
||||
:param user_id: user id
|
||||
:param invoke_from: invoke from
|
||||
:param memory: memory
|
||||
:return:
|
||||
"""
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager,
|
||||
app_record.id,
|
||||
message.id,
|
||||
user_id,
|
||||
invoke_from
|
||||
)
|
||||
|
||||
if (app_record.mode == AppMode.COMPLETION.value and dataset_config
|
||||
and dataset_config.retrieve_config.query_variable):
|
||||
query = inputs.get(dataset_config.retrieve_config.query_variable, "")
|
||||
|
||||
dataset_retrieval = DatasetRetrievalFeature()
|
||||
return dataset_retrieval.retrieve(
|
||||
tenant_id=tenant_id,
|
||||
model_config=model_config,
|
||||
config=dataset_config,
|
||||
query=query,
|
||||
invoke_from=invoke_from,
|
||||
show_retrieve_source=show_retrieve_source,
|
||||
hit_callback=hit_callback,
|
||||
memory=memory
|
||||
)
|
||||
|
||||
def check_hosting_moderation(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
prompt_messages: list[PromptMessage]) -> bool:
|
||||
"""
|
||||
Check hosting moderation
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
hosting_moderation_feature = HostingModerationFeature()
|
||||
moderation_result = hosting_moderation_feature.check(
|
||||
application_generate_entity=application_generate_entity,
|
||||
prompt_messages=prompt_messages
|
||||
)
|
||||
|
||||
if moderation_result:
|
||||
self.direct_output(
|
||||
queue_manager=queue_manager,
|
||||
app_orchestration_config=application_generate_entity.app_orchestration_config_entity,
|
||||
prompt_messages=prompt_messages,
|
||||
text="I apologize for any confusion, " \
|
||||
"but I'm an AI assistant to be helpful, harmless, and honest.",
|
||||
stream=application_generate_entity.stream
|
||||
)
|
||||
|
||||
return moderation_result
|
483
api/core/app_runner/generate_task_pipeline.py
Normal file
483
api/core/app_runner/generate_task_pipeline.py
Normal file
@@ -0,0 +1,483 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Union, Generator, cast, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app_runner.moderation_handler import OutputModerationHandler, ModerationRule
|
||||
from core.entities.application_entities import ApplicationGenerateEntity
|
||||
from core.application_queue_manager import ApplicationQueueManager
|
||||
from core.entities.queue_entities import QueueErrorEvent, QueueStopEvent, QueueMessageEndEvent, \
|
||||
QueueRetrieverResourcesEvent, QueueAgentThoughtEvent, QueuePingEvent, QueueMessageEvent, QueueMessageReplaceEvent, \
|
||||
AnnotationReplyEvent
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessageRole, \
|
||||
TextPromptMessageContent, PromptMessageContentType, ImagePromptMessageContent, PromptMessage
|
||||
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, Conversation, MessageAgentThought
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TaskState(BaseModel):
|
||||
"""
|
||||
TaskState entity
|
||||
"""
|
||||
llm_result: LLMResult
|
||||
metadata: dict = {}
|
||||
|
||||
|
||||
class GenerateTaskPipeline:
|
||||
"""
|
||||
GenerateTaskPipeline is a class that generate stream output and state management for Application.
|
||||
"""
|
||||
|
||||
def __init__(self, application_generate_entity: ApplicationGenerateEntity,
|
||||
queue_manager: ApplicationQueueManager,
|
||||
conversation: Conversation,
|
||||
message: Message) -> None:
|
||||
"""
|
||||
Initialize GenerateTaskPipeline.
|
||||
:param application_generate_entity: application generate entity
|
||||
:param queue_manager: queue manager
|
||||
:param conversation: conversation
|
||||
:param message: message
|
||||
"""
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._queue_manager = queue_manager
|
||||
self._conversation = conversation
|
||||
self._message = message
|
||||
self._task_state = TaskState(
|
||||
llm_result=LLMResult(
|
||||
model=self._application_generate_entity.app_orchestration_config_entity.model_config.model,
|
||||
prompt_messages=[],
|
||||
message=AssistantPromptMessage(content=""),
|
||||
usage=LLMUsage.empty_usage()
|
||||
)
|
||||
)
|
||||
self._start_at = time.perf_counter()
|
||||
self._output_moderation_handler = self._init_output_moderation()
|
||||
|
||||
def process(self, stream: bool) -> Union[dict, Generator]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
:return:
|
||||
"""
|
||||
if stream:
|
||||
return self._process_stream_response()
|
||||
else:
|
||||
return self._process_blocking_response()
|
||||
|
||||
def _process_blocking_response(self) -> dict:
|
||||
"""
|
||||
Process blocking response.
|
||||
:return:
|
||||
"""
|
||||
for queue_message in self._queue_manager.listen():
|
||||
event = queue_message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
raise self._handle_error(event)
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
elif isinstance(event, AnnotationReplyEvent):
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata['annotation_reply'] = {
|
||||
'id': annotation.id,
|
||||
'account': {
|
||||
'id': annotation.account_id,
|
||||
'name': account.name if account else 'Dify user'
|
||||
}
|
||||
}
|
||||
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
|
||||
model = model_config.model
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = 0
|
||||
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model,
|
||||
model_config.credentials,
|
||||
self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
|
||||
completion_tokens = model_type_instance.get_num_tokens(
|
||||
model,
|
||||
model_config.credentials,
|
||||
[self._task_state.llm_result.message]
|
||||
)
|
||||
|
||||
credentials = model_config.credentials
|
||||
|
||||
# transform usage
|
||||
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
|
||||
model,
|
||||
credentials,
|
||||
prompt_tokens,
|
||||
completion_tokens
|
||||
)
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
|
||||
self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
|
||||
completion=self._task_state.llm_result.message.content,
|
||||
public_event=False
|
||||
)
|
||||
|
||||
# Save message
|
||||
self._save_message(event.llm_result)
|
||||
|
||||
response = {
|
||||
'event': 'message',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'id': self._message.id,
|
||||
'mode': self._conversation.mode,
|
||||
'answer': event.llm_result.message.content,
|
||||
'metadata': {},
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
if self._task_state.metadata:
|
||||
response['metadata'] = self._task_state.metadata
|
||||
|
||||
return response
|
||||
else:
|
||||
continue
|
||||
|
||||
def _process_stream_response(self) -> Generator:
|
||||
"""
|
||||
Process stream response.
|
||||
:return:
|
||||
"""
|
||||
for message in self._queue_manager.listen():
|
||||
event = message.event
|
||||
|
||||
if isinstance(event, QueueErrorEvent):
|
||||
raise self._handle_error(event)
|
||||
elif isinstance(event, (QueueStopEvent, QueueMessageEndEvent)):
|
||||
if isinstance(event, QueueMessageEndEvent):
|
||||
self._task_state.llm_result = event.llm_result
|
||||
else:
|
||||
model_config = self._application_generate_entity.app_orchestration_config_entity.model_config
|
||||
model = model_config.model
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = 0
|
||||
if event.stopped_by != QueueStopEvent.StopBy.ANNOTATION_REPLY:
|
||||
prompt_tokens = model_type_instance.get_num_tokens(
|
||||
model,
|
||||
model_config.credentials,
|
||||
self._task_state.llm_result.prompt_messages
|
||||
)
|
||||
|
||||
completion_tokens = 0
|
||||
if event.stopped_by == QueueStopEvent.StopBy.USER_MANUAL:
|
||||
completion_tokens = model_type_instance.get_num_tokens(
|
||||
model,
|
||||
model_config.credentials,
|
||||
[self._task_state.llm_result.message]
|
||||
)
|
||||
|
||||
credentials = model_config.credentials
|
||||
|
||||
# transform usage
|
||||
self._task_state.llm_result.usage = model_type_instance._calc_response_usage(
|
||||
model,
|
||||
credentials,
|
||||
prompt_tokens,
|
||||
completion_tokens
|
||||
)
|
||||
|
||||
# response moderation
|
||||
if self._output_moderation_handler:
|
||||
self._output_moderation_handler.stop_thread()
|
||||
|
||||
self._task_state.llm_result.message.content = self._output_moderation_handler.moderation_completion(
|
||||
completion=self._task_state.llm_result.message.content,
|
||||
public_event=False
|
||||
)
|
||||
|
||||
self._output_moderation_handler = None
|
||||
|
||||
replace_response = {
|
||||
'event': 'message_replace',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
'answer': self._task_state.llm_result.message.content,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
replace_response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(replace_response)
|
||||
|
||||
# Save message
|
||||
self._save_message(self._task_state.llm_result)
|
||||
|
||||
response = {
|
||||
'event': 'message_end',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'id': self._message.id,
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
if self._task_state.metadata:
|
||||
response['metadata'] = self._task_state.metadata
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueRetrieverResourcesEvent):
|
||||
self._task_state.metadata['retriever_resources'] = event.retriever_resources
|
||||
elif isinstance(event, AnnotationReplyEvent):
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata['annotation_reply'] = {
|
||||
'id': annotation.id,
|
||||
'account': {
|
||||
'id': annotation.account_id,
|
||||
'name': account.name if account else 'Dify user'
|
||||
}
|
||||
}
|
||||
|
||||
self._task_state.llm_result.message.content = annotation.content
|
||||
elif isinstance(event, QueueAgentThoughtEvent):
|
||||
agent_thought = (
|
||||
db.session.query(MessageAgentThought)
|
||||
.filter(MessageAgentThought.id == event.agent_thought_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if agent_thought:
|
||||
response = {
|
||||
'event': 'agent_thought',
|
||||
'id': agent_thought.id,
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
'position': agent_thought.position,
|
||||
'thought': agent_thought.thought,
|
||||
'tool': agent_thought.tool,
|
||||
'tool_input': agent_thought.tool_input,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueMessageEvent):
|
||||
chunk = event.chunk
|
||||
delta_text = chunk.delta.message.content
|
||||
if delta_text is None:
|
||||
continue
|
||||
|
||||
if not self._task_state.llm_result.prompt_messages:
|
||||
self._task_state.llm_result.prompt_messages = chunk.prompt_messages
|
||||
|
||||
if self._output_moderation_handler:
|
||||
if self._output_moderation_handler.should_direct_output():
|
||||
# stop subscribe new token when output moderation should direct output
|
||||
self._task_state.llm_result.message.content = self._output_moderation_handler.get_final_output()
|
||||
self._queue_manager.publish_chunk_message(LLMResultChunk(
|
||||
model=self._task_state.llm_result.model,
|
||||
prompt_messages=self._task_state.llm_result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=self._task_state.llm_result.message.content)
|
||||
)
|
||||
))
|
||||
self._queue_manager.publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.OUTPUT_MODERATION))
|
||||
continue
|
||||
else:
|
||||
self._output_moderation_handler.append_new_token(delta_text)
|
||||
|
||||
self._task_state.llm_result.message.content += delta_text
|
||||
response = self._handle_chunk(delta_text)
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueueMessageReplaceEvent):
|
||||
response = {
|
||||
'event': 'message_replace',
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
'answer': event.text,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
yield self._yield_response(response)
|
||||
elif isinstance(event, QueuePingEvent):
|
||||
yield "event: ping\n\n"
|
||||
else:
|
||||
continue
|
||||
|
||||
def _save_message(self, llm_result: LLMResult) -> None:
|
||||
"""
|
||||
Save message.
|
||||
:param llm_result: llm result
|
||||
:return:
|
||||
"""
|
||||
usage = llm_result.usage
|
||||
|
||||
self._message = db.session.query(Message).filter(Message.id == self._message.id).first()
|
||||
|
||||
self._message.message = self._prompt_messages_to_prompt_for_saving(self._task_state.llm_result.prompt_messages)
|
||||
self._message.message_tokens = usage.prompt_tokens
|
||||
self._message.message_unit_price = usage.prompt_unit_price
|
||||
self._message.message_price_unit = usage.prompt_price_unit
|
||||
self._message.answer = PromptTemplateParser.remove_template_variables(llm_result.message.content.strip()) \
|
||||
if llm_result.message.content else ''
|
||||
self._message.answer_tokens = usage.completion_tokens
|
||||
self._message.answer_unit_price = usage.completion_unit_price
|
||||
self._message.answer_price_unit = usage.completion_price_unit
|
||||
self._message.provider_response_latency = time.perf_counter() - self._start_at
|
||||
self._message.total_price = usage.total_price
|
||||
|
||||
db.session.commit()
|
||||
|
||||
message_was_created.send(
|
||||
self._message,
|
||||
application_generate_entity=self._application_generate_entity,
|
||||
conversation=self._conversation,
|
||||
is_first_message=self._application_generate_entity.conversation_id is None,
|
||||
extras=self._application_generate_entity.extras
|
||||
)
|
||||
|
||||
def _handle_chunk(self, text: str) -> dict:
|
||||
"""
|
||||
Handle completed event.
|
||||
:param text: text
|
||||
:return:
|
||||
"""
|
||||
response = {
|
||||
'event': 'message',
|
||||
'id': self._message.id,
|
||||
'task_id': self._application_generate_entity.task_id,
|
||||
'message_id': self._message.id,
|
||||
'answer': text,
|
||||
'created_at': int(self._message.created_at.timestamp())
|
||||
}
|
||||
|
||||
if self._conversation.mode == 'chat':
|
||||
response['conversation_id'] = self._conversation.id
|
||||
|
||||
return response
|
||||
|
||||
def _handle_error(self, event: QueueErrorEvent) -> Exception:
|
||||
"""
|
||||
Handle error event.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
logger.debug("error: %s", event.error)
|
||||
e = event.error
|
||||
|
||||
if isinstance(e, InvokeAuthorizationError):
|
||||
return InvokeAuthorizationError('Incorrect API key provided')
|
||||
elif isinstance(e, InvokeError) or isinstance(e, ValueError):
|
||||
return e
|
||||
else:
|
||||
return Exception(e.description if getattr(e, 'description', None) is not None else str(e))
|
||||
|
||||
def _yield_response(self, response: dict) -> str:
|
||||
"""
|
||||
Yield response.
|
||||
:param response: response
|
||||
:return:
|
||||
"""
|
||||
return "data: " + json.dumps(response) + "\n\n"
|
||||
|
||||
def _prompt_messages_to_prompt_for_saving(self, prompt_messages: list[PromptMessage]) -> list[dict]:
|
||||
"""
|
||||
Prompt messages to prompt for saving.
|
||||
:param prompt_messages: prompt messages
|
||||
:return:
|
||||
"""
|
||||
prompts = []
|
||||
if self._application_generate_entity.app_orchestration_config_entity.model_config.mode == 'chat':
|
||||
for prompt_message in prompt_messages:
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
role = 'user'
|
||||
elif prompt_message.role == PromptMessageRole.ASSISTANT:
|
||||
role = 'assistant'
|
||||
elif prompt_message.role == PromptMessageRole.SYSTEM:
|
||||
role = 'system'
|
||||
else:
|
||||
continue
|
||||
|
||||
text = ''
|
||||
files = []
|
||||
if isinstance(prompt_message.content, list):
|
||||
for content in prompt_message.content:
|
||||
if content.type == PromptMessageContentType.TEXT:
|
||||
content = cast(TextPromptMessageContent, content)
|
||||
text += content.data
|
||||
else:
|
||||
content = cast(ImagePromptMessageContent, content)
|
||||
files.append({
|
||||
"type": 'image',
|
||||
"data": content.data[:10] + '...[TRUNCATED]...' + content.data[-10:],
|
||||
"detail": content.detail.value
|
||||
})
|
||||
else:
|
||||
text = prompt_message.content
|
||||
|
||||
prompts.append({
|
||||
"role": role,
|
||||
"text": text,
|
||||
"files": files
|
||||
})
|
||||
else:
|
||||
prompts.append({
|
||||
"role": 'user',
|
||||
"text": prompt_messages[0].content
|
||||
})
|
||||
|
||||
return prompts
|
||||
|
||||
def _init_output_moderation(self) -> Optional[OutputModerationHandler]:
|
||||
"""
|
||||
Init output moderation.
|
||||
:return:
|
||||
"""
|
||||
app_orchestration_config_entity = self._application_generate_entity.app_orchestration_config_entity
|
||||
sensitive_word_avoidance = app_orchestration_config_entity.sensitive_word_avoidance
|
||||
|
||||
if sensitive_word_avoidance:
|
||||
return OutputModerationHandler(
|
||||
tenant_id=self._application_generate_entity.tenant_id,
|
||||
app_id=self._application_generate_entity.app_id,
|
||||
rule=ModerationRule(
|
||||
type=sensitive_word_avoidance.type,
|
||||
config=sensitive_word_avoidance.config
|
||||
),
|
||||
on_message_replace_func=self._queue_manager.publish_message_replace
|
||||
)
|
138
api/core/app_runner/moderation_handler.py
Normal file
138
api/core/app_runner/moderation_handler.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, Optional, Dict
|
||||
|
||||
from flask import current_app, Flask
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.moderation.base import ModerationAction, ModerationOutputsResult
|
||||
from core.moderation.factory import ModerationFactory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ModerationRule(BaseModel):
|
||||
type: str
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
class OutputModerationHandler(BaseModel):
|
||||
DEFAULT_BUFFER_SIZE: int = 300
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
|
||||
rule: ModerationRule
|
||||
on_message_replace_func: Any
|
||||
|
||||
thread: Optional[threading.Thread] = None
|
||||
thread_running: bool = True
|
||||
buffer: str = ''
|
||||
is_final_chunk: bool = False
|
||||
final_output: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def should_direct_output(self):
|
||||
return self.final_output is not None
|
||||
|
||||
def get_final_output(self):
|
||||
return self.final_output
|
||||
|
||||
def append_new_token(self, token: str):
|
||||
self.buffer += token
|
||||
|
||||
if not self.thread:
|
||||
self.thread = self.start_thread()
|
||||
|
||||
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
|
||||
self.buffer = completion
|
||||
self.is_final_chunk = True
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=completion
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
return completion
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
else:
|
||||
final_output = result.text
|
||||
|
||||
if public_event:
|
||||
self.on_message_replace_func(final_output)
|
||||
|
||||
return final_output
|
||||
|
||||
def start_thread(self) -> threading.Thread:
|
||||
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
|
||||
thread = threading.Thread(target=self.worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
|
||||
})
|
||||
|
||||
thread.start()
|
||||
|
||||
return thread
|
||||
|
||||
def stop_thread(self):
|
||||
if self.thread and self.thread.is_alive():
|
||||
self.thread_running = False
|
||||
|
||||
def worker(self, flask_app: Flask, buffer_size: int):
|
||||
with flask_app.app_context():
|
||||
current_length = 0
|
||||
while self.thread_running:
|
||||
moderation_buffer = self.buffer
|
||||
buffer_length = len(moderation_buffer)
|
||||
if not self.is_final_chunk:
|
||||
chunk_length = buffer_length - current_length
|
||||
if 0 <= chunk_length < buffer_size:
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
current_length = buffer_length
|
||||
|
||||
result = self.moderation(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
moderation_buffer=moderation_buffer
|
||||
)
|
||||
|
||||
if not result or not result.flagged:
|
||||
continue
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
final_output = result.preset_response
|
||||
self.final_output = final_output
|
||||
else:
|
||||
final_output = result.text + self.buffer[len(moderation_buffer):]
|
||||
|
||||
# trigger replace event
|
||||
if self.thread_running:
|
||||
self.on_message_replace_func(final_output)
|
||||
|
||||
if result.action == ModerationAction.DIRECT_OUTPUT:
|
||||
break
|
||||
|
||||
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
|
||||
try:
|
||||
moderation_factory = ModerationFactory(
|
||||
name=self.rule.type,
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
config=self.rule.config
|
||||
)
|
||||
|
||||
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error("Moderation Output error: %s", e)
|
||||
|
||||
return None
|
Reference in New Issue
Block a user