diff --git a/django_enumfield/db/fields.py b/django_enumfield/db/fields.py index 37223ad..48e99e9 100644 --- a/django_enumfield/db/fields.py +++ b/django_enumfield/db/fields.py @@ -126,10 +126,9 @@ def set_enum(self, new_value): "for enum {enum}." ).format(value=new_value, enum=enum) ) + validators.validate_valid_transition(enum, old_value, new_value) setattr(self, private_att_name, new_value) self.__dict__[att_name] = new_value - # Run validation for new value. - validators.validate_valid_transition(enum, old_value, new_value) def get_enum(self): return getattr(self, private_att_name) diff --git a/django_enumfield/tests/test_enum.py b/django_enumfield/tests/test_enum.py index 35d4dbf..9263f0e 100644 --- a/django_enumfield/tests/test_enum.py +++ b/django_enumfield/tests/test_enum.py @@ -125,6 +125,12 @@ def test_enum_field_save(self): self.assertEqual(beer.state, BeerState.FIZZY) beer.save() + def test_invalid_transition_preserves_previous_value(self): + person = Person.objects.create(status=PersonStatus.ALIVE) + with self.assertRaises(InvalidStatusOperationError): + person.status = PersonStatus.UNBORN + self.assertEqual(person.status, PersonStatus.ALIVE) + def test_enum_field_refresh_from_db(self): lamp = Lamp.objects.create(state=LampState.OFF) lamp2 = Lamp.objects.get(pk=lamp.id)