Added `get_base_polymorphic_model()` to detect the common base class for a polymorphic model.

(cherry picked from commit 04d4181e17)
fix_request_path_info
Diederik van der Boor 2017-09-30 16:21:21 +02:00
parent fa9612d49c
commit 874b60ec40
2 changed files with 54 additions and 3 deletions

View File

@ -1,9 +1,9 @@
from unittest import TestCase from unittest import TestCase
from polymorphic.models import PolymorphicTypeUndefined from polymorphic.models import PolymorphicTypeUndefined, PolymorphicModel
from polymorphic.tests import Model2A, Model2B, Model2C, Model2D from polymorphic.tests import Model2A, Model2B, Model2C, Model2D, Enhance_Inherit, Enhance_Base
from polymorphic.tests.test_orm import qrepr from polymorphic.tests.test_orm import qrepr
from polymorphic.utils import sort_by_subclass, reset_polymorphic_ctype from polymorphic.utils import reset_polymorphic_ctype, sort_by_subclass, get_base_polymorphic_model
class UtilsTests(TestCase): class UtilsTests(TestCase):
@ -35,3 +35,39 @@ class UtilsTests(TestCase):
' <Model2B: id 3, field1 (CharField), field2 (CharField)>,' ' <Model2B: id 3, field1 (CharField), field2 (CharField)>,'
' <Model2B: id 4, field1 (CharField), field2 (CharField)>]' ' <Model2B: id 4, field1 (CharField), field2 (CharField)>]'
)) ))
def test_get_base_polymorphic_model(self):
"""
Test that finding the base polymorphic model works.
"""
# Finds the base from every level (including lowest)
self.assertIs(get_base_polymorphic_model(Model2D), Model2A)
self.assertIs(get_base_polymorphic_model(Model2C), Model2A)
self.assertIs(get_base_polymorphic_model(Model2B), Model2A)
self.assertIs(get_base_polymorphic_model(Model2A), Model2A)
# Properly handles multiple inheritance
self.assertIs(get_base_polymorphic_model(Enhance_Inherit), Enhance_Base)
# Ignores PolymorphicModel itself.
self.assertIs(get_base_polymorphic_model(PolymorphicModel), None)
def test_get_base_polymorphic_model_skip_abstract(self):
"""
Skipping abstract models that can't be used for querying.
"""
class A(PolymorphicModel):
class Meta:
abstract = True
class B(A):
pass
class C(B):
pass
self.assertIs(get_base_polymorphic_model(A), None)
self.assertIs(get_base_polymorphic_model(B), B)
self.assertIs(get_base_polymorphic_model(C), B)
self.assertIs(get_base_polymorphic_model(C, allow_abstract=True), A)

View File

@ -2,6 +2,9 @@ import sys
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.db import DEFAULT_DB_ALIAS from django.db import DEFAULT_DB_ALIAS
from polymorphic.models import PolymorphicModel
from polymorphic.base import PolymorphicModelBase
def reset_polymorphic_ctype(*models, **filters): def reset_polymorphic_ctype(*models, **filters):
@ -59,3 +62,15 @@ def sort_by_subclass(*classes):
else: else:
from functools import cmp_to_key from functools import cmp_to_key
return sorted(classes, key=cmp_to_key(_compare_mro)) return sorted(classes, key=cmp_to_key(_compare_mro))
def get_base_polymorphic_model(ChildModel, allow_abstract=False):
"""
First the first concrete model in the inheritance chain that inherited from the PolymorphicModel.
"""
for Model in reversed(ChildModel.mro()):
if isinstance(Model, PolymorphicModelBase) and \
Model is not PolymorphicModel and \
(allow_abstract or not Model._meta.abstract):
return Model
return None