Take coerce_to_string into account when handling DecimalField

Closes #62.
openapi3
Cristi Vîjdea 2018-02-21 04:32:52 +02:00
parent 71dee6eb45
commit 97cd1b63d9
7 changed files with 65 additions and 9 deletions

View File

@ -2,6 +2,13 @@
Changelog Changelog
######### #########
*********
**1.5.0**
*********
- **FIXED:** the ``coerce_to_string`` is now respected when setting the type, default value and min/max values of
``DecimalField`` in the OpenAPI schema (:issue:`62`)
********* *********
**1.4.0** **1.4.0**
********* *********

View File

@ -6,7 +6,7 @@ from rest_framework import serializers
from rest_framework.utils import encoders, json from rest_framework.utils import encoders, json
from .. import openapi from .. import openapi
from ..utils import is_list_view from ..utils import decimal_as_float, is_list_view
#: Sentinel value that inspectors must return to signal that they do not know how to handle an object #: Sentinel value that inspectors must return to signal that they do not know how to handle an object
NotHandled = object() NotHandled = object()
@ -224,6 +224,8 @@ class FieldInspector(BaseInspector):
# JSON roundtrip ensures that the value is valid JSON; # JSON roundtrip ensures that the value is valid JSON;
# for example, sets and tuples get transformed into lists # for example, sets and tuples get transformed into lists
default = json.loads(json.dumps(default, cls=encoders.JSONEncoder)) default = json.loads(json.dumps(default, cls=encoders.JSONEncoder))
if decimal_as_float(field):
default = float(default)
except Exception: # pragma: no cover except Exception: # pragma: no cover
logger.warning("'default' on schema for %s will not be set because " logger.warning("'default' on schema for %s will not be set because "
"to_representation raised an exception", field, exc_info=True) "to_representation raised an exception", field, exc_info=True)

View File

