Improve reset_polymorphic_ctype() for reliability and test it.
This function can now be safely used on a set of models.
(cherry picked from commit 171d14f369)
fix_request_path_info
parent
db46dbb446
commit
467e6f517e
|
|
@ -0,0 +1,37 @@
|
|||
from unittest import TestCase
|
||||
|
||||
from polymorphic.models import PolymorphicTypeUndefined
|
||||
from polymorphic.tests import Model2A, Model2B, Model2C, Model2D
|
||||
from polymorphic.tests.test_orm import qrepr
|
||||
from polymorphic.utils import sort_by_subclass, reset_polymorphic_ctype
|
||||
|
||||
|
||||
class UtilsTests(TestCase):
|
||||
maxDiff = 1000
|
||||
|
||||
def test_sort_by_subclass(self):
|
||||
self.assertEqual(
|
||||
sort_by_subclass(Model2D, Model2B, Model2D, Model2A, Model2C),
|
||||
[Model2A, Model2B, Model2C, Model2D, Model2D]
|
||||
)
|
||||
|
||||
def test_reset_polymorphic_ctype(self):
|
||||
"""
|
||||
Test the the polymorphic_ctype_id can be restored.
|
||||
"""
|
||||
Model2A.objects.create(field1='A1')
|
||||
Model2D.objects.create(field1='A1', field2='B2', field3='C3', field4='D4')
|
||||
Model2B.objects.create(field1='A1', field2='B2')
|
||||
Model2B.objects.create(field1='A1', field2='B2')
|
||||
Model2A.objects.all().update(polymorphic_ctype_id=None)
|
||||
|
||||
with self.assertRaises(PolymorphicTypeUndefined):
|
||||
list(Model2A.objects.all())
|
||||
|
||||
reset_polymorphic_ctype(Model2D, Model2B, Model2D, Model2A, Model2C)
|
||||
self.assertEqual(repr(list(Model2A.objects.all())), (
|
||||
'[<Model2A: id 1, field1 (CharField)>,'
|
||||
' <Model2D: id 2, field1 (CharField), field2 (CharField), field3 (CharField), field4 (CharField)>,'
|
||||
' <Model2B: id 3, field1 (CharField), field2 (CharField)>,'
|
||||
' <Model2B: id 4, field1 (CharField), field2 (CharField)>]'
|
||||
))
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
import sys
|
||||
|
||||
from django.contrib.contenttypes.models import ContentType
|
||||
from django.db import DEFAULT_DB_ALIAS
|
||||
|
||||
|
|
@ -13,6 +15,8 @@ def reset_polymorphic_ctype(*models, **filters):
|
|||
"""
|
||||
using = filters.pop('using', DEFAULT_DB_ALIAS)
|
||||
ignore_existing = filters.pop('ignore_existing', False)
|
||||
|
||||
models = sort_by_subclass(*models)
|
||||
if ignore_existing:
|
||||
# When excluding models, make sure we don't ignore the models we
|
||||
# just assigned the an content type to. hence, start with child first.
|
||||
|
|
@ -27,3 +31,31 @@ def reset_polymorphic_ctype(*models, **filters):
|
|||
if filters:
|
||||
qs = qs.filter(**filters)
|
||||
qs.update(polymorphic_ctype=new_ct)
|
||||
|
||||
|
||||
def _compare_mro(cls1, cls2):
|
||||
if cls1 is cls2:
|
||||
return 0
|
||||
|
||||
try:
|
||||
index1 = cls1.mro().index(cls2)
|
||||
except ValueError:
|
||||
return -1 # cls2 not inherited by 1
|
||||
|
||||
try:
|
||||
index2 = cls2.mro().index(cls1)
|
||||
except ValueError:
|
||||
return 1 # cls1 not inherited by 2
|
||||
|
||||
return (index1 > index2) - (index1 < index2) # python 3 compatible cmp.
|
||||
|
||||
|
||||
def sort_by_subclass(*classes):
|
||||
"""
|
||||
Sort a series of models by their inheritance order.
|
||||
"""
|
||||
if sys.version_info[0] == 2:
|
||||
return sorted(classes, cmp=_compare_mro)
|
||||
else:
|
||||
from functools import cmp_to_key
|
||||
return sorted(classes, key=cmp_to_key(_compare_mro))
|
||||
|
|
|
|||
Loading…
Reference in New Issue