refactor(api/core/workflow/enums.py): Rename SystemVariable to SystemVariableKey. (#7445)

This commit is contained in:
-LAN-
2024-08-20 17:52:06 +08:00
committed by GitHub
parent 5e42e90abc
commit 4f5f27cf2b
16 changed files with 106 additions and 118 deletions

View File

@@ -6,20 +6,20 @@ from typing_extensions import deprecated
from core.app.segments import Segment, Variable, factory
from core.file.file_obj import FileVar
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
VariableValue = Union[str, int, float, dict, list, FileVar]
SYSTEM_VARIABLE_NODE_ID = 'sys'
ENVIRONMENT_VARIABLE_NODE_ID = 'env'
CONVERSATION_VARIABLE_NODE_ID = 'conversation'
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
class VariablePool:
def __init__(
self,
system_variables: Mapping[SystemVariable, Any],
system_variables: Mapping[SystemVariableKey, Any],
user_inputs: Mapping[str, Any],
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable] | None = None,
@@ -68,7 +68,7 @@ class VariablePool:
None
"""
if len(selector) < 2:
raise ValueError('Invalid selector')
raise ValueError("Invalid selector")
if value is None:
return
@@ -95,13 +95,13 @@ class VariablePool:
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
raise ValueError('Invalid selector')
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
return value
@deprecated('This method is deprecated, use `get` instead.')
@deprecated("This method is deprecated, use `get` instead.")
def get_any(self, selector: Sequence[str], /) -> Any | None:
"""
Retrieves the value from the variable pool based on the given selector.
@@ -116,7 +116,7 @@ class VariablePool:
ValueError: If the selector is invalid.
"""
if len(selector) < 2:
raise ValueError('Invalid selector')
raise ValueError("Invalid selector")
hash_key = hash(tuple(selector[1:]))
value = self._variable_dictionary[selector[0]].get(hash_key)
return value.to_object() if value else None

View File

@@ -1,25 +1,13 @@
from enum import Enum
class SystemVariable(str, Enum):
class SystemVariableKey(str, Enum):
"""
System Variables.
"""
QUERY = 'query'
FILES = 'files'
CONVERSATION_ID = 'conversation_id'
USER_ID = 'user_id'
DIALOGUE_COUNT = 'dialogue_count'
@classmethod
def value_of(cls, value: str):
"""
Get value of given system variable.
:param value: system variable value
:return: system variable
"""
for system_variable in cls:
if system_variable.value == value:
return system_variable
raise ValueError(f'invalid system variable value {value}')
QUERY = "query"
FILES = "files"
CONVERSATION_ID = "conversation_id"
USER_ID = "user_id"
DIALOGUE_COUNT = "dialogue_count"

View File

@@ -24,7 +24,7 @@ from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptT
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.llm.entities import (
LLMNodeChatModelMessage,
@@ -94,7 +94,7 @@ class LLMNode(BaseNode):
# fetch prompt messages
prompt_messages, stop = self._fetch_prompt_messages(
node_data=node_data,
query=variable_pool.get_any(['sys', SystemVariable.QUERY.value])
query=variable_pool.get_any(['sys', SystemVariableKey.QUERY.value])
if node_data.memory else None,
query_prompt_template=node_data.memory.query_prompt_template if node_data.memory else None,
inputs=inputs,
@@ -335,7 +335,7 @@ class LLMNode(BaseNode):
if not node_data.vision.enabled:
return []
files = variable_pool.get_any(['sys', SystemVariable.FILES.value])
files = variable_pool.get_any(['sys', SystemVariableKey.FILES.value])
if not files:
return []
@@ -500,7 +500,7 @@ class LLMNode(BaseNode):
return None
# get conversation id
conversation_id = variable_pool.get_any(['sys', SystemVariable.CONVERSATION_ID.value])
conversation_id = variable_pool.get_any(['sys', SystemVariableKey.CONVERSATION_ID.value])
if conversation_id is None:
return None
@@ -672,10 +672,10 @@ class LLMNode(BaseNode):
variable_mapping['#context#'] = node_data.context.variable_selector
if node_data.vision.enabled:
variable_mapping['#files#'] = ['sys', SystemVariable.FILES.value]
variable_mapping['#files#'] = ['sys', SystemVariableKey.FILES.value]
if node_data.memory:
variable_mapping['#sys.query#'] = ['sys', SystemVariable.QUERY.value]
variable_mapping['#sys.query#'] = ['sys', SystemVariableKey.QUERY.value]
if node_data.prompt_config:
enable_jinja = False

View File

@@ -1,7 +1,7 @@
from core.workflow.entities.base_node_data_entities import BaseNodeData
from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.variable_pool import SYSTEM_VARIABLE_NODE_ID, VariablePool
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.start.entities import StartNodeData
from models.workflow import WorkflowNodeExecutionStatus
@@ -17,16 +17,16 @@ class StartNode(BaseNode):
:param variable_pool: variable pool
:return:
"""
# Get cleaned inputs
cleaned_inputs = dict(variable_pool.user_inputs)
node_inputs = dict(variable_pool.user_inputs)
system_inputs = variable_pool.system_variables
for var in variable_pool.system_variables:
cleaned_inputs['sys.' + var.value] = variable_pool.system_variables[var]
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + '.' + var] = system_inputs[var]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=cleaned_inputs,
outputs=cleaned_inputs
inputs=node_inputs,
outputs=node_inputs
)
@classmethod

View File

@@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariable
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.utils.variable_template_parser import VariableTemplateParser
@@ -141,7 +141,7 @@ class ToolNode(BaseNode):
return result
def _fetch_files(self, variable_pool: VariablePool) -> list[FileVar]:
variable = variable_pool.get(['sys', SystemVariable.FILES.value])
variable = variable_pool.get(['sys', SystemVariableKey.FILES.value])
assert isinstance(variable, ArrayAnyVariable)
return list(variable.value) if variable else []