Added `get_base_polymorphic_model()` to detect the common base class for a polymorphic model.

fix_request_path_info
Diederik van der Boor 2017-09-30 16:21:21 +02:00
parent 96209dcb14
commit 04d4181e17
2 changed files with 54 additions and 3 deletions

View File

@ -1,8 +1,8 @@
from django.test import TransactionTestCase from django.test import TransactionTestCase
from polymorphic.models import PolymorphicTypeUndefined from polymorphic.models import PolymorphicTypeUndefined, PolymorphicModel
from polymorphic.tests.models import Model2A, Model2B, Model2C, Model2D from polymorphic.tests.models import Model2A, Model2B, Model2C, Model2D, Enhance_Inherit, Enhance_Base
from polymorphic.utils import reset_polymorphic_ctype, sort_by_subclass from polymorphic.utils import reset_polymorphic_ctype, sort_by_subclass, get_base_polymorphic_model
class UtilsTests(TransactionTestCase): class UtilsTests(TransactionTestCase):
@ -38,3 +38,39 @@ class UtilsTests(TransactionTestCase):
], ],
transform=lambda o: o.__class__, transform=lambda o: o.__class__,
) )
def test_get_base_polymorphic_model(self):
"""
Test that finding the base polymorphic model works.
"""
# Finds the base from every level (including lowest)
self.assertIs(get_base_polymorphic_model(Model2D), Model2A)
self.assertIs(get_base_polymorphic_model(Model2C), Model2A)
self.assertIs(get_base_polymorphic_model(Model2B), Model2A)
self.assertIs(get_base_polymorphic_model(Model2A), Model2A)
# Properly handles multiple inheritance
self.assertIs(get_base_polymorphic_model(Enhance_Inherit), Enhance_Base)
# Ignores PolymorphicModel itself.
self.assertIs(get_base_polymorphic_model(PolymorphicModel), None)
def test_get_base_polymorphic_model_skip_abstract(self):
"""
Skipping abstract models that can't be used for querying.
"""
class A(PolymorphicModel):
class Meta:
abstract = True
class B(A):
pass
class C(B):
pass
self.assertIs(get_base_polymorphic_model(A), None)
self.assertIs(get_base_polymorphic_model(B), B)
self.assertIs(get_base_polymorphic_model(C), B)
self.assertIs(get_base_polymorphic_model(C, allow_abstract=True), A)

View File

@ -2,6 +2,9 @@ import sys
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.db import DEFAULT_DB_ALIAS from django.db import DEFAULT_DB_ALIAS
from polymorphic.models import PolymorphicModel
from polymorphic.base import PolymorphicModelBase
def reset_polymorphic_ctype(*models, **filters): def reset_polymorphic_ctype(*models, **filters):
@ -59,3 +62,15 @@ def sort_by_subclass(*classes):
else: else:
from functools import cmp_to_key from functools import cmp_to_key
return sorted(classes, key=cmp_to_key(_compare_mro)) return sorted(classes, key=cmp_to_key(_compare_mro))
def get_base_polymorphic_model(ChildModel, allow_abstract=False):
"""
First the first concrete model in the inheritance chain that inherited from the PolymorphicModel.
"""
for Model in reversed(ChildModel.mro()):
if isinstance(Model, PolymorphicModelBase) and \
Model is not PolymorphicModel and \
(allow_abstract or not Model._meta.abstract):
return Model
return None