diff --git a/polymorphic/managers.py b/polymorphic/managers.py index 0f6c9de..2f38a7d 100644 --- a/polymorphic/managers.py +++ b/polymorphic/managers.py @@ -32,7 +32,7 @@ class PolymorphicManager(models.Manager): super(PolymorphicManager, self).__init__(*args, **kwrags) def get_queryset(self): - qs = self.queryset_class(self.model, using=self._db) + qs = self.queryset_class(self.model, using=self._db, hints=self._hints) if self.model._meta.proxy: qs = qs.instance_of(self.model) return qs diff --git a/polymorphic/query.py b/polymorphic/query.py index b776f55..ace48c2 100644 --- a/polymorphic/query.py +++ b/polymorphic/query.py @@ -112,8 +112,8 @@ class PolymorphicQuerySet(QuerySet): def _filter_or_exclude(self, negate, *args, **kwargs): "We override this internal Django functon as it is used for all filter member functions." - q_objects = translate_polymorphic_filter_definitions_in_args(self.model, args, using=self._db) # the Q objects - additional_args = translate_polymorphic_filter_definitions_in_kwargs(self.model, kwargs, using=self._db) # filter_field='data' + q_objects = translate_polymorphic_filter_definitions_in_args(self.model, args, using=self.db) # the Q objects + additional_args = translate_polymorphic_filter_definitions_in_kwargs(self.model, kwargs, using=self.db) # filter_field='data' return super(PolymorphicQuerySet, self)._filter_or_exclude(negate, *(list(q_objects) + additional_args), **kwargs) def order_by(self, *args, **kwargs): @@ -309,7 +309,7 @@ class PolymorphicQuerySet(QuerySet): # - also record the correct result order in "ordered_id_list" # - store objects that already have the correct class into "results" base_result_objects_by_id = {} - content_type_manager = ContentType.objects.db_manager(self._db) + content_type_manager = ContentType.objects.db_manager(self.db) self_model_class_id = content_type_manager.get_for_model(self.model, for_concrete_model=False).pk self_concrete_model_class_id = content_type_manager.get_for_model(self.model, for_concrete_model=True).pk @@ -345,7 +345,7 @@ class PolymorphicQuerySet(QuerySet): # Then we copy the extra() select fields from the base objects to the real objects. # TODO: defer(), only(): support for these would be around here for real_concrete_class, idlist in idlist_per_model.items(): - real_objects = real_concrete_class.base_objects.db_manager(self._db).filter(**{ + real_objects = real_concrete_class.base_objects.db_manager(self.db).filter(**{ ('%s__in' % pk_name): idlist, }) real_objects.query.select_related = self.query.select_related # copy select related configuration to new qs diff --git a/polymorphic/tests.py b/polymorphic/tests.py index 672d01b..361d797 100644 --- a/polymorphic/tests.py +++ b/polymorphic/tests.py @@ -429,9 +429,6 @@ class PolymorphicTests(TestCase): """ The test suite """ - - multi_db = True - def test_annotate_aggregate_order(self): # create a blog of type BlogA # create two blog entries in BlogA @@ -1182,6 +1179,30 @@ class PolymorphicTests(TestCase): result = DateModel.objects.annotate(val=DateTime('date', 'day', utc)) self.assertEqual(list(result), []) +class RegressionTests(TestCase): + + def test_for_query_result_incomplete_with_inheritance(self): + """ https://github.com/bconstantin/django_polymorphic/issues/15 """ + + top = Top() + top.save() + middle = Middle() + middle.save() + bottom = Bottom() + bottom.save() + + expected_queryset = [top, middle, bottom] + self.assertQuerysetEqual(Top.objects.all(), [repr(r) for r in expected_queryset]) + + expected_queryset = [middle, bottom] + self.assertQuerysetEqual(Middle.objects.all(), [repr(r) for r in expected_queryset]) + + expected_queryset = [bottom] + self.assertQuerysetEqual(Bottom.objects.all(), [repr(r) for r in expected_queryset]) + +class MultipleDatabasesTests(TestCase): + multi_db = True + def test_save_to_non_default_database(self): Model2A.objects.db_manager('secondary').create(field1='A1') Model2C(field1='C1', field2='C2', field3='C3').save(using='secondary') @@ -1222,24 +1243,47 @@ class PolymorphicTests(TestCase): self.assertEqual(repr(objects[0]), '') self.assertEqual(repr(objects[1]), '') -class RegressionTests(TestCase): + def test_forward_many_to_one_descriptor_on_non_default_database(self): + def func(): + blog = BlogA.objects.db_manager('secondary').create(name='Blog', info='Info') + entry = BlogEntry.objects.db_manager('secondary').create(blog=blog, text='Text') + ContentType.objects.clear_cache() + entry = BlogEntry.objects.db_manager('secondary').get(pk=entry.id) + self.assertEqual(blog, entry.blog) - def test_for_query_result_incomplete_with_inheritance(self): - """ https://github.com/bconstantin/django_polymorphic/issues/15 """ + # Ensure no queries are made using the default database. + self.assertNumQueries(0, func) - top = Top() - top.save() - middle = Middle() - middle.save() - bottom = Bottom() - bottom.save() + def test_reverse_many_to_one_descriptor_on_non_default_database(self): + def func(): + blog = BlogA.objects.db_manager('secondary').create(name='Blog', info='Info') + entry = BlogEntry.objects.db_manager('secondary').create(blog=blog, text='Text') + ContentType.objects.clear_cache() + blog = BlogA.objects.db_manager('secondary').get(pk=blog.id) + self.assertEqual(entry, blog.blogentry_set.using('secondary').get()) - expected_queryset = [top, middle, bottom] - self.assertQuerysetEqual(Top.objects.all(), [repr(r) for r in expected_queryset]) + # Ensure no queries are made using the default database. + self.assertNumQueries(0, func) - expected_queryset = [middle, bottom] - self.assertQuerysetEqual(Middle.objects.all(), [repr(r) for r in expected_queryset]) + def test_reverse_one_to_one_descriptor_on_non_default_database(self): + def func(): + m2a = Model2A.objects.db_manager('secondary').create(field1='A1') + one2one = One2OneRelatingModel.objects.db_manager('secondary').create(one2one=m2a, field1='121') + ContentType.objects.clear_cache() + m2a = Model2A.objects.db_manager('secondary').get(pk=m2a.id) + self.assertEqual(one2one, m2a.one2onerelatingmodel) - expected_queryset = [bottom] - self.assertQuerysetEqual(Bottom.objects.all(), [repr(r) for r in expected_queryset]) + # Ensure no queries are made using the default database. + self.assertNumQueries(0, func) + def test_many_to_many_descriptor_on_non_default_database(self): + def func(): + m2a = Model2A.objects.db_manager('secondary').create(field1='A1') + rm = RelatingModel.objects.db_manager('secondary').create() + rm.many2many.add(m2a) + ContentType.objects.clear_cache() + m2a = Model2A.objects.db_manager('secondary').get(pk=m2a.id) + self.assertEqual(rm, m2a.relatingmodel_set.using('secondary').get()) + + # Ensure no queries are made using the default database. + self.assertNumQueries(0, func)