Added `get_base_polymorphic_model()` to detect the common base class for a polymorphic model.
(cherry picked from commit 04d4181e17)
fix_request_path_info
parent
fa9612d49c
commit
874b60ec40
|
|
@ -1,9 +1,9 @@
|
|||
from unittest import TestCase
|
||||
|
||||
from polymorphic.models import PolymorphicTypeUndefined
|
||||
from polymorphic.tests import Model2A, Model2B, Model2C, Model2D
|
||||
from polymorphic.models import PolymorphicTypeUndefined, PolymorphicModel
|
||||
from polymorphic.tests import Model2A, Model2B, Model2C, Model2D, Enhance_Inherit, Enhance_Base
|
||||
from polymorphic.tests.test_orm import qrepr
|
||||
from polymorphic.utils import sort_by_subclass, reset_polymorphic_ctype
|
||||
from polymorphic.utils import reset_polymorphic_ctype, sort_by_subclass, get_base_polymorphic_model
|
||||
|
||||
|
||||
class UtilsTests(TestCase):
|
||||
|
|
@ -35,3 +35,39 @@ class UtilsTests(TestCase):
|
|||
' <Model2B: id 3, field1 (CharField), field2 (CharField)>,'
|
||||
' <Model2B: id 4, field1 (CharField), field2 (CharField)>]'
|
||||
))
|
||||
|
||||
def test_get_base_polymorphic_model(self):
|
||||
"""
|
||||
Test that finding the base polymorphic model works.
|
||||
"""
|
||||
# Finds the base from every level (including lowest)
|
||||
self.assertIs(get_base_polymorphic_model(Model2D), Model2A)
|
||||
self.assertIs(get_base_polymorphic_model(Model2C), Model2A)
|
||||
self.assertIs(get_base_polymorphic_model(Model2B), Model2A)
|
||||
self.assertIs(get_base_polymorphic_model(Model2A), Model2A)
|
||||
|
||||
# Properly handles multiple inheritance
|
||||
self.assertIs(get_base_polymorphic_model(Enhance_Inherit), Enhance_Base)
|
||||
|
||||
# Ignores PolymorphicModel itself.
|
||||
self.assertIs(get_base_polymorphic_model(PolymorphicModel), None)
|
||||
|
||||
def test_get_base_polymorphic_model_skip_abstract(self):
|
||||
"""
|
||||
Skipping abstract models that can't be used for querying.
|
||||
"""
|
||||
class A(PolymorphicModel):
|
||||
class Meta:
|
||||
abstract = True
|
||||
|
||||
class B(A):
|
||||
pass
|
||||
|
||||
class C(B):
|
||||
pass
|
||||
|
||||
self.assertIs(get_base_polymorphic_model(A), None)
|
||||
self.assertIs(get_base_polymorphic_model(B), B)
|
||||
self.assertIs(get_base_polymorphic_model(C), B)
|
||||
|
||||
self.assertIs(get_base_polymorphic_model(C, allow_abstract=True), A)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ import sys
|
|||
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db import DEFAULT_DB_ALIAS
|
||||
from polymorphic.models import PolymorphicModel
|
||||
|
||||
from polymorphic.base import PolymorphicModelBase
|
||||
|
||||
|
||||
def reset_polymorphic_ctype(*models, **filters):
|
||||
|
|
@ -59,3 +62,15 @@ def sort_by_subclass(*classes):
|
|||
else:
|
||||
from functools import cmp_to_key
|
||||
return sorted(classes, key=cmp_to_key(_compare_mro))
|
||||
|
||||
|
||||
def get_base_polymorphic_model(ChildModel, allow_abstract=False):
|
||||
"""
|
||||
First the first concrete model in the inheritance chain that inherited from the PolymorphicModel.
|
||||
"""
|
||||
for Model in reversed(ChildModel.mro()):
|
||||
if isinstance(Model, PolymorphicModelBase) and \
|
||||
Model is not PolymorphicModel and \
|
||||
(allow_abstract or not Model._meta.abstract):
|
||||
return Model
|
||||
return None
|
||||
|
|
|
|||
Loading…
Reference in New Issue