refactor(api): Decouple ParameterExtractorNode from LLMNode (#20843)

- Extract methods used by `ParameterExtractorNode` from `LLMNode` into a separate file.
- Convert `ParameterExtractorNode` into a subclass of `BaseNode`.
- Refactor code referencing the extracted methods to ensure functionality and clarity.
- Fixes the issue that `ParameterExtractorNode` returns error when executed.
- Fix relevant test cases.

Closes #20840.
This commit is contained in:
QuantumGhost
2025-06-10 11:47:50 +08:00
committed by GitHub
parent a97ff587d2
commit c439e82038
8 changed files with 226 additions and 171 deletions

View File

@@ -21,7 +21,7 @@ from core.plugin.entities.request import (
) )
from core.tools.entities.tool_entities import ToolProviderType from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.workflow.nodes.llm.node import LLMNode from core.workflow.nodes.llm import llm_utils
from models.account import Tenant from models.account import Tenant
@@ -55,7 +55,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
def handle() -> Generator[LLMResultChunk, None, None]: def handle() -> Generator[LLMResultChunk, None, None]:
for chunk in response: for chunk in response:
if chunk.delta.usage: if chunk.delta.usage:
LLMNode.deduct_llm_quota( llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
) )
chunk.prompt_messages = [] chunk.prompt_messages = []
@@ -64,7 +64,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
return handle() return handle()
else: else:
if response.usage: if response.usage:
LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]: def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk( yield LLMResultChunk(

View File

@@ -9,7 +9,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.llm import LLMNode from core.workflow.nodes.llm import llm_utils
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
@@ -165,7 +165,7 @@ class ReactMultiDatasetRouter:
text, usage = self._handle_invoke_result(invoke_result=invoke_result) text, usage = self._handle_invoke_result(invoke_result=invoke_result)
# deduct quota # deduct quota
LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
return text, usage return text, usage

View File

@@ -0,0 +1,156 @@
from collections.abc import Sequence
from datetime import UTC, datetime
from typing import Optional, cast
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import QuotaUnit
from core.file.models import File
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
from models import db
from models.model import Conversation
from models.provider import Provider, ProviderType
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
def fetch_model_config(
tenant_id: str, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
model = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=node_data_model.provider,
model=node_data_model.name,
)
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
# check model
provider_model = model.provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name, model_type=ModelType.LLM
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
# model config
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
return model, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=model.provider_model_bundle,
credentials=model.credentials,
parameters=node_data_model.completion_params,
stop=stop,
)
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
variable = variable_pool.get(selector)
if variable is None:
return []
elif isinstance(variable, FileSegment):
return [variable.value]
elif isinstance(variable, ArrayFileSegment):
return variable.value
elif isinstance(variable, NoneSegment | ArrayAnySegment):
return []
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
def fetch_memory(
variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
) -> Optional[TokenBufferMemory]:
if not node_data_memory:
return None
# get conversation id
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model)
else:
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
)
)
session.execute(stmt)
session.commit()

View File

