Add enum type inference based on choices values (#264)
parent
f587785eb4
commit
f6544654ab
|
|
@ -586,6 +586,7 @@ class ChoiceFieldInspector(FieldInspector):
|
||||||
|
|
||||||
if isinstance(field, serializers.ChoiceField):
|
if isinstance(field, serializers.ChoiceField):
|
||||||
enum_type = openapi.TYPE_STRING
|
enum_type = openapi.TYPE_STRING
|
||||||
|
enum_values = list(field.choices.keys())
|
||||||
|
|
||||||
# for ModelSerializer, try to infer the type from the associated model field
|
# for ModelSerializer, try to infer the type from the associated model field
|
||||||
serializer = get_parent_serializer(field)
|
serializer = get_parent_serializer(field)
|
||||||
|
|
@ -596,8 +597,14 @@ class ChoiceFieldInspector(FieldInspector):
|
||||||
model_type = get_basic_type_info(model_field)
|
model_type = get_basic_type_info(model_field)
|
||||||
if model_type:
|
if model_type:
|
||||||
enum_type = model_type.get('type', enum_type)
|
enum_type = model_type.get('type', enum_type)
|
||||||
|
else:
|
||||||
|
# Try to infer field type based on enum values
|
||||||
|
enum_value_types = {type(v) for v in enum_values}
|
||||||
|
if len(enum_value_types) == 1:
|
||||||
|
values_type = get_basic_type_info_from_hint(next(iter(enum_value_types)))
|
||||||
|
if values_type:
|
||||||
|
enum_type = values_type.get('type', enum_type)
|
||||||
|
|
||||||
enum_values = list(field.choices.keys())
|
|
||||||
if isinstance(field, serializers.MultipleChoiceField):
|
if isinstance(field, serializers.MultipleChoiceField):
|
||||||
result = SwaggerType(
|
result = SwaggerType(
|
||||||
type=openapi.TYPE_ARRAY,
|
type=openapi.TYPE_ARRAY,
|
||||||
|
|
|
||||||
|
|
@ -190,3 +190,32 @@ def test_action_mapping():
|
||||||
assert action_ops['post']['description'] == 'mapping docstring post'
|
assert action_ops['post']['description'] == 'mapping docstring post'
|
||||||
assert action_ops['get']['description'] == 'mapping docstring get/delete'
|
assert action_ops['get']['description'] == 'mapping docstring get/delete'
|
||||||
assert action_ops['delete']['description'] == 'mapping docstring get/delete'
|
assert action_ops['delete']['description'] == 'mapping docstring get/delete'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('choices, expected_type', [
|
||||||
|
(['A', 'B'], openapi.TYPE_STRING),
|
||||||
|
([123, 456], openapi.TYPE_INTEGER),
|
||||||
|
([1.2, 3.4], openapi.TYPE_NUMBER),
|
||||||
|
(['A', 456], openapi.TYPE_STRING)
|
||||||
|
])
|
||||||
|
def test_choice_field(choices, expected_type):
|
||||||
|
class DetailSerializer(serializers.Serializer):
|
||||||
|
detail = serializers.ChoiceField(choices)
|
||||||
|
|
||||||
|
class DetailViewSet(viewsets.ViewSet):
|
||||||
|
@swagger_auto_schema(responses={200: openapi.Response("OK", DetailSerializer)})
|
||||||
|
def retrieve(self, request, pk=None):
|
||||||
|
return Response({'detail': None})
|
||||||
|
|
||||||
|
router = routers.DefaultRouter()
|
||||||
|
router.register(r'details', DetailViewSet, base_name='details')
|
||||||
|
|
||||||
|
generator = OpenAPISchemaGenerator(
|
||||||
|
info=openapi.Info(title="Test generator", default_version="v1"),
|
||||||
|
patterns=router.urls
|
||||||
|
)
|
||||||
|
|
||||||
|
swagger = generator.get_schema(None, True)
|
||||||
|
property_schema = swagger['definitions']['Detail']['properties']['detail']
|
||||||
|
|
||||||
|
assert property_schema == openapi.Schema(title='Detail', type=expected_type, enum=choices)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue