Add support for custom and collection type hint classes (#272)
parent
58e6dae548
commit
3806d6efd5
|
|
@ -455,18 +455,39 @@ def decimal_return_type():
|
||||||
return openapi.TYPE_STRING if rest_framework_settings.COERCE_DECIMAL_TO_STRING else openapi.TYPE_NUMBER
|
return openapi.TYPE_STRING if rest_framework_settings.COERCE_DECIMAL_TO_STRING else openapi.TYPE_NUMBER
|
||||||
|
|
||||||
|
|
||||||
raw_type_info = [
|
def get_origin_type(hint_class):
|
||||||
|
return getattr(hint_class, '__origin__', None) or hint_class
|
||||||
|
|
||||||
|
|
||||||
|
def is_origin_type_subclasses(hint_class, check_class):
|
||||||
|
origin_type = get_origin_type(hint_class)
|
||||||
|
return inspect.isclass(origin_type) and issubclass(origin_type, check_class)
|
||||||
|
|
||||||
|
|
||||||
|
hinting_type_info = [
|
||||||
(bool, (openapi.TYPE_BOOLEAN, None)),
|
(bool, (openapi.TYPE_BOOLEAN, None)),
|
||||||
(int, (openapi.TYPE_INTEGER, None)),
|
(int, (openapi.TYPE_INTEGER, None)),
|
||||||
|
(str, (openapi.TYPE_STRING, None)),
|
||||||
(float, (openapi.TYPE_NUMBER, None)),
|
(float, (openapi.TYPE_NUMBER, None)),
|
||||||
|
(dict, (openapi.TYPE_OBJECT, None)),
|
||||||
(Decimal, (decimal_return_type, openapi.FORMAT_DECIMAL)),
|
(Decimal, (decimal_return_type, openapi.FORMAT_DECIMAL)),
|
||||||
(uuid.UUID, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
|
(uuid.UUID, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
|
||||||
(datetime.datetime, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
|
(datetime.datetime, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
|
||||||
(datetime.date, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
|
(datetime.date, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
|
||||||
# TODO - support typing.List etc
|
|
||||||
]
|
]
|
||||||
|
|
||||||
hinting_type_info = raw_type_info
|
if typing:
|
||||||
|
def inspect_collection_hint_class(hint_class):
|
||||||
|
args = hint_class.__args__
|
||||||
|
child_class = args[0] if args else str
|
||||||
|
child_type_info = get_basic_type_info_from_hint(child_class)
|
||||||
|
if not child_type_info:
|
||||||
|
child_type_info = {'type': openapi.TYPE_STRING}
|
||||||
|
return OrderedDict([
|
||||||
|
('type', openapi.TYPE_ARRAY),
|
||||||
|
('items', openapi.Items(**child_type_info)),
|
||||||
|
])
|
||||||
|
hinting_type_info.append(((typing.Sequence, typing.AbstractSet), inspect_collection_hint_class))
|
||||||
|
|
||||||
|
|
||||||
def get_basic_type_info_from_hint(hint_class):
|
def get_basic_type_info_from_hint(hint_class):
|
||||||
|
|
@ -478,27 +499,25 @@ def get_basic_type_info_from_hint(hint_class):
|
||||||
:return: the extracted attributes as a dictionary, or ``None`` if the field type is not known
|
:return: the extracted attributes as a dictionary, or ``None`` if the field type is not known
|
||||||
:rtype: OrderedDict
|
:rtype: OrderedDict
|
||||||
"""
|
"""
|
||||||
|
if typing and get_origin_type(hint_class) == typing.Union:
|
||||||
for check_class, type_format in hinting_type_info:
|
if len(hint_class.__args__) == 2 and hint_class.__args__[1] == type(None):
|
||||||
if issubclass(hint_class, check_class):
|
child_type = hint_class.__args__[0]
|
||||||
swagger_type, format = type_format
|
return get_basic_type_info_from_hint(child_type)
|
||||||
if callable(swagger_type):
|
|
||||||
swagger_type = swagger_type()
|
|
||||||
# if callable(format):
|
|
||||||
# format = format(klass)
|
|
||||||
break
|
|
||||||
else: # pragma: no cover
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
pattern = None
|
for check_class, info in hinting_type_info:
|
||||||
|
if is_origin_type_subclasses(hint_class, check_class):
|
||||||
result = OrderedDict([
|
if callable(info):
|
||||||
|
return info(hint_class)
|
||||||
|
swagger_type, format = info
|
||||||
|
if callable(swagger_type):
|
||||||
|
swagger_type = swagger_type()
|
||||||
|
return OrderedDict([
|
||||||
('type', swagger_type),
|
('type', swagger_type),
|
||||||
('format', format),
|
('format', format),
|
||||||
('pattern', pattern)
|
|
||||||
])
|
])
|
||||||
|
|
||||||
return result
|
return None
|
||||||
|
|
||||||
|
|
||||||
class SerializerMethodFieldInspector(FieldInspector):
|
class SerializerMethodFieldInspector(FieldInspector):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from drf_yasg import openapi
|
||||||
|
from drf_yasg.inspectors.field import get_basic_type_info_from_hint
|
||||||
|
|
||||||
|
try:
|
||||||
|
import typing
|
||||||
|
from typing import Dict, List, Union, Optional, Set
|
||||||
|
except ImportError:
|
||||||
|
typing = None
|
||||||
|
|
||||||
|
|
||||||
|
if typing:
|
||||||
|
@pytest.mark.parametrize('hint_class, expected_swagger_type_info', [
|
||||||
|
(int, {'type': openapi.TYPE_INTEGER, 'format': None}),
|
||||||
|
(str, {'type': openapi.TYPE_STRING, 'format': None}),
|
||||||
|
(bool, {'type': openapi.TYPE_BOOLEAN, 'format': None}),
|
||||||
|
(dict, {'type': openapi.TYPE_OBJECT, 'format': None}),
|
||||||
|
(Dict[int, int], {'type': openapi.TYPE_OBJECT, 'format': None}),
|
||||||
|
(uuid.UUID, {'type': openapi.TYPE_STRING, 'format': openapi.FORMAT_UUID}),
|
||||||
|
(List[int], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER)}),
|
||||||
|
(List[str], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_STRING)}),
|
||||||
|
(List[bool], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_BOOLEAN)}),
|
||||||
|
(Set[int], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER)}),
|
||||||
|
(Optional[bool], {'type': openapi.TYPE_BOOLEAN, 'format': None}),
|
||||||
|
(Optional[List[int]], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER)}),
|
||||||
|
(Union[List[int], type(None)], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER)}),
|
||||||
|
# Following cases are not 100% correct, but it should work somehow and not crash.
|
||||||
|
(Union[int, float], None),
|
||||||
|
(List, {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_STRING)}),
|
||||||
|
])
|
||||||
|
def test_get_basic_type_info_from_hint(hint_class, expected_swagger_type_info):
|
||||||
|
type_info = get_basic_type_info_from_hint(hint_class)
|
||||||
|
assert type_info == expected_swagger_type_info
|
||||||
Loading…
Reference in New Issue