from django.test import TransactionTestCase from polymorphic.models import PolymorphicModel, PolymorphicTypeUndefined from polymorphic.tests.models import ( Enhance_Base, Enhance_Inherit, Model2A, Model2B, Model2C, Model2D, ) from polymorphic.utils import ( get_base_polymorphic_model, reset_polymorphic_ctype, sort_by_subclass, ) class UtilsTests(TransactionTestCase): def test_sort_by_subclass(self): self.assertEqual( sort_by_subclass(Model2D, Model2B, Model2D, Model2A, Model2C), [Model2A, Model2B, Model2C, Model2D, Model2D], ) def test_reset_polymorphic_ctype(self): """ Test the the polymorphic_ctype_id can be restored. """ Model2A.objects.create(field1="A1") Model2D.objects.create(field1="A1", field2="B2", field3="C3", field4="D4") Model2B.objects.create(field1="A1", field2="B2") Model2B.objects.create(field1="A1", field2="B2") Model2A.objects.all().update(polymorphic_ctype_id=None) with self.assertRaises(PolymorphicTypeUndefined): list(Model2A.objects.all()) reset_polymorphic_ctype(Model2D, Model2B, Model2D, Model2A, Model2C) self.assertQuerysetEqual( Model2A.objects.order_by("pk"), [Model2A, Model2D, Model2B, Model2B], transform=lambda o: o.__class__, ) def test_get_base_polymorphic_model(self): """ Test that finding the base polymorphic model works. """ # Finds the base from every level (including lowest) self.assertIs(get_base_polymorphic_model(Model2D), Model2A) self.assertIs(get_base_polymorphic_model(Model2C), Model2A) self.assertIs(get_base_polymorphic_model(Model2B), Model2A) self.assertIs(get_base_polymorphic_model(Model2A), Model2A) # Properly handles multiple inheritance self.assertIs(get_base_polymorphic_model(Enhance_Inherit), Enhance_Base) # Ignores PolymorphicModel itself. self.assertIs(get_base_polymorphic_model(PolymorphicModel), None) def test_get_base_polymorphic_model_skip_abstract(self): """ Skipping abstract models that can't be used for querying. """ class A(PolymorphicModel): class Meta: abstract = True class B(A): pass class C(B): pass self.assertIs(get_base_polymorphic_model(A), None) self.assertIs(get_base_polymorphic_model(B), B) self.assertIs(get_base_polymorphic_model(C), B) self.assertIs(get_base_polymorphic_model(C, allow_abstract=True), A)