parent
71dee6eb45
commit
97cd1b63d9
|
|
@ -2,6 +2,13 @@
|
|||
Changelog
|
||||
#########
|
||||
|
||||
*********
|
||||
**1.5.0**
|
||||
*********
|
||||
|
||||
- **FIXED:** the ``coerce_to_string`` is now respected when setting the type, default value and min/max values of
|
||||
``DecimalField`` in the OpenAPI schema (:issue:`62`)
|
||||
|
||||
*********
|
||||
**1.4.0**
|
||||
*********
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from rest_framework import serializers
|
|||
from rest_framework.utils import encoders, json
|
||||
|
||||
from .. import openapi
|
||||
from ..utils import is_list_view
|
||||
from ..utils import decimal_as_float, is_list_view
|
||||
|
||||
#: Sentinel value that inspectors must return to signal that they do not know how to handle an object
|
||||
NotHandled = object()
|
||||
|
|
@ -224,6 +224,8 @@ class FieldInspector(BaseInspector):
|
|||
# JSON roundtrip ensures that the value is valid JSON;
|
||||
# for example, sets and tuples get transformed into lists
|
||||
default = json.loads(json.dumps(default, cls=encoders.JSONEncoder))
|
||||
if decimal_as_float(field):
|
||||
default = float(default)
|
||||
except Exception: # pragma: no cover
|
||||
logger.warning("'default' on schema for %s will not be set because "
|
||||
"to_representation raised an exception", field, exc_info=True)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import operator
|
||||
from collections import OrderedDict
|
||||
from decimal import Decimal
|
||||
|
||||
from django.core import validators
|
||||
from django.db import models
|
||||
|
|
@ -8,7 +9,7 @@ from rest_framework.settings import api_settings as rest_framework_settings
|
|||
|
||||
from .. import openapi
|
||||
from ..errors import SwaggerGenerationError
|
||||
from ..utils import filter_none
|
||||
from ..utils import decimal_as_float, filter_none
|
||||
from .base import FieldInspector, NotHandled, SerializerInspector
|
||||
|
||||
|
||||
|
|
@ -258,18 +259,29 @@ def find_limits(field):
|
|||
if isinstance(field, field_class)
|
||||
]
|
||||
|
||||
if isinstance(field, serializers.DecimalField) and not decimal_as_float(field):
|
||||
return limits
|
||||
|
||||
for validator in field.validators:
|
||||
if not hasattr(validator, 'limit_value'):
|
||||
continue
|
||||
|
||||
limit_value = validator.limit_value
|
||||
if isinstance(limit_value, Decimal) and decimal_as_float(field):
|
||||
limit_value = float(limit_value)
|
||||
|
||||
for validator_class, attr, improves in applicable_limits:
|
||||
if isinstance(validator, validator_class):
|
||||
if attr not in limits or improves(validator.limit_value, limits[attr]):
|
||||
limits[attr] = validator.limit_value
|
||||
if attr not in limits or improves(limit_value, limits[attr]):
|
||||
limits[attr] = limit_value
|
||||
|
||||
return OrderedDict(sorted(limits.items()))
|
||||
|
||||
|
||||
def decimal_field_type(field):
|
||||
return openapi.TYPE_NUMBER if decimal_as_float(field) else openapi.TYPE_STRING
|
||||
|
||||
|
||||
model_field_to_basic_type = [
|
||||
(models.AutoField, (openapi.TYPE_INTEGER, None)),
|
||||
(models.BinaryField, (openapi.TYPE_STRING, openapi.FORMAT_BINARY)),
|
||||
|
|
@ -277,7 +289,7 @@ model_field_to_basic_type = [
|
|||
(models.NullBooleanField, (openapi.TYPE_BOOLEAN, None)),
|
||||
(models.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
|
||||
(models.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
|
||||
(models.DecimalField, (openapi.TYPE_NUMBER, None)),
|
||||
(models.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)),
|
||||
(models.DurationField, (openapi.TYPE_INTEGER, None)),
|
||||
(models.FloatField, (openapi.TYPE_NUMBER, None)),
|
||||
(models.IntegerField, (openapi.TYPE_INTEGER, None)),
|
||||
|
|
@ -300,9 +312,11 @@ serializer_field_to_basic_type = [
|
|||
(serializers.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
|
||||
(serializers.RegexField, (openapi.TYPE_STRING, None)),
|
||||
(serializers.CharField, (openapi.TYPE_STRING, None)),
|
||||
((serializers.BooleanField, serializers.NullBooleanField), (openapi.TYPE_BOOLEAN, None)),
|
||||
(serializers.BooleanField, (openapi.TYPE_BOOLEAN, None)),
|
||||
(serializers.NullBooleanField, (openapi.TYPE_BOOLEAN, None)),
|
||||
(serializers.IntegerField, (openapi.TYPE_INTEGER, None)),
|
||||
((serializers.FloatField, serializers.DecimalField), (openapi.TYPE_NUMBER, None)),
|
||||
(serializers.FloatField, (openapi.TYPE_NUMBER, None)),
|
||||
(serializers.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)),
|
||||
(serializers.DurationField, (openapi.TYPE_NUMBER, None)), # ?
|
||||
(serializers.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
|
||||
(serializers.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
|
||||
|
|
@ -326,6 +340,8 @@ def get_basic_type_info(field):
|
|||
for field_class, type_format in basic_type_info:
|
||||
if isinstance(field, field_class):
|
||||
swagger_type, format = type_format
|
||||
if callable(swagger_type):
|
||||
swagger_type = swagger_type(field)
|
||||
if callable(format):
|
||||
format = format(field)
|
||||
break
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ FORMAT_URI = "uri" #:
|
|||
# pulled out of my ass
|
||||
FORMAT_UUID = "uuid" #:
|
||||
FORMAT_SLUG = "slug" #:
|
||||
FORMAT_DECIMAL = "decimal"
|
||||
|
||||
IN_BODY = 'body' #:
|
||||
IN_PATH = 'path' #:
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@ import inspect
|
|||
import logging
|
||||
from collections import OrderedDict
|
||||
|
||||
from django.db import models
|
||||
from rest_framework import serializers, status
|
||||
from rest_framework.mixins import DestroyModelMixin, RetrieveModelMixin, UpdateModelMixin
|
||||
from rest_framework.request import is_form_media_type
|
||||
from rest_framework.settings import api_settings as rest_framework_settings
|
||||
from rest_framework.views import APIView
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -273,3 +275,15 @@ def get_produces(renderer_classes):
|
|||
media_types = [renderer.media_type for renderer in renderer_classes or []]
|
||||
media_types = [encoding for encoding in media_types if 'html' not in encoding]
|
||||
return media_types
|
||||
|
||||
|
||||
def decimal_as_float(field):
|
||||
"""
|
||||
Returns true if ``field`` is a django-rest-framework DecimalField and its ``coerce_to_string`` attribute or the
|
||||
``COERCE_DECIMAL_TO_STRING`` setting is set to ``False``.
|
||||
|
||||
:rtype: bool
|
||||
"""
|
||||
if isinstance(field, serializers.DecimalField) or isinstance(field, models.DecimalField):
|
||||
return not getattr(field, 'coerce_to_string', rest_framework_settings.COERCE_DECIMAL_TO_STRING)
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
from decimal import Decimal
|
||||
|
||||
from django.contrib.auth import get_user_model
|
||||
from rest_framework import serializers
|
||||
from rest_framework.compat import MinValueValidator
|
||||
|
||||
from snippets.models import LANGUAGE_CHOICES, STYLE_CHOICES, Snippet
|
||||
|
||||
|
||||
class LanguageSerializer(serializers.Serializer):
|
||||
|
||||
name = serializers.ChoiceField(
|
||||
choices=LANGUAGE_CHOICES, default='python', help_text='The name of the programming language')
|
||||
|
||||
|
|
@ -14,7 +16,6 @@ class LanguageSerializer(serializers.Serializer):
|
|||
|
||||
|
||||
class ExampleProjectSerializer(serializers.Serializer):
|
||||
|
||||
project_name = serializers.CharField(help_text='Name of the project')
|
||||
github_repo = serializers.CharField(required=True, help_text='Github repository of the project')
|
||||
|
||||
|
|
@ -49,6 +50,10 @@ class SnippetSerializer(serializers.Serializer):
|
|||
example_projects = serializers.ListSerializer(child=ExampleProjectSerializer(), read_only=True)
|
||||
difficulty_factor = serializers.FloatField(help_text="this is here just to test FloatField",
|
||||
read_only=True, default=lambda: 6.9)
|
||||
rate_as_string = serializers.DecimalField(max_digits=6, decimal_places=3, default=Decimal('0.0'),
|
||||
validators=[MinValueValidator(Decimal('0.0'))])
|
||||
rate = serializers.DecimalField(max_digits=6, decimal_places=3, default=Decimal('0.0'), coerce_to_string=False,
|
||||
validators=[MinValueValidator(Decimal('0.0'))])
|
||||
|
||||
def create(self, validated_data):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1036,6 +1036,17 @@ definitions:
|
|||
type: number
|
||||
readOnly: true
|
||||
default: 6.9
|
||||
rateAsString:
|
||||
title: Rate as string
|
||||
type: string
|
||||
format: decimal
|
||||
default: '0.000'
|
||||
rate:
|
||||
title: Rate
|
||||
type: number
|
||||
format: decimal
|
||||
default: 0.0
|
||||
minimum: 0.0
|
||||
UserSerializerrr:
|
||||
required:
|
||||
- username
|
||||
|
|
|
|||
Loading…
Reference in New Issue