feat(api): Explicitly define version method for all BaseNode subclasses (#21443)
This PR addresses issue #21441 by implementing explicit `version` method definitions for all `BaseNode` subclasses to improve code maintainability. ### Changes Added explicit `version` method definitions for all `BaseNode` subclasses: - `QuestionClassifierNode` - `KnowledgeRetrievalNode` - `AgentNode` Added comprehensive test suite to validate: 1. All subclasses of `BaseNode` have explicitly defined `version` method 2. All subclasses have required `_node_type` property 3. The `(node_type, node_version)` combination is unique across all subclasses
This commit is contained in:
@@ -39,6 +39,10 @@ class AgentNode(ToolNode):
|
||||
_node_data_cls = AgentNodeData # type: ignore
|
||||
_node_type = NodeType.AGENT
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator:
|
||||
"""
|
||||
Run the agent node
|
||||
|
@@ -71,6 +71,10 @@ class KnowledgeRetrievalNode(LLMNode):
|
||||
_node_data_cls = KnowledgeRetrievalNodeData # type: ignore
|
||||
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult: # type: ignore
|
||||
node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
|
||||
# extract variables
|
||||
|
@@ -40,6 +40,10 @@ class QuestionClassifierNode(LLMNode):
|
||||
_node_data_cls = QuestionClassifierNodeData # type: ignore
|
||||
_node_type = NodeType.QUESTION_CLASSIFIER
|
||||
|
||||
@classmethod
|
||||
def version(cls):
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
node_data = cast(QuestionClassifierNodeData, self.node_data)
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
@@ -0,0 +1,36 @@
|
||||
from core.workflow.nodes.base.node import BaseNode
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
|
||||
# Ensures that all node classes are imported.
|
||||
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
_ = NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
|
||||
def _get_all_subclasses(root: type[BaseNode]) -> list[type[BaseNode]]:
|
||||
subclasses = []
|
||||
queue = [root]
|
||||
while queue:
|
||||
cls = queue.pop()
|
||||
|
||||
subclasses.extend(cls.__subclasses__())
|
||||
queue.extend(cls.__subclasses__())
|
||||
|
||||
return subclasses
|
||||
|
||||
|
||||
def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined():
|
||||
classes = _get_all_subclasses(BaseNode) # type: ignore
|
||||
type_version_set: set[tuple[NodeType, str]] = set()
|
||||
|
||||
for cls in classes:
|
||||
# Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__
|
||||
assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)"
|
||||
node_type = cls._node_type
|
||||
node_version = cls.version()
|
||||
|
||||
assert isinstance(cls._node_type, NodeType)
|
||||
assert isinstance(node_version, str)
|
||||
node_type_and_version = (node_type, node_version)
|
||||
assert node_type_and_version not in type_version_set
|
||||
type_version_set.add(node_type_and_version)
|
Reference in New Issue
Block a user