Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 54 additions & 2 deletions validity/tests/test_utils/test_orm.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import pytest
from dcim.models import Device
from dcim.models import Device, DeviceType
from django.contrib.contenttypes.models import ContentType
from django.db import connection
from django.db.models import BigIntegerField
from django.db.models.fields.json import KeyTextTransform
from django.db.models.functions import Cast
from extras.models import CustomField
from factories import DeviceFactory, NameSetDBFactory, SerializerDBFactory

from validity.models import NameSet
from validity.models.serializer import Serializer
from validity.utils.orm import CustomPrefetchMixin, QuerySetMap
from validity.utils.orm import CustomFieldBuilder, CustomPrefetchMixin, QuerySetMap


@pytest.mark.parametrize("attrib", ["pk", "name"])
Expand Down Expand Up @@ -55,3 +57,53 @@ def test_custom_postfetch(monkeypatch):
NameSetDBFactory(name=f"ns{i}")
for device in custom_qs:
assert device.name.replace("dev", "ns") == device.nameset.name


@pytest.mark.django_db
def test_custom_field_builder_creates_object_field():
serializer_ct = ContentType.objects.get_for_model(Serializer)
device_ct = ContentType.objects.get_for_model(Device)
cf_builder = CustomFieldBuilder(cf_model=CustomField, content_type_model=ContentType)

custom_field = cf_builder.create(
name="validity_test_serializer",
label="Validity Test Serializer",
type="object",
required=False,
object_type=serializer_ct,
bind_to=[Device],
)

custom_field.refresh_from_db()
assert custom_field.name == "validity_test_serializer"
assert custom_field.related_object_type == serializer_ct
assert list(custom_field.object_types.all()) == [device_ct]


@pytest.mark.django_db
def test_custom_field_builder_reuses_existing_field():
cf_builder = CustomFieldBuilder(cf_model=CustomField, content_type_model=ContentType)
device_type_ct = ContentType.objects.get_for_model(DeviceType)

original = cf_builder.create(
name="validity_test_reused",
label="Original Label",
type="text",
required=False,
bind_to=[Device],
)
reused = cf_builder.create(
name="validity_test_reused",
label="Changed Label",
type="boolean",
required=True,
bind_to=[DeviceType],
)

original.refresh_from_db()
assert reused == original
assert CustomField.objects.filter(name="validity_test_reused").count() == 1
assert original.label == "Changed Label"
assert original.type == "boolean"
assert original.required is True
assert list(original.object_types.all()) == [device_type_ct]
13 changes: 9 additions & 4 deletions validity/utils/orm.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,15 @@ class CustomFieldBuilder:
content_type_model: type
db_alias: str = ""

def create(self, *, bind_to, object_type=None, **cf_params):
def create(self, *, bind_to, name, object_type=None, **cf_params):
db = self.db_alias or self.cf_model.objects.db
if object_type is not None:
cf_params["related_object_type"] = object_type
custom_field = self.cf_model.objects.using(db).create(**cf_params)
cf_params["related_object_type"] = object_type

# get_or_create handles #182 - compatibility with netbox-branching weird behaviour
custom_field, created = self.cf_model.objects.using(db).get_or_create(name=name, defaults=cf_params)
if not created:
for field, value in cf_params.items():
setattr(custom_field, field, value)
custom_field.save(force_update=True, update_fields=cf_params.keys())
custom_field.object_types.set(self.content_type_model.objects.get_for_model(model).pk for model in bind_to)
return custom_field
Loading