From 04d4181e17cd9239b667ec24e45fec137cb4a40b Mon Sep 17 00:00:00 2001 From: Diederik van der Boor Date: Sat, 30 Sep 2017 16:21:21 +0200 Subject: [PATCH] Added `get_base_polymorphic_model()` to detect the common base class for a polymorphic model. --- polymorphic/tests/test_utils.py | 42 ++++++++++++++++++++++++++++++--- polymorphic/utils.py | 15 ++++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/polymorphic/tests/test_utils.py b/polymorphic/tests/test_utils.py index 410ef26..fffd2fc 100644 --- a/polymorphic/tests/test_utils.py +++ b/polymorphic/tests/test_utils.py @@ -1,8 +1,8 @@ from django.test import TransactionTestCase -from polymorphic.models import PolymorphicTypeUndefined -from polymorphic.tests.models import Model2A, Model2B, Model2C, Model2D -from polymorphic.utils import reset_polymorphic_ctype, sort_by_subclass +from polymorphic.models import PolymorphicTypeUndefined, PolymorphicModel +from polymorphic.tests.models import Model2A, Model2B, Model2C, Model2D, Enhance_Inherit, Enhance_Base +from polymorphic.utils import reset_polymorphic_ctype, sort_by_subclass, get_base_polymorphic_model class UtilsTests(TransactionTestCase): @@ -38,3 +38,39 @@ class UtilsTests(TransactionTestCase): ], transform=lambda o: o.__class__, ) + + def test_get_base_polymorphic_model(self): + """ + Test that finding the base polymorphic model works. + """ + # Finds the base from every level (including lowest) + self.assertIs(get_base_polymorphic_model(Model2D), Model2A) + self.assertIs(get_base_polymorphic_model(Model2C), Model2A) + self.assertIs(get_base_polymorphic_model(Model2B), Model2A) + self.assertIs(get_base_polymorphic_model(Model2A), Model2A) + + # Properly handles multiple inheritance + self.assertIs(get_base_polymorphic_model(Enhance_Inherit), Enhance_Base) + + # Ignores PolymorphicModel itself. + self.assertIs(get_base_polymorphic_model(PolymorphicModel), None) + + def test_get_base_polymorphic_model_skip_abstract(self): + """ + Skipping abstract models that can't be used for querying. + """ + class A(PolymorphicModel): + class Meta: + abstract = True + + class B(A): + pass + + class C(B): + pass + + self.assertIs(get_base_polymorphic_model(A), None) + self.assertIs(get_base_polymorphic_model(B), B) + self.assertIs(get_base_polymorphic_model(C), B) + + self.assertIs(get_base_polymorphic_model(C, allow_abstract=True), A) diff --git a/polymorphic/utils.py b/polymorphic/utils.py index a8d968b..0bc0c41 100644 --- a/polymorphic/utils.py +++ b/polymorphic/utils.py @@ -2,6 +2,9 @@ import sys from django.contrib.contenttypes.models import ContentType from django.db import DEFAULT_DB_ALIAS +from polymorphic.models import PolymorphicModel + +from polymorphic.base import PolymorphicModelBase def reset_polymorphic_ctype(*models, **filters): @@ -59,3 +62,15 @@ def sort_by_subclass(*classes): else: from functools import cmp_to_key return sorted(classes, key=cmp_to_key(_compare_mro)) + + +def get_base_polymorphic_model(ChildModel, allow_abstract=False): + """ + First the first concrete model in the inheritance chain that inherited from the PolymorphicModel. + """ + for Model in reversed(ChildModel.mro()): + if isinstance(Model, PolymorphicModelBase) and \ + Model is not PolymorphicModel and \ + (allow_abstract or not Model._meta.abstract): + return Model + return None