From f769ed756800140d210abb2a4d4413a1a68daa05 Mon Sep 17 00:00:00 2001 From: Diederik van der Boor Date: Fri, 12 Jul 2019 10:01:40 +0200 Subject: [PATCH] Fix dealing with missing derived table data with new prefetching code --- polymorphic/query.py | 8 ++++++- polymorphic/tests/test_orm.py | 43 ++++++++++++++++++++++++++--------- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/polymorphic/query.py b/polymorphic/query.py index fcf7ae4..795f95f 100644 --- a/polymorphic/query.py +++ b/polymorphic/query.py @@ -364,6 +364,7 @@ class PolymorphicQuerySet(QuerySet): # upcast it and put it in the results resultlist.append(transmogrify(real_concrete_class, base_object)) else: + # This model has a concrete derived class, track it for bulk retrieval. real_concrete_class = content_type_manager.get_for_id(real_concrete_class_id).model_class() idlist_per_model[real_concrete_class].append(getattr(base_object, pk_name)) indexlist_per_model[real_concrete_class].append((i, len(resultlist))) @@ -416,7 +417,12 @@ class PolymorphicQuerySet(QuerySet): for i, j in indices: base_object = base_result_objects[i] o_pk = getattr(base_object, pk_name) - real_object = copy.copy(real_objects_dict[o_pk]) + real_object = real_objects_dict.get(o_pk) + if real_object is None: + continue + + # need shallow copy to avoid duplication in caches (see PR #353) + real_object = copy.copy(real_object) real_class = real_object.get_real_instance_class() # If the real class is a proxy, upcast it diff --git a/polymorphic/tests/test_orm.py b/polymorphic/tests/test_orm.py index bcca159..fecdcd7 100644 --- a/polymorphic/tests/test_orm.py +++ b/polymorphic/tests/test_orm.py @@ -347,28 +347,32 @@ class PolymorphicTests(TransactionTestCase): def test_queryset_missing_derived(self): a = Model2A.objects.create(field1='A1') b = Model2B.objects.create(field1='B1', field2='B2') + c = Model2C.objects.create(field1='C1', field2='C2', field3='C3') b_base = Model2A.objects.non_polymorphic().get(pk=b.pk) + c_base = Model2A.objects.non_polymorphic().get(pk=c.pk) b.delete(keep_parents=True) # e.g. table was truncated - qs1 = Model2A.objects.order_by('field1').non_polymorphic() - qs2 = Model2A.objects.order_by('field1').all() + qs_base = Model2A.objects.order_by('field1').non_polymorphic() + qs_polymorphic = Model2A.objects.order_by('field1').all() - self.assertEqual(list(qs1), [a, b_base]) - self.assertEqual(list(qs2), [a]) + self.assertEqual(list(qs_base), [a, b_base, c_base]) + self.assertEqual(list(qs_polymorphic), [a, c]) def test_queryset_missing_contenttype(self): stale_ct = ContentType.objects.create(app_label='tests', model='nonexisting') a1 = Model2A.objects.create(field1='A1') a2 = Model2A.objects.create(field1='A2') + c = Model2C.objects.create(field1='C1', field2='C2', field3='C3') + c_base = Model2A.objects.non_polymorphic().get(pk=c.pk) Model2B.objects.filter(pk=a2.pk).update(polymorphic_ctype=stale_ct) - qs1 = Model2A.objects.order_by('field1').non_polymorphic() - qs2 = Model2A.objects.order_by('field1').all() + qs_base = Model2A.objects.order_by('field1').non_polymorphic() + qs_polymorphic = Model2A.objects.order_by('field1').all() - self.assertEqual(list(qs1), [a1, a2]) - self.assertEqual(list(qs2), [a1, a2]) + self.assertEqual(list(qs_base), [a1, a2, c_base]) + self.assertEqual(list(qs_polymorphic), [a1, a2, c]) def test_translate_polymorphic_q_object(self): self.create_model2abcd() @@ -1056,7 +1060,6 @@ class PolymorphicTests(TransactionTestCase): MultiTableDerived(field1='field1', field2='field2') ]) - def test_can_query_using_subclass_selector_on_abstract_model(self): obj = SubclassSelectorAbstractConcreteModel.objects.create(concrete_field='abc') @@ -1078,8 +1081,26 @@ class PolymorphicTests(TransactionTestCase): def test_prefetch_related_behaves_normally_with_polymorphic_model(self): b1 = RelatingModel.objects.create() b2 = RelatingModel.objects.create() - a = b1.many2many.create() - b2.many2many.add(a) + a = b1.many2many.create() # create Model2A + b2.many2many.add(a) # add same to second relating model qs = RelatingModel.objects.prefetch_related('many2many') for obj in qs: self.assertEqual(len(obj.many2many.all()), 1) + + def test_prefetch_related_with_missing(self): + b1 = RelatingModel.objects.create() + b2 = RelatingModel.objects.create() + + rel1 = Model2A.objects.create(field1='A1') + rel2 = Model2B.objects.create(field1='A2', field2='B2') + + b1.many2many.add(rel1) + b2.many2many.add(rel2) + + rel2.delete(keep_parents=True) + + qs = RelatingModel.objects.order_by('pk').prefetch_related('many2many') + objects = list(qs) + self.assertEqual(len(objects[0].many2many.all()), 1) + self.assertEqual(len(objects[1].many2many.all()), 0) # derived object was not fetched + self.assertEqual(len(objects[1].many2many.non_polymorphic()), 1) # base object does exist