diff --git a/polymorphic/base.py b/polymorphic/base.py index a7db9b1..c36b080 100644 --- a/polymorphic/base.py +++ b/polymorphic/base.py @@ -8,6 +8,7 @@ import inspect from django.db import models from django.db.models.base import ModelBase +from django.db.models.manager import ManagerDescriptor from manager import PolymorphicManager from query import PolymorphicQuerySet @@ -68,7 +69,8 @@ class PolymorphicModelBase(ModelBase): new_class.add_to_class(mgr_name, new_manager) # get first user defined manager; if there is one, make it the _default_manager - user_manager = new_class.get_first_user_defined_manager() + # this value is used by the related objects, restoring access to custom queryset methods on related objects. + user_manager = self.get_first_user_defined_manager(new_class) if user_manager: def_mgr = user_manager._copy_to_model(new_class) #print '## add default manager', type(def_mgr) @@ -141,13 +143,16 @@ class PolymorphicModelBase(ModelBase): return add_managers @classmethod - def get_first_user_defined_manager(self): + def get_first_user_defined_manager(mcs, new_class): # See if there is a manager attribute directly stored at this inheritance level. mgr_list = [] - for key, val in self.__dict__.items(): - item = getattr(self, key) - if not isinstance(item, models.Manager): continue - mgr_list.append((item.creation_counter, key, item)) + for key, val in new_class.__dict__.items(): + if isinstance(val, ManagerDescriptor): + val = val.manager + if not isinstance(val, PolymorphicManager) or type(val) is PolymorphicManager: + continue + + mgr_list.append((val.creation_counter, key, val)) # if there are user defined managers, use first one as _default_manager if mgr_list: diff --git a/polymorphic/manager.py b/polymorphic/manager.py index cdef305..83db77b 100644 --- a/polymorphic/manager.py +++ b/polymorphic/manager.py @@ -2,7 +2,7 @@ """ PolymorphicManager Please see README.rst or DOCS.rst or http://chrisglass.github.com/django_polymorphic/ """ - +import warnings from django.db import models from polymorphic.query import PolymorphicQuerySet @@ -14,17 +14,23 @@ class PolymorphicManager(models.Manager): Usually not explicitly needed, except if a custom manager or a custom queryset class is to be used. """ + # Tell Django that related fields also need to use this manager: use_for_related_fields = True + queryset_class = PolymorphicQuerySet def __init__(self, queryset_class=None, *args, **kwrags): - if not queryset_class: - self.queryset_class = PolymorphicQuerySet - else: + # Up till polymorphic 0.4, the queryset class could be specified as parameter to __init__. + # However, this doesn't work for related managers which instantiate a new version of this class. + # Hence, for custom managers the new default is using the 'queryset_class' attribute at class level instead. + if queryset_class: + warnings.warn("Using PolymorphicManager(queryset_class=..) is deprecated; override the queryset_class attribute instead", DeprecationWarning) + # For backwards compatibility, still allow the parameter: self.queryset_class = queryset_class + super(PolymorphicManager, self).__init__(*args, **kwrags) def get_query_set(self): - return self.queryset_class(self.model) + return self.queryset_class(self.model, using=self._db) # Proxy all unknown method calls to the queryset, so that its members are # directly accessible as PolymorphicModel.objects.* diff --git a/polymorphic/tests.py b/polymorphic/tests.py index c6188e0..ecf4c52 100644 --- a/polymorphic/tests.py +++ b/polymorphic/tests.py @@ -4,13 +4,14 @@ """ import uuid import re +from django.db.models.query import QuerySet from django.test import TestCase from django.db.models import Q,Count from django.db import models from django.contrib.contenttypes.models import ContentType -from polymorphic import PolymorphicModel, PolymorphicManager +from polymorphic import PolymorphicModel, PolymorphicManager, PolymorphicQuerySet from polymorphic import ShowFieldContent, ShowFieldType, ShowFieldTypeAndContent from polymorphic.tools_for_tests import UUIDField @@ -81,7 +82,7 @@ class DiamondXY(DiamondX, DiamondY): class RelationBase(ShowFieldTypeAndContent, PolymorphicModel): field_base = models.CharField(max_length=10) - fk = models.ForeignKey('self', null=True) + fk = models.ForeignKey('self', null=True, related_name='relationbase_set') m2m = models.ManyToManyField('self') class RelationA(RelationBase): field_a = models.CharField(max_length=10) @@ -100,9 +101,16 @@ class One2OneRelatingModel(PolymorphicModel): class One2OneRelatingModelDerived(One2OneRelatingModel): field2 = models.CharField(max_length=10) +class MyManagerQuerySet(PolymorphicQuerySet): + def my_queryset_foo(self): + return self.all() # Just a method to prove the existance of the custom queryset. + class MyManager(PolymorphicManager): + queryset_class = MyManagerQuerySet + def get_query_set(self): return super(MyManager, self).get_query_set().order_by('-field1') + class ModelWithMyManager(ShowFieldTypeAndContent, Model2A): objects = MyManager() field4 = models.CharField(max_length=10) @@ -117,6 +125,33 @@ class MROBase3(models.Model): class MRODerived(MROBase2, MROBase3): pass +class ParentModelWithManager(PolymorphicModel): + pass +class ChildModelWithManager(PolymorphicModel): + # Also test whether foreign keys receive the manager: + fk = models.ForeignKey(ParentModelWithManager, related_name='childmodel_set') + objects = MyManager() + + +class PlainMyManagerQuerySet(QuerySet): + def my_queryset_foo(self): + return self.all() # Just a method to prove the existance of the custom queryset. + +class PlainMyManager(models.Manager): + def my_queryset_foo(self): + return self.get_query_set().my_queryset_foo() + + def get_query_set(self): + return PlainMyManagerQuerySet(self.model, using=self._db) + +class PlainParentModelWithManager(models.Model): + pass + +class PlainChildModelWithManager(models.Model): + fk = models.ForeignKey(PlainParentModelWithManager, related_name='childmodel_set') + objects = PlainMyManager() + + class MgrInheritA(models.Model): mgrA = models.Manager() mgrA2 = models.Manager() @@ -409,9 +444,11 @@ class PolymorphicTests(TestCase): self.assertEqual(show_base_manager(PlainA), " ") self.assertEqual(show_base_manager(PlainB), " ") self.assertEqual(show_base_manager(PlainC), " ") + self.assertEqual(show_base_manager(Model2A), " ") self.assertEqual(show_base_manager(Model2B), " ") self.assertEqual(show_base_manager(Model2C), " ") + self.assertEqual(show_base_manager(One2OneRelatingModel), " ") self.assertEqual(show_base_manager(One2OneRelatingModelDerived), " ") @@ -604,23 +641,49 @@ class PolymorphicTests(TestCase): ModelWithMyManager.objects.create(field1='D1a', field4='D4a') ModelWithMyManager.objects.create(field1='D1b', field4='D4b') - objects = ModelWithMyManager.objects.all() + objects = ModelWithMyManager.objects.all() # MyManager should reverse the sorting of field1 self.assertEqual(repr(objects[0]), '') self.assertEqual(repr(objects[1]), '') self.assertEqual(len(objects), 2) - self.assertEqual(repr(type(ModelWithMyManager.objects)), "") - self.assertEqual(repr(type(ModelWithMyManager._default_manager)), "") + self.assertIs(type(ModelWithMyManager.objects), MyManager) + self.assertIs(type(ModelWithMyManager._default_manager), MyManager) + self.assertIs(type(ModelWithMyManager.base_objects), models.Manager) def test_manager_inheritance(self): - self.assertEqual(repr(type(MRODerived.objects)), "") # MRO + # by choice of MRO, should be MyManager from MROBase1. + self.assertIs(type(MRODerived.objects), MyManager) # check for correct default manager - self.assertEqual(repr(type(MROBase1._default_manager)), "") + self.assertIs(type(MROBase1._default_manager), MyManager) # Django vanilla inheritance does not inherit MyManager as _default_manager here - self.assertEqual(repr(type(MROBase2._default_manager)), "") + self.assertIs(type(MROBase2._default_manager), MyManager) + + + def test_queryset_assignment(self): + # This is just a consistency check for now, testing standard Django behavior. + parent = PlainParentModelWithManager.objects.create() + child = PlainChildModelWithManager.objects.create(fk=parent) + self.assertIs(type(PlainParentModelWithManager._default_manager), models.Manager) + self.assertIs(type(PlainChildModelWithManager._default_manager), PlainMyManager) + self.assertIs(type(PlainChildModelWithManager.objects), PlainMyManager) + self.assertIs(type(PlainChildModelWithManager.objects.all()), PlainMyManagerQuerySet) + + # A related set is created using the model's _default_manager, so does gain extra methods. + self.assertIs(type(parent.childmodel_set.my_queryset_foo()), PlainMyManagerQuerySet) + + # For polymorphic models, the same should happen. + parent = ParentModelWithManager.objects.create() + child = ChildModelWithManager.objects.create(fk=parent) + self.assertIs(type(ParentModelWithManager._default_manager), PolymorphicManager) + self.assertIs(type(ChildModelWithManager._default_manager), MyManager) + self.assertIs(type(ChildModelWithManager.objects), MyManager) + self.assertIs(type(ChildModelWithManager.objects.my_queryset_foo()), MyManagerQuerySet) + + # A related set is created using the model's _default_manager, so does gain extra methods. + self.assertIs(type(parent.childmodel_set.my_queryset_foo()), MyManagerQuerySet) def test_proxy_model_inheritance(self):