Added `get_base_polymorphic_model()` to detect the common base class for a polymorphic model.
parent
96209dcb14
commit
04d4181e17
|
|
@ -1,8 +1,8 @@
|
||||||
from django.test import TransactionTestCase
|
from django.test import TransactionTestCase
|
||||||
|
|
||||||
from polymorphic.models import PolymorphicTypeUndefined
|
from polymorphic.models import PolymorphicTypeUndefined, PolymorphicModel
|
||||||
from polymorphic.tests.models import Model2A, Model2B, Model2C, Model2D
|
from polymorphic.tests.models import Model2A, Model2B, Model2C, Model2D, Enhance_Inherit, Enhance_Base
|
||||||
from polymorphic.utils import reset_polymorphic_ctype, sort_by_subclass
|
from polymorphic.utils import reset_polymorphic_ctype, sort_by_subclass, get_base_polymorphic_model
|
||||||
|
|
||||||
|
|
||||||
class UtilsTests(TransactionTestCase):
|
class UtilsTests(TransactionTestCase):
|
||||||
|
|
@ -38,3 +38,39 @@ class UtilsTests(TransactionTestCase):
|
||||||
],
|
],
|
||||||
transform=lambda o: o.__class__,
|
transform=lambda o: o.__class__,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_get_base_polymorphic_model(self):
|
||||||
|
"""
|
||||||
|
Test that finding the base polymorphic model works.
|
||||||
|
"""
|
||||||
|
# Finds the base from every level (including lowest)
|
||||||
|
self.assertIs(get_base_polymorphic_model(Model2D), Model2A)
|
||||||
|
self.assertIs(get_base_polymorphic_model(Model2C), Model2A)
|
||||||
|
self.assertIs(get_base_polymorphic_model(Model2B), Model2A)
|
||||||
|
self.assertIs(get_base_polymorphic_model(Model2A), Model2A)
|
||||||
|
|
||||||
|
# Properly handles multiple inheritance
|
||||||
|
self.assertIs(get_base_polymorphic_model(Enhance_Inherit), Enhance_Base)
|
||||||
|
|
||||||
|
# Ignores PolymorphicModel itself.
|
||||||
|
self.assertIs(get_base_polymorphic_model(PolymorphicModel), None)
|
||||||
|
|
||||||
|
def test_get_base_polymorphic_model_skip_abstract(self):
|
||||||
|
"""
|
||||||
|
Skipping abstract models that can't be used for querying.
|
||||||
|
"""
|
||||||
|
class A(PolymorphicModel):
|
||||||
|
class Meta:
|
||||||
|
abstract = True
|
||||||
|
|
||||||
|
class B(A):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class C(B):
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.assertIs(get_base_polymorphic_model(A), None)
|
||||||
|
self.assertIs(get_base_polymorphic_model(B), B)
|
||||||
|
self.assertIs(get_base_polymorphic_model(C), B)
|
||||||
|
|
||||||
|
self.assertIs(get_base_polymorphic_model(C, allow_abstract=True), A)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,9 @@ import sys
|
||||||
|
|
||||||
from django.contrib.contenttypes.models import ContentType
|
from django.contrib.contenttypes.models import ContentType
|
||||||
from django.db import DEFAULT_DB_ALIAS
|
from django.db import DEFAULT_DB_ALIAS
|
||||||
|
from polymorphic.models import PolymorphicModel
|
||||||
|
|
||||||
|
from polymorphic.base import PolymorphicModelBase
|
||||||
|
|
||||||
|
|
||||||
def reset_polymorphic_ctype(*models, **filters):
|
def reset_polymorphic_ctype(*models, **filters):
|
||||||
|
|
@ -59,3 +62,15 @@ def sort_by_subclass(*classes):
|
||||||
else:
|
else:
|
||||||
from functools import cmp_to_key
|
from functools import cmp_to_key
|
||||||
return sorted(classes, key=cmp_to_key(_compare_mro))
|
return sorted(classes, key=cmp_to_key(_compare_mro))
|
||||||
|
|
||||||
|
|
||||||
|
def get_base_polymorphic_model(ChildModel, allow_abstract=False):
|
||||||
|
"""
|
||||||
|
First the first concrete model in the inheritance chain that inherited from the PolymorphicModel.
|
||||||
|
"""
|
||||||
|
for Model in reversed(ChildModel.mro()):
|
||||||
|
if isinstance(Model, PolymorphicModelBase) and \
|
||||||
|
Model is not PolymorphicModel and \
|
||||||
|
(allow_abstract or not Model._meta.abstract):
|
||||||
|
return Model
|
||||||
|
return None
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue