From bcb8b0d3a4c6871e1dda1b0321013683a7ed82ba Mon Sep 17 00:00:00 2001 From: Diederik van der Boor Date: Mon, 22 May 2017 12:52:17 +0200 Subject: [PATCH] Allow .order_by() to pass expressions unchanged Fixes: #257 --- polymorphic/query.py | 10 +++++++--- polymorphic/query_translate.py | 4 ++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/polymorphic/query.py b/polymorphic/query.py index ab6cbea..8643308 100644 --- a/polymorphic/query.py +++ b/polymorphic/query.py @@ -172,10 +172,14 @@ class PolymorphicQuerySet(QuerySet): additional_args = translate_polymorphic_filter_definitions_in_kwargs(self.model, kwargs, using=self.db) # filter_field='data' return super(PolymorphicQuerySet, self)._filter_or_exclude(negate, *(list(q_objects) + additional_args), **kwargs) - def order_by(self, *args, **kwargs): + def order_by(self, *field_names): """translate the field paths in the args, then call vanilla order_by.""" - new_args = [translate_polymorphic_field_path(self.model, a) for a in args] - return super(PolymorphicQuerySet, self).order_by(*new_args, **kwargs) + field_names = [ + translate_polymorphic_field_path(self.model, a) + if isinstance(a, six.string_types) else a # allow expressions to pass unchanged + for a in field_names + ] + return super(PolymorphicQuerySet, self).order_by(*field_names) def defer(self, *fields): """ diff --git a/polymorphic/query_translate.py b/polymorphic/query_translate.py index 01ffd4c..9c69ba8 100644 --- a/polymorphic/query_translate.py +++ b/polymorphic/query_translate.py @@ -11,6 +11,7 @@ from django.db import models from django.contrib.contenttypes.models import ContentType from django.db.models import Q, FieldDoesNotExist from django.db.utils import DEFAULT_DB_ALIAS +from django.utils import six from django.db.models.fields.related import RelatedField if django.VERSION < (1, 6): @@ -145,6 +146,9 @@ def translate_polymorphic_field_path(queryset_model, field_path): into modela__modelb__modelc__field3. Returns: translated path (unchanged, if no translation needed) """ + if not isinstance(field_path, six.string_types): + raise ValueError("Expected field name as string: {0}".format(field_path)) + classname, sep, pure_field_path = field_path.partition('___') if not sep: return field_path