From 873372f61eb04e6148ce431285962c67cfce1b42 Mon Sep 17 00:00:00 2001 From: Jeremy Stretch Date: Tue, 9 Sep 2025 12:56:08 -0400 Subject: [PATCH] Closes #20241: Record A & B terminations on cable changelog records (#20246) --- netbox/dcim/models/cables.py | 165 +++++++++++++++++++++--------- netbox/utilities/testing/api.py | 8 +- netbox/utilities/testing/views.py | 4 +- 3 files changed, 122 insertions(+), 55 deletions(-) diff --git a/netbox/dcim/models/cables.py b/netbox/dcim/models/cables.py index 69e07ed94..89c9a99b4 100644 --- a/netbox/dcim/models/cables.py +++ b/netbox/dcim/models/cables.py @@ -18,6 +18,7 @@ from utilities.conversion import to_meters from utilities.exceptions import AbortRequest from utilities.fields import ColorField, GenericArrayForeignKey from utilities.querysets import RestrictedQuerySet +from utilities.serialization import deserialize_object, serialize_object from wireless.models import WirelessLink from .device_components import FrontPort, RearPort, PathEndpoint @@ -119,43 +120,61 @@ class Cable(PrimaryModel): pk = self.pk or self._pk return self.label or f'#{pk}' - @property - def a_terminations(self): - if hasattr(self, '_a_terminations'): - return self._a_terminations + def get_status_color(self): + return LinkStatusChoices.colors.get(self.status) + def _get_x_terminations(self, side): + """ + Return the terminating objects for the given cable end (A or B). + """ + if side not in (CableEndChoices.SIDE_A, CableEndChoices.SIDE_B): + raise ValueError(f"Unknown cable side: {side}") + attr = f'_{side.lower()}_terminations' + + if hasattr(self, attr): + return getattr(self, attr) if not self.pk: return [] - - # Query self.terminations.all() to leverage cached results return [ - ct.termination for ct in self.terminations.all() if ct.cable_end == CableEndChoices.SIDE_A + # Query self.terminations.all() to leverage cached results + ct.termination for ct in self.terminations.all() if ct.cable_end == side ] + def _set_x_terminations(self, side, value): + """ + Set the terminating objects for the given cable end (A or B). + """ + if side not in (CableEndChoices.SIDE_A, CableEndChoices.SIDE_B): + raise ValueError(f"Unknown cable side: {side}") + _attr = f'_{side.lower()}_terminations' + + # If the provided value is a list of CableTermination IDs, resolve them + # to their corresponding termination objects. + if all(isinstance(item, int) for item in value): + value = [ + ct.termination for ct in CableTermination.objects.filter(pk__in=value).prefetch_related('termination') + ] + + if not self.pk or getattr(self, _attr, []) != list(value): + self._terminations_modified = True + + setattr(self, _attr, value) + + @property + def a_terminations(self): + return self._get_x_terminations(CableEndChoices.SIDE_A) + @a_terminations.setter def a_terminations(self, value): - if not self.pk or self.a_terminations != list(value): - self._terminations_modified = True - self._a_terminations = value + self._set_x_terminations(CableEndChoices.SIDE_A, value) @property def b_terminations(self): - if hasattr(self, '_b_terminations'): - return self._b_terminations - - if not self.pk: - return [] - - # Query self.terminations.all() to leverage cached results - return [ - ct.termination for ct in self.terminations.all() if ct.cable_end == CableEndChoices.SIDE_B - ] + return self._get_x_terminations(CableEndChoices.SIDE_B) @b_terminations.setter def b_terminations(self, value): - if not self.pk or self.b_terminations != list(value): - self._terminations_modified = True - self._b_terminations = value + self._set_x_terminations(CableEndChoices.SIDE_B, value) @property def color_name(self): @@ -208,7 +227,7 @@ class Cable(PrimaryModel): for termination in self.b_terminations: CableTermination(cable=self, cable_end='B', termination=termination).clean() - def save(self, *args, **kwargs): + def save(self, *args, force_insert=False, force_update=False, using=None, update_fields=None): _created = self.pk is None # Store the given length (if any) in meters for use in database ordering @@ -221,39 +240,87 @@ class Cable(PrimaryModel): if self.length is None: self.length_unit = None - super().save(*args, **kwargs) + # If this is a new Cable, save it before attempting to create its CableTerminations + if self._state.adding: + super().save(*args, force_insert=True, using=using, update_fields=update_fields) + # Update the private PK used in __str__() + self._pk = self.pk - # Update the private pk used in __str__ in case this is a new object (i.e. just got its pk) - self._pk = self.pk - - # Retrieve existing A/B terminations for the Cable - a_terminations = {ct.termination: ct for ct in self.terminations.filter(cable_end='A')} - b_terminations = {ct.termination: ct for ct in self.terminations.filter(cable_end='B')} - - # Delete stale CableTerminations if self._terminations_modified: - for termination, ct in a_terminations.items(): - if termination.pk and termination not in self.a_terminations: - ct.delete() - for termination, ct in b_terminations.items(): - if termination.pk and termination not in self.b_terminations: - ct.delete() + self.update_terminations() + + super().save(*args, force_update=True, using=using, update_fields=update_fields) - # Save new CableTerminations (if any) - if self._terminations_modified: - for termination in self.a_terminations: - if not termination.pk or termination not in a_terminations: - CableTermination(cable=self, cable_end='A', termination=termination).save() - for termination in self.b_terminations: - if not termination.pk or termination not in b_terminations: - CableTermination(cable=self, cable_end='B', termination=termination).save() try: trace_paths.send(Cable, instance=self, created=_created) except UnsupportedCablePath as e: raise AbortRequest(e) - def get_status_color(self): - return LinkStatusChoices.colors.get(self.status) + def serialize_object(self, exclude=None): + data = serialize_object(self, exclude=exclude or []) + + # Add A & B terminations to the serialized data + a_terminations, b_terminations = self.get_terminations() + data['a_terminations'] = sorted([ct.pk for ct in a_terminations.values()]) + data['b_terminations'] = sorted([ct.pk for ct in b_terminations.values()]) + + return data + + @classmethod + def deserialize_object(cls, data, pk=None): + a_terminations = data.pop('a_terminations', []) + b_terminations = data.pop('b_terminations', []) + + instance = deserialize_object(cls, data, pk=pk) + + # Assign A & B termination objects to the Cable instance + queryset = CableTermination.objects.prefetch_related('termination') + instance.a_terminations = [ + ct.termination for ct in queryset.filter(pk__in=a_terminations) + ] + instance.b_terminations = [ + ct.termination for ct in queryset.filter(pk__in=b_terminations) + ] + + return instance + + def get_terminations(self): + """ + Return two dictionaries mapping A & B side terminating objects to their corresponding CableTerminations + for this Cable. + """ + a_terminations = {} + b_terminations = {} + + for ct in CableTermination.objects.filter(cable=self).prefetch_related('termination'): + if ct.cable_end == CableEndChoices.SIDE_A: + a_terminations[ct.termination] = ct + else: + b_terminations[ct.termination] = ct + + return a_terminations, b_terminations + + def update_terminations(self): + """ + Create/delete CableTerminations for this Cable to reflect its current state. + """ + a_terminations, b_terminations = self.get_terminations() + + # Delete any stale CableTerminations + for termination, ct in a_terminations.items(): + if termination.pk and termination not in self.a_terminations: + ct.delete() + for termination, ct in b_terminations.items(): + if termination.pk and termination not in self.b_terminations: + ct.delete() + + # Save any new CableTerminations + for termination in self.a_terminations: + if not termination.pk or termination not in a_terminations: + CableTermination(cable=self, cable_end='A', termination=termination).save() + for termination in self.b_terminations: + if not termination.pk or termination not in b_terminations: + CableTermination(cable=self, cable_end='B', termination=termination).save() class CableTermination(ChangeLoggedModel): diff --git a/netbox/utilities/testing/api.py b/netbox/utilities/testing/api.py index 8df8f4438..1fe881367 100644 --- a/netbox/utilities/testing/api.py +++ b/netbox/utilities/testing/api.py @@ -247,9 +247,9 @@ class APIViewTestCases: if issubclass(self.model, ChangeLoggingMixin): objectchange = ObjectChange.objects.get( changed_object_type=ContentType.objects.get_for_model(instance), - changed_object_id=instance.pk + changed_object_id=instance.pk, + action=ObjectChangeActionChoices.ACTION_CREATE, ) - self.assertEqual(objectchange.action, ObjectChangeActionChoices.ACTION_CREATE) self.assertEqual(objectchange.message, data['changelog_message']) def test_bulk_create_objects(self): @@ -298,11 +298,11 @@ class APIViewTestCases: ] objectchanges = ObjectChange.objects.filter( changed_object_type=ContentType.objects.get_for_model(self.model), - changed_object_id__in=id_list + changed_object_id__in=id_list, + action=ObjectChangeActionChoices.ACTION_CREATE, ) self.assertEqual(len(objectchanges), len(self.create_data)) for oc in objectchanges: - self.assertEqual(oc.action, ObjectChangeActionChoices.ACTION_CREATE) self.assertEqual(oc.message, changelog_message) class UpdateObjectViewTestCase(APITestCase): diff --git a/netbox/utilities/testing/views.py b/netbox/utilities/testing/views.py index da8a87098..99a6dd43a 100644 --- a/netbox/utilities/testing/views.py +++ b/netbox/utilities/testing/views.py @@ -655,11 +655,11 @@ class ViewTestCases: self.assertIsNotNone(request_id, "Unable to determine request ID from response") objectchanges = ObjectChange.objects.filter( changed_object_type=ContentType.objects.get_for_model(self.model), - request_id=request_id + request_id=request_id, + action=ObjectChangeActionChoices.ACTION_CREATE, ) self.assertEqual(len(objectchanges), len(self.csv_data) - 1) for oc in objectchanges: - self.assertEqual(oc.action, ObjectChangeActionChoices.ACTION_CREATE) self.assertEqual(oc.message, data['changelog_message']) @override_settings(EXEMPT_VIEW_PERMISSIONS=['*'])