diff --git a/polymorphic/base.py b/polymorphic/base.py index a407285..83cb7a4 100644 --- a/polymorphic/base.py +++ b/polymorphic/base.py @@ -77,15 +77,17 @@ class PolymorphicModelBase(ModelBase): for source_name, mgr_name, manager in inherited_managers: #print '** add inherited manager from model %s, manager %s, %s' % (source_name, mgr_name, manager.__class__.__name__) new_manager = manager._copy_to_model(new_class) - new_class.add_to_class(mgr_name, new_manager) + if mgr_name == '_default_manager': + new_class._default_manager = new_manager + else: + new_class.add_to_class(mgr_name, new_manager) # get first user defined manager; if there is one, make it the _default_manager # this value is used by the related objects, restoring access to custom queryset methods on related objects. user_manager = self.get_first_user_defined_manager(new_class) if user_manager: - def_mgr = user_manager._copy_to_model(new_class) #print '## add default manager', type(def_mgr) - new_class.add_to_class('_default_manager', def_mgr) + new_class._default_manager = user_manager._copy_to_model(new_class) new_class._default_manager._inherited = False # the default mgr was defined by the user, not inherited # validate resulting default manager diff --git a/polymorphic/tests.py b/polymorphic/tests.py index 64968ed..1a28bcd 100644 --- a/polymorphic/tests.py +++ b/polymorphic/tests.py @@ -506,6 +506,29 @@ class PolymorphicTests(TestCase): self.assertEqual(show_base_manager(One2OneRelatingModelDerived), " ") + def test_instance_default_manager(self): + def show_default_manager(instance): + return "{0} {1}".format( + repr(type(instance._default_manager)), + repr(instance._default_manager.model) + ) + + plain_a = PlainA(field1='C1') + plain_b = PlainB(field2='C1') + plain_c = PlainC(field3='C1') + + model_2a = Model2A(field1='C1') + model_2b = Model2B(field2='C1') + model_2c = Model2C(field3='C1') + + self.assertEqual(show_default_manager(plain_a), " ") + self.assertEqual(show_default_manager(plain_b), " ") + self.assertEqual(show_default_manager(plain_c), " ") + + self.assertEqual(show_default_manager(model_2a), " ") + self.assertEqual(show_default_manager(model_2b), " ") + self.assertEqual(show_default_manager(model_2c), " ") + def test_foreignkey_field(self): self.create_model2abcd()