Added `get_base_polymorphic_model()` to detect the common base class for a polymorphic model.

fix_request_path_info
Diederik van der Boor 2017-09-30 16:21:21 +02:00
parent 96209dcb14
commit 04d4181e17
2 changed files with 54 additions and 3 deletions

View File

@ -1,8 +1,8 @@
from django.test import TransactionTestCase
from polymorphic.models import PolymorphicTypeUndefined
from polymorphic.tests.models import Model2A, Model2B, Model2C, Model2D
from polymorphic.utils import reset_polymorphic_ctype, sort_by_subclass
from polymorphic.models import PolymorphicTypeUndefined, PolymorphicModel
from polymorphic.tests.models import Model2A, Model2B, Model2C, Model2D, Enhance_Inherit, Enhance_Base
from polymorphic.utils import reset_polymorphic_ctype, sort_by_subclass, get_base_polymorphic_model
class UtilsTests(TransactionTestCase):
@ -38,3 +38,39 @@ class UtilsTests(TransactionTestCase):
],
transform=lambda o: o.__class__,
)
def test_get_base_polymorphic_model(self):
"""
Test that finding the base polymorphic model works.
"""
# Finds the base from every level (including lowest)
self.assertIs(get_base_polymorphic_model(Model2D), Model2A)
self.assertIs(get_base_polymorphic_model(Model2C), Model2A)
self.assertIs(get_base_polymorphic_model(Model2B), Model2A)
self.assertIs(get_base_polymorphic_model(Model2A), Model2A)
# Properly handles multiple inheritance
self.assertIs(get_base_polymorphic_model(Enhance_Inherit), Enhance_Base)
# Ignores PolymorphicModel itself.
self.assertIs(get_base_polymorphic_model(PolymorphicModel), None)
def test_get_base_polymorphic_model_skip_abstract(self):
"""
Skipping abstract models that can't be used for querying.
"""
class A(PolymorphicModel):
class Meta:
abstract = True
class B(A):
pass
class C(B):
pass
self.assertIs(get_base_polymorphic_model(A), None)
self.assertIs(get_base_polymorphic_model(B), B)
self.assertIs(get_base_polymorphic_model(C), B)
self.assertIs(get_base_polymorphic_model(C, allow_abstract=True), A)

View File

@ -2,6 +2,9 @@ import sys
from django.contrib.contenttypes.models import ContentType
from django.db import DEFAULT_DB_ALIAS
from polymorphic.models import PolymorphicModel
from polymorphic.base import PolymorphicModelBase
def reset_polymorphic_ctype(*models, **filters):
@ -59,3 +62,15 @@ def sort_by_subclass(*classes):
else:
from functools import cmp_to_key
return sorted(classes, key=cmp_to_key(_compare_mro))
def get_base_polymorphic_model(ChildModel, allow_abstract=False):
"""
First the first concrete model in the inheritance chain that inherited from the PolymorphicModel.
"""
for Model in reversed(ChildModel.mro()):
if isinstance(Model, PolymorphicModelBase) and \
Model is not PolymorphicModel and \
(allow_abstract or not Model._meta.abstract):
return Model
return None