Take coerce_to_string into account when handling DecimalField

Closes #62.
openapi3
Cristi Vîjdea 2018-02-21 04:32:52 +02:00
parent 71dee6eb45
commit 97cd1b63d9
7 changed files with 65 additions and 9 deletions

View File

@ -2,6 +2,13 @@
Changelog
#########
*********
**1.5.0**
*********
- **FIXED:** the ``coerce_to_string`` is now respected when setting the type, default value and min/max values of
``DecimalField`` in the OpenAPI schema (:issue:`62`)
*********
**1.4.0**
*********

View File

@ -6,7 +6,7 @@ from rest_framework import serializers
from rest_framework.utils import encoders, json
from .. import openapi
from ..utils import is_list_view
from ..utils import decimal_as_float, is_list_view
#: Sentinel value that inspectors must return to signal that they do not know how to handle an object
NotHandled = object()
@ -224,6 +224,8 @@ class FieldInspector(BaseInspector):
# JSON roundtrip ensures that the value is valid JSON;
# for example, sets and tuples get transformed into lists
default = json.loads(json.dumps(default, cls=encoders.JSONEncoder))
if decimal_as_float(field):
default = float(default)
except Exception: # pragma: no cover
logger.warning("'default' on schema for %s will not be set because "
"to_representation raised an exception", field, exc_info=True)

View File

@ -1,5 +1,6 @@
import operator
from collections import OrderedDict
from decimal import Decimal
from django.core import validators
from django.db import models
@ -8,7 +9,7 @@ from rest_framework.settings import api_settings as rest_framework_settings
from .. import openapi
from ..errors import SwaggerGenerationError
from ..utils import filter_none
from ..utils import decimal_as_float, filter_none
from .base import FieldInspector, NotHandled, SerializerInspector
@ -258,18 +259,29 @@ def find_limits(field):
if isinstance(field, field_class)
]
if isinstance(field, serializers.DecimalField) and not decimal_as_float(field):
return limits
for validator in field.validators:
if not hasattr(validator, 'limit_value'):
continue
limit_value = validator.limit_value
if isinstance(limit_value, Decimal) and decimal_as_float(field):
limit_value = float(limit_value)
for validator_class, attr, improves in applicable_limits:
if isinstance(validator, validator_class):
if attr not in limits or improves(validator.limit_value, limits[attr]):
limits[attr] = validator.limit_value
if attr not in limits or improves(limit_value, limits[attr]):
limits[attr] = limit_value
return OrderedDict(sorted(limits.items()))
def decimal_field_type(field):
return openapi.TYPE_NUMBER if decimal_as_float(field) else openapi.TYPE_STRING
model_field_to_basic_type = [
(models.AutoField, (openapi.TYPE_INTEGER, None)),
(models.BinaryField, (openapi.TYPE_STRING, openapi.FORMAT_BINARY)),
@ -277,7 +289,7 @@ model_field_to_basic_type = [
(models.NullBooleanField, (openapi.TYPE_BOOLEAN, None)),
(models.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
(models.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
(models.DecimalField, (openapi.TYPE_NUMBER, None)),
(models.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)),
(models.DurationField, (openapi.TYPE_INTEGER, None)),
(models.FloatField, (openapi.TYPE_NUMBER, None)),
(models.IntegerField, (openapi.TYPE_INTEGER, None)),
@ -300,9 +312,11 @@ serializer_field_to_basic_type = [
(serializers.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
(serializers.RegexField, (openapi.TYPE_STRING, None)),
(serializers.CharField, (openapi.TYPE_STRING, None)),
((serializers.BooleanField, serializers.NullBooleanField), (openapi.TYPE_BOOLEAN, None)),
(serializers.BooleanField, (openapi.TYPE_BOOLEAN, None)),
(serializers.NullBooleanField, (openapi.TYPE_BOOLEAN, None)),
(serializers.IntegerField, (openapi.TYPE_INTEGER, None)),
((serializers.FloatField, serializers.DecimalField), (openapi.TYPE_NUMBER, None)),
(serializers.FloatField, (openapi.TYPE_NUMBER, None)),
(serializers.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)),
(serializers.DurationField, (openapi.TYPE_NUMBER, None)), # ?
(serializers.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
(serializers.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
@ -326,6 +340,8 @@ def get_basic_type_info(field):
for field_class, type_format in basic_type_info:
if isinstance(field, field_class):
swagger_type, format = type_format
if callable(swagger_type):
swagger_type = swagger_type(field)
if callable(format):
format = format(field)
break

View File

@ -35,6 +35,7 @@ FORMAT_URI = "uri" #:
# pulled out of my ass
FORMAT_UUID = "uuid" #:
FORMAT_SLUG = "slug" #:
FORMAT_DECIMAL = "decimal"
IN_BODY = 'body' #:
IN_PATH = 'path' #:

View File

@ -2,9 +2,11 @@ import inspect
import logging
from collections import OrderedDict
from django.db import models
from rest_framework import serializers, status
from rest_framework.mixins import DestroyModelMixin, RetrieveModelMixin, UpdateModelMixin
from rest_framework.request import is_form_media_type
from rest_framework.settings import api_settings as rest_framework_settings
from rest_framework.views import APIView
logger = logging.getLogger(__name__)
@ -273,3 +275,15 @@ def get_produces(renderer_classes):
media_types = [renderer.media_type for renderer in renderer_classes or []]
media_types = [encoding for encoding in media_types if 'html' not in encoding]
return media_types
def decimal_as_float(field):
"""
Returns true if ``field`` is a django-rest-framework DecimalField and its ``coerce_to_string`` attribute or the
``COERCE_DECIMAL_TO_STRING`` setting is set to ``False``.
:rtype: bool
"""
if isinstance(field, serializers.DecimalField) or isinstance(field, models.DecimalField):
return not getattr(field, 'coerce_to_string', rest_framework_settings.COERCE_DECIMAL_TO_STRING)
return False

View File

@ -1,11 +1,13 @@
from decimal import Decimal
from django.contrib.auth import get_user_model
from rest_framework import serializers
from rest_framework.compat import MinValueValidator
from snippets.models import LANGUAGE_CHOICES, STYLE_CHOICES, Snippet
class LanguageSerializer(serializers.Serializer):
name = serializers.ChoiceField(
choices=LANGUAGE_CHOICES, default='python', help_text='The name of the programming language')
@ -14,7 +16,6 @@ class LanguageSerializer(serializers.Serializer):
class ExampleProjectSerializer(serializers.Serializer):
project_name = serializers.CharField(help_text='Name of the project')
github_repo = serializers.CharField(required=True, help_text='Github repository of the project')
@ -49,6 +50,10 @@ class SnippetSerializer(serializers.Serializer):
example_projects = serializers.ListSerializer(child=ExampleProjectSerializer(), read_only=True)
difficulty_factor = serializers.FloatField(help_text="this is here just to test FloatField",
read_only=True, default=lambda: 6.9)
rate_as_string = serializers.DecimalField(max_digits=6, decimal_places=3, default=Decimal('0.0'),
validators=[MinValueValidator(Decimal('0.0'))])
rate = serializers.DecimalField(max_digits=6, decimal_places=3, default=Decimal('0.0'), coerce_to_string=False,
validators=[MinValueValidator(Decimal('0.0'))])
def create(self, validated_data):
"""

View File

@ -1036,6 +1036,17 @@ definitions:
type: number
readOnly: true
default: 6.9
rateAsString:
title: Rate as string
type: string
format: decimal
default: '0.000'
rate:
title: Rate
type: number
format: decimal
default: 0.0
minimum: 0.0
UserSerializerrr:
required:
- username