Added `get_base_polymorphic_model()` to detect the common base class for a polymorphic model.

(cherry picked from commit 04d4181e17)
fix_request_path_info
Diederik van der Boor 2017-09-30 16:21:21 +02:00
parent fa9612d49c
commit 874b60ec40
2 changed files with 54 additions and 3 deletions

View File

@ -1,9 +1,9 @@
from unittest import TestCase
from polymorphic.models import PolymorphicTypeUndefined
from polymorphic.tests import Model2A, Model2B, Model2C, Model2D
from polymorphic.models import PolymorphicTypeUndefined, PolymorphicModel
from polymorphic.tests import Model2A, Model2B, Model2C, Model2D, Enhance_Inherit, Enhance_Base
from polymorphic.tests.test_orm import qrepr
from polymorphic.utils import sort_by_subclass, reset_polymorphic_ctype
from polymorphic.utils import reset_polymorphic_ctype, sort_by_subclass, get_base_polymorphic_model
class UtilsTests(TestCase):
@ -35,3 +35,39 @@ class UtilsTests(TestCase):
' <Model2B: id 3, field1 (CharField), field2 (CharField)>,'
' <Model2B: id 4, field1 (CharField), field2 (CharField)>]'
))
def test_get_base_polymorphic_model(self):
"""
Test that finding the base polymorphic model works.
"""
# Finds the base from every level (including lowest)
self.assertIs(get_base_polymorphic_model(Model2D), Model2A)
self.assertIs(get_base_polymorphic_model(Model2C), Model2A)
self.assertIs(get_base_polymorphic_model(Model2B), Model2A)
self.assertIs(get_base_polymorphic_model(Model2A), Model2A)
# Properly handles multiple inheritance
self.assertIs(get_base_polymorphic_model(Enhance_Inherit), Enhance_Base)
# Ignores PolymorphicModel itself.
self.assertIs(get_base_polymorphic_model(PolymorphicModel), None)
def test_get_base_polymorphic_model_skip_abstract(self):
"""
Skipping abstract models that can't be used for querying.
"""
class A(PolymorphicModel):
class Meta:
abstract = True
class B(A):
pass
class C(B):
pass
self.assertIs(get_base_polymorphic_model(A), None)
self.assertIs(get_base_polymorphic_model(B), B)
self.assertIs(get_base_polymorphic_model(C), B)
self.assertIs(get_base_polymorphic_model(C, allow_abstract=True), A)

View File

@ -2,6 +2,9 @@ import sys
from django.contrib.contenttypes.models import ContentType
from django.db import DEFAULT_DB_ALIAS
from polymorphic.models import PolymorphicModel
from polymorphic.base import PolymorphicModelBase
def reset_polymorphic_ctype(*models, **filters):
@ -59,3 +62,15 @@ def sort_by_subclass(*classes):
else:
from functools import cmp_to_key
return sorted(classes, key=cmp_to_key(_compare_mro))
def get_base_polymorphic_model(ChildModel, allow_abstract=False):
"""
First the first concrete model in the inheritance chain that inherited from the PolymorphicModel.
"""
for Model in reversed(ChildModel.mro()):
if isinstance(Model, PolymorphicModelBase) and \
Model is not PolymorphicModel and \
(allow_abstract or not Model._meta.abstract):
return Model
return None