diff --git a/docs/third-party.rst b/docs/third-party.rst index 7b274f5..33c94a6 100644 --- a/docs/third-party.rst +++ b/docs/third-party.rst @@ -97,7 +97,16 @@ This doesn't work, since it needs to look for revisions of the child model. Usin the view of the actual child model is used, similar to the way the regular change and delete views are redirected. +django-guardian support +----------------------- + +You can enable the content type of the base model to be used for the object levels permissions by setting the +django-guardian_ option `GUARDIAN_GET_CONTENT_TYPE` to `polymorphic.contrib.get_polymorphic_base_content_type`. Read +more about this option in the `django-guardian documentation `_. + + .. _django-reversion: https://github.com/etianen/django-reversion .. _django-reversion-compare: https://github.com/jedie/django-reversion-compare .. _django-mptt: https://github.com/django-mptt/django-mptt .. _django-polymorphic-tree: https://github.com/django-polymorphic/django-polymorphic-tree +.. _django-guardian: https://github.com/django-guardian/django-guardian diff --git a/polymorphic/contrib/__init__.py b/polymorphic/contrib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/polymorphic/contrib/guardian.py b/polymorphic/contrib/guardian.py new file mode 100644 index 0000000..55b1601 --- /dev/null +++ b/polymorphic/contrib/guardian.py @@ -0,0 +1,35 @@ +from django.contrib.contenttypes.models import ContentType + + +def get_polymorphic_base_content_type(obj): + """ + Helper function to return the base polymorphic content type id. This should used with django-guardian and the + GUARDIAN_GET_CONTENT_TYPE option. + + See the django-guardian documentation for more information: + + https://django-guardian.readthedocs.io/en/latest/configuration.html#guardian-get-content-type + """ + if hasattr(obj, 'polymorphic_model_marker'): + try: + superclasses = list(obj.__class__.mro()) + except TypeError: + # obj is an object so mro() need to be called with the obj. + superclasses = list(obj.__class__.mro(obj)) + + polymorphic_superclasses = list() + for sclass in superclasses: + if hasattr(sclass, 'polymorphic_model_marker'): + polymorphic_superclasses.append(sclass) + + # PolymorphicMPTT adds an additional class between polymorphic and base class. + if hasattr(obj, 'can_have_children'): + root_polymorphic_class = polymorphic_superclasses[-3] + else: + root_polymorphic_class = polymorphic_superclasses[-2] + ctype = ContentType.objects.get_for_model(root_polymorphic_class) + + else: + ctype = ContentType.objects.get_for_model(obj) + + return ctype diff --git a/polymorphic/tests.py b/polymorphic/tests.py index d048313..9c2cb3d 100644 --- a/polymorphic/tests.py +++ b/polymorphic/tests.py @@ -21,6 +21,7 @@ from django.db import models from django.contrib.contenttypes.models import ContentType from django.utils import six +from polymorphic.contrib.guardian import get_polymorphic_base_content_type from polymorphic.models import PolymorphicModel from polymorphic.managers import PolymorphicManager from polymorphic.query import PolymorphicQuerySet @@ -195,6 +196,7 @@ class ModelWithMyManagerNoDefault(ShowFieldTypeAndContent, Model2A): my_objects = MyManager() field4 = models.CharField(max_length=10) + class ModelWithMyManagerDefault(ShowFieldTypeAndContent, Model2A): my_objects = MyManager() objects = PolymorphicManager() @@ -1194,6 +1196,24 @@ class PolymorphicTests(TestCase): result = Model2B.objects.annotate(val=Concat('field1', 'field2')) self.assertEqual(list(result), []) + def test_contrib_guardian(self): + # Regular Django inheritance should return the child model content type. + obj = PlainC() + ctype = get_polymorphic_base_content_type(obj) + self.assertEqual(ctype.name, 'plain c') + + ctype = get_polymorphic_base_content_type(PlainC) + self.assertEqual(ctype.name, 'plain c') + + # Polymorphic inheritance should return the parent model content type. + obj = Model2D() + ctype = get_polymorphic_base_content_type(obj) + self.assertEqual(ctype.name, 'model2a') + + ctype = get_polymorphic_base_content_type(Model2D) + self.assertEqual(ctype.name, 'model2a') + + class RegressionTests(TestCase): def test_for_query_result_incomplete_with_inheritance(self): @@ -1215,6 +1235,7 @@ class RegressionTests(TestCase): expected_queryset = [bottom] self.assertQuerysetEqual(Bottom.objects.all(), [repr(r) for r in expected_queryset]) + class MultipleDatabasesTests(TestCase): multi_db = True