refactor(variables): replace deprecated 'get_any' with 'get' method (#9584)
This commit is contained in:
@@ -22,7 +22,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.variables import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
|
||||
from core.variables import (
|
||||
ArrayAnySegment,
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
@@ -263,50 +271,44 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
return variables
|
||||
|
||||
for variable_selector in node_data.prompt_config.jinja2_variables or []:
|
||||
variable = variable_selector.variable
|
||||
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
|
||||
variable_name = variable_selector.variable
|
||||
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
|
||||
if variable is None:
|
||||
raise ValueError(f"Variable {variable_selector.variable} not found")
|
||||
|
||||
def parse_dict(d: dict) -> str:
|
||||
def parse_dict(input_dict: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Parse dict into string
|
||||
"""
|
||||
# check if it's a context structure
|
||||
if "metadata" in d and "_source" in d["metadata"] and "content" in d:
|
||||
return d["content"]
|
||||
if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
|
||||
return input_dict["content"]
|
||||
|
||||
# else, parse the dict
|
||||
try:
|
||||
return json.dumps(d, ensure_ascii=False)
|
||||
return json.dumps(input_dict, ensure_ascii=False)
|
||||
except Exception:
|
||||
return str(d)
|
||||
return str(input_dict)
|
||||
|
||||
if isinstance(value, str):
|
||||
value = value
|
||||
elif isinstance(value, list):
|
||||
if isinstance(variable, ArraySegment):
|
||||
result = ""
|
||||
for item in value:
|
||||
for item in variable.value:
|
||||
if isinstance(item, dict):
|
||||
result += parse_dict(item)
|
||||
elif isinstance(item, str):
|
||||
result += item
|
||||
elif isinstance(item, int | float):
|
||||
result += str(item)
|
||||
else:
|
||||
result += str(item)
|
||||
result += "\n"
|
||||
value = result.strip()
|
||||
elif isinstance(value, dict):
|
||||
value = parse_dict(value)
|
||||
elif isinstance(value, int | float):
|
||||
value = str(value)
|
||||
elif isinstance(variable, ObjectSegment):
|
||||
value = parse_dict(variable.value)
|
||||
else:
|
||||
value = str(value)
|
||||
value = variable.text
|
||||
|
||||
variables[variable] = value
|
||||
variables[variable_name] = value
|
||||
|
||||
return variables
|
||||
|
||||
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
||||
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]:
|
||||
inputs = {}
|
||||
prompt_template = node_data.prompt_template
|
||||
|
||||
@@ -363,14 +365,14 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
if not node_data.context.variable_selector:
|
||||
return
|
||||
|
||||
context_value = self.graph_runtime_state.variable_pool.get_any(node_data.context.variable_selector)
|
||||
if context_value:
|
||||
if isinstance(context_value, str):
|
||||
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value)
|
||||
elif isinstance(context_value, list):
|
||||
context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
|
||||
if context_value_variable:
|
||||
if isinstance(context_value_variable, StringSegment):
|
||||
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
|
||||
elif isinstance(context_value_variable, ArraySegment):
|
||||
context_str = ""
|
||||
original_retriever_resource = []
|
||||
for item in context_value:
|
||||
for item in context_value_variable.value:
|
||||
if isinstance(item, str):
|
||||
context_str += item + "\n"
|
||||
else:
|
||||
@@ -484,11 +486,12 @@ class LLMNode(BaseNode[LLMNodeData]):
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get_any(
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID.value]
|
||||
)
|
||||
if conversation_id is None:
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
# get conversation
|
||||
conversation = (
|
||||
|
Reference in New Issue
Block a user