refactor: decouple Node and NodeData (#22581)

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
This commit is contained in:
-LAN-
2025-07-18 10:08:51 +08:00
committed by GitHub
parent 54c56f2d05
commit 460a825ef1
65 changed files with 2305 additions and 1146 deletions

View File

@@ -58,21 +58,26 @@ def test_execute_answer():
pool.add(["start", "weather"], "sunny")
pool.add(["llm", "text"], "You are a helpful AI.")
node_config = {
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
}
node = AnswerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()

View File

@@ -57,12 +57,15 @@ def test_http_request_node_binary_file(monkeypatch):
),
),
)
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = HttpRequestNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@@ -90,6 +93,9 @@ def test_http_request_node_binary_file(monkeypatch):
start_at=0,
),
)
# Initialize node data
node.init_node_data(node_config["data"])
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
@@ -145,12 +151,15 @@ def test_http_request_node_form_with_file(monkeypatch):
),
),
)
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = HttpRequestNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@@ -178,6 +187,10 @@ def test_http_request_node_form_with_file(monkeypatch):
start_at=0,
),
)
# Initialize node data
node.init_node_data(node_config["data"])
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda *args, **kwargs: b"test",
@@ -257,12 +270,14 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
),
)
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = HttpRequestNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@@ -291,6 +306,9 @@ def test_http_request_node_form_with_multiple_files(monkeypatch):
),
)
# Initialize node data
node.init_node_data(node_config["data"])
monkeypatch.setattr(
"core.workflow.nodes.http_request.executor.file_manager.download",
lambda file: b"test_image_data" if file.mime_type == "image/jpeg" else b"test_pdf_data",

View File

@@ -162,25 +162,30 @@ def test_run():
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "tt",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
}
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "tt",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
},
config=node_config,
)
# Initialize node data
iteration_node.init_node_data(node_config["data"])
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -379,25 +384,30 @@ def test_run_parallel():
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
}
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
},
"id": "iteration-1",
},
config=node_config,
)
# Initialize node data
iteration_node.init_node_data(node_config["data"])
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -595,45 +605,55 @@ def test_iteration_run_in_parallel_mode():
)
pool.add(["pe", "list_output"], ["dify-1", "dify-2"])
parallel_node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
}
parallel_iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
},
config=parallel_node_config,
)
# Initialize node data
parallel_iteration_node.init_node_data(parallel_node_config["data"])
sequential_node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
}
sequential_iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "迭代",
"type": "iteration",
"is_parallel": True,
},
"id": "iteration-1",
},
config=sequential_node_config,
)
# Initialize node data
sequential_iteration_node.init_node_data(sequential_node_config["data"])
def tt_generator(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -645,8 +665,8 @@ def test_iteration_run_in_parallel_mode():
# execute node
parallel_result = parallel_iteration_node._run()
sequential_result = sequential_iteration_node._run()
assert parallel_iteration_node.node_data.parallel_nums == 10
assert parallel_iteration_node.node_data.error_handle_mode == ErrorHandleMode.TERMINATED
assert parallel_iteration_node._node_data.parallel_nums == 10
assert parallel_iteration_node._node_data.error_handle_mode == ErrorHandleMode.TERMINATED
count = 0
parallel_arr = []
sequential_arr = []
@@ -818,26 +838,31 @@ def test_iteration_run_error_handle():
environment_variables=[],
)
pool.add(["pe", "list_output"], ["1", "1"])
error_node_config = {
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "iteration",
"type": "iteration",
"is_parallel": True,
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
},
"id": "iteration-1",
}
iteration_node = IterationNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"data": {
"iterator_selector": ["pe", "list_output"],
"output_selector": ["tt", "output"],
"output_type": "array[string]",
"startNodeType": "template-transform",
"start_node_id": "iteration-start",
"title": "iteration",
"type": "iteration",
"is_parallel": True,
"error_handle_mode": ErrorHandleMode.CONTINUE_ON_ERROR,
},
"id": "iteration-1",
},
config=error_node_config,
)
# Initialize node data
iteration_node.init_node_data(error_node_config["data"])
# execute continue on error node
result = iteration_node._run()
result_arr = []
@@ -851,7 +876,7 @@ def test_iteration_run_error_handle():
assert count == 14
# execute remove abnormal output
iteration_node.node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
iteration_node._node_data.error_handle_mode = ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT
result = iteration_node._run()
count = 0
for item in result:

View File