@@ -3,16 +3,11 @@ import io
import json import json
import logging import logging
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, Optional, cast from typing import TYPE_CHECKING, Any, Optional, cast
import json_repair import json_repair
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import QuotaUnit
from core.file import FileType, file_manager from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.memory.token_buffer_memory import TokenBufferMemory from core.memory.token_buffer_memory import TokenBufferMemory
@@ -40,12 +35,10 @@ from core.model_runtime.entities.model_entities import (
) )
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.variables import ( from core.variables import (
ArrayAnySegment,
ArrayFileSegment, ArrayFileSegment,
ArraySegment, ArraySegment,
FileSegment, FileSegment,
@@ -75,10 +68,8 @@ from core.workflow.utils.structured_output.entities import (
) )
from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
from models.model import Conversation
from models.provider import Provider, ProviderType
from . import llm_utils
from .entities import ( from .entities import (
LLMNodeChatModelMessage, LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate, LLMNodeCompletionModelPromptTemplate,
@@ -88,7 +79,6 @@ from .entities import (
from .exc import ( from .exc import (
InvalidContextStructureError, InvalidContextStructureError,
InvalidVariableTypeError, InvalidVariableTypeError,
LLMModeRequiredError,
LLMNodeError, LLMNodeError,
MemoryRolePrefixRequiredError, MemoryRolePrefixRequiredError,
ModelNotExistError, ModelNotExistError,
@@ -160,6 +150,7 @@ class LLMNode(BaseNode[LLMNodeData]):
result_text = "" result_text = ""
usage = LLMUsage.empty_usage() usage = LLMUsage.empty_usage()
finish_reason = None finish_reason = None
variable_pool = self.graph_runtime_state.variable_pool
try: try:
# init messages template # init messages template
@@ -178,7 +169,10 @@ class LLMNode(BaseNode[LLMNodeData]):
# fetch files # fetch files
files = ( files = (
self._fetch_files(selector=self.node_data.vision.configs.variable_selector) llm_utils.fetch_files(
variable_pool=variable_pool,
selector=self.node_data.vision.configs.variable_selector,
)
if self.node_data.vision.enabled if self.node_data.vision.enabled
else [] else []
) )
@@ -200,15 +194,18 @@ class LLMNode(BaseNode[LLMNodeData]):
model_instance, model_config = self._fetch_model_config(self.node_data.model) model_instance, model_config = self._fetch_model_config(self.node_data.model)
# fetch memory # fetch memory
memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance) memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=self.node_data.memory,
model_instance=model_instance,
)
query = None query = None
if self.node_data.memory: if self.node_data.memory:
query = self.node_data.memory.query_prompt_template query = self.node_data.memory.query_prompt_template
if not query and ( if not query and (
query_variable := self.graph_runtime_state.variable_pool.get( query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
(SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)
)
): ):
query = query_variable.text query = query_variable.text
@@ -222,7 +219,7 @@ class LLMNode(BaseNode[LLMNodeData]):
memory_config=self.node_data.memory, memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled, vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail, vision_detail=self.node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool, variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables, jinja2_variables=self.node_data.prompt_config.jinja2_variables,
) )
@@ -251,7 +248,7 @@ class LLMNode(BaseNode[LLMNodeData]):
usage = event.usage usage = event.usage
finish_reason = event.finish_reason finish_reason = event.finish_reason
# deduct quota # deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break break
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
structured_output = process_structured_output(result_text) structured_output = process_structured_output(result_text)
@@ -447,18 +444,6 @@ class LLMNode(BaseNode[LLMNodeData]):
return inputs return inputs
def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is None:
return []
elif isinstance(variable, FileSegment):
return [variable.value]
elif isinstance(variable, ArrayFileSegment):
return variable.value
elif isinstance(variable, NoneSegment | ArrayAnySegment):
return []
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
def _fetch_context(self, node_data: LLMNodeData): def _fetch_context(self, node_data: LLMNodeData):
if not node_data.context.enabled: if not node_data.context.enabled:
return return
@@ -524,31 +509,10 @@ class LLMNode(BaseNode[LLMNodeData]):
def _fetch_model_config( def _fetch_model_config(
self, node_data_model: ModelConfig self, node_data_model: ModelConfig
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
if not node_data_model.mode: model, model_config_with_cred = llm_utils.fetch_model_config(
raise LLMModeRequiredError("LLM mode is required.") tenant_id=self.tenant_id, node_data_model=node_data_model
model = ModelManager().get_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
provider=node_data_model.provider,
model=node_data_model.name,
) )
completion_params = model_config_with_cred.parameters
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
# check model
provider_model = model.provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name, model_type=ModelType.LLM
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
# model config
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials) model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
if not model_schema: if not model_schema:
@@ -556,47 +520,12 @@ class LLMNode(BaseNode[LLMNodeData]):
if self.node_data.structured_output_enabled: if self.node_data.structured_output_enabled:
if model_schema.support_structure_output: if model_schema.support_structure_output:
node_data_model.completion_params = self._handle_native_json_schema( completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
node_data_model.completion_params, model_schema.parameter_rules
)
else: else:
# Set appropriate response format based on model capabilities # Set appropriate response format based on model capabilities
self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules) self._set_response_format(completion_params, model_schema.parameter_rules)
model_config_with_cred.parameters = completion_params
return model, ModelConfigWithCredentialsEntity( return model, model_config_with_cred
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=model.provider_model_bundle,
credentials=model.credentials,
parameters=node_data_model.completion_params,
stop=stop,
)
def _fetch_memory(
self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
) -> Optional[TokenBufferMemory]:
if not node_data_memory:
return None
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID.value]
)
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def _fetch_prompt_messages( def _fetch_prompt_messages(
self, self,
@@ -810,55 +739,6 @@ class LLMNode(BaseNode[LLMNodeData]):
structured_output = parsed structured_output = parsed
return structured_output return structured_output
@classmethod
def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model)
else:
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
)
)
session.execute(stmt)
session.commit()
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
cls, cls,

