feat: llm text stream support for workflow app (#3798)

Co-authored-by: JzoNg <jzongcode@gmail.com>
This commit is contained in:
takatost
2024-04-28 17:37:00 +08:00
committed by GitHub
parent 8e4989ed03
commit ff67a6d338
27 changed files with 549 additions and 58 deletions

View File

@@ -28,9 +28,9 @@ from core.app.entities.task_entities import (
AdvancedChatTaskState,
ChatbotAppBlockingResponse,
ChatbotAppStreamResponse,
ChatflowStreamGenerateRoute,
ErrorStreamResponse,
MessageEndStreamResponse,
StreamGenerateRoute,
StreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
@@ -343,7 +343,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
**extras
)
def _get_stream_generate_routes(self) -> dict[str, StreamGenerateRoute]:
def _get_stream_generate_routes(self) -> dict[str, ChatflowStreamGenerateRoute]:
"""
Get stream generate routes.
:return:
@@ -366,7 +366,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
continue
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = StreamGenerateRoute(
stream_generate_routes[start_node_id] = ChatflowStreamGenerateRoute(
answer_node_id=answer_node_id,
generate_route=generate_route
)
@@ -430,15 +430,14 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
for token in route_chunk.text:
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(token)
if should_direct_answer:
continue
self._task_state.answer += token
yield self._message_to_stream_response(token, self._message.id)
time.sleep(0.01)
# handle output moderation chunk
should_direct_answer = self._handle_output_moderation_chunk(route_chunk.text)
if should_direct_answer:
continue
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
break
@@ -463,10 +462,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
for route_chunk in route_chunks:
if route_chunk.type == 'text':
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
for token in route_chunk.text:
self._task_state.answer += token
yield self._message_to_stream_response(token, self._message.id)
time.sleep(0.01)
self._task_state.answer += route_chunk.text
yield self._message_to_stream_response(route_chunk.text, self._message.id)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
value_selector = route_chunk.value_selector

View File

@@ -28,11 +28,13 @@ from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStreamGenerateNodes,
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
from core.workflow.entities.node_entities import SystemVariable
from core.workflow.entities.node_entities import NodeType, SystemVariable
from core.workflow.nodes.end.end_node import EndNode
from extensions.ext_database import db
from models.account import Account
from models.model import EndUser
@@ -40,6 +42,7 @@ from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
WorkflowNodeExecution,
WorkflowRun,
)
@@ -83,6 +86,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
}
self._task_state = WorkflowTaskState()
self._stream_generate_nodes = self._get_stream_generate_nodes()
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@@ -167,6 +171,14 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
)
elif isinstance(event, QueueNodeStartedEvent):
workflow_node_execution = self._handle_node_start(event)
# search stream_generate_routes if node id is answer start at node
if not self._task_state.current_stream_generate_state and event.node_id in self._stream_generate_nodes:
self._task_state.current_stream_generate_state = self._stream_generate_nodes[event.node_id]
# generate stream outputs when node started
yield from self._generate_stream_outputs_when_node_started()
yield self._workflow_node_start_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
@@ -174,6 +186,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
)
elif isinstance(event, QueueNodeSucceededEvent | QueueNodeFailedEvent):
workflow_node_execution = self._handle_node_finished(event)
yield self._workflow_node_finish_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution
@@ -193,6 +206,11 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if delta_text is None:
continue
if not self._is_stream_out_support(
event=event
):
continue
self._task_state.answer += delta_text
yield self._text_chunk_to_stream_response(delta_text)
elif isinstance(event, QueueMessageReplaceEvent):
@@ -254,3 +272,140 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
task_id=self._application_generate_entity.task_id,
text=TextReplaceStreamResponse.Data(text=text)
)
def _get_stream_generate_nodes(self) -> dict[str, WorkflowStreamGenerateNodes]:
"""
Get stream generate nodes.
:return:
"""
# find all answer nodes
graph = self._workflow.graph_dict
end_node_configs = [
node for node in graph['nodes']
if node.get('data', {}).get('type') == NodeType.END.value
]
# parse stream output node value selectors of end nodes
stream_generate_routes = {}
for node_config in end_node_configs:
# get generate route for stream output
end_node_id = node_config['id']
generate_nodes = EndNode.extract_generate_nodes(graph, node_config)
start_node_ids = self._get_end_start_at_node_ids(graph, end_node_id)
if not start_node_ids:
continue
for start_node_id in start_node_ids:
stream_generate_routes[start_node_id] = WorkflowStreamGenerateNodes(
end_node_id=end_node_id,
stream_node_ids=generate_nodes
)
return stream_generate_routes
def _get_end_start_at_node_ids(self, graph: dict, target_node_id: str) \
-> list[str]:
"""
Get end start at node id.
:param graph: graph
:param target_node_id: target node ID
:return:
"""
nodes = graph.get('nodes')
edges = graph.get('edges')
# fetch all ingoing edges from source node
ingoing_edges = []
for edge in edges:
if edge.get('target') == target_node_id:
ingoing_edges.append(edge)
if not ingoing_edges:
return []
start_node_ids = []
for ingoing_edge in ingoing_edges:
source_node_id = ingoing_edge.get('source')
source_node = next((node for node in nodes if node.get('id') == source_node_id), None)
if not source_node:
continue
node_type = source_node.get('data', {}).get('type')
if node_type in [
NodeType.IF_ELSE.value,
NodeType.QUESTION_CLASSIFIER.value
]:
start_node_id = target_node_id
start_node_ids.append(start_node_id)
elif node_type == NodeType.START.value:
start_node_id = source_node_id
start_node_ids.append(start_node_id)
else:
sub_start_node_ids = self._get_end_start_at_node_ids(graph, source_node_id)
if sub_start_node_ids:
start_node_ids.extend(sub_start_node_ids)
return start_node_ids
def _generate_stream_outputs_when_node_started(self) -> Generator:
"""
Generate stream outputs.
:return:
"""
if self._task_state.current_stream_generate_state:
stream_node_ids = self._task_state.current_stream_generate_state.stream_node_ids
for node_id, node_execution_info in self._task_state.ran_node_execution_infos.items():
if node_id not in stream_node_ids:
continue
node_execution_info = self._task_state.ran_node_execution_infos[node_id]
# get chunk node execution
route_chunk_node_execution = db.session.query(WorkflowNodeExecution).filter(
WorkflowNodeExecution.id == node_execution_info.workflow_node_execution_id).first()
if not route_chunk_node_execution:
continue
outputs = route_chunk_node_execution.outputs_dict
if not outputs:
continue
# get value from outputs
text = outputs.get('text')
if text:
self._task_state.answer += text
yield self._text_chunk_to_stream_response(text)
def _is_stream_out_support(self, event: QueueTextChunkEvent) -> bool:
"""
Is stream out support
:param event: queue text chunk event
:return:
"""
if not event.metadata:
return False
if 'node_id' not in event.metadata:
return False
node_id = event.metadata.get('node_id')
node_type = event.metadata.get('node_type')
stream_output_value_selector = event.metadata.get('value_selector')
if not stream_output_value_selector:
return False
if not self._task_state.current_stream_generate_state:
return False
if node_id not in self._task_state.current_stream_generate_state.stream_node_ids:
return False
if node_type != NodeType.LLM:
# only LLM support chunk stream output
return False
return True

View File

@@ -6,6 +6,7 @@ from core.app.entities.queue_entities import (
QueueNodeFailedEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
@@ -119,7 +120,15 @@ class WorkflowEventTriggerCallback(BaseWorkflowCallback):
"""
Publish text chunk
"""
pass
self._queue_manager.publish(
QueueTextChunkEvent(
text=text,
metadata={
"node_id": node_id,
**metadata
}
), PublishFrom.APPLICATION_MANAGER
)
def on_event(self, event: AppQueueEvent) -> None:
"""

View File

@@ -9,9 +9,17 @@ from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.answer.entities import GenerateRouteChunk
class StreamGenerateRoute(BaseModel):
class WorkflowStreamGenerateNodes(BaseModel):
"""
StreamGenerateRoute entity
WorkflowStreamGenerateNodes entity
"""
end_node_id: str
stream_node_ids: list[str]
class ChatflowStreamGenerateRoute(BaseModel):
"""
ChatflowStreamGenerateRoute entity
"""
answer_node_id: str
generate_route: list[GenerateRouteChunk]
@@ -55,6 +63,8 @@ class WorkflowTaskState(TaskState):
ran_node_execution_infos: dict[str, NodeExecutionInfo] = {}
latest_node_execution_info: Optional[NodeExecutionInfo] = None
current_stream_generate_state: Optional[WorkflowStreamGenerateNodes] = None
class AdvancedChatTaskState(WorkflowTaskState):
"""
@@ -62,7 +72,7 @@ class AdvancedChatTaskState(WorkflowTaskState):
"""
usage: LLMUsage
current_stream_generate_state: Optional[StreamGenerateRoute] = None
current_stream_generate_state: Optional[ChatflowStreamGenerateRoute] = None
class StreamEvent(Enum):

View File

@@ -36,6 +36,49 @@ class EndNode(BaseNode):
outputs=outputs
)
@classmethod
def extract_generate_nodes(cls, graph: dict, config: dict) -> list[str]:
"""
Extract generate nodes
:param graph: graph
:param config: node config
:return:
"""
node_data = cls._node_data_cls(**config.get("data", {}))
node_data = cast(cls._node_data_cls, node_data)
return cls.extract_generate_nodes_from_node_data(graph, node_data)
@classmethod
def extract_generate_nodes_from_node_data(cls, graph: dict, node_data: EndNodeData) -> list[str]:
"""
Extract generate nodes from node data
:param graph: graph
:param node_data: node data object
:return:
"""
nodes = graph.get('nodes')
node_mapping = {node.get('id'): node for node in nodes}
variable_selectors = node_data.outputs
generate_nodes = []
for variable_selector in variable_selectors:
if not variable_selector.value_selector:
continue
node_id = variable_selector.value_selector[0]
if node_id != 'sys' and node_id in node_mapping:
node = node_mapping[node_id]
node_type = node.get('data', {}).get('type')
if node_type == NodeType.LLM.value and variable_selector.value_selector[1] == 'text':
generate_nodes.append(node_id)
# remove duplicates
generate_nodes = list(set(generate_nodes))
return generate_nodes
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
"""