Fix dealing with missing derived table data with new prefetching code

fix_request_path_info
Diederik van der Boor 2019-07-12 10:01:40 +02:00
parent 3d9587acfb
commit f769ed7568
No known key found for this signature in database
GPG Key ID: 4FA014E0305E73C1
2 changed files with 39 additions and 12 deletions

View File

@ -364,6 +364,7 @@ class PolymorphicQuerySet(QuerySet):
# upcast it and put it in the results
resultlist.append(transmogrify(real_concrete_class, base_object))
else:
# This model has a concrete derived class, track it for bulk retrieval.
real_concrete_class = content_type_manager.get_for_id(real_concrete_class_id).model_class()
idlist_per_model[real_concrete_class].append(getattr(base_object, pk_name))
indexlist_per_model[real_concrete_class].append((i, len(resultlist)))
@ -416,7 +417,12 @@ class PolymorphicQuerySet(QuerySet):
for i, j in indices:
base_object = base_result_objects[i]
o_pk = getattr(base_object, pk_name)
real_object = copy.copy(real_objects_dict[o_pk])
real_object = real_objects_dict.get(o_pk)
if real_object is None:
continue
# need shallow copy to avoid duplication in caches (see PR #353)
real_object = copy.copy(real_object)
real_class = real_object.get_real_instance_class()
# If the real class is a proxy, upcast it

View File

@ -347,28 +347,32 @@ class PolymorphicTests(TransactionTestCase):
def test_queryset_missing_derived(self):
a = Model2A.objects.create(field1='A1')
b = Model2B.objects.create(field1='B1', field2='B2')
c = Model2C.objects.create(field1='C1', field2='C2', field3='C3')
b_base = Model2A.objects.non_polymorphic().get(pk=b.pk)
c_base = Model2A.objects.non_polymorphic().get(pk=c.pk)
b.delete(keep_parents=True) # e.g. table was truncated
qs1 = Model2A.objects.order_by('field1').non_polymorphic()
qs2 = Model2A.objects.order_by('field1').all()
qs_base = Model2A.objects.order_by('field1').non_polymorphic()
qs_polymorphic = Model2A.objects.order_by('field1').all()
self.assertEqual(list(qs1), [a, b_base])
self.assertEqual(list(qs2), [a])
self.assertEqual(list(qs_base), [a, b_base, c_base])
self.assertEqual(list(qs_polymorphic), [a, c])
def test_queryset_missing_contenttype(self):
stale_ct = ContentType.objects.create(app_label='tests', model='nonexisting')
a1 = Model2A.objects.create(field1='A1')
a2 = Model2A.objects.create(field1='A2')
c = Model2C.objects.create(field1='C1', field2='C2', field3='C3')
c_base = Model2A.objects.non_polymorphic().get(pk=c.pk)
Model2B.objects.filter(pk=a2.pk).update(polymorphic_ctype=stale_ct)
qs1 = Model2A.objects.order_by('field1').non_polymorphic()
qs2 = Model2A.objects.order_by('field1').all()
qs_base = Model2A.objects.order_by('field1').non_polymorphic()
qs_polymorphic = Model2A.objects.order_by('field1').all()
self.assertEqual(list(qs1), [a1, a2])
self.assertEqual(list(qs2), [a1, a2])
self.assertEqual(list(qs_base), [a1, a2, c_base])
self.assertEqual(list(qs_polymorphic), [a1, a2, c])
def test_translate_polymorphic_q_object(self):
self.create_model2abcd()
@ -1056,7 +1060,6 @@ class PolymorphicTests(TransactionTestCase):
MultiTableDerived(field1='field1', field2='field2')
])
def test_can_query_using_subclass_selector_on_abstract_model(self):
obj = SubclassSelectorAbstractConcreteModel.objects.create(concrete_field='abc')
@ -1078,8 +1081,26 @@ class PolymorphicTests(TransactionTestCase):
def test_prefetch_related_behaves_normally_with_polymorphic_model(self):
b1 = RelatingModel.objects.create()
b2 = RelatingModel.objects.create()
a = b1.many2many.create()
b2.many2many.add(a)
a = b1.many2many.create() # create Model2A
b2.many2many.add(a) # add same to second relating model
qs = RelatingModel.objects.prefetch_related('many2many')
for obj in qs:
self.assertEqual(len(obj.many2many.all()), 1)
def test_prefetch_related_with_missing(self):
b1 = RelatingModel.objects.create()
b2 = RelatingModel.objects.create()
rel1 = Model2A.objects.create(field1='A1')
rel2 = Model2B.objects.create(field1='A2', field2='B2')
b1.many2many.add(rel1)
b2.many2many.add(rel2)
rel2.delete(keep_parents=True)
qs = RelatingModel.objects.order_by('pk').prefetch_related('many2many')
objects = list(qs)
self.assertEqual(len(objects[0].many2many.all()), 1)
self.assertEqual(len(objects[1].many2many.all()), 0) # derived object was not fetched
self.assertEqual(len(objects[1].many2many.non_polymorphic()), 1) # base object does exist