diff --git a/polymorphic/query.py b/polymorphic/query.py index 97fc21a..8cac5e0 100644 --- a/polymorphic/query.py +++ b/polymorphic/query.py @@ -127,6 +127,12 @@ class PolymorphicQuerySet(QuerySet): as_manager.queryset_only = True as_manager = classmethod(as_manager) + def bulk_create(self, objs, batch_size=None): + objs = list(objs) + for obj in objs: + obj.pre_save_polymorphic() + return super(PolymorphicQuerySet, self).bulk_create(objs, batch_size) + def non_polymorphic(self): """switch off polymorphic behaviour for this query. When the queryset is evaluated, only objects of the type of the diff --git a/polymorphic/tests/migrations/0001_initial.py b/polymorphic/tests/migrations/0001_initial.py index b4cd2a2..4e51629 100644 --- a/polymorphic/tests/migrations/0001_initial.py +++ b/polymorphic/tests/migrations/0001_initial.py @@ -1055,4 +1055,81 @@ class Migration(migrations.Migration): name='polymorphic_ctype', field=models.ForeignKey(editable=False, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='polymorphic_tests.inlinemodela_set+', to='contenttypes.ContentType'), ), + migrations.CreateModel( + name='ArtProject', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('topic', models.CharField(max_length=30)), + ('artist', models.CharField(max_length=30)), + ('polymorphic_ctype', + models.ForeignKey(editable=False, null=True, on_delete=django.db.models.deletion.CASCADE, + related_name='polymorphic_tests.artproject_set+', to='contenttypes.ContentType')), + ], + options={ + 'abstract': False, + }, + ), + migrations.CreateModel( + name='Duck', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=30)), + ('polymorphic_ctype', + models.ForeignKey(editable=False, null=True, on_delete=django.db.models.deletion.CASCADE, + related_name='polymorphic_tests.duck_set+', to='contenttypes.ContentType')), + ], + options={ + 'abstract': False, + }, + ), + migrations.CreateModel( + name='MultiTableBase', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('field1', models.CharField(max_length=10)), + ], + options={ + 'abstract': False, + }, + ), + migrations.CreateModel( + name='MultiTableDerived', + fields=[ + ('multitablebase_ptr', + models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, + primary_key=True, serialize=False, to='tests.MultiTableBase')), + ('field2', models.CharField(max_length=10)), + ], + options={ + 'abstract': False, + }, + bases=('tests.multitablebase',), + ), + migrations.AddField( + model_name='multitablebase', + name='polymorphic_ctype', + field=models.ForeignKey(editable=False, null=True, on_delete=django.db.models.deletion.CASCADE, + related_name='polymorphic_tests.multitablebase_set+', + to='contenttypes.ContentType'), + ), + migrations.CreateModel( + name='RedheadDuck', + fields=[ + ], + options={ + 'proxy': True, + 'indexes': [], + }, + bases=('tests.duck',), + ), + migrations.CreateModel( + name='RubberDuck', + fields=[ + ], + options={ + 'proxy': True, + 'indexes': [], + }, + bases=('tests.duck',), + ), ] diff --git a/polymorphic/tests/models.py b/polymorphic/tests/models.py index 651e64e..47f21e4 100644 --- a/polymorphic/tests/models.py +++ b/polymorphic/tests/models.py @@ -428,3 +428,38 @@ class InlineModelA(PolymorphicModel): class InlineModelB(InlineModelA): field2 = models.CharField(max_length=10) + + +class AbstractProject(PolymorphicModel): + topic = models.CharField(max_length=30) + + class Meta: + abstract = True + + +class ArtProject(AbstractProject): + artist = models.CharField(max_length=30) + + +class Duck(PolymorphicModel): + name = models.CharField(max_length=30) + + +class RedheadDuck(Duck): + + class Meta: + proxy = True + + +class RubberDuck(Duck): + + class Meta: + proxy = True + + +class MultiTableBase(PolymorphicModel): + field1 = models.CharField(max_length=10) + + +class MultiTableDerived(MultiTableBase): + field2 = models.CharField(max_length=10) diff --git a/polymorphic/tests/test_orm.py b/polymorphic/tests/test_orm.py index f3da078..441877e 100644 --- a/polymorphic/tests/test_orm.py +++ b/polymorphic/tests/test_orm.py @@ -10,6 +10,7 @@ from django.utils import six from polymorphic.managers import PolymorphicManager from polymorphic.models import PolymorphicTypeUndefined from polymorphic.tests.models import ( + ArtProject, Base, BlogA, BlogB, @@ -19,6 +20,7 @@ from polymorphic.tests.models import ( ChildModelWithManager, CustomPkBase, CustomPkInherit, + Duck, Enhance_Base, Enhance_Inherit, InitTestModelSubclass, @@ -45,6 +47,7 @@ from polymorphic.tests.models import ( ModelWithMyManagerNoDefault, ModelX, ModelY, + MultiTableDerived, MyManager, MyManagerQuerySet, NonProxyChild, @@ -65,10 +68,12 @@ from polymorphic.tests.models import ( ProxyModelB, ProxyModelBase, QuerySet, + RedheadDuck, RelationA, RelationB, RelationBC, RelationBase, + RubberDuck, TestParentLinkAndRelatedName, UUIDArtProject, UUIDPlainA, @@ -953,6 +958,46 @@ class PolymorphicTests(TransactionTestCase): with self.assertRaises(PolymorphicTypeUndefined): list(Model2A.objects.all()) + def test_bulk_create_abstract_inheritance(self): + ArtProject.objects.bulk_create([ + ArtProject(topic='Painting with Tim', artist='T. Turner'), + ArtProject(topic='Sculpture with Tim', artist='T. Turner'), + ]) + self.assertEqual( + sorted(ArtProject.objects.values_list('topic', 'artist')), + [('Painting with Tim', 'T. Turner'), ('Sculpture with Tim', 'T. Turner')] + ) + + def test_bulk_create_proxy_inheritance(self): + RedheadDuck.objects.bulk_create([ + RedheadDuck(name='redheadduck1'), + Duck(name='duck1'), + RubberDuck(name='rubberduck1'), + ]) + RubberDuck.objects.bulk_create([ + RedheadDuck(name='redheadduck2'), + RubberDuck(name='rubberduck2'), + Duck(name='duck2'), + ]) + self.assertEqual( + sorted(RedheadDuck.objects.values_list('name', flat=True)), + ['redheadduck1', 'redheadduck2'], + ) + self.assertEqual( + sorted(RubberDuck.objects.values_list('name', flat=True)), + ['rubberduck1', 'rubberduck2'], + ) + self.assertEqual( + sorted(Duck.objects.values_list('name', flat=True)), + ['duck1', 'duck2', 'redheadduck1', 'redheadduck2', 'rubberduck1', 'rubberduck2'], + ) + + def test_bulk_create_unsupported_multi_table_inheritance(self): + with self.assertRaises(ValueError): + MultiTableDerived.objects.bulk_create([ + MultiTableDerived(field1='field1', field2='field2') + ]) + def qrepr(data): """