Feat/workflow phase2 (#4687)

This commit is contained in:
Yeuoly
2024-05-27 22:01:11 +08:00
committed by GitHub
parent 45deaee762
commit e852a21634
139 changed files with 5997 additions and 779 deletions

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Optional
from typing import Any, Optional
from core.app.entities.queue_entities import AppQueueEvent
from core.workflow.entities.base_node_data_entities import BaseNodeData
@@ -71,6 +71,42 @@ class BaseWorkflowCallback(ABC):
Publish text chunk
"""
raise NotImplementedError
@abstractmethod
def on_workflow_iteration_started(self,
node_id: str,
node_type: NodeType,
node_run_index: int = 1,
node_data: Optional[BaseNodeData] = None,
inputs: dict = None,
predecessor_node_id: Optional[str] = None,
metadata: Optional[dict] = None) -> None:
"""
Publish iteration started
"""
raise NotImplementedError
@abstractmethod
def on_workflow_iteration_next(self, node_id: str,
node_type: NodeType,
index: int,
node_run_index: int,
output: Optional[Any],
) -> None:
"""
Publish iteration next
"""
raise NotImplementedError
@abstractmethod
def on_workflow_iteration_completed(self, node_id: str,
node_type: NodeType,
node_run_index: int,
outputs: dict) -> None:
"""
Publish iteration completed
"""
raise NotImplementedError
@abstractmethod
def on_event(self, event: AppQueueEvent) -> None:

View File

@@ -7,3 +7,16 @@ from pydantic import BaseModel
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
class BaseIterationNodeData(BaseNodeData):
start_node_id: str
class BaseIterationState(BaseModel):
iteration_node_id: str
index: int
inputs: dict
class MetaData(BaseModel):
pass
metadata: MetaData

View File

@@ -21,7 +21,11 @@ class NodeType(Enum):
QUESTION_CLASSIFIER = 'question-classifier'
HTTP_REQUEST = 'http-request'
TOOL = 'tool'
VARIABLE_AGGREGATOR = 'variable-aggregator'
VARIABLE_ASSIGNER = 'variable-assigner'
LOOP = 'loop'
ITERATION = 'iteration'
PARAMETER_EXTRACTOR = 'parameter-extractor'
@classmethod
def value_of(cls, value: str) -> 'NodeType':
@@ -68,6 +72,8 @@ class NodeRunMetadataKey(Enum):
TOTAL_PRICE = 'total_price'
CURRENCY = 'currency'
TOOL_INFO = 'tool_info'
ITERATION_ID = 'iteration_id'
ITERATION_INDEX = 'iteration_index'
class NodeRunResult(BaseModel):

View File

@@ -90,3 +90,12 @@ class VariablePool:
raise ValueError(f'Invalid value type: {target_value_type.value}')
return value
def clear_node_variables(self, node_id: str) -> None:
"""
Clear node variables
:param node_id: node id
:return:
"""
if node_id in self.variables_mapping:
self.variables_mapping.pop(node_id)

View File

@@ -1,5 +1,9 @@
from typing import Optional
from pydantic import BaseModel
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.base_node_data_entities import BaseIterationState
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode, UserFrom
@@ -22,6 +26,9 @@ class WorkflowRunState:
workflow_type: WorkflowType
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
workflow_call_depth: int
start_at: float
variable_pool: VariablePool
@@ -30,20 +37,37 @@ class WorkflowRunState:
workflow_nodes_and_results: list[WorkflowNodeAndResult]
class NodeRun(BaseModel):
node_id: str
iteration_node_id: str
workflow_node_runs: list[NodeRun]
workflow_node_steps: int
current_iteration_state: Optional[BaseIterationState]
def __init__(self, workflow: Workflow,
start_at: float,
variable_pool: VariablePool,
user_id: str,
user_from: UserFrom):
user_from: UserFrom,
invoke_from: InvokeFrom,
workflow_call_depth: int):
self.workflow_id = workflow.id
self.tenant_id = workflow.tenant_id
self.app_id = workflow.app_id
self.workflow_type = WorkflowType.value_of(workflow.type)
self.user_id = user_id
self.user_from = user_from
self.invoke_from = invoke_from
self.workflow_call_depth = workflow_call_depth
self.start_at = start_at
self.variable_pool = variable_pool
self.total_tokens = 0
self.workflow_nodes_and_results = []
self.current_iteration_state = None
self.workflow_node_steps = 1
self.workflow_node_runs = []

View File

@@ -2,8 +2,9 @@ from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
@@ -37,6 +38,9 @@ class BaseNode(ABC):
workflow_id: str
user_id: str
user_from: UserFrom
invoke_from: InvokeFrom
workflow_call_depth: int
node_id: str
node_data: BaseNodeData
@@ -49,13 +53,17 @@ class BaseNode(ABC):
workflow_id: str,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
config: dict,
callbacks: list[BaseWorkflowCallback] = None) -> None:
callbacks: list[BaseWorkflowCallback] = None,
workflow_call_depth: int = 0) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
self.workflow_id = workflow_id
self.user_id = user_id
self.user_from = user_from
self.invoke_from = invoke_from
self.workflow_call_depth = workflow_call_depth
self.node_id = config.get("id")
if not self.node_id:
@@ -140,3 +148,38 @@ class BaseNode(ABC):
:return:
"""
return self._node_type
class BaseIterationNode(BaseNode):
@abstractmethod
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
"""
Run node
:param variable_pool: variable pool
:return:
"""
raise NotImplementedError
def run(self, variable_pool: VariablePool) -> BaseIterationState:
"""
Run node entry
:param variable_pool: variable pool
:return:
"""
return self._run(variable_pool=variable_pool)
def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
return self._get_next_iteration(variable_pool, state)
@abstractmethod
def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
raise NotImplementedError

View File

@@ -0,0 +1,39 @@
from typing import Any, Optional
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
class IterationNodeData(BaseIterationNodeData):
"""
Iteration Node Data.
"""
parent_loop_id: Optional[str] # redundant field, not used currently
iterator_selector: list[str] # variable selector
output_selector: list[str] # output selector
class IterationState(BaseIterationState):
"""
Iteration State.
"""
outputs: list[Any] = None
current_output: Optional[Any] = None
class MetaData(BaseIterationState.MetaData):
"""
Data.
"""
iterator_length: int
def get_last_output(self) -> Optional[Any]:
"""
Get last output.
"""
if self.outputs:
return self.outputs[-1]
return None
def get_current_output(self) -> Optional[Any]:
"""
Get current output.
"""
return self.current_output

View File

@@ -0,0 +1,119 @@
from typing import cast
from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.base_node_data_entities import BaseIterationState
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseIterationNode
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
from models.workflow import WorkflowNodeExecutionStatus
class IterationNode(BaseIterationNode):
"""
Iteration Node.
"""
_node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
"""
Run the node.
"""
iterator = variable_pool.get_variable_value(cast(IterationNodeData, self.node_data).iterator_selector)
if not isinstance(iterator, list):
raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.")
state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={
'iterator_selector': iterator
}, outputs=[], metadata=IterationState.MetaData(
iterator_length=len(iterator) if iterator is not None else 0
))
self._set_current_iteration_variable(variable_pool, state)
return state
def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
:param graph: graph
:return: next node id
"""
# resolve current output
self._resolve_current_output(variable_pool, state)
# move to next iteration
self._next_iteration(variable_pool, state)
node_data = cast(IterationNodeData, self.node_data)
if self._reached_iteration_limit(variable_pool, state):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
'output': jsonable_encoder(state.outputs)
}
)
return node_data.start_node_id
def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState):
"""
Set current iteration variable.
:variable_pool: variable pool
"""
node_data = cast(IterationNodeData, self.node_data)
variable_pool.append_variable(self.node_id, ['index'], state.index)
# get the iterator value
iterator = variable_pool.get_variable_value(node_data.iterator_selector)
if iterator is None or not isinstance(iterator, list):
return
if state.index < len(iterator):
variable_pool.append_variable(self.node_id, ['item'], iterator[state.index])
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
"""
Move to next iteration.
:param variable_pool: variable pool
"""
state.index += 1
self._set_current_iteration_variable(variable_pool, state)
def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState):
"""
Check if iteration limit is reached.
:return: True if iteration limit is reached, False otherwise
"""
node_data = cast(IterationNodeData, self.node_data)
iterator = variable_pool.get_variable_value(node_data.iterator_selector)
if iterator is None or not isinstance(iterator, list):
return True
return state.index >= len(iterator)
def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState):
"""
Resolve current output.
:param variable_pool: variable pool
"""
output_selector = cast(IterationNodeData, self.node_data).output_selector
output = variable_pool.get_variable_value(output_selector)
# clear the output for this iteration
variable_pool.append_variable(self.node_id, output_selector[1:], None)
state.current_output = output
if output is not None:
state.outputs.append(output)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
return {
'input_selector': node_data.iterator_selector,
}

View File

View File

@@ -0,0 +1,13 @@
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
class LoopNodeData(BaseIterationNodeData):
"""
Loop Node Data.
"""
class LoopState(BaseIterationState):
"""
Loop State.
"""

View File

@@ -0,0 +1,20 @@
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseIterationNode
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
class LoopNode(BaseIterationNode):
"""
Loop Node.
"""
_node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
def _run(self, variable_pool: VariablePool) -> LoopState:
return super()._run(variable_pool)
def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str:
"""
Get next iteration start node id based on the graph.
"""

View File

@@ -0,0 +1,85 @@
from typing import Any, Literal, Optional
from pydantic import BaseModel, validator
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.entities.base_node_data_entities import BaseNodeData
class ModelConfig(BaseModel):
"""
Model Config.
"""
provider: str
name: str
mode: str
completion_params: dict[str, Any] = {}
class ParameterConfig(BaseModel):
"""
Parameter Config.
"""
name: str
type: Literal['string', 'number', 'bool', 'select', 'array[string]', 'array[number]', 'array[object]']
options: Optional[list[str]]
description: str
required: bool
@validator('name', pre=True, always=True)
def validate_name(cls, value):
if not value:
raise ValueError('Parameter name is required')
if value in ['__reason', '__is_success']:
raise ValueError('Invalid parameter name, __reason and __is_success are reserved')
return value
class ParameterExtractorNodeData(BaseNodeData):
"""
Parameter Extractor Node Data.
"""
model: ModelConfig
query: list[str]
parameters: list[ParameterConfig]
instruction: Optional[str]
memory: Optional[MemoryConfig]
reasoning_mode: Literal['function_call', 'prompt']
@validator('reasoning_mode', pre=True, always=True)
def set_reasoning_mode(cls, v):
return v or 'function_call'
def get_parameter_json_schema(self) -> dict:
"""
Get parameter json schema.
:return: parameter json schema
"""
parameters = {
'type': 'object',
'properties': {},
'required': []
}
for parameter in self.parameters:
parameter_schema = {
'description': parameter.description
}
if parameter.type in ['string', 'select']:
parameter_schema['type'] = 'string'
elif parameter.type.startswith('array'):
parameter_schema['type'] = 'array'
nested_type = parameter.type[6:-1]
parameter_schema['items'] = {'type': nested_type}
else:
parameter_schema['type'] = parameter.type
if parameter.type == 'select':
parameter_schema['enum'] = parameter.options
parameters['properties'][parameter.name] = parameter_schema
if parameter.required:
parameters['required'].append(parameter.name)
return parameters

View File

@@ -0,0 +1,711 @@
import json
import uuid
from typing import Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageRole,
PromptMessageTool,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from core.workflow.nodes.parameter_extractor.prompts import (
CHAT_EXAMPLE,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
COMPLETION_GENERATE_JSON_PROMPT,
FUNCTION_CALLING_EXTRACTOR_EXAMPLE,
FUNCTION_CALLING_EXTRACTOR_NAME,
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT,
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE,
)
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.workflow import WorkflowNodeExecutionStatus
class ParameterExtractorNode(LLMNode):
"""
Parameter Extractor Node.
"""
_node_data_cls = ParameterExtractorNodeData
_node_type = NodeType.PARAMETER_EXTRACTOR
_model_instance: Optional[ModelInstance] = None
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
return {
"model": {
"prompt_templates": {
"completion_model": {
"conversation_histories_role": {
"user_prefix": "Human",
"assistant_prefix": "Assistant"
},
"stop": ["Human:"]
}
}
}
}
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
"""
Run the node.
"""
node_data = cast(ParameterExtractorNodeData, self.node_data)
query = variable_pool.get_variable_value(node_data.query)
if not query:
raise ValueError("Query not found")
inputs={
'query': query,
'parameters': jsonable_encoder(node_data.parameters),
'instruction': jsonable_encoder(node_data.instruction),
}
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise ValueError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
if not model_schema:
raise ValueError("Model schema not found")
# fetch memory
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]) \
and node_data.reasoning_mode == 'function_call':
# use function call
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
node_data, query, variable_pool, model_config, memory
)
else:
# use prompt engineering
prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config, memory)
prompt_message_tools = []
process_data = {
'model_mode': model_config.mode,
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode,
prompt_messages=prompt_messages
),
'usage': None,
'function': {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
'tool_call': None,
}
try:
text, usage, tool_call = self._invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
tools=prompt_message_tools,
stop=model_config.stop,
)
process_data['usage'] = jsonable_encoder(usage)
process_data['tool_call'] = jsonable_encoder(tool_call)
process_data['llm_text'] = text
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=inputs,
process_data={},
outputs={
'__is_success': 0,
'__reason': str(e)
},
error=str(e),
metadata={}
)
error = None
if tool_call:
result = self._extract_json_from_tool_call(tool_call)
else:
result = self._extract_complete_json_response(text)
if not result:
result = self._generate_default_result(node_data)
error = "Failed to extract result from function call or text response, using empty result."
try:
result = self._validate_result(node_data, result)
except Exception as e:
error = str(e)
# transform result into standard format
result = self._transform_result(node_data, result)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
outputs={
'__is_success': 1 if not error else 0,
'__reason': error,
**result
},
metadata={
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
NodeRunMetadataKey.CURRENCY: usage.currency
}
)
def _invoke_llm(self, node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: list[PromptMessage],
tools: list[PromptMessageTool],
stop: list[str]) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
"""
Invoke large language model
:param node_data_model: node data model
:param model_instance: model instance
:param prompt_messages: prompt messages
:param stop: stop
:return:
"""
db.session.close()
invoke_result = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=node_data_model.completion_params,
tools=tools,
stop=stop,
stream=False,
user=self.user_id,
)
# handle invoke result
if not isinstance(invoke_result, LLMResult):
raise ValueError(f"Invalid invoke result: {invoke_result}")
text = invoke_result.message.content
usage = invoke_result.usage
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
# deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage, tool_call
def _generate_function_call_prompt(self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
"""
Generate function call prompt.
"""
query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps(node_data.get_parameter_json_schema()))
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory, rest_token)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query='',
files=[],
context='',
memory_config=node_data.memory,
memory=None,
model_config=model_config
)
# find last user message
last_user_message_idx = -1
for i, prompt_message in enumerate(prompt_messages):
if prompt_message.role == PromptMessageRole.USER:
last_user_message_idx = i
# add function call messages before last user message
example_messages = []
for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE:
id = uuid.uuid4().hex
example_messages.extend([
UserPromptMessage(content=example['user']['query']),
AssistantPromptMessage(
content=example['assistant']['text'],
tool_calls=[
AssistantPromptMessage.ToolCall(
id=id,
type='function',
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=example['assistant']['function_call']['name'],
arguments=json.dumps(example['assistant']['function_call']['parameters']
)
))
]
),
ToolPromptMessage(
content='Great! You have called the function with the correct parameters.',
tool_call_id=id
),
AssistantPromptMessage(
content='I have extracted the parameters, let\'s move on.',
)
])
prompt_messages = prompt_messages[:last_user_message_idx] + \
example_messages + prompt_messages[last_user_message_idx:]
# generate tool
tool = PromptMessageTool(
name=FUNCTION_CALLING_EXTRACTOR_NAME,
description='Extract parameters from the natural language text',
parameters=node_data.get_parameter_json_schema(),
)
return prompt_messages, [tool]
def _generate_prompt_engineering_prompt(self,
data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
"""
Generate prompt engineering prompt.
"""
model_mode = ModelMode.value_of(data.model.mode)
if model_mode == ModelMode.COMPLETION:
return self._generate_prompt_engineering_completion_prompt(
data, query, variable_pool, model_config, memory
)
elif model_mode == ModelMode.CHAT:
return self._generate_prompt_engineering_chat_prompt(
data, query, variable_pool, model_config, memory
)
else:
raise ValueError(f"Invalid model mode: {model_mode}")
def _generate_prompt_engineering_completion_prompt(self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
"""
Generate completion prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory, rest_token)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={
'structure': json.dumps(node_data.get_parameter_json_schema())
},
query='',
files=[],
context='',
memory_config=node_data.memory,
memory=memory,
model_config=model_config
)
return prompt_messages
def _generate_prompt_engineering_chat_prompt(self,
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: Optional[TokenBufferMemory],
) -> list[PromptMessage]:
"""
Generate chat prompt.
"""
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
prompt_template = self._get_prompt_engineering_prompt_template(
node_data,
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
structure=json.dumps(node_data.get_parameter_json_schema()),
text=query
),
variable_pool, memory, rest_token
)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query='',
files=[],
context='',
memory_config=node_data.memory,
memory=memory,
model_config=model_config
)
# find last user message
last_user_message_idx = -1
for i, prompt_message in enumerate(prompt_messages):
if prompt_message.role == PromptMessageRole.USER:
last_user_message_idx = i
# add example messages before last user message
example_messages = []
for example in CHAT_EXAMPLE:
example_messages.extend([
UserPromptMessage(content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
structure=json.dumps(example['user']['json']),
text=example['user']['query'],
)),
AssistantPromptMessage(
content=json.dumps(example['assistant']['json']),
)
])
prompt_messages = prompt_messages[:last_user_message_idx] + \
example_messages + prompt_messages[last_user_message_idx:]
return prompt_messages
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
"""
Validate result.
"""
if len(data.parameters) != len(result):
raise ValueError("Invalid number of parameters")
for parameter in data.parameters:
if parameter.required and parameter.name not in result:
raise ValueError(f"Parameter {parameter.name} is required")
if parameter.type == 'select' and parameter.options and result.get(parameter.name) not in parameter.options:
raise ValueError(f"Invalid `select` value for parameter {parameter.name}")
if parameter.type == 'number' and not isinstance(result.get(parameter.name), int | float):
raise ValueError(f"Invalid `number` value for parameter {parameter.name}")
if parameter.type == 'bool' and not isinstance(result.get(parameter.name), bool):
raise ValueError(f"Invalid `bool` value for parameter {parameter.name}")
if parameter.type == 'string' and not isinstance(result.get(parameter.name), str):
raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
if parameter.type.startswith('array'):
if not isinstance(result.get(parameter.name), list):
raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
nested_type = parameter.type[6:-1]
for item in result.get(parameter.name):
if nested_type == 'number' and not isinstance(item, int | float):
raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
if nested_type == 'string' and not isinstance(item, str):
raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
if nested_type == 'object' and not isinstance(item, dict):
raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
return result
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
"""
Transform result into standard format.
"""
transformed_result = {}
for parameter in data.parameters:
if parameter.name in result:
# transform value
if parameter.type == 'number':
if isinstance(result[parameter.name], int | float):
transformed_result[parameter.name] = result[parameter.name]
elif isinstance(result[parameter.name], str):
try:
if '.' in result[parameter.name]:
result[parameter.name] = float(result[parameter.name])
else:
result[parameter.name] = int(result[parameter.name])
except ValueError:
pass
else:
pass
# TODO: bool is not supported in the current version
# elif parameter.type == 'bool':
# if isinstance(result[parameter.name], bool):
# transformed_result[parameter.name] = bool(result[parameter.name])
# elif isinstance(result[parameter.name], str):
# if result[parameter.name].lower() in ['true', 'false']:
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true')
# elif isinstance(result[parameter.name], int):
# transformed_result[parameter.name] = bool(result[parameter.name])
elif parameter.type in ['string', 'select']:
if isinstance(result[parameter.name], str):
transformed_result[parameter.name] = result[parameter.name]
elif parameter.type.startswith('array'):
if isinstance(result[parameter.name], list):
nested_type = parameter.type[6:-1]
transformed_result[parameter.name] = []
for item in result[parameter.name]:
if nested_type == 'number':
if isinstance(item, int | float):
transformed_result[parameter.name].append(item)
elif isinstance(item, str):
try:
if '.' in item:
transformed_result[parameter.name].append(float(item))
else:
transformed_result[parameter.name].append(int(item))
except ValueError:
pass
elif nested_type == 'string':
if isinstance(item, str):
transformed_result[parameter.name].append(item)
elif nested_type == 'object':
if isinstance(item, dict):
transformed_result[parameter.name].append(item)
if parameter.name not in transformed_result:
if parameter.type == 'number':
transformed_result[parameter.name] = 0
elif parameter.type == 'bool':
transformed_result[parameter.name] = False
elif parameter.type in ['string', 'select']:
transformed_result[parameter.name] = ''
elif parameter.type.startswith('array'):
transformed_result[parameter.name] = []
return transformed_result
def _extract_complete_json_response(self, result: str) -> Optional[dict]:
"""
Extract complete json response.
"""
def extract_json(text):
"""
From a given JSON started from '{' or '[' extract the complete JSON object.
"""
stack = []
for i, c in enumerate(text):
if c == '{' or c == '[':
stack.append(c)
elif c == '}' or c == ']':
# check if stack is empty
if not stack:
return text[:i]
# check if the last element in stack is matching
if (c == '}' and stack[-1] == '{') or (c == ']' and stack[-1] == '['):
stack.pop()
if not stack:
return text[:i+1]
else:
return text[:i]
return None
# extract json from the text
for idx in range(len(result)):
if result[idx] == '{' or result[idx] == '[':
json_str = extract_json(result[idx:])
if json_str:
try:
return json.loads(json_str)
except Exception:
pass
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
"""
Extract json from tool call.
"""
if not tool_call or not tool_call.function.arguments:
return None
return json.loads(tool_call.function.arguments)
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
"""
Generate default result.
"""
result = {}
for parameter in data.parameters:
if parameter.type == 'number':
result[parameter.name] = 0
elif parameter.type == 'bool':
result[parameter.name] = False
elif parameter.type in ['string', 'select']:
result[parameter.name] = ''
return result
def _render_instruction(self, instruction: str, variable_pool: VariablePool) -> str:
"""
Render instruction.
"""
variable_template_parser = VariableTemplateParser(instruction)
inputs = {}
for selector in variable_template_parser.extract_variable_selectors():
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
return variable_template_parser.format(inputs)
def _get_function_calling_prompt_template(self, node_data: ParameterExtractorNodeData, query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000) \
-> list[ChatModelMessage]:
model_mode = ModelMode.value_of(node_data.model.mode)
input_text = query
memory_str = ''
instruction = self._render_instruction(node_data.instruction or '', variable_pool)
if memory:
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction)
)
user_prompt_message = ChatModelMessage(
role=PromptMessageRole.USER,
text=input_text
)
return [system_prompt_messages, user_prompt_message]
else:
raise ValueError(f"Model mode {model_mode} not support.")
def _get_prompt_engineering_prompt_template(self, node_data: ParameterExtractorNodeData, query: str,
variable_pool: VariablePool,
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000) \
-> list[ChatModelMessage]:
model_mode = ModelMode.value_of(node_data.model.mode)
input_text = query
memory_str = ''
instruction = self._render_instruction(node_data.instruction or '', variable_pool)
if memory:
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
message_limit=node_data.memory.window.size)
if model_mode == ModelMode.CHAT:
system_prompt_messages = ChatModelMessage(
role=PromptMessageRole.SYSTEM,
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction)
)
user_prompt_message = ChatModelMessage(
role=PromptMessageRole.USER,
text=input_text
)
return [system_prompt_messages, user_prompt_message]
elif model_mode == ModelMode.COMPLETION:
return CompletionModelPromptTemplate(
text=COMPLETION_GENERATE_JSON_PROMPT.format(histories=memory_str,
text=input_text,
instruction=instruction)
.replace('{γγγ', '')
.replace('}γγγ', '')
)
else:
raise ValueError(f"Model mode {model_mode} not support.")
def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
context: Optional[str]) -> int:
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
model_instance, model_config = self._fetch_model_config(node_data.model)
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
raise ValueError("Model is not a Large Language Model")
llm_model = model_instance.model_type_instance
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
if not model_schema:
raise ValueError("Model schema not found")
if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]):
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
else:
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
prompt_messages = prompt_transform.get_prompt(
prompt_template=prompt_template,
inputs={},
query='',
files=[],
context=context,
memory_config=node_data.memory,
memory=None,
model_config=model_config
)
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
curr_message_tokens = model_type_instance.get_num_tokens(
model_config.model,
model_config.credentials,
prompt_messages
) + 1000 # add 1000 to ensure tool call messages
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
if (parameter_rule.name == 'max_tokens'
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
max_tokens = (model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(parameter_rule.use_template)) or 0
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
"""
Fetch model config.
"""
if not self._model_instance or not self._model_config:
self._model_instance, self._model_config = super()._fetch_model_config(node_data_model)
return self._model_instance, self._model_config
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[str, list[str]]:
"""
Extract variable selector to variable mapping
:param node_data: node data
:return:
"""
node_data = node_data
variable_mapping = {
'query': node_data.query
}
if node_data.instruction:
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
for selector in variable_template_parser.extract_variable_selectors():
variable_mapping[selector.variable] = selector.value_selector
return variable_mapping

View File

@@ -0,0 +1,206 @@
FUNCTION_CALLING_EXTRACTOR_NAME = 'extract_parameters'
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy.
### Task
Always call the `{FUNCTION_CALLING_EXTRACTOR_NAME}` function with the correct parameters. Ensure that the information extraction is contextual and aligns with the provided criteria.
### Memory
Here is the chat history between the human and assistant, provided within <histories> tags:
<histories>
\x7bhistories\x7d
</histories>
### Instructions:
Some additional information is provided below. Always adhere to these instructions as closely as possible:
<instruction>
\x7binstruction\x7d
</instruction>
Steps:
1. Review the chat history provided within the <histories> tags.
2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text.
3. Generate a well-formatted output using the defined functions and arguments.
4. Use the `extract_parameter` function to create structured outputs with appropriate parameters.
5. Do not include any XML tags in your output.
### Example
To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples.
### Final Output
Produce well-formatted function calls in json without XML tags, as shown in the example.
"""
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside <context></context> XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside <structure></structure> XML tags.
<context>
\x7bcontent\x7d
</context>
<structure>
\x7bstructure\x7d
</structure>
"""
FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [{
'user': {
'query': 'What is the weather today in SF?',
'function': {
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
'parameters': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The location to get the weather information',
'required': True
},
},
'required': ['location']
}
}
},
'assistant': {
'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the location parameter.',
'function_call' : {
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
'parameters': {
'location': 'San Francisco'
}
}
}
}, {
'user': {
'query': 'I want to eat some apple pie.',
'function': {
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
'parameters': {
'type': 'object',
'properties': {
'food': {
'type': 'string',
'description': 'The food to eat',
'required': True
}
},
'required': ['food']
}
}
},
'assistant': {
'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the food parameter.',
'function_call' : {
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
'parameters': {
'food': 'apple pie'
}
}
}
}]
COMPLETION_GENERATE_JSON_PROMPT = """### Instructions:
Some extra information are provided below, I should always follow the instructions as possible as I can.
<instructions>
{instruction}
</instructions>
### Extract parameter Workflow
I need to extract the following information from the input text. The <information to be extracted> tag specifies the 'type', 'description' and 'required' of the information to be extracted.
<information to be extracted>
{{ structure }}
</information to be extracted>
Step 1: Carefully read the input and understand the structure of the expected output.
Step 2: Extract relevant parameters from the provided text based on the name and description of object.
Step 3: Structure the extracted parameters to JSON object as specified in <structure>.
Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted.
### Memory
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
### Structure
Here is the structure of the expected output, I should always follow the output structure.
{{γγγ
'properties1': 'relevant text extracted from input',
'properties2': 'relevant text extracted from input',
}}γγγ
### Input Text
Inside <text></text> XML tags, there is a text that I should extract parameters and convert to a JSON object.
<text>
{text}
</text>
### Answer
I should always output a valid JSON object. Output nothing other than the JSON object.
```JSON
"""
CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object.
The structure of the JSON object you can found in the instructions.
### Memory
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
<histories>
{histories}
</histories>
### Instructions:
Some extra information are provided below, you should always follow the instructions as possible as you can.
<instructions>
{{instructions}}
</instructions>
"""
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE = """### Structure
Here is the structure of the JSON object, you should always follow the structure.
<structure>
{structure}
</structure>
### Text to be converted to JSON
Inside <text></text> XML tags, there is a text that you should convert to a JSON object.
<text>
{text}
</text>
"""
CHAT_EXAMPLE = [{
'user': {
'query': 'What is the weather today in SF?',
'json': {
'type': 'object',
'properties': {
'location': {
'type': 'string',
'description': 'The location to get the weather information',
'required': True
}
},
'required': ['location']
}
},
'assistant': {
'text': 'I need to output a valid JSON object.',
'json': {
'location': 'San Francisco'
}
}
}, {
'user': {
'query': 'I want to eat some apple pie.',
'json': {
'type': 'object',
'properties': {
'food': {
'type': 'string',
'description': 'The food to eat',
'required': True
}
},
'required': ['food']
}
},
'assistant': {
'text': 'I need to output a valid JSON object.',
'json': {
'result': 'apple pie'
}
}
}]

View File

@@ -7,7 +7,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
class ToolEntity(BaseModel):
provider_id: str
provider_type: Literal['builtin', 'api']
provider_type: Literal['builtin', 'api', 'workflow']
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy

View File

@@ -1,13 +1,14 @@
from os import path
from typing import cast
from typing import Optional, cast
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool.tool import Tool
from core.tools.tool_engine import ToolEngine
from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
@@ -35,20 +36,23 @@ class ToolNode(BaseNode):
'provider_id': node_data.provider_id
}
# get parameters
parameters = self._generate_parameters(variable_pool, node_data)
# get tool runtime
try:
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, self.app_id, self.node_id, node_data)
tool_runtime = ToolManager.get_workflow_tool_runtime(
self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from
)
except Exception as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters,
inputs={},
metadata={
NodeRunMetadataKey.TOOL_INFO: tool_info
},
error=f'Failed to get tool runtime: {str(e)}'
)
# get parameters
parameters = self._generate_parameters(variable_pool, node_data, tool_runtime)
try:
messages = ToolEngine.workflow_invoke(
@@ -56,7 +60,8 @@ class ToolNode(BaseNode):
tool_parameters=parameters,
user_id=self.user_id,
workflow_id=self.workflow_id,
workflow_tool_callback=DifyWorkflowCallbackHandler()
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth + 1
)
except Exception as e:
return NodeRunResult(
@@ -83,19 +88,32 @@ class ToolNode(BaseNode):
inputs=parameters
)
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict:
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData, tool_runtime: Tool) -> dict:
"""
Generate parameters
"""
tool_parameters = tool_runtime.get_all_runtime_parameters()
def fetch_parameter(name: str) -> Optional[ToolParameter]:
return next((parameter for parameter in tool_parameters if parameter.name == name), None)
result = {}
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == 'mixed':
result[parameter_name] = self._format_variable_template(input.value, variable_pool)
elif input.type == 'variable':
result[parameter_name] = variable_pool.get_variable_value(input.value)
elif input.type == 'constant':
result[parameter_name] = input.value
parameter = fetch_parameter(parameter_name)
if not parameter:
continue
if parameter.type == ToolParameter.ToolParameterType.FILE:
result[parameter_name] = [
v.to_dict() for v in self._fetch_files(variable_pool)
]
else:
input = node_data.tool_parameters[parameter_name]
if input.type == 'mixed':
result[parameter_name] = self._format_variable_template(input.value, variable_pool)
elif input.type == 'variable':
result[parameter_name] = variable_pool.get_variable_value(input.value)
elif input.type == 'constant':
result[parameter_name] = input.value
return result
@@ -109,6 +127,13 @@ class ToolNode(BaseNode):
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
return template_parser.format(inputs)
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value])
if not files:
return []
return files
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
"""

