diff --git a/src/drf_yasg/inspectors/field.py b/src/drf_yasg/inspectors/field.py index 4a58964..27369b3 100644 --- a/src/drf_yasg/inspectors/field.py +++ b/src/drf_yasg/inspectors/field.py @@ -13,7 +13,9 @@ from rest_framework.settings import api_settings as rest_framework_settings from .. import openapi from ..errors import SwaggerGenerationError -from ..utils import decimal_as_float, filter_none, get_serializer_class, get_serializer_ref_name +from ..utils import ( + decimal_as_float, field_value_to_representation, filter_none, get_serializer_class, get_serializer_ref_name +) from .base import FieldInspector, NotHandled, SerializerInspector, call_view_method try: @@ -635,7 +637,14 @@ class ChoiceFieldInspector(FieldInspector): if isinstance(field, serializers.ChoiceField): enum_type = openapi.TYPE_STRING - enum_values = list(field.choices.keys()) + enum_values = [] + for choice in field.choices.keys(): + if isinstance(field, serializers.MultipleChoiceField): + choice = field_value_to_representation(field, [choice])[0] + else: + choice = field_value_to_representation(field, choice) + + enum_values.append(choice) # for ModelSerializer, try to infer the type from the associated model field serializer = get_parent_serializer(field) diff --git a/src/drf_yasg/utils.py b/src/drf_yasg/utils.py index 9621fc2..83270ed 100644 --- a/src/drf_yasg/utils.py +++ b/src/drf_yasg/utils.py @@ -2,6 +2,7 @@ import inspect import logging import sys from collections import OrderedDict +from decimal import Decimal from django.db import models from django.utils.encoding import force_text @@ -441,6 +442,25 @@ def force_real_str(s, encoding='utf-8', strings_only=False, errors='strict'): return s +def field_value_to_representation(field, value): + """Convert a python value related to a field (default, choices, etc.) into its OpenAPI-compatible representation. + + :param serializers.Field field: field associated with the value + :param obj value: value + :return: the converted value + """ + value = field.to_representation(value) + if isinstance(value, Decimal): + if decimal_as_float(field): + value = float(value) + else: + value = str(value) + + # JSON roundtrip ensures that the value is valid JSON; + # for example, sets and tuples get transformed into lists + return json.loads(json.dumps(value, cls=encoders.JSONEncoder)) + + def get_field_default(field): """ Get the default value for a field, converted to a JSON-compatible value while properly handling callables. @@ -462,12 +482,7 @@ def get_field_default(field): if default is not serializers.empty and default is not None: try: - default = field.to_representation(default) - # JSON roundtrip ensures that the value is valid JSON; - # for example, sets and tuples get transformed into lists - default = json.loads(json.dumps(default, cls=encoders.JSONEncoder)) - if decimal_as_float(field): - default = float(default) + default = field_value_to_representation(field, default) except Exception: # pragma: no cover logger.warning("'default' on schema for %s will not be set because " "to_representation raised an exception", field, exc_info=True)