@@ -119,17 +119,20 @@ def llm_node(
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
) -> LLMNode:
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
}
node = LLMNode(
id="1",
config={
"id": "1",
"data": llm_node_data.model_dump(),
},
config=node_config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
# Initialize node data
node.init_node_data(node_config["data"])
return node
@@ -488,7 +491,7 @@ def test_handle_list_messages_basic(llm_node):
variable_pool = llm_node.graph_runtime_state.variable_pool
vision_detail_config = ImagePromptMessageContent.DETAIL.HIGH
result = llm_node._handle_list_messages(
result = llm_node.handle_list_messages(
messages=messages,
context=context,
jinja2_variables=jinja2_variables,
@@ -506,17 +509,20 @@ def llm_node_for_multimodal(
llm_node_data, graph_init_params, graph, graph_runtime_state
) -> tuple[LLMNode, LLMFileSaver]:
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
}
node = LLMNode(
id="1",
config={
"id": "1",
"data": llm_node_data.model_dump(),
},
config=node_config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
# Initialize node data
node.init_node_data(node_config["data"])
return node, mock_file_saver
@@ -540,7 +546,12 @@ class TestLLMNodeSaveMultiModalImageOutput:
size=9,
)
mock_file_saver.save_binary_string.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content)
file = llm_node.save_multimodal_image_output(
content=content,
file_saver=mock_file_saver,
)
# Manually append to _file_outputs since the static method doesn't do it
llm_node._file_outputs.append(file)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
mock_file_saver.save_binary_string.assert_called_once_with(
@@ -566,7 +577,12 @@ class TestLLMNodeSaveMultiModalImageOutput:
size=9,
)
mock_file_saver.save_remote_url.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content)
file = llm_node.save_multimodal_image_output(
content=content,
file_saver=mock_file_saver,
)
# Manually append to _file_outputs since the static method doesn't do it
llm_node._file_outputs.append(file)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
@@ -582,7 +598,9 @@ def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_str_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world")
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents="hello world", file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
@@ -590,7 +608,7 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_text_prompt_message_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[TextPromptMessageContent(data="hello world")]
contents=[TextPromptMessageContent(data="hello world")], file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called()
@@ -616,13 +634,15 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
)
mock_file_saver.save_binary_string.return_value = mock_saved_file
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[
contents=[
ImagePromptMessageContent(
format="png",
base64_data=image_b64_data,
mime_type="image/png",
)
]
],
file_saver=mock_file_saver,
file_outputs=llm_node._file_outputs,
)
yielded_strs = list(gen)
assert len(yielded_strs) == 1
@@ -645,21 +665,27 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_unknown_content_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"]))
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents=frozenset(["hello world"]), file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_unknown_item_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])])
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents=[frozenset(["hello world"])], file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_none_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None)
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
contents=None, file_saver=mock_file_saver, file_outputs=[]
)
assert list(gen) == []
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()

View File

