refactor: decouple Node and NodeData (#22581)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
-LAN-
2025-07-18 10:08:51 +08:00
committed by GitHub
parent 54c56f2d05
commit 460a825ef1
65 changed files with 2305 additions and 1146 deletions

View File

@@ -15,7 +15,7 @@ def get_mocked_fetch_model_config(
mode: str,
credentials: dict,
):
model_provider_factory = ModelProviderFactory(tenant_id="test_tenant")
model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b")
model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM)
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(

View File

@@ -66,6 +66,10 @@ def init_code_node(code_config: dict):
config=code_config,
)
# Initialize node data
if "data" in code_config:
node.init_node_data(code_config["data"])
return node
@@ -234,10 +238,10 @@ def test_execute_code_output_validator_depth():
"object_validator": {"result": 1, "depth": {"depth": {"depth": 1}}},
}
node.node_data = cast(CodeNodeData, node.node_data)
node._node_data = cast(CodeNodeData, node._node_data)
# validate
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@@ -250,7 +254,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@@ -263,7 +267,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@@ -276,7 +280,7 @@ def test_execute_code_output_validator_depth():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
def test_execute_code_output_object_list():
@@ -330,10 +334,10 @@ def test_execute_code_output_object_list():
]
}
node.node_data = cast(CodeNodeData, node.node_data)
node._node_data = cast(CodeNodeData, node._node_data)
# validate
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
# construct result
result = {
@@ -353,7 +357,7 @@ def test_execute_code_output_object_list():
# validate
with pytest.raises(ValueError):
node._transform_result(result, node.node_data.outputs)
node._transform_result(result, node._node_data.outputs)
def test_execute_code_scientific_notation():

View File

@@ -52,7 +52,7 @@ def init_http_node(config: dict):
variable_pool.add(["a", "b123", "args1"], 1)
variable_pool.add(["a", "b123", "args2"], 2)
return HttpRequestNode(
node = HttpRequestNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
@@ -60,6 +60,12 @@ def init_http_node(config: dict):
config=config,
)
# Initialize node data
if "data" in config:
node.init_node_data(config["data"])
return node
@pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True)
def test_get(setup_http_mock):

View File

@@ -2,15 +2,10 @@ import json
import time
import uuid
from collections.abc import Generator
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.output_parser.structured_output import _parse_structured_output
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.graph import Graph
@@ -24,8 +19,6 @@ from models.enums import UserFrom
from models.workflow import WorkflowType
"""FOR MOCK FIXTURES, DO NOT REMOVE"""
from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
def init_llm_node(config: dict) -> LLMNode:
@@ -84,10 +77,14 @@ def init_llm_node(config: dict) -> LLMNode:
config=config,
)
# Initialize node data
if "data" in config:
node.init_node_data(config["data"])
return node
def test_execute_llm(flask_req_ctx):
def test_execute_llm():
node = init_llm_node(
config={
"id": "llm",
@@ -95,7 +92,7 @@ def test_execute_llm(flask_req_ctx):
"title": "123",
"type": "llm",
"model": {
"provider": "langgenius/openai/openai",
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {},
@@ -114,53 +111,62 @@ def test_execute_llm(flask_req_ctx):
},
)
# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
db.session.close = MagicMock()
mock_message = AssistantPromptMessage(content="This is a test response from the mocked LLM.")
# Mock the _fetch_model_config to avoid database calls
def mock_fetch_model_config(**_kwargs):
from decimal import Decimal
from unittest.mock import MagicMock
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create mock model instance
mock_model_instance = MagicMock()
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="Test response from mock")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "langgenius/openai/openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Create mock model config
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.parameters = {}
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_1(**_kwargs):
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
return [
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
UserPromptMessage(content="what's the weather today?"),
], []
with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_1),
):
# execute node
result = node._run()
@@ -168,6 +174,9 @@ def test_execute_llm(flask_req_ctx):
for item in result:
if isinstance(item, RunCompletedEvent):
if item.run_result.status != WorkflowNodeExecutionStatus.SUCCEEDED:
print(f"Error: {item.run_result.error}")
print(f"Error type: {item.run_result.error_type}")
assert item.run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert item.run_result.process_data is not None
assert item.run_result.outputs is not None
@@ -175,8 +184,7 @@ def test_execute_llm(flask_req_ctx):
assert item.run_result.outputs.get("usage", {})["total_tokens"] > 0
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
def test_execute_llm_with_jinja2():
"""
Test execute LLM node with jinja2
"""
@@ -217,53 +225,60 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock):
# Mock db.session.close()
db.session.close = MagicMock()
# Create a proper LLM result with real entities
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
# Create a simple mock model instance that doesn't call real providers
mock_model_instance = MagicMock()
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create a simple mock model config with required attributes
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.provider_model_bundle.configuration.tenant_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056b"
# Mock the _fetch_model_config method
def mock_fetch_model_config_func(_node_data_model):
def mock_fetch_model_config(**_kwargs):
from decimal import Decimal
from unittest.mock import MagicMock
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import AssistantPromptMessage
# Create mock model instance
mock_model_instance = MagicMock()
mock_usage = LLMUsage(
prompt_tokens=30,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal(1000),
prompt_price=Decimal("0.00003"),
completion_tokens=20,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal(1000),
completion_price=Decimal("0.00004"),
total_tokens=50,
total_price=Decimal("0.00007"),
currency="USD",
latency=0.5,
)
mock_message = AssistantPromptMessage(content="Test response: sunny weather and what's the weather today?")
mock_llm_result = LLMResult(
model="gpt-3.5-turbo",
prompt_messages=[],
message=mock_message,
usage=mock_usage,
)
mock_model_instance.invoke_llm.return_value = mock_llm_result
# Create mock model config
mock_model_config = MagicMock()
mock_model_config.mode = "chat"
mock_model_config.provider = "openai"
mock_model_config.model = "gpt-3.5-turbo"
mock_model_config.parameters = {}
return mock_model_instance, mock_model_config
# Also mock ModelManager.get_model_instance to avoid database calls
def mock_get_model_instance(_self, **kwargs):
return mock_model_instance
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_2(**_kwargs):
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
return [
SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."),
UserPromptMessage(content="what's the weather today?"),
], []
with (
patch.object(node, "_fetch_model_config", mock_fetch_model_config_func),
patch("core.model_manager.ModelManager.get_model_instance", mock_get_model_instance),
patch.object(LLMNode, "_fetch_model_config", mock_fetch_model_config),
patch.object(LLMNode, "fetch_prompt_messages", mock_fetch_prompt_messages_2),
):
# execute node
result = node._run()

View File

@@ -74,13 +74,15 @@ def init_parameter_extractor_node(config: dict):
variable_pool.add(["a", "b123", "args1"], 1)
variable_pool.add(["a", "b123", "args2"], 2)
return ParameterExtractorNode(
node = ParameterExtractorNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
node.init_node_data(config.get("data", {}))
return node
def test_function_calling_parameter_extractor(setup_model_mock):

View File

@@ -76,6 +76,7 @@ def test_execute_code(setup_code_executor_mock):
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
node.init_node_data(config.get("data", {}))
# execute node
result = node._run()

View File

@@ -50,13 +50,15 @@ def init_tool_node(config: dict):
conversation_variables=[],
)
return ToolNode(
node = ToolNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config=config,
)
node.init_node_data(config.get("data", {}))
return node
def test_tool_variable_invoke():