django-polymorphic/polymorphic/query.py

476 lines
22 KiB
Python

# -*- coding: utf-8 -*-
""" QuerySet for PolymorphicModel
Please see README.rst or DOCS.rst or http://chrisglass.github.com/django_polymorphic/
"""
from __future__ import absolute_import
import copy
from collections import defaultdict
import django
from django.db.models.query import QuerySet, Q
from django.contrib.contenttypes.models import ContentType
from django.utils import six
from .query_translate import translate_polymorphic_filter_definitions_in_kwargs, translate_polymorphic_filter_definitions_in_args
from .query_translate import translate_polymorphic_field_path, translate_polymorphic_Q_object
# chunk-size: maximum number of objects requested per db-request
# by the polymorphic queryset.iterator() implementation; we use the same chunk size as Django
try:
from django.db.models.query import CHUNK_SIZE # this is 100 for Django 1.1/1.2
except ImportError:
# CHUNK_SIZE was removed in Django 1.6
CHUNK_SIZE = 100
Polymorphic_QuerySet_objects_per_request = CHUNK_SIZE
def transmogrify(cls, obj):
"""
Upcast a class to a different type without asking questions.
"""
if not '__init__' in obj.__dict__:
# Just assign __class__ to a different value.
new = obj
new.__class__ = cls
else:
# Run constructor, reassign values
new = cls()
for k, v in obj.__dict__.items():
new.__dict__[k] = v
return new
###################################################################################
# PolymorphicQuerySet
def _query_annotations(query):
try:
return query.annotations
except AttributeError:
# Django < 1.8
return query.aggregates
class PolymorphicQuerySet(QuerySet):
"""
QuerySet for PolymorphicModel
Contains the core functionality for PolymorphicModel
Usually not explicitly needed, except if a custom queryset class
is to be used.
"""
def __init__(self, *args, **kwargs):
"init our queryset object member variables"
self.polymorphic_disabled = False
# A parallel structure to django.db.models.query.Query.deferred_loading,
# which we maintain with the untranslated field names passed to
# .defer() and .only() in order to be able to retranslate them when
# retrieving the real instance (so that the deferred fields apply
# to that queryset as well).
self.polymorphic_deferred_loading = (set([]), True)
super(PolymorphicQuerySet, self).__init__(*args, **kwargs)
def _clone(self, *args, **kwargs):
"Django's _clone only copies its own variables, so we need to copy ours here"
new = super(PolymorphicQuerySet, self)._clone(*args, **kwargs)
new.polymorphic_disabled = self.polymorphic_disabled
new.polymorphic_deferred_loading = (
copy.copy(self.polymorphic_deferred_loading[0]),
self.polymorphic_deferred_loading[1])
return new
if django.VERSION >= (1, 7):
def as_manager(cls):
# Make sure the Django 1.7 way of creating managers works.
from .managers import PolymorphicManager
manager = PolymorphicManager.from_queryset(cls)()
manager._built_with_as_manager = True
return manager
as_manager.queryset_only = True
as_manager = classmethod(as_manager)
def non_polymorphic(self):
"""switch off polymorphic behaviour for this query.
When the queryset is evaluated, only objects of the type of the
base class used for this query are returned."""
qs = self._clone()
qs.polymorphic_disabled = True
return qs
def instance_of(self, *args):
"""Filter the queryset to only include the classes in args (and their subclasses).
Implementation in _translate_polymorphic_filter_defnition."""
return self.filter(instance_of=args)
def not_instance_of(self, *args):
"""Filter the queryset to exclude the classes in args (and their subclasses).
Implementation in _translate_polymorphic_filter_defnition."""
return self.filter(not_instance_of=args)
def _filter_or_exclude(self, negate, *args, **kwargs):
"We override this internal Django functon as it is used for all filter member functions."
translate_polymorphic_filter_definitions_in_args(self.model, args) # the Q objects
additional_args = translate_polymorphic_filter_definitions_in_kwargs(self.model, kwargs) # filter_field='data'
return super(PolymorphicQuerySet, self)._filter_or_exclude(negate, *(list(args) + additional_args), **kwargs)
def order_by(self, *args, **kwargs):
"""translate the field paths in the args, then call vanilla order_by."""
new_args = [translate_polymorphic_field_path(self.model, a) for a in args]
return super(PolymorphicQuerySet, self).order_by(*new_args, **kwargs)
def defer(self, *fields):
"""
Translate the field paths in the args, then call vanilla defer.
Also retain a copy of the original fields passed, which we'll need
when we're retrieving the real instance (since we'll need to translate
them again, as the model will have changed).
"""
new_fields = [translate_polymorphic_field_path(self.model, a) for a in fields]
clone = super(PolymorphicQuerySet, self).defer(*new_fields)
clone._polymorphic_add_deferred_loading(fields)
return clone
def only(self, *fields):
"""
Translate the field paths in the args, then call vanilla only.
Also retain a copy of the original fields passed, which we'll need
when we're retrieving the real instance (since we'll need to translate
them again, as the model will have changed).
"""
new_fields = [translate_polymorphic_field_path(self.model, a) for a in fields]
clone = super(PolymorphicQuerySet, self).only(*new_fields)
clone._polymorphic_add_immediate_loading(fields)
return clone
def _polymorphic_add_deferred_loading(self, field_names):
"""
Follows the logic of django.db.models.query.Query.add_deferred_loading(),
but for the non-translated field names that were passed to self.defer().
"""
existing, defer = self.polymorphic_deferred_loading
if defer:
# Add to existing deferred names.
self.polymorphic_deferred_loading = existing.union(field_names), True
else:
# Remove names from the set of any existing "immediate load" names.
self.polymorphic_deferred_loading = existing.difference(field_names), False
def _polymorphic_add_immediate_loading(self, field_names):
"""
Follows the logic of django.db.models.query.Query.add_immediate_loading(),
but for the non-translated field names that were passed to self.only()
"""
existing, defer = self.polymorphic_deferred_loading
field_names = set(field_names)
if 'pk' in field_names:
field_names.remove('pk')
field_names.add(self.model._meta.pk.name)
if defer:
# Remove any existing deferred names from the current set before
# setting the new names.
self.polymorphic_deferred_loading = field_names.difference(existing), False
else:
# Replace any existing "immediate load" field names.
self.polymorphic_deferred_loading = field_names, False
def _process_aggregate_args(self, args, kwargs):
"""for aggregate and annotate kwargs: allow ModelX___field syntax for kwargs, forbid it for args.
Modifies kwargs if needed (these are Aggregate objects, we translate the lookup member variable)"""
def patch_lookup_lt_18(a):
a.lookup = translate_polymorphic_field_path(self.model, a.lookup)
def patch_lookup_gte_18(a):
# With Django > 1.8, the field on which the aggregate operates is
# stored inside a complex query expression.
if isinstance(a, Q):
translate_polymorphic_Q_object(self.model, a)
elif hasattr(a, 'get_source_expressions'):
for source_expression in a.get_source_expressions():
if source_expression is not None:
patch_lookup_gte_18(source_expression)
else:
a.name = translate_polymorphic_field_path(self.model, a.name)
___lookup_assert_msg = 'PolymorphicModel: annotate()/aggregate(): ___ model lookup supported for keyword arguments only'
def test___lookup_for_args_lt_18(a):
assert '___' not in a.lookup, ___lookup_assert_msg
def test___lookup_for_args_gte_18(a):
""" *args might be complex expressions too in django 1.8 so
the testing for a '___' is rather complex on this one """
if isinstance(a, Q):
def tree_node_test___lookup(my_model, node):
" process all children of this Q node "
for i in range(len(node.children)):
child = node.children[i]
if type(child) == tuple:
# this Q object child is a tuple => a kwarg like Q( instance_of=ModelB )
assert '___' not in child[0], ___lookup_assert_msg
else:
# this Q object child is another Q object, recursively process this as well
tree_node_test___lookup(my_model, child)
tree_node_test___lookup(self.model, a)
elif hasattr(a, 'get_source_expressions'):
for source_expression in a.get_source_expressions():
test___lookup_for_args_gte_18(source_expression)
else:
assert '___' not in a.name, ___lookup_assert_msg
for a in args:
test___lookup = test___lookup_for_args_lt_18 if django.VERSION < (1, 8) else test___lookup_for_args_gte_18
test___lookup(a)
for a in six.itervalues(kwargs):
patch_lookup = patch_lookup_lt_18 if django.VERSION < (1, 8) else patch_lookup_gte_18
patch_lookup(a)
def annotate(self, *args, **kwargs):
"""translate the polymorphic field paths in the kwargs, then call vanilla annotate.
_get_real_instances will do the rest of the job after executing the query."""
self._process_aggregate_args(args, kwargs)
return super(PolymorphicQuerySet, self).annotate(*args, **kwargs)
def aggregate(self, *args, **kwargs):
"""translate the polymorphic field paths in the kwargs, then call vanilla aggregate.
We need no polymorphic object retrieval for aggregate => switch it off."""
self._process_aggregate_args(args, kwargs)
qs = self.non_polymorphic()
return super(PolymorphicQuerySet, qs).aggregate(*args, **kwargs)
if django.VERSION >= (1, 9):
# On Django < 1.9, 'qs.values(...)' returned a new special ValuesQuerySet
# object, which our polymorphic modifications didn't apply to.
# Starting with Django 1.9, the copy returned by 'qs.values(...)' has the
# same class as 'qs', so our polymorphic modifications would apply.
# We want to leave values queries untouched, so we set 'polymorphic_disabled'.
def _values(self, *args, **kwargs):
clone = super(PolymorphicQuerySet, self)._values(*args, **kwargs)
clone.polymorphic_disabled = True
return clone
# Since django_polymorphic 'V1.0 beta2', extra() always returns polymorphic results.
# The resulting objects are required to have a unique primary key within the result set
# (otherwise an error is thrown).
# The "polymorphic" keyword argument is not supported anymore.
# def extra(self, *args, **kwargs):
def _get_real_instances(self, base_result_objects):
"""
Polymorphic object loader
Does the same as:
return [ o.get_real_instance() for o in base_result_objects ]
but more efficiently.
The list base_result_objects contains the objects from the executed
base class query. The class of all of them is self.model (our base model).
Some, many or all of these objects were not created and stored as
class self.model, but as a class derived from self.model. We want to re-fetch
these objects from the db as their original class so we can return them
just as they were created/saved.
We identify these objects by looking at o.polymorphic_ctype, which specifies
the real class of these objects (the class at the time they were saved).
First, we sort the result objects in base_result_objects for their
subclass (from o.polymorphic_ctype), and then we execute one db query per
subclass of objects. Here, we handle any annotations from annotate().
Finally we re-sort the resulting objects into the correct order and
return them as a list.
"""
ordered_id_list = [] # list of ids of result-objects in correct order
results = {} # polymorphic dict of result-objects, keyed with their id (no order)
# dict contains one entry per unique model type occurring in result,
# in the format idlist_per_model[modelclass]=[list-of-object-ids]
idlist_per_model = defaultdict(list)
# django's automatic ".pk" field does not always work correctly for
# custom fields in derived objects (unclear yet who to put the blame on).
# We get different type(o.pk) in this case.
# We work around this by using the real name of the field directly
# for accessing the primary key of the the derived objects.
# We might assume that self.model._meta.pk.name gives us the name of the primary key field,
# but it doesn't. Therefore we use polymorphic_primary_key_name, which we set up in base.py.
pk_name = self.model.polymorphic_primary_key_name
# - sort base_result_object ids into idlist_per_model lists, depending on their real class;
# - also record the correct result order in "ordered_id_list"
# - store objects that already have the correct class into "results"
base_result_objects_by_id = {}
self_model_class_id = ContentType.objects.get_for_model(self.model, for_concrete_model=False).pk
self_concrete_model_class_id = ContentType.objects.get_for_model(self.model, for_concrete_model=True).pk
for base_object in base_result_objects:
ordered_id_list.append(base_object.pk)
# check if id of the result object occurres more than once - this can happen e.g. with base_objects.extra(tables=...)
if not base_object.pk in base_result_objects_by_id:
base_result_objects_by_id[base_object.pk] = base_object
if base_object.polymorphic_ctype_id == self_model_class_id:
# Real class is exactly the same as base class, go straight to results
results[base_object.pk] = base_object
else:
real_concrete_class = base_object.get_real_instance_class()
real_concrete_class_id = base_object.get_real_concrete_instance_class_id()
if real_concrete_class_id is None:
# Dealing with a stale content type
continue
elif real_concrete_class_id == self_concrete_model_class_id:
# Real and base classes share the same concrete ancestor,
# upcast it and put it in the results
results[base_object.pk] = transmogrify(real_concrete_class, base_object)
else:
real_concrete_class = ContentType.objects.get_for_id(real_concrete_class_id).model_class()
idlist_per_model[real_concrete_class].append(getattr(base_object, pk_name))
# For each model in "idlist_per_model" request its objects (the real model)
# from the db and store them in results[].
# Then we copy the annotate fields from the base objects to the real objects.
# Then we copy the extra() select fields from the base objects to the real objects.
# TODO: defer(), only(): support for these would be around here
for real_concrete_class, idlist in idlist_per_model.items():
real_objects = real_concrete_class.base_objects.filter(**{
('%s__in' % pk_name): idlist,
})
real_objects.query.select_related = self.query.select_related # copy select related configuration to new qs
# Copy deferred fields configuration to the new queryset
deferred_loading_fields = []
existing_fields = self.polymorphic_deferred_loading[0]
for field in existing_fields:
try:
translated_field_name = translate_polymorphic_field_path(
real_concrete_class, field)
except AssertionError:
if '___' in field:
# The originally passed argument to .defer() or .only()
# was in the form Model2B___field2, where Model2B is
# now a superclass of real_concrete_class. Thus it's
# sufficient to just use the field name.
translated_field_name = field.rpartition('___')[-1]
else:
raise
deferred_loading_fields.append(translated_field_name)
real_objects.query.deferred_loading = (set(deferred_loading_fields), self.query.deferred_loading[1])
for real_object in real_objects:
o_pk = getattr(real_object, pk_name)
real_class = real_object.get_real_instance_class()
# If the real class is a proxy, upcast it
if real_class != real_concrete_class:
real_object = transmogrify(real_class, real_object)
if _query_annotations(self.query):
for anno_field_name in six.iterkeys(_query_annotations(self.query)):
attr = getattr(base_result_objects_by_id[o_pk], anno_field_name)
setattr(real_object, anno_field_name, attr)
if self.query.extra_select:
for select_field_name in six.iterkeys(self.query.extra_select):
attr = getattr(base_result_objects_by_id[o_pk], select_field_name)
setattr(real_object, select_field_name, attr)
results[o_pk] = real_object
# re-create correct order and return result list
resultlist = [results[ordered_id] for ordered_id in ordered_id_list if ordered_id in results]
# set polymorphic_annotate_names in all objects (currently just used for debugging/printing)
if _query_annotations(self.query):
annotate_names = list(six.iterkeys(_query_annotations(self.query))) # get annotate field list
for real_object in resultlist:
real_object.polymorphic_annotate_names = annotate_names
# set polymorphic_extra_select_names in all objects (currently just used for debugging/printing)
if self.query.extra_select:
extra_select_names = list(six.iterkeys(self.query.extra_select)) # get extra select field list
for real_object in resultlist:
real_object.polymorphic_extra_select_names = extra_select_names
return resultlist
def iterator(self):
"""
This function is used by Django for all object retrieval.
By overriding it, we modify the objects that this queryset returns
when it is evaluated (or its get method or other object-returning methods are called).
Here we do the same as:
base_result_objects=list(super(PolymorphicQuerySet, self).iterator())
real_results=self._get_real_instances(base_result_objects)
for o in real_results: yield o
but it requests the objects in chunks from the database,
with Polymorphic_QuerySet_objects_per_request per chunk
"""
base_iter = super(PolymorphicQuerySet, self).iterator()
# disabled => work just like a normal queryset
if self.polymorphic_disabled:
for o in base_iter:
yield o
return
while True:
base_result_objects = []
reached_end = False
for i in range(Polymorphic_QuerySet_objects_per_request):
try:
o = next(base_iter)
base_result_objects.append(o)
except StopIteration:
reached_end = True
break
real_results = self._get_real_instances(base_result_objects)
for o in real_results:
yield o
if reached_end:
return
def __repr__(self, *args, **kwargs):
if self.model.polymorphic_query_multiline_output:
result = [repr(o) for o in self.all()]
return '[ ' + ',\n '.join(result) + ' ]'
else:
return super(PolymorphicQuerySet, self).__repr__(*args, **kwargs)
class _p_list_class(list):
def __repr__(self, *args, **kwargs):
result = [repr(o) for o in self]
return '[ ' + ',\n '.join(result) + ' ]'
def get_real_instances(self, base_result_objects=None):
"same as _get_real_instances, but make sure that __repr__ for ShowField... creates correct output"
if not base_result_objects:
base_result_objects = self
olist = self._get_real_instances(base_result_objects)
if not self.model.polymorphic_query_multiline_output:
return olist
clist = PolymorphicQuerySet._p_list_class(olist)
return clist