View File

@@ -28,8 +28,9 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.base.node import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import LLMNode, ModelConfig from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.utils import variable_template_parser from core.workflow.utils import variable_template_parser
from .entities import ParameterExtractorNodeData from .entities import ParameterExtractorNodeData
@@ -83,7 +84,7 @@ def extract_json(text):
return None return None
class ParameterExtractorNode(LLMNode): class ParameterExtractorNode(BaseNode):
""" """
Parameter Extractor Node. Parameter Extractor Node.
""" """
@@ -116,8 +117,11 @@ class ParameterExtractorNode(LLMNode):
variable = self.graph_runtime_state.variable_pool.get(node_data.query) variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else "" query = variable.text if variable else ""
variable_pool = self.graph_runtime_state.variable_pool
files = ( files = (
self._fetch_files( llm_utils.fetch_files(
variable_pool=variable_pool,
selector=node_data.vision.configs.variable_selector, selector=node_data.vision.configs.variable_selector,
) )
if node_data.vision.enabled if node_data.vision.enabled
@@ -137,7 +141,9 @@ class ParameterExtractorNode(LLMNode):
raise ModelSchemaNotFoundError("Model schema not found") raise ModelSchemaNotFoundError("Model schema not found")
# fetch memory # fetch memory
memory = self._fetch_memory( memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory, node_data_memory=node_data.memory,
model_instance=model_instance, model_instance=model_instance,
) )
@@ -279,7 +285,7 @@ class ParameterExtractorNode(LLMNode):
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
# deduct quota # deduct quota
self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
if text is None: if text is None:
text = "" text = ""
@@ -794,7 +800,9 @@ class ParameterExtractorNode(LLMNode):
Fetch model config. Fetch model config.
""" """
if not self._model_instance or not self._model_config: if not self._model_instance or not self._model_config:
self._model_instance, self._model_config = super()._fetch_model_config(node_data_model) self._model_instance, self._model_config = llm_utils.fetch_model_config(
tenant_id=self.tenant_id, node_data_model=node_data_model
)
return self._model_instance, self._model_config return self._model_instance, self._model_config

View File

@@ -18,6 +18,7 @@ from core.workflow.nodes.llm import (
LLMNode, LLMNode,
LLMNodeChatModelMessage, LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate, LLMNodeCompletionModelPromptTemplate,
llm_utils,
) )
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown from libs.json_in_md_parser import parse_and_check_json_markdown
@@ -50,7 +51,9 @@ class QuestionClassifierNode(LLMNode):
# fetch model config # fetch model config
model_instance, model_config = self._fetch_model_config(node_data.model) model_instance, model_config = self._fetch_model_config(node_data.model)
# fetch memory # fetch memory
memory = self._fetch_memory( memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
node_data_memory=node_data.memory, node_data_memory=node_data.memory,
model_instance=model_instance, model_instance=model_instance,
) )
@@ -59,7 +62,8 @@ class QuestionClassifierNode(LLMNode):
node_data.instruction = variable_pool.convert_template(node_data.instruction).text node_data.instruction = variable_pool.convert_template(node_data.instruction).text
files = ( files = (
self._fetch_files( llm_utils.fetch_files(
variable_pool=variable_pool,
selector=node_data.vision.configs.variable_selector, selector=node_data.vision.configs.variable_selector,
) )
if node_data.vision.enabled if node_data.vision.enabled

View File

@@ -353,7 +353,7 @@ def test_extract_json_from_tool_call():
assert result["location"] == "kawaii" assert result["location"] == "kawaii"
def test_chat_parameter_extractor_with_memory(setup_model_mock): def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch):
""" """
Test chat parameter extractor with memory. Test chat parameter extractor with memory.
""" """
@@ -384,7 +384,8 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock):
mode="chat", mode="chat",
credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")},
) )
node._fetch_memory = get_mocked_fetch_memory("customized memory") # Test the mock before running the actual test
monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory"))
db.session.close = MagicMock() db.session.close = MagicMock()
result = node._run() result = node._run()

