diff --git a/validity/tests/test_utils/test_orm.py b/validity/tests/test_utils/test_orm.py index 7075399..d744d27 100644 --- a/validity/tests/test_utils/test_orm.py +++ b/validity/tests/test_utils/test_orm.py @@ -1,14 +1,16 @@ import pytest -from dcim.models import Device +from dcim.models import Device, DeviceType +from django.contrib.contenttypes.models import ContentType from django.db import connection from django.db.models import BigIntegerField from django.db.models.fields.json import KeyTextTransform from django.db.models.functions import Cast +from extras.models import CustomField from factories import DeviceFactory, NameSetDBFactory, SerializerDBFactory from validity.models import NameSet from validity.models.serializer import Serializer -from validity.utils.orm import CustomPrefetchMixin, QuerySetMap +from validity.utils.orm import CustomFieldBuilder, CustomPrefetchMixin, QuerySetMap @pytest.mark.parametrize("attrib", ["pk", "name"]) @@ -55,3 +57,53 @@ def test_custom_postfetch(monkeypatch): NameSetDBFactory(name=f"ns{i}") for device in custom_qs: assert device.name.replace("dev", "ns") == device.nameset.name + + +@pytest.mark.django_db +def test_custom_field_builder_creates_object_field(): + serializer_ct = ContentType.objects.get_for_model(Serializer) + device_ct = ContentType.objects.get_for_model(Device) + cf_builder = CustomFieldBuilder(cf_model=CustomField, content_type_model=ContentType) + + custom_field = cf_builder.create( + name="validity_test_serializer", + label="Validity Test Serializer", + type="object", + required=False, + object_type=serializer_ct, + bind_to=[Device], + ) + + custom_field.refresh_from_db() + assert custom_field.name == "validity_test_serializer" + assert custom_field.related_object_type == serializer_ct + assert list(custom_field.object_types.all()) == [device_ct] + + +@pytest.mark.django_db +def test_custom_field_builder_reuses_existing_field(): + cf_builder = CustomFieldBuilder(cf_model=CustomField, content_type_model=ContentType) + device_type_ct = ContentType.objects.get_for_model(DeviceType) + + original = cf_builder.create( + name="validity_test_reused", + label="Original Label", + type="text", + required=False, + bind_to=[Device], + ) + reused = cf_builder.create( + name="validity_test_reused", + label="Changed Label", + type="boolean", + required=True, + bind_to=[DeviceType], + ) + + original.refresh_from_db() + assert reused == original + assert CustomField.objects.filter(name="validity_test_reused").count() == 1 + assert original.label == "Changed Label" + assert original.type == "boolean" + assert original.required is True + assert list(original.object_types.all()) == [device_type_ct] diff --git a/validity/utils/orm.py b/validity/utils/orm.py index 2f3b9af..30f9217 100644 --- a/validity/utils/orm.py +++ b/validity/utils/orm.py @@ -206,10 +206,15 @@ class CustomFieldBuilder: content_type_model: type db_alias: str = "" - def create(self, *, bind_to, object_type=None, **cf_params): + def create(self, *, bind_to, name, object_type=None, **cf_params): db = self.db_alias or self.cf_model.objects.db - if object_type is not None: - cf_params["related_object_type"] = object_type - custom_field = self.cf_model.objects.using(db).create(**cf_params) + cf_params["related_object_type"] = object_type + + # get_or_create handles #182 - compatibility with netbox-branching weird behaviour + custom_field, created = self.cf_model.objects.using(db).get_or_create(name=name, defaults=cf_params) + if not created: + for field, value in cf_params.items(): + setattr(custom_field, field, value) + custom_field.save(force_update=True, update_fields=cf_params.keys()) custom_field.object_types.set(self.content_type_model.objects.get_for_model(model).pk for model in bind_to) return custom_field