From a4ac6cc91d4f4f1263ed84d4af79a32cba553dbb Mon Sep 17 00:00:00 2001 From: Bert Constantin Date: Sat, 30 Oct 2010 15:54:13 +0200 Subject: [PATCH] fix object retrieval problem occuring with some custom primary key fields + added UUIDField as test case --- polymorphic/query.py | 18 +++-- polymorphic/test_tools.py | 145 ++++++++++++++++++++++++++++++++++++++ polymorphic/tests.py | 68 ++++++++++++++---- 3 files changed, 211 insertions(+), 20 deletions(-) create mode 100644 polymorphic/test_tools.py diff --git a/polymorphic/query.py b/polymorphic/query.py index b24de6a..1bb8b16 100644 --- a/polymorphic/query.py +++ b/polymorphic/query.py @@ -156,27 +156,36 @@ class PolymorphicQuerySet(QuerySet): else: idlist_per_model[base_object.get_real_instance_class()].append(base_object.pk) + # django's automatic ".pk" field does not always work correctly for + # custom fields in derived objects (unclear yet who to put the blame on). + # We get different type(o.pk) in this case. + # We work around this by using the real name of the field directly + # for accessing the primary key of the the derived objects. + pk_name = self.model._meta.pk.name + # For each model in "idlist_per_model" request its objects (the real model) # from the db and store them in results[]. # Then we copy the annotate fields from the base objects to the real objects. # Then we copy the extra() select fields from the base objects to the real objects. # TODO: defer(), only(): support for these would be around here for modelclass, idlist in idlist_per_model.items(): - qs = modelclass.base_objects.filter(id__in=idlist) + qs = modelclass.base_objects.filter(pk__in=idlist) # use pk__in instead #### qs.dup_select_related(self) # copy select related configuration to new qs for o in qs: + o_pk=getattr(o,pk_name) + if self.query.aggregates: for anno_field_name in self.query.aggregates.keys(): - attr = getattr(base_result_objects_by_id[o.pk], anno_field_name) + attr = getattr(base_result_objects_by_id[o_pk], anno_field_name) setattr(o, anno_field_name, attr) if self.query.extra_select: for select_field_name in self.query.extra_select.keys(): - attr = getattr(base_result_objects_by_id[o.pk], select_field_name) + attr = getattr(base_result_objects_by_id[o_pk], select_field_name) setattr(o, select_field_name, attr) - results[o.pk] = o + results[o_pk] = o # re-create correct order and return result list resultlist = [ results[ordered_id] for ordered_id in ordered_id_list if ordered_id in results ] @@ -193,7 +202,6 @@ class PolymorphicQuerySet(QuerySet): for o in resultlist: o.polymorphic_extra_select_names=extra_select_names - return resultlist def iterator(self): diff --git a/polymorphic/test_tools.py b/polymorphic/test_tools.py new file mode 100644 index 0000000..3bdb639 --- /dev/null +++ b/polymorphic/test_tools.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- + +#################################################################### + +import uuid + +from django.forms.util import ValidationError +from django import forms +from django.db import models +from django.utils.encoding import smart_unicode +from django.utils.translation import ugettext_lazy + +class UUIDVersionError(Exception): + pass + +class UUIDField(models.CharField): + """Encode and stores a Python uuid.UUID in a manner that is appropriate + for the given datatabase that we are using. + + For sqlite3 or MySQL we save it as a 36-character string value + For PostgreSQL we save it as a uuid field + + This class supports type 1, 2, 4, and 5 UUID's. + """ + __metaclass__ = models.SubfieldBase + + _CREATE_COLUMN_TYPES = { + 'postgresql_psycopg2': 'uuid', + 'postgresql': 'uuid' + } + + def __init__(self, verbose_name=None, name=None, auto=True, version=1, node=None, clock_seq=None, namespace=None, **kwargs): + """Contruct a UUIDField. + + @param verbose_name: Optional verbose name to use in place of what + Django would assign. + @param name: Override Django's name assignment + @param auto: If True, create a UUID value if one is not specified. + @param version: By default we create a version 1 UUID. + @param node: Used for version 1 UUID's. If not supplied, then the uuid.getnode() function is called to obtain it. This can be slow. + @param clock_seq: Used for version 1 UUID's. If not supplied a random 14-bit sequence number is chosen + @param namespace: Required for version 3 and version 5 UUID's. + @param name: Required for version4 and version 5 UUID's. + + See Also: + - Python Library Reference, section 18.16 for more information. + - RFC 4122, "A Universally Unique IDentifier (UUID) URN Namespace" + + If you want to use one of these as a primary key for a Django + model, do this:: + id = UUIDField(primary_key=True) + This will currently I{not} work with Jython because PostgreSQL support + in Jython is not working for uuid column types. + """ + self.max_length = 36 + kwargs['max_length'] = self.max_length + if auto: + kwargs['blank'] = True + kwargs.setdefault('editable', False) + + self.auto = auto + self.version = version + if version==1: + self.node, self.clock_seq = node, clock_seq + elif version==3 or version==5: + self.namespace, self.name = namespace, name + + super(UUIDField, self).__init__(verbose_name=verbose_name, + name=name, **kwargs) + + def create_uuid(self): + if not self.version or self.version==4: + return uuid.uuid4() + elif self.version==1: + return uuid.uuid1(self.node, self.clock_seq) + elif self.version==2: + raise UUIDVersionError("UUID version 2 is not supported.") + elif self.version==3: + return uuid.uuid3(self.namespace, self.name) + elif self.version==5: + return uuid.uuid5(self.namespace, self.name) + else: + raise UUIDVersionError("UUID version %s is not valid." % self.version) + + def db_type(self): + from django.conf import settings + return UUIDField._CREATE_COLUMN_TYPES.get(settings.DATABASE_ENGINE, "char(%s)" % self.max_length) + + def to_python(self, value): + """Return a uuid.UUID instance from the value returned by the database.""" + # + # This is the proper way... But this doesn't work correctly when + # working with an inherited model + # + if not value: + return None + if isinstance(value, uuid.UUID): + return value + # attempt to parse a UUID + return uuid.UUID(smart_unicode(value)) + + # + # If I do the following (returning a String instead of a UUID + # instance), everything works. + # + + #if not value: + # return None + #if isinstance(value, uuid.UUID): + # return smart_unicode(value) + #else: + # return value + + def pre_save(self, model_instance, add): + if self.auto and add: + value = self.create_uuid() + setattr(model_instance, self.attname, value) + else: + value = super(UUIDField, self).pre_save(model_instance,add) + if self.auto and not value: + value = self.create_uuid() + setattr(model_instance, self.attname, value) + return value + + def get_db_prep_value(self, value): + """Casts uuid.UUID values into the format expected by the back end for use in queries""" + if isinstance(value, uuid.UUID): + return smart_unicode(value) + return value + + def value_to_string(self, obj): + val = self._get_val_from_obj(obj) + if val is None: + data = '' + else: + data = smart_unicode(val) + return data + + def formfield(self, **kwargs): + defaults = { + 'form_class': forms.CharField, + 'max_length': self.max_length + } + defaults.update(kwargs) + return super(UUIDField, self).formfield(**defaults) diff --git a/polymorphic/tests.py b/polymorphic/tests.py index a97f0f5..fd27c05 100644 --- a/polymorphic/tests.py +++ b/polymorphic/tests.py @@ -71,8 +71,7 @@ class Enhance_Base(ShowFieldTypeAndContent, PolymorphicModel): field_b = models.CharField(max_length=10) class Enhance_Inherit(Enhance_Base, Enhance_Plain): field_i = models.CharField(max_length=10) - - + class DiamondBase(models.Model): field_b = models.CharField(max_length=10) class DiamondX(DiamondBase): @@ -102,7 +101,7 @@ class One2OneRelatingModel(PolymorphicModel): class One2OneRelatingModelDerived(One2OneRelatingModel): field2 = models.CharField(max_length=10) - + class MyManager(PolymorphicManager): def get_query_set(self): return super(MyManager, self).get_query_set().order_by('-field1') @@ -113,18 +112,18 @@ class ModelWithMyManager(ShowFieldTypeAndContent, Model2A): class MROBase1(ShowFieldType, PolymorphicModel): objects = MyManager() field1 = models.CharField(max_length=10) # needed as MyManager uses it -class MROBase2(MROBase1): +class MROBase2(MROBase1): pass # Django vanilla inheritance does not inherit MyManager as _default_manager here class MROBase3(models.Model): objects = PolymorphicManager() -class MRODerived(MROBase2, MROBase3): +class MRODerived(MROBase2, MROBase3): pass class MgrInheritA(models.Model): mgrA = models.Manager() mgrA2 = models.Manager() field1 = models.CharField(max_length=10) -class MgrInheritB(MgrInheritA): +class MgrInheritB(MgrInheritA): mgrB = models.Manager() field2 = models.CharField(max_length=10) class MgrInheritC(ShowFieldTypeAndContent, MgrInheritB): @@ -156,6 +155,19 @@ class InitTestModelSubclass(InitTestModel): def x(self): return 'XYZ' +try: from polymorphic.test_tools import UUIDField +except: pass +if 'UUIDField' in globals(): + import uuid + class UUIDProject(ShowFieldTypeAndContent, PolymorphicModel): + id = UUIDField(primary_key = True) + topic = models.CharField(max_length = 30) + class UUIDArtProject(UUIDProject): + artist = models.CharField(max_length = 30) + class UUIDResearchProject(UUIDProject): + supervisor = models.CharField(max_length = 30) + + # test bad field name #class TestBadFieldModel(ShowFieldType, PolymorphicModel): @@ -165,17 +177,19 @@ class InitTestModelSubclass(InitTestModel): # with related field 'ContentType.relatednameclash_set'." (reported by Andrew Ingram) # fixed with related_name class RelatedNameClash(ShowFieldType, PolymorphicModel): - ctype = models.ForeignKey(ContentType, null=True, editable=False) + ctype = models.ForeignKey(ContentType, null=True, editable=False) class testclass(TestCase): - def test_diamond_inheritance(self): + def test_diamond_inheritance(self): # Django diamond problem o = DiamondXY.objects.create(field_b='b', field_x='x', field_y='y') print 'DiamondXY fields 1: field_b "%s", field_x "%s", field_y "%s"' % (o.field_b, o.field_x, o.field_y) o = DiamondXY.objects.get() print 'DiamondXY fields 2: field_b "%s", field_x "%s", field_y "%s"' % (o.field_b, o.field_x, o.field_y) - if o.field_b != 'b': print '# Django model inheritance diamond problem detected' + if o.field_b != 'b': + print + print '# known django model inheritance diamond problem detected' def test_annotate_aggregate_order(self): @@ -199,7 +213,7 @@ class testclass(TestCase): assert o.entrycount == 2 else: assert o.entrycount == 0 - + x = BlogBase.objects.aggregate(entrycount=Count('BlogA___blogentry')) assert x['entrycount'] == 2 @@ -250,7 +264,6 @@ class testclass(TestCase): x = '\n' + repr(BlogBase.objects.order_by('-BlogA___info')) assert x == expected1 or x == expected2 - #assert False def test_limit_choices_to(self): "this is not really a testcase, as limit_choices_to only affects the Django admin" @@ -262,9 +275,34 @@ class testclass(TestCase): entry2 = BlogEntry_limit_choices_to.objects.create(blog=blog_b, text='bla2') + def test_primary_key_custom_field_problem(self): + "object retrieval problem occuring with some custom primary key fields (UUIDField as test case)" + if not 'UUIDField' in globals(): return + a=UUIDProject.objects.create(topic="John's gathering") + b=UUIDArtProject.objects.create(topic="Sculpting with Tim", artist="T. Turner") + c=UUIDResearchProject.objects.create(topic="Swallow Aerodynamics", supervisor="Dr. Winter") + qs=UUIDProject.objects.all() + ol=list(qs) + a=qs[0] + b=qs[1] + c=qs[2] + assert len(qs)==3 + assert type(a.id)==uuid.UUID and type(a.pk)==uuid.UUID + res=repr(qs) + import re + res=re.sub(' id ...................................., topic',' id, topic',res) + res_exp="""[ , + , + ]""" + assert res==res_exp + if (a.pk!= uuid.UUID or c.pk!= uuid.UUID): + print + print '# known django object inconstency with custom primary key field detected' + + def show_base_manager(model): print type(model._base_manager),model._base_manager.model - + __test__ = {"doctest": """ ####################################################### ### Tests @@ -480,7 +518,7 @@ __test__ = {"doctest": """ , , ] - + >>> oa=RelationBase.objects.get(id=2) >>> oa.fk @@ -521,7 +559,7 @@ __test__ = {"doctest": """ # check for correct default manager >>> type(MROBase1._default_manager) - + # Django vanilla inheritance does not inherit MyManager as _default_manager here >>> type(MROBase2._default_manager) @@ -550,4 +588,4 @@ __test__ = {"doctest": """ >>> settings.DEBUG=False """} - +