View File

@@ -25,6 +25,7 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.nodes.answer import AnswerStreamGenerateRoute from core.workflow.nodes.answer import AnswerStreamGenerateRoute
from core.workflow.nodes.end import EndStreamParam from core.workflow.nodes.end import EndStreamParam
from core.workflow.nodes.llm import llm_utils
from core.workflow.nodes.llm.entities import ( from core.workflow.nodes.llm.entities import (
ContextConfig, ContextConfig,
LLMNodeChatModelMessage, LLMNodeChatModelMessage,
@@ -170,7 +171,7 @@ def model_config():
) )
def test_fetch_files_with_file_segment(llm_node): def test_fetch_files_with_file_segment():
file = File( file = File(
id="1", id="1",
tenant_id="test", tenant_id="test",
@@ -180,13 +181,14 @@ def test_fetch_files_with_file_segment(llm_node):
related_id="1", related_id="1",
storage_key="", storage_key="",
) )
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) variable_pool = VariablePool()
variable_pool.add(["sys", "files"], file)
result = llm_node._fetch_files(selector=["sys", "files"]) result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == [file] assert result == [file]
def test_fetch_files_with_array_file_segment(llm_node): def test_fetch_files_with_array_file_segment():
files = [ files = [
File( File(
id="1", id="1",
@@ -207,28 +209,32 @@ def test_fetch_files_with_array_file_segment(llm_node):
storage_key="", storage_key="",
), ),
] ]
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) variable_pool = VariablePool()
variable_pool.add(["sys", "files"], ArrayFileSegment(value=files))
result = llm_node._fetch_files(selector=["sys", "files"]) result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == files assert result == files
def test_fetch_files_with_none_segment(llm_node): def test_fetch_files_with_none_segment():
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) variable_pool = VariablePool()
variable_pool.add(["sys", "files"], NoneSegment())
result = llm_node._fetch_files(selector=["sys", "files"]) result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == [] assert result == []
def test_fetch_files_with_array_any_segment(llm_node): def test_fetch_files_with_array_any_segment():
llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) variable_pool = VariablePool()
variable_pool.add(["sys", "files"], ArrayAnySegment(value=[]))
result = llm_node._fetch_files(selector=["sys", "files"]) result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == [] assert result == []
def test_fetch_files_with_non_existent_variable(llm_node): def test_fetch_files_with_non_existent_variable():
result = llm_node._fetch_files(selector=["sys", "files"]) variable_pool = VariablePool()
result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"])
assert result == [] assert result == []