From 97cd1b63d9ff34b893f2e1929863e96f03b97858 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristi=20V=C3=AEjdea?= Date: Wed, 21 Feb 2018 04:32:52 +0200 Subject: [PATCH] Take coerce_to_string into account when handling DecimalField Closes #62. --- docs/changelog.rst | 7 +++++++ src/drf_yasg/inspectors/base.py | 4 +++- src/drf_yasg/inspectors/field.py | 28 ++++++++++++++++++++++------ src/drf_yasg/openapi.py | 1 + src/drf_yasg/utils.py | 14 ++++++++++++++ testproj/snippets/serializers.py | 9 +++++++-- tests/reference.yaml | 11 +++++++++++ 7 files changed, 65 insertions(+), 9 deletions(-) diff --git a/docs/changelog.rst b/docs/changelog.rst index 787c878..f43e328 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -2,6 +2,13 @@ Changelog ######### +********* +**1.5.0** +********* + +- **FIXED:** the ``coerce_to_string`` is now respected when setting the type, default value and min/max values of + ``DecimalField`` in the OpenAPI schema (:issue:`62`) + ********* **1.4.0** ********* diff --git a/src/drf_yasg/inspectors/base.py b/src/drf_yasg/inspectors/base.py index 8175cc9..f23f178 100644 --- a/src/drf_yasg/inspectors/base.py +++ b/src/drf_yasg/inspectors/base.py @@ -6,7 +6,7 @@ from rest_framework import serializers from rest_framework.utils import encoders, json from .. import openapi -from ..utils import is_list_view +from ..utils import decimal_as_float, is_list_view #: Sentinel value that inspectors must return to signal that they do not know how to handle an object NotHandled = object() @@ -224,6 +224,8 @@ class FieldInspector(BaseInspector): # JSON roundtrip ensures that the value is valid JSON; # for example, sets and tuples get transformed into lists default = json.loads(json.dumps(default, cls=encoders.JSONEncoder)) + if decimal_as_float(field): + default = float(default) except Exception: # pragma: no cover logger.warning("'default' on schema for %s will not be set because " "to_representation raised an exception", field, exc_info=True) diff --git a/src/drf_yasg/inspectors/field.py b/src/drf_yasg/inspectors/field.py index 39928f2..e7b84ab 100644 --- a/src/drf_yasg/inspectors/field.py +++ b/src/drf_yasg/inspectors/field.py @@ -1,5 +1,6 @@ import operator from collections import OrderedDict +from decimal import Decimal from django.core import validators from django.db import models @@ -8,7 +9,7 @@ from rest_framework.settings import api_settings as rest_framework_settings from .. import openapi from ..errors import SwaggerGenerationError -from ..utils import filter_none +from ..utils import decimal_as_float, filter_none from .base import FieldInspector, NotHandled, SerializerInspector @@ -258,18 +259,29 @@ def find_limits(field): if isinstance(field, field_class) ] + if isinstance(field, serializers.DecimalField) and not decimal_as_float(field): + return limits + for validator in field.validators: if not hasattr(validator, 'limit_value'): continue + limit_value = validator.limit_value + if isinstance(limit_value, Decimal) and decimal_as_float(field): + limit_value = float(limit_value) + for validator_class, attr, improves in applicable_limits: if isinstance(validator, validator_class): - if attr not in limits or improves(validator.limit_value, limits[attr]): - limits[attr] = validator.limit_value + if attr not in limits or improves(limit_value, limits[attr]): + limits[attr] = limit_value return OrderedDict(sorted(limits.items())) +def decimal_field_type(field): + return openapi.TYPE_NUMBER if decimal_as_float(field) else openapi.TYPE_STRING + + model_field_to_basic_type = [ (models.AutoField, (openapi.TYPE_INTEGER, None)), (models.BinaryField, (openapi.TYPE_STRING, openapi.FORMAT_BINARY)), @@ -277,7 +289,7 @@ model_field_to_basic_type = [ (models.NullBooleanField, (openapi.TYPE_BOOLEAN, None)), (models.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)), (models.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)), - (models.DecimalField, (openapi.TYPE_NUMBER, None)), + (models.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)), (models.DurationField, (openapi.TYPE_INTEGER, None)), (models.FloatField, (openapi.TYPE_NUMBER, None)), (models.IntegerField, (openapi.TYPE_INTEGER, None)), @@ -300,9 +312,11 @@ serializer_field_to_basic_type = [ (serializers.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)), (serializers.RegexField, (openapi.TYPE_STRING, None)), (serializers.CharField, (openapi.TYPE_STRING, None)), - ((serializers.BooleanField, serializers.NullBooleanField), (openapi.TYPE_BOOLEAN, None)), + (serializers.BooleanField, (openapi.TYPE_BOOLEAN, None)), + (serializers.NullBooleanField, (openapi.TYPE_BOOLEAN, None)), (serializers.IntegerField, (openapi.TYPE_INTEGER, None)), - ((serializers.FloatField, serializers.DecimalField), (openapi.TYPE_NUMBER, None)), + (serializers.FloatField, (openapi.TYPE_NUMBER, None)), + (serializers.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)), (serializers.DurationField, (openapi.TYPE_NUMBER, None)), # ? (serializers.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)), (serializers.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)), @@ -326,6 +340,8 @@ def get_basic_type_info(field): for field_class, type_format in basic_type_info: if isinstance(field, field_class): swagger_type, format = type_format + if callable(swagger_type): + swagger_type = swagger_type(field) if callable(format): format = format(field) break diff --git a/src/drf_yasg/openapi.py b/src/drf_yasg/openapi.py index d1ed257..4397257 100644 --- a/src/drf_yasg/openapi.py +++ b/src/drf_yasg/openapi.py @@ -35,6 +35,7 @@ FORMAT_URI = "uri" #: # pulled out of my ass FORMAT_UUID = "uuid" #: FORMAT_SLUG = "slug" #: +FORMAT_DECIMAL = "decimal" IN_BODY = 'body' #: IN_PATH = 'path' #: diff --git a/src/drf_yasg/utils.py b/src/drf_yasg/utils.py index a366473..3f713b4 100644 --- a/src/drf_yasg/utils.py +++ b/src/drf_yasg/utils.py @@ -2,9 +2,11 @@ import inspect import logging from collections import OrderedDict +from django.db import models from rest_framework import serializers, status from rest_framework.mixins import DestroyModelMixin, RetrieveModelMixin, UpdateModelMixin from rest_framework.request import is_form_media_type +from rest_framework.settings import api_settings as rest_framework_settings from rest_framework.views import APIView logger = logging.getLogger(__name__) @@ -273,3 +275,15 @@ def get_produces(renderer_classes): media_types = [renderer.media_type for renderer in renderer_classes or []] media_types = [encoding for encoding in media_types if 'html' not in encoding] return media_types + + +def decimal_as_float(field): + """ + Returns true if ``field`` is a django-rest-framework DecimalField and its ``coerce_to_string`` attribute or the + ``COERCE_DECIMAL_TO_STRING`` setting is set to ``False``. + + :rtype: bool + """ + if isinstance(field, serializers.DecimalField) or isinstance(field, models.DecimalField): + return not getattr(field, 'coerce_to_string', rest_framework_settings.COERCE_DECIMAL_TO_STRING) + return False diff --git a/testproj/snippets/serializers.py b/testproj/snippets/serializers.py index a83dced..8f89a7a 100644 --- a/testproj/snippets/serializers.py +++ b/testproj/snippets/serializers.py @@ -1,11 +1,13 @@ +from decimal import Decimal + from django.contrib.auth import get_user_model from rest_framework import serializers +from rest_framework.compat import MinValueValidator from snippets.models import LANGUAGE_CHOICES, STYLE_CHOICES, Snippet class LanguageSerializer(serializers.Serializer): - name = serializers.ChoiceField( choices=LANGUAGE_CHOICES, default='python', help_text='The name of the programming language') @@ -14,7 +16,6 @@ class LanguageSerializer(serializers.Serializer): class ExampleProjectSerializer(serializers.Serializer): - project_name = serializers.CharField(help_text='Name of the project') github_repo = serializers.CharField(required=True, help_text='Github repository of the project') @@ -49,6 +50,10 @@ class SnippetSerializer(serializers.Serializer): example_projects = serializers.ListSerializer(child=ExampleProjectSerializer(), read_only=True) difficulty_factor = serializers.FloatField(help_text="this is here just to test FloatField", read_only=True, default=lambda: 6.9) + rate_as_string = serializers.DecimalField(max_digits=6, decimal_places=3, default=Decimal('0.0'), + validators=[MinValueValidator(Decimal('0.0'))]) + rate = serializers.DecimalField(max_digits=6, decimal_places=3, default=Decimal('0.0'), coerce_to_string=False, + validators=[MinValueValidator(Decimal('0.0'))]) def create(self, validated_data): """ diff --git a/tests/reference.yaml b/tests/reference.yaml index 1990d7c..bded591 100644 --- a/tests/reference.yaml +++ b/tests/reference.yaml @@ -1036,6 +1036,17 @@ definitions: type: number readOnly: true default: 6.9 + rateAsString: + title: Rate as string + type: string + format: decimal + default: '0.000' + rate: + title: Rate + type: number + format: decimal + default: 0.0 + minimum: 0.0 UserSerializerrr: required: - username