View File

@@ -0,0 +1,33 @@
from typing import Literal, Optional
from pydantic import BaseModel
from core.workflow.entities.base_node_data_entities import BaseNodeData
class AdvancedSetting(BaseModel):
"""
Advanced setting.
"""
group_enabled: bool
class Group(BaseModel):
"""
Group.
"""
output_type: Literal['string', 'number', 'array', 'object']
variables: list[list[str]]
group_name: str
groups: list[Group]
class VariableAssignerNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
type: str = 'variable-assigner'
output_type: str
variables: list[list[str]]
advanced_setting: Optional[AdvancedSetting]

View File

@@ -0,0 +1,52 @@
from typing import cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
from models.workflow import WorkflowNodeExecutionStatus
class VariableAggregatorNode(BaseNode):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_AGGREGATOR
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data = cast(VariableAssignerNodeData, self.node_data)
# Get variables
outputs = {}
inputs = {}
if not node_data.advanced_setting or node_data.advanced_setting.group_enabled:
for variable in node_data.variables:
value = variable_pool.get_variable_value(variable)
if value is not None:
outputs = {
"output": value
}
inputs = {
'.'.join(variable[1:]): value
}
break
else:
for group in node_data.advanced_setting.groups:
for variable in group.variables:
value = variable_pool.get_variable_value(variable)
if value is not None:
outputs[f'{group.group_name}_output'] = value
inputs['.'.join(variable[1:])] = value
break
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
inputs=inputs
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
return {}

