diff --git a/polymorphic/polymorphic_model.py b/polymorphic/polymorphic_model.py index 487f3b1..fef0db1 100644 --- a/polymorphic/polymorphic_model.py +++ b/polymorphic/polymorphic_model.py @@ -23,7 +23,6 @@ from .base import PolymorphicModelBase from .manager import PolymorphicManager from .query_translate import translate_polymorphic_Q_object - ################################################################################### ### PolymorphicModel @@ -185,27 +184,37 @@ class PolymorphicModel(six.with_metaclass(PolymorphicModelBase, models.Model)): """helper function for __init__: determine names of all Django inheritance accessor member functions for type(self)""" - def add_model(model, as_ptr, result): - name = model.__name__.lower() - if as_ptr: - name += '_ptr' - result[name] = model + def add_model(model, field_name, result): + result[field_name] = model - def add_model_if_regular(model, as_ptr, result): + def add_model_if_regular(model, field_name, result): if (issubclass(model, models.Model) and model != models.Model and model != self.__class__ and model != PolymorphicModel): - add_model(model, as_ptr, result) + add_model(model, field_name, result) - def add_all_super_models(model, result): - add_model_if_regular(model, True, result) - for b in model.__bases__: - add_all_super_models(b, result) + def add_all_super_models(model, result): + for super_cls, field_to_super in model._meta.parents.items(): + if field_to_super is not None: #if not a link to a proxy model + field_name = field_to_super.name #the field on model can have a different name to super_cls._meta.module_name, if the field is created manually using 'parent_link' + add_model_if_regular(super_cls, field_name, result) + add_all_super_models(super_cls, result) - def add_all_sub_models(model, result): - for b in model.__subclasses__(): - add_model_if_regular(b, False, result) + def add_all_sub_models(super_cls, result): + for sub_cls in super_cls.__subclasses__(): #go through all subclasses of model + if super_cls in sub_cls._meta.parents: #super_cls may not be in sub_cls._meta.parents if super_cls is a proxy model + field_to_super = sub_cls._meta.parents[super_cls] #get the field that links sub_cls to super_cls + if field_to_super is not None: # if filed_to_super is not a link to a proxy model + super_to_sub_related_field = field_to_super.rel + if super_to_sub_related_field.related_name is None: + #if related name is None the related field is the name of the subclass + to_subclass_fieldname = sub_cls.__name__.lower() + else: + #otherwise use the given related name + to_subclass_fieldname = super_to_sub_related_field.related_name + + add_model_if_regular(sub_cls, to_subclass_fieldname, result) result = {} add_all_super_models(self.__class__, result) diff --git a/polymorphic/tests.py b/polymorphic/tests.py index 9f34c05..096a74f 100644 --- a/polymorphic/tests.py +++ b/polymorphic/tests.py @@ -279,6 +279,10 @@ class ProxyModelB(ProxyModelBase): class RelatedNameClash(ShowFieldType, PolymorphicModel): ctype = models.ForeignKey(ContentType, null=True, editable=False) +#class with a parent_link to superclass, and a related_name back to subclass +class TestParentLinkAndRelatedName(ModelShow1_plain): + superclass = models.OneToOneField(ModelShow1_plain, parent_link=True, related_name = 'related_name_subclass') + class PolymorphicTests(TestCase): """ @@ -826,6 +830,24 @@ class PolymorphicTests(TestCase): o = InitTestModelSubclass.objects.create() self.assertEqual(o.bar, 'XYZ') + def test_parent_link_and_related_name(self): + t = TestParentLinkAndRelatedName(field1 = "TestParentLinkAndRelatedName") + t.save() + p = ModelShow1_plain.objects.get(field1 = "TestParentLinkAndRelatedName") + + #check that p is equal to the + self.assertIsInstance(p, TestParentLinkAndRelatedName) + self.assertEqual(p, t) + + #check that the accessors to parent and sublass work correctly and return the right object + p = ModelShow1_plain.objects.non_polymorphic().get(field1 = "TestParentLinkAndRelatedName") + self.assertNotEqual(p, t) #p should be Plain1 and t TestParentLinkAndRelatedName, so not equal + self.assertEqual(p, t.superclass) + self.assertEqual(p.related_name_subclass, t) + + #test that we can delete t + t.delete() + class RegressionTests(TestCase):