Add enum type inference based on choices values (#264)
parent
f587785eb4
commit
f6544654ab
|
|
@ -586,6 +586,7 @@ class ChoiceFieldInspector(FieldInspector):
|
|||
|
||||
if isinstance(field, serializers.ChoiceField):
|
||||
enum_type = openapi.TYPE_STRING
|
||||
enum_values = list(field.choices.keys())
|
||||
|
||||
# for ModelSerializer, try to infer the type from the associated model field
|
||||
serializer = get_parent_serializer(field)
|
||||
|
|
@ -596,8 +597,14 @@ class ChoiceFieldInspector(FieldInspector):
|
|||
model_type = get_basic_type_info(model_field)
|
||||
if model_type:
|
||||
enum_type = model_type.get('type', enum_type)
|
||||
else:
|
||||
# Try to infer field type based on enum values
|
||||
enum_value_types = {type(v) for v in enum_values}
|
||||
if len(enum_value_types) == 1:
|
||||
values_type = get_basic_type_info_from_hint(next(iter(enum_value_types)))
|
||||
if values_type:
|
||||
enum_type = values_type.get('type', enum_type)
|
||||
|
||||
enum_values = list(field.choices.keys())
|
||||
if isinstance(field, serializers.MultipleChoiceField):
|
||||
result = SwaggerType(
|
||||
type=openapi.TYPE_ARRAY,
|
||||
|
|
|
|||
|
|
@ -190,3 +190,32 @@ def test_action_mapping():
|
|||
assert action_ops['post']['description'] == 'mapping docstring post'
|
||||
assert action_ops['get']['description'] == 'mapping docstring get/delete'
|
||||
assert action_ops['delete']['description'] == 'mapping docstring get/delete'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('choices, expected_type', [
|
||||
(['A', 'B'], openapi.TYPE_STRING),
|
||||
([123, 456], openapi.TYPE_INTEGER),
|
||||
([1.2, 3.4], openapi.TYPE_NUMBER),
|
||||
(['A', 456], openapi.TYPE_STRING)
|
||||
])
|
||||
def test_choice_field(choices, expected_type):
|
||||
class DetailSerializer(serializers.Serializer):
|
||||
detail = serializers.ChoiceField(choices)
|
||||
|
||||
class DetailViewSet(viewsets.ViewSet):
|
||||
@swagger_auto_schema(responses={200: openapi.Response("OK", DetailSerializer)})
|
||||
def retrieve(self, request, pk=None):
|
||||
return Response({'detail': None})
|
||||
|
||||
router = routers.DefaultRouter()
|
||||
router.register(r'details', DetailViewSet, base_name='details')
|
||||
|
||||
generator = OpenAPISchemaGenerator(
|
||||
info=openapi.Info(title="Test generator", default_version="v1"),
|
||||
patterns=router.urls
|
||||
)
|
||||
|
||||
swagger = generator.get_schema(None, True)
|
||||
property_schema = swagger['definitions']['Detail']['properties']['detail']
|
||||
|
||||
assert property_schema == openapi.Schema(title='Detail', type=expected_type, enum=choices)
|
||||
|
|
|
|||
Loading…
Reference in New Issue