diff --git a/polymorphic/query.py b/polymorphic/query.py index 8cac5e0..6c87a93 100644 --- a/polymorphic/query.py +++ b/polymorphic/query.py @@ -20,40 +20,6 @@ from .query_translate import translate_polymorphic_field_path, translate_polymor Polymorphic_QuerySet_objects_per_request = 100 -def _polymorphic_iterator(queryset, base_iter): - """ - Here we do the same as:: - - real_results = queryset._get_real_instances(list(base_iter)) - for o in real_results: yield o - - but it requests the objects in chunks from the database, - with Polymorphic_QuerySet_objects_per_request per chunk - """ - while True: - base_result_objects = [] - reached_end = False - - # Make sure the base iterator is read in chunks instead of - # reading it completely, in case our caller read only a few objects. - for i in range(Polymorphic_QuerySet_objects_per_request): - - try: - o = next(base_iter) - base_result_objects.append(o) - except StopIteration: - reached_end = True - break - - real_results = queryset._get_real_instances(base_result_objects) - - for o in real_results: - yield o - - if reached_end: - return - - class PolymorphicModelIterable(ModelIterable): """ ModelIterable for PolymorphicModel @@ -66,7 +32,40 @@ class PolymorphicModelIterable(ModelIterable): base_iter = super(PolymorphicModelIterable, self).__iter__() if self.queryset.polymorphic_disabled: return base_iter - return _polymorphic_iterator(self.queryset, base_iter) + return self._polymorphic_iterator(base_iter) + + def _polymorphic_iterator(self, base_iter): + """ + Here we do the same as:: + + real_results = queryset._get_real_instances(list(base_iter)) + for o in real_results: yield o + + but it requests the objects in chunks from the database, + with Polymorphic_QuerySet_objects_per_request per chunk + """ + while True: + base_result_objects = [] + reached_end = False + + # Make sure the base iterator is read in chunks instead of + # reading it completely, in case our caller read only a few objects. + for i in range(Polymorphic_QuerySet_objects_per_request): + + try: + o = next(base_iter) + base_result_objects.append(o) + except StopIteration: + reached_end = True + break + + real_results = self.queryset._get_real_instances(base_result_objects) + + for o in real_results: + yield o + + if reached_end: + return def transmogrify(cls, obj):