diff --git a/.codecov.yml b/.codecov.yml
index fb3842d..57bfecc 100644
--- a/.codecov.yml
+++ b/.codecov.yml
@@ -11,15 +11,15 @@ coverage:
default:
enabled: yes
target: auto
- threshold: 0%
+ threshold: 100%
if_no_uploads: error
if_ci_failed: error
patch:
default:
enabled: yes
- target: 80%
- threshold: 0%
+ target: 100%
+ threshold: 100%
if_no_uploads: error
if_ci_failed: error
diff --git a/.coveragerc b/.coveragerc
index d3750bc..996f08c 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -17,6 +17,8 @@ exclude_lines =
raise TypeError
raise NotImplementedError
warnings.warn
+ logger.warning
+ return NotHandled
# Don't complain if non-runnable code isn't run:
if 0:
diff --git a/.gitignore b/.gitignore
index d48f341..0840b0f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -156,3 +156,5 @@ com_crashlytics_export_strings.xml
crashlytics.properties
crashlytics-build.properties
fabric.properties
+
+testproj/db\.sqlite3
diff --git a/.travis.yml b/.travis.yml
index d8ceb8d..dab442f 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -40,6 +40,7 @@ after_success:
branches:
only:
- master
+ - /^release\/.*$/
notifications:
email:
diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst
index 1b239ad..ec46a5f 100644
--- a/CONTRIBUTING.rst
+++ b/CONTRIBUTING.rst
@@ -43,6 +43,8 @@ You want to contribute some code? Great! Here are a few steps to get you started
.. code:: console
(venv) $ cd testproj
+ (venv) $ python manage.py migrate
+ (venv) $ cat createsuperuser.py | python manage.py shell
(venv) $ python manage.py runserver
(venv) $ curl localhost:8000/swagger.yaml
diff --git a/README.rst b/README.rst
index 75efd12..2b01cbb 100644
--- a/README.rst
+++ b/README.rst
@@ -141,6 +141,7 @@ This exposes 4 cached, validated and publicly available endpoints:
2. Configuration
================
+---------------------------------
a. ``get_schema_view`` parameters
---------------------------------
@@ -153,6 +154,7 @@ a. ``get_schema_view`` parameters
- ``authentication_classes`` - authentication classes for the schema view itself
- ``permission_classes`` - permission classes for the schema view itself
+-------------------------------
b. ``SchemaView`` options
-------------------------------
@@ -169,6 +171,7 @@ All of the first 3 methods take two optional arguments,
to Django’s :python:`cached_page` decorator in order to enable caching on the
resulting view. See `3. Caching`_.
+----------------------------------------------
c. ``SWAGGER_SETTINGS`` and ``REDOC_SETTINGS``
----------------------------------------------
@@ -178,6 +181,26 @@ The possible settings and their default values are as follows:
.. code:: python
SWAGGER_SETTINGS = {
+ # default inspector classes, see advanced documentation
+ 'DEFAULT_AUTO_SCHEMA_CLASS': 'drf_yasg.inspectors.SwaggerAutoSchema',
+ 'DEFAULT_FIELD_INSPECTORS': [
+ 'drf_yasg.inspectors.CamelCaseJSONFilter',
+ 'drf_yasg.inspectors.ReferencingSerializerInspector',
+ 'drf_yasg.inspectors.RelatedFieldInspector',
+ 'drf_yasg.inspectors.ChoiceFieldInspector',
+ 'drf_yasg.inspectors.FileFieldInspector',
+ 'drf_yasg.inspectors.DictFieldInspector',
+ 'drf_yasg.inspectors.SimpleFieldInspector',
+ 'drf_yasg.inspectors.StringDefaultFieldInspector',
+ ],
+ 'DEFAULT_FILTER_INSPECTORS': [
+ 'drf_yasg.inspectors.CoreAPICompatInspector',
+ ],
+ 'DEFAULT_PAGINATOR_INSPECTORS': [
+ 'drf_yasg.inspectors.DjangoRestResponsePagination',
+ 'drf_yasg.inspectors.CoreAPICompatInspector',
+ ],
+
'USE_SESSION_AUTH': True, # add Django Login and Django Logout buttons, CSRF token to swagger UI page
'LOGIN_URL': getattr(django.conf.settings, 'LOGIN_URL', None), # URL for the login button
'LOGOUT_URL': getattr(django.conf.settings, 'LOGOUT_URL', None), # URL for the logout button
@@ -241,6 +264,7 @@ Caching can mitigate the speed impact of validation.
The provided validation will catch syntactic errors, but more subtle violations of the spec might slip by them. To
ensure compatibility with code generation tools, it is recommended to also employ one or more of the following methods:
+-------------------------------
``swagger-ui`` validation badge
-------------------------------
@@ -271,6 +295,7 @@ If your schema is not accessible from the internet, you can run a local copy of
$ curl http://localhost:8189/debug?url=http://test.local:8002/swagger/?format=openapi
{}
+---------------------
Using ``swagger-cli``
---------------------
@@ -283,6 +308,7 @@ https://www.npmjs.com/package/swagger-cli
$ swagger-cli validate http://test.local:8002/swagger.yaml
http://test.local:8002/swagger.yaml is valid
+--------------------------------------------------------------
Manually on `editor.swagger.io `__
--------------------------------------------------------------
@@ -345,10 +371,16 @@ named schemas.
Both projects are also currently unmantained.
-Documentation, advanced usage
-=============================
+************************
+Third-party integrations
+************************
-https://drf-yasg.readthedocs.io/en/latest/
+djangorestframework-camel-case
+===============================
+
+Integration with `djangorestframework-camel-case `_ is
+provided out of the box - if you have ``djangorestframework-camel-case`` installed and your ``APIView`` uses
+``CamelCaseJSONParser`` or ``CamelCaseJSONRenderer``, all property names will be converted to *camelCase* by default.
.. |travis| image:: https://img.shields.io/travis/axnsan12/drf-yasg/master.svg
:target: https://travis-ci.org/axnsan12/drf-yasg
diff --git a/docs/_static/css/style.css b/docs/_static/css/style.css
new file mode 100644
index 0000000..2571b7c
--- /dev/null
+++ b/docs/_static/css/style.css
@@ -0,0 +1,18 @@
+.versionadded, .versionchanged, .deprecated {
+ font-family: "Roboto", Corbel, Avenir, "Lucida Grande", "Lucida Sans", sans-serif;
+ padding: 10px 13px;
+ border: 1px solid rgb(137, 191, 4);
+ border-radius: 4px;
+ margin-bottom: 10px;
+}
+
+.versionmodified {
+ font-weight: bold;
+ display: block;
+}
+
+.versionadded p, .versionchanged p, .deprecated p,
+/*override fucking !important by being more specific */
+.rst-content dl .versionadded p, .rst-content dl .versionchanged p {
+ margin: 0 !important;
+}
diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html
new file mode 100644
index 0000000..c242078
--- /dev/null
+++ b/docs/_templates/layout.html
@@ -0,0 +1,4 @@
+{% extends "!layout.html" %}
+{% block extrahead %}
+
+{% endblock %}
diff --git a/docs/changelog.rst b/docs/changelog.rst
index 6e3fd2a..79bbbc0 100644
--- a/docs/changelog.rst
+++ b/docs/changelog.rst
@@ -3,6 +3,21 @@ Changelog
#########
+*********
+**1.1.0**
+*********
+
+- **ADDED:** added support for APIs versioned with ``URLPathVersioning`` or ``NamespaceVersioning``
+- **ADDED:** added ability to recursively customize schema generation
+ :ref:`using pluggable inspector classes `
+- **ADDED:** added ``operation_id`` parameter to :func:`@swagger_auto_schema <.swagger_auto_schema>`
+- **ADDED:** integration with `djangorestframework-camel-case
+ `_ (:issue:`28`)
+- **IMPROVED:** strings, arrays and integers will now have min/max validation attributes inferred from the
+ field-level validators
+- **FIXED:** fixed a bug that caused ``title`` to never be generated for Schemas; ``title`` is now correctly
+ populated from the field's ``label`` property
+
*********
**1.0.6**
*********
diff --git a/docs/conf.py b/docs/conf.py
index bdad022..3e9b70b 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -3,6 +3,7 @@
#
# drf-yasg documentation build configuration file, created by
# sphinx-quickstart on Sun Dec 10 15:20:34 2017.
+import inspect
import os
import re
import sys
@@ -68,9 +69,6 @@ pygments_style = 'sphinx'
modindex_common_prefix = ['drf_yasg.']
-# If true, `todo` and `todoList` produce output, else they produce nothing.
-todo_include_todos = False
-
# -- Options for HTML output ----------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
@@ -186,18 +184,23 @@ nitpick_ignore = [
('py:obj', 'callable'),
('py:obj', 'type'),
('py:obj', 'OrderedDict'),
+ ('py:obj', 'None'),
('py:obj', 'coreapi.Field'),
('py:obj', 'BaseFilterBackend'),
('py:obj', 'BasePagination'),
+ ('py:obj', 'Request'),
('py:obj', 'rest_framework.request.Request'),
('py:obj', 'rest_framework.serializers.Field'),
('py:obj', 'serializers.Field'),
('py:obj', 'serializers.BaseSerializer'),
('py:obj', 'Serializer'),
+ ('py:obj', 'BaseSerializer'),
('py:obj', 'APIView'),
]
+# TODO: inheritance aliases in sphinx 1.7
+
# even though the package should be already installed, the sphinx build on RTD
# for some reason needs the sources dir to be in the path in order for viewcode to work
sys.path.insert(0, os.path.abspath('../src'))
@@ -215,6 +218,40 @@ import drf_yasg.views # noqa: E402
drf_yasg.views.SchemaView = drf_yasg.views.get_schema_view(None)
+# monkey patch to stop sphinx from trying to find classes by their real location instead of the
+# top-level __init__ alias; this allows us to document only `drf_yasg.inspectors` and avoid broken references or
+# double documenting
+
+import drf_yasg.inspectors # noqa: E402
+
+
+def redirect_cls(cls):
+ if cls.__module__.startswith('drf_yasg.inspectors'):
+ return getattr(drf_yasg.inspectors, cls.__name__)
+ return cls
+
+
+for cls_name in drf_yasg.inspectors.__all__:
+ # first pass - replace all classes' module with the top level module
+ real_cls = getattr(drf_yasg.inspectors, cls_name)
+ if not inspect.isclass(real_cls):
+ continue
+
+ patched_dict = dict(real_cls.__dict__)
+ patched_dict.update({'__module__': 'drf_yasg.inspectors'})
+ patched_cls = type(cls_name, real_cls.__bases__, patched_dict)
+ setattr(drf_yasg.inspectors, cls_name, patched_cls)
+
+for cls_name in drf_yasg.inspectors.__all__:
+ # second pass - replace the inheritance bases for all classes to point to the new clean classes
+ real_cls = getattr(drf_yasg.inspectors, cls_name)
+ if not inspect.isclass(real_cls):
+ continue
+
+ patched_bases = tuple(redirect_cls(base) for base in real_cls.__bases__)
+ patched_cls = type(cls_name, patched_bases, dict(real_cls.__dict__))
+ setattr(drf_yasg.inspectors, cls_name, patched_cls)
+
# custom interpreted role for linking to GitHub issues and pull requests
# use as :issue:`14` or :pr:`17`
gh_issue_uri = "https://github.com/axnsan12/drf-yasg/issues/{}"
@@ -273,3 +310,7 @@ def role_github_pull_request_or_issue(name, rawtext, text, lineno, inliner, opti
roles.register_local_role('pr', role_github_pull_request_or_issue)
roles.register_local_role('issue', role_github_pull_request_or_issue)
roles.register_local_role('ghuser', role_github_user)
+
+
+def setup(app):
+ app.add_stylesheet('css/style.css')
diff --git a/docs/custom_spec.rst b/docs/custom_spec.rst
index 2541b1a..1b23b79 100644
--- a/docs/custom_spec.rst
+++ b/docs/custom_spec.rst
@@ -249,15 +249,63 @@ Where you can use the :func:`@swagger_auto_schema <.swagger_auto_schema>` decora
However, do note that both of the methods above can lead to unexpected (and maybe surprising) results by
replacing/decorating methods on the base class itself.
+
+********************************
+Serializer ``Meta`` nested class
+********************************
+
+You can define some per-serializer options by adding a ``Meta`` class to your serializer, e.g.:
+
+.. code:: python
+
+ class WhateverSerializer(Serializer):
+ ...
+
+ class Meta:
+ ... options here ...
+
+Currently, the only option you can add here is
+
+ * ``ref_name`` - a string which will be used as the model definition name for this serializer class; setting it to
+ ``None`` will force the serializer to be generated as an inline model everywhere it is used
+
*************************
Subclassing and extending
*************************
-For more advanced control you can subclass :class:`.SwaggerAutoSchema` - see the documentation page for a list of
-methods you can override.
+
+---------------------
+``SwaggerAutoSchema``
+---------------------
+
+For more advanced control you can subclass :class:`~.inspectors.SwaggerAutoSchema` - see the documentation page
+for a list of methods you can override.
You can put your custom subclass to use by setting it on a view method using the
-:func:`@swagger_auto_schema <.swagger_auto_schema>` decorator described above.
+:ref:`@swagger_auto_schema ` decorator described above, by setting it as a
+class-level attribute named ``swagger_schema`` on the view class, or
+:ref:`globally via settings `.
+
+For example, to generate all operation IDs as camel case, you could do:
+
+.. code:: python
+
+ from inflection import camelize
+
+ class CamelCaseOperationIDAutoSchema(SwaggerAutoSchema):
+ def get_operation_id(self, operation_keys):
+ operation_id = super(CamelCaseOperationIDAutoSchema, self).get_operation_id(operation_keys)
+ return camelize(operation_id, uppercase_first_letter=False)
+
+
+ SWAGGER_SETTINGS = {
+ 'DEFAULT_AUTO_SCHEMA_CLASS': 'path.to.CamelCaseOperationIDAutoSchema',
+ ...
+ }
+
+--------------------------
+``OpenAPISchemaGenerator``
+--------------------------
If you need to control things at a higher level than :class:`.Operation` objects (e.g. overall document structure,
vendor extensions in metadata) you can also subclass :class:`.OpenAPISchemaGenerator` - again, see the documentation
@@ -265,3 +313,88 @@ page for a list of its methods.
This custom generator can be put to use by setting it as the :attr:`.generator_class` of a :class:`.SchemaView` using
:func:`.get_schema_view`.
+
+.. _custom-spec-inspectors:
+
+---------------------
+``Inspector`` classes
+---------------------
+
+.. versionadded:: 1.1
+
+For customizing behavior related to specific field, serializer, filter or paginator classes you can implement the
+:class:`~.inspectors.FieldInspector`, :class:`~.inspectors.SerializerInspector`, :class:`~.inspectors.FilterInspector`,
+:class:`~.inspectors.PaginatorInspector` classes and use them with
+:ref:`@swagger_auto_schema ` or one of the
+:ref:`related settings `.
+
+A :class:`~.inspectors.FilterInspector` that adds a description to all ``DjangoFilterBackend`` parameters could be
+implemented like so:
+
+.. code:: python
+
+ class DjangoFilterDescriptionInspector(CoreAPICompatInspector):
+ def get_filter_parameters(self, filter_backend):
+ if isinstance(filter_backend, DjangoFilterBackend):
+ result = super(DjangoFilterDescriptionInspector, self).get_filter_parameters(filter_backend)
+ for param in result:
+ if not param.get('description', ''):
+ param.description = "Filter the returned list by {field_name}".format(field_name=param.name)
+
+ return result
+
+ return NotHandled
+
+ @method_decorator(name='list', decorator=swagger_auto_schema(
+ filter_inspectors=[DjangoFilterDescriptionInspector]
+ ))
+ class ArticleViewSet(viewsets.ModelViewSet):
+ filter_backends = (DjangoFilterBackend,)
+ filter_fields = ('title',)
+ ...
+
+
+A second example, of a :class:`~.inspectors.FieldInspector` that removes the ``title`` attribute from all generated
+:class:`.Schema` objects:
+
+.. code:: python
+
+ class NoSchemaTitleInspector(FieldInspector):
+ def process_result(self, result, method_name, obj, **kwargs):
+ # remove the `title` attribute of all Schema objects
+ if isinstance(result, openapi.Schema.OR_REF):
+ # traverse any references and alter the Schema object in place
+ schema = openapi.resolve_ref(result, self.components)
+ schema.pop('title', None)
+
+ # no ``return schema`` here, because it would mean we always generate
+ # an inline `object` instead of a definition reference
+
+ # return back the same object that we got - i.e. a reference if we got a reference
+ return result
+
+
+ class NoTitleAutoSchema(SwaggerAutoSchema):
+ field_inspectors = [NoSchemaTitleInspector] + swagger_settings.DEFAULT_FIELD_INSPECTORS
+
+ class ArticleViewSet(viewsets.ModelViewSet):
+ swagger_schema = NoTitleAutoSchema
+ ...
+
+
+.. Note::
+
+ A note on references - :class:`.Schema` objects are sometimes output by reference (:class:`.SchemaRef`); in fact,
+ that is how named models are implemented in OpenAPI:
+
+ - in the output swagger document there is a ``definitions`` section containing :class:`.Schema` objects for all
+ models
+ - every usage of a model refers to that single :class:`.Schema` object - for example, in the ArticleViewSet
+ above, all requests and responses containg an ``Article`` model would refer to the same schema definition by a
+ ``'$ref': '#/definitions/Article'``
+
+ This is implemented by only generating **one** :class:`.Schema` object for every serializer **class** encountered.
+
+ This means that you should generally avoid view or method-specific ``FieldInspector``\ s if you are dealing with
+ references (a.k.a named models), because you can never know which view will be the first to generate the schema
+ for a given serializer.
diff --git a/docs/drf_yasg.rst b/docs/drf_yasg.rst
index 9fa9a4b..77c95aa 100644
--- a/docs/drf_yasg.rst
+++ b/docs/drf_yasg.rst
@@ -1,14 +1,6 @@
drf\_yasg package
====================
-drf\_yasg\.app\_settings
-----------------------------------
-
-.. automodule:: drf_yasg.app_settings
- :members:
- :undoc-members:
- :show-inheritance:
-
drf\_yasg\.codecs
---------------------------
@@ -16,7 +8,7 @@ drf\_yasg\.codecs
:members:
:undoc-members:
:show-inheritance:
- :exclude-members: SaneYamlDumper
+ :exclude-members: SaneYamlDumper,SaneYamlLoader
drf\_yasg\.errors
---------------------------
@@ -37,6 +29,8 @@ drf\_yasg\.generators
drf\_yasg\.inspectors
-------------------------------
+.. autodata:: drf_yasg.inspectors.NotHandled
+
.. automodule:: drf_yasg.inspectors
:members:
:undoc-members:
diff --git a/docs/settings.rst b/docs/settings.rst
index 9a66d72..afd3ebf 100644
--- a/docs/settings.rst
+++ b/docs/settings.rst
@@ -37,6 +37,60 @@ The possible settings and their default values are as follows:
``SWAGGER_SETTINGS``
********************
+
+.. _default-class-settings:
+
+Default classes
+===============
+
+DEFAULT_AUTO_SCHEMA_CLASS
+-------------------------
+
+:class:`~.inspectors.ViewInspector` subclass that will be used by default for generating :class:`.Operation`
+objects when iterating over endpoints. Can be overriden by using the `auto_schema` argument of
+:func:`@swagger_auto_schema <.swagger_auto_schema>` or by a ``swagger_schema`` attribute on the view class.
+
+**Default**: :class:`drf_yasg.inspectors.SwaggerAutoSchema`
+
+DEFAULT_FIELD_INSPECTORS
+------------------------
+
+List of :class:`~.inspectors.FieldInspector` subclasses that will be used by default for inspecting serializers and
+serializer fields. Field inspectors given to :func:`@swagger_auto_schema <.swagger_auto_schema>` will be prepended
+to this list.
+
+**Default**: ``[`` |br| \
+:class:`'drf_yasg.inspectors.CamelCaseJSONFilter' <.inspectors.CamelCaseJSONFilter>`, |br| \
+:class:`'drf_yasg.inspectors.ReferencingSerializerInspector' <.inspectors.ReferencingSerializerInspector>`, |br| \
+:class:`'drf_yasg.inspectors.RelatedFieldInspector' <.inspectors.RelatedFieldInspector>`, |br| \
+:class:`'drf_yasg.inspectors.ChoiceFieldInspector' <.inspectors.ChoiceFieldInspector>`, |br| \
+:class:`'drf_yasg.inspectors.FileFieldInspector' <.inspectors.FileFieldInspector>`, |br| \
+:class:`'drf_yasg.inspectors.DictFieldInspector' <.inspectors.DictFieldInspector>`, |br| \
+:class:`'drf_yasg.inspectors.SimpleFieldInspector' <.inspectors.SimpleFieldInspector>`, |br| \
+:class:`'drf_yasg.inspectors.StringDefaultFieldInspector' <.inspectors.StringDefaultFieldInspector>`, |br| \
+``]``
+
+DEFAULT_FILTER_INSPECTORS
+-------------------------
+
+List of :class:`~.inspectors.FilterInspector` subclasses that will be used by default for inspecting filter backends.
+Filter inspectors given to :func:`@swagger_auto_schema <.swagger_auto_schema>` will be prepended to this list.
+
+**Default**: ``[`` |br| \
+:class:`'drf_yasg.inspectors.CoreAPICompatInspector' <.inspectors.CoreAPICompatInspector>`, |br| \
+``]``
+
+DEFAULT_PAGINATOR_INSPECTORS
+----------------------------
+
+List of :class:`~.inspectors.PaginatorInspector` subclasses that will be used by default for inspecting paginators.
+Paginator inspectors given to :func:`@swagger_auto_schema <.swagger_auto_schema>` will be prepended to this list.
+
+**Default**: ``[`` |br| \
+:class:`'drf_yasg.inspectors.DjangoRestResponsePagination' <.inspectors.DjangoRestResponsePagination>`, |br| \
+:class:`'drf_yasg.inspectors.CoreAPICompatInspector' <.inspectors.CoreAPICompatInspector>`, |br| \
+``]``
+
Authorization
=============
diff --git a/requirements/test.txt b/requirements/test.txt
index c45ebcf..762fcb9 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -12,3 +12,4 @@ pygments>=2.2.0
django-cors-headers>=2.1.0
django-filter>=1.1.0,<2.0; python_version == "2.7"
django-filter>=1.1.0; python_version >= "3.4"
+djangorestframework-camel-case>=0.2.0
diff --git a/setup.py b/setup.py
index bf3e32b..4a8fb03 100644
--- a/setup.py
+++ b/setup.py
@@ -56,7 +56,7 @@ requirements_validation = read_req('validation.txt')
setup(
name='drf-yasg',
use_scm_version=True,
- packages=find_packages('src', include=['drf_yasg']),
+ packages=find_packages('src'),
package_dir={'': 'src'},
include_package_data=True,
install_requires=requirements,
diff --git a/src/drf_yasg/app_settings.py b/src/drf_yasg/app_settings.py
index 527c9d0..a6ae8e6 100644
--- a/src/drf_yasg/app_settings.py
+++ b/src/drf_yasg/app_settings.py
@@ -2,6 +2,26 @@ from django.conf import settings
from rest_framework.settings import perform_import
SWAGGER_DEFAULTS = {
+ 'DEFAULT_AUTO_SCHEMA_CLASS': 'drf_yasg.inspectors.SwaggerAutoSchema',
+
+ 'DEFAULT_FIELD_INSPECTORS': [
+ 'drf_yasg.inspectors.CamelCaseJSONFilter',
+ 'drf_yasg.inspectors.ReferencingSerializerInspector',
+ 'drf_yasg.inspectors.RelatedFieldInspector',
+ 'drf_yasg.inspectors.ChoiceFieldInspector',
+ 'drf_yasg.inspectors.FileFieldInspector',
+ 'drf_yasg.inspectors.DictFieldInspector',
+ 'drf_yasg.inspectors.SimpleFieldInspector',
+ 'drf_yasg.inspectors.StringDefaultFieldInspector',
+ ],
+ 'DEFAULT_FILTER_INSPECTORS': [
+ 'drf_yasg.inspectors.CoreAPICompatInspector',
+ ],
+ 'DEFAULT_PAGINATOR_INSPECTORS': [
+ 'drf_yasg.inspectors.DjangoRestResponsePagination',
+ 'drf_yasg.inspectors.CoreAPICompatInspector',
+ ],
+
'USE_SESSION_AUTH': True,
'SECURITY_DEFINITIONS': {
'basic': {
@@ -28,7 +48,12 @@ REDOC_DEFAULTS = {
'PATH_IN_MIDDLE': False,
}
-IMPORT_STRINGS = []
+IMPORT_STRINGS = [
+ 'DEFAULT_AUTO_SCHEMA_CLASS',
+ 'DEFAULT_FIELD_INSPECTORS',
+ 'DEFAULT_FILTER_INSPECTORS',
+ 'DEFAULT_PAGINATOR_INSPECTORS',
+]
class AppSettings(object):
diff --git a/src/drf_yasg/codecs.py b/src/drf_yasg/codecs.py
index d2746a0..d52a54c 100644
--- a/src/drf_yasg/codecs.py
+++ b/src/drf_yasg/codecs.py
@@ -98,6 +98,9 @@ class OpenAPICodecJson(_OpenAPICodec):
return json.dumps(spec)
+YAML_MAP_TAG = u'tag:yaml.org,2002:map'
+
+
class SaneYamlDumper(yaml.SafeDumper):
"""YamlDumper class usable for dumping ``OrderedDict`` and list instances in a standard way."""
@@ -122,7 +125,7 @@ class SaneYamlDumper(yaml.SafeDumper):
To use yaml.safe_dump(), you need the following.
"""
- tag = u'tag:yaml.org,2002:map'
+ tag = YAML_MAP_TAG
value = []
node = yaml.MappingNode(tag, value, flow_style=flow_style)
if dump.alias_key is not None:
@@ -158,7 +161,7 @@ def yaml_sane_dump(data, binary):
* list elements are indented into their parents
* YAML references/aliases are disabled
- :param dict data: the data to be serializers
+ :param dict data: the data to be dumped
:param bool binary: True to return a utf-8 encoded binary object, False to return a string
:return: the serialized YAML
:rtype: str,bytes
@@ -166,6 +169,24 @@ def yaml_sane_dump(data, binary):
return yaml.dump(data, Dumper=SaneYamlDumper, default_flow_style=False, encoding='utf-8' if binary else None)
+class SaneYamlLoader(yaml.SafeLoader):
+ def construct_odict(self, node, deep=False):
+ self.flatten_mapping(node)
+ return OrderedDict(self.construct_pairs(node))
+
+
+SaneYamlLoader.add_constructor(YAML_MAP_TAG, SaneYamlLoader.construct_odict)
+
+
+def yaml_sane_load(stream):
+ """Load the given YAML stream while preserving the input order for mapping items.
+
+ :param stream: YAML stream (can be a string or a file-like object)
+ :rtype: OrderedDict
+ """
+ return yaml.load(stream, Loader=SaneYamlLoader)
+
+
class OpenAPICodecYaml(_OpenAPICodec):
media_type = 'application/yaml'
diff --git a/src/drf_yasg/generators.py b/src/drf_yasg/generators.py
index 3f1979f..0f3d8d7 100644
--- a/src/drf_yasg/generators.py
+++ b/src/drf_yasg/generators.py
@@ -2,12 +2,15 @@ import re
from collections import defaultdict, OrderedDict
import uritemplate
+from django.utils.encoding import force_text
+from rest_framework import versioning
from rest_framework.schemas.generators import SchemaGenerator, EndpointEnumerator as _EndpointEnumerator
+from rest_framework.schemas.inspectors import get_pk_description
from . import openapi
-from .inspectors import SwaggerAutoSchema
+from .app_settings import swagger_settings
+from .inspectors.field import get_queryset_field, get_basic_type_info
from .openapi import ReferenceResolver
-from .utils import inspect_model_field, get_model_field
PATH_PARAMETER_RE = re.compile(r'{(?P\w+)}')
@@ -52,7 +55,7 @@ class EndpointEnumerator(_EndpointEnumerator):
class OpenAPISchemaGenerator(object):
"""
This class iterates over all registered API endpoints and returns an appropriate OpenAPI 2.0 compliant schema.
- Method implementations shamelessly stolen and adapted from rest_framework SchemaGenerator.
+ Method implementations shamelessly stolen and adapted from rest-framework ``SchemaGenerator``.
"""
endpoint_enumerator_class = EndpointEnumerator
@@ -70,10 +73,14 @@ class OpenAPISchemaGenerator(object):
self.info = info
self.version = version
- def get_schema(self, request=None, public=False):
- """Generate an :class:`.Swagger` representing the API schema.
+ @property
+ def url(self):
+ return self._gen.url
- :param rest_framework.request.Request request: the request used for filtering
+ def get_schema(self, request=None, public=False):
+ """Generate a :class:`.Swagger` object representing the API schema.
+
+ :param Request request: the request used for filtering
accesible endpoints and finding the spec URI
:param bool public: if True, all endpoints are included regardless of access through `request`
@@ -81,10 +88,11 @@ class OpenAPISchemaGenerator(object):
:rtype: openapi.Swagger
"""
endpoints = self.get_endpoints(request)
+ endpoints = self.replace_version(endpoints, request)
components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS)
- paths = self.get_paths(endpoints, components, public)
+ paths = self.get_paths(endpoints, components, request, public)
- url = self._gen.url
+ url = self.url
if not url and request is not None:
url = request.build_absolute_uri()
@@ -102,16 +110,40 @@ class OpenAPISchemaGenerator(object):
:return: the view instance
"""
view = self._gen.create_view(callback, method, request)
- overrides = getattr(callback, 'swagger_auto_schema', None)
+ overrides = getattr(callback, '_swagger_auto_schema', None)
if overrides is not None:
# decorated function based view must have its decorator information passed on to the re-instantiated view
for method, _ in overrides.items():
view_method = getattr(view, method, None)
if view_method is not None: # pragma: no cover
- setattr(view_method.__func__, 'swagger_auto_schema', overrides)
+ setattr(view_method.__func__, '_swagger_auto_schema', overrides)
return view
- def get_endpoints(self, request=None):
+ def replace_version(self, endpoints, request):
+ """If ``request.version`` is not ``None``, replace the version parameter in the path of any endpoints using
+ ``URLPathVersioning`` as a versioning class.
+
+ :param dict endpoints: endpoints as returned by :meth:`.get_endpoints`
+ :param Request request: the request made against the schema view
+ :return: endpoints with modified paths
+ """
+ version = getattr(request, 'version', None)
+ if version is None:
+ return endpoints
+
+ new_endpoints = {}
+ for path, endpoint in endpoints.items():
+ view_cls = endpoint[0]
+ versioning_class = getattr(view_cls, 'versioning_class', None)
+ version_param = getattr(versioning_class, 'version_param', 'version')
+ if versioning_class is not None and issubclass(versioning_class, versioning.URLPathVersioning):
+ path = path.replace('{%s}' % version_param, version)
+
+ new_endpoints[path] = endpoint
+
+ return new_endpoints
+
+ def get_endpoints(self, request):
"""Iterate over all the registered endpoints in the API and return a fake view with the right parameters.
:param rest_framework.request.Request request: request to bind to the endpoint views
@@ -131,9 +163,7 @@ class OpenAPISchemaGenerator(object):
return {path: (view_cls[path], methods) for path, methods in view_paths.items()}
def get_operation_keys(self, subpath, method, view):
- """Return a list of keys that should be used to group an operation within the specification.
-
- ::
+ """Return a list of keys that should be used to group an operation within the specification. ::
/users/ ("users", "list"), ("users", "create")
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
@@ -149,39 +179,94 @@ class OpenAPISchemaGenerator(object):
"""
return self._gen.get_keys(subpath, method, view)
- def get_paths(self, endpoints, components, public):
+ def determine_path_prefix(self, paths):
+ """
+ Given a list of all paths, return the common prefix which should be
+ discounted when generating a schema structure.
+
+ This will be the longest common string that does not include that last
+ component of the URL, or the last component before a path parameter.
+
+ For example: ::
+
+ /api/v1/users/
+ /api/v1/users/{pk}/
+
+ The path prefix is ``/api/v1/``.
+
+ :param list[str] paths: list of paths
+ :rtype: str
+ """
+ return self._gen.determine_path_prefix(paths)
+
+ def get_paths(self, endpoints, components, request, public):
"""Generate the Swagger Paths for the API from the given endpoints.
:param dict endpoints: endpoints as returned by get_endpoints
:param ReferenceResolver components: resolver/container for Swagger References
+ :param Request request: the request made against the schema view; can be None
:param bool public: if True, all endpoints are included regardless of access through `request`
:rtype: openapi.Paths
"""
if not endpoints:
return openapi.Paths(paths={})
- prefix = self._gen.determine_path_prefix(endpoints.keys())
+ prefix = self.determine_path_prefix(list(endpoints.keys()))
paths = OrderedDict()
- default_schema_cls = SwaggerAutoSchema
for path, (view_cls, methods) in sorted(endpoints.items()):
- path_parameters = self.get_path_parameters(path, view_cls)
operations = {}
for method, view in methods:
if not public and not self._gen.has_view_permissions(path, method, view):
continue
- operation_keys = self.get_operation_keys(path[len(prefix):], method, view)
- overrides = self.get_overrides(view, method)
- auto_schema_cls = overrides.get('auto_schema', default_schema_cls)
- schema = auto_schema_cls(view, path, method, overrides, components)
- operations[method.lower()] = schema.get_operation(operation_keys)
+ operations[method.lower()] = self.get_operation(view, path, prefix, method, components, request)
if operations:
- paths[path] = openapi.PathItem(parameters=path_parameters, **operations)
+ paths[path] = self.get_path_item(path, view_cls, operations)
return openapi.Paths(paths=paths)
+ def get_operation(self, view, path, prefix, method, components, request):
+ """Get an :class:`.Operation` for the given API endpoint (path, method). This method delegates to
+ :meth:`~.inspectors.ViewInspector.get_operation` of a :class:`~.inspectors.ViewInspector` determined
+ according to settings and :func:`@swagger_auto_schema <.swagger_auto_schema>` overrides.
+
+ :param view: the view associated with this endpoint
+ :param str path: the path component of the operation URL
+ :param str prefix: common path prefix among all endpoints
+ :param str method: the http method of the operation
+ :param openapi.ReferenceResolver components: referenceable components
+ :param Request request: the request made against the schema view; can be None
+ :rtype: openapi.Operation
+ """
+
+ operation_keys = self.get_operation_keys(path[len(prefix):], method, view)
+ overrides = self.get_overrides(view, method)
+
+ # the inspector class can be specified, in decreasing order of priorty,
+ # 1. globaly via DEFAULT_AUTO_SCHEMA_CLASS
+ view_inspector_cls = swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS
+ # 2. on the view/viewset class
+ view_inspector_cls = getattr(view, 'swagger_schema', view_inspector_cls)
+ # 3. on the swagger_auto_schema decorator
+ view_inspector_cls = overrides.get('auto_schema', view_inspector_cls)
+
+ view_inspector = view_inspector_cls(view, path, method, components, request, overrides)
+ return view_inspector.get_operation(operation_keys)
+
+ def get_path_item(self, path, view_cls, operations):
+ """Get a :class:`.PathItem` object that describes the parameters and operations related to a single path in the
+ API.
+
+ :param str path: the path
+ :param type view_cls: the view that was bound to this path in urlpatterns
+ :param dict[str,openapi.Operation] operations: operations defined on this path, keyed by lowercase HTTP method
+ :rtype: openapi.PathItem
+ """
+ path_parameters = self.get_path_parameters(path, view_cls)
+ return openapi.PathItem(parameters=path_parameters, **operations)
+
def get_overrides(self, view, method):
"""Get overrides specified for a given operation.
@@ -193,7 +278,7 @@ class OpenAPISchemaGenerator(object):
method = method.lower()
action = getattr(view, 'action', method)
action_method = getattr(view, action, None)
- overrides = getattr(action_method, 'swagger_auto_schema', {})
+ overrides = getattr(action_method, '_swagger_auto_schema', {})
if method in overrides:
overrides = overrides[method]
@@ -212,13 +297,21 @@ class OpenAPISchemaGenerator(object):
model = getattr(getattr(view_cls, 'queryset', None), 'model', None)
for variable in uritemplate.variables(path):
- model, model_field = get_model_field(queryset, variable)
- attrs = inspect_model_field(model, model_field)
+ model, model_field = get_queryset_field(queryset, variable)
+ attrs = get_basic_type_info(model_field) or {'type': openapi.TYPE_STRING}
if hasattr(view_cls, 'lookup_value_regex') and getattr(view_cls, 'lookup_field', None) == variable:
attrs['pattern'] = view_cls.lookup_value_regex
+ if model_field and model_field.help_text:
+ description = force_text(model_field.help_text)
+ elif model_field and model_field.primary_key:
+ description = get_pk_description(model, model_field)
+ else:
+ description = None
+
field = openapi.Parameter(
name=variable,
+ description=description,
required=True,
in_=openapi.IN_PATH,
**attrs
diff --git a/src/drf_yasg/inspectors/__init__.py b/src/drf_yasg/inspectors/__init__.py
new file mode 100644
index 0000000..50d7aa2
--- /dev/null
+++ b/src/drf_yasg/inspectors/__init__.py
@@ -0,0 +1,38 @@
+from .base import (
+ BaseInspector, ViewInspector, FilterInspector, PaginatorInspector,
+ FieldInspector, SerializerInspector, NotHandled
+)
+from .field import (
+ InlineSerializerInspector, ReferencingSerializerInspector, RelatedFieldInspector, SimpleFieldInspector,
+ FileFieldInspector, ChoiceFieldInspector, DictFieldInspector, StringDefaultFieldInspector,
+ CamelCaseJSONFilter
+)
+from .query import (
+ CoreAPICompatInspector, DjangoRestResponsePagination
+)
+from .view import SwaggerAutoSchema
+from ..app_settings import swagger_settings
+
+# these settings must be accesed only after definig/importing all the classes in this module to avoid ImportErrors
+ViewInspector.field_inspectors = swagger_settings.DEFAULT_FIELD_INSPECTORS
+ViewInspector.filter_inspectors = swagger_settings.DEFAULT_FILTER_INSPECTORS
+ViewInspector.paginator_inspectors = swagger_settings.DEFAULT_PAGINATOR_INSPECTORS
+
+__all__ = [
+ # base inspectors
+ 'BaseInspector', 'FilterInspector', 'PaginatorInspector', 'FieldInspector', 'SerializerInspector', 'ViewInspector',
+
+ # filter and pagination inspectors
+ 'CoreAPICompatInspector', 'DjangoRestResponsePagination',
+
+ # field inspectors
+ 'InlineSerializerInspector', 'ReferencingSerializerInspector', 'RelatedFieldInspector', 'SimpleFieldInspector',
+ 'FileFieldInspector', 'ChoiceFieldInspector', 'DictFieldInspector', 'StringDefaultFieldInspector',
+ 'CamelCaseJSONFilter',
+
+ # view inspectors
+ 'SwaggerAutoSchema',
+
+ # module constants
+ 'NotHandled',
+]
diff --git a/src/drf_yasg/inspectors/base.py b/src/drf_yasg/inspectors/base.py
new file mode 100644
index 0000000..0dbcbce
--- /dev/null
+++ b/src/drf_yasg/inspectors/base.py
@@ -0,0 +1,406 @@
+import inspect
+import logging
+
+from django.utils.encoding import force_text
+from rest_framework import serializers
+from rest_framework.utils import json, encoders
+from rest_framework.viewsets import GenericViewSet
+
+from .. import openapi
+from ..utils import is_list_view
+
+#: Sentinel value that inspectors must return to signal that they do not know how to handle an object
+NotHandled = object()
+
+logger = logging.getLogger(__name__)
+
+
+class BaseInspector(object):
+ def __init__(self, view, path, method, components, request):
+ """
+ :param view: the view associated with this endpoint
+ :param str path: the path component of the operation URL
+ :param str method: the http method of the operation
+ :param openapi.ReferenceResolver components: referenceable components
+ :param Request request: the request made against the schema view; can be None
+ """
+ self.view = view
+ self.path = path
+ self.method = method
+ self.components = components
+ self.request = request
+
+ def process_result(self, result, method_name, obj, **kwargs):
+ """After an inspector handles an object (i.e. returns a value other than :data:`.NotHandled`), all inspectors
+ that were probed get the chance to alter the result, in reverse order. The inspector that handled the object
+ is the first to receive a ``process_result`` call with the object it just returned.
+
+ This behaviour is similar to the Django request/response middleware processing.
+
+ If this inspector has no post-processing to do, it should just ``return result`` (the default implementation).
+
+ :param result: the return value of the winning inspector, or ``None`` if no inspector handled the object
+ :param str method_name: name of the method that was called on the inspector
+ :param obj: first argument passed to inspector method
+ :param kwargs: additional arguments passed to inspector method
+ :return:
+ """
+ return result
+
+ def probe_inspectors(self, inspectors, method_name, obj, initkwargs=None, **kwargs):
+ """Probe a list of inspectors with a given object. The first inspector in the list to return a value that
+ is not :data:`.NotHandled` wins.
+
+ :param list[type[BaseInspector]] inspectors: list of inspectors to probe
+ :param str method_name: name of the target method on the inspector
+ :param obj: first argument to inspector method
+ :param dict initkwargs: extra kwargs for instantiating inspector class
+ :param kwargs: additional arguments to inspector method
+ :return: the return value of the winning inspector, or ``None`` if no inspector handled the object
+ """
+ initkwargs = initkwargs or {}
+ tried_inspectors = []
+
+ for inspector in inspectors:
+ assert inspect.isclass(inspector), "inspector must be a class, not an object"
+ assert issubclass(inspector, BaseInspector), "inspectors must subclass BaseInspector"
+
+ inspector = inspector(self.view, self.path, self.method, self.components, self.request, **initkwargs)
+ tried_inspectors.append(inspector)
+ method = getattr(inspector, method_name, None)
+ if method is None:
+ continue
+
+ result = method(obj, **kwargs)
+ if result is not NotHandled:
+ break
+ else: # pragma: no cover
+ logger.warning("%s ignored because no inspector in %s handled it (operation: %s)",
+ obj, inspectors, method_name)
+ result = None
+
+ for inspector in reversed(tried_inspectors):
+ result = inspector.process_result(result, method_name, obj, **kwargs)
+
+ return result
+
+
+class PaginatorInspector(BaseInspector):
+ """Base inspector for paginators.
+
+ Responisble for determining extra query parameters and response structure added by given paginators.
+ """
+
+ def get_paginator_parameters(self, paginator):
+ """Get the pagination parameters for a single paginator **instance**.
+
+ Should return :data:`.NotHandled` if this inspector does not know how to handle the given `paginator`.
+
+ :param BasePagination paginator: the paginator
+ :rtype: list[openapi.Parameter]
+ """
+ return NotHandled
+
+ def get_paginated_response(self, paginator, response_schema):
+ """Add appropriate paging fields to a response :class:`.Schema`.
+
+ Should return :data:`.NotHandled` if this inspector does not know how to handle the given `paginator`.
+
+ :param BasePagination paginator: the paginator
+ :param openapi.Schema response_schema: the response schema that must be paged.
+ :rtype: openapi.Schema
+ """
+ return NotHandled
+
+
+class FilterInspector(BaseInspector):
+ """Base inspector for filter backends.
+
+ Responsible for determining extra query parameters added by given filter backends.
+ """
+
+ def get_filter_parameters(self, filter_backend):
+ """Get the filter parameters for a single filter backend **instance**.
+
+ Should return :data:`.NotHandled` if this inspector does not know how to handle the given `filter_backend`.
+
+ :param BaseFilterBackend filter_backend: the filter backend
+ :rtype: list[openapi.Parameter]
+ """
+ return NotHandled
+
+
+class FieldInspector(BaseInspector):
+ """Base inspector for serializers and serializer fields. """
+
+ def __init__(self, view, path, method, components, request, field_inspectors):
+ super(FieldInspector, self).__init__(view, path, method, components, request)
+ self.field_inspectors = field_inspectors
+
+ def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
+ """Convert a drf Serializer or Field instance into a Swagger object.
+
+ Should return :data:`.NotHandled` if this inspector does not know how to handle the given `field`.
+
+ :param rest_framework.serializers.Field field: the source field
+ :param type[openapi.SwaggerDict] swagger_object_type: should be one of Schema, Parameter, Items
+ :param bool use_references: if False, forces all objects to be declared inline
+ instead of by referencing other components
+ :param kwargs: extra attributes for constructing the object;
+ if swagger_object_type is Parameter, ``name`` and ``in_`` should be provided
+ :return: the swagger object
+ :rtype: openapi.Parameter,openapi.Items,openapi.Schema,openapi.SchemaRef
+ """
+ return NotHandled
+
+ def probe_field_inspectors(self, field, swagger_object_type, use_references, **kwargs):
+ """Helper method for recursively probing `field_inspectors` to handle a given field.
+
+ All arguments are the same as :meth:`.field_to_swagger_object`.
+
+ :rtype: openapi.Parameter,openapi.Items,openapi.Schema,openapi.SchemaRef
+ """
+ return self.probe_inspectors(
+ self.field_inspectors, 'field_to_swagger_object', field, {'field_inspectors': self.field_inspectors},
+ swagger_object_type=swagger_object_type, use_references=use_references, **kwargs
+ )
+
+ def _get_partial_types(self, field, swagger_object_type, use_references, **kwargs):
+ """Helper method to extract generic information from a field and return a partial constructor for the
+ appropriate openapi object.
+
+ All arguments are the same as :meth:`.field_to_swagger_object`.
+
+ The return value is a tuple consisting of:
+
+ * a function for constructing objects of `swagger_object_type`; its prototype is: ::
+
+ def SwaggerType(existing_object=None, **instance_kwargs):
+
+ This function creates an instance of `swagger_object_type`, passing the following attributes to its init,
+ in order of precedence:
+
+ - arguments specified by the ``kwargs`` parameter of :meth:`._get_partial_types`
+ - ``instance_kwargs`` passed to the constructor function
+ - ``title``, ``description``, ``required``, ``default`` and ``read_only`` inferred from the field,
+ where appropriate
+
+ If ``existing_object`` is not ``None``, it is updated instead of creating a new object.
+
+ * a type that should be used for child objects if `field` is of an array type. This can currently have two
+ values:
+
+ - :class:`.Schema` if `swagger_object_type` is :class:`.Schema`
+ - :class:`.Items` if `swagger_object_type` is :class:`.Parameter` or :class:`.Items`
+
+ :rtype: tuple[callable,(type[openapi.Schema],type[openapi.Items])]
+ """
+ assert swagger_object_type in (openapi.Schema, openapi.Parameter, openapi.Items)
+ assert not isinstance(field, openapi.SwaggerDict), "passed field is already a SwaggerDict object"
+ title = force_text(field.label) if field.label else None
+ title = title if swagger_object_type == openapi.Schema else None # only Schema has title
+ description = force_text(field.help_text) if field.help_text else None
+ description = description if swagger_object_type != openapi.Items else None # Items has no description either
+
+ def SwaggerType(existing_object=None, **instance_kwargs):
+ if 'required' not in instance_kwargs and swagger_object_type == openapi.Parameter:
+ instance_kwargs['required'] = field.required
+
+ if 'default' not in instance_kwargs and swagger_object_type != openapi.Items:
+ default = getattr(field, 'default', serializers.empty)
+ if default is not serializers.empty:
+ if callable(default):
+ try:
+ if hasattr(default, 'set_context'):
+ default.set_context(field)
+ default = default()
+ except Exception: # pragma: no cover
+ logger.warning("default for %s is callable but it raised an exception when "
+ "called; 'default' field will not be added to schema", field, exc_info=True)
+ default = None
+
+ if default is not None:
+ try:
+ default = field.to_representation(default)
+ # JSON roundtrip ensures that the value is valid JSON;
+ # for example, sets and tuples get transformed into lists
+ default = json.loads(json.dumps(default, cls=encoders.JSONEncoder))
+ except Exception: # pragma: no cover
+ logger.warning("'default' on schema for %s will not be set because "
+ "to_representation raised an exception", field, exc_info=True)
+ default = None
+
+ if default is not None:
+ instance_kwargs['default'] = default
+
+ if 'read_only' not in instance_kwargs and swagger_object_type == openapi.Schema:
+ # TODO: read_only is only relevant for schema `properties` - should not be generated in other cases
+ if field.read_only:
+ instance_kwargs['read_only'] = True
+
+ instance_kwargs.setdefault('title', title)
+ instance_kwargs.setdefault('description', description)
+ instance_kwargs.update(kwargs)
+
+ if existing_object is not None:
+ assert isinstance(existing_object, swagger_object_type)
+ for attr, val in sorted(instance_kwargs.items()):
+ setattr(existing_object, attr, val)
+ return existing_object
+
+ return swagger_object_type(**instance_kwargs)
+
+ # arrays in Schema have Schema elements, arrays in Parameter and Items have Items elements
+ child_swagger_type = openapi.Schema if swagger_object_type == openapi.Schema else openapi.Items
+ return SwaggerType, child_swagger_type
+
+
+class SerializerInspector(FieldInspector):
+ def get_schema(self, serializer):
+ """Convert a DRF Serializer instance to an :class:`.openapi.Schema`.
+
+ Should return :data:`.NotHandled` if this inspector does not know how to handle the given `serializer`.
+
+ :param serializers.BaseSerializer serializer: the ``Serializer`` instance
+ :rtype: openapi.Schema
+ """
+ return NotHandled
+
+ def get_request_parameters(self, serializer, in_):
+ """Convert a DRF serializer into a list of :class:`.Parameter`\ s.
+
+ Should return :data:`.NotHandled` if this inspector does not know how to handle the given `serializer`.
+
+ :param serializers.BaseSerializer serializer: the ``Serializer`` instance
+ :param str in_: the location of the parameters, one of the `openapi.IN_*` constants
+ :rtype: list[openapi.Parameter]
+ """
+ return NotHandled
+
+
+class ViewInspector(BaseInspector):
+ body_methods = ('PUT', 'PATCH', 'POST') #: methods that are allowed to have a request body
+
+ # real values set in __init__ to prevent import errors
+ field_inspectors = [] #:
+ filter_inspectors = [] #:
+ paginator_inspectors = [] #:
+
+ def __init__(self, view, path, method, components, request, overrides):
+ """
+ Inspector class responsible for providing :class:`.Operation` definitions given a view, path and method.
+
+ :param dict overrides: manual overrides as passed to :func:`@swagger_auto_schema <.swagger_auto_schema>`
+ """
+ super(ViewInspector, self).__init__(view, path, method, components, request)
+ self.overrides = overrides
+ self._prepend_inspector_overrides('field_inspectors')
+ self._prepend_inspector_overrides('filter_inspectors')
+ self._prepend_inspector_overrides('paginator_inspectors')
+
+ def _prepend_inspector_overrides(self, inspectors):
+ extra_inspectors = self.overrides.get(inspectors, None)
+ if extra_inspectors:
+ default_inspectors = [insp for insp in getattr(self, inspectors) if insp not in extra_inspectors]
+ setattr(self, inspectors, extra_inspectors + default_inspectors)
+
+ def get_operation(self, operation_keys):
+ """Get an :class:`.Operation` for the given API endpoint (path, method).
+ This includes query, body parameters and response schemas.
+
+ :param tuple[str] operation_keys: an array of keys describing the hierarchical layout of this view in the API;
+ e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc.
+ :rtype: openapi.Operation
+ """
+ raise NotImplementedError("ViewInspector must implement get_operation()!")
+
+ # methods below provided as default implementations for probing inspectors
+
+ def should_filter(self):
+ """Determine whether filter backend parameters should be included for this request.
+
+ :rtype: bool
+ """
+ if not getattr(self.view, 'filter_backends', None):
+ return False
+
+ if self.method.lower() not in ["get", "delete"]:
+ return False
+
+ if not isinstance(self.view, GenericViewSet):
+ return True
+
+ return is_list_view(self.path, self.method, self.view)
+
+ def get_filter_parameters(self):
+ """Return the parameters added to the view by its filter backends.
+
+ :rtype: list[openapi.Parameter]
+ """
+ if not self.should_filter():
+ return []
+
+ fields = []
+ for filter_backend in self.view.filter_backends:
+ fields += self.probe_inspectors(self.filter_inspectors, 'get_filter_parameters', filter_backend()) or []
+
+ return fields
+
+ def should_page(self):
+ """Determine whether paging parameters and structure should be added to this operation's request and response.
+
+ :rtype: bool
+ """
+ if not hasattr(self.view, 'paginator'):
+ return False
+
+ if self.view.paginator is None:
+ return False
+
+ if self.method.lower() != 'get':
+ return False
+
+ return is_list_view(self.path, self.method, self.view)
+
+ def get_pagination_parameters(self):
+ """Return the parameters added to the view by its paginator.
+
+ :rtype: list[openapi.Parameter]
+ """
+ if not self.should_page():
+ return []
+
+ return self.probe_inspectors(self.paginator_inspectors, 'get_paginator_parameters', self.view.paginator) or []
+
+ def serializer_to_schema(self, serializer):
+ """Convert a serializer to an OpenAPI :class:`.Schema`.
+
+ :param serializers.BaseSerializer serializer: the ``Serializer`` instance
+ :returns: the converted :class:`.Schema`, or ``None`` in case of an unknown serializer
+ :rtype: openapi.Schema,openapi.SchemaRef,None
+ """
+ return self.probe_inspectors(
+ self.field_inspectors, 'get_schema', serializer, {'field_inspectors': self.field_inspectors}
+ )
+
+ def serializer_to_parameters(self, serializer, in_):
+ """Convert a serializer to a possibly empty list of :class:`.Parameter`\ s.
+
+ :param serializers.BaseSerializer serializer: the ``Serializer`` instance
+ :param str in_: the location of the parameters, one of the `openapi.IN_*` constants
+ :rtype: list[openapi.Parameter]
+ """
+ return self.probe_inspectors(
+ self.field_inspectors, 'get_request_parameters', serializer, {'field_inspectors': self.field_inspectors},
+ in_=in_
+ ) or []
+
+ def get_paginated_response(self, response_schema):
+ """Add appropriate paging fields to a response :class:`.Schema`.
+
+ :param openapi.Schema response_schema: the response schema that must be paged.
+ :returns: the paginated response class:`.Schema`, or ``None`` in case of an unknown pagination scheme
+ :rtype: openapi.Schema
+ """
+ return self.probe_inspectors(self.paginator_inspectors, 'get_paginated_response',
+ self.view.paginator, response_schema=response_schema)
diff --git a/src/drf_yasg/inspectors/field.py b/src/drf_yasg/inspectors/field.py
new file mode 100644
index 0000000..ed77d3e
--- /dev/null
+++ b/src/drf_yasg/inspectors/field.py
@@ -0,0 +1,455 @@
+import operator
+from collections import OrderedDict
+
+from django.core import validators
+from django.db import models
+from rest_framework import serializers
+from rest_framework.settings import api_settings as rest_framework_settings
+
+from .base import NotHandled, SerializerInspector, FieldInspector
+from .. import openapi
+from ..errors import SwaggerGenerationError
+from ..utils import filter_none
+
+
+class InlineSerializerInspector(SerializerInspector):
+ """Provides serializer conversions using :meth:`.FieldInspector.field_to_swagger_object`."""
+
+ #: whether to output :class:`.Schema` definitions inline or into the ``definitions`` section
+ use_definitions = False
+
+ def get_schema(self, serializer):
+ return self.probe_field_inspectors(serializer, openapi.Schema, self.use_definitions)
+
+ def get_request_parameters(self, serializer, in_):
+ fields = getattr(serializer, 'fields', {})
+ return [
+ self.probe_field_inspectors(
+ value, openapi.Parameter, self.use_definitions,
+ name=self.get_parameter_name(key), in_=in_
+ )
+ for key, value
+ in fields.items()
+ ]
+
+ def get_property_name(self, field_name):
+ return field_name
+
+ def get_parameter_name(self, field_name):
+ return field_name
+
+ def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
+ SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
+
+ if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
+ child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references)
+ return SwaggerType(
+ type=openapi.TYPE_ARRAY,
+ items=child_schema,
+ )
+ elif isinstance(field, serializers.Serializer):
+ if swagger_object_type != openapi.Schema:
+ raise SwaggerGenerationError("cannot instantiate nested serializer as " + swagger_object_type.__name__)
+
+ serializer = field
+ serializer_meta = getattr(serializer, 'Meta', None)
+ if hasattr(serializer_meta, 'ref_name'):
+ ref_name = serializer_meta.ref_name
+ else:
+ ref_name = type(serializer).__name__
+ if ref_name.endswith('Serializer'):
+ ref_name = ref_name[:-len('Serializer')]
+
+ def make_schema_definition():
+ properties = OrderedDict()
+ required = []
+ for key, value in serializer.fields.items():
+ key = self.get_property_name(key)
+ properties[key] = self.probe_field_inspectors(value, ChildSwaggerType, use_references)
+ if value.required:
+ required.append(key)
+
+ return SwaggerType(
+ type=openapi.TYPE_OBJECT,
+ properties=properties,
+ required=required or None,
+ )
+
+ if not ref_name or not use_references:
+ return make_schema_definition()
+
+ definitions = self.components.with_scope(openapi.SCHEMA_DEFINITIONS)
+ definitions.setdefault(ref_name, make_schema_definition)
+ return openapi.SchemaRef(definitions, ref_name)
+
+ return NotHandled
+
+
+class ReferencingSerializerInspector(InlineSerializerInspector):
+ use_definitions = True
+
+
+def get_queryset_field(queryset, field_name):
+ """Try to get information about a model and model field from a queryset.
+
+ :param queryset: the queryset
+ :param field_name: target field name
+ :returns: the model and target field from the queryset as a 2-tuple; both elements can be ``None``
+ :rtype: tuple
+ """
+ model = getattr(queryset, 'model', None)
+ model_field = get_model_field(model, field_name)
+ return model, model_field
+
+
+def get_model_field(model, field_name):
+ """Try to get the given field from a django db model.
+
+ :param model: the model
+ :param field_name: target field name
+ :return: model field or ``None``
+ """
+ try:
+ if field_name == 'pk':
+ return model._meta.pk
+ else:
+ return model._meta.get_field(field_name)
+ except Exception: # pragma: no cover
+ return None
+
+
+def get_parent_serializer(field):
+ """Get the nearest parent ``Serializer`` instance for the given field.
+
+ :return: ``Serializer`` or ``None``
+ """
+ while field is not None:
+ if isinstance(field, serializers.Serializer):
+ return field
+
+ field = field.parent
+
+ return None # pragma: no cover
+
+
+def get_related_model(model, source):
+ """Try to find the other side of a model relationship given the name of a related field.
+
+ :param model: one side of the relationship
+ :param str source: related field name
+ :return: related model or ``None``
+ """
+ try:
+ return getattr(model, source).rel.related_model
+ except Exception: # pragma: no cover
+ return None
+
+
+class RelatedFieldInspector(FieldInspector):
+ """Provides conversions for ``RelatedField``\ s."""
+
+ def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
+ SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
+
+ if isinstance(field, serializers.ManyRelatedField):
+ child_schema = self.probe_field_inspectors(field.child_relation, ChildSwaggerType, use_references)
+ return SwaggerType(
+ type=openapi.TYPE_ARRAY,
+ items=child_schema,
+ unique_items=True,
+ )
+
+ if not isinstance(field, serializers.RelatedField):
+ return NotHandled
+
+ field_queryset = getattr(field, 'queryset', None)
+
+ if isinstance(field, (serializers.PrimaryKeyRelatedField, serializers.SlugRelatedField)):
+ if getattr(field, 'pk_field', ''):
+ # a PrimaryKeyRelatedField can have a `pk_field` attribute which is a
+ # serializer field that will convert the PK value
+ result = self.probe_field_inspectors(field.pk_field, swagger_object_type, use_references, **kwargs)
+ # take the type, format, etc from `pk_field`, and the field-level information
+ # like title, description, default from the PrimaryKeyRelatedField
+ return SwaggerType(existing_object=result)
+
+ target_field = getattr(field, 'slug_field', 'pk')
+ if field_queryset is not None:
+ # if the RelatedField has a queryset, try to get the related model field from there
+ model, model_field = get_queryset_field(field_queryset, target_field)
+ else:
+ # if the RelatedField has no queryset (e.g. read only), try to find the target model
+ # from the view queryset or ModelSerializer model, if present
+ view_queryset = getattr(self.view, 'queryset', None)
+ serializer_meta = getattr(get_parent_serializer(field), 'Meta', None)
+ this_model = getattr(view_queryset, 'model', None) or getattr(serializer_meta, 'model', None)
+ source = getattr(field, 'source', '') or field.field_name
+ model = get_related_model(this_model, source)
+ model_field = get_model_field(model, target_field)
+
+ attrs = get_basic_type_info(model_field) or {'type': openapi.TYPE_STRING}
+ return SwaggerType(**attrs)
+ elif isinstance(field, serializers.HyperlinkedRelatedField):
+ return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)
+
+ return SwaggerType(type=openapi.TYPE_STRING)
+
+
+def find_regex(regex_field):
+ """Given a ``Field``, look for a ``RegexValidator`` and try to extract its pattern and return it as a string.
+
+ :param serializers.Field regex_field: the field instance
+ :return: the extracted pattern, or ``None``
+ :rtype: str
+ """
+ regex_validator = None
+ for validator in regex_field.validators:
+ if isinstance(validator, validators.RegexValidator):
+ if regex_validator is not None:
+ # bail if multiple validators are found - no obvious way to choose
+ return None # pragma: no cover
+ regex_validator = validator
+
+ # regex_validator.regex should be a compiled re object...
+ return getattr(getattr(regex_validator, 'regex', None), 'pattern', None)
+
+
+numeric_fields = (serializers.IntegerField, serializers.FloatField, serializers.DecimalField)
+limit_validators = [
+ # minimum and maximum apply to numbers
+ (validators.MinValueValidator, numeric_fields, 'minimum', operator.__gt__),
+ (validators.MaxValueValidator, numeric_fields, 'maximum', operator.__lt__),
+
+ # minLength and maxLength apply to strings
+ (validators.MinLengthValidator, serializers.CharField, 'min_length', operator.__gt__),
+ (validators.MaxLengthValidator, serializers.CharField, 'max_length', operator.__lt__),
+
+ # minItems and maxItems apply to lists
+ (validators.MinLengthValidator, serializers.ListField, 'min_items', operator.__gt__),
+ (validators.MaxLengthValidator, serializers.ListField, 'max_items', operator.__lt__),
+]
+
+
+def find_limits(field):
+ """Given a ``Field``, look for min/max value/length validators and return appropriate limit validation attributes.
+
+ :param serializers.Field field: the field instance
+ :return: the extracted limits
+ :rtype: OrderedDict
+ """
+ limits = {}
+ applicable_limits = [
+ (validator, attr, improves)
+ for validator, field_class, attr, improves in limit_validators
+ if isinstance(field, field_class)
+ ]
+
+ for validator in field.validators:
+ if not hasattr(validator, 'limit_value'):
+ continue
+
+ for validator_class, attr, improves in applicable_limits:
+ if isinstance(validator, validator_class):
+ if attr not in limits or improves(validator.limit_value, limits[attr]):
+ limits[attr] = validator.limit_value
+
+ return OrderedDict(sorted(limits.items()))
+
+
+model_field_to_basic_type = [
+ (models.AutoField, (openapi.TYPE_INTEGER, None)),
+ (models.BinaryField, (openapi.TYPE_STRING, openapi.FORMAT_BINARY)),
+ (models.BooleanField, (openapi.TYPE_BOOLEAN, None)),
+ (models.NullBooleanField, (openapi.TYPE_BOOLEAN, None)),
+ (models.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
+ (models.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
+ (models.DecimalField, (openapi.TYPE_NUMBER, None)),
+ (models.DurationField, (openapi.TYPE_INTEGER, None)),
+ (models.FloatField, (openapi.TYPE_NUMBER, None)),
+ (models.IntegerField, (openapi.TYPE_INTEGER, None)),
+ (models.IPAddressField, (openapi.TYPE_STRING, openapi.FORMAT_IPV4)),
+ (models.GenericIPAddressField, (openapi.TYPE_STRING, openapi.FORMAT_IPV6)),
+ (models.SlugField, (openapi.TYPE_STRING, openapi.FORMAT_SLUG)),
+ (models.TextField, (openapi.TYPE_STRING, None)),
+ (models.TimeField, (openapi.TYPE_STRING, None)),
+ (models.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
+ (models.CharField, (openapi.TYPE_STRING, None)),
+]
+
+ip_format = {'ipv4': openapi.FORMAT_IPV4, 'ipv6': openapi.FORMAT_IPV6}
+
+serializer_field_to_basic_type = [
+ (serializers.EmailField, (openapi.TYPE_STRING, openapi.FORMAT_EMAIL)),
+ (serializers.SlugField, (openapi.TYPE_STRING, openapi.FORMAT_SLUG)),
+ (serializers.URLField, (openapi.TYPE_STRING, openapi.FORMAT_URI)),
+ (serializers.IPAddressField, (openapi.TYPE_STRING, lambda field: ip_format.get(field.protocol, None))),
+ (serializers.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
+ (serializers.RegexField, (openapi.TYPE_STRING, None)),
+ (serializers.CharField, (openapi.TYPE_STRING, None)),
+ ((serializers.BooleanField, serializers.NullBooleanField), (openapi.TYPE_BOOLEAN, None)),
+ (serializers.IntegerField, (openapi.TYPE_INTEGER, None)),
+ ((serializers.FloatField, serializers.DecimalField), (openapi.TYPE_NUMBER, None)),
+ (serializers.DurationField, (openapi.TYPE_NUMBER, None)), # ?
+ (serializers.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
+ (serializers.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
+ (serializers.ModelField, (openapi.TYPE_STRING, None)),
+]
+
+basic_type_info = serializer_field_to_basic_type + model_field_to_basic_type
+
+
+def get_basic_type_info(field):
+ """Given a serializer or model ``Field``, return its basic type information - ``type``, ``format``, ``pattern``,
+ and any applicable min/max limit values.
+
+ :param field: the field instance
+ :return: the extracted attributes as a dictionary, or ``None`` if the field type is not known
+ :rtype: OrderedDict
+ """
+ if field is None:
+ return None
+
+ for field_class, type_format in basic_type_info:
+ if isinstance(field, field_class):
+ swagger_type, format = type_format
+ if callable(format):
+ format = format(field)
+ break
+ else: # pragma: no cover
+ return None
+
+ pattern = find_regex(field) if format in (None, openapi.FORMAT_SLUG) else None
+ limits = find_limits(field)
+
+ result = OrderedDict([
+ ('type', swagger_type),
+ ('format', format),
+ ('pattern', pattern)
+ ])
+ result.update(limits)
+ result = filter_none(result)
+ return result
+
+
+class SimpleFieldInspector(FieldInspector):
+ """Provides conversions for fields which can be described using just ``type``, ``format``, ``pattern``
+ and min/max validators.
+ """
+
+ def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
+ type_info = get_basic_type_info(field)
+ if type_info is None:
+ return NotHandled
+
+ SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
+ return SwaggerType(**type_info)
+
+
+class ChoiceFieldInspector(FieldInspector):
+ """Provides conversions for ``ChoiceField`` and ``MultipleChoiceField``."""
+
+ def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
+ SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
+
+ if isinstance(field, serializers.MultipleChoiceField):
+ return SwaggerType(
+ type=openapi.TYPE_ARRAY,
+ items=ChildSwaggerType(
+ type=openapi.TYPE_STRING,
+ enum=list(field.choices.keys())
+ )
+ )
+ elif isinstance(field, serializers.ChoiceField):
+ return SwaggerType(type=openapi.TYPE_STRING, enum=list(field.choices.keys()))
+
+ return NotHandled
+
+
+class FileFieldInspector(FieldInspector):
+ """Provides conversions for ``FileField``\ s."""
+
+ def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
+ SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
+
+ if isinstance(field, serializers.FileField):
+ # swagger 2.0 does not support specifics about file fields, so ImageFile gets no special treatment
+ # OpenAPI 3.0 does support it, so a future implementation could handle this better
+ err = SwaggerGenerationError("FileField is supported only in a formData Parameter or response Schema")
+ if swagger_object_type == openapi.Schema:
+ # FileField.to_representation returns URL or file name
+ result = SwaggerType(type=openapi.TYPE_STRING, read_only=True)
+ if getattr(field, 'use_url', rest_framework_settings.UPLOADED_FILES_USE_URL):
+ result.format = openapi.FORMAT_URI
+ return result
+ elif swagger_object_type == openapi.Parameter:
+ param = SwaggerType(type=openapi.TYPE_FILE)
+ if param['in'] != openapi.IN_FORM:
+ raise err # pragma: no cover
+ return param
+ else:
+ raise err # pragma: no cover
+
+ return NotHandled
+
+
+class DictFieldInspector(FieldInspector):
+ """Provides conversion for ``DictField``."""
+
+ def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
+ SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
+
+ if isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema:
+ child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references)
+ return SwaggerType(
+ type=openapi.TYPE_OBJECT,
+ additional_properties=child_schema
+ )
+
+ return NotHandled
+
+
+class StringDefaultFieldInspector(FieldInspector):
+ """For otherwise unhandled fields, return them as plain :data:`.TYPE_STRING` objects."""
+
+ def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): # pragma: no cover
+ # TODO unhandled fields: TimeField HiddenField JSONField
+ SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
+ return SwaggerType(type=openapi.TYPE_STRING)
+
+
+try:
+ from djangorestframework_camel_case.parser import CamelCaseJSONParser
+ from djangorestframework_camel_case.render import CamelCaseJSONRenderer
+ from djangorestframework_camel_case.render import camelize
+except ImportError: # pragma: no cover
+ class CamelCaseJSONFilter(FieldInspector):
+ pass
+else:
+ def camelize_string(s):
+ """Hack to force ``djangorestframework_camel_case`` to camelize a plain string."""
+ return next(iter(camelize({s: ''})))
+
+ def camelize_schema(schema_or_ref, components):
+ """Recursively camelize property names for the given schema using ``djangorestframework_camel_case``."""
+ schema = openapi.resolve_ref(schema_or_ref, components)
+ if getattr(schema, 'properties', {}):
+ schema.properties = OrderedDict(
+ (camelize_string(key), camelize_schema(val, components))
+ for key, val in schema.properties.items()
+ )
+
+ if getattr(schema, 'required', []):
+ schema.required = [camelize_string(p) for p in schema.required]
+
+ return schema_or_ref
+
+ class CamelCaseJSONFilter(FieldInspector):
+ def is_camel_case(self):
+ return any(issubclass(parser, CamelCaseJSONParser) for parser in self.view.parser_classes) \
+ or any(issubclass(renderer, CamelCaseJSONRenderer) for renderer in self.view.renderer_classes)
+
+ def process_result(self, result, method_name, obj, **kwargs):
+ if isinstance(result, openapi.Schema.OR_REF) and self.is_camel_case():
+ return camelize_schema(result, self.components)
+
+ return result
diff --git a/src/drf_yasg/inspectors/query.py b/src/drf_yasg/inspectors/query.py
new file mode 100644
index 0000000..90717d2
--- /dev/null
+++ b/src/drf_yasg/inspectors/query.py
@@ -0,0 +1,76 @@
+from collections import OrderedDict
+
+import coreschema
+from rest_framework.pagination import CursorPagination, PageNumberPagination, LimitOffsetPagination
+
+from .base import PaginatorInspector, FilterInspector
+from .. import openapi
+
+
+class CoreAPICompatInspector(PaginatorInspector, FilterInspector):
+ """Converts ``coreapi.Field``\ s to :class:`.openapi.Parameter`\ s for filters and paginators that implement a
+ ``get_schema_fields`` method.
+ """
+
+ def get_paginator_parameters(self, paginator):
+ fields = []
+ if hasattr(paginator, 'get_schema_fields'):
+ fields = paginator.get_schema_fields(self.view)
+
+ return [self.coreapi_field_to_parameter(field) for field in fields]
+
+ def get_filter_parameters(self, filter_backend):
+ fields = []
+ if hasattr(filter_backend, 'get_schema_fields'):
+ fields = filter_backend.get_schema_fields(self.view)
+ return [self.coreapi_field_to_parameter(field) for field in fields]
+
+ def coreapi_field_to_parameter(self, field):
+ """Convert an instance of `coreapi.Field` to a swagger :class:`.Parameter` object.
+
+ :param coreapi.Field field:
+ :rtype: openapi.Parameter
+ """
+ location_to_in = {
+ 'query': openapi.IN_QUERY,
+ 'path': openapi.IN_PATH,
+ 'form': openapi.IN_FORM,
+ 'body': openapi.IN_FORM,
+ }
+ coreapi_types = {
+ coreschema.Integer: openapi.TYPE_INTEGER,
+ coreschema.Number: openapi.TYPE_NUMBER,
+ coreschema.String: openapi.TYPE_STRING,
+ coreschema.Boolean: openapi.TYPE_BOOLEAN,
+ }
+ return openapi.Parameter(
+ name=field.name,
+ in_=location_to_in[field.location],
+ type=coreapi_types.get(type(field.schema), openapi.TYPE_STRING),
+ required=field.required,
+ description=field.schema.description,
+ )
+
+
+class DjangoRestResponsePagination(PaginatorInspector):
+ """Provides response schema pagination warpping for django-rest-framework's LimitOffsetPagination,
+ PageNumberPagination and CursorPagination
+ """
+
+ def get_paginated_response(self, paginator, response_schema):
+ assert response_schema.type == openapi.TYPE_ARRAY, "array return expected for paged response"
+ paged_schema = None
+ if isinstance(paginator, (LimitOffsetPagination, PageNumberPagination, CursorPagination)):
+ has_count = not isinstance(paginator, CursorPagination)
+ paged_schema = openapi.Schema(
+ type=openapi.TYPE_OBJECT,
+ properties=OrderedDict((
+ ('count', openapi.Schema(type=openapi.TYPE_INTEGER) if has_count else None),
+ ('next', openapi.Schema(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)),
+ ('previous', openapi.Schema(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)),
+ ('results', response_schema),
+ )),
+ required=['count', 'results']
+ )
+
+ return paged_schema
diff --git a/src/drf_yasg/inspectors.py b/src/drf_yasg/inspectors/view.py
similarity index 57%
rename from src/drf_yasg/inspectors.py
rename to src/drf_yasg/inspectors/view.py
index 6cb03e8..ff3c6fd 100644
--- a/src/drf_yasg/inspectors.py
+++ b/src/drf_yasg/inspectors/view.py
@@ -1,63 +1,22 @@
-import inspect
from collections import OrderedDict
-import coreschema
-from rest_framework import serializers, status
from rest_framework.request import is_form_media_type
from rest_framework.schemas import AutoSchema
from rest_framework.status import is_success
-from rest_framework.viewsets import GenericViewSet
-from . import openapi
-from .errors import SwaggerGenerationError
-from .utils import serializer_field_to_swagger, no_body, is_list_view, param_list_to_odict
+from .base import ViewInspector
+from .. import openapi
+from ..errors import SwaggerGenerationError
+from ..utils import force_serializer_instance, no_body, is_list_view, param_list_to_odict, guess_response_status
-def force_serializer_instance(serializer):
- """Force `serializer` into a ``Serializer`` instance. If it is not a ``Serializer`` class or instance, raises
- an assertion error.
-
- :param serializer: serializer class or instance
- :return: serializer instance
- """
- if inspect.isclass(serializer):
- assert issubclass(serializer, serializers.BaseSerializer), "Serializer required, not %s" % serializer.__name__
- return serializer()
-
- assert isinstance(serializer, serializers.BaseSerializer), \
- "Serializer class or instance required, not %s" % type(serializer).__name__
- return serializer
-
-
-class SwaggerAutoSchema(object):
- body_methods = ('PUT', 'PATCH', 'POST') #: methods allowed to have a request body
-
- def __init__(self, view, path, method, overrides, components):
- """Inspector class responsible for providing :class:`.Operation` definitions given a
-
- :param view: the view associated with this endpoint
- :param str path: the path component of the operation URL
- :param str method: the http method of the operation
- :param dict overrides: manual overrides as passed to :func:`@swagger_auto_schema <.swagger_auto_schema>`
- :param openapi.ReferenceResolver components: referenceable components
- """
- super(SwaggerAutoSchema, self).__init__()
+class SwaggerAutoSchema(ViewInspector):
+ def __init__(self, view, path, method, components, request, overrides):
+ super(SwaggerAutoSchema, self).__init__(view, path, method, components, request, overrides)
self._sch = AutoSchema()
- self.view = view
- self.path = path
- self.method = method
- self.overrides = overrides
- self.components = components
self._sch.view = view
def get_operation(self, operation_keys):
- """Get an :class:`.Operation` for the given API endpoint (path, method).
- This includes query, body parameters and response schemas.
-
- :param tuple[str] operation_keys: an array of keys describing the hierarchical layout of this view in the API;
- e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc.
- :rtype: openapi.Operation
- """
consumes = self.get_consumes()
body = self.get_request_body_parameters(consumes)
@@ -66,17 +25,19 @@ class SwaggerAutoSchema(object):
parameters = [param for param in parameters if param is not None]
parameters = self.add_manual_parameters(parameters)
+ operation_id = self.get_operation_id(operation_keys)
description = self.get_description()
+ tags = self.get_tags(operation_keys)
responses = self.get_responses()
return openapi.Operation(
- operation_id='_'.join(operation_keys),
+ operation_id=operation_id,
description=description,
responses=responses,
parameters=parameters,
consumes=consumes,
- tags=[operation_keys[0]],
+ tags=tags,
)
def get_request_body_parameters(self, consumes):
@@ -105,7 +66,7 @@ class SwaggerAutoSchema(object):
else:
if schema is None:
schema = self.get_request_body_schema(serializer)
- return [self.make_body_parameter(schema)]
+ return [self.make_body_parameter(schema)] if schema is not None else []
def get_view_serializer(self):
"""Return the serializer as defined by the view's ``get_serializer()`` method.
@@ -192,26 +153,6 @@ class SwaggerAutoSchema(object):
responses=self.get_response_schemas(response_serializers)
)
- def get_paged_response_schema(self, response_schema):
- """Add appropriate paging fields to a response :class:`.Schema`.
-
- :param openapi.Schema response_schema: the response schema that must be paged.
- :rtype: openapi.Schema
- """
- assert response_schema.type == openapi.TYPE_ARRAY, "array return expected for paged response"
- paged_schema = openapi.Schema(
- type=openapi.TYPE_OBJECT,
- properties={
- 'count': openapi.Schema(type=openapi.TYPE_INTEGER),
- 'next': openapi.Schema(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI),
- 'previous': openapi.Schema(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI),
- 'results': response_schema,
- },
- required=['count', 'results']
- )
-
- return paged_schema
-
def get_default_responses(self):
"""Get the default responses determined for this view from the request serializer and request method.
@@ -219,28 +160,26 @@ class SwaggerAutoSchema(object):
"""
method = self.method.lower()
- default_status = status.HTTP_200_OK
+ default_status = guess_response_status(method)
default_schema = ''
if method == 'post':
- default_status = status.HTTP_201_CREATED
default_schema = self.get_request_serializer() or self.get_view_serializer()
- elif method == 'delete':
- default_status = status.HTTP_204_NO_CONTENT
elif method in ('get', 'put', 'patch'):
default_schema = self.get_request_serializer() or self.get_view_serializer()
default_schema = default_schema or ''
if any(is_form_media_type(encoding) for encoding in self.get_consumes()):
default_schema = ''
+ if default_schema and not isinstance(default_schema, openapi.Schema):
+ default_schema = self.serializer_to_schema(default_schema) or ''
+
if default_schema:
- if not isinstance(default_schema, openapi.Schema):
- default_schema = self.serializer_to_schema(default_schema)
if is_list_view(self.path, self.method, self.view) and self.method.lower() == 'get':
default_schema = openapi.Schema(type=openapi.TYPE_ARRAY, items=default_schema)
if self.should_page():
- default_schema = self.get_paged_response_schema(default_schema)
+ default_schema = self.get_paginated_response(default_schema) or default_schema
- return {str(default_status): default_schema}
+ return OrderedDict({str(default_status): default_schema})
def get_response_serializers(self):
"""Return the response codes that this view is expected to return, and the serializer for each response body.
@@ -254,7 +193,7 @@ class SwaggerAutoSchema(object):
manual_responses = self.overrides.get('responses', None) or {}
manual_responses = OrderedDict((str(sc), resp) for sc, resp in manual_responses.items())
- responses = {}
+ responses = OrderedDict()
if not any(is_success(int(sc)) for sc in manual_responses if sc != 'default'):
responses = self.get_default_responses()
@@ -268,7 +207,7 @@ class SwaggerAutoSchema(object):
:return: a dictionary of status code to :class:`.Response` object
:rtype: dict[str, openapi.Response]
"""
- responses = {}
+ responses = OrderedDict()
for sc, serializer in response_serializers.items():
if isinstance(serializer, str):
response = openapi.Response(
@@ -325,84 +264,18 @@ class SwaggerAutoSchema(object):
return natural_parameters + serializer_parameters
- def should_filter(self):
- """Determine whether filter backend parameters should be included for this request.
+ def get_operation_id(self, operation_keys):
+ """Return an unique ID for this operation. The ID must be unique across
+ all :class:`.Operation` objects in the API.
- :rtype: bool
+ :param tuple[str] operation_keys: an array of keys derived from the pathdescribing the hierarchical layout
+ of this view in the API; e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc.
+ :rtype: str
"""
- if not getattr(self.view, 'filter_backends', None):
- return False
-
- if self.method.lower() not in ["get", "delete"]:
- return False
-
- if not isinstance(self.view, GenericViewSet):
- return True
-
- return is_list_view(self.path, self.method, self.view)
-
- def get_filter_backend_parameters(self, filter_backend):
- """Get the filter parameters for a single filter backend **instance**.
-
- :param BaseFilterBackend filter_backend: the filter backend
- :rtype: list[openapi.Parameter]
- """
- fields = []
- if hasattr(filter_backend, 'get_schema_fields'):
- fields = filter_backend.get_schema_fields(self.view)
- return [self.coreapi_field_to_parameter(field) for field in fields]
-
- def get_filter_parameters(self):
- """Return the parameters added to the view by its filter backends.
-
- :rtype: list[openapi.Parameter]
- """
- if not self.should_filter():
- return []
-
- fields = []
- for filter_backend in self.view.filter_backends:
- fields += self.get_filter_backend_parameters(filter_backend())
-
- return fields
-
- def should_page(self):
- """Determine whether paging parameters and structure should be added to this operation's request and response.
-
- :rtype: bool
- """
- if not hasattr(self.view, 'paginator'):
- return False
-
- if self.view.paginator is None:
- return False
-
- if self.method.lower() != 'get':
- return False
-
- return is_list_view(self.path, self.method, self.view)
-
- def get_paginator_parameters(self, paginator):
- """Get the pagination parameters for a single paginator **instance**.
-
- :param BasePagination paginator: the paginator
- :rtype: list[openapi.Parameter]
- """
- fields = []
- if hasattr(paginator, 'get_schema_fields'):
- fields = paginator.get_schema_fields(self.view)
-
- return [self.coreapi_field_to_parameter(field) for field in fields]
-
- def get_pagination_parameters(self):
- """Return the parameters added to the view by its paginator.
-
- :rtype: list[openapi.Parameter]
- """
- if not self.should_page():
- return []
-
- return self.get_paginator_parameters(self.view.paginator)
+ operation_id = self.overrides.get('operation_id', '')
+ if not operation_id:
+ operation_id = '_'.join(operation_keys)
+ return operation_id
def get_description(self):
"""Return an operation description determined as appropriate from the view's method and class docstrings.
@@ -415,6 +288,16 @@ class SwaggerAutoSchema(object):
description = self._sch.get_description(self.path, self.method)
return description
+ def get_tags(self, operation_keys):
+ """Get a list of tags for this operation. Tags determine how operations relate with each other, and in the UI
+ each tag will show as a group containing the operations that use it.
+
+ :param tuple[str] operation_keys: an array of keys derived from the pathdescribing the hierarchical layout
+ of this view in the API; e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc.
+ :rtype: list[str]
+ """
+ return [operation_keys[0]]
+
def get_consumes(self):
"""Return the MIME types this endpoint can consume.
@@ -424,62 +307,3 @@ class SwaggerAutoSchema(object):
if all(is_form_media_type(encoding) for encoding in media_types):
return media_types
return media_types[:1]
-
- def serializer_to_schema(self, serializer):
- """Convert a DRF Serializer instance to an :class:`.openapi.Schema`.
-
- :param serializers.BaseSerializer serializer: the ``Serializer`` instance
- :rtype: openapi.Schema
- """
- definitions = self.components.with_scope(openapi.SCHEMA_DEFINITIONS)
- return serializer_field_to_swagger(serializer, openapi.Schema, definitions)
-
- def serializer_to_parameters(self, serializer, in_):
- """Convert a DRF serializer into a list of :class:`.Parameter`\ s using :meth:`.field_to_parameter`
-
- :param serializers.BaseSerializer serializer: the ``Serializer`` instance
- :param str in_: the location of the parameters, one of the `openapi.IN_*` constants
- :rtype: list[openapi.Parameter]
- """
- fields = getattr(serializer, 'fields', {})
- return [
- self.field_to_parameter(value, key, in_)
- for key, value
- in fields.items()
- ]
-
- def field_to_parameter(self, field, name, in_):
- """Convert a DRF serializer Field to a swagger :class:`.Parameter` object.
-
- :param coreapi.Field field:
- :param str name: the name of the parameter
- :param str in_: the location of the parameter, one of the `openapi.IN_*` constants
- :rtype: openapi.Parameter
- """
- return serializer_field_to_swagger(field, openapi.Parameter, name=name, in_=in_)
-
- def coreapi_field_to_parameter(self, field):
- """Convert an instance of `coreapi.Field` to a swagger :class:`.Parameter` object.
-
- :param coreapi.Field field:
- :rtype: openapi.Parameter
- """
- location_to_in = {
- 'query': openapi.IN_QUERY,
- 'path': openapi.IN_PATH,
- 'form': openapi.IN_FORM,
- 'body': openapi.IN_FORM,
- }
- coreapi_types = {
- coreschema.Integer: openapi.TYPE_INTEGER,
- coreschema.Number: openapi.TYPE_NUMBER,
- coreschema.String: openapi.TYPE_STRING,
- coreschema.Boolean: openapi.TYPE_BOOLEAN,
- }
- return openapi.Parameter(
- name=field.name,
- in_=location_to_in[field.location],
- type=coreapi_types.get(type(field.schema), openapi.TYPE_STRING),
- required=field.required,
- description=field.schema.description,
- )
diff --git a/src/drf_yasg/openapi.py b/src/drf_yasg/openapi.py
index 8b3d3b0..7d9397f 100644
--- a/src/drf_yasg/openapi.py
+++ b/src/drf_yasg/openapi.py
@@ -1,9 +1,11 @@
+import re
from collections import OrderedDict
from coreapi.compat import urlparse
-from future.utils import raise_from
from inflection import camelize
+from .utils import filter_none
+
TYPE_OBJECT = "object" #:
TYPE_STRING = "string" #:
TYPE_NUMBER = "number" #:
@@ -94,8 +96,9 @@ class SwaggerDict(OrderedDict):
raise AttributeError
try:
return self[make_swagger_name(item)]
- except KeyError as e:
- raise_from(AttributeError("object of class " + type(self).__name__ + " has no attribute " + item), e)
+ except KeyError:
+ # raise_from is EXTREMELY slow, replaced with plain raise
+ raise AttributeError("object of class " + type(self).__name__ + " has no attribute " + item)
def __delattr__(self, item):
if item.startswith('_'):
@@ -230,7 +233,7 @@ class Swagger(SwaggerDict):
self.base_path = '/'
self.paths = paths
- self.definitions = definitions
+ self.definitions = filter_none(definitions)
self._insert_extras__()
@@ -270,13 +273,13 @@ class PathItem(SwaggerDict):
self.patch = patch
self.delete = delete
self.options = options
- self.parameters = parameters
+ self.parameters = filter_none(parameters)
self._insert_extras__()
class Operation(SwaggerDict):
def __init__(self, operation_id, responses, parameters=None, consumes=None,
- produces=None, description=None, tags=None, **extra):
+ produces=None, summary=None, description=None, tags=None, **extra):
"""Information about an API operation (path + http method combination)
:param str operation_id: operation ID, should be unique across all operations
@@ -284,17 +287,19 @@ class Operation(SwaggerDict):
:param list[.Parameter] parameters: parameters accepted
:param list[str] consumes: content types accepted
:param list[str] produces: content types produced
- :param str description: operation description
+ :param str summary: operation summary; should be < 120 characters
+ :param str description: operation description; can be of any length and supports markdown
:param list[str] tags: operation tags
"""
super(Operation, self).__init__(**extra)
self.operation_id = operation_id
+ self.summary = summary
self.description = description
- self.parameters = [param for param in parameters if param is not None]
+ self.parameters = filter_none(parameters)
self.responses = responses
- self.consumes = consumes
- self.produces = produces
- self.tags = tags
+ self.consumes = filter_none(consumes)
+ self.produces = filter_none(produces)
+ self.tags = filter_none(tags)
self._insert_extras__()
@@ -352,21 +357,26 @@ class Parameter(SwaggerDict):
class Schema(SwaggerDict):
- OR_REF = ()
+ OR_REF = () #: useful for type-checking, e.g ``isinstance(obj, openapi.Schema.OR_REF)``
- def __init__(self, description=None, required=None, type=None, properties=None, additional_properties=None,
- format=None, enum=None, pattern=None, items=None, **extra):
+ def __init__(self, title=None, description=None, type=None, format=None, enum=None, pattern=None, properties=None,
+ additional_properties=None, required=None, items=None, default=None, read_only=None, **extra):
"""Describes a complex object accepted as parameter or returned as a response.
- :param description: schema description
- :param list[str] required: list of requried property names
+ :param str title: schema title
+ :param str description: schema description
:param str type: value type; required
- :param list[.Schema,.SchemaRef] properties: object properties; required if `type` is ``object``
- :param bool,.Schema,.SchemaRef additional_properties: allow wildcard properties not listed in `properties`
:param str format: value format, see OpenAPI spec
:param list enum: restrict possible values
:param str pattern: pattern if type is ``string``
- :param .Schema,.SchemaRef items: only valid if `type` is ``array``
+ :param list[.Schema,.SchemaRef] properties: object properties; required if `type` is ``object``
+ :param bool,.Schema,.SchemaRef additional_properties: allow wildcard properties not listed in `properties`
+ :param list[str] required: list of requried property names
+ :param .Schema,.SchemaRef items: type of array items, only valid if `type` is ``array``
+ :param default: only valid when insider another ``Schema``\ 's ``properties``;
+ the default value of this property if it is not provided, must conform to the type of this Schema
+ :param read_only: only valid when insider another ``Schema``\ 's ``properties``;
+ declares the property as read only - it must only be sent as part of responses, never in requests
"""
super(Schema, self).__init__(**extra)
if required is True or required is False:
@@ -374,19 +384,24 @@ class Schema(SwaggerDict):
raise AssertionError(
"the `requires` attribute of schema must be an array of required properties, not a boolean!")
assert type is not None, "type is required!"
+ self.title = title
self.description = description
- self.required = required
+ self.required = filter_none(required)
self.type = type
- self.properties = properties
+ self.properties = filter_none(properties)
self.additional_properties = additional_properties
self.format = format
self.enum = enum
self.pattern = pattern
self.items = items
+ self.read_only = read_only
+ self.default = default
self._insert_extras__()
class _Ref(SwaggerDict):
+ ref_name_re = re.compile(r"#/(?P.+)/(?P[^/]+)$")
+
def __init__(self, resolver, name, scope, expected_type):
"""Base class for all reference types. A reference object has only one property, ``$ref``, which must be a JSON
reference to a valid object in the specification, e.g. ``#/definitions/Article`` to refer to an article model.
@@ -404,6 +419,15 @@ class _Ref(SwaggerDict):
.format(actual=type(obj).__name__, expected=expected_type.__name__)
self.ref = ref_name
+ def resolve(self, resolver):
+ """Get the object targeted by this reference from the given component resolver.
+
+ :param .ReferenceResolver resolver: component resolver which must contain the referneced object
+ :returns: the target object
+ """
+ ref_match = self.ref_name_re.match(self.ref)
+ return resolver.get(ref_match.group('name'), scope=ref_match.group('scope'))
+
def __setitem__(self, key, value, **kwargs):
if key == "$ref":
return super(_Ref, self).__setitem__(key, value, **kwargs)
@@ -427,6 +451,17 @@ class SchemaRef(_Ref):
Schema.OR_REF = (Schema, SchemaRef)
+def resolve_ref(ref_or_obj, resolver):
+ """Resolve `ref_or_obj` if it is a reference type. Return it unchaged if not.
+
+ :param SwaggerDict,_Ref ref_or_obj:
+ :param resolver: component resolver which must contain the referenced object
+ """
+ if isinstance(ref_or_obj, _Ref):
+ return ref_or_obj.resolve(resolver)
+ return ref_or_obj
+
+
class Responses(SwaggerDict):
def __init__(self, responses, default=None, **extra):
"""Describes the expected responses of an :class:`.Operation`.
@@ -483,7 +518,7 @@ class ReferenceResolver(object):
self._objects[scope] = OrderedDict()
def with_scope(self, scope):
- """Return a new :class:`.ReferenceResolver` whose scope is defaulted and forced to `scope`.
+ """Return a view into this :class:`.ReferenceResolver` whose scope is defaulted and forced to `scope`.
:param str scope: target scope, must be in this resolver's `scopes`
:return: the bound resolver
diff --git a/src/drf_yasg/renderers.py b/src/drf_yasg/renderers.py
index 4352812..87448ed 100644
--- a/src/drf_yasg/renderers.py
+++ b/src/drf_yasg/renderers.py
@@ -14,7 +14,7 @@ class _SpecRenderer(BaseRenderer):
@classmethod
def with_validators(cls, validators):
- assert all(vld in VALIDATORS for vld in validators), "allowed validators are" + ", ".join(VALIDATORS)
+ assert all(vld in VALIDATORS for vld in validators), "allowed validators are " + ", ".join(VALIDATORS)
return type(cls.__name__, (cls,), {'validators': validators})
def render(self, data, media_type=None, renderer_context=None):
@@ -45,7 +45,7 @@ class SwaggerYAMLRenderer(_SpecRenderer):
class _UIRenderer(BaseRenderer):
- """Base class for web UI renderers. Handles loading an passing settings to the appropriate template."""
+ """Base class for web UI renderers. Handles loading and passing settings to the appropriate template."""
media_type = 'text/html'
charset = 'utf-8'
template = ''
diff --git a/src/drf_yasg/templates/drf-yasg/swagger-ui.html b/src/drf_yasg/templates/drf-yasg/swagger-ui.html
index 0a2ad2a..f3a0473 100644
--- a/src/drf_yasg/templates/drf-yasg/swagger-ui.html
+++ b/src/drf_yasg/templates/drf-yasg/swagger-ui.html
@@ -166,9 +166,11 @@
layout: "StandaloneLayout",
filter: true,
requestInterceptor: function(request) {
- console.log(request);
var headers = request.headers || {};
- headers["X-CSRFToken"] = document.querySelector("[name=csrfmiddlewaretoken]").value;
+ var csrftoken = document.querySelector("[name=csrfmiddlewaretoken]");
+ if (csrftoken) {
+ headers["X-CSRFToken"] = csrftoken.value;
+ }
return request;
}
};
diff --git a/src/drf_yasg/utils.py b/src/drf_yasg/utils.py
index 81a78fd..3c6302d 100644
--- a/src/drf_yasg/utils.py
+++ b/src/drf_yasg/utils.py
@@ -1,17 +1,9 @@
+import inspect
import logging
from collections import OrderedDict
-from django.core.validators import RegexValidator
-from django.db import models
-from django.utils.encoding import force_text
-from rest_framework import serializers
+from rest_framework import status, serializers
from rest_framework.mixins import RetrieveModelMixin, DestroyModelMixin, UpdateModelMixin
-from rest_framework.schemas.inspectors import get_pk_description
-from rest_framework.settings import api_settings
-from rest_framework.utils import json, encoders
-
-from . import openapi
-from .errors import SwaggerGenerationError
logger = logging.getLogger(__name__)
@@ -19,6 +11,141 @@ logger = logging.getLogger(__name__)
no_body = object()
+def swagger_auto_schema(method=None, methods=None, auto_schema=None, request_body=None, query_serializer=None,
+ manual_parameters=None, operation_id=None, operation_description=None, responses=None,
+ field_inspectors=None, filter_inspectors=None, paginator_inspectors=None,
+ **extra_overrides):
+ """Decorate a view method to customize the :class:`.Operation` object generated from it.
+
+ `method` and `methods` are mutually exclusive and must only be present when decorating a view method that accepts
+ more than one HTTP request method.
+
+ The `auto_schema` and `operation_description` arguments take precendence over view- or method-level values.
+
+ .. versionchanged:: 1.1
+ Added the ``extra_overrides`` and ``operatiod_id`` parameters.
+
+ .. versionchanged:: 1.1
+ Added the ``field_inspectors``, ``filter_inspectors`` and ``paginator_inspectors`` parameters.
+
+ :param str method: for multi-method views, the http method the options should apply to
+ :param list[str] methods: for multi-method views, the http methods the options should apply to
+ :param .inspectors.SwaggerAutoSchema auto_schema: custom class to use for generating the Operation object;
+ this overrides both the class-level ``swagger_schema`` attribute and the ``DEFAULT_AUTO_SCHEMA_CLASS``
+ setting
+ :param .Schema,.SchemaRef,.Serializer request_body: custom request body, or :data:`.no_body`. The value given here
+ will be used as the ``schema`` property of a :class:`.Parameter` with ``in: 'body'``.
+
+ A Schema or SchemaRef is not valid if this request consumes form-data, because ``form`` and ``body`` parameters
+ are mutually exclusive in an :class:`.Operation`. If you need to set custom ``form`` parameters, you can use
+ the `manual_parameters` argument.
+
+ If a ``Serializer`` class or instance is given, it will be automatically converted into a :class:`.Schema`
+ used as a ``body`` :class:`.Parameter`, or into a list of ``form`` :class:`.Parameter`\ s, as appropriate.
+
+ :param .Serializer query_serializer: if you use a ``Serializer`` to parse query parameters, you can pass it here
+ and have :class:`.Parameter` objects be generated automatically from it.
+
+ If any ``Field`` on the serializer cannot be represented as a ``query`` :class:`.Parameter`
+ (e.g. nested Serializers, file fields, ...), the schema generation will fail with an error.
+
+ Schema generation will also fail if the name of any Field on the `query_serializer` conflicts with parameters
+ generated by ``filter_backends`` or ``paginator``.
+
+ :param list[.Parameter] manual_parameters: a list of manual parameters to override the automatically generated ones
+
+ :class:`.Parameter`\ s are identified by their (``name``, ``in``) combination, and any parameters given
+ here will fully override automatically generated parameters if they collide.
+
+ It is an error to supply ``form`` parameters when the request does not consume form-data.
+
+ :param str operation_id: operation ID override; the operation ID must be unique accross the whole API
+ :param str operation_description: operation description override
+ :param dict[str,(.Schema,.SchemaRef,.Response,str,Serializer)] responses: a dict of documented manual responses
+ keyed on response status code. If no success (``2xx``) response is given, one will automatically be
+ generated from the request body and http method. If any ``2xx`` response is given the automatic response is
+ suppressed.
+
+ * if a plain string is given as value, a :class:`.Response` with no body and that string as its description
+ will be generated
+ * if a :class:`.Schema`, :class:`.SchemaRef` is given, a :class:`.Response` with the schema as its body and
+ an empty description will be generated
+ * a ``Serializer`` class or instance will be converted into a :class:`.Schema` and treated as above
+ * a :class:`.Response` object will be used as-is; however if its ``schema`` attribute is a ``Serializer``,
+ it will automatically be converted into a :class:`.Schema`
+
+ :param list[.FieldInspector] field_inspectors: extra serializer and field inspectors; these will be tried
+ before :attr:`.ViewInspector.field_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance
+ :param list[.FilterInspector] filter_inspectors: extra filter inspectors; these will be tried before
+ :attr:`.ViewInspector.filter_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance
+ :param list[.PaginatorInspector] paginator_inspectors: extra paginator inspectors; these will be tried before
+ :attr:`.ViewInspector.paginator_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance
+ :param extra_overrides: extra values that will be saved into the ``overrides`` dict; these values will be available
+ in the handling :class:`.inspectors.SwaggerAutoSchema` instance via ``self.overrides``
+ """
+
+ def decorator(view_method):
+ data = {
+ 'auto_schema': auto_schema,
+ 'request_body': request_body,
+ 'query_serializer': query_serializer,
+ 'manual_parameters': manual_parameters,
+ 'operation_id': operation_id,
+ 'operation_description': operation_description,
+ 'responses': responses,
+ 'filter_inspectors': list(filter_inspectors) if filter_inspectors else None,
+ 'paginator_inspectors': list(paginator_inspectors) if paginator_inspectors else None,
+ 'field_inspectors': list(field_inspectors) if field_inspectors else None,
+ }
+ data = {k: v for k, v in data.items() if v is not None}
+ data.update(extra_overrides)
+
+ # if the method is a detail_route or list_route, it will have a bind_to_methods attribute
+ bind_to_methods = getattr(view_method, 'bind_to_methods', [])
+ # if the method is actually a function based view (@api_view), it will have a 'cls' attribute
+ view_cls = getattr(view_method, 'cls', None)
+ http_method_names = getattr(view_cls, 'http_method_names', [])
+ if bind_to_methods or http_method_names:
+ # detail_route, list_route or api_view
+ assert bool(http_method_names) != bool(bind_to_methods), "this should never happen"
+ available_methods = http_method_names + bind_to_methods
+ existing_data = getattr(view_method, '_swagger_auto_schema', {})
+
+ if http_method_names:
+ _route = "api_view"
+ else:
+ _route = "detail_route" if view_method.detail else "list_route"
+
+ _methods = methods
+ if len(available_methods) > 1:
+ assert methods or method, \
+ "on multi-method %s, you must specify swagger_auto_schema on a per-method basis " \
+ "using one of the `method` or `methods` arguments" % _route
+ assert bool(methods) != bool(method), "specify either method or methods"
+ assert not isinstance(methods, str), "`methods` expects to receive a list of methods;" \
+ " use `method` for a single argument"
+ if method:
+ _methods = [method.lower()]
+ else:
+ _methods = [mth.lower() for mth in methods]
+ assert not any(mth in existing_data for mth in _methods), "method defined multiple times"
+ assert all(mth in available_methods for mth in _methods), "method not bound to %s" % _route
+
+ existing_data.update((mth.lower(), data) for mth in _methods)
+ else:
+ existing_data[available_methods[0]] = data
+ view_method._swagger_auto_schema = existing_data
+ else:
+ assert method is None and methods is None, \
+ "the methods argument should only be specified when decorating a detail_route or list_route; you " \
+ "should also ensure that you put the swagger_auto_schema decorator AFTER (above) the _route decorator"
+ view_method._swagger_auto_schema = data
+
+ return view_method
+
+ return decorator
+
+
def is_list_view(path, method, view):
"""Check if the given path/method appears to represent a list view (as opposed to a detail/instance view).
@@ -52,431 +179,13 @@ def is_list_view(path, method, view):
return True
-def swagger_auto_schema(method=None, methods=None, auto_schema=None, request_body=None, query_serializer=None,
- manual_parameters=None, operation_description=None, responses=None):
- """Decorate a view method to customize the :class:`.Operation` object generated from it.
-
- `method` and `methods` are mutually exclusive and must only be present when decorating a view method that accepts
- more than one HTTP request method.
-
- The `auto_schema` and `operation_description` arguments take precendence over view- or method-level values.
-
- :param str method: for multi-method views, the http method the options should apply to
- :param list[str] methods: for multi-method views, the http methods the options should apply to
- :param .SwaggerAutoSchema auto_schema: custom class to use for generating the Operation object
- :param .Schema,.SchemaRef,.Serializer request_body: custom request body, or :data:`.no_body`. The value given here
- will be used as the ``schema`` property of a :class:`.Parameter` with ``in: 'body'``.
-
- A Schema or SchemaRef is not valid if this request consumes form-data, because ``form`` and ``body`` parameters
- are mutually exclusive in an :class:`.Operation`. If you need to set custom ``form`` parameters, you can use
- the `manual_parameters` argument.
-
- If a ``Serializer`` class or instance is given, it will be automatically converted into a :class:`.Schema`
- used as a ``body`` :class:`.Parameter`, or into a list of ``form`` :class:`.Parameter`\ s, as appropriate.
-
- :param .Serializer query_serializer: if you use a ``Serializer`` to parse query parameters, you can pass it here
- and have :class:`.Parameter` objects be generated automatically from it.
-
- If any ``Field`` on the serializer cannot be represented as a ``query`` :class:`.Parameter`
- (e.g. nested Serializers, file fields, ...), the schema generation will fail with an error.
-
- Schema generation will also fail if the name of any Field on the `query_serializer` conflicts with parameters
- generated by ``filter_backends`` or ``paginator``.
-
- :param list[.Parameter] manual_parameters: a list of manual parameters to override the automatically generated ones
-
- :class:`.Parameter`\ s are identified by their (``name``, ``in``) combination, and any parameters given
- here will fully override automatically generated parameters if they collide.
-
- It is an error to supply ``form`` parameters when the request does not consume form-data.
-
- :param str operation_description: operation description override
- :param dict[str,(.Schema,.SchemaRef,.Response,str,Serializer)] responses: a dict of documented manual responses
- keyed on response status code. If no success (``2xx``) response is given, one will automatically be
- generated from the request body and http method. If any ``2xx`` response is given the automatic response is
- suppressed.
-
- * if a plain string is given as value, a :class:`.Response` with no body and that string as its description
- will be generated
- * if a :class:`.Schema`, :class:`.SchemaRef` is given, a :class:`.Response` with the schema as its body and
- an empty description will be generated
- * a ``Serializer`` class or instance will be converted into a :class:`.Schema` and treated as above
- * a :class:`.Response` object will be used as-is; however if its ``schema`` attribute is a ``Serializer``,
- it will automatically be converted into a :class:`.Schema`
-
- """
-
- def decorator(view_method):
- data = {
- 'auto_schema': auto_schema,
- 'request_body': request_body,
- 'query_serializer': query_serializer,
- 'manual_parameters': manual_parameters,
- 'operation_description': operation_description,
- 'responses': responses,
- }
- data = {k: v for k, v in data.items() if v is not None}
-
- # if the method is a detail_route or list_route, it will have a bind_to_methods attribute
- bind_to_methods = getattr(view_method, 'bind_to_methods', [])
- # if the method is actually a function based view (@api_view), it will have a 'cls' attribute
- view_cls = getattr(view_method, 'cls', None)
- http_method_names = getattr(view_cls, 'http_method_names', [])
- if bind_to_methods or http_method_names:
- # detail_route, list_route or api_view
- assert bool(http_method_names) != bool(bind_to_methods), "this should never happen"
- available_methods = http_method_names + bind_to_methods
- existing_data = getattr(view_method, 'swagger_auto_schema', {})
-
- if http_method_names:
- _route = "api_view"
- else:
- _route = "detail_route" if view_method.detail else "list_route"
-
- _methods = methods
- if len(available_methods) > 1:
- assert methods or method, \
- "on multi-method %s, you must specify swagger_auto_schema on a per-method basis " \
- "using one of the `method` or `methods` arguments" % _route
- assert bool(methods) != bool(method), "specify either method or methods"
- assert not isinstance(methods, str), "`methods` expects to receive a list of methods;" \
- " use `method` for a single argument"
- if method:
- _methods = [method.lower()]
- else:
- _methods = [mth.lower() for mth in methods]
- assert not any(mth in existing_data for mth in _methods), "method defined multiple times"
- assert all(mth in available_methods for mth in _methods), "method not bound to %s" % _route
-
- existing_data.update((mth.lower(), data) for mth in _methods)
- else:
- existing_data[available_methods[0]] = data
- view_method.swagger_auto_schema = existing_data
- else:
- assert method is None and methods is None, \
- "the methods argument should only be specified when decorating a detail_route or list_route; you " \
- "should also ensure that you put the swagger_auto_schema decorator AFTER (above) the _route decorator"
- view_method.swagger_auto_schema = data
-
- return view_method
-
- return decorator
-
-
-def get_model_field(queryset, field_name):
- """Try to get information about a model and model field from a queryset.
-
- :param queryset: the queryset
- :param field_name: the target field name
- :returns: the model and target field from the queryset as a 2-tuple; both elements can be ``None``
- :rtype: tuple
- """
- model = getattr(queryset, 'model', None)
- try:
- model_field = model._meta.get_field(field_name)
- except Exception: # pragma: no cover
- model_field = None
-
- return model, model_field
-
-
-model_field_to_swagger_type = [
- (models.AutoField, (openapi.TYPE_INTEGER, None)),
- (models.BinaryField, (openapi.TYPE_STRING, openapi.FORMAT_BINARY)),
- (models.BooleanField, (openapi.TYPE_BOOLEAN, None)),
- (models.NullBooleanField, (openapi.TYPE_BOOLEAN, None)),
- (models.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
- (models.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
- (models.DecimalField, (openapi.TYPE_NUMBER, None)),
- (models.DurationField, (openapi.TYPE_INTEGER, None)),
- (models.FloatField, (openapi.TYPE_NUMBER, None)),
- (models.IntegerField, (openapi.TYPE_INTEGER, None)),
- (models.IPAddressField, (openapi.TYPE_STRING, openapi.FORMAT_IPV4)),
- (models.GenericIPAddressField, (openapi.TYPE_STRING, openapi.FORMAT_IPV6)),
- (models.SlugField, (openapi.TYPE_STRING, openapi.FORMAT_SLUG)),
- (models.TextField, (openapi.TYPE_STRING, None)),
- (models.TimeField, (openapi.TYPE_STRING, None)),
- (models.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
- (models.CharField, (openapi.TYPE_STRING, None)),
-]
-
-
-def inspect_model_field(model, model_field):
- """Extract information from a django model field instance.
-
- :param model: the django model
- :param model_field: a field on the model
- :return: description, type, format and pattern extracted from the model field
- :rtype: OrderedDict
- """
- if model is not None and model_field is not None:
- for model_field_class, tf in model_field_to_swagger_type:
- if isinstance(model_field, model_field_class):
- swagger_type, format = tf
- break
- else: # pragma: no cover
- swagger_type, format = None, None
-
- if format is None or format == openapi.FORMAT_SLUG:
- pattern = find_regex(model_field)
- else:
- pattern = None
-
- if model_field.help_text:
- description = force_text(model_field.help_text)
- elif model_field.primary_key:
- description = get_pk_description(model, model_field)
- else:
- description = None
+def guess_response_status(method):
+ if method == 'post':
+ return status.HTTP_201_CREATED
+ elif method == 'delete':
+ return status.HTTP_204_NO_CONTENT
else:
- description = None
- swagger_type = None
- format = None
- pattern = None
-
- result = OrderedDict([
- ('description', description),
- ('type', swagger_type or openapi.TYPE_STRING),
- ('format', format),
- ('pattern', pattern)
- ])
- # TODO: filter none
- return result
-
-
-def serializer_field_to_swagger(field, swagger_object_type, definitions=None, **kwargs):
- """Convert a drf Serializer or Field instance into a Swagger object.
-
- :param rest_framework.serializers.Field field: the source field
- :param type[openapi.SwaggerDict] swagger_object_type: should be one of Schema, Parameter, Items
- :param .ReferenceResolver definitions: used to serialize Schemas by reference
- :param kwargs: extra attributes for constructing the object;
- if swagger_object_type is Parameter, ``name`` and ``in_`` should be provided
- :return: the swagger object
- :rtype: openapi.Parameter, openapi.Items, openapi.Schema
- """
- assert swagger_object_type in (openapi.Schema, openapi.Parameter, openapi.Items)
- assert not isinstance(field, openapi.SwaggerDict), "passed field is already a SwaggerDict object"
- title = force_text(field.label) if field.label else None
- title = title if swagger_object_type == openapi.Schema else None # only Schema has title
- title = None
- description = force_text(field.help_text) if field.help_text else None
- description = description if swagger_object_type != openapi.Items else None # Items has no description either
-
- def SwaggerType(existing_object=None, **instance_kwargs):
- if swagger_object_type == openapi.Parameter and 'required' not in instance_kwargs:
- instance_kwargs['required'] = field.required
- if swagger_object_type != openapi.Items and 'default' not in instance_kwargs:
- default = getattr(field, 'default', serializers.empty)
- if default is not serializers.empty:
- if callable(default):
- try:
- if hasattr(default, 'set_context'):
- default.set_context(field)
- default = default()
- except Exception as e: # pragma: no cover
- logger.warning("default for %s is callable but it raised an exception when "
- "called; 'default' field will not be added to schema", field, exc_info=True)
- default = None
-
- if default is not None:
- try:
- default = field.to_representation(default)
- # JSON roundtrip ensures that the value is valid JSON;
- # for example, sets get transformed into lists
- default = json.loads(json.dumps(default, cls=encoders.JSONEncoder))
- except Exception: # pragma: no cover
- logger.warning("'default' on schema for %s will not be set because "
- "to_representation raised an exception", field, exc_info=True)
- default = None
-
- if default is not None:
- instance_kwargs['default'] = default
-
- if swagger_object_type == openapi.Schema and 'read_only' not in instance_kwargs:
- if field.read_only:
- instance_kwargs['read_only'] = True
- instance_kwargs.update(kwargs)
- instance_kwargs.pop('title', None)
- instance_kwargs.pop('description', None)
-
- if existing_object is not None:
- existing_object.title = title
- existing_object.description = description
- for attr, val in instance_kwargs.items():
- setattr(existing_object, attr, val)
- return existing_object
-
- return swagger_object_type(title=title, description=description, **instance_kwargs)
-
- # arrays in Schema have Schema elements, arrays in Parameter and Items have Items elements
- ChildSwaggerType = openapi.Schema if swagger_object_type == openapi.Schema else openapi.Items
-
- # ------ NESTED
- if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
- child_schema = serializer_field_to_swagger(field.child, ChildSwaggerType, definitions)
- return SwaggerType(
- type=openapi.TYPE_ARRAY,
- items=child_schema,
- )
- elif isinstance(field, serializers.Serializer):
- if swagger_object_type != openapi.Schema:
- raise SwaggerGenerationError("cannot instantiate nested serializer as " + swagger_object_type.__name__)
- assert definitions is not None, "ReferenceResolver required when instantiating Schema"
-
- serializer = field
- if hasattr(serializer, '__ref_name__'):
- ref_name = serializer.__ref_name__
- else:
- ref_name = type(serializer).__name__
- if ref_name.endswith('Serializer'):
- ref_name = ref_name[:-len('Serializer')]
-
- def make_schema_definition():
- properties = OrderedDict()
- required = []
- for key, value in serializer.fields.items():
- properties[key] = serializer_field_to_swagger(value, ChildSwaggerType, definitions)
- if value.required:
- required.append(key)
-
- return SwaggerType(
- type=openapi.TYPE_OBJECT,
- properties=properties,
- required=required or None,
- )
-
- if not ref_name:
- return make_schema_definition()
-
- definitions.setdefault(ref_name, make_schema_definition)
- return openapi.SchemaRef(definitions, ref_name)
- elif isinstance(field, serializers.ManyRelatedField):
- child_schema = serializer_field_to_swagger(field.child_relation, ChildSwaggerType, definitions)
- return SwaggerType(
- type=openapi.TYPE_ARRAY,
- items=child_schema,
- unique_items=True, # is this OK?
- )
- elif isinstance(field, serializers.PrimaryKeyRelatedField):
- if field.pk_field:
- result = serializer_field_to_swagger(field.pk_field, swagger_object_type, definitions, **kwargs)
- return SwaggerType(existing_object=result)
-
- attrs = {'type': openapi.TYPE_STRING}
- try:
- model = field.queryset.model
- pk_field = model._meta.pk
- except Exception: # pragma: no cover
- logger.warning("an exception was raised when attempting to extract the primary key related to %s; "
- "falling back to plain string" % field, exc_info=True)
- else:
- attrs.update(inspect_model_field(model, pk_field))
-
- return SwaggerType(**attrs)
- elif isinstance(field, serializers.HyperlinkedRelatedField):
- return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)
- elif isinstance(field, serializers.SlugRelatedField):
- model, model_field = get_model_field(field.queryset, field.slug_field)
- attrs = inspect_model_field(model, model_field)
- return SwaggerType(**attrs)
- elif isinstance(field, serializers.RelatedField):
- return SwaggerType(type=openapi.TYPE_STRING)
- # ------ CHOICES
- elif isinstance(field, serializers.MultipleChoiceField):
- return SwaggerType(
- type=openapi.TYPE_ARRAY,
- items=ChildSwaggerType(
- type=openapi.TYPE_STRING,
- enum=list(field.choices.keys())
- )
- )
- elif isinstance(field, serializers.ChoiceField):
- return SwaggerType(type=openapi.TYPE_STRING, enum=list(field.choices.keys()))
- # ------ BOOL
- elif isinstance(field, (serializers.BooleanField, serializers.NullBooleanField)):
- return SwaggerType(type=openapi.TYPE_BOOLEAN)
- # ------ NUMERIC
- elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
- # TODO: min_value max_value
- return SwaggerType(type=openapi.TYPE_NUMBER)
- elif isinstance(field, serializers.IntegerField):
- # TODO: min_value max_value
- return SwaggerType(type=openapi.TYPE_INTEGER)
- elif isinstance(field, serializers.DurationField):
- return SwaggerType(type=openapi.TYPE_INTEGER)
- # ------ STRING
- elif isinstance(field, serializers.EmailField):
- return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_EMAIL)
- elif isinstance(field, serializers.RegexField):
- return SwaggerType(type=openapi.TYPE_STRING, pattern=find_regex(field))
- elif isinstance(field, serializers.SlugField):
- return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_SLUG, pattern=find_regex(field))
- elif isinstance(field, serializers.URLField):
- return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)
- elif isinstance(field, serializers.IPAddressField):
- format = {'ipv4': openapi.FORMAT_IPV4, 'ipv6': openapi.FORMAT_IPV6}.get(field.protocol, None)
- return SwaggerType(type=openapi.TYPE_STRING, format=format)
- elif isinstance(field, serializers.CharField):
- # TODO: min_length max_length (for all CharField subclasses above too)
- return SwaggerType(type=openapi.TYPE_STRING)
- elif isinstance(field, serializers.UUIDField):
- return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_UUID)
- # ------ DATE & TIME
- elif isinstance(field, serializers.DateField):
- return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_DATE)
- elif isinstance(field, serializers.DateTimeField):
- return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_DATETIME)
- # ------ OTHERS
- elif isinstance(field, serializers.FileField):
- # swagger 2.0 does not support specifics about file fields, so ImageFile gets no special treatment
- # OpenAPI 3.0 does support it, so a future implementation could handle this better
- err = SwaggerGenerationError("FileField is supported only in a formData Parameter or response Schema")
- if swagger_object_type == openapi.Schema:
- # FileField.to_representation returns URL or file name
- result = SwaggerType(type=openapi.TYPE_STRING, read_only=True)
- if getattr(field, 'use_url', api_settings.UPLOADED_FILES_USE_URL):
- result.format = openapi.FORMAT_URI
- return result
- elif swagger_object_type == openapi.Parameter:
- param = SwaggerType(type=openapi.TYPE_FILE)
- if param['in'] != openapi.IN_FORM:
- raise err # pragma: no cover
- return param
- else:
- raise err # pragma: no cover
- elif isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema:
- child_schema = serializer_field_to_swagger(field.child, ChildSwaggerType, definitions)
- return SwaggerType(
- type=openapi.TYPE_OBJECT,
- additional_properties=child_schema
- )
- elif isinstance(field, serializers.ModelField):
- return SwaggerType(type=openapi.TYPE_STRING)
-
- # TODO unhandled fields: TimeField HiddenField JSONField
-
- # everything else gets string by default
- return SwaggerType(type=openapi.TYPE_STRING)
-
-
-def find_regex(regex_field):
- """Given a ``Field``, look for a ``RegexValidator`` and try to extract its pattern and return it as a string.
-
- :param serializers.Field regex_field: the field instance
- :return: the extracted pattern, or ``None``
- :rtype: str
- """
- regex_validator = None
- for validator in regex_field.validators:
- if isinstance(validator, RegexValidator):
- if regex_validator is not None:
- # bail if multiple validators are found - no obvious way to choose
- return None # pragma: no cover
- regex_validator = validator
-
- # regex_validator.regex should be a compiled re object...
- return getattr(getattr(regex_validator, 'regex', None), 'pattern', None)
+ return status.HTTP_200_OK
def param_list_to_odict(parameters):
@@ -492,3 +201,37 @@ def param_list_to_odict(parameters):
result = OrderedDict(((param.name, param.in_), param) for param in parameters)
assert len(result) == len(parameters), "duplicate Parameters found"
return result
+
+
+def filter_none(obj):
+ """Remove ``None`` values from tuples, lists or dictionaries. Return other objects as-is.
+
+ :param obj:
+ :return: collection with ``None`` values removed
+ """
+ if obj is None:
+ return None
+ new_obj = None
+ if isinstance(obj, dict):
+ new_obj = type(obj)((k, v) for k, v in obj.items() if k is not None and v is not None)
+ if isinstance(obj, (list, tuple)):
+ new_obj = type(obj)(v for v in obj if v is not None)
+ if new_obj is not None and len(new_obj) != len(obj):
+ return new_obj # pragma: no cover
+ return obj
+
+
+def force_serializer_instance(serializer):
+ """Force `serializer` into a ``Serializer`` instance. If it is not a ``Serializer`` class or instance, raises
+ an assertion error.
+
+ :param serializer: serializer class or instance
+ :return: serializer instance
+ """
+ if inspect.isclass(serializer):
+ assert issubclass(serializer, serializers.BaseSerializer), "Serializer required, not %s" % serializer.__name__
+ return serializer()
+
+ assert isinstance(serializer, serializers.BaseSerializer), \
+ "Serializer class or instance required, not %s" % type(serializer).__name__
+ return serializer
diff --git a/src/drf_yasg/views.py b/src/drf_yasg/views.py
index 1720bdc..2d877d3 100644
--- a/src/drf_yasg/views.py
+++ b/src/drf_yasg/views.py
@@ -82,12 +82,24 @@ def get_schema_view(info, url=None, patterns=None, urlconf=None, public=False, v
renderer_classes = _spec_renderers
def get(self, request, version='', format=None):
- generator = self.generator_class(info, version, url, patterns, urlconf)
+ generator = self.generator_class(info, request.version or version or '', url, patterns, urlconf)
schema = generator.get_schema(request, self.public)
if schema is None:
raise exceptions.PermissionDenied() # pragma: no cover
return Response(schema)
+ @classmethod
+ def apply_cache(cls, view, cache_timeout, cache_kwargs):
+ """Override this method to customize how caching is applied to the view.
+
+ Arguments described in :meth:`.as_cached_view`.
+ """
+ if not cls.public:
+ view = vary_on_headers('Cookie', 'Authorization')(view)
+ view = cache_page(cache_timeout, **cache_kwargs)(view)
+ view = deferred_never_cache(view) # disable in-browser caching
+ return view
+
@classmethod
def as_cached_view(cls, cache_timeout=0, cache_kwargs=None, **initkwargs):
"""
@@ -102,10 +114,7 @@ def get_schema_view(info, url=None, patterns=None, urlconf=None, public=False, v
cache_kwargs = cache_kwargs or {}
view = cls.as_view(**initkwargs)
if cache_timeout != 0:
- if not public:
- view = vary_on_headers('Cookie', 'Authorization')(view)
- view = cache_page(cache_timeout, **cache_kwargs)(view)
- view = deferred_never_cache(view) # disable in-browser caching
+ view = cls.apply_cache(view, cache_timeout, cache_kwargs)
elif cache_kwargs:
warnings.warn("cache_kwargs ignored because cache_timeout is 0 (disabled)")
return view
diff --git a/testproj/articles/serializers.py b/testproj/articles/serializers.py
index ee90e2c..80c8e0c 100644
--- a/testproj/articles/serializers.py
+++ b/testproj/articles/serializers.py
@@ -1,11 +1,12 @@
from rest_framework import serializers
from articles.models import Article
+from django.utils.translation import ugettext_lazy as _
class ArticleSerializer(serializers.ModelSerializer):
references = serializers.DictField(
- help_text="this is a really bad example",
+ help_text=_("this is a really bad example"),
child=serializers.URLField(help_text="but i needed to test these 2 fields somehow"),
read_only=True,
)
@@ -23,8 +24,8 @@ class ArticleSerializer(serializers.ModelSerializer):
'body': {'help_text': 'body serializer help_text'},
'author': {
'default': serializers.CurrentUserDefault(),
- 'help_text': "The ID of the user that created this article; if none is provided, "
- "defaults to the currently logged in user."
+ 'help_text': _("The ID of the user that created this article; if none is provided, "
+ "defaults to the currently logged in user.")
},
}
diff --git a/testproj/articles/views.py b/testproj/articles/views.py
index 860e681..ef0ba47 100644
--- a/testproj/articles/views.py
+++ b/testproj/articles/views.py
@@ -11,11 +11,45 @@ from rest_framework.response import Response
from articles import serializers
from articles.models import Article
-from drf_yasg.inspectors import SwaggerAutoSchema
+from drf_yasg import openapi
+from drf_yasg.app_settings import swagger_settings
+from drf_yasg.inspectors import SwaggerAutoSchema, FieldInspector, CoreAPICompatInspector, NotHandled
from drf_yasg.utils import swagger_auto_schema
-class NoPagingAutoSchema(SwaggerAutoSchema):
+class DjangoFilterDescriptionInspector(CoreAPICompatInspector):
+ def get_filter_parameters(self, filter_backend):
+ if isinstance(filter_backend, DjangoFilterBackend):
+ result = super(DjangoFilterDescriptionInspector, self).get_filter_parameters(filter_backend)
+ for param in result:
+ if not param.get('description', ''):
+ param.description = "Filter the returned list by {field_name}".format(field_name=param.name)
+
+ return result
+
+ return NotHandled
+
+
+class NoSchemaTitleInspector(FieldInspector):
+ def process_result(self, result, method_name, obj, **kwargs):
+ # remove the `title` attribute of all Schema objects
+ if isinstance(result, openapi.Schema.OR_REF):
+ # traverse any references and alter the Schema object in place
+ schema = openapi.resolve_ref(result, self.components)
+ schema.pop('title', None)
+
+ # no ``return schema`` here, because it would mean we always generate
+ # an inline `object` instead of a definition reference
+
+ # return back the same object that we got - i.e. a reference if we got a reference
+ return result
+
+
+class NoTitleAutoSchema(SwaggerAutoSchema):
+ field_inspectors = [NoSchemaTitleInspector] + swagger_settings.DEFAULT_FIELD_INSPECTORS
+
+
+class NoPagingAutoSchema(NoTitleAutoSchema):
def should_page(self):
return False
@@ -26,7 +60,8 @@ class ArticlePagination(LimitOffsetPagination):
@method_decorator(name='list', decorator=swagger_auto_schema(
- operation_description="description from swagger_auto_schema via method_decorator"
+ operation_description="description from swagger_auto_schema via method_decorator",
+ filter_inspectors=[DjangoFilterDescriptionInspector]
))
class ArticleViewSet(viewsets.ModelViewSet):
"""
@@ -52,7 +87,9 @@ class ArticleViewSet(viewsets.ModelViewSet):
ordering_fields = ('date_modified', 'date_created')
ordering = ('date_created',)
- @swagger_auto_schema(auto_schema=NoPagingAutoSchema)
+ swagger_schema = NoTitleAutoSchema
+
+ @swagger_auto_schema(auto_schema=NoPagingAutoSchema, filter_inspectors=[DjangoFilterDescriptionInspector])
@list_route(methods=['get'])
def today(self, request):
today_min = datetime.datetime.combine(datetime.date.today(), datetime.time.min)
diff --git a/testproj/createsuperuser.py b/testproj/createsuperuser.py
new file mode 100644
index 0000000..eb000e9
--- /dev/null
+++ b/testproj/createsuperuser.py
@@ -0,0 +1,4 @@
+from django.contrib.auth.models import User
+
+User.objects.filter(username='admin').delete()
+User.objects.create_superuser('admin', 'admin@admin.admin', 'passwordadmin')
diff --git a/testproj/db.sqlite3 b/testproj/db.sqlite3
deleted file mode 100644
index 532719e..0000000
Binary files a/testproj/db.sqlite3 and /dev/null differ
diff --git a/testproj/snippets/serializers.py b/testproj/snippets/serializers.py
index eaa5798..167d5ac 100644
--- a/testproj/snippets/serializers.py
+++ b/testproj/snippets/serializers.py
@@ -5,18 +5,22 @@ from snippets.models import Snippet, LANGUAGE_CHOICES, STYLE_CHOICES
class LanguageSerializer(serializers.Serializer):
- __ref_name__ = None
name = serializers.ChoiceField(
choices=LANGUAGE_CHOICES, default='python', help_text='The name of the programming language')
+ class Meta:
+ ref_name = None
+
class ExampleProjectSerializer(serializers.Serializer):
- __ref_name__ = 'Project'
project_name = serializers.CharField(help_text='Name of the project')
github_repo = serializers.CharField(required=True, help_text='Github repository of the project')
+ class Meta:
+ ref_name = 'Project'
+
class SnippetSerializer(serializers.Serializer):
"""SnippetSerializer classdoc
diff --git a/testproj/snippets/views.py b/testproj/snippets/views.py
index a76d74c..462cf69 100644
--- a/testproj/snippets/views.py
+++ b/testproj/snippets/views.py
@@ -1,14 +1,28 @@
+from djangorestframework_camel_case.parser import CamelCaseJSONParser
+from djangorestframework_camel_case.render import CamelCaseJSONRenderer
+from inflection import camelize
from rest_framework import generics
+from drf_yasg.inspectors import SwaggerAutoSchema
from snippets.models import Snippet
from snippets.serializers import SnippetSerializer
+class CamelCaseOperationIDAutoSchema(SwaggerAutoSchema):
+ def get_operation_id(self, operation_keys):
+ operation_id = super(CamelCaseOperationIDAutoSchema, self).get_operation_id(operation_keys)
+ return camelize(operation_id, uppercase_first_letter=False)
+
+
class SnippetList(generics.ListCreateAPIView):
"""SnippetList classdoc"""
queryset = Snippet.objects.all()
serializer_class = SnippetSerializer
+ parser_classes = (CamelCaseJSONParser,)
+ renderer_classes = (CamelCaseJSONRenderer,)
+ swagger_schema = CamelCaseOperationIDAutoSchema
+
def perform_create(self, serializer):
serializer.save(owner=self.request.user)
@@ -31,6 +45,10 @@ class SnippetDetail(generics.RetrieveUpdateDestroyAPIView):
serializer_class = SnippetSerializer
pagination_class = None
+ parser_classes = (CamelCaseJSONParser,)
+ renderer_classes = (CamelCaseJSONRenderer,)
+ swagger_schema = CamelCaseOperationIDAutoSchema
+
def patch(self, request, *args, **kwargs):
"""patch method docstring"""
return super(SnippetDetail, self).patch(request, *args, **kwargs)
diff --git a/testproj/testproj/settings.py b/testproj/testproj/settings.py
index c2dae4c..64ad1b8 100644
--- a/testproj/testproj/settings.py
+++ b/testproj/testproj/settings.py
@@ -128,3 +128,52 @@ USE_TZ = True
STATIC_URL = '/static/'
TEST_RUNNER = 'testproj.runner.PytestTestRunner'
+
+LOGGING = {
+ 'version': 1,
+ 'disable_existing_loggers': True,
+ 'formatters': {
+ 'pipe_separated': {
+ 'format': '%(asctime)s | %(levelname)s | %(name)s | %(message)s'
+ }
+ },
+ 'handlers': {
+ 'console_log': {
+ 'level': 'DEBUG',
+ 'class': 'logging.StreamHandler',
+ 'stream': 'ext://sys.stdout',
+ 'formatter': 'pipe_separated',
+ },
+ },
+ 'loggers': {
+ 'drf_yasg': {
+ 'handlers': ['console_log'],
+ 'level': 'DEBUG',
+ 'propagate': False,
+ },
+ 'django': {
+ 'handlers': ['console_log'],
+ 'level': 'DEBUG',
+ 'propagate': False,
+ },
+ 'django.db.backends': {
+ 'handlers': ['console_log'],
+ 'level': 'INFO',
+ 'propagate': False,
+ },
+ 'django.template': {
+ 'handlers': ['console_log'],
+ 'level': 'INFO',
+ 'propagate': False,
+ },
+ 'swagger_spec_validator': {
+ 'handlers': ['console_log'],
+ 'level': 'INFO',
+ 'propagate': False,
+ }
+ },
+ 'root': {
+ 'handlers': ['console_log'],
+ 'level': 'INFO',
+ }
+}
diff --git a/testproj/testproj/urls.py b/testproj/testproj/urls.py
index d1ea6a4..efff18f 100644
--- a/testproj/testproj/urls.py
+++ b/testproj/testproj/urls.py
@@ -6,7 +6,7 @@ from rest_framework.decorators import api_view
from drf_yasg import openapi
from drf_yasg.views import get_schema_view
-schema_view = get_schema_view(
+SchemaView = get_schema_view(
openapi.Info(
title="Snippets API",
default_version='v1',
@@ -27,12 +27,12 @@ def plain_view(request):
urlpatterns = [
- url(r'^swagger(?P.json|.yaml)$', schema_view.without_ui(cache_timeout=0), name='schema-json'),
- url(r'^swagger/$', schema_view.with_ui('swagger', cache_timeout=0), name='schema-swagger-ui'),
- url(r'^redoc/$', schema_view.with_ui('redoc', cache_timeout=0), name='schema-redoc'),
- url(r'^cached/swagger(?P.json|.yaml)$', schema_view.without_ui(cache_timeout=None), name='schema-json'),
- url(r'^cached/swagger/$', schema_view.with_ui('swagger', cache_timeout=None), name='schema-swagger-ui'),
- url(r'^cached/redoc/$', schema_view.with_ui('redoc', cache_timeout=None), name='schema-redoc'),
+ url(r'^swagger(?P.json|.yaml)$', SchemaView.without_ui(cache_timeout=0), name='schema-json'),
+ url(r'^swagger/$', SchemaView.with_ui('swagger', cache_timeout=0), name='schema-swagger-ui'),
+ url(r'^redoc/$', SchemaView.with_ui('redoc', cache_timeout=0), name='schema-redoc'),
+ url(r'^cached/swagger(?P.json|.yaml)$', SchemaView.without_ui(cache_timeout=None), name='cschema-json'),
+ url(r'^cached/swagger/$', SchemaView.with_ui('swagger', cache_timeout=None), name='cschema-swagger-ui'),
+ url(r'^cached/redoc/$', SchemaView.with_ui('redoc', cache_timeout=None), name='cschema-redoc'),
url(r'^admin/', admin.site.urls),
url(r'^snippets/', include('snippets.urls')),
diff --git a/testproj/users/serializers.py b/testproj/users/serializers.py
index 87cc87c..c10f076 100644
--- a/testproj/users/serializers.py
+++ b/testproj/users/serializers.py
@@ -6,7 +6,7 @@ from snippets.models import Snippet
class UserSerializerrr(serializers.ModelSerializer):
snippets = serializers.PrimaryKeyRelatedField(many=True, queryset=Snippet.objects.all())
- article_slugs = serializers.SlugRelatedField(read_only=True, slug_field='slug', many=True, source='articlessss')
+ article_slugs = serializers.SlugRelatedField(read_only=True, slug_field='slug', many=True, source='articles')
last_connected_ip = serializers.IPAddressField(help_text="i'm out of ideas", protocol='ipv4', read_only=True)
last_connected_at = serializers.DateField(help_text="really?", read_only=True)
diff --git a/testproj/users/views.py b/testproj/users/views.py
index cac1162..5a44854 100644
--- a/testproj/users/views.py
+++ b/testproj/users/views.py
@@ -32,7 +32,7 @@ class UserList(APIView):
serializer.save()
return Response(serializer.data, status=status.HTTP_201_CREATED)
- @swagger_auto_schema(request_body=no_body, operation_description="dummy operation")
+ @swagger_auto_schema(request_body=no_body, operation_id="users_dummy", operation_description="dummy operation")
def patch(self, request):
pass
diff --git a/tests/conftest.py b/tests/conftest.py
index ff80891..a7449a1 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,14 +1,15 @@
import copy
import json
import os
+from collections import OrderedDict
import pytest
from django.contrib.auth.models import User
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
-from ruamel import yaml
from drf_yasg import openapi, codecs
+from drf_yasg.codecs import yaml_sane_load
from drf_yasg.generators import OpenAPISchemaGenerator
@@ -47,7 +48,7 @@ def swagger(mock_schema_request):
@pytest.fixture
def swagger_dict(swagger):
json_bytes = codec_json().encode(swagger)
- return json.loads(json_bytes.decode('utf-8'))
+ return json.loads(json_bytes.decode('utf-8'), object_pairs_hook=OrderedDict)
@pytest.fixture
@@ -79,4 +80,4 @@ def redoc_settings(settings):
@pytest.fixture
def reference_schema():
with open(os.path.join(os.path.dirname(__file__), 'reference.yaml')) as reference:
- return yaml.safe_load(reference)
+ return yaml_sane_load(reference)
diff --git a/tests/reference.yaml b/tests/reference.yaml
index b36e1ff..1145ff4 100644
--- a/tests/reference.yaml
+++ b/tests/reference.yaml
@@ -20,7 +20,7 @@ paths:
parameters:
- name: title
in: query
- description: ''
+ description: Filter the returned list by title
required: false
type: string
- name: ordering
@@ -89,7 +89,7 @@ paths:
parameters:
- name: title
in: query
- description: ''
+ description: Filter the returned list by title
required: false
type: string
- name: ordering
@@ -249,7 +249,7 @@ paths:
parameters: []
/snippets/:
get:
- operationId: snippets_list
+ operationId: snippetsList
description: SnippetList classdoc
parameters: []
responses:
@@ -264,7 +264,7 @@ paths:
tags:
- snippets
post:
- operationId: snippets_create
+ operationId: snippetsCreate
description: post method docstring
parameters:
- name: data
@@ -284,7 +284,7 @@ paths:
parameters: []
/snippets/{id}/:
get:
- operationId: snippets_read
+ operationId: snippetsRead
description: SnippetDetail classdoc
parameters: []
responses:
@@ -297,7 +297,7 @@ paths:
tags:
- snippets
put:
- operationId: snippets_update
+ operationId: snippetsUpdate
description: put class docstring
parameters:
- name: data
@@ -315,7 +315,7 @@ paths:
tags:
- snippets
patch:
- operationId: snippets_partial_update
+ operationId: snippetsPartialUpdate
description: patch method docstring
parameters:
- name: data
@@ -333,7 +333,7 @@ paths:
tags:
- snippets
delete:
- operationId: snippets_delete
+ operationId: snippetsDelete
description: delete method docstring
parameters: []
responses:
@@ -404,7 +404,7 @@ paths:
tags:
- users
patch:
- operationId: users_partial_update
+ operationId: users_dummy
description: dummy operation
parameters: []
responses:
@@ -466,6 +466,7 @@ definitions:
title:
description: title model help_text
type: string
+ maxLength: 255
author:
description: The ID of the user that created this article; if none is provided,
defaults to the currently logged in user.
@@ -474,11 +475,13 @@ definitions:
body:
description: body serializer help_text
type: string
+ maxLength: 5000
slug:
description: slug model help_text
type: string
format: slug
pattern: ^[-a-zA-Z0-9_]+$
+ maxLength: 50
date_created:
type: string
format: date-time
@@ -509,14 +512,16 @@ definitions:
readOnly: true
Project:
required:
- - project_name
- - github_repo
+ - projectName
+ - githubRepo
type: object
properties:
- project_name:
+ projectName:
+ title: Project name
description: Name of the project
type: string
- github_repo:
+ githubRepo:
+ title: Github repo
description: Github repository of the project
type: string
Snippet:
@@ -526,29 +531,38 @@ definitions:
type: object
properties:
id:
+ title: Id
description: id serializer help text
type: integer
readOnly: true
owner:
+ title: Owner
description: The ID of the user that created this snippet; if none is provided,
defaults to the currently logged in user.
type: integer
default: 1
- owner_as_string:
+ ownerAsString:
description: The ID of the user that created this snippet.
type: string
readOnly: true
+ title: Owner as string
title:
+ title: Title
type: string
+ maxLength: 100
code:
+ title: Code
type: string
linenos:
+ title: Linenos
type: boolean
language:
+ title: Language
description: Sample help text for language
type: object
properties:
name:
+ title: Name
description: The name of the programming language
type: string
enum:
@@ -988,6 +1002,7 @@ definitions:
- zephir
default: python
styles:
+ title: Styles
type: array
items:
type: string
@@ -1024,19 +1039,22 @@ definitions:
default:
- friendly
lines:
+ title: Lines
type: array
items:
type: integer
- example_projects:
+ exampleProjects:
+ title: Example projects
type: array
items:
$ref: '#/definitions/Project'
readOnly: true
- difficulty_factor:
+ difficultyFactor:
+ title: Difficulty factor
description: this is here just to test FloatField
type: number
- default: 6.9
readOnly: true
+ default: 6.9
UserSerializerrr:
required:
- username
@@ -1045,42 +1063,55 @@ definitions:
type: object
properties:
id:
+ title: ID
type: integer
readOnly: true
username:
+ title: Username
description: Required. 150 characters or fewer. Letters, digits and @/./+/-/_
only.
type: string
+ pattern: ^[\w.@+-]+$
+ maxLength: 150
email:
+ title: Email address
type: string
format: email
+ maxLength: 254
articles:
+ title: Articles
type: array
items:
type: integer
uniqueItems: true
snippets:
+ title: Snippets
type: array
items:
type: integer
uniqueItems: true
last_connected_ip:
+ title: Last connected ip
description: i'm out of ideas
type: string
format: ipv4
readOnly: true
last_connected_at:
+ title: Last connected at
description: really?
type: string
format: date
readOnly: true
article_slugs:
+ title: Article slugs
type: array
items:
type: string
+ format: slug
+ pattern: ^[-a-zA-Z0-9_]+\Z
readOnly: true
- uniqueItems: true
readOnly: true
+ uniqueItems: true
securityDefinitions:
basic:
type: basic
diff --git a/tests/test_api_view.py b/tests/test_api_view.py
deleted file mode 100644
index 43f9112..0000000
--- a/tests/test_api_view.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from drf_yasg import openapi
-
-
-def test_operation_docstrings(swagger_dict):
- users_list = swagger_dict['paths']['/users/']
- assert users_list['get']['description'] == "UserList cbv classdoc"
- assert users_list['post']['description'] == "apiview post description override"
-
- users_detail = swagger_dict['paths']['/users/{id}/']
- assert users_detail['get']['description'] == "user_detail fbv docstring"
- assert users_detail['put']['description'] == "user_detail fbv docstring"
-
-
-def test_parameter_docstrings(swagger_dict):
- users_detail = swagger_dict['paths']['/users/{id}/']
- assert users_detail['get']['parameters'][0]['description'] == "test manual param"
- assert users_detail['put']['parameters'][0]['in'] == openapi.IN_BODY
diff --git a/tests/test_generic_api_view.py b/tests/test_generic_api_view.py
deleted file mode 100644
index 2b698c7..0000000
--- a/tests/test_generic_api_view.py
+++ /dev/null
@@ -1,22 +0,0 @@
-def test_appropriate_status_codes(swagger_dict):
- snippets_list = swagger_dict['paths']['/snippets/']
- assert '200' in snippets_list['get']['responses']
- assert '201' in snippets_list['post']['responses']
-
- snippets_detail = swagger_dict['paths']['/snippets/{id}/']
- assert '200' in snippets_detail['get']['responses']
- assert '200' in snippets_detail['put']['responses']
- assert '200' in snippets_detail['patch']['responses']
- assert '204' in snippets_detail['delete']['responses']
-
-
-def test_operation_docstrings(swagger_dict):
- snippets_list = swagger_dict['paths']['/snippets/']
- assert snippets_list['get']['description'] == "SnippetList classdoc"
- assert snippets_list['post']['description'] == "post method docstring"
-
- snippets_detail = swagger_dict['paths']['/snippets/{id}/']
- assert snippets_detail['get']['description'] == "SnippetDetail classdoc"
- assert snippets_detail['put']['description'] == "put class docstring"
- assert snippets_detail['patch']['description'] == "patch method docstring"
- assert snippets_detail['delete']['description'] == "delete method docstring"
diff --git a/tests/test_generic_viewset.py b/tests/test_generic_viewset.py
deleted file mode 100644
index d67882b..0000000
--- a/tests/test_generic_viewset.py
+++ /dev/null
@@ -1,29 +0,0 @@
-def test_appropriate_status_codes(swagger_dict):
- articles_list = swagger_dict['paths']['/articles/']
- assert '200' in articles_list['get']['responses']
- assert '201' in articles_list['post']['responses']
-
- articles_detail = swagger_dict['paths']['/articles/{slug}/']
- assert '200' in articles_detail['get']['responses']
- assert '200' in articles_detail['put']['responses']
- assert '200' in articles_detail['patch']['responses']
- assert '204' in articles_detail['delete']['responses']
-
-
-def test_operation_docstrings(swagger_dict):
- articles_list = swagger_dict['paths']['/articles/']
- assert articles_list['get']['description'] == "description from swagger_auto_schema via method_decorator"
- assert articles_list['post']['description'] == "ArticleViewSet class docstring"
-
- articles_detail = swagger_dict['paths']['/articles/{slug}/']
- assert articles_detail['get']['description'] == "retrieve class docstring"
- assert articles_detail['put']['description'] == "update method docstring"
- assert articles_detail['patch']['description'] == "partial_update description override"
- assert articles_detail['delete']['description'] == "destroy method docstring"
-
- articles_today = swagger_dict['paths']['/articles/today/']
- assert articles_today['get']['description'] == "ArticleViewSet class docstring"
-
- articles_image = swagger_dict['paths']['/articles/{slug}/image/']
- assert articles_image['get']['description'] == "image GET description override"
- assert articles_image['post']['description'] == "image method docstring"
diff --git a/tests/test_reference_schema.py b/tests/test_reference_schema.py
index d3bd8a0..cf04cd1 100644
--- a/tests/test_reference_schema.py
+++ b/tests/test_reference_schema.py
@@ -1,13 +1,46 @@
+from collections import OrderedDict
+
from datadiff.tools import assert_equal
+from drf_yasg.codecs import yaml_sane_dump
+from drf_yasg.inspectors import FieldInspector, SerializerInspector, PaginatorInspector, FilterInspector
+
def test_reference_schema(swagger_dict, reference_schema):
- swagger_dict = dict(swagger_dict)
- reference_schema = dict(reference_schema)
+ swagger_dict = OrderedDict(swagger_dict)
+ reference_schema = OrderedDict(reference_schema)
ignore = ['info', 'host', 'schemes', 'basePath', 'securityDefinitions']
for attr in ignore:
swagger_dict.pop(attr, None)
reference_schema.pop(attr, None)
- # formatted better than pytest diff
- assert_equal(swagger_dict, reference_schema)
+ # print diff between YAML strings because it's prettier
+ assert_equal(yaml_sane_dump(swagger_dict, binary=False), yaml_sane_dump(reference_schema, binary=False))
+
+
+class NoOpFieldInspector(FieldInspector):
+ pass
+
+
+class NoOpSerializerInspector(SerializerInspector):
+ pass
+
+
+class NoOpFilterInspector(FilterInspector):
+ pass
+
+
+class NoOpPaginatorInspector(PaginatorInspector):
+ pass
+
+
+def test_noop_inspectors(swagger_settings, swagger_dict, reference_schema):
+ from drf_yasg import app_settings
+
+ def set_inspectors(inspectors, setting_name):
+ swagger_settings[setting_name] = inspectors + app_settings.SWAGGER_DEFAULTS[setting_name]
+
+ set_inspectors([NoOpFieldInspector, NoOpSerializerInspector], 'DEFAULT_FIELD_INSPECTORS')
+ set_inspectors([NoOpFilterInspector], 'DEFAULT_FILTER_INSPECTORS')
+ set_inspectors([NoOpPaginatorInspector], 'DEFAULT_PAGINATOR_INSPECTORS')
+ test_reference_schema(swagger_dict, reference_schema)
diff --git a/tests/test_schema_generator.py b/tests/test_schema_generator.py
index 62b4f80..df51b13 100644
--- a/tests/test_schema_generator.py
+++ b/tests/test_schema_generator.py
@@ -1,9 +1,10 @@
import json
+from collections import OrderedDict
import pytest
-from ruamel import yaml
from drf_yasg import openapi, codecs
+from drf_yasg.codecs import yaml_sane_load
from drf_yasg.generators import OpenAPISchemaGenerator
@@ -35,12 +36,12 @@ def test_yaml_codec_roundtrip(codec_yaml, swagger, validate_schema):
yaml_bytes = codec_yaml.encode(swagger)
assert b'omap' not in yaml_bytes # ensure no ugly !!omap is outputted
assert b'&id' not in yaml_bytes and b'*id' not in yaml_bytes # ensure no YAML references are generated
- validate_schema(yaml.safe_load(yaml_bytes.decode('utf-8')))
+ validate_schema(yaml_sane_load(yaml_bytes.decode('utf-8')))
def test_yaml_and_json_match(codec_yaml, codec_json, swagger):
- yaml_schema = yaml.safe_load(codec_yaml.encode(swagger).decode('utf-8'))
- json_schema = json.loads(codec_json.encode(swagger).decode('utf-8'))
+ yaml_schema = yaml_sane_load(codec_yaml.encode(swagger).decode('utf-8'))
+ json_schema = json.loads(codec_json.encode(swagger).decode('utf-8'), object_pairs_hook=OrderedDict)
assert yaml_schema == json_schema
diff --git a/tests/test_schema_structure.py b/tests/test_schema_structure.py
deleted file mode 100644
index 77ec55c..0000000
--- a/tests/test_schema_structure.py
+++ /dev/null
@@ -1,2 +0,0 @@
-def test_paths_not_empty(swagger_dict):
- assert len(swagger_dict['paths']) > 0
diff --git a/tests/test_schema_views.py b/tests/test_schema_views.py
index aabd828..136e5aa 100644
--- a/tests/test_schema_views.py
+++ b/tests/test_schema_views.py
@@ -2,7 +2,8 @@ import json
from collections import OrderedDict
import pytest
-from ruamel import yaml
+
+from drf_yasg.codecs import yaml_sane_load
def _validate_text_schema_view(client, validate_schema, path, loader):
@@ -22,10 +23,10 @@ def test_swagger_json(client, validate_schema):
def test_swagger_yaml(client, validate_schema):
- _validate_text_schema_view(client, validate_schema, "/swagger.yaml", yaml.safe_load)
+ _validate_text_schema_view(client, validate_schema, "/swagger.yaml", yaml_sane_load)
-def test_exception_middleware(client, swagger_settings):
+def test_exception_middleware(client, swagger_settings, db):
swagger_settings['SECURITY_DEFINITIONS'] = {
'bad': {
'bad_attribute': 'should not be accepted'
@@ -70,5 +71,5 @@ def test_caching(client, validate_schema):
@pytest.mark.urls('urlconfs.non_public_urls')
def test_non_public(client):
response = client.get('/private/swagger.yaml')
- swagger = yaml.safe_load(response.content.decode('utf-8'))
+ swagger = yaml_sane_load(response.content.decode('utf-8'))
assert len(swagger['paths']) == 0
diff --git a/tests/test_versioning.py b/tests/test_versioning.py
new file mode 100644
index 0000000..265a58e
--- /dev/null
+++ b/tests/test_versioning.py
@@ -0,0 +1,56 @@
+import pytest
+
+from drf_yasg.codecs import yaml_sane_load
+
+
+def _get_versioned_schema(prefix, client, validate_schema):
+ response = client.get(prefix + 'swagger.yaml')
+ assert response.status_code == 200
+ swagger = yaml_sane_load(response.content.decode('utf-8'))
+ validate_schema(swagger)
+ assert prefix + 'snippets/' in swagger['paths']
+ return swagger
+
+
+def _check_v1(swagger, prefix):
+ assert swagger['info']['version'] == '1.0'
+ versioned_post = swagger['paths'][prefix + 'snippets/']['post']
+ assert versioned_post['responses']['201']['schema']['$ref'] == '#/definitions/Snippet'
+ assert 'v2field' not in swagger['definitions']['Snippet']['properties']
+
+
+def _check_v2(swagger, prefix):
+ assert swagger['info']['version'] == '2.0'
+ versioned_post = swagger['paths'][prefix + 'snippets/']['post']
+ assert versioned_post['responses']['201']['schema']['$ref'] == '#/definitions/SnippetV2'
+ assert 'v2field' in swagger['definitions']['SnippetV2']['properties']
+ v2field = swagger['definitions']['SnippetV2']['properties']['v2field']
+ assert v2field['description'] == 'version 2.0 field'
+
+
+@pytest.mark.urls('urlconfs.url_versioning')
+def test_url_v1(client, validate_schema):
+ prefix = '/versioned/url/v1.0/'
+ swagger = _get_versioned_schema(prefix, client, validate_schema)
+ _check_v1(swagger, prefix)
+
+
+@pytest.mark.urls('urlconfs.url_versioning')
+def test_url_v2(client, validate_schema):
+ prefix = '/versioned/url/v2.0/'
+ swagger = _get_versioned_schema(prefix, client, validate_schema)
+ _check_v2(swagger, prefix)
+
+
+@pytest.mark.urls('urlconfs.ns_versioning')
+def test_ns_v1(client, validate_schema):
+ prefix = '/versioned/ns/v1.0/'
+ swagger = _get_versioned_schema(prefix, client, validate_schema)
+ _check_v1(swagger, prefix)
+
+
+@pytest.mark.urls('urlconfs.ns_versioning')
+def test_ns_v2(client, validate_schema):
+ prefix = '/versioned/ns/v2.0/'
+ swagger = _get_versioned_schema(prefix, client, validate_schema)
+ _check_v2(swagger, prefix)
diff --git a/tests/urlconfs/ns_version1.py b/tests/urlconfs/ns_version1.py
new file mode 100644
index 0000000..8d22a56
--- /dev/null
+++ b/tests/urlconfs/ns_version1.py
@@ -0,0 +1,26 @@
+from django.conf.urls import url
+from rest_framework import generics, versioning
+
+from snippets.models import Snippet
+from snippets.serializers import SnippetSerializer
+
+
+class SnippetList(generics.ListCreateAPIView):
+ """SnippetList classdoc"""
+ queryset = Snippet.objects.all()
+ serializer_class = SnippetSerializer
+ versioning_class = versioning.NamespaceVersioning
+
+ def perform_create(self, serializer):
+ serializer.save(owner=self.request.user)
+
+ def post(self, request, *args, **kwargs):
+ """post method docstring"""
+ return super(SnippetList, self).post(request, *args, **kwargs)
+
+
+app_name = 'test_ns_versioning'
+
+urlpatterns = [
+ url(r"^$", SnippetList.as_view())
+]
diff --git a/tests/urlconfs/ns_version2.py b/tests/urlconfs/ns_version2.py
new file mode 100644
index 0000000..69908f2
--- /dev/null
+++ b/tests/urlconfs/ns_version2.py
@@ -0,0 +1,23 @@
+from django.conf.urls import url
+from rest_framework import fields
+
+from snippets.serializers import SnippetSerializer
+from .ns_version1 import SnippetList as SnippetListV1
+
+
+class SnippetSerializerV2(SnippetSerializer):
+ v2field = fields.IntegerField(help_text="version 2.0 field")
+
+ class Meta:
+ ref_name = 'SnippetV2'
+
+
+class SnippetListV2(SnippetListV1):
+ serializer_class = SnippetSerializerV2
+
+
+app_name = 'test_ns_versioning'
+
+urlpatterns = [
+ url(r"^$", SnippetListV2.as_view())
+]
diff --git a/tests/urlconfs/ns_versioning.py b/tests/urlconfs/ns_versioning.py
new file mode 100644
index 0000000..5875908
--- /dev/null
+++ b/tests/urlconfs/ns_versioning.py
@@ -0,0 +1,24 @@
+from django.conf.urls import url, include
+from rest_framework import versioning
+
+from testproj.urls import SchemaView
+from . import ns_version1, ns_version2
+
+VERSION_PREFIX_NS = r"^versioned/ns/"
+
+
+class VersionedSchemaView(SchemaView):
+ versioning_class = versioning.NamespaceVersioning
+
+
+schema_patterns = [
+ url(r'swagger(?P.json|.yaml)$', VersionedSchemaView.without_ui(), name='ns-schema')
+]
+
+
+urlpatterns = [
+ url(VERSION_PREFIX_NS + r"v1.0/snippets/", include(ns_version1, namespace='1.0')),
+ url(VERSION_PREFIX_NS + r"v2.0/snippets/", include(ns_version2, namespace='2.0')),
+ url(VERSION_PREFIX_NS + r'v1.0/', include((schema_patterns, '1.0'))),
+ url(VERSION_PREFIX_NS + r'v2.0/', include((schema_patterns, '2.0'))),
+]
diff --git a/tests/urlconfs/url_versioning.py b/tests/urlconfs/url_versioning.py
new file mode 100644
index 0000000..5642b5c
--- /dev/null
+++ b/tests/urlconfs/url_versioning.py
@@ -0,0 +1,48 @@
+from django.conf.urls import url
+from rest_framework import generics, versioning, fields
+
+from snippets.models import Snippet
+from snippets.serializers import SnippetSerializer
+from testproj.urls import SchemaView
+
+
+class SnippetSerializerV2(SnippetSerializer):
+ v2field = fields.IntegerField(help_text="version 2.0 field")
+
+ class Meta:
+ ref_name = 'SnippetV2'
+
+
+class SnippetList(generics.ListCreateAPIView):
+ """SnippetList classdoc"""
+ queryset = Snippet.objects.all()
+ serializer_class = SnippetSerializer
+ versioning_class = versioning.URLPathVersioning
+
+ def get_serializer_class(self):
+ context = self.get_serializer_context()
+ request = context['request']
+ if int(float(request.version)) >= 2:
+ return SnippetSerializerV2
+ else:
+ return SnippetSerializer
+
+ def perform_create(self, serializer):
+ serializer.save(owner=self.request.user)
+
+ def post(self, request, *args, **kwargs):
+ """post method docstring"""
+ return super(SnippetList, self).post(request, *args, **kwargs)
+
+
+VERSION_PREFIX_URL = r"^versioned/url/v(?P1.0|2.0)/"
+
+
+class VersionedSchemaView(SchemaView):
+ versioning_class = versioning.URLPathVersioning
+
+
+urlpatterns = [
+ url(VERSION_PREFIX_URL + r"snippets/$", SnippetList.as_view()),
+ url(VERSION_PREFIX_URL + r'swagger(?P.json|.yaml)$', VersionedSchemaView.without_ui(), name='vschema-json'),
+]