Improve reset_polymorphic_ctype() for reliability and test it.

This function can now be safely used on a set of models.
fix_request_path_info
Diederik van der Boor 2017-08-01 12:11:42 +02:00
parent e10deeaebd
commit 171d14f369
2 changed files with 68 additions and 0 deletions

View File

@ -0,0 +1,36 @@
from unittest import TestCase
from polymorphic.models import PolymorphicTypeUndefined
from polymorphic.tests import Model2A, Model2B, Model2C, Model2D
from polymorphic.tests.test_orm import qrepr
from polymorphic.utils import sort_by_subclass, reset_polymorphic_ctype
class UtilsTests(TestCase):
def test_sort_by_subclass(self):
self.assertEqual(
sort_by_subclass(Model2D, Model2B, Model2D, Model2A, Model2C),
[Model2A, Model2B, Model2C, Model2D, Model2D]
)
def test_reset_polymorphic_ctype(self):
"""
Test the the polymorphic_ctype_id can be restored.
"""
Model2A.objects.create(field1='A1')
Model2D.objects.create(field1='A1', field2='B2', field3='C3', field4='D4')
Model2B.objects.create(field1='A1', field2='B2')
Model2B.objects.create(field1='A1', field2='B2')
Model2A.objects.all().update(polymorphic_ctype_id=None)
with self.assertRaises(PolymorphicTypeUndefined):
list(Model2A.objects.all())
reset_polymorphic_ctype(Model2D, Model2B, Model2D, Model2A, Model2C)
self.assertEqual(qrepr(Model2A.objects.all()), (
'[ <Model2A: id 1, field1 (CharField)>,\n'
' <Model2D: id 2, field1 (CharField), field2 (CharField), field3 (CharField), field4 (CharField)>,\n'
' <Model2B: id 3, field1 (CharField), field2 (CharField)>,\n'
' <Model2B: id 4, field1 (CharField), field2 (CharField)> ]'
))

View File

@ -1,3 +1,5 @@
import sys
from django.contrib.contenttypes.models import ContentType from django.contrib.contenttypes.models import ContentType
from django.db import DEFAULT_DB_ALIAS from django.db import DEFAULT_DB_ALIAS
@ -13,6 +15,8 @@ def reset_polymorphic_ctype(*models, **filters):
""" """
using = filters.pop('using', DEFAULT_DB_ALIAS) using = filters.pop('using', DEFAULT_DB_ALIAS)
ignore_existing = filters.pop('ignore_existing', False) ignore_existing = filters.pop('ignore_existing', False)
models = sort_by_subclass(*models)
if ignore_existing: if ignore_existing:
# When excluding models, make sure we don't ignore the models we # When excluding models, make sure we don't ignore the models we
# just assigned the an content type to. hence, start with child first. # just assigned the an content type to. hence, start with child first.
@ -27,3 +31,31 @@ def reset_polymorphic_ctype(*models, **filters):
if filters: if filters:
qs = qs.filter(**filters) qs = qs.filter(**filters)
qs.update(polymorphic_ctype=new_ct) qs.update(polymorphic_ctype=new_ct)
def _compare_mro(cls1, cls2):
if cls1 is cls2:
return 0
try:
index1 = cls1.mro().index(cls2)
except ValueError:
return -1 # cls2 not inherited by 1
try:
index2 = cls2.mro().index(cls1)
except ValueError:
return 1 # cls1 not inherited by 2
return (index1 > index2) - (index1 < index2) # python 3 compatible cmp.
def sort_by_subclass(*classes):
"""
Sort a series of models by their inheritance order.
"""
if sys.version_info[0] == 2:
return sorted(classes, cmp=_compare_mro)
else:
from functools import cmp_to_key
return sorted(classes, key=cmp_to_key(_compare_mro))