diff --git a/polymorphic/query.py b/polymorphic/query.py index c842414..fb1dfa8 100644 --- a/polymorphic/query.py +++ b/polymorphic/query.py @@ -25,6 +25,25 @@ except ImportError: Polymorphic_QuerySet_objects_per_request = CHUNK_SIZE +if django.VERSION >= (1, 9): + # We ignore this on django < 1.9, as ModelIterable didn't yet exist. + from django.db.models.query import ModelIterable + + + class PolymorphicModelIterable(ModelIterable): + def __iter__(self): + base_iter = super(PolymorphicModelIterable, self).__iter__() + + if self.queryset.polymorphic_disabled: + for o in base_iter: + yield o + return + + real_instances = self.queryset._get_real_instances(base_iter) + for obj in real_instances: + yield obj + + def transmogrify(cls, obj): """ Upcast a class to a different type without asking questions. @@ -72,6 +91,9 @@ class PolymorphicQuerySet(QuerySet): # to that queryset as well). self.polymorphic_deferred_loading = (set([]), True) super(PolymorphicQuerySet, self).__init__(*args, **kwargs) + if django.VERSION >= (1, 9): + # On django < 1.9 we override the iterator() method instead + self._iterable_class = PolymorphicModelIterable def _clone(self, *args, **kwargs): # Django's _clone only copies its own variables, so we need to copy ours here @@ -407,49 +429,51 @@ class PolymorphicQuerySet(QuerySet): return resultlist - def iterator(self): - """ - This function is used by Django for all object retrieval. - By overriding it, we modify the objects that this queryset returns - when it is evaluated (or its get method or other object-returning methods are called). + if django.VERSION < (1, 9): + # On django 1.9+, we can define self._iterator_class instead of iterator() + def iterator(self): + """ + This function is used by Django for all object retrieval. + By overriding it, we modify the objects that this queryset returns + when it is evaluated (or its get method or other object-returning methods are called). - Here we do the same as:: + Here we do the same as:: - base_result_objects = list(super(PolymorphicQuerySet, self).iterator()) - real_results = self._get_real_instances(base_result_objects) - for o in real_results: yield o + base_result_objects = list(super(PolymorphicQuerySet, self).iterator()) + real_results = self._get_real_instances(base_result_objects) + for o in real_results: yield o - but it requests the objects in chunks from the database, - with Polymorphic_QuerySet_objects_per_request per chunk - """ - base_iter = super(PolymorphicQuerySet, self).iterator() + but it requests the objects in chunks from the database, + with Polymorphic_QuerySet_objects_per_request per chunk + """ + base_iter = super(PolymorphicQuerySet, self).iterator() - # disabled => work just like a normal queryset - if self.polymorphic_disabled: - for o in base_iter: - yield o - return - - while True: - base_result_objects = [] - reached_end = False - - for i in range(Polymorphic_QuerySet_objects_per_request): - try: - o = next(base_iter) - base_result_objects.append(o) - except StopIteration: - reached_end = True - break - - real_results = self._get_real_instances(base_result_objects) - - for o in real_results: - yield o - - if reached_end: + # disabled => work just like a normal queryset + if self.polymorphic_disabled: + for o in base_iter: + yield o return + while True: + base_result_objects = [] + reached_end = False + + for i in range(Polymorphic_QuerySet_objects_per_request): + try: + o = next(base_iter) + base_result_objects.append(o) + except StopIteration: + reached_end = True + break + + real_results = self._get_real_instances(base_result_objects) + + for o in real_results: + yield o + + if reached_end: + return + def __repr__(self, *args, **kwargs): if self.model.polymorphic_query_multiline_output: result = [repr(o) for o in self.all()]