diff --git a/src/drf_yasg/inspectors/field.py b/src/drf_yasg/inspectors/field.py index f302a0e..b7cbe9b 100644 --- a/src/drf_yasg/inspectors/field.py +++ b/src/drf_yasg/inspectors/field.py @@ -586,6 +586,7 @@ class ChoiceFieldInspector(FieldInspector): if isinstance(field, serializers.ChoiceField): enum_type = openapi.TYPE_STRING + enum_values = list(field.choices.keys()) # for ModelSerializer, try to infer the type from the associated model field serializer = get_parent_serializer(field) @@ -596,8 +597,14 @@ class ChoiceFieldInspector(FieldInspector): model_type = get_basic_type_info(model_field) if model_type: enum_type = model_type.get('type', enum_type) + else: + # Try to infer field type based on enum values + enum_value_types = {type(v) for v in enum_values} + if len(enum_value_types) == 1: + values_type = get_basic_type_info_from_hint(next(iter(enum_value_types))) + if values_type: + enum_type = values_type.get('type', enum_type) - enum_values = list(field.choices.keys()) if isinstance(field, serializers.MultipleChoiceField): result = SwaggerType( type=openapi.TYPE_ARRAY, diff --git a/tests/test_schema_generator.py b/tests/test_schema_generator.py index a52446d..2a74c95 100644 --- a/tests/test_schema_generator.py +++ b/tests/test_schema_generator.py @@ -190,3 +190,32 @@ def test_action_mapping(): assert action_ops['post']['description'] == 'mapping docstring post' assert action_ops['get']['description'] == 'mapping docstring get/delete' assert action_ops['delete']['description'] == 'mapping docstring get/delete' + + +@pytest.mark.parametrize('choices, expected_type', [ + (['A', 'B'], openapi.TYPE_STRING), + ([123, 456], openapi.TYPE_INTEGER), + ([1.2, 3.4], openapi.TYPE_NUMBER), + (['A', 456], openapi.TYPE_STRING) +]) +def test_choice_field(choices, expected_type): + class DetailSerializer(serializers.Serializer): + detail = serializers.ChoiceField(choices) + + class DetailViewSet(viewsets.ViewSet): + @swagger_auto_schema(responses={200: openapi.Response("OK", DetailSerializer)}) + def retrieve(self, request, pk=None): + return Response({'detail': None}) + + router = routers.DefaultRouter() + router.register(r'details', DetailViewSet, base_name='details') + + generator = OpenAPISchemaGenerator( + info=openapi.Info(title="Test generator", default_version="v1"), + patterns=router.urls + ) + + swagger = generator.get_schema(None, True) + property_schema = swagger['definitions']['Detail']['properties']['detail'] + + assert property_schema == openapi.Schema(title='Detail', type=expected_type, enum=choices)