diff --git a/docs/advanced.rst b/docs/advanced.rst index dbc7f82..c9fcac4 100644 --- a/docs/advanced.rst +++ b/docs/advanced.rst @@ -176,8 +176,9 @@ About Queryset Methods methods now, it's best if you use ``Model.base_objects.values...`` as this is guaranteed to not change. -* ``defer()`` and ``only()`` are not yet supported (support will be added - in the future). +* ``defer()`` and ``only()`` work as expected. On Django 1.5+ they support + the ``ModelX___field`` syntax, but on Django 1.4 it is only possible to + pass fields on the base model into these methods. Using enhanced Q-objects in any Places @@ -231,10 +232,10 @@ Restrictions & Caveats * Database Performance regarding concrete Model inheritance in general. Please see the :ref:`performance`. -* Queryset methods ``values()``, ``values_list()``, ``select_related()``, - ``defer()`` and ``only()`` are not yet fully supported (see above). - ``extra()`` has one restriction: the resulting objects are required to have - a unique primary key within the result set. +* Queryset methods ``values()``, ``values_list()``, and ``select_related()`` + are not yet fully supported (see above). ``extra()`` has one restriction: + the resulting objects are required to have a unique primary key within + the result set. * Diamond shaped inheritance: There seems to be a general problem with diamond shaped multiple model inheritance with Django models diff --git a/polymorphic/query.py b/polymorphic/query.py index 4e55d71..0a31f98 100644 --- a/polymorphic/query.py +++ b/polymorphic/query.py @@ -4,6 +4,7 @@ """ from __future__ import absolute_import +import copy from collections import defaultdict import django @@ -64,12 +65,21 @@ class PolymorphicQuerySet(QuerySet): def __init__(self, *args, **kwargs): "init our queryset object member variables" self.polymorphic_disabled = False + # A parallel structure to django.db.models.query.Query.deferred_loading, + # which we maintain with the untranslated field names passed to + # .defer() and .only() in order to be able to retranslate them when + # retrieving the real instance (so that the deferred fields apply + # to that queryset as well). + self.polymorphic_deferred_loading = (set([]), True) super(PolymorphicQuerySet, self).__init__(*args, **kwargs) def _clone(self, *args, **kwargs): "Django's _clone only copies its own variables, so we need to copy ours here" new = super(PolymorphicQuerySet, self)._clone(*args, **kwargs) new.polymorphic_disabled = self.polymorphic_disabled + new.polymorphic_deferred_loading = ( + copy.copy(self.polymorphic_deferred_loading[0]), + self.polymorphic_deferred_loading[1]) return new if django.VERSION >= (1, 7): @@ -111,6 +121,64 @@ class PolymorphicQuerySet(QuerySet): new_args = [translate_polymorphic_field_path(self.model, a) for a in args] return super(PolymorphicQuerySet, self).order_by(*new_args, **kwargs) + def defer(self, *fields): + """ + Translate the field paths in the args, then call vanilla defer. + + Also retain a copy of the original fields passed, which we'll need + when we're retrieving the real instance (since we'll need to translate + them again, as the model will have changed). + """ + new_fields = [translate_polymorphic_field_path(self.model, a) for a in fields] + clone = super(PolymorphicQuerySet, self).defer(*new_fields) + clone._polymorphic_add_deferred_loading(fields) + return clone + + def only(self, *fields): + """ + Translate the field paths in the args, then call vanilla only. + + Also retain a copy of the original fields passed, which we'll need + when we're retrieving the real instance (since we'll need to translate + them again, as the model will have changed). + """ + new_fields = [translate_polymorphic_field_path(self.model, a) for a in fields] + clone = super(PolymorphicQuerySet, self).only(*new_fields) + clone._polymorphic_add_immediate_loading(fields) + return clone + + def _polymorphic_add_deferred_loading(self, field_names): + """ + Follows the logic of django.db.models.query.Query.add_deferred_loading(), + but for the non-translated field names that were passed to self.defer(). + """ + existing, defer = self.polymorphic_deferred_loading + if defer: + # Add to existing deferred names. + self.polymorphic_deferred_loading = existing.union(field_names), True + else: + # Remove names from the set of any existing "immediate load" names. + self.polymorphic_deferred_loading = existing.difference(field_names), False + + def _polymorphic_add_immediate_loading(self, field_names): + """ + Follows the logic of django.db.models.query.Query.add_immediate_loading(), + but for the non-translated field names that were passed to self.only() + """ + existing, defer = self.polymorphic_deferred_loading + field_names = set(field_names) + if 'pk' in field_names: + field_names.remove('pk') + field_names.add(self.get_meta().pk.name) + + if defer: + # Remove any existing deferred names from the current set before + # setting the new names. + self.polymorphic_deferred_loading = field_names.difference(existing), False + else: + # Replace any existing "immediate load" field names. + self.polymorphic_deferred_loading = field_names, False + def _process_aggregate_args(self, args, kwargs): """for aggregate and annotate kwargs: allow ModelX___field syntax for kwargs, forbid it for args. Modifies kwargs if needed (these are Aggregate objects, we translate the lookup member variable)""" @@ -282,6 +350,26 @@ class PolymorphicQuerySet(QuerySet): }) real_objects.query.select_related = self.query.select_related # copy select related configuration to new qs + # Copy deferred fields configuration to the new queryset + deferred_loading_fields = [] + existing_fields = self.polymorphic_deferred_loading[0] + for field in existing_fields: + try: + translated_field_name = translate_polymorphic_field_path( + real_concrete_class, field) + except AssertionError: + if '___' in field: + # The originally passed argument to .defer() or .only() + # was in the form Model2B___field2, where Model2B is + # now a superclass of real_concrete_class. Thus it's + # sufficient to just use the field name. + translated_field_name = field.rpartition('___')[-1] + else: + raise + + deferred_loading_fields.append(translated_field_name) + real_objects.query.deferred_loading = (set(deferred_loading_fields), self.query.deferred_loading[1]) + for real_object in real_objects: o_pk = getattr(real_object, pk_name) real_class = real_object.get_real_instance_class() diff --git a/polymorphic/tests.py b/polymorphic/tests.py index c04eebd..671e445 100644 --- a/polymorphic/tests.py +++ b/polymorphic/tests.py @@ -554,6 +554,76 @@ class PolymorphicTests(TestCase): self.assertEqual(repr(objects[2]), '') self.assertEqual(repr(objects[3]), '') + def test_defer_fields(self): + self.create_model2abcd() + + objects_deferred = Model2A.objects.defer('field1') + self.assertNotIn('field1', objects_deferred[0].__dict__, + 'field1 was not deferred (using defer())') + self.assertEqual(repr(objects_deferred[0]), + '') + self.assertEqual(repr(objects_deferred[1]), + '') + self.assertEqual(repr(objects_deferred[2]), + '') + self.assertEqual(repr(objects_deferred[3]), + '') + + objects_only = Model2A.objects.only('polymorphic_ctype', 'field1') + self.assertIn('field1', objects_only[0].__dict__, + 'qs.only("field1") was used, but field1 was incorrectly deferred') + self.assertIn('field1', objects_only[3].__dict__, + 'qs.only("field1") was used, but field1 was incorrectly deferred' + ' on a child model') + self.assertNotIn('field4', objects_only[3].__dict__, + 'field4 was not deferred (using only())') + self.assertEqual(repr(objects_only[0]), + '') + self.assertEqual(repr(objects_only[1]), + '') + self.assertEqual(repr(objects_only[2]), + '') + self.assertEqual(repr(objects_only[3]), + '') + + # A bug in Django 1.4 prevents using defer across reverse relations + # . Since polymorphic + # uses reverse relations to traverse down model inheritance, deferring + # fields in child models will not work in Django 1.4. + @skipIf(django.VERSION < (1, 5), "Django 1.4 does not support defer on related fields") + def test_defer_related_fields(self): + self.create_model2abcd() + + objects_deferred_field4 = Model2A.objects.defer('Model2D___field4') + self.assertNotIn('field4', objects_deferred_field4[3].__dict__, + 'field4 was not deferred (using defer(), traversing inheritance)') + self.assertEqual(repr(objects_deferred_field4[0]), + '') + self.assertEqual(repr(objects_deferred_field4[1]), + '') + self.assertEqual(repr(objects_deferred_field4[2]), + '') + self.assertEqual(repr(objects_deferred_field4[3]), + '') + + objects_only_field4 = Model2A.objects.only( + 'polymorphic_ctype', 'field1', + 'Model2B___id', 'Model2B___field2', 'Model2B___model2a_ptr', + 'Model2C___id', 'Model2C___field3', 'Model2C___model2b_ptr', + 'Model2D___id', 'Model2D___model2c_ptr') + self.assertEqual(repr(objects_only_field4[0]), + '') + self.assertEqual(repr(objects_only_field4[1]), + '') + self.assertEqual(repr(objects_only_field4[2]), + '') + self.assertEqual(repr(objects_only_field4[3]), + '') + + def test_manual_get_real_instance(self): self.create_model2abcd()