From 467e6f517e53cbb8858fca6408d168693d045e7c Mon Sep 17 00:00:00 2001 From: Diederik van der Boor Date: Tue, 1 Aug 2017 12:11:42 +0200 Subject: [PATCH] Improve reset_polymorphic_ctype() for reliability and test it. This function can now be safely used on a set of models. (cherry picked from commit 171d14f36948ca7a625565621824b992eaff8178) --- polymorphic/tests/test_utils.py | 37 +++++++++++++++++++++++++++++++++ polymorphic/utils.py | 32 ++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 polymorphic/tests/test_utils.py diff --git a/polymorphic/tests/test_utils.py b/polymorphic/tests/test_utils.py new file mode 100644 index 0000000..ac56de9 --- /dev/null +++ b/polymorphic/tests/test_utils.py @@ -0,0 +1,37 @@ +from unittest import TestCase + +from polymorphic.models import PolymorphicTypeUndefined +from polymorphic.tests import Model2A, Model2B, Model2C, Model2D +from polymorphic.tests.test_orm import qrepr +from polymorphic.utils import sort_by_subclass, reset_polymorphic_ctype + + +class UtilsTests(TestCase): + maxDiff = 1000 + + def test_sort_by_subclass(self): + self.assertEqual( + sort_by_subclass(Model2D, Model2B, Model2D, Model2A, Model2C), + [Model2A, Model2B, Model2C, Model2D, Model2D] + ) + + def test_reset_polymorphic_ctype(self): + """ + Test the the polymorphic_ctype_id can be restored. + """ + Model2A.objects.create(field1='A1') + Model2D.objects.create(field1='A1', field2='B2', field3='C3', field4='D4') + Model2B.objects.create(field1='A1', field2='B2') + Model2B.objects.create(field1='A1', field2='B2') + Model2A.objects.all().update(polymorphic_ctype_id=None) + + with self.assertRaises(PolymorphicTypeUndefined): + list(Model2A.objects.all()) + + reset_polymorphic_ctype(Model2D, Model2B, Model2D, Model2A, Model2C) + self.assertEqual(repr(list(Model2A.objects.all())), ( + '[,' + ' ,' + ' ,' + ' ]' + )) diff --git a/polymorphic/utils.py b/polymorphic/utils.py index 8f1ac75..a8d968b 100644 --- a/polymorphic/utils.py +++ b/polymorphic/utils.py @@ -1,3 +1,5 @@ +import sys + from django.contrib.contenttypes.models import ContentType from django.db import DEFAULT_DB_ALIAS @@ -13,6 +15,8 @@ def reset_polymorphic_ctype(*models, **filters): """ using = filters.pop('using', DEFAULT_DB_ALIAS) ignore_existing = filters.pop('ignore_existing', False) + + models = sort_by_subclass(*models) if ignore_existing: # When excluding models, make sure we don't ignore the models we # just assigned the an content type to. hence, start with child first. @@ -27,3 +31,31 @@ def reset_polymorphic_ctype(*models, **filters): if filters: qs = qs.filter(**filters) qs.update(polymorphic_ctype=new_ct) + + +def _compare_mro(cls1, cls2): + if cls1 is cls2: + return 0 + + try: + index1 = cls1.mro().index(cls2) + except ValueError: + return -1 # cls2 not inherited by 1 + + try: + index2 = cls2.mro().index(cls1) + except ValueError: + return 1 # cls1 not inherited by 2 + + return (index1 > index2) - (index1 < index2) # python 3 compatible cmp. + + +def sort_by_subclass(*classes): + """ + Sort a series of models by their inheritance order. + """ + if sys.version_info[0] == 2: + return sorted(classes, cmp=_compare_mro) + else: + from functools import cmp_to_key + return sorted(classes, key=cmp_to_key(_compare_mro))