Feat/workflow phase2 (#4687)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.entities.queue_entities import AppQueueEvent
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
@@ -71,6 +71,42 @@ class BaseWorkflowCallback(ABC):
|
||||
Publish text chunk
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any],
|
||||
) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
|
@@ -7,3 +7,16 @@ from pydantic import BaseModel
|
||||
class BaseNodeData(ABC, BaseModel):
|
||||
title: str
|
||||
desc: Optional[str] = None
|
||||
|
||||
class BaseIterationNodeData(BaseNodeData):
|
||||
start_node_id: str
|
||||
|
||||
class BaseIterationState(BaseModel):
|
||||
iteration_node_id: str
|
||||
index: int
|
||||
inputs: dict
|
||||
|
||||
class MetaData(BaseModel):
|
||||
pass
|
||||
|
||||
metadata: MetaData
|
@@ -21,7 +21,11 @@ class NodeType(Enum):
|
||||
QUESTION_CLASSIFIER = 'question-classifier'
|
||||
HTTP_REQUEST = 'http-request'
|
||||
TOOL = 'tool'
|
||||
VARIABLE_AGGREGATOR = 'variable-aggregator'
|
||||
VARIABLE_ASSIGNER = 'variable-assigner'
|
||||
LOOP = 'loop'
|
||||
ITERATION = 'iteration'
|
||||
PARAMETER_EXTRACTOR = 'parameter-extractor'
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> 'NodeType':
|
||||
@@ -68,6 +72,8 @@ class NodeRunMetadataKey(Enum):
|
||||
TOTAL_PRICE = 'total_price'
|
||||
CURRENCY = 'currency'
|
||||
TOOL_INFO = 'tool_info'
|
||||
ITERATION_ID = 'iteration_id'
|
||||
ITERATION_INDEX = 'iteration_index'
|
||||
|
||||
|
||||
class NodeRunResult(BaseModel):
|
||||
|
@@ -90,3 +90,12 @@ class VariablePool:
|
||||
raise ValueError(f'Invalid value type: {target_value_type.value}')
|
||||
|
||||
return value
|
||||
|
||||
def clear_node_variables(self, node_id: str) -> None:
|
||||
"""
|
||||
Clear node variables
|
||||
:param node_id: node id
|
||||
:return:
|
||||
"""
|
||||
if node_id in self.variables_mapping:
|
||||
self.variables_mapping.pop(node_id)
|
@@ -1,5 +1,9 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode, UserFrom
|
||||
@@ -22,6 +26,9 @@ class WorkflowRunState:
|
||||
workflow_type: WorkflowType
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
workflow_call_depth: int
|
||||
|
||||
start_at: float
|
||||
variable_pool: VariablePool
|
||||
@@ -30,20 +37,37 @@ class WorkflowRunState:
|
||||
|
||||
workflow_nodes_and_results: list[WorkflowNodeAndResult]
|
||||
|
||||
class NodeRun(BaseModel):
|
||||
node_id: str
|
||||
iteration_node_id: str
|
||||
|
||||
workflow_node_runs: list[NodeRun]
|
||||
workflow_node_steps: int
|
||||
|
||||
current_iteration_state: Optional[BaseIterationState]
|
||||
|
||||
def __init__(self, workflow: Workflow,
|
||||
start_at: float,
|
||||
variable_pool: VariablePool,
|
||||
user_id: str,
|
||||
user_from: UserFrom):
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_call_depth: int):
|
||||
self.workflow_id = workflow.id
|
||||
self.tenant_id = workflow.tenant_id
|
||||
self.app_id = workflow.app_id
|
||||
self.workflow_type = WorkflowType.value_of(workflow.type)
|
||||
self.user_id = user_id
|
||||
self.user_from = user_from
|
||||
self.invoke_from = invoke_from
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
|
||||
self.start_at = start_at
|
||||
self.variable_pool = variable_pool
|
||||
|
||||
self.total_tokens = 0
|
||||
self.workflow_nodes_and_results = []
|
||||
|
||||
self.current_iteration_state = None
|
||||
self.workflow_node_steps = 1
|
||||
self.workflow_node_runs = []
|
@@ -2,8 +2,9 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState, BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
|
||||
@@ -37,6 +38,9 @@ class BaseNode(ABC):
|
||||
workflow_id: str
|
||||
user_id: str
|
||||
user_from: UserFrom
|
||||
invoke_from: InvokeFrom
|
||||
|
||||
workflow_call_depth: int
|
||||
|
||||
node_id: str
|
||||
node_data: BaseNodeData
|
||||
@@ -49,13 +53,17 @@ class BaseNode(ABC):
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
config: dict,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
workflow_call_depth: int = 0) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
self.workflow_id = workflow_id
|
||||
self.user_id = user_id
|
||||
self.user_from = user_from
|
||||
self.invoke_from = invoke_from
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
|
||||
self.node_id = config.get("id")
|
||||
if not self.node_id:
|
||||
@@ -140,3 +148,38 @@ class BaseNode(ABC):
|
||||
:return:
|
||||
"""
|
||||
return self._node_type
|
||||
|
||||
class BaseIterationNode(BaseNode):
|
||||
@abstractmethod
|
||||
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
|
||||
"""
|
||||
Run node
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def run(self, variable_pool: VariablePool) -> BaseIterationState:
|
||||
"""
|
||||
Run node entry
|
||||
:param variable_pool: variable pool
|
||||
:return:
|
||||
"""
|
||||
return self._run(variable_pool=variable_pool)
|
||||
|
||||
def get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
:param graph: graph
|
||||
:return: next node id
|
||||
"""
|
||||
return self._get_next_iteration(variable_pool, state)
|
||||
|
||||
@abstractmethod
|
||||
def _get_next_iteration(self, variable_pool: VariablePool, state: BaseIterationState) -> NodeRunResult | str:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
:param graph: graph
|
||||
:return: next node id
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
39
api/core/workflow/nodes/iteration/entities.py
Normal file
39
api/core/workflow/nodes/iteration/entities.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
|
||||
|
||||
|
||||
class IterationNodeData(BaseIterationNodeData):
|
||||
"""
|
||||
Iteration Node Data.
|
||||
"""
|
||||
parent_loop_id: Optional[str] # redundant field, not used currently
|
||||
iterator_selector: list[str] # variable selector
|
||||
output_selector: list[str] # output selector
|
||||
|
||||
class IterationState(BaseIterationState):
|
||||
"""
|
||||
Iteration State.
|
||||
"""
|
||||
outputs: list[Any] = None
|
||||
current_output: Optional[Any] = None
|
||||
|
||||
class MetaData(BaseIterationState.MetaData):
|
||||
"""
|
||||
Data.
|
||||
"""
|
||||
iterator_length: int
|
||||
|
||||
def get_last_output(self) -> Optional[Any]:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
if self.outputs:
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> Optional[Any]:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
return self.current_output
|
119
api/core/workflow/nodes/iteration/iteration_node.py
Normal file
119
api/core/workflow/nodes/iteration/iteration_node.py
Normal file
@@ -0,0 +1,119 @@
|
||||
from typing import cast
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseIterationNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class IterationNode(BaseIterationNode):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
_node_data_cls = IterationNodeData
|
||||
_node_type = NodeType.ITERATION
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
iterator = variable_pool.get_variable_value(cast(IterationNodeData, self.node_data).iterator_selector)
|
||||
|
||||
if not isinstance(iterator, list):
|
||||
raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.")
|
||||
|
||||
state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={
|
||||
'iterator_selector': iterator
|
||||
}, outputs=[], metadata=IterationState.MetaData(
|
||||
iterator_length=len(iterator) if iterator is not None else 0
|
||||
))
|
||||
|
||||
self._set_current_iteration_variable(variable_pool, state)
|
||||
return state
|
||||
|
||||
def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
:param graph: graph
|
||||
:return: next node id
|
||||
"""
|
||||
# resolve current output
|
||||
self._resolve_current_output(variable_pool, state)
|
||||
# move to next iteration
|
||||
self._next_iteration(variable_pool, state)
|
||||
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
if self._reached_iteration_limit(variable_pool, state):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
'output': jsonable_encoder(state.outputs)
|
||||
}
|
||||
)
|
||||
|
||||
return node_data.start_node_id
|
||||
|
||||
def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Set current iteration variable.
|
||||
:variable_pool: variable pool
|
||||
"""
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
|
||||
variable_pool.append_variable(self.node_id, ['index'], state.index)
|
||||
# get the iterator value
|
||||
iterator = variable_pool.get_variable_value(node_data.iterator_selector)
|
||||
|
||||
if iterator is None or not isinstance(iterator, list):
|
||||
return
|
||||
|
||||
if state.index < len(iterator):
|
||||
variable_pool.append_variable(self.node_id, ['item'], iterator[state.index])
|
||||
|
||||
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Move to next iteration.
|
||||
:param variable_pool: variable pool
|
||||
"""
|
||||
state.index += 1
|
||||
self._set_current_iteration_variable(variable_pool, state)
|
||||
|
||||
def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Check if iteration limit is reached.
|
||||
:return: True if iteration limit is reached, False otherwise
|
||||
"""
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
iterator = variable_pool.get_variable_value(node_data.iterator_selector)
|
||||
|
||||
if iterator is None or not isinstance(iterator, list):
|
||||
return True
|
||||
|
||||
return state.index >= len(iterator)
|
||||
|
||||
def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Resolve current output.
|
||||
:param variable_pool: variable pool
|
||||
"""
|
||||
output_selector = cast(IterationNodeData, self.node_data).output_selector
|
||||
output = variable_pool.get_variable_value(output_selector)
|
||||
# clear the output for this iteration
|
||||
variable_pool.append_variable(self.node_id, output_selector[1:], None)
|
||||
state.current_output = output
|
||||
if output is not None:
|
||||
state.outputs.append(output)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
'input_selector': node_data.iterator_selector,
|
||||
}
|
0
api/core/workflow/nodes/loop/__init__.py
Normal file
0
api/core/workflow/nodes/loop/__init__.py
Normal file
13
api/core/workflow/nodes/loop/entities.py
Normal file
13
api/core/workflow/nodes/loop/entities.py
Normal file
@@ -0,0 +1,13 @@
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationNodeData, BaseIterationState
|
||||
|
||||
|
||||
class LoopNodeData(BaseIterationNodeData):
|
||||
"""
|
||||
Loop Node Data.
|
||||
"""
|
||||
|
||||
class LoopState(BaseIterationState):
|
||||
"""
|
||||
Loop State.
|
||||
"""
|
20
api/core/workflow/nodes/loop/loop_node.py
Normal file
20
api/core/workflow/nodes/loop/loop_node.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseIterationNode
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
|
||||
|
||||
|
||||
class LoopNode(BaseIterationNode):
|
||||
"""
|
||||
Loop Node.
|
||||
"""
|
||||
_node_data_cls = LoopNodeData
|
||||
_node_type = NodeType.LOOP
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> LoopState:
|
||||
return super()._run(variable_pool)
|
||||
|
||||
def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
"""
|
85
api/core/workflow/nodes/parameter_extractor/entities.py
Normal file
85
api/core/workflow/nodes/parameter_extractor/entities.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
"""
|
||||
Model Config.
|
||||
"""
|
||||
provider: str
|
||||
name: str
|
||||
mode: str
|
||||
completion_params: dict[str, Any] = {}
|
||||
|
||||
class ParameterConfig(BaseModel):
|
||||
"""
|
||||
Parameter Config.
|
||||
"""
|
||||
name: str
|
||||
type: Literal['string', 'number', 'bool', 'select', 'array[string]', 'array[number]', 'array[object]']
|
||||
options: Optional[list[str]]
|
||||
description: str
|
||||
required: bool
|
||||
|
||||
@validator('name', pre=True, always=True)
|
||||
def validate_name(cls, value):
|
||||
if not value:
|
||||
raise ValueError('Parameter name is required')
|
||||
if value in ['__reason', '__is_success']:
|
||||
raise ValueError('Invalid parameter name, __reason and __is_success are reserved')
|
||||
return value
|
||||
|
||||
class ParameterExtractorNodeData(BaseNodeData):
|
||||
"""
|
||||
Parameter Extractor Node Data.
|
||||
"""
|
||||
model: ModelConfig
|
||||
query: list[str]
|
||||
parameters: list[ParameterConfig]
|
||||
instruction: Optional[str]
|
||||
memory: Optional[MemoryConfig]
|
||||
reasoning_mode: Literal['function_call', 'prompt']
|
||||
|
||||
@validator('reasoning_mode', pre=True, always=True)
|
||||
def set_reasoning_mode(cls, v):
|
||||
return v or 'function_call'
|
||||
|
||||
def get_parameter_json_schema(self) -> dict:
|
||||
"""
|
||||
Get parameter json schema.
|
||||
|
||||
:return: parameter json schema
|
||||
"""
|
||||
parameters = {
|
||||
'type': 'object',
|
||||
'properties': {},
|
||||
'required': []
|
||||
}
|
||||
|
||||
for parameter in self.parameters:
|
||||
parameter_schema = {
|
||||
'description': parameter.description
|
||||
}
|
||||
|
||||
if parameter.type in ['string', 'select']:
|
||||
parameter_schema['type'] = 'string'
|
||||
elif parameter.type.startswith('array'):
|
||||
parameter_schema['type'] = 'array'
|
||||
nested_type = parameter.type[6:-1]
|
||||
parameter_schema['items'] = {'type': nested_type}
|
||||
else:
|
||||
parameter_schema['type'] = parameter.type
|
||||
|
||||
if parameter.type == 'select':
|
||||
parameter_schema['enum'] = parameter.options
|
||||
|
||||
parameters['properties'][parameter.name] = parameter_schema
|
||||
|
||||
if parameter.required:
|
||||
parameters['required'].append(parameter.name)
|
||||
|
||||
return parameters
|
@@ -0,0 +1,711 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from core.workflow.nodes.parameter_extractor.prompts import (
|
||||
CHAT_EXAMPLE,
|
||||
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE,
|
||||
COMPLETION_GENERATE_JSON_PROMPT,
|
||||
FUNCTION_CALLING_EXTRACTOR_EXAMPLE,
|
||||
FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT,
|
||||
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class ParameterExtractorNode(LLMNode):
|
||||
"""
|
||||
Parameter Extractor Node.
|
||||
"""
|
||||
_node_data_cls = ParameterExtractorNodeData
|
||||
_node_type = NodeType.PARAMETER_EXTRACTOR
|
||||
|
||||
_model_instance: Optional[ModelInstance] = None
|
||||
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
return {
|
||||
"model": {
|
||||
"prompt_templates": {
|
||||
"completion_model": {
|
||||
"conversation_histories_role": {
|
||||
"user_prefix": "Human",
|
||||
"assistant_prefix": "Assistant"
|
||||
},
|
||||
"stop": ["Human:"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
|
||||
node_data = cast(ParameterExtractorNodeData, self.node_data)
|
||||
query = variable_pool.get_variable_value(node_data.query)
|
||||
if not query:
|
||||
raise ValueError("Query not found")
|
||||
|
||||
inputs={
|
||||
'query': query,
|
||||
'parameters': jsonable_encoder(node_data.parameters),
|
||||
'instruction': jsonable_encoder(node_data.instruction),
|
||||
}
|
||||
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise ValueError("Model is not a Large Language Model")
|
||||
|
||||
llm_model = model_instance.model_type_instance
|
||||
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
|
||||
# fetch memory
|
||||
memory = self._fetch_memory(node_data.memory, variable_pool, model_instance)
|
||||
|
||||
if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]) \
|
||||
and node_data.reasoning_mode == 'function_call':
|
||||
# use function call
|
||||
prompt_messages, prompt_message_tools = self._generate_function_call_prompt(
|
||||
node_data, query, variable_pool, model_config, memory
|
||||
)
|
||||
else:
|
||||
# use prompt engineering
|
||||
prompt_messages = self._generate_prompt_engineering_prompt(node_data, query, variable_pool, model_config, memory)
|
||||
prompt_message_tools = []
|
||||
|
||||
process_data = {
|
||||
'model_mode': model_config.mode,
|
||||
'prompts': PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode,
|
||||
prompt_messages=prompt_messages
|
||||
),
|
||||
'usage': None,
|
||||
'function': {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]),
|
||||
'tool_call': None,
|
||||
}
|
||||
|
||||
try:
|
||||
text, usage, tool_call = self._invoke_llm(
|
||||
node_data_model=node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=prompt_message_tools,
|
||||
stop=model_config.stop,
|
||||
)
|
||||
process_data['usage'] = jsonable_encoder(usage)
|
||||
process_data['tool_call'] = jsonable_encoder(tool_call)
|
||||
process_data['llm_text'] = text
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=inputs,
|
||||
process_data={},
|
||||
outputs={
|
||||
'__is_success': 0,
|
||||
'__reason': str(e)
|
||||
},
|
||||
error=str(e),
|
||||
metadata={}
|
||||
)
|
||||
|
||||
error = None
|
||||
|
||||
if tool_call:
|
||||
result = self._extract_json_from_tool_call(tool_call)
|
||||
else:
|
||||
result = self._extract_complete_json_response(text)
|
||||
if not result:
|
||||
result = self._generate_default_result(node_data)
|
||||
error = "Failed to extract result from function call or text response, using empty result."
|
||||
|
||||
try:
|
||||
result = self._validate_result(node_data, result)
|
||||
except Exception as e:
|
||||
error = str(e)
|
||||
|
||||
# transform result into standard format
|
||||
result = self._transform_result(node_data, result)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs={
|
||||
'__is_success': 1 if not error else 0,
|
||||
'__reason': error,
|
||||
**result
|
||||
},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
NodeRunMetadataKey.CURRENCY: usage.currency
|
||||
}
|
||||
)
|
||||
|
||||
def _invoke_llm(self, node_data_model: ModelConfig,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: list[PromptMessageTool],
|
||||
stop: list[str]) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]:
|
||||
"""
|
||||
Invoke large language model
|
||||
:param node_data_model: node data model
|
||||
:param model_instance: model instance
|
||||
:param prompt_messages: prompt messages
|
||||
:param stop: stop
|
||||
:return:
|
||||
"""
|
||||
db.session.close()
|
||||
|
||||
invoke_result = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=node_data_model.completion_params,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user=self.user_id,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
if not isinstance(invoke_result, LLMResult):
|
||||
raise ValueError(f"Invalid invoke result: {invoke_result}")
|
||||
|
||||
text = invoke_result.message.content
|
||||
usage = invoke_result.usage
|
||||
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
|
||||
|
||||
# deduct quota
|
||||
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
|
||||
return text, usage, tool_call
|
||||
|
||||
def _generate_function_call_prompt(self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
|
||||
"""
|
||||
Generate function call prompt.
|
||||
"""
|
||||
query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format(content=query, structure=json.dumps(node_data.get_parameter_json_schema()))
|
||||
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
|
||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, memory, rest_token)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query='',
|
||||
files=[],
|
||||
context='',
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config
|
||||
)
|
||||
|
||||
# find last user message
|
||||
last_user_message_idx = -1
|
||||
for i, prompt_message in enumerate(prompt_messages):
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
last_user_message_idx = i
|
||||
|
||||
# add function call messages before last user message
|
||||
example_messages = []
|
||||
for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE:
|
||||
id = uuid.uuid4().hex
|
||||
example_messages.extend([
|
||||
UserPromptMessage(content=example['user']['query']),
|
||||
AssistantPromptMessage(
|
||||
content=example['assistant']['text'],
|
||||
tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=id,
|
||||
type='function',
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=example['assistant']['function_call']['name'],
|
||||
arguments=json.dumps(example['assistant']['function_call']['parameters']
|
||||
)
|
||||
))
|
||||
]
|
||||
),
|
||||
ToolPromptMessage(
|
||||
content='Great! You have called the function with the correct parameters.',
|
||||
tool_call_id=id
|
||||
),
|
||||
AssistantPromptMessage(
|
||||
content='I have extracted the parameters, let\'s move on.',
|
||||
)
|
||||
])
|
||||
|
||||
prompt_messages = prompt_messages[:last_user_message_idx] + \
|
||||
example_messages + prompt_messages[last_user_message_idx:]
|
||||
|
||||
# generate tool
|
||||
tool = PromptMessageTool(
|
||||
name=FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
description='Extract parameters from the natural language text',
|
||||
parameters=node_data.get_parameter_json_schema(),
|
||||
)
|
||||
|
||||
return prompt_messages, [tool]
|
||||
|
||||
def _generate_prompt_engineering_prompt(self,
|
||||
data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate prompt engineering prompt.
|
||||
"""
|
||||
model_mode = ModelMode.value_of(data.model.mode)
|
||||
|
||||
if model_mode == ModelMode.COMPLETION:
|
||||
return self._generate_prompt_engineering_completion_prompt(
|
||||
data, query, variable_pool, model_config, memory
|
||||
)
|
||||
elif model_mode == ModelMode.CHAT:
|
||||
return self._generate_prompt_engineering_chat_prompt(
|
||||
data, query, variable_pool, model_config, memory
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid model mode: {model_mode}")
|
||||
|
||||
def _generate_prompt_engineering_completion_prompt(self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate completion prompt.
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, memory, rest_token)
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={
|
||||
'structure': json.dumps(node_data.get_parameter_json_schema())
|
||||
},
|
||||
query='',
|
||||
files=[],
|
||||
context='',
|
||||
memory_config=node_data.memory,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _generate_prompt_engineering_chat_prompt(self,
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Generate chat prompt.
|
||||
"""
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
rest_token = self._calculate_rest_token(node_data, query, variable_pool, model_config, '')
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(
|
||||
node_data,
|
||||
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
|
||||
structure=json.dumps(node_data.get_parameter_json_schema()),
|
||||
text=query
|
||||
),
|
||||
variable_pool, memory, rest_token
|
||||
)
|
||||
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query='',
|
||||
files=[],
|
||||
context='',
|
||||
memory_config=node_data.memory,
|
||||
memory=memory,
|
||||
model_config=model_config
|
||||
)
|
||||
|
||||
# find last user message
|
||||
last_user_message_idx = -1
|
||||
for i, prompt_message in enumerate(prompt_messages):
|
||||
if prompt_message.role == PromptMessageRole.USER:
|
||||
last_user_message_idx = i
|
||||
|
||||
# add example messages before last user message
|
||||
example_messages = []
|
||||
for example in CHAT_EXAMPLE:
|
||||
example_messages.extend([
|
||||
UserPromptMessage(content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format(
|
||||
structure=json.dumps(example['user']['json']),
|
||||
text=example['user']['query'],
|
||||
)),
|
||||
AssistantPromptMessage(
|
||||
content=json.dumps(example['assistant']['json']),
|
||||
)
|
||||
])
|
||||
|
||||
prompt_messages = prompt_messages[:last_user_message_idx] + \
|
||||
example_messages + prompt_messages[last_user_message_idx:]
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _validate_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
"""
|
||||
Validate result.
|
||||
"""
|
||||
if len(data.parameters) != len(result):
|
||||
raise ValueError("Invalid number of parameters")
|
||||
|
||||
for parameter in data.parameters:
|
||||
if parameter.required and parameter.name not in result:
|
||||
raise ValueError(f"Parameter {parameter.name} is required")
|
||||
|
||||
if parameter.type == 'select' and parameter.options and result.get(parameter.name) not in parameter.options:
|
||||
raise ValueError(f"Invalid `select` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == 'number' and not isinstance(result.get(parameter.name), int | float):
|
||||
raise ValueError(f"Invalid `number` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == 'bool' and not isinstance(result.get(parameter.name), bool):
|
||||
raise ValueError(f"Invalid `bool` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type == 'string' and not isinstance(result.get(parameter.name), str):
|
||||
raise ValueError(f"Invalid `string` value for parameter {parameter.name}")
|
||||
|
||||
if parameter.type.startswith('array'):
|
||||
if not isinstance(result.get(parameter.name), list):
|
||||
raise ValueError(f"Invalid `array` value for parameter {parameter.name}")
|
||||
nested_type = parameter.type[6:-1]
|
||||
for item in result.get(parameter.name):
|
||||
if nested_type == 'number' and not isinstance(item, int | float):
|
||||
raise ValueError(f"Invalid `array[number]` value for parameter {parameter.name}")
|
||||
if nested_type == 'string' and not isinstance(item, str):
|
||||
raise ValueError(f"Invalid `array[string]` value for parameter {parameter.name}")
|
||||
if nested_type == 'object' and not isinstance(item, dict):
|
||||
raise ValueError(f"Invalid `array[object]` value for parameter {parameter.name}")
|
||||
return result
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict) -> dict:
|
||||
"""
|
||||
Transform result into standard format.
|
||||
"""
|
||||
transformed_result = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.name in result:
|
||||
# transform value
|
||||
if parameter.type == 'number':
|
||||
if isinstance(result[parameter.name], int | float):
|
||||
transformed_result[parameter.name] = result[parameter.name]
|
||||
elif isinstance(result[parameter.name], str):
|
||||
try:
|
||||
if '.' in result[parameter.name]:
|
||||
result[parameter.name] = float(result[parameter.name])
|
||||
else:
|
||||
result[parameter.name] = int(result[parameter.name])
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
# TODO: bool is not supported in the current version
|
||||
# elif parameter.type == 'bool':
|
||||
# if isinstance(result[parameter.name], bool):
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
# elif isinstance(result[parameter.name], str):
|
||||
# if result[parameter.name].lower() in ['true', 'false']:
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name].lower() == 'true')
|
||||
# elif isinstance(result[parameter.name], int):
|
||||
# transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
elif parameter.type in ['string', 'select']:
|
||||
if isinstance(result[parameter.name], str):
|
||||
transformed_result[parameter.name] = result[parameter.name]
|
||||
elif parameter.type.startswith('array'):
|
||||
if isinstance(result[parameter.name], list):
|
||||
nested_type = parameter.type[6:-1]
|
||||
transformed_result[parameter.name] = []
|
||||
for item in result[parameter.name]:
|
||||
if nested_type == 'number':
|
||||
if isinstance(item, int | float):
|
||||
transformed_result[parameter.name].append(item)
|
||||
elif isinstance(item, str):
|
||||
try:
|
||||
if '.' in item:
|
||||
transformed_result[parameter.name].append(float(item))
|
||||
else:
|
||||
transformed_result[parameter.name].append(int(item))
|
||||
except ValueError:
|
||||
pass
|
||||
elif nested_type == 'string':
|
||||
if isinstance(item, str):
|
||||
transformed_result[parameter.name].append(item)
|
||||
elif nested_type == 'object':
|
||||
if isinstance(item, dict):
|
||||
transformed_result[parameter.name].append(item)
|
||||
|
||||
if parameter.name not in transformed_result:
|
||||
if parameter.type == 'number':
|
||||
transformed_result[parameter.name] = 0
|
||||
elif parameter.type == 'bool':
|
||||
transformed_result[parameter.name] = False
|
||||
elif parameter.type in ['string', 'select']:
|
||||
transformed_result[parameter.name] = ''
|
||||
elif parameter.type.startswith('array'):
|
||||
transformed_result[parameter.name] = []
|
||||
|
||||
return transformed_result
|
||||
|
||||
def _extract_complete_json_response(self, result: str) -> Optional[dict]:
|
||||
"""
|
||||
Extract complete json response.
|
||||
"""
|
||||
def extract_json(text):
|
||||
"""
|
||||
From a given JSON started from '{' or '[' extract the complete JSON object.
|
||||
"""
|
||||
stack = []
|
||||
for i, c in enumerate(text):
|
||||
if c == '{' or c == '[':
|
||||
stack.append(c)
|
||||
elif c == '}' or c == ']':
|
||||
# check if stack is empty
|
||||
if not stack:
|
||||
return text[:i]
|
||||
# check if the last element in stack is matching
|
||||
if (c == '}' and stack[-1] == '{') or (c == ']' and stack[-1] == '['):
|
||||
stack.pop()
|
||||
if not stack:
|
||||
return text[:i+1]
|
||||
else:
|
||||
return text[:i]
|
||||
return None
|
||||
|
||||
# extract json from the text
|
||||
for idx in range(len(result)):
|
||||
if result[idx] == '{' or result[idx] == '[':
|
||||
json_str = extract_json(result[idx:])
|
||||
if json_str:
|
||||
try:
|
||||
return json.loads(json_str)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> Optional[dict]:
|
||||
"""
|
||||
Extract json from tool call.
|
||||
"""
|
||||
if not tool_call or not tool_call.function.arguments:
|
||||
return None
|
||||
|
||||
return json.loads(tool_call.function.arguments)
|
||||
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict:
|
||||
"""
|
||||
Generate default result.
|
||||
"""
|
||||
result = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.type == 'number':
|
||||
result[parameter.name] = 0
|
||||
elif parameter.type == 'bool':
|
||||
result[parameter.name] = False
|
||||
elif parameter.type in ['string', 'select']:
|
||||
result[parameter.name] = ''
|
||||
|
||||
return result
|
||||
|
||||
def _render_instruction(self, instruction: str, variable_pool: VariablePool) -> str:
|
||||
"""
|
||||
Render instruction.
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(instruction)
|
||||
inputs = {}
|
||||
for selector in variable_template_parser.extract_variable_selectors():
|
||||
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
|
||||
|
||||
return variable_template_parser.format(inputs)
|
||||
|
||||
def _get_function_calling_prompt_template(self, node_data: ParameterExtractorNodeData, query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000) \
|
||||
-> list[ChatModelMessage]:
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
input_text = query
|
||||
memory_str = ''
|
||||
instruction = self._render_instruction(node_data.instruction or '', variable_pool)
|
||||
|
||||
if memory:
|
||||
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
|
||||
message_limit=node_data.memory.window.size)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction)
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=input_text
|
||||
)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
else:
|
||||
raise ValueError(f"Model mode {model_mode} not support.")
|
||||
|
||||
def _get_prompt_engineering_prompt_template(self, node_data: ParameterExtractorNodeData, query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: Optional[TokenBufferMemory],
|
||||
max_token_limit: int = 2000) \
|
||||
-> list[ChatModelMessage]:
|
||||
|
||||
model_mode = ModelMode.value_of(node_data.model.mode)
|
||||
input_text = query
|
||||
memory_str = ''
|
||||
instruction = self._render_instruction(node_data.instruction or '', variable_pool)
|
||||
|
||||
if memory:
|
||||
memory_str = memory.get_history_prompt_text(max_token_limit=max_token_limit,
|
||||
message_limit=node_data.memory.window.size)
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = ChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM,
|
||||
text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction)
|
||||
)
|
||||
user_prompt_message = ChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=input_text
|
||||
)
|
||||
return [system_prompt_messages, user_prompt_message]
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
return CompletionModelPromptTemplate(
|
||||
text=COMPLETION_GENERATE_JSON_PROMPT.format(histories=memory_str,
|
||||
text=input_text,
|
||||
instruction=instruction)
|
||||
.replace('{γγγ', '')
|
||||
.replace('}γγγ', '')
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Model mode {model_mode} not support.")
|
||||
|
||||
def _calculate_rest_token(self, node_data: ParameterExtractorNodeData, query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context: Optional[str]) -> int:
|
||||
prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True)
|
||||
|
||||
model_instance, model_config = self._fetch_model_config(node_data.model)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise ValueError("Model is not a Large Language Model")
|
||||
|
||||
llm_model = model_instance.model_type_instance
|
||||
model_schema = llm_model.get_model_schema(model_config.model, model_config.credentials)
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
|
||||
if set(model_schema.features or []) & set([ModelFeature.MULTI_TOOL_CALL, ModelFeature.MULTI_TOOL_CALL]):
|
||||
prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
else:
|
||||
prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000)
|
||||
|
||||
prompt_messages = prompt_transform.get_prompt(
|
||||
prompt_template=prompt_template,
|
||||
inputs={},
|
||||
query='',
|
||||
files=[],
|
||||
context=context,
|
||||
memory_config=node_data.memory,
|
||||
memory=None,
|
||||
model_config=model_config
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
model_config.model,
|
||||
model_config.credentials,
|
||||
prompt_messages
|
||||
) + 1000 # add 1000 to ensure tool call messages
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
if (parameter_rule.name == 'max_tokens'
|
||||
or (parameter_rule.use_template and parameter_rule.use_template == 'max_tokens')):
|
||||
max_tokens = (model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(parameter_rule.use_template)) or 0
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
def _fetch_model_config(self, node_data_model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
"""
|
||||
Fetch model config.
|
||||
"""
|
||||
if not self._model_instance or not self._model_config:
|
||||
self._model_instance, self._model_config = super()._fetch_model_config(node_data_model)
|
||||
|
||||
return self._model_instance, self._model_config
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: ParameterExtractorNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = node_data
|
||||
|
||||
variable_mapping = {
|
||||
'query': node_data.query
|
||||
}
|
||||
|
||||
if node_data.instruction:
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.instruction)
|
||||
for selector in variable_template_parser.extract_variable_selectors():
|
||||
variable_mapping[selector.variable] = selector.value_selector
|
||||
|
||||
return variable_mapping
|
206
api/core/workflow/nodes/parameter_extractor/prompts.py
Normal file
206
api/core/workflow/nodes/parameter_extractor/prompts.py
Normal file
@@ -0,0 +1,206 @@
|
||||
FUNCTION_CALLING_EXTRACTOR_NAME = 'extract_parameters'
|
||||
|
||||
FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy.
|
||||
### Task
|
||||
Always call the `{FUNCTION_CALLING_EXTRACTOR_NAME}` function with the correct parameters. Ensure that the information extraction is contextual and aligns with the provided criteria.
|
||||
### Memory
|
||||
Here is the chat history between the human and assistant, provided within <histories> tags:
|
||||
<histories>
|
||||
\x7bhistories\x7d
|
||||
</histories>
|
||||
### Instructions:
|
||||
Some additional information is provided below. Always adhere to these instructions as closely as possible:
|
||||
<instruction>
|
||||
\x7binstruction\x7d
|
||||
</instruction>
|
||||
Steps:
|
||||
1. Review the chat history provided within the <histories> tags.
|
||||
2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text.
|
||||
3. Generate a well-formatted output using the defined functions and arguments.
|
||||
4. Use the `extract_parameter` function to create structured outputs with appropriate parameters.
|
||||
5. Do not include any XML tags in your output.
|
||||
### Example
|
||||
To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples.
|
||||
### Final Output
|
||||
Produce well-formatted function calls in json without XML tags, as shown in the example.
|
||||
"""
|
||||
|
||||
FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside <context></context> XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside <structure></structure> XML tags.
|
||||
<context>
|
||||
\x7bcontent\x7d
|
||||
</context>
|
||||
|
||||
<structure>
|
||||
\x7bstructure\x7d
|
||||
</structure>
|
||||
"""
|
||||
|
||||
FUNCTION_CALLING_EXTRACTOR_EXAMPLE = [{
|
||||
'user': {
|
||||
'query': 'What is the weather today in SF?',
|
||||
'function': {
|
||||
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'location': {
|
||||
'type': 'string',
|
||||
'description': 'The location to get the weather information',
|
||||
'required': True
|
||||
},
|
||||
},
|
||||
'required': ['location']
|
||||
}
|
||||
}
|
||||
},
|
||||
'assistant': {
|
||||
'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the location parameter.',
|
||||
'function_call' : {
|
||||
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
'parameters': {
|
||||
'location': 'San Francisco'
|
||||
}
|
||||
}
|
||||
}
|
||||
}, {
|
||||
'user': {
|
||||
'query': 'I want to eat some apple pie.',
|
||||
'function': {
|
||||
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'food': {
|
||||
'type': 'string',
|
||||
'description': 'The food to eat',
|
||||
'required': True
|
||||
}
|
||||
},
|
||||
'required': ['food']
|
||||
}
|
||||
}
|
||||
},
|
||||
'assistant': {
|
||||
'text': 'I need always call the function with the correct parameters. in this case, I need to call the function with the food parameter.',
|
||||
'function_call' : {
|
||||
'name': FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
'parameters': {
|
||||
'food': 'apple pie'
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
|
||||
COMPLETION_GENERATE_JSON_PROMPT = """### Instructions:
|
||||
Some extra information are provided below, I should always follow the instructions as possible as I can.
|
||||
<instructions>
|
||||
{instruction}
|
||||
</instructions>
|
||||
|
||||
### Extract parameter Workflow
|
||||
I need to extract the following information from the input text. The <information to be extracted> tag specifies the 'type', 'description' and 'required' of the information to be extracted.
|
||||
<information to be extracted>
|
||||
{{ structure }}
|
||||
</information to be extracted>
|
||||
|
||||
Step 1: Carefully read the input and understand the structure of the expected output.
|
||||
Step 2: Extract relevant parameters from the provided text based on the name and description of object.
|
||||
Step 3: Structure the extracted parameters to JSON object as specified in <structure>.
|
||||
Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted.
|
||||
|
||||
### Memory
|
||||
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
<histories>
|
||||
{histories}
|
||||
</histories>
|
||||
|
||||
### Structure
|
||||
Here is the structure of the expected output, I should always follow the output structure.
|
||||
{{γγγ
|
||||
'properties1': 'relevant text extracted from input',
|
||||
'properties2': 'relevant text extracted from input',
|
||||
}}γγγ
|
||||
|
||||
### Input Text
|
||||
Inside <text></text> XML tags, there is a text that I should extract parameters and convert to a JSON object.
|
||||
<text>
|
||||
{text}
|
||||
</text>
|
||||
|
||||
### Answer
|
||||
I should always output a valid JSON object. Output nothing other than the JSON object.
|
||||
```JSON
|
||||
"""
|
||||
|
||||
CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object.
|
||||
The structure of the JSON object you can found in the instructions.
|
||||
|
||||
### Memory
|
||||
Here is the chat histories between human and assistant, inside <histories></histories> XML tags.
|
||||
<histories>
|
||||
{histories}
|
||||
</histories>
|
||||
|
||||
### Instructions:
|
||||
Some extra information are provided below, you should always follow the instructions as possible as you can.
|
||||
<instructions>
|
||||
{{instructions}}
|
||||
</instructions>
|
||||
"""
|
||||
|
||||
CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE = """### Structure
|
||||
Here is the structure of the JSON object, you should always follow the structure.
|
||||
<structure>
|
||||
{structure}
|
||||
</structure>
|
||||
|
||||
### Text to be converted to JSON
|
||||
Inside <text></text> XML tags, there is a text that you should convert to a JSON object.
|
||||
<text>
|
||||
{text}
|
||||
</text>
|
||||
"""
|
||||
|
||||
CHAT_EXAMPLE = [{
|
||||
'user': {
|
||||
'query': 'What is the weather today in SF?',
|
||||
'json': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'location': {
|
||||
'type': 'string',
|
||||
'description': 'The location to get the weather information',
|
||||
'required': True
|
||||
}
|
||||
},
|
||||
'required': ['location']
|
||||
}
|
||||
},
|
||||
'assistant': {
|
||||
'text': 'I need to output a valid JSON object.',
|
||||
'json': {
|
||||
'location': 'San Francisco'
|
||||
}
|
||||
}
|
||||
}, {
|
||||
'user': {
|
||||
'query': 'I want to eat some apple pie.',
|
||||
'json': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'food': {
|
||||
'type': 'string',
|
||||
'description': 'The food to eat',
|
||||
'required': True
|
||||
}
|
||||
},
|
||||
'required': ['food']
|
||||
}
|
||||
},
|
||||
'assistant': {
|
||||
'text': 'I need to output a valid JSON object.',
|
||||
'json': {
|
||||
'result': 'apple pie'
|
||||
}
|
||||
}
|
||||
}]
|
@@ -7,7 +7,7 @@ from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: Literal['builtin', 'api']
|
||||
provider_type: Literal['builtin', 'api', 'workflow']
|
||||
provider_name: str # redundancy
|
||||
tool_name: str
|
||||
tool_label: str # redundancy
|
||||
|
@@ -1,13 +1,14 @@
|
||||
from os import path
|
||||
from typing import cast
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType, SystemVariable
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
@@ -35,20 +36,23 @@ class ToolNode(BaseNode):
|
||||
'provider_id': node_data.provider_id
|
||||
}
|
||||
|
||||
# get parameters
|
||||
parameters = self._generate_parameters(variable_pool, node_data)
|
||||
# get tool runtime
|
||||
try:
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, self.app_id, self.node_id, node_data)
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(
|
||||
self.tenant_id, self.app_id, self.node_id, node_data, self.invoke_from
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters,
|
||||
inputs={},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOOL_INFO: tool_info
|
||||
},
|
||||
error=f'Failed to get tool runtime: {str(e)}'
|
||||
)
|
||||
|
||||
# get parameters
|
||||
parameters = self._generate_parameters(variable_pool, node_data, tool_runtime)
|
||||
|
||||
try:
|
||||
messages = ToolEngine.workflow_invoke(
|
||||
@@ -56,7 +60,8 @@ class ToolNode(BaseNode):
|
||||
tool_parameters=parameters,
|
||||
user_id=self.user_id,
|
||||
workflow_id=self.workflow_id,
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler()
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth + 1
|
||||
)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
@@ -83,19 +88,32 @@ class ToolNode(BaseNode):
|
||||
inputs=parameters
|
||||
)
|
||||
|
||||
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict:
|
||||
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData, tool_runtime: Tool) -> dict:
|
||||
"""
|
||||
Generate parameters
|
||||
"""
|
||||
tool_parameters = tool_runtime.get_all_runtime_parameters()
|
||||
|
||||
def fetch_parameter(name: str) -> Optional[ToolParameter]:
|
||||
return next((parameter for parameter in tool_parameters if parameter.name == name), None)
|
||||
|
||||
result = {}
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
if input.type == 'mixed':
|
||||
result[parameter_name] = self._format_variable_template(input.value, variable_pool)
|
||||
elif input.type == 'variable':
|
||||
result[parameter_name] = variable_pool.get_variable_value(input.value)
|
||||
elif input.type == 'constant':
|
||||
result[parameter_name] = input.value
|
||||
parameter = fetch_parameter(parameter_name)
|
||||
if not parameter:
|
||||
continue
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
result[parameter_name] = [
|
||||
v.to_dict() for v in self._fetch_files(variable_pool)
|
||||
]
|
||||
else:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
if input.type == 'mixed':
|
||||
result[parameter_name] = self._format_variable_template(input.value, variable_pool)
|
||||
elif input.type == 'variable':
|
||||
result[parameter_name] = variable_pool.get_variable_value(input.value)
|
||||
elif input.type == 'constant':
|
||||
result[parameter_name] = input.value
|
||||
|
||||
return result
|
||||
|
||||
@@ -109,6 +127,13 @@ class ToolNode(BaseNode):
|
||||
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
|
||||
|
||||
return template_parser.format(inputs)
|
||||
|
||||
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
|
||||
files = variable_pool.get_variable_value(['sys', SystemVariable.FILES.value])
|
||||
if not files:
|
||||
return []
|
||||
|
||||
return files
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
|
||||
"""
|
||||
|
33
api/core/workflow/nodes/variable_aggregator/entities.py
Normal file
33
api/core/workflow/nodes/variable_aggregator/entities.py
Normal file
@@ -0,0 +1,33 @@
|
||||
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class AdvancedSetting(BaseModel):
|
||||
"""
|
||||
Advanced setting.
|
||||
"""
|
||||
group_enabled: bool
|
||||
|
||||
class Group(BaseModel):
|
||||
"""
|
||||
Group.
|
||||
"""
|
||||
output_type: Literal['string', 'number', 'array', 'object']
|
||||
variables: list[list[str]]
|
||||
group_name: str
|
||||
|
||||
groups: list[Group]
|
||||
|
||||
class VariableAssignerNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
"""
|
||||
type: str = 'variable-assigner'
|
||||
output_type: str
|
||||
variables: list[list[str]]
|
||||
advanced_setting: Optional[AdvancedSetting]
|
@@ -0,0 +1,52 @@
|
||||
from typing import cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class VariableAggregatorNode(BaseNode):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
_node_type = NodeType.VARIABLE_AGGREGATOR
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
node_data = cast(VariableAssignerNodeData, self.node_data)
|
||||
# Get variables
|
||||
outputs = {}
|
||||
inputs = {}
|
||||
|
||||
if not node_data.advanced_setting or node_data.advanced_setting.group_enabled:
|
||||
for variable in node_data.variables:
|
||||
value = variable_pool.get_variable_value(variable)
|
||||
|
||||
if value is not None:
|
||||
outputs = {
|
||||
"output": value
|
||||
}
|
||||
|
||||
inputs = {
|
||||
'.'.join(variable[1:]): value
|
||||
}
|
||||
break
|
||||
else:
|
||||
for group in node_data.advanced_setting.groups:
|
||||
for variable in group.variables:
|
||||
value = variable_pool.get_variable_value(variable)
|
||||
|
||||
if value is not None:
|
||||
outputs[f'{group.group_name}_output'] = value
|
||||
inputs['.'.join(variable[1:])] = value
|
||||
break
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
inputs=inputs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
return {}
|
@@ -1,12 +0,0 @@
|
||||
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
|
||||
|
||||
class VariableAssignerNodeData(BaseNodeData):
|
||||
"""
|
||||
Knowledge retrieval Node Data.
|
||||
"""
|
||||
type: str = 'variable-assigner'
|
||||
output_type: str
|
||||
variables: list[list[str]]
|
@@ -1,41 +0,0 @@
|
||||
from typing import cast
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.variable_assigner.entities import VariableAssignerNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class VariableAssignerNode(BaseNode):
|
||||
_node_data_cls = VariableAssignerNodeData
|
||||
_node_type = NodeType.VARIABLE_ASSIGNER
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
node_data: VariableAssignerNodeData = cast(self._node_data_cls, self.node_data)
|
||||
# Get variables
|
||||
outputs = {}
|
||||
inputs = {}
|
||||
for variable in node_data.variables:
|
||||
value = variable_pool.get_variable_value(variable)
|
||||
|
||||
if value is not None:
|
||||
outputs = {
|
||||
"output": value
|
||||
}
|
||||
|
||||
inputs = {
|
||||
'.'.join(variable[1:]): value
|
||||
}
|
||||
break
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
inputs=inputs
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[str, list[str]]:
|
||||
return {}
|
@@ -6,6 +6,7 @@ from flask import current_app
|
||||
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.workflow.callbacks.base_workflow_callback import BaseWorkflowCallback
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
@@ -13,19 +14,22 @@ from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.entities.workflow_entities import WorkflowNodeAndResult, WorkflowRunState
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.base_node import BaseNode, UserFrom
|
||||
from core.workflow.nodes.base_node import BaseIterationNode, BaseNode, UserFrom
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.nodes.iteration.entities import IterationState
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.nodes.variable_assigner.variable_assigner_node import VariableAssignerNode
|
||||
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
@@ -44,9 +48,14 @@ node_classes = {
|
||||
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
|
||||
NodeType.HTTP_REQUEST: HttpRequestNode,
|
||||
NodeType.TOOL: ToolNode,
|
||||
NodeType.VARIABLE_ASSIGNER: VariableAssignerNode,
|
||||
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode,
|
||||
NodeType.ITERATION: IterationNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode
|
||||
}
|
||||
|
||||
WORKFLOW_CALL_MAX_DEPTH = 5
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -83,18 +92,20 @@ class WorkflowEngineManager:
|
||||
def run_workflow(self, workflow: Workflow,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
user_inputs: dict,
|
||||
system_inputs: Optional[dict] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
call_depth: Optional[int] = 0,
|
||||
variable_pool: Optional[VariablePool] = None) -> None:
|
||||
"""
|
||||
Run workflow
|
||||
:param workflow: Workflow instance
|
||||
:param user_id: user id
|
||||
:param user_from: user from
|
||||
:param user_inputs: user variables inputs
|
||||
:param system_inputs: system inputs, like: query, files
|
||||
:param callbacks: workflow callbacks
|
||||
:return:
|
||||
:param call_depth: call depth
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph = workflow.graph_dict
|
||||
@@ -109,57 +120,185 @@ class WorkflowEngineManager:
|
||||
|
||||
if not isinstance(graph.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
|
||||
# init variable pool
|
||||
if not variable_pool:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=user_inputs
|
||||
)
|
||||
|
||||
if call_depth > WORKFLOW_CALL_MAX_DEPTH:
|
||||
raise ValueError('Max workflow call depth reached.')
|
||||
|
||||
# init workflow run state
|
||||
workflow_run_state = WorkflowRunState(
|
||||
workflow=workflow,
|
||||
start_at=time.perf_counter(),
|
||||
variable_pool=variable_pool,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
workflow_call_depth=call_depth
|
||||
)
|
||||
|
||||
# init workflow run
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_workflow_run_started()
|
||||
|
||||
# init workflow run state
|
||||
workflow_run_state = WorkflowRunState(
|
||||
# run workflow
|
||||
self._run_workflow(
|
||||
workflow=workflow,
|
||||
start_at=time.perf_counter(),
|
||||
variable_pool=VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=user_inputs
|
||||
),
|
||||
user_id=user_id,
|
||||
user_from=user_from
|
||||
workflow_run_state=workflow_run_state,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
def _run_workflow(self, workflow: Workflow,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
start_at: Optional[str] = None,
|
||||
end_at: Optional[str] = None) -> None:
|
||||
"""
|
||||
Run workflow
|
||||
:param workflow: Workflow instance
|
||||
:param user_id: user id
|
||||
:param user_from: user from
|
||||
:param user_inputs: user variables inputs
|
||||
:param system_inputs: system inputs, like: query, files
|
||||
:param callbacks: workflow callbacks
|
||||
:param call_depth: call depth
|
||||
:param start_at: force specific start node
|
||||
:param end_at: force specific end node
|
||||
:return:
|
||||
"""
|
||||
graph = workflow.graph_dict
|
||||
|
||||
try:
|
||||
predecessor_node = None
|
||||
predecessor_node: BaseNode = None
|
||||
current_iteration_node: BaseIterationNode = None
|
||||
has_entry_node = False
|
||||
max_execution_steps = current_app.config.get("WORKFLOW_MAX_EXECUTION_STEPS")
|
||||
max_execution_time = current_app.config.get("WORKFLOW_MAX_EXECUTION_TIME")
|
||||
while True:
|
||||
# get next node, multiple target nodes in the future
|
||||
next_node = self._get_next_node(
|
||||
next_node = self._get_next_overall_node(
|
||||
workflow_run_state=workflow_run_state,
|
||||
graph=graph,
|
||||
predecessor_node=predecessor_node,
|
||||
callbacks=callbacks
|
||||
callbacks=callbacks,
|
||||
start_at=start_at,
|
||||
end_at=end_at
|
||||
)
|
||||
|
||||
if not next_node:
|
||||
# reached loop/iteration end or overall end
|
||||
if current_iteration_node and workflow_run_state.current_iteration_state:
|
||||
# reached loop/iteration end
|
||||
# get next iteration
|
||||
next_iteration = current_iteration_node.get_next_iteration(
|
||||
variable_pool=workflow_run_state.variable_pool,
|
||||
state=workflow_run_state.current_iteration_state
|
||||
)
|
||||
self._workflow_iteration_next(
|
||||
graph=graph,
|
||||
current_iteration_node=current_iteration_node,
|
||||
workflow_run_state=workflow_run_state,
|
||||
callbacks=callbacks
|
||||
)
|
||||
if isinstance(next_iteration, NodeRunResult):
|
||||
if next_iteration.outputs:
|
||||
for variable_key, variable_value in next_iteration.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
variable_pool=workflow_run_state.variable_pool,
|
||||
node_id=current_iteration_node.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value
|
||||
)
|
||||
self._workflow_iteration_completed(
|
||||
current_iteration_node=current_iteration_node,
|
||||
workflow_run_state=workflow_run_state,
|
||||
callbacks=callbacks
|
||||
)
|
||||
# iteration has ended
|
||||
next_node = self._get_next_overall_node(
|
||||
workflow_run_state=workflow_run_state,
|
||||
graph=graph,
|
||||
predecessor_node=current_iteration_node,
|
||||
callbacks=callbacks,
|
||||
start_at=start_at,
|
||||
end_at=end_at
|
||||
)
|
||||
current_iteration_node = None
|
||||
workflow_run_state.current_iteration_state = None
|
||||
# continue overall process
|
||||
elif isinstance(next_iteration, str):
|
||||
# move to next iteration
|
||||
next_node_id = next_iteration
|
||||
# get next id
|
||||
next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks)
|
||||
|
||||
if not next_node:
|
||||
break
|
||||
|
||||
# check is already ran
|
||||
if next_node.node_id in [node_and_result.node.node_id
|
||||
for node_and_result in workflow_run_state.workflow_nodes_and_results]:
|
||||
if self._check_node_has_ran(workflow_run_state, next_node.node_id):
|
||||
predecessor_node = next_node
|
||||
continue
|
||||
|
||||
has_entry_node = True
|
||||
|
||||
# max steps reached
|
||||
if len(workflow_run_state.workflow_nodes_and_results) > max_execution_steps:
|
||||
if workflow_run_state.workflow_node_steps > max_execution_steps:
|
||||
raise ValueError('Max steps {} reached.'.format(max_execution_steps))
|
||||
|
||||
# or max execution time reached
|
||||
if self._is_timed_out(start_at=workflow_run_state.start_at, max_execution_time=max_execution_time):
|
||||
raise ValueError('Max execution time {}s reached.'.format(max_execution_time))
|
||||
|
||||
# handle iteration nodes
|
||||
if isinstance(next_node, BaseIterationNode):
|
||||
current_iteration_node = next_node
|
||||
workflow_run_state.current_iteration_state = next_node.run(
|
||||
variable_pool=workflow_run_state.variable_pool
|
||||
)
|
||||
self._workflow_iteration_started(
|
||||
graph=graph,
|
||||
current_iteration_node=current_iteration_node,
|
||||
workflow_run_state=workflow_run_state,
|
||||
predecessor_node_id=predecessor_node.node_id if predecessor_node else None,
|
||||
callbacks=callbacks
|
||||
)
|
||||
predecessor_node = next_node
|
||||
# move to start node of iteration
|
||||
next_node_id = next_node.get_next_iteration(
|
||||
variable_pool=workflow_run_state.variable_pool,
|
||||
state=workflow_run_state.current_iteration_state
|
||||
)
|
||||
self._workflow_iteration_next(
|
||||
graph=graph,
|
||||
current_iteration_node=current_iteration_node,
|
||||
workflow_run_state=workflow_run_state,
|
||||
callbacks=callbacks
|
||||
)
|
||||
if isinstance(next_node_id, NodeRunResult):
|
||||
# iteration has ended
|
||||
current_iteration_node.set_output(
|
||||
variable_pool=workflow_run_state.variable_pool,
|
||||
state=workflow_run_state.current_iteration_state
|
||||
)
|
||||
self._workflow_iteration_completed(
|
||||
current_iteration_node=current_iteration_node,
|
||||
workflow_run_state=workflow_run_state,
|
||||
callbacks=callbacks
|
||||
)
|
||||
current_iteration_node = None
|
||||
workflow_run_state.current_iteration_state = None
|
||||
continue
|
||||
else:
|
||||
next_node = self._get_node(workflow_run_state, graph, next_node_id, callbacks)
|
||||
|
||||
# run workflow, run multiple target nodes in the future
|
||||
self._run_workflow_node(
|
||||
workflow_run_state=workflow_run_state,
|
||||
@@ -235,7 +374,9 @@ class WorkflowEngineManager:
|
||||
workflow_id=workflow.id,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
config=node_config
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
config=node_config,
|
||||
workflow_call_depth=0
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -251,49 +392,14 @@ class WorkflowEngineManager:
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
for variable_key, variable_selector in variable_mapping.items():
|
||||
if variable_key not in user_inputs:
|
||||
raise ValueError(f'Variable key {variable_key} not found in user inputs.')
|
||||
|
||||
# fetch variable node id from variable selector
|
||||
variable_node_id = variable_selector[0]
|
||||
variable_key_list = variable_selector[1:]
|
||||
|
||||
# get value
|
||||
value = user_inputs.get(variable_key)
|
||||
|
||||
# temp fix for image type
|
||||
if node_type == NodeType.LLM:
|
||||
new_value = []
|
||||
if isinstance(value, list):
|
||||
node_data = node_instance.node_data
|
||||
node_data = cast(LLMNodeData, node_data)
|
||||
|
||||
detail = node_data.vision.configs.detail if node_data.vision.configs else None
|
||||
|
||||
for item in value:
|
||||
if isinstance(item, dict) and 'type' in item and item['type'] == 'image':
|
||||
transfer_method = FileTransferMethod.value_of(item.get('transfer_method'))
|
||||
file = FileVar(
|
||||
tenant_id=workflow.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
related_id=item.get(
|
||||
'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None),
|
||||
)
|
||||
new_value.append(file)
|
||||
|
||||
if new_value:
|
||||
value = new_value
|
||||
|
||||
# append variable and value to variable pool
|
||||
variable_pool.append_variable(
|
||||
node_id=variable_node_id,
|
||||
variable_key_list=variable_key_list,
|
||||
value=value
|
||||
)
|
||||
self._mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_instance=node_instance
|
||||
)
|
||||
|
||||
# run node
|
||||
node_run_result = node_instance.run(
|
||||
variable_pool=variable_pool
|
||||
@@ -311,6 +417,126 @@ class WorkflowEngineManager:
|
||||
|
||||
return node_instance, node_run_result
|
||||
|
||||
def single_step_run_iteration_workflow_node(self, workflow: Workflow,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: dict,
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Single iteration run workflow node
|
||||
"""
|
||||
# fetch node info from workflow graph
|
||||
graph = workflow.graph_dict
|
||||
if not graph:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
nodes = graph.get('nodes')
|
||||
if not nodes:
|
||||
raise ValueError('nodes not found in workflow graph')
|
||||
|
||||
for node in nodes:
|
||||
if node.get('id') == node_id:
|
||||
if node.get('data', {}).get('type') in [
|
||||
NodeType.ITERATION.value,
|
||||
NodeType.LOOP.value,
|
||||
]:
|
||||
node_config = node
|
||||
else:
|
||||
raise ValueError('node id is not an iteration node')
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={}
|
||||
)
|
||||
|
||||
# variable selector to variable mapping
|
||||
iteration_nested_nodes = [
|
||||
node for node in nodes
|
||||
if node.get('data', {}).get('iteration_id') == node_id or node.get('id') == node_id
|
||||
]
|
||||
iteration_nested_node_ids = [node.get('id') for node in iteration_nested_nodes]
|
||||
|
||||
if not iteration_nested_nodes:
|
||||
raise ValueError('iteration has no nested nodes')
|
||||
|
||||
# init workflow run
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_workflow_run_started()
|
||||
|
||||
for node_config in iteration_nested_nodes:
|
||||
# mapping user inputs to variable pool
|
||||
node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type')))
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(node_config)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
# remove iteration variables
|
||||
variable_mapping = {
|
||||
f'{node_config.get("id")}.{key}': value for key, value in variable_mapping.items()
|
||||
if value[0] != node_id
|
||||
}
|
||||
|
||||
# remove variable out from iteration
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items()
|
||||
if value[0] not in iteration_nested_node_ids
|
||||
}
|
||||
|
||||
# append variables to variable pool
|
||||
node_instance = node_cls(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_id=workflow.id,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
config=node_config,
|
||||
callbacks=callbacks,
|
||||
workflow_call_depth=0
|
||||
)
|
||||
|
||||
self._mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_instance=node_instance
|
||||
)
|
||||
|
||||
# fetch end node of iteration
|
||||
end_node_id = None
|
||||
for edge in graph.get('edges'):
|
||||
if edge.get('source') == node_id:
|
||||
end_node_id = edge.get('target')
|
||||
break
|
||||
|
||||
if not end_node_id:
|
||||
raise ValueError('end node of iteration not found')
|
||||
|
||||
# init workflow run state
|
||||
workflow_run_state = WorkflowRunState(
|
||||
workflow=workflow,
|
||||
start_at=time.perf_counter(),
|
||||
variable_pool=variable_pool,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
workflow_call_depth=0
|
||||
)
|
||||
|
||||
# run workflow
|
||||
self._run_workflow(
|
||||
workflow=workflow,
|
||||
workflow_run_state=workflow_run_state,
|
||||
callbacks=callbacks,
|
||||
start_at=node_id,
|
||||
end_at=end_node_id
|
||||
)
|
||||
|
||||
def _workflow_run_success(self, callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
"""
|
||||
Workflow run success
|
||||
@@ -336,10 +562,96 @@ class WorkflowEngineManager:
|
||||
error=error
|
||||
)
|
||||
|
||||
def _get_next_node(self, workflow_run_state: WorkflowRunState,
|
||||
def _workflow_iteration_started(self, graph: dict,
|
||||
current_iteration_node: BaseIterationNode,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
"""
|
||||
Workflow iteration started
|
||||
:param current_iteration_node: current iteration node
|
||||
:param workflow_run_state: workflow run state
|
||||
:param callbacks: workflow callbacks
|
||||
:return:
|
||||
"""
|
||||
# get nested nodes
|
||||
iteration_nested_nodes = [
|
||||
node for node in graph.get('nodes')
|
||||
if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id
|
||||
]
|
||||
|
||||
if not iteration_nested_nodes:
|
||||
raise ValueError('iteration has no nested nodes')
|
||||
|
||||
if callbacks:
|
||||
if isinstance(workflow_run_state.current_iteration_state, IterationState):
|
||||
for callback in callbacks:
|
||||
callback.on_workflow_iteration_started(
|
||||
node_id=current_iteration_node.node_id,
|
||||
node_type=NodeType.ITERATION,
|
||||
node_run_index=workflow_run_state.workflow_node_steps,
|
||||
node_data=current_iteration_node.node_data,
|
||||
inputs=workflow_run_state.current_iteration_state.inputs,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
metadata=workflow_run_state.current_iteration_state.metadata.dict()
|
||||
)
|
||||
|
||||
# add steps
|
||||
workflow_run_state.workflow_node_steps += 1
|
||||
|
||||
def _workflow_iteration_next(self, graph: dict,
|
||||
current_iteration_node: BaseIterationNode,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
"""
|
||||
Workflow iteration next
|
||||
:param workflow_run_state: workflow run state
|
||||
:return:
|
||||
"""
|
||||
if callbacks:
|
||||
if isinstance(workflow_run_state.current_iteration_state, IterationState):
|
||||
for callback in callbacks:
|
||||
callback.on_workflow_iteration_next(
|
||||
node_id=current_iteration_node.node_id,
|
||||
node_type=NodeType.ITERATION,
|
||||
index=workflow_run_state.current_iteration_state.index,
|
||||
node_run_index=workflow_run_state.workflow_node_steps,
|
||||
output=workflow_run_state.current_iteration_state.get_current_output()
|
||||
)
|
||||
# clear ran nodes
|
||||
workflow_run_state.workflow_node_runs = [
|
||||
node_run for node_run in workflow_run_state.workflow_node_runs
|
||||
if node_run.iteration_node_id != current_iteration_node.node_id
|
||||
]
|
||||
|
||||
# clear variables in current iteration
|
||||
nodes = graph.get('nodes')
|
||||
nodes = [node for node in nodes if node.get('data', {}).get('iteration_id') == current_iteration_node.node_id]
|
||||
|
||||
for node in nodes:
|
||||
workflow_run_state.variable_pool.clear_node_variables(node_id=node.get('id'))
|
||||
|
||||
def _workflow_iteration_completed(self, current_iteration_node: BaseIterationNode,
|
||||
workflow_run_state: WorkflowRunState,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> None:
|
||||
if callbacks:
|
||||
if isinstance(workflow_run_state.current_iteration_state, IterationState):
|
||||
for callback in callbacks:
|
||||
callback.on_workflow_iteration_completed(
|
||||
node_id=current_iteration_node.node_id,
|
||||
node_type=NodeType.ITERATION,
|
||||
node_run_index=workflow_run_state.workflow_node_steps,
|
||||
outputs={
|
||||
'output': workflow_run_state.current_iteration_state.outputs
|
||||
}
|
||||
)
|
||||
|
||||
def _get_next_overall_node(self, workflow_run_state: WorkflowRunState,
|
||||
graph: dict,
|
||||
predecessor_node: Optional[BaseNode] = None,
|
||||
callbacks: list[BaseWorkflowCallback] = None) -> Optional[BaseNode]:
|
||||
callbacks: list[BaseWorkflowCallback] = None,
|
||||
start_at: Optional[str] = None,
|
||||
end_at: Optional[str] = None) -> Optional[BaseNode]:
|
||||
"""
|
||||
Get next node
|
||||
multiple target nodes in the future.
|
||||
@@ -354,16 +666,26 @@ class WorkflowEngineManager:
|
||||
|
||||
if not predecessor_node:
|
||||
for node_config in nodes:
|
||||
if node_config.get('data', {}).get('type', '') == NodeType.START.value:
|
||||
return StartNode(
|
||||
node_cls = None
|
||||
if start_at:
|
||||
if node_config.get('id') == start_at:
|
||||
node_cls = node_classes.get(NodeType.value_of(node_config.get('data', {}).get('type')))
|
||||
else:
|
||||
if node_config.get('data', {}).get('type', '') == NodeType.START.value:
|
||||
node_cls = StartNode
|
||||
if node_cls:
|
||||
return node_cls(
|
||||
tenant_id=workflow_run_state.tenant_id,
|
||||
app_id=workflow_run_state.app_id,
|
||||
workflow_id=workflow_run_state.workflow_id,
|
||||
user_id=workflow_run_state.user_id,
|
||||
user_from=workflow_run_state.user_from,
|
||||
invoke_from=workflow_run_state.invoke_from,
|
||||
config=node_config,
|
||||
callbacks=callbacks
|
||||
callbacks=callbacks,
|
||||
workflow_call_depth=workflow_run_state.workflow_call_depth
|
||||
)
|
||||
|
||||
else:
|
||||
edges = graph.get('edges')
|
||||
source_node_id = predecessor_node.node_id
|
||||
@@ -390,6 +712,9 @@ class WorkflowEngineManager:
|
||||
|
||||
target_node_id = outgoing_edge.get('target')
|
||||
|
||||
if end_at and target_node_id == end_at:
|
||||
return None
|
||||
|
||||
# fetch target node from target node id
|
||||
target_node_config = None
|
||||
for node in nodes:
|
||||
@@ -409,9 +734,40 @@ class WorkflowEngineManager:
|
||||
workflow_id=workflow_run_state.workflow_id,
|
||||
user_id=workflow_run_state.user_id,
|
||||
user_from=workflow_run_state.user_from,
|
||||
invoke_from=workflow_run_state.invoke_from,
|
||||
config=target_node_config,
|
||||
callbacks=callbacks
|
||||
callbacks=callbacks,
|
||||
workflow_call_depth=workflow_run_state.workflow_call_depth
|
||||
)
|
||||
|
||||
def _get_node(self, workflow_run_state: WorkflowRunState,
|
||||
graph: dict,
|
||||
node_id: str,
|
||||
callbacks: list[BaseWorkflowCallback]) -> Optional[BaseNode]:
|
||||
"""
|
||||
Get node from graph by node id
|
||||
"""
|
||||
nodes = graph.get('nodes')
|
||||
if not nodes:
|
||||
return None
|
||||
|
||||
for node_config in nodes:
|
||||
if node_config.get('id') == node_id:
|
||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
return node_cls(
|
||||
tenant_id=workflow_run_state.tenant_id,
|
||||
app_id=workflow_run_state.app_id,
|
||||
workflow_id=workflow_run_state.workflow_id,
|
||||
user_id=workflow_run_state.user_id,
|
||||
user_from=workflow_run_state.user_from,
|
||||
invoke_from=workflow_run_state.invoke_from,
|
||||
config=node_config,
|
||||
callbacks=callbacks,
|
||||
workflow_call_depth=workflow_run_state.workflow_call_depth
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||
"""
|
||||
@@ -422,6 +778,15 @@ class WorkflowEngineManager:
|
||||
"""
|
||||
return time.perf_counter() - start_at > max_execution_time
|
||||
|
||||
def _check_node_has_ran(self, workflow_run_state: WorkflowRunState, node_id: str) -> bool:
|
||||
"""
|
||||
Check node has ran
|
||||
"""
|
||||
return bool([
|
||||
node_and_result for node_and_result in workflow_run_state.workflow_node_runs
|
||||
if node_and_result.node_id == node_id
|
||||
])
|
||||
|
||||
def _run_workflow_node(self, workflow_run_state: WorkflowRunState,
|
||||
node: BaseNode,
|
||||
predecessor_node: Optional[BaseNode] = None,
|
||||
@@ -432,7 +797,7 @@ class WorkflowEngineManager:
|
||||
node_id=node.node_id,
|
||||
node_type=node.node_type,
|
||||
node_data=node.node_data,
|
||||
node_run_index=len(workflow_run_state.workflow_nodes_and_results) + 1,
|
||||
node_run_index=workflow_run_state.workflow_node_steps,
|
||||
predecessor_node_id=predecessor_node.node_id if predecessor_node else None
|
||||
)
|
||||
|
||||
@@ -446,6 +811,16 @@ class WorkflowEngineManager:
|
||||
# add to workflow_nodes_and_results
|
||||
workflow_run_state.workflow_nodes_and_results.append(workflow_nodes_and_result)
|
||||
|
||||
# add steps
|
||||
workflow_run_state.workflow_node_steps += 1
|
||||
|
||||
# mark node as running
|
||||
if workflow_run_state.current_iteration_state:
|
||||
workflow_run_state.workflow_node_runs.append(WorkflowRunState.NodeRun(
|
||||
node_id=node.node_id,
|
||||
iteration_node_id=workflow_run_state.current_iteration_state.iteration_node_id
|
||||
))
|
||||
|
||||
try:
|
||||
# run node, result must have inputs, process_data, outputs, execution_metadata
|
||||
node_run_result = node.run(
|
||||
@@ -565,3 +940,53 @@ class WorkflowEngineManager:
|
||||
new_value[key] = new_val
|
||||
|
||||
return new_value
|
||||
|
||||
def _mapping_user_inputs_to_variable_pool(self,
|
||||
variable_mapping: dict,
|
||||
user_inputs: dict,
|
||||
variable_pool: VariablePool,
|
||||
tenant_id: str,
|
||||
node_instance: BaseNode):
|
||||
for variable_key, variable_selector in variable_mapping.items():
|
||||
if variable_key not in user_inputs:
|
||||
raise ValueError(f'Variable key {variable_key} not found in user inputs.')
|
||||
|
||||
# fetch variable node id from variable selector
|
||||
variable_node_id = variable_selector[0]
|
||||
variable_key_list = variable_selector[1:]
|
||||
|
||||
# get value
|
||||
value = user_inputs.get(variable_key)
|
||||
|
||||
# temp fix for image type
|
||||
if node_instance.node_type == NodeType.LLM:
|
||||
new_value = []
|
||||
if isinstance(value, list):
|
||||
node_data = node_instance.node_data
|
||||
node_data = cast(LLMNodeData, node_data)
|
||||
|
||||
detail = node_data.vision.configs.detail if node_data.vision.configs else None
|
||||
|
||||
for item in value:
|
||||
if isinstance(item, dict) and 'type' in item and item['type'] == 'image':
|
||||
transfer_method = FileTransferMethod.value_of(item.get('transfer_method'))
|
||||
file = FileVar(
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
related_id=item.get(
|
||||
'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None),
|
||||
)
|
||||
new_value.append(file)
|
||||
|
||||
if new_value:
|
||||
value = new_value
|
||||
|
||||
# append variable and value to variable pool
|
||||
variable_pool.append_variable(
|
||||
node_id=variable_node_id,
|
||||
variable_key_list=variable_key_list,
|
||||
value=value
|
||||
)
|
Reference in New Issue
Block a user