diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 17cfaf2ed..072644e53 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -21,7 +21,7 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from core.workflow.nodes.llm.node import LLMNode +from core.workflow.nodes.llm import llm_utils from models.account import Tenant @@ -55,7 +55,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): def handle() -> Generator[LLMResultChunk, None, None]: for chunk in response: if chunk.delta.usage: - LLMNode.deduct_llm_quota( + llm_utils.deduct_llm_quota( tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage ) chunk.prompt_messages = [] @@ -64,7 +64,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): return handle() else: if response.usage: - LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) + llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]: yield LLMResultChunk( diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index f0426ace1..33a283771 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -9,7 +9,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser -from core.workflow.nodes.llm import LLMNode +from core.workflow.nodes.llm import llm_utils PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" @@ -165,7 +165,7 @@ class ReactMultiDatasetRouter: text, usage = self._handle_invoke_result(invoke_result=invoke_result) # deduct quota - LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) + llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) return text, usage diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py new file mode 100644 index 000000000..0966c87a1 --- /dev/null +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -0,0 +1,156 @@ +from collections.abc import Sequence +from datetime import UTC, datetime +from typing import Optional, cast + +from sqlalchemy import select, update +from sqlalchemy.orm import Session + +from configs import dify_config +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.provider_entities import QuotaUnit +from core.file.models import File +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.plugin.entities.plugin import ModelProviderID +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.nodes.llm.entities import ModelConfig +from models import db +from models.model import Conversation +from models.provider import Provider, ProviderType + +from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError + + +def fetch_model_config( + tenant_id: str, node_data_model: ModelConfig +) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + if not node_data_model.mode: + raise LLMModeRequiredError("LLM mode is required.") + + model = ModelManager().get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=node_data_model.provider, + model=node_data_model.name, + ) + + model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance) + + # check model + provider_model = model.provider_model_bundle.configuration.get_provider_model( + model=node_data_model.name, model_type=ModelType.LLM + ) + + if provider_model is None: + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + provider_model.raise_for_status() + + # model config + stop: list[str] = [] + if "stop" in node_data_model.completion_params: + stop = node_data_model.completion_params.pop("stop") + + model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials) + if not model_schema: + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + + return model, ModelConfigWithCredentialsEntity( + provider=node_data_model.provider, + model=node_data_model.name, + model_schema=model_schema, + mode=node_data_model.mode, + provider_model_bundle=model.provider_model_bundle, + credentials=model.credentials, + parameters=node_data_model.completion_params, + stop=stop, + ) + + +def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]: + variable = variable_pool.get(selector) + if variable is None: + return [] + elif isinstance(variable, FileSegment): + return [variable.value] + elif isinstance(variable, ArrayFileSegment): + return variable.value + elif isinstance(variable, NoneSegment | ArrayAnySegment): + return [] + raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") + + +def fetch_memory( + variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance +) -> Optional[TokenBufferMemory]: + if not node_data_memory: + return None + + # get conversation id + conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value]) + if not isinstance(conversation_id_variable, StringSegment): + return None + conversation_id = conversation_id_variable.value + + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) + conversation = session.scalar(stmt) + if not conversation: + return None + + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) + return memory + + +def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: + provider_model_bundle = model_instance.provider_model_bundle + provider_configuration = provider_model_bundle.configuration + + if provider_configuration.using_provider_type != ProviderType.SYSTEM: + return + + system_configuration = provider_configuration.system_configuration + + quota_unit = None + for quota_configuration in system_configuration.quota_configurations: + if quota_configuration.quota_type == system_configuration.current_quota_type: + quota_unit = quota_configuration.quota_unit + + if quota_configuration.quota_limit == -1: + return + + break + + used_quota = None + if quota_unit: + if quota_unit == QuotaUnit.TOKENS: + used_quota = usage.total_tokens + elif quota_unit == QuotaUnit.CREDITS: + used_quota = dify_config.get_model_credits(model_instance.model) + else: + used_quota = 1 + + if used_quota is not None and system_configuration.current_quota_type is not None: + with Session(db.engine) as session: + stmt = ( + update(Provider) + .where( + Provider.tenant_id == tenant_id, + # TODO: Use provider name with prefix after the data migration. + Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_limit > Provider.quota_used, + ) + .values( + quota_used=Provider.quota_used + used_quota, + last_used=datetime.now(tz=UTC).replace(tzinfo=None), + ) + ) + session.execute(stmt) + session.commit() diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index ee181cf3b..78d2ad628 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -3,16 +3,11 @@ import io import json import logging from collections.abc import Generator, Mapping, Sequence -from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Optional, cast import json_repair -from sqlalchemy import select, update -from sqlalchemy.orm import Session -from configs import dify_config from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.provider_entities import QuotaUnit from core.file import FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.memory.token_buffer_memory import TokenBufferMemory @@ -40,12 +35,10 @@ from core.model_runtime.entities.model_entities import ( ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.plugin.entities.plugin import ModelProviderID from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.variables import ( - ArrayAnySegment, ArrayFileSegment, ArraySegment, FileSegment, @@ -75,10 +68,8 @@ from core.workflow.utils.structured_output.entities import ( ) from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from core.workflow.utils.variable_template_parser import VariableTemplateParser -from extensions.ext_database import db -from models.model import Conversation -from models.provider import Provider, ProviderType +from . import llm_utils from .entities import ( LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, @@ -88,7 +79,6 @@ from .entities import ( from .exc import ( InvalidContextStructureError, InvalidVariableTypeError, - LLMModeRequiredError, LLMNodeError, MemoryRolePrefixRequiredError, ModelNotExistError, @@ -160,6 +150,7 @@ class LLMNode(BaseNode[LLMNodeData]): result_text = "" usage = LLMUsage.empty_usage() finish_reason = None + variable_pool = self.graph_runtime_state.variable_pool try: # init messages template @@ -178,7 +169,10 @@ class LLMNode(BaseNode[LLMNodeData]): # fetch files files = ( - self._fetch_files(selector=self.node_data.vision.configs.variable_selector) + llm_utils.fetch_files( + variable_pool=variable_pool, + selector=self.node_data.vision.configs.variable_selector, + ) if self.node_data.vision.enabled else [] ) @@ -200,15 +194,18 @@ class LLMNode(BaseNode[LLMNodeData]): model_instance, model_config = self._fetch_model_config(self.node_data.model) # fetch memory - memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance) + memory = llm_utils.fetch_memory( + variable_pool=variable_pool, + app_id=self.app_id, + node_data_memory=self.node_data.memory, + model_instance=model_instance, + ) query = None if self.node_data.memory: query = self.node_data.memory.query_prompt_template if not query and ( - query_variable := self.graph_runtime_state.variable_pool.get( - (SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY) - ) + query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) ): query = query_variable.text @@ -222,7 +219,7 @@ class LLMNode(BaseNode[LLMNodeData]): memory_config=self.node_data.memory, vision_enabled=self.node_data.vision.enabled, vision_detail=self.node_data.vision.configs.detail, - variable_pool=self.graph_runtime_state.variable_pool, + variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, ) @@ -251,7 +248,7 @@ class LLMNode(BaseNode[LLMNodeData]): usage = event.usage finish_reason = event.finish_reason # deduct quota - self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) + llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) break outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} structured_output = process_structured_output(result_text) @@ -447,18 +444,6 @@ class LLMNode(BaseNode[LLMNodeData]): return inputs - def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]: - variable = self.graph_runtime_state.variable_pool.get(selector) - if variable is None: - return [] - elif isinstance(variable, FileSegment): - return [variable.value] - elif isinstance(variable, ArrayFileSegment): - return variable.value - elif isinstance(variable, NoneSegment | ArrayAnySegment): - return [] - raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") - def _fetch_context(self, node_data: LLMNodeData): if not node_data.context.enabled: return @@ -524,31 +509,10 @@ class LLMNode(BaseNode[LLMNodeData]): def _fetch_model_config( self, node_data_model: ModelConfig ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - if not node_data_model.mode: - raise LLMModeRequiredError("LLM mode is required.") - - model = ModelManager().get_model_instance( - tenant_id=self.tenant_id, - model_type=ModelType.LLM, - provider=node_data_model.provider, - model=node_data_model.name, + model, model_config_with_cred = llm_utils.fetch_model_config( + tenant_id=self.tenant_id, node_data_model=node_data_model ) - - model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance) - - # check model - provider_model = model.provider_model_bundle.configuration.get_provider_model( - model=node_data_model.name, model_type=ModelType.LLM - ) - - if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - provider_model.raise_for_status() - - # model config - stop: list[str] = [] - if "stop" in node_data_model.completion_params: - stop = node_data_model.completion_params.pop("stop") + completion_params = model_config_with_cred.parameters model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials) if not model_schema: @@ -556,47 +520,12 @@ class LLMNode(BaseNode[LLMNodeData]): if self.node_data.structured_output_enabled: if model_schema.support_structure_output: - node_data_model.completion_params = self._handle_native_json_schema( - node_data_model.completion_params, model_schema.parameter_rules - ) + completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) else: # Set appropriate response format based on model capabilities - self._set_response_format(node_data_model.completion_params, model_schema.parameter_rules) - - return model, ModelConfigWithCredentialsEntity( - provider=node_data_model.provider, - model=node_data_model.name, - model_schema=model_schema, - mode=node_data_model.mode, - provider_model_bundle=model.provider_model_bundle, - credentials=model.credentials, - parameters=node_data_model.completion_params, - stop=stop, - ) - - def _fetch_memory( - self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance - ) -> Optional[TokenBufferMemory]: - if not node_data_memory: - return None - - # get conversation id - conversation_id_variable = self.graph_runtime_state.variable_pool.get( - ["sys", SystemVariableKey.CONVERSATION_ID.value] - ) - if not isinstance(conversation_id_variable, StringSegment): - return None - conversation_id = conversation_id_variable.value - - with Session(db.engine, expire_on_commit=False) as session: - stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id) - conversation = session.scalar(stmt) - if not conversation: - return None - - memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) - - return memory + self._set_response_format(completion_params, model_schema.parameter_rules) + model_config_with_cred.parameters = completion_params + return model, model_config_with_cred def _fetch_prompt_messages( self, @@ -810,55 +739,6 @@ class LLMNode(BaseNode[LLMNodeData]): structured_output = parsed return structured_output - @classmethod - def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: - provider_model_bundle = model_instance.provider_model_bundle - provider_configuration = provider_model_bundle.configuration - - if provider_configuration.using_provider_type != ProviderType.SYSTEM: - return - - system_configuration = provider_configuration.system_configuration - - quota_unit = None - for quota_configuration in system_configuration.quota_configurations: - if quota_configuration.quota_type == system_configuration.current_quota_type: - quota_unit = quota_configuration.quota_unit - - if quota_configuration.quota_limit == -1: - return - - break - - used_quota = None - if quota_unit: - if quota_unit == QuotaUnit.TOKENS: - used_quota = usage.total_tokens - elif quota_unit == QuotaUnit.CREDITS: - used_quota = dify_config.get_model_credits(model_instance.model) - else: - used_quota = 1 - - if used_quota is not None and system_configuration.current_quota_type is not None: - with Session(db.engine) as session: - stmt = ( - update(Provider) - .where( - Provider.tenant_id == tenant_id, - # TODO: Use provider name with prefix after the data migration. - Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used, - ) - .values( - quota_used=Provider.quota_used + used_quota, - last_used=datetime.now(tz=UTC).replace(tzinfo=None), - ) - ) - session.execute(stmt) - session.commit() - @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 4f31258b1..255278476 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -28,8 +28,9 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.llm import LLMNode, ModelConfig +from core.workflow.nodes.llm import ModelConfig, llm_utils from core.workflow.utils import variable_template_parser from .entities import ParameterExtractorNodeData @@ -83,7 +84,7 @@ def extract_json(text): return None -class ParameterExtractorNode(LLMNode): +class ParameterExtractorNode(BaseNode): """ Parameter Extractor Node. """ @@ -116,8 +117,11 @@ class ParameterExtractorNode(LLMNode): variable = self.graph_runtime_state.variable_pool.get(node_data.query) query = variable.text if variable else "" + variable_pool = self.graph_runtime_state.variable_pool + files = ( - self._fetch_files( + llm_utils.fetch_files( + variable_pool=variable_pool, selector=node_data.vision.configs.variable_selector, ) if node_data.vision.enabled @@ -137,7 +141,9 @@ class ParameterExtractorNode(LLMNode): raise ModelSchemaNotFoundError("Model schema not found") # fetch memory - memory = self._fetch_memory( + memory = llm_utils.fetch_memory( + variable_pool=variable_pool, + app_id=self.app_id, node_data_memory=node_data.memory, model_instance=model_instance, ) @@ -279,7 +285,7 @@ class ParameterExtractorNode(LLMNode): tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None # deduct quota - self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) + llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) if text is None: text = "" @@ -794,7 +800,9 @@ class ParameterExtractorNode(LLMNode): Fetch model config. """ if not self._model_instance or not self._model_config: - self._model_instance, self._model_config = super()._fetch_model_config(node_data_model) + self._model_instance, self._model_config = llm_utils.fetch_model_config( + tenant_id=self.tenant_id, node_data_model=node_data_model + ) return self._model_instance, self._model_config diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 28d8b3b86..1f50700c7 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -18,6 +18,7 @@ from core.workflow.nodes.llm import ( LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, + llm_utils, ) from core.workflow.utils.variable_template_parser import VariableTemplateParser from libs.json_in_md_parser import parse_and_check_json_markdown @@ -50,7 +51,9 @@ class QuestionClassifierNode(LLMNode): # fetch model config model_instance, model_config = self._fetch_model_config(node_data.model) # fetch memory - memory = self._fetch_memory( + memory = llm_utils.fetch_memory( + variable_pool=variable_pool, + app_id=self.app_id, node_data_memory=node_data.memory, model_instance=model_instance, ) @@ -59,7 +62,8 @@ class QuestionClassifierNode(LLMNode): node_data.instruction = variable_pool.convert_template(node_data.instruction).text files = ( - self._fetch_files( + llm_utils.fetch_files( + variable_pool=variable_pool, selector=node_data.vision.configs.variable_selector, ) if node_data.vision.enabled diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index e89e03ae8..0df8e8b14 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -353,7 +353,7 @@ def test_extract_json_from_tool_call(): assert result["location"] == "kawaii" -def test_chat_parameter_extractor_with_memory(setup_model_mock): +def test_chat_parameter_extractor_with_memory(setup_model_mock, monkeypatch): """ Test chat parameter extractor with memory. """ @@ -384,7 +384,8 @@ def test_chat_parameter_extractor_with_memory(setup_model_mock): mode="chat", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}, ) - node._fetch_memory = get_mocked_fetch_memory("customized memory") + # Test the mock before running the actual test + monkeypatch.setattr("core.workflow.nodes.llm.llm_utils.fetch_memory", get_mocked_fetch_memory("customized memory")) db.session.close = MagicMock() result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 519dd7378..336c2befc 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -25,6 +25,7 @@ from core.workflow.entities.variable_pool import VariablePool from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.nodes.answer import AnswerStreamGenerateRoute from core.workflow.nodes.end import EndStreamParam +from core.workflow.nodes.llm import llm_utils from core.workflow.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, @@ -170,7 +171,7 @@ def model_config(): ) -def test_fetch_files_with_file_segment(llm_node): +def test_fetch_files_with_file_segment(): file = File( id="1", tenant_id="test", @@ -180,13 +181,14 @@ def test_fetch_files_with_file_segment(llm_node): related_id="1", storage_key="", ) - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], file) + variable_pool = VariablePool() + variable_pool.add(["sys", "files"], file) - result = llm_node._fetch_files(selector=["sys", "files"]) + result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) assert result == [file] -def test_fetch_files_with_array_file_segment(llm_node): +def test_fetch_files_with_array_file_segment(): files = [ File( id="1", @@ -207,28 +209,32 @@ def test_fetch_files_with_array_file_segment(llm_node): storage_key="", ), ] - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) + variable_pool = VariablePool() + variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) - result = llm_node._fetch_files(selector=["sys", "files"]) + result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) assert result == files -def test_fetch_files_with_none_segment(llm_node): - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], NoneSegment()) +def test_fetch_files_with_none_segment(): + variable_pool = VariablePool() + variable_pool.add(["sys", "files"], NoneSegment()) - result = llm_node._fetch_files(selector=["sys", "files"]) + result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) assert result == [] -def test_fetch_files_with_array_any_segment(llm_node): - llm_node.graph_runtime_state.variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) +def test_fetch_files_with_array_any_segment(): + variable_pool = VariablePool() + variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) - result = llm_node._fetch_files(selector=["sys", "files"]) + result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) assert result == [] -def test_fetch_files_with_non_existent_variable(llm_node): - result = llm_node._fetch_files(selector=["sys", "files"]) +def test_fetch_files_with_non_existent_variable(): + variable_pool = VariablePool() + result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) assert result == []