fix(workflow_entry): Support receive File and FileList in single step run. (#10947)
Signed-off-by: -LAN- <laipz8200@outlook.com> Co-authored-by: JzoNg <jzongcode@gmail.com>
This commit is contained in:
@@ -36,7 +36,7 @@ class NodeRunResult(BaseModel):
|
||||
|
||||
inputs: Optional[Mapping[str, Any]] = None # node inputs
|
||||
process_data: Optional[dict[str, Any]] = None # process data
|
||||
outputs: Optional[dict[str, Any]] = None # node outputs
|
||||
outputs: Optional[Mapping[str, Any]] = None # node outputs
|
||||
metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata
|
||||
llm_usage: Optional[LLMUsage] = None # llm usage
|
||||
|
||||
|
@@ -5,10 +5,9 @@ from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import FileUploadConfig
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File, FileTransferMethod, ImageConfig
|
||||
from core.file.models import File
|
||||
from core.workflow.callbacks import WorkflowCallback
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
@@ -18,9 +17,8 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.base import BaseNode, BaseNodeData
|
||||
from core.workflow.nodes.base import BaseNode
|
||||
from core.workflow.nodes.event import NodeEvent
|
||||
from core.workflow.nodes.llm import LLMNodeData
|
||||
from core.workflow.nodes.node_mapping import node_type_classes_mapping
|
||||
from factories import file_factory
|
||||
from models.enums import UserFrom
|
||||
@@ -115,7 +113,12 @@ class WorkflowEntry:
|
||||
|
||||
@classmethod
|
||||
def single_step_run(
|
||||
cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict
|
||||
cls,
|
||||
*,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: dict,
|
||||
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
|
||||
"""
|
||||
Single step run workflow node
|
||||
@@ -135,13 +138,9 @@ class WorkflowEntry:
|
||||
raise ValueError("nodes not found in workflow graph")
|
||||
|
||||
# fetch node config from node id
|
||||
node_config = None
|
||||
for node in nodes:
|
||||
if node.get("id") == node_id:
|
||||
node_config = node
|
||||
break
|
||||
|
||||
if not node_config:
|
||||
try:
|
||||
node_config = next(filter(lambda node: node["id"] == node_id, nodes))
|
||||
except StopIteration:
|
||||
raise ValueError("node id not found in workflow graph")
|
||||
|
||||
# Get node class
|
||||
@@ -153,11 +152,7 @@ class WorkflowEntry:
|
||||
raise ValueError(f"Node class not found for node type {node_type}")
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
variable_pool = VariablePool(environment_variables=workflow.environment_variables)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(graph_config=workflow.graph_dict)
|
||||
@@ -183,28 +178,24 @@ class WorkflowEntry:
|
||||
|
||||
try:
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
cls.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=node_instance.node_data,
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict, config=node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
cls.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
)
|
||||
try:
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
|
||||
return node_instance, generator
|
||||
except Exception as e:
|
||||
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
|
||||
return node_instance, generator
|
||||
|
||||
@staticmethod
|
||||
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
|
||||
@@ -231,12 +222,11 @@ class WorkflowEntry:
|
||||
@classmethod
|
||||
def mapping_user_inputs_to_variable_pool(
|
||||
cls,
|
||||
*,
|
||||
variable_mapping: Mapping[str, Sequence[str]],
|
||||
user_inputs: dict,
|
||||
variable_pool: VariablePool,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
) -> None:
|
||||
for node_variable, variable_selector in variable_mapping.items():
|
||||
# fetch node id and variable key from node_variable
|
||||
@@ -254,40 +244,21 @@ class WorkflowEntry:
|
||||
# fetch variable node id from variable selector
|
||||
variable_node_id = variable_selector[0]
|
||||
variable_key_list = variable_selector[1:]
|
||||
variable_key_list = cast(list[str], variable_key_list)
|
||||
variable_key_list = list(variable_key_list)
|
||||
|
||||
# get input value
|
||||
input_value = user_inputs.get(node_variable)
|
||||
if not input_value:
|
||||
input_value = user_inputs.get(node_variable_key)
|
||||
|
||||
# FIXME: temp fix for image type
|
||||
if node_type == NodeType.LLM:
|
||||
new_value = []
|
||||
if isinstance(input_value, list):
|
||||
node_data = cast(LLMNodeData, node_data)
|
||||
|
||||
detail = node_data.vision.configs.detail if node_data.vision.configs else None
|
||||
|
||||
for item in input_value:
|
||||
if isinstance(item, dict) and "type" in item and item["type"] == "image":
|
||||
transfer_method = FileTransferMethod.value_of(item.get("transfer_method"))
|
||||
mapping = {
|
||||
"id": item.get("id"),
|
||||
"transfer_method": transfer_method,
|
||||
"upload_file_id": item.get("upload_file_id"),
|
||||
"url": item.get("url"),
|
||||
}
|
||||
config = FileUploadConfig(image_config=ImageConfig(detail=detail) if detail else None)
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
)
|
||||
new_value.append(file)
|
||||
|
||||
if new_value:
|
||||
input_value = new_value
|
||||
if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value:
|
||||
input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id)
|
||||
if (
|
||||
isinstance(input_value, list)
|
||||
and all(isinstance(item, dict) for item in input_value)
|
||||
and all("type" in item and "transfer_method" in item for item in input_value)
|
||||
):
|
||||
input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id)
|
||||
|
||||
# append variable and value to variable pool
|
||||
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
||||
|
Reference in New Issue
Block a user