diff --git a/src/drf_yasg/inspectors/field.py b/src/drf_yasg/inspectors/field.py index 13913ed..b5af672 100644 --- a/src/drf_yasg/inspectors/field.py +++ b/src/drf_yasg/inspectors/field.py @@ -490,6 +490,20 @@ if typing: hinting_type_info.append(((typing.Sequence, typing.AbstractSet), inspect_collection_hint_class)) +def _get_union_types(hint_class): + if typing: + origin_type = get_origin_type(hint_class) + if origin_type is typing.Union: + return hint_class.__args__ + try: + # python 3.5.2 and lower compatibility + if issubclass(origin_type, typing.Union): + return hint_class.__union_params__ + except TypeError: + pass + return None + + def get_basic_type_info_from_hint(hint_class): """Given a class (eg from a SerializerMethodField's return type hint, return its basic type information - ``type``, ``format``, ``pattern``, @@ -499,11 +513,11 @@ def get_basic_type_info_from_hint(hint_class): :return: the extracted attributes as a dictionary, or ``None`` if the field type is not known :rtype: OrderedDict """ - if typing and get_origin_type(hint_class) == typing.Union: + union_types = _get_union_types(hint_class) + if typing and union_types: # Optional is implemented as Union[T, None] - if len(hint_class.__args__) == 2 and hint_class.__args__[1] == type(None): # noqa: E721 - child_type = hint_class.__args__[0] - result = get_basic_type_info_from_hint(child_type) + if len(union_types) == 2 and isinstance(None, union_types[1]): + result = get_basic_type_info_from_hint(union_types[0]) result['x-nullable'] = True return result diff --git a/tests/test_get_basic_type_info_from_hint.py b/tests/test_get_basic_type_info_from_hint.py index db45fef..e21e597 100644 --- a/tests/test_get_basic_type_info_from_hint.py +++ b/tests/test_get_basic_type_info_from_hint.py @@ -33,6 +33,10 @@ if typing: # Following cases are not 100% correct, but it should work somehow and not crash. (Union[int, float], None), (List, {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_STRING)}), + ('SomeType', None), + (type('SomeType', (object,), {}), None), + (None, None), + (6, None), ]) def test_get_basic_type_info_from_hint(hint_class, expected_swagger_type_info): type_info = get_basic_type_info_from_hint(hint_class)