@@ -61,21 +61,26 @@ def test_execute_answer():
variable_pool.add(["start", "weather"], "sunny")
variable_pool.add(["llm", "text"], "You are a helpful AI.")
node_config = {
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
}
node = AnswerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "answer",
"data": {
"title": "123",
"type": "answer",
"answer": "Today's weather is {{#start.weather#}}\n{{#llm.text#}}\n{{img}}\nFin.",
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()

View File

@@ -27,13 +27,17 @@ def document_extractor_node():
title="Test Document Extractor",
variable_selector=["node_id", "variable_name"],
)
return DocumentExtractorNode(
node_config = {"id": "test_node_id", "data": node_data.model_dump()}
node = DocumentExtractorNode(
id="test_node_id",
config={"id": "test_node_id", "data": node_data.model_dump()},
config=node_config,
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=Mock(),
)
# Initialize node data
node.init_node_data(node_config["data"])
return node
@pytest.fixture

View File

@@ -57,57 +57,62 @@ def test_execute_if_else_result_true():
pool.add(["start", "null"], None)
pool.add(["start", "not_null"], "1212")
node_config = {
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "and",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
{"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "not_contains"],
"value": "ab",
},
{"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
{"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
{"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
{"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
{"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
{"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
{"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
{"comparison_operator": "", "variable_selector": ["start", "not_equals"], "value": "22"},
{"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
{"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
{
"comparison_operator": "",
"variable_selector": ["start", "greater_than_or_equal"],
"value": "22",
},
{"comparison_operator": "", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
{"comparison_operator": "null", "variable_selector": ["start", "null"]},
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
],
},
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "and",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
{"comparison_operator": "contains", "variable_selector": ["start", "contains"], "value": "ab"},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "not_contains"],
"value": "ab",
},
{"comparison_operator": "start with", "variable_selector": ["start", "start_with"], "value": "ab"},
{"comparison_operator": "end with", "variable_selector": ["start", "end_with"], "value": "ab"},
{"comparison_operator": "is", "variable_selector": ["start", "is"], "value": "ab"},
{"comparison_operator": "is not", "variable_selector": ["start", "is_not"], "value": "ab"},
{"comparison_operator": "empty", "variable_selector": ["start", "empty"], "value": "ab"},
{"comparison_operator": "not empty", "variable_selector": ["start", "not_empty"], "value": "ab"},
{"comparison_operator": "=", "variable_selector": ["start", "equals"], "value": "22"},
{"comparison_operator": "", "variable_selector": ["start", "not_equals"], "value": "22"},
{"comparison_operator": ">", "variable_selector": ["start", "greater_than"], "value": "22"},
{"comparison_operator": "<", "variable_selector": ["start", "less_than"], "value": "22"},
{
"comparison_operator": "",
"variable_selector": ["start", "greater_than_or_equal"],
"value": "22",
},
{"comparison_operator": "", "variable_selector": ["start", "less_than_or_equal"], "value": "22"},
{"comparison_operator": "null", "variable_selector": ["start", "null"]},
{"comparison_operator": "not null", "variable_selector": ["start", "not_null"]},
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()
@@ -162,33 +167,38 @@ def test_execute_if_else_result_false():
pool.add(["start", "array_contains"], ["1ab", "def"])
pool.add(["start", "array_not_contains"], ["ab", "def"])
node_config = {
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "or",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
],
},
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
config={
"id": "if-else",
"data": {
"title": "123",
"type": "if-else",
"logical_operator": "or",
"conditions": [
{
"comparison_operator": "contains",
"variable_selector": ["start", "array_contains"],
"value": "ab",
},
{
"comparison_operator": "not contains",
"variable_selector": ["start", "array_not_contains"],
"value": "ab",
},
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()
@@ -228,17 +238,22 @@ def test_array_file_contains_file_name():
],
)
node_config = {
"id": "if-else",
"data": node_data.model_dump(),
}
node = IfElseNode(
id=str(uuid.uuid4()),
graph_init_params=Mock(),
graph=Mock(),
graph_runtime_state=Mock(),
config={
"id": "if-else",
"data": node_data.model_dump(),
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
value=[
File(

View File

@@ -33,16 +33,19 @@ def list_operator_node():
"title": "Test Title",
}
node_data = ListOperatorNodeData(**config)
node_config = {
"id": "test_node_id",
"data": node_data.model_dump(),
}
node = ListOperatorNode(
id="test_node_id",
config={
"id": "test_node_id",
"data": node_data.model_dump(),
},
config=node_config,
graph_init_params=MagicMock(),
graph=MagicMock(),
graph_runtime_state=MagicMock(),
)
# Initialize node data
node.init_node_data(node_config["data"])
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.variable_pool = MagicMock()
return node

View File

@@ -38,12 +38,13 @@ def _create_tool_node():
system_variables=SystemVariable.empty(),
user_inputs={},
)
node_config = {
"id": "1",
"data": data.model_dump(),
}
node = ToolNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
config=node_config,
graph_init_params=GraphInitParams(
tenant_id="1",
app_id="1",
@@ -71,6 +72,8 @@ def _create_tool_node():
start_at=0,
),
)
# Initialize node data
node.init_node_data(node_config["data"])
return node

View File

@@ -82,23 +82,28 @@ def test_overwrite_string_variable():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.OVER_WRITE.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.OVER_WRITE.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run())
expected_var = StringVariable(
id=conversation_variable.id,
@@ -178,23 +183,28 @@ def test_append_variable_to_array():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.APPEND.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.APPEND.value,
"input_variable_selector": [DEFAULT_NODE_ID, input_variable.name],
},
},
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run())
expected_value = list(conversation_variable.value)
expected_value.append(input_variable.value)
@@ -265,23 +275,28 @@ def test_clear_array():
mock_conv_var_updater = mock.Mock(spec=ConversationVariableUpdater)
mock_conv_var_updater_factory = mock.Mock(return_value=mock_conv_var_updater)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.CLEAR.value,
"input_variable_selector": [],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"assigned_variable_selector": ["conversation", conversation_variable.name],
"write_mode": WriteMode.CLEAR.value,
"input_variable_selector": [],
},
},
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run())
expected_var = ArrayStringVariable(
id=conversation_variable.id,

View File

@@ -115,28 +115,33 @@ def test_remove_first_from_array():
conversation_variables=[conversation_variable],
)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
# Print the variable before running
print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
@@ -202,28 +207,33 @@ def test_remove_last_from_array():
conversation_variables=[conversation_variable],
)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
list(node.run())
@@ -281,28 +291,33 @@ def test_remove_first_from_empty_array():
conversation_variables=[conversation_variable],
)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
list(node.run())
@@ -360,28 +375,33 @@ def test_remove_last_from_empty_array():
conversation_variables=[conversation_variable],
)
node_config = {
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
}
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
},
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
list(node.run())