diff --git a/.idea/drf-yasg.iml b/.idea/drf-yasg.iml
index b60958d..4cb4d4e 100644
--- a/.idea/drf-yasg.iml
+++ b/.idea/drf-yasg.iml
@@ -44,4 +44,4 @@
-
\ No newline at end of file
+
diff --git a/.idea/misc.xml b/.idea/misc.xml
index e71a223..ab1b8f5 100644
--- a/.idea/misc.xml
+++ b/.idea/misc.xml
@@ -73,4 +73,4 @@
-
\ No newline at end of file
+
diff --git a/src/drf_yasg/inspectors/field.py b/src/drf_yasg/inspectors/field.py
index 770bbf3..6e457c2 100644
--- a/src/drf_yasg/inspectors/field.py
+++ b/src/drf_yasg/inspectors/field.py
@@ -455,18 +455,39 @@ def decimal_return_type():
return openapi.TYPE_STRING if rest_framework_settings.COERCE_DECIMAL_TO_STRING else openapi.TYPE_NUMBER
-raw_type_info = [
+def get_origin_type(hint_class):
+ return getattr(hint_class, '__origin__', None) or hint_class
+
+
+def is_origin_type_subclasses(hint_class, check_class):
+ origin_type = get_origin_type(hint_class)
+ return inspect.isclass(origin_type) and issubclass(origin_type, check_class)
+
+
+hinting_type_info = [
(bool, (openapi.TYPE_BOOLEAN, None)),
(int, (openapi.TYPE_INTEGER, None)),
+ (str, (openapi.TYPE_STRING, None)),
(float, (openapi.TYPE_NUMBER, None)),
+ (dict, (openapi.TYPE_OBJECT, None)),
(Decimal, (decimal_return_type, openapi.FORMAT_DECIMAL)),
(uuid.UUID, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
(datetime.datetime, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
(datetime.date, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
- # TODO - support typing.List etc
]
-hinting_type_info = raw_type_info
+if typing:
+ def inspect_collection_hint_class(hint_class):
+ args = hint_class.__args__
+ child_class = args[0] if args else str
+ child_type_info = get_basic_type_info_from_hint(child_class)
+ if not child_type_info:
+ child_type_info = {'type': openapi.TYPE_STRING}
+ return OrderedDict([
+ ('type', openapi.TYPE_ARRAY),
+ ('items', openapi.Items(**child_type_info)),
+ ])
+ hinting_type_info.append(((typing.Sequence, typing.AbstractSet), inspect_collection_hint_class))
def get_basic_type_info_from_hint(hint_class):
@@ -478,27 +499,25 @@ def get_basic_type_info_from_hint(hint_class):
:return: the extracted attributes as a dictionary, or ``None`` if the field type is not known
:rtype: OrderedDict
"""
-
- for check_class, type_format in hinting_type_info:
- if issubclass(hint_class, check_class):
- swagger_type, format = type_format
- if callable(swagger_type):
- swagger_type = swagger_type()
- # if callable(format):
- # format = format(klass)
- break
- else: # pragma: no cover
+ if typing and get_origin_type(hint_class) == typing.Union:
+ if len(hint_class.__args__) == 2 and hint_class.__args__[1] == type(None):
+ child_type = hint_class.__args__[0]
+ return get_basic_type_info_from_hint(child_type)
return None
- pattern = None
+ for check_class, info in hinting_type_info:
+ if is_origin_type_subclasses(hint_class, check_class):
+ if callable(info):
+ return info(hint_class)
+ swagger_type, format = info
+ if callable(swagger_type):
+ swagger_type = swagger_type()
+ return OrderedDict([
+ ('type', swagger_type),
+ ('format', format),
+ ])
- result = OrderedDict([
- ('type', swagger_type),
- ('format', format),
- ('pattern', pattern)
- ])
-
- return result
+ return None
class SerializerMethodFieldInspector(FieldInspector):
diff --git a/tests/test_get_basic_type_info_from_hint.py b/tests/test_get_basic_type_info_from_hint.py
new file mode 100644
index 0000000..49c1f96
--- /dev/null
+++ b/tests/test_get_basic_type_info_from_hint.py
@@ -0,0 +1,36 @@
+import uuid
+
+import pytest
+
+from drf_yasg import openapi
+from drf_yasg.inspectors.field import get_basic_type_info_from_hint
+
+try:
+ import typing
+ from typing import Dict, List, Union, Optional, Set
+except ImportError:
+ typing = None
+
+
+if typing:
+ @pytest.mark.parametrize('hint_class, expected_swagger_type_info', [
+ (int, {'type': openapi.TYPE_INTEGER, 'format': None}),
+ (str, {'type': openapi.TYPE_STRING, 'format': None}),
+ (bool, {'type': openapi.TYPE_BOOLEAN, 'format': None}),
+ (dict, {'type': openapi.TYPE_OBJECT, 'format': None}),
+ (Dict[int, int], {'type': openapi.TYPE_OBJECT, 'format': None}),
+ (uuid.UUID, {'type': openapi.TYPE_STRING, 'format': openapi.FORMAT_UUID}),
+ (List[int], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER)}),
+ (List[str], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_STRING)}),
+ (List[bool], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_BOOLEAN)}),
+ (Set[int], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER)}),
+ (Optional[bool], {'type': openapi.TYPE_BOOLEAN, 'format': None}),
+ (Optional[List[int]], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER)}),
+ (Union[List[int], type(None)], {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER)}),
+ # Following cases are not 100% correct, but it should work somehow and not crash.
+ (Union[int, float], None),
+ (List, {'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_STRING)}),
+ ])
+ def test_get_basic_type_info_from_hint(hint_class, expected_swagger_type_info):
+ type_info = get_basic_type_info_from_hint(hint_class)
+ assert type_info == expected_swagger_type_info