FEAT: NEW WORKFLOW ENGINE (#3160)

Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: Yeuoly <admin@srmxy.cn>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
Co-authored-by: jyong <jyong@dify.ai>
Co-authored-by: nite-knite <nkCoding@gmail.com>
Co-authored-by: jyong <718720800@qq.com>
This commit is contained in:
takatost
2024-04-08 18:51:46 +08:00
committed by GitHub
parent 2fb9850af5
commit 7753ba2d37
1161 changed files with 103836 additions and 10327 deletions

View File

View File

@@ -1,262 +0,0 @@
import json
import logging
import time
from typing import Any, Optional, Union, cast
from langchain.agents import openai_functions_agent, openai_functions_multi_agent
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.callback_handler.entity.agent_loop import AgentLoop
from core.entities.application_entities import ModelConfigEntity
from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult
from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, UserPromptMessage
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from extensions.ext_database import db
from models.model import Message, MessageAgentThought, MessageChain
class AgentLoopGatherCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
raise_error: bool = True
def __init__(self, model_config: ModelConfigEntity,
queue_manager: ApplicationQueueManager,
message: Message,
message_chain: MessageChain) -> None:
"""Initialize callback handler."""
self.model_config = model_config
self.queue_manager = queue_manager
self.message = message
self.message_chain = message_chain
model_type_instance = self.model_config.provider_model_bundle.model_type_instance
self.model_type_instance = cast(LargeLanguageModel, model_type_instance)
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
@property
def agent_loops(self) -> list[AgentLoop]:
return self._agent_loops
def clear_agent_loops(self) -> None:
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
return True
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return True
def on_llm_before_invoke(self, prompt_messages: list[PromptMessage]) -> None:
if not self._current_loop:
# Agent start with a LLM query
self._current_loop = AgentLoop(
position=len(self._agent_loops) + 1,
prompt="\n".join([prompt_message.content for prompt_message in prompt_messages]),
status='llm_started',
started_at=time.perf_counter()
)
def on_llm_after_invoke(self, result: RuntimeLLMResult) -> None:
if self._current_loop and self._current_loop.status == 'llm_started':
self._current_loop.status = 'llm_end'
if result.usage:
self._current_loop.prompt_tokens = result.usage.prompt_tokens
else:
self._current_loop.prompt_tokens = self.model_type_instance.get_num_tokens(
model=self.model_config.model,
credentials=self.model_config.credentials,
prompt_messages=[UserPromptMessage(content=self._current_loop.prompt)]
)
completion_message = result.message
if completion_message.tool_calls:
self._current_loop.completion \
= json.dumps({'function_call': completion_message.tool_calls})
else:
self._current_loop.completion = completion_message.content
if result.usage:
self._current_loop.completion_tokens = result.usage.completion_tokens
else:
self._current_loop.completion_tokens = self.model_type_instance.get_num_tokens(
model=self.model_config.model,
credentials=self.model_config.credentials,
prompt_messages=[AssistantPromptMessage(content=self._current_loop.completion)]
)
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
**kwargs: Any
) -> Any:
pass
def on_llm_start(
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None:
pass
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
logging.debug("Agent on_llm_error: %s", error)
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
def on_tool_start(
self,
serialized: dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing."""
# kwargs={'color': 'green', 'llm_prefix': 'Thought:', 'observation_prefix': 'Observation: '}
# input_str='action-input'
# serialized={'description': 'A search engine. Useful for when you need to answer questions about current events. Input should be a search query.', 'name': 'Search'}
pass
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
"""Run on agent action."""
tool = action.tool
tool_input = json.dumps({"query": action.tool_input}
if isinstance(action.tool_input, str) else action.tool_input)
completion = None
if isinstance(action, openai_functions_agent.base._FunctionsAgentAction) \
or isinstance(action, openai_functions_multi_agent.base._FunctionsAgentAction):
thought = action.log.strip()
completion = json.dumps({'function_call': action.message_log[0].additional_kwargs['function_call']})
else:
action_name_position = action.log.index("Action:") if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else ''
if self._current_loop and self._current_loop.status == 'llm_end':
self._current_loop.status = 'agent_action'
self._current_loop.thought = thought
self._current_loop.tool_name = tool
self._current_loop.tool_input = tool_input
if completion is not None:
self._current_loop.completion = completion
self._message_agent_thought = self._init_agent_thought()
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
# kwargs={'name': 'Search'}
# llm_prefix='Thought:'
# observation_prefix='Observation: '
# output='53 years'
if self._current_loop and self._current_loop.status == 'agent_action' and output and output != 'None':
self._current_loop.status = 'tool_end'
self._current_loop.tool_output = output
self._current_loop.completed = True
self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self._complete_agent_thought(self._message_agent_thought)
self._agent_loops.append(self._current_loop)
self._current_loop = None
self._message_agent_thought = None
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
logging.debug("Agent on_tool_error: %s", error)
self._agent_loops = []
self._current_loop = None
self._message_agent_thought = None
def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
"""Run on agent end."""
# Final Answer
if self._current_loop and (self._current_loop.status == 'llm_end' or self._current_loop.status == 'agent_action'):
self._current_loop.status = 'agent_finish'
self._current_loop.completed = True
self._current_loop.completed_at = time.perf_counter()
self._current_loop.latency = self._current_loop.completed_at - self._current_loop.started_at
self._current_loop.thought = '[DONE]'
self._message_agent_thought = self._init_agent_thought()
self._complete_agent_thought(self._message_agent_thought)
self._agent_loops.append(self._current_loop)
self._current_loop = None
self._message_agent_thought = None
elif not self._current_loop and self._agent_loops:
self._agent_loops[-1].status = 'agent_finish'
def _init_agent_thought(self) -> MessageAgentThought:
message_agent_thought = MessageAgentThought(
message_id=self.message.id,
message_chain_id=self.message_chain.id,
position=self._current_loop.position,
thought=self._current_loop.thought,
tool=self._current_loop.tool_name,
tool_input=self._current_loop.tool_input,
message=self._current_loop.prompt,
message_price_unit=0,
answer=self._current_loop.completion,
answer_price_unit=0,
created_by_role=('account' if self.message.from_source == 'console' else 'end_user'),
created_by=(self.message.from_account_id
if self.message.from_source == 'console' else self.message.from_end_user_id)
)
db.session.add(message_agent_thought)
db.session.commit()
self.queue_manager.publish_agent_thought(message_agent_thought, PublishFrom.APPLICATION_MANAGER)
return message_agent_thought
def _complete_agent_thought(self, message_agent_thought: MessageAgentThought) -> None:
loop_message_tokens = self._current_loop.prompt_tokens
loop_answer_tokens = self._current_loop.completion_tokens
# transform usage
llm_usage = self.model_type_instance._calc_response_usage(
self.model_config.model,
self.model_config.credentials,
loop_message_tokens,
loop_answer_tokens
)
message_agent_thought.observation = self._current_loop.tool_output
message_agent_thought.tool_process_data = '' # currently not support
message_agent_thought.message_token = loop_message_tokens
message_agent_thought.message_unit_price = llm_usage.prompt_unit_price
message_agent_thought.message_price_unit = llm_usage.prompt_price_unit
message_agent_thought.answer_token = loop_answer_tokens
message_agent_thought.answer_unit_price = llm_usage.completion_unit_price
message_agent_thought.answer_price_unit = llm_usage.completion_price_unit
message_agent_thought.latency = self._current_loop.latency
message_agent_thought.tokens = self._current_loop.prompt_tokens + self._current_loop.completion_tokens
message_agent_thought.total_price = llm_usage.total_price
message_agent_thought.currency = llm_usage.currency
db.session.commit()

View File

@@ -36,7 +36,7 @@ class DifyAgentCallbackHandler(BaseCallbackHandler, BaseModel):
print_text("\n[on_tool_end]\n", color=self.color)
print_text("Tool: " + tool_name + "\n", color=self.color)
print_text("Inputs: " + str(tool_inputs) + "\n", color=self.color)
print_text("Outputs: " + str(tool_outputs) + "\n", color=self.color)
print_text("Outputs: " + str(tool_outputs)[:1000] + "\n", color=self.color)
print_text("\n")
def on_tool_error(

View File

@@ -1,23 +0,0 @@
from pydantic import BaseModel
class AgentLoop(BaseModel):
position: int = 1
thought: str = None
tool_name: str = None
tool_input: str = None
tool_output: str = None
prompt: str = None
prompt_tokens: int = 0
completion: str = None
completion_tokens: int = 0
latency: float = None
status: str = 'llm_started'
completed: bool = False
started_at: float = None
completed_at: float = None

View File

@@ -1,6 +1,7 @@
from core.application_queue_manager import ApplicationQueueManager, PublishFrom
from core.entities.application_entities import InvokeFrom
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import DatasetQuery, DocumentSegment
@@ -10,7 +11,7 @@ from models.model import DatasetRetrieverResource
class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool."""
def __init__(self, queue_manager: ApplicationQueueManager,
def __init__(self, queue_manager: AppQueueManager,
app_id: str,
message_id: str,
user_id: str,
@@ -82,4 +83,7 @@ class DatasetIndexToolCallbackHandler:
db.session.add(dataset_retriever_resource)
db.session.commit()
self._queue_manager.publish_retriever_resources(resource, PublishFrom.APPLICATION_MANAGER)
self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource),
PublishFrom.APPLICATION_MANAGER
)

View File

@@ -1,157 +0,0 @@
import os
import sys
from typing import Any, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish, BaseMessage, LLMResult
class DifyStdOutCallbackHandler(BaseCallbackHandler):
"""Callback Handler that prints to std out."""
def __init__(self, color: Optional[str] = None) -> None:
"""Initialize callback handler."""
self.color = color
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
**kwargs: Any
) -> Any:
print_text("\n[on_chat_model_start]\n", color='blue')
for sub_messages in messages:
for sub_message in sub_messages:
print_text(str(sub_message) + "\n", color='blue')
def on_llm_start(
self, serialized: dict[str, Any], prompts: list[str], **kwargs: Any
) -> None:
"""Print out the prompts."""
print_text("\n[on_llm_start]\n", color='blue')
print_text(prompts[0] + "\n", color='blue')
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Do nothing."""
print_text("\n[on_llm_end]\nOutput: " + str(response.generations[0][0].text) + "\nllm_output: " + str(
response.llm_output) + "\n", color='blue')
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Do nothing."""
pass
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
print_text("\n[on_llm_error]\nError: " + str(error) + "\n", color='blue')
def on_chain_start(
self, serialized: dict[str, Any], inputs: dict[str, Any], **kwargs: Any
) -> None:
"""Print out that we are entering a chain."""
chain_type = serialized['id'][-1]
print_text("\n[on_chain_start]\nChain: " + chain_type + "\nInputs: " + str(inputs) + "\n", color='pink')
def on_chain_end(self, outputs: dict[str, Any], **kwargs: Any) -> None:
"""Print out that we finished a chain."""
print_text("\n[on_chain_end]\nOutputs: " + str(outputs) + "\n", color='pink')
def on_chain_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
print_text("\n[on_chain_error]\nError: " + str(error) + "\n", color='pink')
def on_tool_start(
self,
serialized: dict[str, Any],
input_str: str,
**kwargs: Any,
) -> None:
"""Do nothing."""
print_text("\n[on_tool_start] " + str(serialized), color='yellow')
def on_agent_action(
self, action: AgentAction, color: Optional[str] = None, **kwargs: Any
) -> Any:
"""Run on agent action."""
tool = action.tool
tool_input = action.tool_input
try:
action_name_position = action.log.index("\nAction:") + 1 if action.log else -1
thought = action.log[:action_name_position].strip() if action.log else ''
except ValueError:
thought = ''
log = f"Thought: {thought}\nTool: {tool}\nTool Input: {tool_input}"
print_text("\n[on_agent_action]\n" + log + "\n", color='green')
def on_tool_end(
self,
output: str,
color: Optional[str] = None,
observation_prefix: Optional[str] = None,
llm_prefix: Optional[str] = None,
**kwargs: Any,
) -> None:
"""If not the final action, print out observation."""
print_text("\n[on_tool_end]\n", color='yellow')
if observation_prefix:
print_text(f"\n{observation_prefix}")
print_text(output, color='yellow')
if llm_prefix:
print_text(f"\n{llm_prefix}")
print_text("\n")
def on_tool_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
print_text("\n[on_tool_error] Error: " + str(error) + "\n", color='yellow')
def on_text(
self,
text: str,
color: Optional[str] = None,
end: str = "",
**kwargs: Optional[str],
) -> None:
"""Run when agent ends."""
print_text("\n[on_text] " + text + "\n", color=color if color else self.color, end=end)
def on_agent_finish(
self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
) -> None:
"""Run on agent end."""
print_text("[on_agent_finish] " + finish.return_values['output'] + "\n", color='green', end="\n")
@property
def ignore_llm(self) -> bool:
"""Whether to ignore LLM callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_chain(self) -> bool:
"""Whether to ignore chain callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_agent(self) -> bool:
"""Whether to ignore agent callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
@property
def ignore_chat_model(self) -> bool:
"""Whether to ignore chat model callbacks."""
return not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true'
class DifyStreamingStdOutCallbackHandler(DifyStdOutCallbackHandler):
"""Callback handler for streaming. Only works with LLMs that support streaming."""
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
"""Run on new LLM token. Only available when streaming is enabled."""
sys.stdout.write(token)
sys.stdout.flush()

View File

@@ -0,0 +1,5 @@
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
"""Callback Handler that prints to std out."""