Fix union type hint checks (#318)
Fix some obscure edge cases related to typing.Union type args. Fixes #304.master
parent
3d43ee6748
commit
76c8fe0646
|
|
@ -490,6 +490,20 @@ if typing:
|
||||||
hinting_type_info.append(((typing.Sequence, typing.AbstractSet), inspect_collection_hint_class))
|
hinting_type_info.append(((typing.Sequence, typing.AbstractSet), inspect_collection_hint_class))
|
||||||
|
|
||||||
|
|
||||||
|
def _get_union_types(hint_class):
|
||||||
|
if typing:
|
||||||
|
origin_type = get_origin_type(hint_class)
|
||||||
|
if origin_type is typing.Union:
|
||||||
|
return hint_class.__args__
|
||||||
|
try:
|
||||||
|
# python 3.5.2 and lower compatibility
|
||||||
|
if issubclass(origin_type, typing.Union):
|
||||||
|
return hint_class.__union_params__
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_basic_type_info_from_hint(hint_class):
|
def get_basic_type_info_from_hint(hint_class):
|
||||||
"""Given a class (eg from a SerializerMethodField's return type hint,
|
"""Given a class (eg from a SerializerMethodField's return type hint,
|
||||||
return its basic type information - ``type``, ``format``, ``pattern``,
|
return its basic type information - ``type``, ``format``, ``pattern``,
|
||||||
|
|
@ -499,11 +513,11 @@ def get_basic_type_info_from_hint(hint_class):
|
||||||
:return: the extracted attributes as a dictionary, or ``None`` if the field type is not known
|
:return: the extracted attributes as a dictionary, or ``None`` if the field type is not known
|
||||||
:rtype: OrderedDict
|
:rtype: OrderedDict
|
||||||
"""
|
"""
|
||||||
if typing and get_origin_type(hint_class) == typing.Union:
|
union_types = _get_union_types(hint_class)
|
||||||
|
if typing and union_types:
|
||||||
# Optional is implemented as Union[T, None]
|
# Optional is implemented as Union[T, None]
|
||||||
if len(hint_class.__args__) == 2 and hint_class.__args__[1] == type(None): # noqa: E721
|
if len(union_types) == 2 and isinstance(None, union_types[1]):
|
||||||
child_type = hint_class.__args__[0]
|
result = get_basic_type_info_from_hint(union_types[0])
|
||||||
result = get_basic_type_info_from_hint(child_type)
|
|
||||||
result['x-nullable'] = True
|
result['x-nullable'] = True
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,10 @@ if typing:
|
||||||
# Following cases are not 100% correct, but it should work somehow and not crash.
|
# Following cases are not 100% correct, but it should work somehow and not crash.
|
||||||
(Union[int, float], None),
|
(Union[int, float], None),
|
||||||
(List, {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_STRING)}),
|
(List, {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_STRING)}),
|
||||||
|
('SomeType', None),
|
||||||
|
(type('SomeType', (object,), {}), None),
|
||||||
|
(None, None),
|
||||||
|
(6, None),
|
||||||
])
|
])
|
||||||
def test_get_basic_type_info_from_hint(hint_class, expected_swagger_type_info):
|
def test_get_basic_type_info_from_hint(hint_class, expected_swagger_type_info):
|
||||||
type_info = get_basic_type_info_from_hint(hint_class)
|
type_info = get_basic_type_info_from_hint(hint_class)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue