parent
71dee6eb45
commit
97cd1b63d9
|
|
@ -2,6 +2,13 @@
|
||||||
Changelog
|
Changelog
|
||||||
#########
|
#########
|
||||||
|
|
||||||
|
*********
|
||||||
|
**1.5.0**
|
||||||
|
*********
|
||||||
|
|
||||||
|
- **FIXED:** the ``coerce_to_string`` is now respected when setting the type, default value and min/max values of
|
||||||
|
``DecimalField`` in the OpenAPI schema (:issue:`62`)
|
||||||
|
|
||||||
*********
|
*********
|
||||||
**1.4.0**
|
**1.4.0**
|
||||||
*********
|
*********
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from rest_framework import serializers
|
||||||
from rest_framework.utils import encoders, json
|
from rest_framework.utils import encoders, json
|
||||||
|
|
||||||
from .. import openapi
|
from .. import openapi
|
||||||
from ..utils import is_list_view
|
from ..utils import decimal_as_float, is_list_view
|
||||||
|
|
||||||
#: Sentinel value that inspectors must return to signal that they do not know how to handle an object
|
#: Sentinel value that inspectors must return to signal that they do not know how to handle an object
|
||||||
NotHandled = object()
|
NotHandled = object()
|
||||||
|
|
@ -224,6 +224,8 @@ class FieldInspector(BaseInspector):
|
||||||
# JSON roundtrip ensures that the value is valid JSON;
|
# JSON roundtrip ensures that the value is valid JSON;
|
||||||
# for example, sets and tuples get transformed into lists
|
# for example, sets and tuples get transformed into lists
|
||||||
default = json.loads(json.dumps(default, cls=encoders.JSONEncoder))
|
default = json.loads(json.dumps(default, cls=encoders.JSONEncoder))
|
||||||
|
if decimal_as_float(field):
|
||||||
|
default = float(default)
|
||||||
except Exception: # pragma: no cover
|
except Exception: # pragma: no cover
|
||||||
logger.warning("'default' on schema for %s will not be set because "
|
logger.warning("'default' on schema for %s will not be set because "
|
||||||
"to_representation raised an exception", field, exc_info=True)
|
"to_representation raised an exception", field, exc_info=True)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import operator
|
import operator
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
from django.core import validators
|
from django.core import validators
|
||||||
from django.db import models
|
from django.db import models
|
||||||
|
|
@ -8,7 +9,7 @@ from rest_framework.settings import api_settings as rest_framework_settings
|
||||||
|
|
||||||
from .. import openapi
|
from .. import openapi
|
||||||
from ..errors import SwaggerGenerationError
|
from ..errors import SwaggerGenerationError
|
||||||
from ..utils import filter_none
|
from ..utils import decimal_as_float, filter_none
|
||||||
from .base import FieldInspector, NotHandled, SerializerInspector
|
from .base import FieldInspector, NotHandled, SerializerInspector
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -258,18 +259,29 @@ def find_limits(field):
|
||||||
if isinstance(field, field_class)
|
if isinstance(field, field_class)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if isinstance(field, serializers.DecimalField) and not decimal_as_float(field):
|
||||||
|
return limits
|
||||||
|
|
||||||
for validator in field.validators:
|
for validator in field.validators:
|
||||||
if not hasattr(validator, 'limit_value'):
|
if not hasattr(validator, 'limit_value'):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
limit_value = validator.limit_value
|
||||||
|
if isinstance(limit_value, Decimal) and decimal_as_float(field):
|
||||||
|
limit_value = float(limit_value)
|
||||||
|
|
||||||
for validator_class, attr, improves in applicable_limits:
|
for validator_class, attr, improves in applicable_limits:
|
||||||
if isinstance(validator, validator_class):
|
if isinstance(validator, validator_class):
|
||||||
if attr not in limits or improves(validator.limit_value, limits[attr]):
|
if attr not in limits or improves(limit_value, limits[attr]):
|
||||||
limits[attr] = validator.limit_value
|
limits[attr] = limit_value
|
||||||
|
|
||||||
return OrderedDict(sorted(limits.items()))
|
return OrderedDict(sorted(limits.items()))
|
||||||
|
|
||||||
|
|
||||||
|
def decimal_field_type(field):
|
||||||
|
return openapi.TYPE_NUMBER if decimal_as_float(field) else openapi.TYPE_STRING
|
||||||
|
|
||||||
|
|
||||||
model_field_to_basic_type = [
|
model_field_to_basic_type = [
|
||||||
(models.AutoField, (openapi.TYPE_INTEGER, None)),
|
(models.AutoField, (openapi.TYPE_INTEGER, None)),
|
||||||
(models.BinaryField, (openapi.TYPE_STRING, openapi.FORMAT_BINARY)),
|
(models.BinaryField, (openapi.TYPE_STRING, openapi.FORMAT_BINARY)),
|
||||||
|
|
@ -277,7 +289,7 @@ model_field_to_basic_type = [
|
||||||
(models.NullBooleanField, (openapi.TYPE_BOOLEAN, None)),
|
(models.NullBooleanField, (openapi.TYPE_BOOLEAN, None)),
|
||||||
(models.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
|
(models.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
|
||||||
(models.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
|
(models.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
|
||||||
(models.DecimalField, (openapi.TYPE_NUMBER, None)),
|
(models.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)),
|
||||||
(models.DurationField, (openapi.TYPE_INTEGER, None)),
|
(models.DurationField, (openapi.TYPE_INTEGER, None)),
|
||||||
(models.FloatField, (openapi.TYPE_NUMBER, None)),
|
(models.FloatField, (openapi.TYPE_NUMBER, None)),
|
||||||
(models.IntegerField, (openapi.TYPE_INTEGER, None)),
|
(models.IntegerField, (openapi.TYPE_INTEGER, None)),
|
||||||
|
|
@ -300,9 +312,11 @@ serializer_field_to_basic_type = [
|
||||||
(serializers.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
|
(serializers.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
|
||||||
(serializers.RegexField, (openapi.TYPE_STRING, None)),
|
(serializers.RegexField, (openapi.TYPE_STRING, None)),
|
||||||
(serializers.CharField, (openapi.TYPE_STRING, None)),
|
(serializers.CharField, (openapi.TYPE_STRING, None)),
|
||||||
((serializers.BooleanField, serializers.NullBooleanField), (openapi.TYPE_BOOLEAN, None)),
|
(serializers.BooleanField, (openapi.TYPE_BOOLEAN, None)),
|
||||||
|
(serializers.NullBooleanField, (openapi.TYPE_BOOLEAN, None)),
|
||||||
(serializers.IntegerField, (openapi.TYPE_INTEGER, None)),
|
(serializers.IntegerField, (openapi.TYPE_INTEGER, None)),
|
||||||
((serializers.FloatField, serializers.DecimalField), (openapi.TYPE_NUMBER, None)),
|
(serializers.FloatField, (openapi.TYPE_NUMBER, None)),
|
||||||
|
(serializers.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)),
|
||||||
(serializers.DurationField, (openapi.TYPE_NUMBER, None)), # ?
|
(serializers.DurationField, (openapi.TYPE_NUMBER, None)), # ?
|
||||||
(serializers.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
|
(serializers.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
|
||||||
(serializers.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
|
(serializers.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
|
||||||
|
|
@ -326,6 +340,8 @@ def get_basic_type_info(field):
|
||||||
for field_class, type_format in basic_type_info:
|
for field_class, type_format in basic_type_info:
|
||||||
if isinstance(field, field_class):
|
if isinstance(field, field_class):
|
||||||
swagger_type, format = type_format
|
swagger_type, format = type_format
|
||||||
|
if callable(swagger_type):
|
||||||
|
swagger_type = swagger_type(field)
|
||||||
if callable(format):
|
if callable(format):
|
||||||
format = format(field)
|
format = format(field)
|
||||||
break
|
break
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,7 @@ FORMAT_URI = "uri" #:
|
||||||
# pulled out of my ass
|
# pulled out of my ass
|
||||||
FORMAT_UUID = "uuid" #:
|
FORMAT_UUID = "uuid" #:
|
||||||
FORMAT_SLUG = "slug" #:
|
FORMAT_SLUG = "slug" #:
|
||||||
|
FORMAT_DECIMAL = "decimal"
|
||||||
|
|
||||||
IN_BODY = 'body' #:
|
IN_BODY = 'body' #:
|
||||||
IN_PATH = 'path' #:
|
IN_PATH = 'path' #:
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,11 @@ import inspect
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from django.db import models
|
||||||
from rest_framework import serializers, status
|
from rest_framework import serializers, status
|
||||||
from rest_framework.mixins import DestroyModelMixin, RetrieveModelMixin, UpdateModelMixin
|
from rest_framework.mixins import DestroyModelMixin, RetrieveModelMixin, UpdateModelMixin
|
||||||
from rest_framework.request import is_form_media_type
|
from rest_framework.request import is_form_media_type
|
||||||
|
from rest_framework.settings import api_settings as rest_framework_settings
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -273,3 +275,15 @@ def get_produces(renderer_classes):
|
||||||
media_types = [renderer.media_type for renderer in renderer_classes or []]
|
media_types = [renderer.media_type for renderer in renderer_classes or []]
|
||||||
media_types = [encoding for encoding in media_types if 'html' not in encoding]
|
media_types = [encoding for encoding in media_types if 'html' not in encoding]
|
||||||
return media_types
|
return media_types
|
||||||
|
|
||||||
|
|
||||||
|
def decimal_as_float(field):
|
||||||
|
"""
|
||||||
|
Returns true if ``field`` is a django-rest-framework DecimalField and its ``coerce_to_string`` attribute or the
|
||||||
|
``COERCE_DECIMAL_TO_STRING`` setting is set to ``False``.
|
||||||
|
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
if isinstance(field, serializers.DecimalField) or isinstance(field, models.DecimalField):
|
||||||
|
return not getattr(field, 'coerce_to_string', rest_framework_settings.COERCE_DECIMAL_TO_STRING)
|
||||||
|
return False
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,13 @@
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
from django.contrib.auth import get_user_model
|
from django.contrib.auth import get_user_model
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
|
from rest_framework.compat import MinValueValidator
|
||||||
|
|
||||||
from snippets.models import LANGUAGE_CHOICES, STYLE_CHOICES, Snippet
|
from snippets.models import LANGUAGE_CHOICES, STYLE_CHOICES, Snippet
|
||||||
|
|
||||||
|
|
||||||
class LanguageSerializer(serializers.Serializer):
|
class LanguageSerializer(serializers.Serializer):
|
||||||
|
|
||||||
name = serializers.ChoiceField(
|
name = serializers.ChoiceField(
|
||||||
choices=LANGUAGE_CHOICES, default='python', help_text='The name of the programming language')
|
choices=LANGUAGE_CHOICES, default='python', help_text='The name of the programming language')
|
||||||
|
|
||||||
|
|
@ -14,7 +16,6 @@ class LanguageSerializer(serializers.Serializer):
|
||||||
|
|
||||||
|
|
||||||
class ExampleProjectSerializer(serializers.Serializer):
|
class ExampleProjectSerializer(serializers.Serializer):
|
||||||
|
|
||||||
project_name = serializers.CharField(help_text='Name of the project')
|
project_name = serializers.CharField(help_text='Name of the project')
|
||||||
github_repo = serializers.CharField(required=True, help_text='Github repository of the project')
|
github_repo = serializers.CharField(required=True, help_text='Github repository of the project')
|
||||||
|
|
||||||
|
|
@ -49,6 +50,10 @@ class SnippetSerializer(serializers.Serializer):
|
||||||
example_projects = serializers.ListSerializer(child=ExampleProjectSerializer(), read_only=True)
|
example_projects = serializers.ListSerializer(child=ExampleProjectSerializer(), read_only=True)
|
||||||
difficulty_factor = serializers.FloatField(help_text="this is here just to test FloatField",
|
difficulty_factor = serializers.FloatField(help_text="this is here just to test FloatField",
|
||||||
read_only=True, default=lambda: 6.9)
|
read_only=True, default=lambda: 6.9)
|
||||||
|
rate_as_string = serializers.DecimalField(max_digits=6, decimal_places=3, default=Decimal('0.0'),
|
||||||
|
validators=[MinValueValidator(Decimal('0.0'))])
|
||||||
|
rate = serializers.DecimalField(max_digits=6, decimal_places=3, default=Decimal('0.0'), coerce_to_string=False,
|
||||||
|
validators=[MinValueValidator(Decimal('0.0'))])
|
||||||
|
|
||||||
def create(self, validated_data):
|
def create(self, validated_data):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1036,6 +1036,17 @@ definitions:
|
||||||
type: number
|
type: number
|
||||||
readOnly: true
|
readOnly: true
|
||||||
default: 6.9
|
default: 6.9
|
||||||
|
rateAsString:
|
||||||
|
title: Rate as string
|
||||||
|
type: string
|
||||||
|
format: decimal
|
||||||
|
default: '0.000'
|
||||||
|
rate:
|
||||||
|
title: Rate
|
||||||
|
type: number
|
||||||
|
format: decimal
|
||||||
|
default: 0.0
|
||||||
|
minimum: 0.0
|
||||||
UserSerializerrr:
|
UserSerializerrr:
|
||||||
required:
|
required:
|
||||||
- username
|
- username
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue