diff --git a/polymorphic/query.py b/polymorphic/query.py index eab634e..150c505 100644 --- a/polymorphic/query.py +++ b/polymorphic/query.py @@ -167,6 +167,15 @@ class PolymorphicQuerySet(QuerySet): # in the format idlist_per_model[modelclass]=[list-of-object-ids] idlist_per_model = defaultdict(list) + # django's automatic ".pk" field does not always work correctly for + # custom fields in derived objects (unclear yet who to put the blame on). + # We get different type(o.pk) in this case. + # We work around this by using the real name of the field directly + # for accessing the primary key of the the derived objects. + # We might assume that self.model._meta.pk.name gives us the name of the primary key field, + # but it doesn't. Therefore we use polymorphic_primary_key_name, which we set up in base.py. + pk_name = self.model.polymorphic_primary_key_name + # - sort base_result_object ids into idlist_per_model lists, depending on their real class; # - also record the correct result order in "ordered_id_list" # - store objects that already have the correct class into "results" @@ -198,16 +207,7 @@ class PolymorphicQuerySet(QuerySet): results[base_object.pk] = transmogrify(real_concrete_class, base_object) else: real_concrete_class = ContentType.objects.get_for_id(real_concrete_class_id).model_class() - idlist_per_model[real_concrete_class].append(base_object.pk) - - # django's automatic ".pk" field does not always work correctly for - # custom fields in derived objects (unclear yet who to put the blame on). - # We get different type(o.pk) in this case. - # We work around this by using the real name of the field directly - # for accessing the primary key of the the derived objects. - # We might assume that self.model._meta.pk.name gives us the name of the primary key field, - # but it doesn't. Therefore we use polymorphic_primary_key_name, which we set up in base.py. - pk_name = self.model.polymorphic_primary_key_name + idlist_per_model[real_concrete_class].append(getattr(base_object, pk_name)) # For each model in "idlist_per_model" request its objects (the real model) # from the db and store them in results[]. @@ -215,7 +215,9 @@ class PolymorphicQuerySet(QuerySet): # Then we copy the extra() select fields from the base objects to the real objects. # TODO: defer(), only(): support for these would be around here for real_concrete_class, idlist in idlist_per_model.items(): - real_objects = real_concrete_class.base_objects.filter(pk__in=idlist) # use pk__in instead #### + real_objects = real_concrete_class.base_objects.filter(**{ + ('%s__in' % pk_name): idlist, + }) real_objects.query.select_related = self.query.select_related # copy select related configuration to new qs for real_object in real_objects: diff --git a/polymorphic/tests.py b/polymorphic/tests.py index 1a28bcd..c9493fc 100644 --- a/polymorphic/tests.py +++ b/polymorphic/tests.py @@ -283,6 +283,12 @@ class RelatedNameClash(ShowFieldType, PolymorphicModel): class TestParentLinkAndRelatedName(ModelShow1_plain): superclass = models.OneToOneField(ModelShow1_plain, parent_link=True, related_name='related_name_subclass') +class CustomPkBase(ShowFieldTypeAndContent, PolymorphicModel): + b = models.CharField(max_length=1) +class CustomPkInherit(CustomPkBase): + custom_id = models.AutoField(primary_key=True) + i = models.CharField(max_length=1) + class PolymorphicTests(TestCase): """ @@ -304,7 +310,6 @@ class PolymorphicTests(TestCase): print('DiamondXY fields 1: field_b "{0}", field_x "{1}", field_y "{2}"'.format(o1.field_b, o1.field_x, o1.field_y)) print('DiamondXY fields 2: field_b "{0}", field_x "{1}", field_y "{2}"'.format(o2.field_b, o2.field_x, o2.field_y)) - def test_annotate_aggregate_order(self): # create a blog of type BlogA # create two blog entries in BlogA @@ -842,6 +847,13 @@ class PolymorphicTests(TestCase): self.assertIsInstance(objects[0], ProxyModelA) self.assertIsInstance(objects[1], ProxyModelB) + def test_custom_pk(self): + CustomPkBase.objects.create(b='b') + CustomPkInherit.objects.create(b='b', i='i') + qs = CustomPkBase.objects.all() + self.assertEqual(len(qs), 2) + self.assertEqual(repr(qs[0]), '') + self.assertEqual(repr(qs[1]), '') def test_fix_getattribute(self): ### fixed issue in PolymorphicModel.__getattribute__: field name same as model name