diff --git a/src/drf_yasg/inspectors/field.py b/src/drf_yasg/inspectors/field.py index 6e457c2..13913ed 100644 --- a/src/drf_yasg/inspectors/field.py +++ b/src/drf_yasg/inspectors/field.py @@ -459,7 +459,7 @@ def get_origin_type(hint_class): return getattr(hint_class, '__origin__', None) or hint_class -def is_origin_type_subclasses(hint_class, check_class): +def hint_class_issubclass(hint_class, check_class): origin_type = get_origin_type(hint_class) return inspect.isclass(origin_type) and issubclass(origin_type, check_class) @@ -480,13 +480,13 @@ if typing: def inspect_collection_hint_class(hint_class): args = hint_class.__args__ child_class = args[0] if args else str - child_type_info = get_basic_type_info_from_hint(child_class) - if not child_type_info: - child_type_info = {'type': openapi.TYPE_STRING} + child_type_info = get_basic_type_info_from_hint(child_class) or {'type': openapi.TYPE_STRING} + return OrderedDict([ ('type', openapi.TYPE_ARRAY), ('items', openapi.Items(**child_type_info)), ]) + hinting_type_info.append(((typing.Sequence, typing.AbstractSet), inspect_collection_hint_class)) @@ -500,18 +500,24 @@ def get_basic_type_info_from_hint(hint_class): :rtype: OrderedDict """ if typing and get_origin_type(hint_class) == typing.Union: - if len(hint_class.__args__) == 2 and hint_class.__args__[1] == type(None): + # Optional is implemented as Union[T, None] + if len(hint_class.__args__) == 2 and hint_class.__args__[1] == type(None): # noqa: E721 child_type = hint_class.__args__[0] - return get_basic_type_info_from_hint(child_type) + result = get_basic_type_info_from_hint(child_type) + result['x-nullable'] = True + return result + return None for check_class, info in hinting_type_info: - if is_origin_type_subclasses(hint_class, check_class): + if hint_class_issubclass(hint_class, check_class): if callable(info): return info(hint_class) + swagger_type, format = info if callable(swagger_type): swagger_type = swagger_type() + return OrderedDict([ ('type', swagger_type), ('format', format), diff --git a/tests/test_get_basic_type_info_from_hint.py b/tests/test_get_basic_type_info_from_hint.py index 49c1f96..db45fef 100644 --- a/tests/test_get_basic_type_info_from_hint.py +++ b/tests/test_get_basic_type_info_from_hint.py @@ -11,7 +11,6 @@ try: except ImportError: typing = None - if typing: @pytest.mark.parametrize('hint_class, expected_swagger_type_info', [ (int, {'type': openapi.TYPE_INTEGER, 'format': None}), @@ -24,9 +23,13 @@ if typing: (List[str], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_STRING)}), (List[bool], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_BOOLEAN)}), (Set[int], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER)}), - (Optional[bool], {'type': openapi.TYPE_BOOLEAN, 'format': None}), - (Optional[List[int]], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER)}), - (Union[List[int], type(None)], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER)}), + (Optional[bool], {'type': openapi.TYPE_BOOLEAN, 'format': None, 'x-nullable': True}), + (Optional[List[int]], { + 'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER), 'x-nullable': True + }), + (Union[List[int], type(None)], { + 'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER), 'x-nullable': True + }), # Following cases are not 100% correct, but it should work somehow and not crash. (Union[int, float], None), (List, {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_STRING)}),