Fix union type hint checks (#318)

Fix some obscure edge cases related to typing.Union type args.

Fixes #304.
master
Roman Sichny 2019-02-22 01:00:14 +02:00 committed by Cristi Vîjdea
parent 3d43ee6748
commit 76c8fe0646
2 changed files with 22 additions and 4 deletions

View File

@ -490,6 +490,20 @@ if typing:
hinting_type_info.append(((typing.Sequence, typing.AbstractSet), inspect_collection_hint_class)) hinting_type_info.append(((typing.Sequence, typing.AbstractSet), inspect_collection_hint_class))
def _get_union_types(hint_class):
if typing:
origin_type = get_origin_type(hint_class)
if origin_type is typing.Union:
return hint_class.__args__
try:
# python 3.5.2 and lower compatibility
if issubclass(origin_type, typing.Union):
return hint_class.__union_params__
except TypeError:
pass
return None
def get_basic_type_info_from_hint(hint_class): def get_basic_type_info_from_hint(hint_class):
"""Given a class (eg from a SerializerMethodField's return type hint, """Given a class (eg from a SerializerMethodField's return type hint,
return its basic type information - ``type``, ``format``, ``pattern``, return its basic type information - ``type``, ``format``, ``pattern``,
@ -499,11 +513,11 @@ def get_basic_type_info_from_hint(hint_class):
:return: the extracted attributes as a dictionary, or ``None`` if the field type is not known :return: the extracted attributes as a dictionary, or ``None`` if the field type is not known
:rtype: OrderedDict :rtype: OrderedDict
""" """
if typing and get_origin_type(hint_class) == typing.Union: union_types = _get_union_types(hint_class)
if typing and union_types:
# Optional is implemented as Union[T, None] # Optional is implemented as Union[T, None]
if len(hint_class.__args__) == 2 and hint_class.__args__[1] == type(None): # noqa: E721 if len(union_types) == 2 and isinstance(None, union_types[1]):
child_type = hint_class.__args__[0] result = get_basic_type_info_from_hint(union_types[0])
result = get_basic_type_info_from_hint(child_type)
result['x-nullable'] = True result['x-nullable'] = True
return result return result

View File

@ -33,6 +33,10 @@ if typing:
# Following cases are not 100% correct, but it should work somehow and not crash. # Following cases are not 100% correct, but it should work somehow and not crash.
(Union[int, float], None), (Union[int, float], None),
(List, {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_STRING)}), (List, {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_STRING)}),
('SomeType', None),
(type('SomeType', (object,), {}), None),
(None, None),
(6, None),
]) ])
def test_get_basic_type_info_from_hint(hint_class, expected_swagger_type_info): def test_get_basic_type_info_from_hint(hint_class, expected_swagger_type_info):
type_info = get_basic_type_info_from_hint(hint_class) type_info = get_basic_type_info_from_hint(hint_class)