refactor(variables): replace deprecated 'get_any' with 'get' method (#9584)

This commit is contained in:
-LAN-
2024-10-22 10:49:19 +08:00
committed by GitHub
parent 5838345f48
commit 8f670f31b8
8 changed files with 95 additions and 77 deletions

View File

@@ -22,7 +22,15 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.variables import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
from core.variables import (
ArrayAnySegment,
ArrayFileSegment,
ArraySegment,
FileSegment,
NoneSegment,
ObjectSegment,
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
from core.workflow.enums import SystemVariableKey
@@ -263,50 +271,44 @@ class LLMNode(BaseNode[LLMNodeData]):
return variables
for variable_selector in node_data.prompt_config.jinja2_variables or []:
variable = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get_any(variable_selector.value_selector)
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if variable is None:
raise ValueError(f"Variable {variable_selector.variable} not found")
def parse_dict(d: dict) -> str:
def parse_dict(input_dict: Mapping[str, Any]) -> str:
"""
Parse dict into string
"""
# check if it's a context structure
if "metadata" in d and "_source" in d["metadata"] and "content" in d:
return d["content"]
if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict:
return input_dict["content"]
# else, parse the dict
try:
return json.dumps(d, ensure_ascii=False)
return json.dumps(input_dict, ensure_ascii=False)
except Exception:
return str(d)
return str(input_dict)
if isinstance(value, str):
value = value
elif isinstance(value, list):
if isinstance(variable, ArraySegment):
result = ""
for item in value:
for item in variable.value:
if isinstance(item, dict):
result += parse_dict(item)
elif isinstance(item, str):
result += item
elif isinstance(item, int | float):
result += str(item)
else:
result += str(item)
result += "\n"
value = result.strip()
elif isinstance(value, dict):
value = parse_dict(value)
elif isinstance(value, int | float):
value = str(value)
elif isinstance(variable, ObjectSegment):
value = parse_dict(variable.value)
else:
value = str(value)
value = variable.text
variables[variable] = value
variables[variable_name] = value
return variables
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]:
inputs = {}
prompt_template = node_data.prompt_template
@@ -363,14 +365,14 @@ class LLMNode(BaseNode[LLMNodeData]):
if not node_data.context.variable_selector:
return
context_value = self.graph_runtime_state.variable_pool.get_any(node_data.context.variable_selector)
if context_value:
if isinstance(context_value, str):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value)
elif isinstance(context_value, list):
context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector)
if context_value_variable:
if isinstance(context_value_variable, StringSegment):
yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value)
elif isinstance(context_value_variable, ArraySegment):
context_str = ""
original_retriever_resource = []
for item in context_value:
for item in context_value_variable.value:
if isinstance(item, str):
context_str += item + "\n"
else:
@@ -484,11 +486,12 @@ class LLMNode(BaseNode[LLMNodeData]):
return None
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get_any(
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID.value]
)
if conversation_id is None:
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
# get conversation
conversation = (