View File

@@ -1,12 +0,0 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
class VariableAssignerNodeData(BaseNodeData):
"""
Knowledge retrieval Node Data.
"""
type: str = 'variable-assigner'
output_type: str
variables: list[list[str]]

View File

@@ -1,41 +0,0 @@
from typing import cast
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.variable_assigner.entities import VariableAssignerNodeData
from models.workflow import WorkflowNodeExecutionStatus
class VariableAssignerNode(BaseNode):
_node_data_cls = VariableAssignerNodeData
_node_type = NodeType.VARIABLE_ASSIGNER
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
node_data: VariableAssignerNodeData = cast(self._node_data_cls, self.node_data)
# Get variables
outputs = {}
inputs = {}
for variable in node_data.variables:
value = variable_pool.get_variable_value(variable)
if value is not None:
outputs = {
"output": value
}
inputs = {
'.'.join(variable[1:]): value
}
break
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
inputs=inputs
)
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
return {}

View File

@@ -6,6 +6,7 @@ from flask import current_app
from core.app.app_config.entities import FileExtraConfig
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.file_obj import FileTransferMethod, FileType, FileVar
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
@@ -13,19 +14,22 @@ from core.workflow.entities.variable_pool import VariablePool, VariableValue
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.base_node import BaseNode, UserFrom
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
from core.workflow.nodes.if_else.if_else_node import IfElseNode
from core.workflow.nodes.iteration.entities import IterationState
from core.workflow.nodes.iteration.iteration_node import IterationNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.entities import LLMNodeData
from core.workflow.nodes.llm.llm_node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
from core.workflow.nodes.tool.tool_node import ToolNode
from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
from extensions.ext_database import db
from models.workflow import (
Workflow,
@@ -44,9 +48,14 @@ node_classes = {
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
NodeType.HTTP_REQUEST: HttpRequestNode,
NodeType.TOOL: ToolNode,
NodeType.VARIABLE_ASSIGNER: VariableAssignerNode,
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode,
NodeType.ITERATION: IterationNode,
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
}
WORKFLOW_CALL_MAX_DEPTH = 5
logger = logging.getLogger(__name__)
@@ -83,18 +92,20 @@ class WorkflowEngineManager:
def run_workflow(self, workflow: Workflow,
user_id: str,
user_from: UserFrom,
invoke_from: InvokeFrom,
user_inputs: dict,
system_inputs: Optional[dict] = None,
callbacks: list[BaseWorkflowCallback] = None) -> None:
callbacks: list[BaseWorkflowCallback] = None,
call_depth: Optional[int] = 0,
variable_pool: Optional[VariablePool] = None) -> None:
"""
Run workflow
:param workflow: Workflow instance
:param user_id: user id
:param user_from: user from
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:param callbacks: workflow callbacks
:return:
:param call_depth: call depth
"""
# fetch workflow graph
graph = workflow.graph_dict
@@ -109,57 +120,185 @@ class WorkflowEngineManager:
if not isinstance(graph.get('edges'), list):
raise ValueError('edges in workflow graph must be a list')
# init variable pool
if not variable_pool:
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=user_inputs
)
if call_depth > WORKFLOW_CALL_MAX_DEPTH:
raise ValueError('Max workflow call depth reached.')
# init workflow run state
workflow_run_state = WorkflowRunState(
workflow=workflow,
start_at=time.perf_counter(),
variable_pool=variable_pool,
user_id=user_id,
user_from=user_from,
invoke_from=invoke_from,
workflow_call_depth=call_depth
)
# init workflow run
if callbacks:
for callback in callbacks:
callback.on_workflow_run_started()
# init workflow run state
workflow_run_state = WorkflowRunState(
# run workflow
self._run_workflow(
workflow=workflow,
start_at=time.perf_counter(),
variable_pool=VariablePool(
system_variables=system_inputs,
user_inputs=user_inputs
),
user_id=user_id,
user_from=user_from
workflow_run_state=workflow_run_state,
callbacks=callbacks,
)
def _run_workflow(self, workflow: Workflow,
workflow_run_state: WorkflowRunState,
callbacks: list[BaseWorkflowCallback] = None,
start_at: Optional[str] = None,
end_at: Optional[str] = None) -> None:
"""
Run workflow
:param workflow: Workflow instance
:param user_id: user id
:param user_from: user from
:param user_inputs: user variables inputs
:param system_inputs: system inputs, like: query, files
:param callbacks: workflow callbacks
:param call_depth: call depth
:param start_at: force specific start node
:param end_at: force specific end node
:return:
"""
graph = workflow.graph_dict
try:
predecessor_node = None
predecessor_node: BaseNode = None
current_iteration_node: BaseIterationNode = None
has_entry_node = False
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
while True:
# get next node, multiple target nodes in the future
next_node = self._get_next_node(
next_node = self._get_next_overall_node(
workflow_run_state=workflow_run_state,
graph=graph,
predecessor_node=predecessor_node,
callbacks=callbacks
callbacks=callbacks,
start_at=start_at,
end_at=end_at
)
if not next_node:
# reached loop/iteration end or overall end
if current_iteration_node and workflow_run_state.current_iteration_state:
# reached loop/iteration end
# get next iteration
next_iteration = current_iteration_node.get_next_iteration(
variable_pool=workflow_run_state.variable_pool,
state=workflow_run_state.current_iteration_state
)
self._workflow_iteration_next(
graph=graph,
current_iteration_node=current_iteration_node,
workflow_run_state=workflow_run_state,
callbacks=callbacks
)
if isinstance(next_iteration, NodeRunResult):
if next_iteration.outputs:
for variable_key, variable_value in next_iteration.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
variable_pool=workflow_run_state.variable_pool,
node_id=current_iteration_node.node_id,
variable_key_list=[variable_key],
variable_value=variable_value
)
self._workflow_iteration_completed(
current_iteration_node=current_iteration_node,
workflow_run_state=workflow_run_state,
callbacks=callbacks
)
# iteration has ended
next_node = self._get_next_overall_node(
workflow_run_state=workflow_run_state,
graph=graph,
predecessor_node=current_iteration_node,
callbacks=callbacks,
start_at=start_at,
end_at=end_at
)
current_iteration_node = None
workflow_run_state.current_iteration_state = None
# continue overall process
elif isinstance(next_iteration, str):
# move to next iteration
next_node_id = next_iteration
# get next id
next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks)
if not next_node:
break
# check is already ran
if next_node.node_id in [node_and_result.node.node_id
for node_and_result in workflow_run_state.workflow_nodes_and_results]:
if self._check_node_has_ran(workflow_run_state, next_node.node_id):
predecessor_node = next_node
continue
has_entry_node = True
# max steps reached
if len(workflow_run_state.workflow_nodes_and_results) > max_execution_steps:
if workflow_run_state.workflow_node_steps > max_execution_steps:
raise ValueError('Max steps {} reached.'.format(max_execution_steps))
# or max execution time reached
if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time):
raise ValueError('Max execution time {}s reached.'.format(max_execution_time))
# handle iteration nodes
if isinstance(next_node, BaseIterationNode):
current_iteration_node = next_node
workflow_run_state.current_iteration_state = next_node.run(
variable_pool=workflow_run_state.variable_pool
)
self._workflow_iteration_started(
graph=graph,
current_iteration_node=current_iteration_node,
workflow_run_state=workflow_run_state,
predecessor_node_id=predecessor_node.node_id if predecessor_node else None,
callbacks=callbacks
)
predecessor_node = next_node
# move to start node of iteration
next_node_id = next_node.get_next_iteration(
variable_pool=workflow_run_state.variable_pool,
state=workflow_run_state.current_iteration_state
)
self._workflow_iteration_next(
graph=graph,
current_iteration_node=current_iteration_node,
workflow_run_state=workflow_run_state,
callbacks=callbacks
)
if isinstance(next_node_id, NodeRunResult):
# iteration has ended
current_iteration_node.set_output(
variable_pool=workflow_run_state.variable_pool,
state=workflow_run_state.current_iteration_state
)
self._workflow_iteration_completed(
current_iteration_node=current_iteration_node,
workflow_run_state=workflow_run_state,
callbacks=callbacks
)
current_iteration_node = None
workflow_run_state.current_iteration_state = None
continue
else:
next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks)
# run workflow, run multiple target nodes in the future
self._run_workflow_node(
workflow_run_state=workflow_run_state,
@@ -235,7 +374,9 @@ class WorkflowEngineManager:
workflow_id=workflow.id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
config=node_config
invoke_from=InvokeFrom.DEBUGGER,
config=node_config,
workflow_call_depth=0
)
try:
@@ -251,49 +392,14 @@ class WorkflowEngineManager:
except NotImplementedError:
variable_mapping = {}
for variable_key, variable_selector in variable_mapping.items():
if variable_key not in user_inputs:
raise ValueError(f'Variable key {variable_key} not found in user inputs.')
# fetch variable node id from variable selector
variable_node_id = variable_selector[0]
variable_key_list = variable_selector[1:]
# get value
value = user_inputs.get(variable_key)
# temp fix for image type
if node_type == NodeType.LLM:
new_value = []
if isinstance(value, list):
node_data = node_instance.node_data
node_data = cast(LLMNodeData, node_data)
detail = node_data.vision.configs.detail if node_data.vision.configs else None
for item in value:
if isinstance(item, dict) and 'type' in item and item['type'] == 'image':
transfer_method = FileTransferMethod.value_of(item.get('transfer_method'))
file = FileVar(
tenant_id=workflow.tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
related_id=item.get(
'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None),
)
new_value.append(file)
if new_value:
value = new_value
# append variable and value to variable pool
variable_pool.append_variable(
node_id=variable_node_id,
variable_key_list=variable_key_list,
value=value
)
self._mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
node_instance=node_instance
)
# run node
node_run_result = node_instance.run(
variable_pool=variable_pool
@@ -311,6 +417,126 @@ class WorkflowEngineManager:
return node_instance, node_run_result
def single_step_run_iteration_workflow_node(self, workflow: Workflow,
node_id: str,
user_id: str,
user_inputs: dict,
callbacks: list[BaseWorkflowCallback] = None,
) -> None:
"""
Single iteration run workflow node
"""
# fetch node info from workflow graph
graph = workflow.graph_dict
if not graph:
raise ValueError('workflow graph not found')
nodes = graph.get('nodes')
if not nodes:
raise ValueError('nodes not found in workflow graph')
for node in nodes:
if node.get('id') == node_id:
if node.get('data', {}).get('type') in [
NodeType.ITERATION.value,
NodeType.LOOP.value,
]:
node_config = node
else:
raise ValueError('node id is not an iteration node')
# init variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={}
)
# variable selector to variable mapping
iteration_nested_nodes = [
node for node in nodes
if node.get('data', {}).get('iteration_id') == node_id or node.get('id') == node_id
]
iteration_nested_node_ids = [node.get('id') for node in iteration_nested_nodes]
if not iteration_nested_nodes:
raise ValueError('iteration has no nested nodes')
# init workflow run
if callbacks:
for callback in callbacks:
callback.on_workflow_run_started()
for node_config in iteration_nested_nodes:
# mapping user inputs to variable pool
node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type')))
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config)
except NotImplementedError:
variable_mapping = {}
# remove iteration variables
variable_mapping = {
f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items()
if value[0] != node_id
}
# remove variable out from iteration
variable_mapping = {
key: value for key, value in variable_mapping.items()
if value[0] not in iteration_nested_node_ids
}
# append variables to variable pool
node_instance = node_cls(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
config=node_config,
callbacks=callbacks,
workflow_call_depth=0
)
self._mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
node_instance=node_instance
)
# fetch end node of iteration
end_node_id = None
for edge in graph.get('edges'):
if edge.get('source') == node_id:
end_node_id = edge.get('target')
break
if not end_node_id:
raise ValueError('end node of iteration not found')
# init workflow run state
workflow_run_state = WorkflowRunState(
workflow=workflow,
start_at=time.perf_counter(),
variable_pool=variable_pool,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
workflow_call_depth=0
)
# run workflow
self._run_workflow(
workflow=workflow,
workflow_run_state=workflow_run_state,
callbacks=callbacks,
start_at=node_id,
end_at=end_node_id
)
def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None:
"""
Workflow run success
@@ -336,10 +562,96 @@ class WorkflowEngineManager:
error=error
)
def _get_next_node(self, workflow_run_state: WorkflowRunState,
def _workflow_iteration_started(self, graph: dict,
current_iteration_node: BaseIterationNode,
workflow_run_state: WorkflowRunState,
predecessor_node_id: Optional[str] = None,
callbacks: list[BaseWorkflowCallback] = None) -> None:
"""
Workflow iteration started
:param current_iteration_node: current iteration node
:param workflow_run_state: workflow run state
:param callbacks: workflow callbacks
:return:
"""
# get nested nodes
iteration_nested_nodes = [
node for node in graph.get('nodes')
if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id
]
if not iteration_nested_nodes:
raise ValueError('iteration has no nested nodes')
if callbacks:
if isinstance(workflow_run_state.current_iteration_state, IterationState):
for callback in callbacks:
callback.on_workflow_iteration_started(
node_id=current_iteration_node.node_id,
node_type=NodeType.ITERATION,
node_run_index=workflow_run_state.workflow_node_steps,
node_data=current_iteration_node.node_data,
inputs=workflow_run_state.current_iteration_state.inputs,
predecessor_node_id=predecessor_node_id,
metadata=workflow_run_state.current_iteration_state.metadata.dict()
)
# add steps
workflow_run_state.workflow_node_steps += 1
def _workflow_iteration_next(self, graph: dict,
current_iteration_node: BaseIterationNode,
workflow_run_state: WorkflowRunState,
callbacks: list[BaseWorkflowCallback] = None) -> None:
"""
Workflow iteration next
:param workflow_run_state: workflow run state
:return:
"""
if callbacks:
if isinstance(workflow_run_state.current_iteration_state, IterationState):
for callback in callbacks:
callback.on_workflow_iteration_next(
node_id=current_iteration_node.node_id,
node_type=NodeType.ITERATION,
index=workflow_run_state.current_iteration_state.index,
node_run_index=workflow_run_state.workflow_node_steps,
output=workflow_run_state.current_iteration_state.get_current_output()
)
# clear ran nodes
workflow_run_state.workflow_node_runs = [
node_run for node_run in workflow_run_state.workflow_node_runs
if node_run.iteration_node_id != current_iteration_node.node_id
]
# clear variables in current iteration
nodes = graph.get('nodes')
nodes = [node for node in nodes if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id]
for node in nodes:
workflow_run_state.variable_pool.clear_node_variables(node_id=node.get('id'))
def _workflow_iteration_completed(self, current_iteration_node: BaseIterationNode,
workflow_run_state: WorkflowRunState,
callbacks: list[BaseWorkflowCallback] = None) -> None:
if callbacks:
if isinstance(workflow_run_state.current_iteration_state, IterationState):
for callback in callbacks:
callback.on_workflow_iteration_completed(
node_id=current_iteration_node.node_id,
node_type=NodeType.ITERATION,
node_run_index=workflow_run_state.workflow_node_steps,
outputs={
'output': workflow_run_state.current_iteration_state.outputs
}
)
def _get_next_overall_node(self, workflow_run_state: WorkflowRunState,
graph: dict,
predecessor_node: Optional[BaseNode] = None,
callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]:
callbacks: list[BaseWorkflowCallback] = None,
start_at: Optional[str] = None,
end_at: Optional[str] = None) -> Optional[BaseNode]:
"""
Get next node
multiple target nodes in the future.
@@ -354,16 +666,26 @@ class WorkflowEngineManager:
if not predecessor_node:
for node_config in nodes:
if node_config.get('data', {}).get('type', '') == NodeType.START.value:
return StartNode(
node_cls = None
if start_at:
if node_config.get('id') == start_at:
node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type')))
else:
if node_config.get('data', {}).get('type', '') == NodeType.START.value:
node_cls = StartNode
if node_cls:
return node_cls(
tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id,
workflow_id=workflow_run_state.workflow_id,
user_id=workflow_run_state.user_id,
user_from=workflow_run_state.user_from,
invoke_from=workflow_run_state.invoke_from,
config=node_config,
callbacks=callbacks
callbacks=callbacks,
workflow_call_depth=workflow_run_state.workflow_call_depth
)
else:
edges = graph.get('edges')
source_node_id = predecessor_node.node_id
@@ -390,6 +712,9 @@ class WorkflowEngineManager:
target_node_id = outgoing_edge.get('target')
if end_at and target_node_id == end_at:
return None
# fetch target node from target node id
target_node_config = None
for node in nodes:
@@ -409,9 +734,40 @@ class WorkflowEngineManager:
workflow_id=workflow_run_state.workflow_id,
user_id=workflow_run_state.user_id,
user_from=workflow_run_state.user_from,
invoke_from=workflow_run_state.invoke_from,
config=target_node_config,
callbacks=callbacks
callbacks=callbacks,
workflow_call_depth=workflow_run_state.workflow_call_depth
)
def _get_node(self, workflow_run_state: WorkflowRunState,
graph: dict,
node_id: str,
callbacks: list[BaseWorkflowCallback]) -> Optional[BaseNode]:
"""
Get node from graph by node id
"""
nodes = graph.get('nodes')
if not nodes:
return None
for node_config in nodes:
if node_config.get('id') == node_id:
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
node_cls = node_classes.get(node_type)
return node_cls(
tenant_id=workflow_run_state.tenant_id,
app_id=workflow_run_state.app_id,
workflow_id=workflow_run_state.workflow_id,
user_id=workflow_run_state.user_id,
user_from=workflow_run_state.user_from,
invoke_from=workflow_run_state.invoke_from,
config=node_config,
callbacks=callbacks,
workflow_call_depth=workflow_run_state.workflow_call_depth
)
return None
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
@@ -422,6 +778,15 @@ class WorkflowEngineManager:
"""
return time.perf_counter() - start_at > max_execution_time
def _check_node_has_ran(self, workflow_run_state: WorkflowRunState, node_id: str) -> bool:
"""
Check node has ran
"""
return bool([
node_and_result for node_and_result in workflow_run_state.workflow_node_runs
if node_and_result.node_id == node_id
])
def _run_workflow_node(self, workflow_run_state: WorkflowRunState,
node: BaseNode,
predecessor_node: Optional[BaseNode] = None,
@@ -432,7 +797,7 @@ class WorkflowEngineManager:
node_id=node.node_id,
node_type=node.node_type,
node_data=node.node_data,
node_run_index=len(workflow_run_state.workflow_nodes_and_results) + 1,
node_run_index=workflow_run_state.workflow_node_steps,
predecessor_node_id=predecessor_node.node_id if predecessor_node else None
)
@@ -446,6 +811,16 @@ class WorkflowEngineManager:
# add to workflow_nodes_and_results
workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result)
# add steps
workflow_run_state.workflow_node_steps += 1
# mark node as running
if workflow_run_state.current_iteration_state:
workflow_run_state.workflow_node_runs.append(WorkflowRunState.NodeRun(
node_id=node.node_id,
iteration_node_id=workflow_run_state.current_iteration_state.iteration_node_id
))
try:
# run node, result must have inputs, process_data, outputs, execution_metadata
node_run_result = node.run(
@@ -565,3 +940,53 @@ class WorkflowEngineManager:
new_value[key] = new_val
return new_value
def _mapping_user_inputs_to_variable_pool(self,
variable_mapping: dict,
user_inputs: dict,
variable_pool: VariablePool,
tenant_id: str,
node_instance: BaseNode):
for variable_key, variable_selector in variable_mapping.items():
if variable_key not in user_inputs:
raise ValueError(f'Variable key {variable_key} not found in user inputs.')
# fetch variable node id from variable selector
variable_node_id = variable_selector[0]
variable_key_list = variable_selector[1:]
# get value
value = user_inputs.get(variable_key)
# temp fix for image type
if node_instance.node_type == NodeType.LLM:
new_value = []
if isinstance(value, list):
node_data = node_instance.node_data
node_data = cast(LLMNodeData, node_data)
detail = node_data.vision.configs.detail if node_data.vision.configs else None
for item in value:
if isinstance(item, dict) and 'type' in item and item['type'] == 'image':
transfer_method = FileTransferMethod.value_of(item.get('transfer_method'))
file = FileVar(
tenant_id=tenant_id,
type=FileType.IMAGE,
transfer_method=transfer_method,
url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
related_id=item.get(
'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None),
)
new_value.append(file)
if new_value:
value = new_value
# append variable and value to variable pool
variable_pool.append_variable(
node_id=variable_node_id,
variable_key_list=variable_key_list,
value=value
)