@ -1,5 +1,6 @@
import operator import operator
from collections import OrderedDict from collections import OrderedDict
from decimal import Decimal
from django.core import validators from django.core import validators
from django.db import models from django.db import models
@ -8,7 +9,7 @@ from rest_framework.settings import api_settings as rest_framework_settings
from .. import openapi from .. import openapi
from ..errors import SwaggerGenerationError from ..errors import SwaggerGenerationError
from ..utils import filter_none from ..utils import decimal_as_float, filter_none
from .base import FieldInspector, NotHandled, SerializerInspector from .base import FieldInspector, NotHandled, SerializerInspector
@ -258,18 +259,29 @@ def find_limits(field):
if isinstance(field, field_class) if isinstance(field, field_class)
] ]
if isinstance(field, serializers.DecimalField) and not decimal_as_float(field):
return limits
for validator in field.validators: for validator in field.validators:
if not hasattr(validator, 'limit_value'): if not hasattr(validator, 'limit_value'):
continue continue
limit_value = validator.limit_value
if isinstance(limit_value, Decimal) and decimal_as_float(field):
limit_value = float(limit_value)
for validator_class, attr, improves in applicable_limits: for validator_class, attr, improves in applicable_limits:
if isinstance(validator, validator_class): if isinstance(validator, validator_class):
if attr not in limits or improves(validator.limit_value, limits[attr]): if attr not in limits or improves(limit_value, limits[attr]):
limits[attr] = validator.limit_value limits[attr] = limit_value
return OrderedDict(sorted(limits.items())) return OrderedDict(sorted(limits.items()))
def decimal_field_type(field):
return openapi.TYPE_NUMBER if decimal_as_float(field) else openapi.TYPE_STRING
model_field_to_basic_type = [ model_field_to_basic_type = [
(models.AutoField, (openapi.TYPE_INTEGER, None)), (models.AutoField, (openapi.TYPE_INTEGER, None)),
(models.BinaryField, (openapi.TYPE_STRING, openapi.FORMAT_BINARY)), (models.BinaryField, (openapi.TYPE_STRING, openapi.FORMAT_BINARY)),
@ -277,7 +289,7 @@ model_field_to_basic_type = [
(models.NullBooleanField, (openapi.TYPE_BOOLEAN, None)), (models.NullBooleanField, (openapi.TYPE_BOOLEAN, None)),
(models.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)), (models.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
(models.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)), (models.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
(models.DecimalField, (openapi.TYPE_NUMBER, None)), (models.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)),
(models.DurationField, (openapi.TYPE_INTEGER, None)), (models.DurationField, (openapi.TYPE_INTEGER, None)),
(models.FloatField, (openapi.TYPE_NUMBER, None)), (models.FloatField, (openapi.TYPE_NUMBER, None)),
(models.IntegerField, (openapi.TYPE_INTEGER, None)), (models.IntegerField, (openapi.TYPE_INTEGER, None)),
@ -300,9 +312,11 @@ serializer_field_to_basic_type = [
(serializers.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)), (serializers.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
(serializers.RegexField, (openapi.TYPE_STRING, None)), (serializers.RegexField, (openapi.TYPE_STRING, None)),
(serializers.CharField, (openapi.TYPE_STRING, None)), (serializers.CharField, (openapi.TYPE_STRING, None)),
((serializers.BooleanField, serializers.NullBooleanField), (openapi.TYPE_BOOLEAN, None)), (serializers.BooleanField, (openapi.TYPE_BOOLEAN, None)),
(serializers.NullBooleanField, (openapi.TYPE_BOOLEAN, None)),
(serializers.IntegerField, (openapi.TYPE_INTEGER, None)), (serializers.IntegerField, (openapi.TYPE_INTEGER, None)),
((serializers.FloatField, serializers.DecimalField), (openapi.TYPE_NUMBER, None)), (serializers.FloatField, (openapi.TYPE_NUMBER, None)),
(serializers.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)),
(serializers.DurationField, (openapi.TYPE_NUMBER, None)), # ? (serializers.DurationField, (openapi.TYPE_NUMBER, None)), # ?
(serializers.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)), (serializers.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
(serializers.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)), (serializers.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
@ -326,6 +340,8 @@ def get_basic_type_info(field):
for field_class, type_format in basic_type_info: for field_class, type_format in basic_type_info:
if isinstance(field, field_class): if isinstance(field, field_class):
swagger_type, format = type_format swagger_type, format = type_format
if callable(swagger_type):
swagger_type = swagger_type(field)
if callable(format): if callable(format):
format = format(field) format = format(field)
break break

View File

@ -35,6 +35,7 @@ FORMAT_URI = "uri" #:
# pulled out of my ass # pulled out of my ass
FORMAT_UUID = "uuid" #: FORMAT_UUID = "uuid" #:
FORMAT_SLUG = "slug" #: FORMAT_SLUG = "slug" #:
FORMAT_DECIMAL = "decimal"
IN_BODY = 'body' #: IN_BODY = 'body' #:
IN_PATH = 'path' #: IN_PATH = 'path' #:

View File

@ -2,9 +2,11 @@ import inspect
import logging import logging
from collections import OrderedDict from collections import OrderedDict
from django.db import models
from rest_framework import serializers, status from rest_framework import serializers, status
from rest_framework.mixins import DestroyModelMixin, RetrieveModelMixin, UpdateModelMixin from rest_framework.mixins import DestroyModelMixin, RetrieveModelMixin, UpdateModelMixin
from rest_framework.request import is_form_media_type from rest_framework.request import is_form_media_type
from rest_framework.settings import api_settings as rest_framework_settings
from rest_framework.views import APIView from rest_framework.views import APIView
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -273,3 +275,15 @@ def get_produces(renderer_classes):
media_types = [renderer.media_type for renderer in renderer_classes or []] media_types = [renderer.media_type for renderer in renderer_classes or []]
media_types = [encoding for encoding in media_types if 'html' not in encoding] media_types = [encoding for encoding in media_types if 'html' not in encoding]
return media_types return media_types
def decimal_as_float(field):
"""
Returns true if ``field`` is a django-rest-framework DecimalField and its ``coerce_to_string`` attribute or the
``COERCE_DECIMAL_TO_STRING`` setting is set to ``False``.
:rtype: bool
"""
if isinstance(field, serializers.DecimalField) or isinstance(field, models.DecimalField):
return not getattr(field, 'coerce_to_string', rest_framework_settings.COERCE_DECIMAL_TO_STRING)
return False

View File

@ -1,11 +1,13 @@
from decimal import Decimal
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from rest_framework import serializers from rest_framework import serializers
from rest_framework.compat import MinValueValidator
from snippets.models import LANGUAGE_CHOICES, STYLE_CHOICES, Snippet from snippets.models import LANGUAGE_CHOICES, STYLE_CHOICES, Snippet
class LanguageSerializer(serializers.Serializer): class LanguageSerializer(serializers.Serializer):
name = serializers.ChoiceField( name = serializers.ChoiceField(
choices=LANGUAGE_CHOICES, default='python', help_text='The name of the programming language') choices=LANGUAGE_CHOICES, default='python', help_text='The name of the programming language')
@ -14,7 +16,6 @@ class LanguageSerializer(serializers.Serializer):
class ExampleProjectSerializer(serializers.Serializer): class ExampleProjectSerializer(serializers.Serializer):
project_name = serializers.CharField(help_text='Name of the project') project_name = serializers.CharField(help_text='Name of the project')
github_repo = serializers.CharField(required=True, help_text='Github repository of the project') github_repo = serializers.CharField(required=True, help_text='Github repository of the project')
@ -49,6 +50,10 @@ class SnippetSerializer(serializers.Serializer):
example_projects = serializers.ListSerializer(child=ExampleProjectSerializer(), read_only=True) example_projects = serializers.ListSerializer(child=ExampleProjectSerializer(), read_only=True)
difficulty_factor = serializers.FloatField(help_text="this is here just to test FloatField", difficulty_factor = serializers.FloatField(help_text="this is here just to test FloatField",
read_only=True, default=lambda: 6.9) read_only=True, default=lambda: 6.9)
rate_as_string = serializers.DecimalField(max_digits=6, decimal_places=3, default=Decimal('0.0'),
validators=[MinValueValidator(Decimal('0.0'))])
rate = serializers.DecimalField(max_digits=6, decimal_places=3, default=Decimal('0.0'), coerce_to_string=False,
validators=[MinValueValidator(Decimal('0.0'))])
def create(self, validated_data): def create(self, validated_data):
""" """

View File

@ -1036,6 +1036,17 @@ definitions:
type: number type: number
readOnly: true readOnly: true
default: 6.9 default: 6.9
rateAsString:
title: Rate as string
type: string
format: decimal
default: '0.000'
rate:
title: Rate
type: number
format: decimal
default: 0.0
minimum: 0.0
UserSerializerrr: UserSerializerrr:
required: required:
- username - username