diff --git a/polymorphic/query.py b/polymorphic/query.py index 61071dc..4e55d71 100644 --- a/polymorphic/query.py +++ b/polymorphic/query.py @@ -7,12 +7,12 @@ from __future__ import absolute_import from collections import defaultdict import django -from django.db.models.query import QuerySet +from django.db.models.query import QuerySet, Q from django.contrib.contenttypes.models import ContentType from django.utils import six from .query_translate import translate_polymorphic_filter_definitions_in_kwargs, translate_polymorphic_filter_definitions_in_args -from .query_translate import translate_polymorphic_field_path +from .query_translate import translate_polymorphic_field_path, translate_polymorphic_Q_object # chunk-size: maximum number of objects requested per db-request # by the polymorphic queryset.iterator() implementation; we use the same chunk size as Django @@ -115,21 +115,53 @@ class PolymorphicQuerySet(QuerySet): """for aggregate and annotate kwargs: allow ModelX___field syntax for kwargs, forbid it for args. Modifies kwargs if needed (these are Aggregate objects, we translate the lookup member variable)""" - def patch_lookup(a): - if django.VERSION < (1, 8): - a.lookup = translate_polymorphic_field_path(self.model, a.lookup) + def patch_lookup_lt_18(a): + a.lookup = translate_polymorphic_field_path(self.model, a.lookup) + + + def patch_lookup_gte_18(a): + # With Django > 1.8, the field on which the aggregate operates is + # stored inside a complex query expression. + if isinstance(a, Q): + translate_polymorphic_Q_object(self.model, a) + elif hasattr(a, 'get_source_expressions'): + for source_expression in a.get_source_expressions(): + patch_lookup_gte_18(source_expression) else: - # With Django > 1.8, the field on which the aggregate operates is - # stored inside a query expression. - if hasattr(a, 'source_expressions'): - a.source_expressions[0].name = translate_polymorphic_field_path( - self.model, a.source_expressions[0].name) - - get_lookup = lambda a: a.lookup if django.VERSION < (1, 8) else a.source_expressions[0].name - + a.name = translate_polymorphic_field_path(self.model, a.name) + + ___lookup_assert_msg = 'PolymorphicModel: annotate()/aggregate(): ___ model lookup supported for keyword arguments only' + def test___lookup_for_args_lt_18(a): + assert '___' not in a.lookup, ___lookup_assert_msg + + def test___lookup_for_args_gte_18(a): + """ *args might be complex expressions too in django 1.8 so + the testing for a '___' is rather complex on this one """ + if isinstance(a, Q): + def tree_node_test___lookup(my_model, node): + " process all children of this Q node " + for i in range(len(node.children)): + child = node.children[i] + + if type(child) == tuple: + # this Q object child is a tuple => a kwarg like Q( instance_of=ModelB ) + assert '___' not in child[0], ___lookup_assert_msg + else: + # this Q object child is another Q object, recursively process this as well + tree_node_test___lookup(my_model, child) + + tree_node_test___lookup(self.model, a) + elif hasattr(a, 'get_source_expressions'): + for source_expression in a.get_source_expressions(): + test___lookup_for_args_gte_18(source_expression) + else: + assert '___' not in a.name, ___lookup_assert_msg + for a in args: - assert '___' not in get_lookup(a), 'PolymorphicModel: annotate()/aggregate(): ___ model lookup supported for keyword arguments only' + test___lookup = test___lookup_for_args_lt_18 if django.VERSION < (1, 8) else test___lookup_for_args_gte_18 + test___lookup(a) for a in six.itervalues(kwargs): + patch_lookup = patch_lookup_lt_18 if django.VERSION < (1, 8) else patch_lookup_gte_18 patch_lookup(a) def annotate(self, *args, **kwargs): diff --git a/polymorphic/tests.py b/polymorphic/tests.py index 109eef9..c04eebd 100644 --- a/polymorphic/tests.py +++ b/polymorphic/tests.py @@ -15,6 +15,8 @@ from django.db.models.query import QuerySet from django.test import TestCase from django.db.models import Q, Count +if django.VERSION >= (1, 8): + from django.db.models import Case, When from django.db import models from django.contrib.contenttypes.models import ContentType from django.utils import six @@ -1021,6 +1023,50 @@ class PolymorphicTests(TestCase): # test that we can delete the object t.delete() + + def test_polymorphic__aggregate(self): + """ test ModelX___field syntax on aggregate (should work for annotate either) """ + + Model2A.objects.create(field1='A1') + Model2B.objects.create(field1='A1', field2='B2') + Model2B.objects.create(field1='A1', field2='B2') + + # aggregate using **kwargs + result = Model2A.objects.aggregate(cnt=Count('Model2B___field2')) + self.assertEqual(result, {'cnt': 2}) + + # aggregate using **args + with self.assertRaisesMessage(AssertionError, 'PolymorphicModel: annotate()/aggregate(): ___ model lookup supported for keyword arguments only'): + Model2A.objects.aggregate(Count('Model2B___field2')) + + + + @skipIf(django.VERSION < (1,8,), "This test needs Django >=1.8") + def test_polymorphic__complex_aggregate(self): + """ test (complex expression on) aggregate (should work for annotate either) """ + + Model2A.objects.create(field1='A1') + Model2B.objects.create(field1='A1', field2='B2') + Model2B.objects.create(field1='A1', field2='B2') + + # aggregate using **kwargs + result = Model2A.objects.aggregate( + cnt_a1=Count(Case(When(field1='A1', then=1))), + cnt_b2=Count(Case(When(Model2B___field2='B2', then=1))), + ) + self.assertEqual(result, {'cnt_b2': 2, 'cnt_a1': 3}) + + # aggregate using **args + # we have to set the defaul alias or django won't except a complex expression + # on aggregate/annotate + def ComplexAgg(expression): + complexagg = Count(expression)*10 + complexagg.default_alias = 'complexagg' + return complexagg + + with self.assertRaisesMessage(AssertionError, 'PolymorphicModel: annotate()/aggregate(): ___ model lookup supported for keyword arguments only'): + Model2A.objects.aggregate(ComplexAgg('Model2B___field2')) + class RegressionTests(TestCase):