diff --git a/.codecov.yml b/.codecov.yml index fb3842d..57bfecc 100644 --- a/.codecov.yml +++ b/.codecov.yml @@ -11,15 +11,15 @@ coverage: default: enabled: yes target: auto - threshold: 0% + threshold: 100% if_no_uploads: error if_ci_failed: error patch: default: enabled: yes - target: 80% - threshold: 0% + target: 100% + threshold: 100% if_no_uploads: error if_ci_failed: error diff --git a/.coveragerc b/.coveragerc index d3750bc..996f08c 100644 --- a/.coveragerc +++ b/.coveragerc @@ -17,6 +17,8 @@ exclude_lines = raise TypeError raise NotImplementedError warnings.warn + logger.warning + return NotHandled # Don't complain if non-runnable code isn't run: if 0: diff --git a/.gitignore b/.gitignore index d48f341..0840b0f 100644 --- a/.gitignore +++ b/.gitignore @@ -156,3 +156,5 @@ com_crashlytics_export_strings.xml crashlytics.properties crashlytics-build.properties fabric.properties + +testproj/db\.sqlite3 diff --git a/.travis.yml b/.travis.yml index d8ceb8d..dab442f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -40,6 +40,7 @@ after_success: branches: only: - master + - /^release\/.*$/ notifications: email: diff --git a/CONTRIBUTING.rst b/CONTRIBUTING.rst index 1b239ad..ec46a5f 100644 --- a/CONTRIBUTING.rst +++ b/CONTRIBUTING.rst @@ -43,6 +43,8 @@ You want to contribute some code? Great! Here are a few steps to get you started .. code:: console (venv) $ cd testproj + (venv) $ python manage.py migrate + (venv) $ cat createsuperuser.py | python manage.py shell (venv) $ python manage.py runserver (venv) $ curl localhost:8000/swagger.yaml diff --git a/README.rst b/README.rst index 75efd12..2b01cbb 100644 --- a/README.rst +++ b/README.rst @@ -141,6 +141,7 @@ This exposes 4 cached, validated and publicly available endpoints: 2. Configuration ================ +--------------------------------- a. ``get_schema_view`` parameters --------------------------------- @@ -153,6 +154,7 @@ a. ``get_schema_view`` parameters - ``authentication_classes`` - authentication classes for the schema view itself - ``permission_classes`` - permission classes for the schema view itself +------------------------------- b. ``SchemaView`` options ------------------------------- @@ -169,6 +171,7 @@ All of the first 3 methods take two optional arguments, to Django’s :python:`cached_page` decorator in order to enable caching on the resulting view. See `3. Caching`_. +---------------------------------------------- c. ``SWAGGER_SETTINGS`` and ``REDOC_SETTINGS`` ---------------------------------------------- @@ -178,6 +181,26 @@ The possible settings and their default values are as follows: .. code:: python SWAGGER_SETTINGS = { + # default inspector classes, see advanced documentation + 'DEFAULT_AUTO_SCHEMA_CLASS': 'drf_yasg.inspectors.SwaggerAutoSchema', + 'DEFAULT_FIELD_INSPECTORS': [ + 'drf_yasg.inspectors.CamelCaseJSONFilter', + 'drf_yasg.inspectors.ReferencingSerializerInspector', + 'drf_yasg.inspectors.RelatedFieldInspector', + 'drf_yasg.inspectors.ChoiceFieldInspector', + 'drf_yasg.inspectors.FileFieldInspector', + 'drf_yasg.inspectors.DictFieldInspector', + 'drf_yasg.inspectors.SimpleFieldInspector', + 'drf_yasg.inspectors.StringDefaultFieldInspector', + ], + 'DEFAULT_FILTER_INSPECTORS': [ + 'drf_yasg.inspectors.CoreAPICompatInspector', + ], + 'DEFAULT_PAGINATOR_INSPECTORS': [ + 'drf_yasg.inspectors.DjangoRestResponsePagination', + 'drf_yasg.inspectors.CoreAPICompatInspector', + ], + 'USE_SESSION_AUTH': True, # add Django Login and Django Logout buttons, CSRF token to swagger UI page 'LOGIN_URL': getattr(django.conf.settings, 'LOGIN_URL', None), # URL for the login button 'LOGOUT_URL': getattr(django.conf.settings, 'LOGOUT_URL', None), # URL for the logout button @@ -241,6 +264,7 @@ Caching can mitigate the speed impact of validation. The provided validation will catch syntactic errors, but more subtle violations of the spec might slip by them. To ensure compatibility with code generation tools, it is recommended to also employ one or more of the following methods: +------------------------------- ``swagger-ui`` validation badge ------------------------------- @@ -271,6 +295,7 @@ If your schema is not accessible from the internet, you can run a local copy of $ curl http://localhost:8189/debug?url=http://test.local:8002/swagger/?format=openapi {} +--------------------- Using ``swagger-cli`` --------------------- @@ -283,6 +308,7 @@ https://www.npmjs.com/package/swagger-cli $ swagger-cli validate http://test.local:8002/swagger.yaml http://test.local:8002/swagger.yaml is valid +-------------------------------------------------------------- Manually on `editor.swagger.io `__ -------------------------------------------------------------- @@ -345,10 +371,16 @@ named schemas. Both projects are also currently unmantained. -Documentation, advanced usage -============================= +************************ +Third-party integrations +************************ -https://drf-yasg.readthedocs.io/en/latest/ +djangorestframework-camel-case +=============================== + +Integration with `djangorestframework-camel-case `_ is +provided out of the box - if you have ``djangorestframework-camel-case`` installed and your ``APIView`` uses +``CamelCaseJSONParser`` or ``CamelCaseJSONRenderer``, all property names will be converted to *camelCase* by default. .. |travis| image:: https://img.shields.io/travis/axnsan12/drf-yasg/master.svg :target: https://travis-ci.org/axnsan12/drf-yasg diff --git a/docs/_static/css/style.css b/docs/_static/css/style.css new file mode 100644 index 0000000..2571b7c --- /dev/null +++ b/docs/_static/css/style.css @@ -0,0 +1,18 @@ +.versionadded, .versionchanged, .deprecated { + font-family: "Roboto", Corbel, Avenir, "Lucida Grande", "Lucida Sans", sans-serif; + padding: 10px 13px; + border: 1px solid rgb(137, 191, 4); + border-radius: 4px; + margin-bottom: 10px; +} + +.versionmodified { + font-weight: bold; + display: block; +} + +.versionadded p, .versionchanged p, .deprecated p, +/*override fucking !important by being more specific */ +.rst-content dl .versionadded p, .rst-content dl .versionchanged p { + margin: 0 !important; +} diff --git a/docs/_templates/layout.html b/docs/_templates/layout.html new file mode 100644 index 0000000..c242078 --- /dev/null +++ b/docs/_templates/layout.html @@ -0,0 +1,4 @@ +{% extends "!layout.html" %} +{% block extrahead %} + +{% endblock %} diff --git a/docs/changelog.rst b/docs/changelog.rst index 6e3fd2a..79bbbc0 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,6 +3,21 @@ Changelog ######### +********* +**1.1.0** +********* + +- **ADDED:** added support for APIs versioned with ``URLPathVersioning`` or ``NamespaceVersioning`` +- **ADDED:** added ability to recursively customize schema generation + :ref:`using pluggable inspector classes ` +- **ADDED:** added ``operation_id`` parameter to :func:`@swagger_auto_schema <.swagger_auto_schema>` +- **ADDED:** integration with `djangorestframework-camel-case + `_ (:issue:`28`) +- **IMPROVED:** strings, arrays and integers will now have min/max validation attributes inferred from the + field-level validators +- **FIXED:** fixed a bug that caused ``title`` to never be generated for Schemas; ``title`` is now correctly + populated from the field's ``label`` property + ********* **1.0.6** ********* diff --git a/docs/conf.py b/docs/conf.py index bdad022..3e9b70b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -3,6 +3,7 @@ # # drf-yasg documentation build configuration file, created by # sphinx-quickstart on Sun Dec 10 15:20:34 2017. +import inspect import os import re import sys @@ -68,9 +69,6 @@ pygments_style = 'sphinx' modindex_common_prefix = ['drf_yasg.'] -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = False - # -- Options for HTML output ---------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for @@ -186,18 +184,23 @@ nitpick_ignore = [ ('py:obj', 'callable'), ('py:obj', 'type'), ('py:obj', 'OrderedDict'), + ('py:obj', 'None'), ('py:obj', 'coreapi.Field'), ('py:obj', 'BaseFilterBackend'), ('py:obj', 'BasePagination'), + ('py:obj', 'Request'), ('py:obj', 'rest_framework.request.Request'), ('py:obj', 'rest_framework.serializers.Field'), ('py:obj', 'serializers.Field'), ('py:obj', 'serializers.BaseSerializer'), ('py:obj', 'Serializer'), + ('py:obj', 'BaseSerializer'), ('py:obj', 'APIView'), ] +# TODO: inheritance aliases in sphinx 1.7 + # even though the package should be already installed, the sphinx build on RTD # for some reason needs the sources dir to be in the path in order for viewcode to work sys.path.insert(0, os.path.abspath('../src')) @@ -215,6 +218,40 @@ import drf_yasg.views # noqa: E402 drf_yasg.views.SchemaView = drf_yasg.views.get_schema_view(None) +# monkey patch to stop sphinx from trying to find classes by their real location instead of the +# top-level __init__ alias; this allows us to document only `drf_yasg.inspectors` and avoid broken references or +# double documenting + +import drf_yasg.inspectors # noqa: E402 + + +def redirect_cls(cls): + if cls.__module__.startswith('drf_yasg.inspectors'): + return getattr(drf_yasg.inspectors, cls.__name__) + return cls + + +for cls_name in drf_yasg.inspectors.__all__: + # first pass - replace all classes' module with the top level module + real_cls = getattr(drf_yasg.inspectors, cls_name) + if not inspect.isclass(real_cls): + continue + + patched_dict = dict(real_cls.__dict__) + patched_dict.update({'__module__': 'drf_yasg.inspectors'}) + patched_cls = type(cls_name, real_cls.__bases__, patched_dict) + setattr(drf_yasg.inspectors, cls_name, patched_cls) + +for cls_name in drf_yasg.inspectors.__all__: + # second pass - replace the inheritance bases for all classes to point to the new clean classes + real_cls = getattr(drf_yasg.inspectors, cls_name) + if not inspect.isclass(real_cls): + continue + + patched_bases = tuple(redirect_cls(base) for base in real_cls.__bases__) + patched_cls = type(cls_name, patched_bases, dict(real_cls.__dict__)) + setattr(drf_yasg.inspectors, cls_name, patched_cls) + # custom interpreted role for linking to GitHub issues and pull requests # use as :issue:`14` or :pr:`17` gh_issue_uri = "https://github.com/axnsan12/drf-yasg/issues/{}" @@ -273,3 +310,7 @@ def role_github_pull_request_or_issue(name, rawtext, text, lineno, inliner, opti roles.register_local_role('pr', role_github_pull_request_or_issue) roles.register_local_role('issue', role_github_pull_request_or_issue) roles.register_local_role('ghuser', role_github_user) + + +def setup(app): + app.add_stylesheet('css/style.css') diff --git a/docs/custom_spec.rst b/docs/custom_spec.rst index 2541b1a..1b23b79 100644 --- a/docs/custom_spec.rst +++ b/docs/custom_spec.rst @@ -249,15 +249,63 @@ Where you can use the :func:`@swagger_auto_schema <.swagger_auto_schema>` decora However, do note that both of the methods above can lead to unexpected (and maybe surprising) results by replacing/decorating methods on the base class itself. + +******************************** +Serializer ``Meta`` nested class +******************************** + +You can define some per-serializer options by adding a ``Meta`` class to your serializer, e.g.: + +.. code:: python + + class WhateverSerializer(Serializer): + ... + + class Meta: + ... options here ... + +Currently, the only option you can add here is + + * ``ref_name`` - a string which will be used as the model definition name for this serializer class; setting it to + ``None`` will force the serializer to be generated as an inline model everywhere it is used + ************************* Subclassing and extending ************************* -For more advanced control you can subclass :class:`.SwaggerAutoSchema` - see the documentation page for a list of -methods you can override. + +--------------------- +``SwaggerAutoSchema`` +--------------------- + +For more advanced control you can subclass :class:`~.inspectors.SwaggerAutoSchema` - see the documentation page +for a list of methods you can override. You can put your custom subclass to use by setting it on a view method using the -:func:`@swagger_auto_schema <.swagger_auto_schema>` decorator described above. +:ref:`@swagger_auto_schema ` decorator described above, by setting it as a +class-level attribute named ``swagger_schema`` on the view class, or +:ref:`globally via settings `. + +For example, to generate all operation IDs as camel case, you could do: + +.. code:: python + + from inflection import camelize + + class CamelCaseOperationIDAutoSchema(SwaggerAutoSchema): + def get_operation_id(self, operation_keys): + operation_id = super(CamelCaseOperationIDAutoSchema, self).get_operation_id(operation_keys) + return camelize(operation_id, uppercase_first_letter=False) + + + SWAGGER_SETTINGS = { + 'DEFAULT_AUTO_SCHEMA_CLASS': 'path.to.CamelCaseOperationIDAutoSchema', + ... + } + +-------------------------- +``OpenAPISchemaGenerator`` +-------------------------- If you need to control things at a higher level than :class:`.Operation` objects (e.g. overall document structure, vendor extensions in metadata) you can also subclass :class:`.OpenAPISchemaGenerator` - again, see the documentation @@ -265,3 +313,88 @@ page for a list of its methods. This custom generator can be put to use by setting it as the :attr:`.generator_class` of a :class:`.SchemaView` using :func:`.get_schema_view`. + +.. _custom-spec-inspectors: + +--------------------- +``Inspector`` classes +--------------------- + +.. versionadded:: 1.1 + +For customizing behavior related to specific field, serializer, filter or paginator classes you can implement the +:class:`~.inspectors.FieldInspector`, :class:`~.inspectors.SerializerInspector`, :class:`~.inspectors.FilterInspector`, +:class:`~.inspectors.PaginatorInspector` classes and use them with +:ref:`@swagger_auto_schema ` or one of the +:ref:`related settings `. + +A :class:`~.inspectors.FilterInspector` that adds a description to all ``DjangoFilterBackend`` parameters could be +implemented like so: + +.. code:: python + + class DjangoFilterDescriptionInspector(CoreAPICompatInspector): + def get_filter_parameters(self, filter_backend): + if isinstance(filter_backend, DjangoFilterBackend): + result = super(DjangoFilterDescriptionInspector, self).get_filter_parameters(filter_backend) + for param in result: + if not param.get('description', ''): + param.description = "Filter the returned list by {field_name}".format(field_name=param.name) + + return result + + return NotHandled + + @method_decorator(name='list', decorator=swagger_auto_schema( + filter_inspectors=[DjangoFilterDescriptionInspector] + )) + class ArticleViewSet(viewsets.ModelViewSet): + filter_backends = (DjangoFilterBackend,) + filter_fields = ('title',) + ... + + +A second example, of a :class:`~.inspectors.FieldInspector` that removes the ``title`` attribute from all generated +:class:`.Schema` objects: + +.. code:: python + + class NoSchemaTitleInspector(FieldInspector): + def process_result(self, result, method_name, obj, **kwargs): + # remove the `title` attribute of all Schema objects + if isinstance(result, openapi.Schema.OR_REF): + # traverse any references and alter the Schema object in place + schema = openapi.resolve_ref(result, self.components) + schema.pop('title', None) + + # no ``return schema`` here, because it would mean we always generate + # an inline `object` instead of a definition reference + + # return back the same object that we got - i.e. a reference if we got a reference + return result + + + class NoTitleAutoSchema(SwaggerAutoSchema): + field_inspectors = [NoSchemaTitleInspector] + swagger_settings.DEFAULT_FIELD_INSPECTORS + + class ArticleViewSet(viewsets.ModelViewSet): + swagger_schema = NoTitleAutoSchema + ... + + +.. Note:: + + A note on references - :class:`.Schema` objects are sometimes output by reference (:class:`.SchemaRef`); in fact, + that is how named models are implemented in OpenAPI: + + - in the output swagger document there is a ``definitions`` section containing :class:`.Schema` objects for all + models + - every usage of a model refers to that single :class:`.Schema` object - for example, in the ArticleViewSet + above, all requests and responses containg an ``Article`` model would refer to the same schema definition by a + ``'$ref': '#/definitions/Article'`` + + This is implemented by only generating **one** :class:`.Schema` object for every serializer **class** encountered. + + This means that you should generally avoid view or method-specific ``FieldInspector``\ s if you are dealing with + references (a.k.a named models), because you can never know which view will be the first to generate the schema + for a given serializer. diff --git a/docs/drf_yasg.rst b/docs/drf_yasg.rst index 9fa9a4b..77c95aa 100644 --- a/docs/drf_yasg.rst +++ b/docs/drf_yasg.rst @@ -1,14 +1,6 @@ drf\_yasg package ==================== -drf\_yasg\.app\_settings ----------------------------------- - -.. automodule:: drf_yasg.app_settings - :members: - :undoc-members: - :show-inheritance: - drf\_yasg\.codecs --------------------------- @@ -16,7 +8,7 @@ drf\_yasg\.codecs :members: :undoc-members: :show-inheritance: - :exclude-members: SaneYamlDumper + :exclude-members: SaneYamlDumper,SaneYamlLoader drf\_yasg\.errors --------------------------- @@ -37,6 +29,8 @@ drf\_yasg\.generators drf\_yasg\.inspectors ------------------------------- +.. autodata:: drf_yasg.inspectors.NotHandled + .. automodule:: drf_yasg.inspectors :members: :undoc-members: diff --git a/docs/settings.rst b/docs/settings.rst index 9a66d72..afd3ebf 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -37,6 +37,60 @@ The possible settings and their default values are as follows: ``SWAGGER_SETTINGS`` ******************** + +.. _default-class-settings: + +Default classes +=============== + +DEFAULT_AUTO_SCHEMA_CLASS +------------------------- + +:class:`~.inspectors.ViewInspector` subclass that will be used by default for generating :class:`.Operation` +objects when iterating over endpoints. Can be overriden by using the `auto_schema` argument of +:func:`@swagger_auto_schema <.swagger_auto_schema>` or by a ``swagger_schema`` attribute on the view class. + +**Default**: :class:`drf_yasg.inspectors.SwaggerAutoSchema` + +DEFAULT_FIELD_INSPECTORS +------------------------ + +List of :class:`~.inspectors.FieldInspector` subclasses that will be used by default for inspecting serializers and +serializer fields. Field inspectors given to :func:`@swagger_auto_schema <.swagger_auto_schema>` will be prepended +to this list. + +**Default**: ``[`` |br| \ +:class:`'drf_yasg.inspectors.CamelCaseJSONFilter' <.inspectors.CamelCaseJSONFilter>`, |br| \ +:class:`'drf_yasg.inspectors.ReferencingSerializerInspector' <.inspectors.ReferencingSerializerInspector>`, |br| \ +:class:`'drf_yasg.inspectors.RelatedFieldInspector' <.inspectors.RelatedFieldInspector>`, |br| \ +:class:`'drf_yasg.inspectors.ChoiceFieldInspector' <.inspectors.ChoiceFieldInspector>`, |br| \ +:class:`'drf_yasg.inspectors.FileFieldInspector' <.inspectors.FileFieldInspector>`, |br| \ +:class:`'drf_yasg.inspectors.DictFieldInspector' <.inspectors.DictFieldInspector>`, |br| \ +:class:`'drf_yasg.inspectors.SimpleFieldInspector' <.inspectors.SimpleFieldInspector>`, |br| \ +:class:`'drf_yasg.inspectors.StringDefaultFieldInspector' <.inspectors.StringDefaultFieldInspector>`, |br| \ +``]`` + +DEFAULT_FILTER_INSPECTORS +------------------------- + +List of :class:`~.inspectors.FilterInspector` subclasses that will be used by default for inspecting filter backends. +Filter inspectors given to :func:`@swagger_auto_schema <.swagger_auto_schema>` will be prepended to this list. + +**Default**: ``[`` |br| \ +:class:`'drf_yasg.inspectors.CoreAPICompatInspector' <.inspectors.CoreAPICompatInspector>`, |br| \ +``]`` + +DEFAULT_PAGINATOR_INSPECTORS +---------------------------- + +List of :class:`~.inspectors.PaginatorInspector` subclasses that will be used by default for inspecting paginators. +Paginator inspectors given to :func:`@swagger_auto_schema <.swagger_auto_schema>` will be prepended to this list. + +**Default**: ``[`` |br| \ +:class:`'drf_yasg.inspectors.DjangoRestResponsePagination' <.inspectors.DjangoRestResponsePagination>`, |br| \ +:class:`'drf_yasg.inspectors.CoreAPICompatInspector' <.inspectors.CoreAPICompatInspector>`, |br| \ +``]`` + Authorization ============= diff --git a/requirements/test.txt b/requirements/test.txt index c45ebcf..762fcb9 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -12,3 +12,4 @@ pygments>=2.2.0 django-cors-headers>=2.1.0 django-filter>=1.1.0,<2.0; python_version == "2.7" django-filter>=1.1.0; python_version >= "3.4" +djangorestframework-camel-case>=0.2.0 diff --git a/setup.py b/setup.py index bf3e32b..4a8fb03 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ requirements_validation = read_req('validation.txt') setup( name='drf-yasg', use_scm_version=True, - packages=find_packages('src', include=['drf_yasg']), + packages=find_packages('src'), package_dir={'': 'src'}, include_package_data=True, install_requires=requirements, diff --git a/src/drf_yasg/app_settings.py b/src/drf_yasg/app_settings.py index 527c9d0..a6ae8e6 100644 --- a/src/drf_yasg/app_settings.py +++ b/src/drf_yasg/app_settings.py @@ -2,6 +2,26 @@ from django.conf import settings from rest_framework.settings import perform_import SWAGGER_DEFAULTS = { + 'DEFAULT_AUTO_SCHEMA_CLASS': 'drf_yasg.inspectors.SwaggerAutoSchema', + + 'DEFAULT_FIELD_INSPECTORS': [ + 'drf_yasg.inspectors.CamelCaseJSONFilter', + 'drf_yasg.inspectors.ReferencingSerializerInspector', + 'drf_yasg.inspectors.RelatedFieldInspector', + 'drf_yasg.inspectors.ChoiceFieldInspector', + 'drf_yasg.inspectors.FileFieldInspector', + 'drf_yasg.inspectors.DictFieldInspector', + 'drf_yasg.inspectors.SimpleFieldInspector', + 'drf_yasg.inspectors.StringDefaultFieldInspector', + ], + 'DEFAULT_FILTER_INSPECTORS': [ + 'drf_yasg.inspectors.CoreAPICompatInspector', + ], + 'DEFAULT_PAGINATOR_INSPECTORS': [ + 'drf_yasg.inspectors.DjangoRestResponsePagination', + 'drf_yasg.inspectors.CoreAPICompatInspector', + ], + 'USE_SESSION_AUTH': True, 'SECURITY_DEFINITIONS': { 'basic': { @@ -28,7 +48,12 @@ REDOC_DEFAULTS = { 'PATH_IN_MIDDLE': False, } -IMPORT_STRINGS = [] +IMPORT_STRINGS = [ + 'DEFAULT_AUTO_SCHEMA_CLASS', + 'DEFAULT_FIELD_INSPECTORS', + 'DEFAULT_FILTER_INSPECTORS', + 'DEFAULT_PAGINATOR_INSPECTORS', +] class AppSettings(object): diff --git a/src/drf_yasg/codecs.py b/src/drf_yasg/codecs.py index d2746a0..d52a54c 100644 --- a/src/drf_yasg/codecs.py +++ b/src/drf_yasg/codecs.py @@ -98,6 +98,9 @@ class OpenAPICodecJson(_OpenAPICodec): return json.dumps(spec) +YAML_MAP_TAG = u'tag:yaml.org,2002:map' + + class SaneYamlDumper(yaml.SafeDumper): """YamlDumper class usable for dumping ``OrderedDict`` and list instances in a standard way.""" @@ -122,7 +125,7 @@ class SaneYamlDumper(yaml.SafeDumper): To use yaml.safe_dump(), you need the following. """ - tag = u'tag:yaml.org,2002:map' + tag = YAML_MAP_TAG value = [] node = yaml.MappingNode(tag, value, flow_style=flow_style) if dump.alias_key is not None: @@ -158,7 +161,7 @@ def yaml_sane_dump(data, binary): * list elements are indented into their parents * YAML references/aliases are disabled - :param dict data: the data to be serializers + :param dict data: the data to be dumped :param bool binary: True to return a utf-8 encoded binary object, False to return a string :return: the serialized YAML :rtype: str,bytes @@ -166,6 +169,24 @@ def yaml_sane_dump(data, binary): return yaml.dump(data, Dumper=SaneYamlDumper, default_flow_style=False, encoding='utf-8' if binary else None) +class SaneYamlLoader(yaml.SafeLoader): + def construct_odict(self, node, deep=False): + self.flatten_mapping(node) + return OrderedDict(self.construct_pairs(node)) + + +SaneYamlLoader.add_constructor(YAML_MAP_TAG, SaneYamlLoader.construct_odict) + + +def yaml_sane_load(stream): + """Load the given YAML stream while preserving the input order for mapping items. + + :param stream: YAML stream (can be a string or a file-like object) + :rtype: OrderedDict + """ + return yaml.load(stream, Loader=SaneYamlLoader) + + class OpenAPICodecYaml(_OpenAPICodec): media_type = 'application/yaml' diff --git a/src/drf_yasg/generators.py b/src/drf_yasg/generators.py index 3f1979f..0f3d8d7 100644 --- a/src/drf_yasg/generators.py +++ b/src/drf_yasg/generators.py @@ -2,12 +2,15 @@ import re from collections import defaultdict, OrderedDict import uritemplate +from django.utils.encoding import force_text +from rest_framework import versioning from rest_framework.schemas.generators import SchemaGenerator, EndpointEnumerator as _EndpointEnumerator +from rest_framework.schemas.inspectors import get_pk_description from . import openapi -from .inspectors import SwaggerAutoSchema +from .app_settings import swagger_settings +from .inspectors.field import get_queryset_field, get_basic_type_info from .openapi import ReferenceResolver -from .utils import inspect_model_field, get_model_field PATH_PARAMETER_RE = re.compile(r'{(?P\w+)}') @@ -52,7 +55,7 @@ class EndpointEnumerator(_EndpointEnumerator): class OpenAPISchemaGenerator(object): """ This class iterates over all registered API endpoints and returns an appropriate OpenAPI 2.0 compliant schema. - Method implementations shamelessly stolen and adapted from rest_framework SchemaGenerator. + Method implementations shamelessly stolen and adapted from rest-framework ``SchemaGenerator``. """ endpoint_enumerator_class = EndpointEnumerator @@ -70,10 +73,14 @@ class OpenAPISchemaGenerator(object): self.info = info self.version = version - def get_schema(self, request=None, public=False): - """Generate an :class:`.Swagger` representing the API schema. + @property + def url(self): + return self._gen.url - :param rest_framework.request.Request request: the request used for filtering + def get_schema(self, request=None, public=False): + """Generate a :class:`.Swagger` object representing the API schema. + + :param Request request: the request used for filtering accesible endpoints and finding the spec URI :param bool public: if True, all endpoints are included regardless of access through `request` @@ -81,10 +88,11 @@ class OpenAPISchemaGenerator(object): :rtype: openapi.Swagger """ endpoints = self.get_endpoints(request) + endpoints = self.replace_version(endpoints, request) components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS) - paths = self.get_paths(endpoints, components, public) + paths = self.get_paths(endpoints, components, request, public) - url = self._gen.url + url = self.url if not url and request is not None: url = request.build_absolute_uri() @@ -102,16 +110,40 @@ class OpenAPISchemaGenerator(object): :return: the view instance """ view = self._gen.create_view(callback, method, request) - overrides = getattr(callback, 'swagger_auto_schema', None) + overrides = getattr(callback, '_swagger_auto_schema', None) if overrides is not None: # decorated function based view must have its decorator information passed on to the re-instantiated view for method, _ in overrides.items(): view_method = getattr(view, method, None) if view_method is not None: # pragma: no cover - setattr(view_method.__func__, 'swagger_auto_schema', overrides) + setattr(view_method.__func__, '_swagger_auto_schema', overrides) return view - def get_endpoints(self, request=None): + def replace_version(self, endpoints, request): + """If ``request.version`` is not ``None``, replace the version parameter in the path of any endpoints using + ``URLPathVersioning`` as a versioning class. + + :param dict endpoints: endpoints as returned by :meth:`.get_endpoints` + :param Request request: the request made against the schema view + :return: endpoints with modified paths + """ + version = getattr(request, 'version', None) + if version is None: + return endpoints + + new_endpoints = {} + for path, endpoint in endpoints.items(): + view_cls = endpoint[0] + versioning_class = getattr(view_cls, 'versioning_class', None) + version_param = getattr(versioning_class, 'version_param', 'version') + if versioning_class is not None and issubclass(versioning_class, versioning.URLPathVersioning): + path = path.replace('{%s}' % version_param, version) + + new_endpoints[path] = endpoint + + return new_endpoints + + def get_endpoints(self, request): """Iterate over all the registered endpoints in the API and return a fake view with the right parameters. :param rest_framework.request.Request request: request to bind to the endpoint views @@ -131,9 +163,7 @@ class OpenAPISchemaGenerator(object): return {path: (view_cls[path], methods) for path, methods in view_paths.items()} def get_operation_keys(self, subpath, method, view): - """Return a list of keys that should be used to group an operation within the specification. - - :: + """Return a list of keys that should be used to group an operation within the specification. :: /users/ ("users", "list"), ("users", "create") /users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete") @@ -149,39 +179,94 @@ class OpenAPISchemaGenerator(object): """ return self._gen.get_keys(subpath, method, view) - def get_paths(self, endpoints, components, public): + def determine_path_prefix(self, paths): + """ + Given a list of all paths, return the common prefix which should be + discounted when generating a schema structure. + + This will be the longest common string that does not include that last + component of the URL, or the last component before a path parameter. + + For example: :: + + /api/v1/users/ + /api/v1/users/{pk}/ + + The path prefix is ``/api/v1/``. + + :param list[str] paths: list of paths + :rtype: str + """ + return self._gen.determine_path_prefix(paths) + + def get_paths(self, endpoints, components, request, public): """Generate the Swagger Paths for the API from the given endpoints. :param dict endpoints: endpoints as returned by get_endpoints :param ReferenceResolver components: resolver/container for Swagger References + :param Request request: the request made against the schema view; can be None :param bool public: if True, all endpoints are included regardless of access through `request` :rtype: openapi.Paths """ if not endpoints: return openapi.Paths(paths={}) - prefix = self._gen.determine_path_prefix(endpoints.keys()) + prefix = self.determine_path_prefix(list(endpoints.keys())) paths = OrderedDict() - default_schema_cls = SwaggerAutoSchema for path, (view_cls, methods) in sorted(endpoints.items()): - path_parameters = self.get_path_parameters(path, view_cls) operations = {} for method, view in methods: if not public and not self._gen.has_view_permissions(path, method, view): continue - operation_keys = self.get_operation_keys(path[len(prefix):], method, view) - overrides = self.get_overrides(view, method) - auto_schema_cls = overrides.get('auto_schema', default_schema_cls) - schema = auto_schema_cls(view, path, method, overrides, components) - operations[method.lower()] = schema.get_operation(operation_keys) + operations[method.lower()] = self.get_operation(view, path, prefix, method, components, request) if operations: - paths[path] = openapi.PathItem(parameters=path_parameters, **operations) + paths[path] = self.get_path_item(path, view_cls, operations) return openapi.Paths(paths=paths) + def get_operation(self, view, path, prefix, method, components, request): + """Get an :class:`.Operation` for the given API endpoint (path, method). This method delegates to + :meth:`~.inspectors.ViewInspector.get_operation` of a :class:`~.inspectors.ViewInspector` determined + according to settings and :func:`@swagger_auto_schema <.swagger_auto_schema>` overrides. + + :param view: the view associated with this endpoint + :param str path: the path component of the operation URL + :param str prefix: common path prefix among all endpoints + :param str method: the http method of the operation + :param openapi.ReferenceResolver components: referenceable components + :param Request request: the request made against the schema view; can be None + :rtype: openapi.Operation + """ + + operation_keys = self.get_operation_keys(path[len(prefix):], method, view) + overrides = self.get_overrides(view, method) + + # the inspector class can be specified, in decreasing order of priorty, + # 1. globaly via DEFAULT_AUTO_SCHEMA_CLASS + view_inspector_cls = swagger_settings.DEFAULT_AUTO_SCHEMA_CLASS + # 2. on the view/viewset class + view_inspector_cls = getattr(view, 'swagger_schema', view_inspector_cls) + # 3. on the swagger_auto_schema decorator + view_inspector_cls = overrides.get('auto_schema', view_inspector_cls) + + view_inspector = view_inspector_cls(view, path, method, components, request, overrides) + return view_inspector.get_operation(operation_keys) + + def get_path_item(self, path, view_cls, operations): + """Get a :class:`.PathItem` object that describes the parameters and operations related to a single path in the + API. + + :param str path: the path + :param type view_cls: the view that was bound to this path in urlpatterns + :param dict[str,openapi.Operation] operations: operations defined on this path, keyed by lowercase HTTP method + :rtype: openapi.PathItem + """ + path_parameters = self.get_path_parameters(path, view_cls) + return openapi.PathItem(parameters=path_parameters, **operations) + def get_overrides(self, view, method): """Get overrides specified for a given operation. @@ -193,7 +278,7 @@ class OpenAPISchemaGenerator(object): method = method.lower() action = getattr(view, 'action', method) action_method = getattr(view, action, None) - overrides = getattr(action_method, 'swagger_auto_schema', {}) + overrides = getattr(action_method, '_swagger_auto_schema', {}) if method in overrides: overrides = overrides[method] @@ -212,13 +297,21 @@ class OpenAPISchemaGenerator(object): model = getattr(getattr(view_cls, 'queryset', None), 'model', None) for variable in uritemplate.variables(path): - model, model_field = get_model_field(queryset, variable) - attrs = inspect_model_field(model, model_field) + model, model_field = get_queryset_field(queryset, variable) + attrs = get_basic_type_info(model_field) or {'type': openapi.TYPE_STRING} if hasattr(view_cls, 'lookup_value_regex') and getattr(view_cls, 'lookup_field', None) == variable: attrs['pattern'] = view_cls.lookup_value_regex + if model_field and model_field.help_text: + description = force_text(model_field.help_text) + elif model_field and model_field.primary_key: + description = get_pk_description(model, model_field) + else: + description = None + field = openapi.Parameter( name=variable, + description=description, required=True, in_=openapi.IN_PATH, **attrs diff --git a/src/drf_yasg/inspectors/__init__.py b/src/drf_yasg/inspectors/__init__.py new file mode 100644 index 0000000..50d7aa2 --- /dev/null +++ b/src/drf_yasg/inspectors/__init__.py @@ -0,0 +1,38 @@ +from .base import ( + BaseInspector, ViewInspector, FilterInspector, PaginatorInspector, + FieldInspector, SerializerInspector, NotHandled +) +from .field import ( + InlineSerializerInspector, ReferencingSerializerInspector, RelatedFieldInspector, SimpleFieldInspector, + FileFieldInspector, ChoiceFieldInspector, DictFieldInspector, StringDefaultFieldInspector, + CamelCaseJSONFilter +) +from .query import ( + CoreAPICompatInspector, DjangoRestResponsePagination +) +from .view import SwaggerAutoSchema +from ..app_settings import swagger_settings + +# these settings must be accesed only after definig/importing all the classes in this module to avoid ImportErrors +ViewInspector.field_inspectors = swagger_settings.DEFAULT_FIELD_INSPECTORS +ViewInspector.filter_inspectors = swagger_settings.DEFAULT_FILTER_INSPECTORS +ViewInspector.paginator_inspectors = swagger_settings.DEFAULT_PAGINATOR_INSPECTORS + +__all__ = [ + # base inspectors + 'BaseInspector', 'FilterInspector', 'PaginatorInspector', 'FieldInspector', 'SerializerInspector', 'ViewInspector', + + # filter and pagination inspectors + 'CoreAPICompatInspector', 'DjangoRestResponsePagination', + + # field inspectors + 'InlineSerializerInspector', 'ReferencingSerializerInspector', 'RelatedFieldInspector', 'SimpleFieldInspector', + 'FileFieldInspector', 'ChoiceFieldInspector', 'DictFieldInspector', 'StringDefaultFieldInspector', + 'CamelCaseJSONFilter', + + # view inspectors + 'SwaggerAutoSchema', + + # module constants + 'NotHandled', +] diff --git a/src/drf_yasg/inspectors/base.py b/src/drf_yasg/inspectors/base.py new file mode 100644 index 0000000..0dbcbce --- /dev/null +++ b/src/drf_yasg/inspectors/base.py @@ -0,0 +1,406 @@ +import inspect +import logging + +from django.utils.encoding import force_text +from rest_framework import serializers +from rest_framework.utils import json, encoders +from rest_framework.viewsets import GenericViewSet + +from .. import openapi +from ..utils import is_list_view + +#: Sentinel value that inspectors must return to signal that they do not know how to handle an object +NotHandled = object() + +logger = logging.getLogger(__name__) + + +class BaseInspector(object): + def __init__(self, view, path, method, components, request): + """ + :param view: the view associated with this endpoint + :param str path: the path component of the operation URL + :param str method: the http method of the operation + :param openapi.ReferenceResolver components: referenceable components + :param Request request: the request made against the schema view; can be None + """ + self.view = view + self.path = path + self.method = method + self.components = components + self.request = request + + def process_result(self, result, method_name, obj, **kwargs): + """After an inspector handles an object (i.e. returns a value other than :data:`.NotHandled`), all inspectors + that were probed get the chance to alter the result, in reverse order. The inspector that handled the object + is the first to receive a ``process_result`` call with the object it just returned. + + This behaviour is similar to the Django request/response middleware processing. + + If this inspector has no post-processing to do, it should just ``return result`` (the default implementation). + + :param result: the return value of the winning inspector, or ``None`` if no inspector handled the object + :param str method_name: name of the method that was called on the inspector + :param obj: first argument passed to inspector method + :param kwargs: additional arguments passed to inspector method + :return: + """ + return result + + def probe_inspectors(self, inspectors, method_name, obj, initkwargs=None, **kwargs): + """Probe a list of inspectors with a given object. The first inspector in the list to return a value that + is not :data:`.NotHandled` wins. + + :param list[type[BaseInspector]] inspectors: list of inspectors to probe + :param str method_name: name of the target method on the inspector + :param obj: first argument to inspector method + :param dict initkwargs: extra kwargs for instantiating inspector class + :param kwargs: additional arguments to inspector method + :return: the return value of the winning inspector, or ``None`` if no inspector handled the object + """ + initkwargs = initkwargs or {} + tried_inspectors = [] + + for inspector in inspectors: + assert inspect.isclass(inspector), "inspector must be a class, not an object" + assert issubclass(inspector, BaseInspector), "inspectors must subclass BaseInspector" + + inspector = inspector(self.view, self.path, self.method, self.components, self.request, **initkwargs) + tried_inspectors.append(inspector) + method = getattr(inspector, method_name, None) + if method is None: + continue + + result = method(obj, **kwargs) + if result is not NotHandled: + break + else: # pragma: no cover + logger.warning("%s ignored because no inspector in %s handled it (operation: %s)", + obj, inspectors, method_name) + result = None + + for inspector in reversed(tried_inspectors): + result = inspector.process_result(result, method_name, obj, **kwargs) + + return result + + +class PaginatorInspector(BaseInspector): + """Base inspector for paginators. + + Responisble for determining extra query parameters and response structure added by given paginators. + """ + + def get_paginator_parameters(self, paginator): + """Get the pagination parameters for a single paginator **instance**. + + Should return :data:`.NotHandled` if this inspector does not know how to handle the given `paginator`. + + :param BasePagination paginator: the paginator + :rtype: list[openapi.Parameter] + """ + return NotHandled + + def get_paginated_response(self, paginator, response_schema): + """Add appropriate paging fields to a response :class:`.Schema`. + + Should return :data:`.NotHandled` if this inspector does not know how to handle the given `paginator`. + + :param BasePagination paginator: the paginator + :param openapi.Schema response_schema: the response schema that must be paged. + :rtype: openapi.Schema + """ + return NotHandled + + +class FilterInspector(BaseInspector): + """Base inspector for filter backends. + + Responsible for determining extra query parameters added by given filter backends. + """ + + def get_filter_parameters(self, filter_backend): + """Get the filter parameters for a single filter backend **instance**. + + Should return :data:`.NotHandled` if this inspector does not know how to handle the given `filter_backend`. + + :param BaseFilterBackend filter_backend: the filter backend + :rtype: list[openapi.Parameter] + """ + return NotHandled + + +class FieldInspector(BaseInspector): + """Base inspector for serializers and serializer fields. """ + + def __init__(self, view, path, method, components, request, field_inspectors): + super(FieldInspector, self).__init__(view, path, method, components, request) + self.field_inspectors = field_inspectors + + def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + """Convert a drf Serializer or Field instance into a Swagger object. + + Should return :data:`.NotHandled` if this inspector does not know how to handle the given `field`. + + :param rest_framework.serializers.Field field: the source field + :param type[openapi.SwaggerDict] swagger_object_type: should be one of Schema, Parameter, Items + :param bool use_references: if False, forces all objects to be declared inline + instead of by referencing other components + :param kwargs: extra attributes for constructing the object; + if swagger_object_type is Parameter, ``name`` and ``in_`` should be provided + :return: the swagger object + :rtype: openapi.Parameter,openapi.Items,openapi.Schema,openapi.SchemaRef + """ + return NotHandled + + def probe_field_inspectors(self, field, swagger_object_type, use_references, **kwargs): + """Helper method for recursively probing `field_inspectors` to handle a given field. + + All arguments are the same as :meth:`.field_to_swagger_object`. + + :rtype: openapi.Parameter,openapi.Items,openapi.Schema,openapi.SchemaRef + """ + return self.probe_inspectors( + self.field_inspectors, 'field_to_swagger_object', field, {'field_inspectors': self.field_inspectors}, + swagger_object_type=swagger_object_type, use_references=use_references, **kwargs + ) + + def _get_partial_types(self, field, swagger_object_type, use_references, **kwargs): + """Helper method to extract generic information from a field and return a partial constructor for the + appropriate openapi object. + + All arguments are the same as :meth:`.field_to_swagger_object`. + + The return value is a tuple consisting of: + + * a function for constructing objects of `swagger_object_type`; its prototype is: :: + + def SwaggerType(existing_object=None, **instance_kwargs): + + This function creates an instance of `swagger_object_type`, passing the following attributes to its init, + in order of precedence: + + - arguments specified by the ``kwargs`` parameter of :meth:`._get_partial_types` + - ``instance_kwargs`` passed to the constructor function + - ``title``, ``description``, ``required``, ``default`` and ``read_only`` inferred from the field, + where appropriate + + If ``existing_object`` is not ``None``, it is updated instead of creating a new object. + + * a type that should be used for child objects if `field` is of an array type. This can currently have two + values: + + - :class:`.Schema` if `swagger_object_type` is :class:`.Schema` + - :class:`.Items` if `swagger_object_type` is :class:`.Parameter` or :class:`.Items` + + :rtype: tuple[callable,(type[openapi.Schema],type[openapi.Items])] + """ + assert swagger_object_type in (openapi.Schema, openapi.Parameter, openapi.Items) + assert not isinstance(field, openapi.SwaggerDict), "passed field is already a SwaggerDict object" + title = force_text(field.label) if field.label else None + title = title if swagger_object_type == openapi.Schema else None # only Schema has title + description = force_text(field.help_text) if field.help_text else None + description = description if swagger_object_type != openapi.Items else None # Items has no description either + + def SwaggerType(existing_object=None, **instance_kwargs): + if 'required' not in instance_kwargs and swagger_object_type == openapi.Parameter: + instance_kwargs['required'] = field.required + + if 'default' not in instance_kwargs and swagger_object_type != openapi.Items: + default = getattr(field, 'default', serializers.empty) + if default is not serializers.empty: + if callable(default): + try: + if hasattr(default, 'set_context'): + default.set_context(field) + default = default() + except Exception: # pragma: no cover + logger.warning("default for %s is callable but it raised an exception when " + "called; 'default' field will not be added to schema", field, exc_info=True) + default = None + + if default is not None: + try: + default = field.to_representation(default) + # JSON roundtrip ensures that the value is valid JSON; + # for example, sets and tuples get transformed into lists + default = json.loads(json.dumps(default, cls=encoders.JSONEncoder)) + except Exception: # pragma: no cover + logger.warning("'default' on schema for %s will not be set because " + "to_representation raised an exception", field, exc_info=True) + default = None + + if default is not None: + instance_kwargs['default'] = default + + if 'read_only' not in instance_kwargs and swagger_object_type == openapi.Schema: + # TODO: read_only is only relevant for schema `properties` - should not be generated in other cases + if field.read_only: + instance_kwargs['read_only'] = True + + instance_kwargs.setdefault('title', title) + instance_kwargs.setdefault('description', description) + instance_kwargs.update(kwargs) + + if existing_object is not None: + assert isinstance(existing_object, swagger_object_type) + for attr, val in sorted(instance_kwargs.items()): + setattr(existing_object, attr, val) + return existing_object + + return swagger_object_type(**instance_kwargs) + + # arrays in Schema have Schema elements, arrays in Parameter and Items have Items elements + child_swagger_type = openapi.Schema if swagger_object_type == openapi.Schema else openapi.Items + return SwaggerType, child_swagger_type + + +class SerializerInspector(FieldInspector): + def get_schema(self, serializer): + """Convert a DRF Serializer instance to an :class:`.openapi.Schema`. + + Should return :data:`.NotHandled` if this inspector does not know how to handle the given `serializer`. + + :param serializers.BaseSerializer serializer: the ``Serializer`` instance + :rtype: openapi.Schema + """ + return NotHandled + + def get_request_parameters(self, serializer, in_): + """Convert a DRF serializer into a list of :class:`.Parameter`\ s. + + Should return :data:`.NotHandled` if this inspector does not know how to handle the given `serializer`. + + :param serializers.BaseSerializer serializer: the ``Serializer`` instance + :param str in_: the location of the parameters, one of the `openapi.IN_*` constants + :rtype: list[openapi.Parameter] + """ + return NotHandled + + +class ViewInspector(BaseInspector): + body_methods = ('PUT', 'PATCH', 'POST') #: methods that are allowed to have a request body + + # real values set in __init__ to prevent import errors + field_inspectors = [] #: + filter_inspectors = [] #: + paginator_inspectors = [] #: + + def __init__(self, view, path, method, components, request, overrides): + """ + Inspector class responsible for providing :class:`.Operation` definitions given a view, path and method. + + :param dict overrides: manual overrides as passed to :func:`@swagger_auto_schema <.swagger_auto_schema>` + """ + super(ViewInspector, self).__init__(view, path, method, components, request) + self.overrides = overrides + self._prepend_inspector_overrides('field_inspectors') + self._prepend_inspector_overrides('filter_inspectors') + self._prepend_inspector_overrides('paginator_inspectors') + + def _prepend_inspector_overrides(self, inspectors): + extra_inspectors = self.overrides.get(inspectors, None) + if extra_inspectors: + default_inspectors = [insp for insp in getattr(self, inspectors) if insp not in extra_inspectors] + setattr(self, inspectors, extra_inspectors + default_inspectors) + + def get_operation(self, operation_keys): + """Get an :class:`.Operation` for the given API endpoint (path, method). + This includes query, body parameters and response schemas. + + :param tuple[str] operation_keys: an array of keys describing the hierarchical layout of this view in the API; + e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc. + :rtype: openapi.Operation + """ + raise NotImplementedError("ViewInspector must implement get_operation()!") + + # methods below provided as default implementations for probing inspectors + + def should_filter(self): + """Determine whether filter backend parameters should be included for this request. + + :rtype: bool + """ + if not getattr(self.view, 'filter_backends', None): + return False + + if self.method.lower() not in ["get", "delete"]: + return False + + if not isinstance(self.view, GenericViewSet): + return True + + return is_list_view(self.path, self.method, self.view) + + def get_filter_parameters(self): + """Return the parameters added to the view by its filter backends. + + :rtype: list[openapi.Parameter] + """ + if not self.should_filter(): + return [] + + fields = [] + for filter_backend in self.view.filter_backends: + fields += self.probe_inspectors(self.filter_inspectors, 'get_filter_parameters', filter_backend()) or [] + + return fields + + def should_page(self): + """Determine whether paging parameters and structure should be added to this operation's request and response. + + :rtype: bool + """ + if not hasattr(self.view, 'paginator'): + return False + + if self.view.paginator is None: + return False + + if self.method.lower() != 'get': + return False + + return is_list_view(self.path, self.method, self.view) + + def get_pagination_parameters(self): + """Return the parameters added to the view by its paginator. + + :rtype: list[openapi.Parameter] + """ + if not self.should_page(): + return [] + + return self.probe_inspectors(self.paginator_inspectors, 'get_paginator_parameters', self.view.paginator) or [] + + def serializer_to_schema(self, serializer): + """Convert a serializer to an OpenAPI :class:`.Schema`. + + :param serializers.BaseSerializer serializer: the ``Serializer`` instance + :returns: the converted :class:`.Schema`, or ``None`` in case of an unknown serializer + :rtype: openapi.Schema,openapi.SchemaRef,None + """ + return self.probe_inspectors( + self.field_inspectors, 'get_schema', serializer, {'field_inspectors': self.field_inspectors} + ) + + def serializer_to_parameters(self, serializer, in_): + """Convert a serializer to a possibly empty list of :class:`.Parameter`\ s. + + :param serializers.BaseSerializer serializer: the ``Serializer`` instance + :param str in_: the location of the parameters, one of the `openapi.IN_*` constants + :rtype: list[openapi.Parameter] + """ + return self.probe_inspectors( + self.field_inspectors, 'get_request_parameters', serializer, {'field_inspectors': self.field_inspectors}, + in_=in_ + ) or [] + + def get_paginated_response(self, response_schema): + """Add appropriate paging fields to a response :class:`.Schema`. + + :param openapi.Schema response_schema: the response schema that must be paged. + :returns: the paginated response class:`.Schema`, or ``None`` in case of an unknown pagination scheme + :rtype: openapi.Schema + """ + return self.probe_inspectors(self.paginator_inspectors, 'get_paginated_response', + self.view.paginator, response_schema=response_schema) diff --git a/src/drf_yasg/inspectors/field.py b/src/drf_yasg/inspectors/field.py new file mode 100644 index 0000000..ed77d3e --- /dev/null +++ b/src/drf_yasg/inspectors/field.py @@ -0,0 +1,455 @@ +import operator +from collections import OrderedDict + +from django.core import validators +from django.db import models +from rest_framework import serializers +from rest_framework.settings import api_settings as rest_framework_settings + +from .base import NotHandled, SerializerInspector, FieldInspector +from .. import openapi +from ..errors import SwaggerGenerationError +from ..utils import filter_none + + +class InlineSerializerInspector(SerializerInspector): + """Provides serializer conversions using :meth:`.FieldInspector.field_to_swagger_object`.""" + + #: whether to output :class:`.Schema` definitions inline or into the ``definitions`` section + use_definitions = False + + def get_schema(self, serializer): + return self.probe_field_inspectors(serializer, openapi.Schema, self.use_definitions) + + def get_request_parameters(self, serializer, in_): + fields = getattr(serializer, 'fields', {}) + return [ + self.probe_field_inspectors( + value, openapi.Parameter, self.use_definitions, + name=self.get_parameter_name(key), in_=in_ + ) + for key, value + in fields.items() + ] + + def get_property_name(self, field_name): + return field_name + + def get_parameter_name(self, field_name): + return field_name + + def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs) + + if isinstance(field, (serializers.ListSerializer, serializers.ListField)): + child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references) + return SwaggerType( + type=openapi.TYPE_ARRAY, + items=child_schema, + ) + elif isinstance(field, serializers.Serializer): + if swagger_object_type != openapi.Schema: + raise SwaggerGenerationError("cannot instantiate nested serializer as " + swagger_object_type.__name__) + + serializer = field + serializer_meta = getattr(serializer, 'Meta', None) + if hasattr(serializer_meta, 'ref_name'): + ref_name = serializer_meta.ref_name + else: + ref_name = type(serializer).__name__ + if ref_name.endswith('Serializer'): + ref_name = ref_name[:-len('Serializer')] + + def make_schema_definition(): + properties = OrderedDict() + required = [] + for key, value in serializer.fields.items(): + key = self.get_property_name(key) + properties[key] = self.probe_field_inspectors(value, ChildSwaggerType, use_references) + if value.required: + required.append(key) + + return SwaggerType( + type=openapi.TYPE_OBJECT, + properties=properties, + required=required or None, + ) + + if not ref_name or not use_references: + return make_schema_definition() + + definitions = self.components.with_scope(openapi.SCHEMA_DEFINITIONS) + definitions.setdefault(ref_name, make_schema_definition) + return openapi.SchemaRef(definitions, ref_name) + + return NotHandled + + +class ReferencingSerializerInspector(InlineSerializerInspector): + use_definitions = True + + +def get_queryset_field(queryset, field_name): + """Try to get information about a model and model field from a queryset. + + :param queryset: the queryset + :param field_name: target field name + :returns: the model and target field from the queryset as a 2-tuple; both elements can be ``None`` + :rtype: tuple + """ + model = getattr(queryset, 'model', None) + model_field = get_model_field(model, field_name) + return model, model_field + + +def get_model_field(model, field_name): + """Try to get the given field from a django db model. + + :param model: the model + :param field_name: target field name + :return: model field or ``None`` + """ + try: + if field_name == 'pk': + return model._meta.pk + else: + return model._meta.get_field(field_name) + except Exception: # pragma: no cover + return None + + +def get_parent_serializer(field): + """Get the nearest parent ``Serializer`` instance for the given field. + + :return: ``Serializer`` or ``None`` + """ + while field is not None: + if isinstance(field, serializers.Serializer): + return field + + field = field.parent + + return None # pragma: no cover + + +def get_related_model(model, source): + """Try to find the other side of a model relationship given the name of a related field. + + :param model: one side of the relationship + :param str source: related field name + :return: related model or ``None`` + """ + try: + return getattr(model, source).rel.related_model + except Exception: # pragma: no cover + return None + + +class RelatedFieldInspector(FieldInspector): + """Provides conversions for ``RelatedField``\ s.""" + + def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs) + + if isinstance(field, serializers.ManyRelatedField): + child_schema = self.probe_field_inspectors(field.child_relation, ChildSwaggerType, use_references) + return SwaggerType( + type=openapi.TYPE_ARRAY, + items=child_schema, + unique_items=True, + ) + + if not isinstance(field, serializers.RelatedField): + return NotHandled + + field_queryset = getattr(field, 'queryset', None) + + if isinstance(field, (serializers.PrimaryKeyRelatedField, serializers.SlugRelatedField)): + if getattr(field, 'pk_field', ''): + # a PrimaryKeyRelatedField can have a `pk_field` attribute which is a + # serializer field that will convert the PK value + result = self.probe_field_inspectors(field.pk_field, swagger_object_type, use_references, **kwargs) + # take the type, format, etc from `pk_field`, and the field-level information + # like title, description, default from the PrimaryKeyRelatedField + return SwaggerType(existing_object=result) + + target_field = getattr(field, 'slug_field', 'pk') + if field_queryset is not None: + # if the RelatedField has a queryset, try to get the related model field from there + model, model_field = get_queryset_field(field_queryset, target_field) + else: + # if the RelatedField has no queryset (e.g. read only), try to find the target model + # from the view queryset or ModelSerializer model, if present + view_queryset = getattr(self.view, 'queryset', None) + serializer_meta = getattr(get_parent_serializer(field), 'Meta', None) + this_model = getattr(view_queryset, 'model', None) or getattr(serializer_meta, 'model', None) + source = getattr(field, 'source', '') or field.field_name + model = get_related_model(this_model, source) + model_field = get_model_field(model, target_field) + + attrs = get_basic_type_info(model_field) or {'type': openapi.TYPE_STRING} + return SwaggerType(**attrs) + elif isinstance(field, serializers.HyperlinkedRelatedField): + return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI) + + return SwaggerType(type=openapi.TYPE_STRING) + + +def find_regex(regex_field): + """Given a ``Field``, look for a ``RegexValidator`` and try to extract its pattern and return it as a string. + + :param serializers.Field regex_field: the field instance + :return: the extracted pattern, or ``None`` + :rtype: str + """ + regex_validator = None + for validator in regex_field.validators: + if isinstance(validator, validators.RegexValidator): + if regex_validator is not None: + # bail if multiple validators are found - no obvious way to choose + return None # pragma: no cover + regex_validator = validator + + # regex_validator.regex should be a compiled re object... + return getattr(getattr(regex_validator, 'regex', None), 'pattern', None) + + +numeric_fields = (serializers.IntegerField, serializers.FloatField, serializers.DecimalField) +limit_validators = [ + # minimum and maximum apply to numbers + (validators.MinValueValidator, numeric_fields, 'minimum', operator.__gt__), + (validators.MaxValueValidator, numeric_fields, 'maximum', operator.__lt__), + + # minLength and maxLength apply to strings + (validators.MinLengthValidator, serializers.CharField, 'min_length', operator.__gt__), + (validators.MaxLengthValidator, serializers.CharField, 'max_length', operator.__lt__), + + # minItems and maxItems apply to lists + (validators.MinLengthValidator, serializers.ListField, 'min_items', operator.__gt__), + (validators.MaxLengthValidator, serializers.ListField, 'max_items', operator.__lt__), +] + + +def find_limits(field): + """Given a ``Field``, look for min/max value/length validators and return appropriate limit validation attributes. + + :param serializers.Field field: the field instance + :return: the extracted limits + :rtype: OrderedDict + """ + limits = {} + applicable_limits = [ + (validator, attr, improves) + for validator, field_class, attr, improves in limit_validators + if isinstance(field, field_class) + ] + + for validator in field.validators: + if not hasattr(validator, 'limit_value'): + continue + + for validator_class, attr, improves in applicable_limits: + if isinstance(validator, validator_class): + if attr not in limits or improves(validator.limit_value, limits[attr]): + limits[attr] = validator.limit_value + + return OrderedDict(sorted(limits.items())) + + +model_field_to_basic_type = [ + (models.AutoField, (openapi.TYPE_INTEGER, None)), + (models.BinaryField, (openapi.TYPE_STRING, openapi.FORMAT_BINARY)), + (models.BooleanField, (openapi.TYPE_BOOLEAN, None)), + (models.NullBooleanField, (openapi.TYPE_BOOLEAN, None)), + (models.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)), + (models.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)), + (models.DecimalField, (openapi.TYPE_NUMBER, None)), + (models.DurationField, (openapi.TYPE_INTEGER, None)), + (models.FloatField, (openapi.TYPE_NUMBER, None)), + (models.IntegerField, (openapi.TYPE_INTEGER, None)), + (models.IPAddressField, (openapi.TYPE_STRING, openapi.FORMAT_IPV4)), + (models.GenericIPAddressField, (openapi.TYPE_STRING, openapi.FORMAT_IPV6)), + (models.SlugField, (openapi.TYPE_STRING, openapi.FORMAT_SLUG)), + (models.TextField, (openapi.TYPE_STRING, None)), + (models.TimeField, (openapi.TYPE_STRING, None)), + (models.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)), + (models.CharField, (openapi.TYPE_STRING, None)), +] + +ip_format = {'ipv4': openapi.FORMAT_IPV4, 'ipv6': openapi.FORMAT_IPV6} + +serializer_field_to_basic_type = [ + (serializers.EmailField, (openapi.TYPE_STRING, openapi.FORMAT_EMAIL)), + (serializers.SlugField, (openapi.TYPE_STRING, openapi.FORMAT_SLUG)), + (serializers.URLField, (openapi.TYPE_STRING, openapi.FORMAT_URI)), + (serializers.IPAddressField, (openapi.TYPE_STRING, lambda field: ip_format.get(field.protocol, None))), + (serializers.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)), + (serializers.RegexField, (openapi.TYPE_STRING, None)), + (serializers.CharField, (openapi.TYPE_STRING, None)), + ((serializers.BooleanField, serializers.NullBooleanField), (openapi.TYPE_BOOLEAN, None)), + (serializers.IntegerField, (openapi.TYPE_INTEGER, None)), + ((serializers.FloatField, serializers.DecimalField), (openapi.TYPE_NUMBER, None)), + (serializers.DurationField, (openapi.TYPE_NUMBER, None)), # ? + (serializers.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)), + (serializers.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)), + (serializers.ModelField, (openapi.TYPE_STRING, None)), +] + +basic_type_info = serializer_field_to_basic_type + model_field_to_basic_type + + +def get_basic_type_info(field): + """Given a serializer or model ``Field``, return its basic type information - ``type``, ``format``, ``pattern``, + and any applicable min/max limit values. + + :param field: the field instance + :return: the extracted attributes as a dictionary, or ``None`` if the field type is not known + :rtype: OrderedDict + """ + if field is None: + return None + + for field_class, type_format in basic_type_info: + if isinstance(field, field_class): + swagger_type, format = type_format + if callable(format): + format = format(field) + break + else: # pragma: no cover + return None + + pattern = find_regex(field) if format in (None, openapi.FORMAT_SLUG) else None + limits = find_limits(field) + + result = OrderedDict([ + ('type', swagger_type), + ('format', format), + ('pattern', pattern) + ]) + result.update(limits) + result = filter_none(result) + return result + + +class SimpleFieldInspector(FieldInspector): + """Provides conversions for fields which can be described using just ``type``, ``format``, ``pattern`` + and min/max validators. + """ + + def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + type_info = get_basic_type_info(field) + if type_info is None: + return NotHandled + + SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs) + return SwaggerType(**type_info) + + +class ChoiceFieldInspector(FieldInspector): + """Provides conversions for ``ChoiceField`` and ``MultipleChoiceField``.""" + + def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs) + + if isinstance(field, serializers.MultipleChoiceField): + return SwaggerType( + type=openapi.TYPE_ARRAY, + items=ChildSwaggerType( + type=openapi.TYPE_STRING, + enum=list(field.choices.keys()) + ) + ) + elif isinstance(field, serializers.ChoiceField): + return SwaggerType(type=openapi.TYPE_STRING, enum=list(field.choices.keys())) + + return NotHandled + + +class FileFieldInspector(FieldInspector): + """Provides conversions for ``FileField``\ s.""" + + def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs) + + if isinstance(field, serializers.FileField): + # swagger 2.0 does not support specifics about file fields, so ImageFile gets no special treatment + # OpenAPI 3.0 does support it, so a future implementation could handle this better + err = SwaggerGenerationError("FileField is supported only in a formData Parameter or response Schema") + if swagger_object_type == openapi.Schema: + # FileField.to_representation returns URL or file name + result = SwaggerType(type=openapi.TYPE_STRING, read_only=True) + if getattr(field, 'use_url', rest_framework_settings.UPLOADED_FILES_USE_URL): + result.format = openapi.FORMAT_URI + return result + elif swagger_object_type == openapi.Parameter: + param = SwaggerType(type=openapi.TYPE_FILE) + if param['in'] != openapi.IN_FORM: + raise err # pragma: no cover + return param + else: + raise err # pragma: no cover + + return NotHandled + + +class DictFieldInspector(FieldInspector): + """Provides conversion for ``DictField``.""" + + def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs) + + if isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema: + child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references) + return SwaggerType( + type=openapi.TYPE_OBJECT, + additional_properties=child_schema + ) + + return NotHandled + + +class StringDefaultFieldInspector(FieldInspector): + """For otherwise unhandled fields, return them as plain :data:`.TYPE_STRING` objects.""" + + def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): # pragma: no cover + # TODO unhandled fields: TimeField HiddenField JSONField + SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs) + return SwaggerType(type=openapi.TYPE_STRING) + + +try: + from djangorestframework_camel_case.parser import CamelCaseJSONParser + from djangorestframework_camel_case.render import CamelCaseJSONRenderer + from djangorestframework_camel_case.render import camelize +except ImportError: # pragma: no cover + class CamelCaseJSONFilter(FieldInspector): + pass +else: + def camelize_string(s): + """Hack to force ``djangorestframework_camel_case`` to camelize a plain string.""" + return next(iter(camelize({s: ''}))) + + def camelize_schema(schema_or_ref, components): + """Recursively camelize property names for the given schema using ``djangorestframework_camel_case``.""" + schema = openapi.resolve_ref(schema_or_ref, components) + if getattr(schema, 'properties', {}): + schema.properties = OrderedDict( + (camelize_string(key), camelize_schema(val, components)) + for key, val in schema.properties.items() + ) + + if getattr(schema, 'required', []): + schema.required = [camelize_string(p) for p in schema.required] + + return schema_or_ref + + class CamelCaseJSONFilter(FieldInspector): + def is_camel_case(self): + return any(issubclass(parser, CamelCaseJSONParser) for parser in self.view.parser_classes) \ + or any(issubclass(renderer, CamelCaseJSONRenderer) for renderer in self.view.renderer_classes) + + def process_result(self, result, method_name, obj, **kwargs): + if isinstance(result, openapi.Schema.OR_REF) and self.is_camel_case(): + return camelize_schema(result, self.components) + + return result diff --git a/src/drf_yasg/inspectors/query.py b/src/drf_yasg/inspectors/query.py new file mode 100644 index 0000000..90717d2 --- /dev/null +++ b/src/drf_yasg/inspectors/query.py @@ -0,0 +1,76 @@ +from collections import OrderedDict + +import coreschema +from rest_framework.pagination import CursorPagination, PageNumberPagination, LimitOffsetPagination + +from .base import PaginatorInspector, FilterInspector +from .. import openapi + + +class CoreAPICompatInspector(PaginatorInspector, FilterInspector): + """Converts ``coreapi.Field``\ s to :class:`.openapi.Parameter`\ s for filters and paginators that implement a + ``get_schema_fields`` method. + """ + + def get_paginator_parameters(self, paginator): + fields = [] + if hasattr(paginator, 'get_schema_fields'): + fields = paginator.get_schema_fields(self.view) + + return [self.coreapi_field_to_parameter(field) for field in fields] + + def get_filter_parameters(self, filter_backend): + fields = [] + if hasattr(filter_backend, 'get_schema_fields'): + fields = filter_backend.get_schema_fields(self.view) + return [self.coreapi_field_to_parameter(field) for field in fields] + + def coreapi_field_to_parameter(self, field): + """Convert an instance of `coreapi.Field` to a swagger :class:`.Parameter` object. + + :param coreapi.Field field: + :rtype: openapi.Parameter + """ + location_to_in = { + 'query': openapi.IN_QUERY, + 'path': openapi.IN_PATH, + 'form': openapi.IN_FORM, + 'body': openapi.IN_FORM, + } + coreapi_types = { + coreschema.Integer: openapi.TYPE_INTEGER, + coreschema.Number: openapi.TYPE_NUMBER, + coreschema.String: openapi.TYPE_STRING, + coreschema.Boolean: openapi.TYPE_BOOLEAN, + } + return openapi.Parameter( + name=field.name, + in_=location_to_in[field.location], + type=coreapi_types.get(type(field.schema), openapi.TYPE_STRING), + required=field.required, + description=field.schema.description, + ) + + +class DjangoRestResponsePagination(PaginatorInspector): + """Provides response schema pagination warpping for django-rest-framework's LimitOffsetPagination, + PageNumberPagination and CursorPagination + """ + + def get_paginated_response(self, paginator, response_schema): + assert response_schema.type == openapi.TYPE_ARRAY, "array return expected for paged response" + paged_schema = None + if isinstance(paginator, (LimitOffsetPagination, PageNumberPagination, CursorPagination)): + has_count = not isinstance(paginator, CursorPagination) + paged_schema = openapi.Schema( + type=openapi.TYPE_OBJECT, + properties=OrderedDict(( + ('count', openapi.Schema(type=openapi.TYPE_INTEGER) if has_count else None), + ('next', openapi.Schema(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)), + ('previous', openapi.Schema(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)), + ('results', response_schema), + )), + required=['count', 'results'] + ) + + return paged_schema diff --git a/src/drf_yasg/inspectors.py b/src/drf_yasg/inspectors/view.py similarity index 57% rename from src/drf_yasg/inspectors.py rename to src/drf_yasg/inspectors/view.py index 6cb03e8..ff3c6fd 100644 --- a/src/drf_yasg/inspectors.py +++ b/src/drf_yasg/inspectors/view.py @@ -1,63 +1,22 @@ -import inspect from collections import OrderedDict -import coreschema -from rest_framework import serializers, status from rest_framework.request import is_form_media_type from rest_framework.schemas import AutoSchema from rest_framework.status import is_success -from rest_framework.viewsets import GenericViewSet -from . import openapi -from .errors import SwaggerGenerationError -from .utils import serializer_field_to_swagger, no_body, is_list_view, param_list_to_odict +from .base import ViewInspector +from .. import openapi +from ..errors import SwaggerGenerationError +from ..utils import force_serializer_instance, no_body, is_list_view, param_list_to_odict, guess_response_status -def force_serializer_instance(serializer): - """Force `serializer` into a ``Serializer`` instance. If it is not a ``Serializer`` class or instance, raises - an assertion error. - - :param serializer: serializer class or instance - :return: serializer instance - """ - if inspect.isclass(serializer): - assert issubclass(serializer, serializers.BaseSerializer), "Serializer required, not %s" % serializer.__name__ - return serializer() - - assert isinstance(serializer, serializers.BaseSerializer), \ - "Serializer class or instance required, not %s" % type(serializer).__name__ - return serializer - - -class SwaggerAutoSchema(object): - body_methods = ('PUT', 'PATCH', 'POST') #: methods allowed to have a request body - - def __init__(self, view, path, method, overrides, components): - """Inspector class responsible for providing :class:`.Operation` definitions given a - - :param view: the view associated with this endpoint - :param str path: the path component of the operation URL - :param str method: the http method of the operation - :param dict overrides: manual overrides as passed to :func:`@swagger_auto_schema <.swagger_auto_schema>` - :param openapi.ReferenceResolver components: referenceable components - """ - super(SwaggerAutoSchema, self).__init__() +class SwaggerAutoSchema(ViewInspector): + def __init__(self, view, path, method, components, request, overrides): + super(SwaggerAutoSchema, self).__init__(view, path, method, components, request, overrides) self._sch = AutoSchema() - self.view = view - self.path = path - self.method = method - self.overrides = overrides - self.components = components self._sch.view = view def get_operation(self, operation_keys): - """Get an :class:`.Operation` for the given API endpoint (path, method). - This includes query, body parameters and response schemas. - - :param tuple[str] operation_keys: an array of keys describing the hierarchical layout of this view in the API; - e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc. - :rtype: openapi.Operation - """ consumes = self.get_consumes() body = self.get_request_body_parameters(consumes) @@ -66,17 +25,19 @@ class SwaggerAutoSchema(object): parameters = [param for param in parameters if param is not None] parameters = self.add_manual_parameters(parameters) + operation_id = self.get_operation_id(operation_keys) description = self.get_description() + tags = self.get_tags(operation_keys) responses = self.get_responses() return openapi.Operation( - operation_id='_'.join(operation_keys), + operation_id=operation_id, description=description, responses=responses, parameters=parameters, consumes=consumes, - tags=[operation_keys[0]], + tags=tags, ) def get_request_body_parameters(self, consumes): @@ -105,7 +66,7 @@ class SwaggerAutoSchema(object): else: if schema is None: schema = self.get_request_body_schema(serializer) - return [self.make_body_parameter(schema)] + return [self.make_body_parameter(schema)] if schema is not None else [] def get_view_serializer(self): """Return the serializer as defined by the view's ``get_serializer()`` method. @@ -192,26 +153,6 @@ class SwaggerAutoSchema(object): responses=self.get_response_schemas(response_serializers) ) - def get_paged_response_schema(self, response_schema): - """Add appropriate paging fields to a response :class:`.Schema`. - - :param openapi.Schema response_schema: the response schema that must be paged. - :rtype: openapi.Schema - """ - assert response_schema.type == openapi.TYPE_ARRAY, "array return expected for paged response" - paged_schema = openapi.Schema( - type=openapi.TYPE_OBJECT, - properties={ - 'count': openapi.Schema(type=openapi.TYPE_INTEGER), - 'next': openapi.Schema(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI), - 'previous': openapi.Schema(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI), - 'results': response_schema, - }, - required=['count', 'results'] - ) - - return paged_schema - def get_default_responses(self): """Get the default responses determined for this view from the request serializer and request method. @@ -219,28 +160,26 @@ class SwaggerAutoSchema(object): """ method = self.method.lower() - default_status = status.HTTP_200_OK + default_status = guess_response_status(method) default_schema = '' if method == 'post': - default_status = status.HTTP_201_CREATED default_schema = self.get_request_serializer() or self.get_view_serializer() - elif method == 'delete': - default_status = status.HTTP_204_NO_CONTENT elif method in ('get', 'put', 'patch'): default_schema = self.get_request_serializer() or self.get_view_serializer() default_schema = default_schema or '' if any(is_form_media_type(encoding) for encoding in self.get_consumes()): default_schema = '' + if default_schema and not isinstance(default_schema, openapi.Schema): + default_schema = self.serializer_to_schema(default_schema) or '' + if default_schema: - if not isinstance(default_schema, openapi.Schema): - default_schema = self.serializer_to_schema(default_schema) if is_list_view(self.path, self.method, self.view) and self.method.lower() == 'get': default_schema = openapi.Schema(type=openapi.TYPE_ARRAY, items=default_schema) if self.should_page(): - default_schema = self.get_paged_response_schema(default_schema) + default_schema = self.get_paginated_response(default_schema) or default_schema - return {str(default_status): default_schema} + return OrderedDict({str(default_status): default_schema}) def get_response_serializers(self): """Return the response codes that this view is expected to return, and the serializer for each response body. @@ -254,7 +193,7 @@ class SwaggerAutoSchema(object): manual_responses = self.overrides.get('responses', None) or {} manual_responses = OrderedDict((str(sc), resp) for sc, resp in manual_responses.items()) - responses = {} + responses = OrderedDict() if not any(is_success(int(sc)) for sc in manual_responses if sc != 'default'): responses = self.get_default_responses() @@ -268,7 +207,7 @@ class SwaggerAutoSchema(object): :return: a dictionary of status code to :class:`.Response` object :rtype: dict[str, openapi.Response] """ - responses = {} + responses = OrderedDict() for sc, serializer in response_serializers.items(): if isinstance(serializer, str): response = openapi.Response( @@ -325,84 +264,18 @@ class SwaggerAutoSchema(object): return natural_parameters + serializer_parameters - def should_filter(self): - """Determine whether filter backend parameters should be included for this request. + def get_operation_id(self, operation_keys): + """Return an unique ID for this operation. The ID must be unique across + all :class:`.Operation` objects in the API. - :rtype: bool + :param tuple[str] operation_keys: an array of keys derived from the pathdescribing the hierarchical layout + of this view in the API; e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc. + :rtype: str """ - if not getattr(self.view, 'filter_backends', None): - return False - - if self.method.lower() not in ["get", "delete"]: - return False - - if not isinstance(self.view, GenericViewSet): - return True - - return is_list_view(self.path, self.method, self.view) - - def get_filter_backend_parameters(self, filter_backend): - """Get the filter parameters for a single filter backend **instance**. - - :param BaseFilterBackend filter_backend: the filter backend - :rtype: list[openapi.Parameter] - """ - fields = [] - if hasattr(filter_backend, 'get_schema_fields'): - fields = filter_backend.get_schema_fields(self.view) - return [self.coreapi_field_to_parameter(field) for field in fields] - - def get_filter_parameters(self): - """Return the parameters added to the view by its filter backends. - - :rtype: list[openapi.Parameter] - """ - if not self.should_filter(): - return [] - - fields = [] - for filter_backend in self.view.filter_backends: - fields += self.get_filter_backend_parameters(filter_backend()) - - return fields - - def should_page(self): - """Determine whether paging parameters and structure should be added to this operation's request and response. - - :rtype: bool - """ - if not hasattr(self.view, 'paginator'): - return False - - if self.view.paginator is None: - return False - - if self.method.lower() != 'get': - return False - - return is_list_view(self.path, self.method, self.view) - - def get_paginator_parameters(self, paginator): - """Get the pagination parameters for a single paginator **instance**. - - :param BasePagination paginator: the paginator - :rtype: list[openapi.Parameter] - """ - fields = [] - if hasattr(paginator, 'get_schema_fields'): - fields = paginator.get_schema_fields(self.view) - - return [self.coreapi_field_to_parameter(field) for field in fields] - - def get_pagination_parameters(self): - """Return the parameters added to the view by its paginator. - - :rtype: list[openapi.Parameter] - """ - if not self.should_page(): - return [] - - return self.get_paginator_parameters(self.view.paginator) + operation_id = self.overrides.get('operation_id', '') + if not operation_id: + operation_id = '_'.join(operation_keys) + return operation_id def get_description(self): """Return an operation description determined as appropriate from the view's method and class docstrings. @@ -415,6 +288,16 @@ class SwaggerAutoSchema(object): description = self._sch.get_description(self.path, self.method) return description + def get_tags(self, operation_keys): + """Get a list of tags for this operation. Tags determine how operations relate with each other, and in the UI + each tag will show as a group containing the operations that use it. + + :param tuple[str] operation_keys: an array of keys derived from the pathdescribing the hierarchical layout + of this view in the API; e.g. ``('snippets', 'list')``, ``('snippets', 'retrieve')``, etc. + :rtype: list[str] + """ + return [operation_keys[0]] + def get_consumes(self): """Return the MIME types this endpoint can consume. @@ -424,62 +307,3 @@ class SwaggerAutoSchema(object): if all(is_form_media_type(encoding) for encoding in media_types): return media_types return media_types[:1] - - def serializer_to_schema(self, serializer): - """Convert a DRF Serializer instance to an :class:`.openapi.Schema`. - - :param serializers.BaseSerializer serializer: the ``Serializer`` instance - :rtype: openapi.Schema - """ - definitions = self.components.with_scope(openapi.SCHEMA_DEFINITIONS) - return serializer_field_to_swagger(serializer, openapi.Schema, definitions) - - def serializer_to_parameters(self, serializer, in_): - """Convert a DRF serializer into a list of :class:`.Parameter`\ s using :meth:`.field_to_parameter` - - :param serializers.BaseSerializer serializer: the ``Serializer`` instance - :param str in_: the location of the parameters, one of the `openapi.IN_*` constants - :rtype: list[openapi.Parameter] - """ - fields = getattr(serializer, 'fields', {}) - return [ - self.field_to_parameter(value, key, in_) - for key, value - in fields.items() - ] - - def field_to_parameter(self, field, name, in_): - """Convert a DRF serializer Field to a swagger :class:`.Parameter` object. - - :param coreapi.Field field: - :param str name: the name of the parameter - :param str in_: the location of the parameter, one of the `openapi.IN_*` constants - :rtype: openapi.Parameter - """ - return serializer_field_to_swagger(field, openapi.Parameter, name=name, in_=in_) - - def coreapi_field_to_parameter(self, field): - """Convert an instance of `coreapi.Field` to a swagger :class:`.Parameter` object. - - :param coreapi.Field field: - :rtype: openapi.Parameter - """ - location_to_in = { - 'query': openapi.IN_QUERY, - 'path': openapi.IN_PATH, - 'form': openapi.IN_FORM, - 'body': openapi.IN_FORM, - } - coreapi_types = { - coreschema.Integer: openapi.TYPE_INTEGER, - coreschema.Number: openapi.TYPE_NUMBER, - coreschema.String: openapi.TYPE_STRING, - coreschema.Boolean: openapi.TYPE_BOOLEAN, - } - return openapi.Parameter( - name=field.name, - in_=location_to_in[field.location], - type=coreapi_types.get(type(field.schema), openapi.TYPE_STRING), - required=field.required, - description=field.schema.description, - ) diff --git a/src/drf_yasg/openapi.py b/src/drf_yasg/openapi.py index 8b3d3b0..7d9397f 100644 --- a/src/drf_yasg/openapi.py +++ b/src/drf_yasg/openapi.py @@ -1,9 +1,11 @@ +import re from collections import OrderedDict from coreapi.compat import urlparse -from future.utils import raise_from from inflection import camelize +from .utils import filter_none + TYPE_OBJECT = "object" #: TYPE_STRING = "string" #: TYPE_NUMBER = "number" #: @@ -94,8 +96,9 @@ class SwaggerDict(OrderedDict): raise AttributeError try: return self[make_swagger_name(item)] - except KeyError as e: - raise_from(AttributeError("object of class " + type(self).__name__ + " has no attribute " + item), e) + except KeyError: + # raise_from is EXTREMELY slow, replaced with plain raise + raise AttributeError("object of class " + type(self).__name__ + " has no attribute " + item) def __delattr__(self, item): if item.startswith('_'): @@ -230,7 +233,7 @@ class Swagger(SwaggerDict): self.base_path = '/' self.paths = paths - self.definitions = definitions + self.definitions = filter_none(definitions) self._insert_extras__() @@ -270,13 +273,13 @@ class PathItem(SwaggerDict): self.patch = patch self.delete = delete self.options = options - self.parameters = parameters + self.parameters = filter_none(parameters) self._insert_extras__() class Operation(SwaggerDict): def __init__(self, operation_id, responses, parameters=None, consumes=None, - produces=None, description=None, tags=None, **extra): + produces=None, summary=None, description=None, tags=None, **extra): """Information about an API operation (path + http method combination) :param str operation_id: operation ID, should be unique across all operations @@ -284,17 +287,19 @@ class Operation(SwaggerDict): :param list[.Parameter] parameters: parameters accepted :param list[str] consumes: content types accepted :param list[str] produces: content types produced - :param str description: operation description + :param str summary: operation summary; should be < 120 characters + :param str description: operation description; can be of any length and supports markdown :param list[str] tags: operation tags """ super(Operation, self).__init__(**extra) self.operation_id = operation_id + self.summary = summary self.description = description - self.parameters = [param for param in parameters if param is not None] + self.parameters = filter_none(parameters) self.responses = responses - self.consumes = consumes - self.produces = produces - self.tags = tags + self.consumes = filter_none(consumes) + self.produces = filter_none(produces) + self.tags = filter_none(tags) self._insert_extras__() @@ -352,21 +357,26 @@ class Parameter(SwaggerDict): class Schema(SwaggerDict): - OR_REF = () + OR_REF = () #: useful for type-checking, e.g ``isinstance(obj, openapi.Schema.OR_REF)`` - def __init__(self, description=None, required=None, type=None, properties=None, additional_properties=None, - format=None, enum=None, pattern=None, items=None, **extra): + def __init__(self, title=None, description=None, type=None, format=None, enum=None, pattern=None, properties=None, + additional_properties=None, required=None, items=None, default=None, read_only=None, **extra): """Describes a complex object accepted as parameter or returned as a response. - :param description: schema description - :param list[str] required: list of requried property names + :param str title: schema title + :param str description: schema description :param str type: value type; required - :param list[.Schema,.SchemaRef] properties: object properties; required if `type` is ``object`` - :param bool,.Schema,.SchemaRef additional_properties: allow wildcard properties not listed in `properties` :param str format: value format, see OpenAPI spec :param list enum: restrict possible values :param str pattern: pattern if type is ``string`` - :param .Schema,.SchemaRef items: only valid if `type` is ``array`` + :param list[.Schema,.SchemaRef] properties: object properties; required if `type` is ``object`` + :param bool,.Schema,.SchemaRef additional_properties: allow wildcard properties not listed in `properties` + :param list[str] required: list of requried property names + :param .Schema,.SchemaRef items: type of array items, only valid if `type` is ``array`` + :param default: only valid when insider another ``Schema``\ 's ``properties``; + the default value of this property if it is not provided, must conform to the type of this Schema + :param read_only: only valid when insider another ``Schema``\ 's ``properties``; + declares the property as read only - it must only be sent as part of responses, never in requests """ super(Schema, self).__init__(**extra) if required is True or required is False: @@ -374,19 +384,24 @@ class Schema(SwaggerDict): raise AssertionError( "the `requires` attribute of schema must be an array of required properties, not a boolean!") assert type is not None, "type is required!" + self.title = title self.description = description - self.required = required + self.required = filter_none(required) self.type = type - self.properties = properties + self.properties = filter_none(properties) self.additional_properties = additional_properties self.format = format self.enum = enum self.pattern = pattern self.items = items + self.read_only = read_only + self.default = default self._insert_extras__() class _Ref(SwaggerDict): + ref_name_re = re.compile(r"#/(?P.+)/(?P[^/]+)$") + def __init__(self, resolver, name, scope, expected_type): """Base class for all reference types. A reference object has only one property, ``$ref``, which must be a JSON reference to a valid object in the specification, e.g. ``#/definitions/Article`` to refer to an article model. @@ -404,6 +419,15 @@ class _Ref(SwaggerDict): .format(actual=type(obj).__name__, expected=expected_type.__name__) self.ref = ref_name + def resolve(self, resolver): + """Get the object targeted by this reference from the given component resolver. + + :param .ReferenceResolver resolver: component resolver which must contain the referneced object + :returns: the target object + """ + ref_match = self.ref_name_re.match(self.ref) + return resolver.get(ref_match.group('name'), scope=ref_match.group('scope')) + def __setitem__(self, key, value, **kwargs): if key == "$ref": return super(_Ref, self).__setitem__(key, value, **kwargs) @@ -427,6 +451,17 @@ class SchemaRef(_Ref): Schema.OR_REF = (Schema, SchemaRef) +def resolve_ref(ref_or_obj, resolver): + """Resolve `ref_or_obj` if it is a reference type. Return it unchaged if not. + + :param SwaggerDict,_Ref ref_or_obj: + :param resolver: component resolver which must contain the referenced object + """ + if isinstance(ref_or_obj, _Ref): + return ref_or_obj.resolve(resolver) + return ref_or_obj + + class Responses(SwaggerDict): def __init__(self, responses, default=None, **extra): """Describes the expected responses of an :class:`.Operation`. @@ -483,7 +518,7 @@ class ReferenceResolver(object): self._objects[scope] = OrderedDict() def with_scope(self, scope): - """Return a new :class:`.ReferenceResolver` whose scope is defaulted and forced to `scope`. + """Return a view into this :class:`.ReferenceResolver` whose scope is defaulted and forced to `scope`. :param str scope: target scope, must be in this resolver's `scopes` :return: the bound resolver diff --git a/src/drf_yasg/renderers.py b/src/drf_yasg/renderers.py index 4352812..87448ed 100644 --- a/src/drf_yasg/renderers.py +++ b/src/drf_yasg/renderers.py @@ -14,7 +14,7 @@ class _SpecRenderer(BaseRenderer): @classmethod def with_validators(cls, validators): - assert all(vld in VALIDATORS for vld in validators), "allowed validators are" + ", ".join(VALIDATORS) + assert all(vld in VALIDATORS for vld in validators), "allowed validators are " + ", ".join(VALIDATORS) return type(cls.__name__, (cls,), {'validators': validators}) def render(self, data, media_type=None, renderer_context=None): @@ -45,7 +45,7 @@ class SwaggerYAMLRenderer(_SpecRenderer): class _UIRenderer(BaseRenderer): - """Base class for web UI renderers. Handles loading an passing settings to the appropriate template.""" + """Base class for web UI renderers. Handles loading and passing settings to the appropriate template.""" media_type = 'text/html' charset = 'utf-8' template = '' diff --git a/src/drf_yasg/templates/drf-yasg/swagger-ui.html b/src/drf_yasg/templates/drf-yasg/swagger-ui.html index 0a2ad2a..f3a0473 100644 --- a/src/drf_yasg/templates/drf-yasg/swagger-ui.html +++ b/src/drf_yasg/templates/drf-yasg/swagger-ui.html @@ -166,9 +166,11 @@ layout: "StandaloneLayout", filter: true, requestInterceptor: function(request) { - console.log(request); var headers = request.headers || {}; - headers["X-CSRFToken"] = document.querySelector("[name=csrfmiddlewaretoken]").value; + var csrftoken = document.querySelector("[name=csrfmiddlewaretoken]"); + if (csrftoken) { + headers["X-CSRFToken"] = csrftoken.value; + } return request; } }; diff --git a/src/drf_yasg/utils.py b/src/drf_yasg/utils.py index 81a78fd..3c6302d 100644 --- a/src/drf_yasg/utils.py +++ b/src/drf_yasg/utils.py @@ -1,17 +1,9 @@ +import inspect import logging from collections import OrderedDict -from django.core.validators import RegexValidator -from django.db import models -from django.utils.encoding import force_text -from rest_framework import serializers +from rest_framework import status, serializers from rest_framework.mixins import RetrieveModelMixin, DestroyModelMixin, UpdateModelMixin -from rest_framework.schemas.inspectors import get_pk_description -from rest_framework.settings import api_settings -from rest_framework.utils import json, encoders - -from . import openapi -from .errors import SwaggerGenerationError logger = logging.getLogger(__name__) @@ -19,6 +11,141 @@ logger = logging.getLogger(__name__) no_body = object() +def swagger_auto_schema(method=None, methods=None, auto_schema=None, request_body=None, query_serializer=None, + manual_parameters=None, operation_id=None, operation_description=None, responses=None, + field_inspectors=None, filter_inspectors=None, paginator_inspectors=None, + **extra_overrides): + """Decorate a view method to customize the :class:`.Operation` object generated from it. + + `method` and `methods` are mutually exclusive and must only be present when decorating a view method that accepts + more than one HTTP request method. + + The `auto_schema` and `operation_description` arguments take precendence over view- or method-level values. + + .. versionchanged:: 1.1 + Added the ``extra_overrides`` and ``operatiod_id`` parameters. + + .. versionchanged:: 1.1 + Added the ``field_inspectors``, ``filter_inspectors`` and ``paginator_inspectors`` parameters. + + :param str method: for multi-method views, the http method the options should apply to + :param list[str] methods: for multi-method views, the http methods the options should apply to + :param .inspectors.SwaggerAutoSchema auto_schema: custom class to use for generating the Operation object; + this overrides both the class-level ``swagger_schema`` attribute and the ``DEFAULT_AUTO_SCHEMA_CLASS`` + setting + :param .Schema,.SchemaRef,.Serializer request_body: custom request body, or :data:`.no_body`. The value given here + will be used as the ``schema`` property of a :class:`.Parameter` with ``in: 'body'``. + + A Schema or SchemaRef is not valid if this request consumes form-data, because ``form`` and ``body`` parameters + are mutually exclusive in an :class:`.Operation`. If you need to set custom ``form`` parameters, you can use + the `manual_parameters` argument. + + If a ``Serializer`` class or instance is given, it will be automatically converted into a :class:`.Schema` + used as a ``body`` :class:`.Parameter`, or into a list of ``form`` :class:`.Parameter`\ s, as appropriate. + + :param .Serializer query_serializer: if you use a ``Serializer`` to parse query parameters, you can pass it here + and have :class:`.Parameter` objects be generated automatically from it. + + If any ``Field`` on the serializer cannot be represented as a ``query`` :class:`.Parameter` + (e.g. nested Serializers, file fields, ...), the schema generation will fail with an error. + + Schema generation will also fail if the name of any Field on the `query_serializer` conflicts with parameters + generated by ``filter_backends`` or ``paginator``. + + :param list[.Parameter] manual_parameters: a list of manual parameters to override the automatically generated ones + + :class:`.Parameter`\ s are identified by their (``name``, ``in``) combination, and any parameters given + here will fully override automatically generated parameters if they collide. + + It is an error to supply ``form`` parameters when the request does not consume form-data. + + :param str operation_id: operation ID override; the operation ID must be unique accross the whole API + :param str operation_description: operation description override + :param dict[str,(.Schema,.SchemaRef,.Response,str,Serializer)] responses: a dict of documented manual responses + keyed on response status code. If no success (``2xx``) response is given, one will automatically be + generated from the request body and http method. If any ``2xx`` response is given the automatic response is + suppressed. + + * if a plain string is given as value, a :class:`.Response` with no body and that string as its description + will be generated + * if a :class:`.Schema`, :class:`.SchemaRef` is given, a :class:`.Response` with the schema as its body and + an empty description will be generated + * a ``Serializer`` class or instance will be converted into a :class:`.Schema` and treated as above + * a :class:`.Response` object will be used as-is; however if its ``schema`` attribute is a ``Serializer``, + it will automatically be converted into a :class:`.Schema` + + :param list[.FieldInspector] field_inspectors: extra serializer and field inspectors; these will be tried + before :attr:`.ViewInspector.field_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance + :param list[.FilterInspector] filter_inspectors: extra filter inspectors; these will be tried before + :attr:`.ViewInspector.filter_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance + :param list[.PaginatorInspector] paginator_inspectors: extra paginator inspectors; these will be tried before + :attr:`.ViewInspector.paginator_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance + :param extra_overrides: extra values that will be saved into the ``overrides`` dict; these values will be available + in the handling :class:`.inspectors.SwaggerAutoSchema` instance via ``self.overrides`` + """ + + def decorator(view_method): + data = { + 'auto_schema': auto_schema, + 'request_body': request_body, + 'query_serializer': query_serializer, + 'manual_parameters': manual_parameters, + 'operation_id': operation_id, + 'operation_description': operation_description, + 'responses': responses, + 'filter_inspectors': list(filter_inspectors) if filter_inspectors else None, + 'paginator_inspectors': list(paginator_inspectors) if paginator_inspectors else None, + 'field_inspectors': list(field_inspectors) if field_inspectors else None, + } + data = {k: v for k, v in data.items() if v is not None} + data.update(extra_overrides) + + # if the method is a detail_route or list_route, it will have a bind_to_methods attribute + bind_to_methods = getattr(view_method, 'bind_to_methods', []) + # if the method is actually a function based view (@api_view), it will have a 'cls' attribute + view_cls = getattr(view_method, 'cls', None) + http_method_names = getattr(view_cls, 'http_method_names', []) + if bind_to_methods or http_method_names: + # detail_route, list_route or api_view + assert bool(http_method_names) != bool(bind_to_methods), "this should never happen" + available_methods = http_method_names + bind_to_methods + existing_data = getattr(view_method, '_swagger_auto_schema', {}) + + if http_method_names: + _route = "api_view" + else: + _route = "detail_route" if view_method.detail else "list_route" + + _methods = methods + if len(available_methods) > 1: + assert methods or method, \ + "on multi-method %s, you must specify swagger_auto_schema on a per-method basis " \ + "using one of the `method` or `methods` arguments" % _route + assert bool(methods) != bool(method), "specify either method or methods" + assert not isinstance(methods, str), "`methods` expects to receive a list of methods;" \ + " use `method` for a single argument" + if method: + _methods = [method.lower()] + else: + _methods = [mth.lower() for mth in methods] + assert not any(mth in existing_data for mth in _methods), "method defined multiple times" + assert all(mth in available_methods for mth in _methods), "method not bound to %s" % _route + + existing_data.update((mth.lower(), data) for mth in _methods) + else: + existing_data[available_methods[0]] = data + view_method._swagger_auto_schema = existing_data + else: + assert method is None and methods is None, \ + "the methods argument should only be specified when decorating a detail_route or list_route; you " \ + "should also ensure that you put the swagger_auto_schema decorator AFTER (above) the _route decorator" + view_method._swagger_auto_schema = data + + return view_method + + return decorator + + def is_list_view(path, method, view): """Check if the given path/method appears to represent a list view (as opposed to a detail/instance view). @@ -52,431 +179,13 @@ def is_list_view(path, method, view): return True -def swagger_auto_schema(method=None, methods=None, auto_schema=None, request_body=None, query_serializer=None, - manual_parameters=None, operation_description=None, responses=None): - """Decorate a view method to customize the :class:`.Operation` object generated from it. - - `method` and `methods` are mutually exclusive and must only be present when decorating a view method that accepts - more than one HTTP request method. - - The `auto_schema` and `operation_description` arguments take precendence over view- or method-level values. - - :param str method: for multi-method views, the http method the options should apply to - :param list[str] methods: for multi-method views, the http methods the options should apply to - :param .SwaggerAutoSchema auto_schema: custom class to use for generating the Operation object - :param .Schema,.SchemaRef,.Serializer request_body: custom request body, or :data:`.no_body`. The value given here - will be used as the ``schema`` property of a :class:`.Parameter` with ``in: 'body'``. - - A Schema or SchemaRef is not valid if this request consumes form-data, because ``form`` and ``body`` parameters - are mutually exclusive in an :class:`.Operation`. If you need to set custom ``form`` parameters, you can use - the `manual_parameters` argument. - - If a ``Serializer`` class or instance is given, it will be automatically converted into a :class:`.Schema` - used as a ``body`` :class:`.Parameter`, or into a list of ``form`` :class:`.Parameter`\ s, as appropriate. - - :param .Serializer query_serializer: if you use a ``Serializer`` to parse query parameters, you can pass it here - and have :class:`.Parameter` objects be generated automatically from it. - - If any ``Field`` on the serializer cannot be represented as a ``query`` :class:`.Parameter` - (e.g. nested Serializers, file fields, ...), the schema generation will fail with an error. - - Schema generation will also fail if the name of any Field on the `query_serializer` conflicts with parameters - generated by ``filter_backends`` or ``paginator``. - - :param list[.Parameter] manual_parameters: a list of manual parameters to override the automatically generated ones - - :class:`.Parameter`\ s are identified by their (``name``, ``in``) combination, and any parameters given - here will fully override automatically generated parameters if they collide. - - It is an error to supply ``form`` parameters when the request does not consume form-data. - - :param str operation_description: operation description override - :param dict[str,(.Schema,.SchemaRef,.Response,str,Serializer)] responses: a dict of documented manual responses - keyed on response status code. If no success (``2xx``) response is given, one will automatically be - generated from the request body and http method. If any ``2xx`` response is given the automatic response is - suppressed. - - * if a plain string is given as value, a :class:`.Response` with no body and that string as its description - will be generated - * if a :class:`.Schema`, :class:`.SchemaRef` is given, a :class:`.Response` with the schema as its body and - an empty description will be generated - * a ``Serializer`` class or instance will be converted into a :class:`.Schema` and treated as above - * a :class:`.Response` object will be used as-is; however if its ``schema`` attribute is a ``Serializer``, - it will automatically be converted into a :class:`.Schema` - - """ - - def decorator(view_method): - data = { - 'auto_schema': auto_schema, - 'request_body': request_body, - 'query_serializer': query_serializer, - 'manual_parameters': manual_parameters, - 'operation_description': operation_description, - 'responses': responses, - } - data = {k: v for k, v in data.items() if v is not None} - - # if the method is a detail_route or list_route, it will have a bind_to_methods attribute - bind_to_methods = getattr(view_method, 'bind_to_methods', []) - # if the method is actually a function based view (@api_view), it will have a 'cls' attribute - view_cls = getattr(view_method, 'cls', None) - http_method_names = getattr(view_cls, 'http_method_names', []) - if bind_to_methods or http_method_names: - # detail_route, list_route or api_view - assert bool(http_method_names) != bool(bind_to_methods), "this should never happen" - available_methods = http_method_names + bind_to_methods - existing_data = getattr(view_method, 'swagger_auto_schema', {}) - - if http_method_names: - _route = "api_view" - else: - _route = "detail_route" if view_method.detail else "list_route" - - _methods = methods - if len(available_methods) > 1: - assert methods or method, \ - "on multi-method %s, you must specify swagger_auto_schema on a per-method basis " \ - "using one of the `method` or `methods` arguments" % _route - assert bool(methods) != bool(method), "specify either method or methods" - assert not isinstance(methods, str), "`methods` expects to receive a list of methods;" \ - " use `method` for a single argument" - if method: - _methods = [method.lower()] - else: - _methods = [mth.lower() for mth in methods] - assert not any(mth in existing_data for mth in _methods), "method defined multiple times" - assert all(mth in available_methods for mth in _methods), "method not bound to %s" % _route - - existing_data.update((mth.lower(), data) for mth in _methods) - else: - existing_data[available_methods[0]] = data - view_method.swagger_auto_schema = existing_data - else: - assert method is None and methods is None, \ - "the methods argument should only be specified when decorating a detail_route or list_route; you " \ - "should also ensure that you put the swagger_auto_schema decorator AFTER (above) the _route decorator" - view_method.swagger_auto_schema = data - - return view_method - - return decorator - - -def get_model_field(queryset, field_name): - """Try to get information about a model and model field from a queryset. - - :param queryset: the queryset - :param field_name: the target field name - :returns: the model and target field from the queryset as a 2-tuple; both elements can be ``None`` - :rtype: tuple - """ - model = getattr(queryset, 'model', None) - try: - model_field = model._meta.get_field(field_name) - except Exception: # pragma: no cover - model_field = None - - return model, model_field - - -model_field_to_swagger_type = [ - (models.AutoField, (openapi.TYPE_INTEGER, None)), - (models.BinaryField, (openapi.TYPE_STRING, openapi.FORMAT_BINARY)), - (models.BooleanField, (openapi.TYPE_BOOLEAN, None)), - (models.NullBooleanField, (openapi.TYPE_BOOLEAN, None)), - (models.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)), - (models.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)), - (models.DecimalField, (openapi.TYPE_NUMBER, None)), - (models.DurationField, (openapi.TYPE_INTEGER, None)), - (models.FloatField, (openapi.TYPE_NUMBER, None)), - (models.IntegerField, (openapi.TYPE_INTEGER, None)), - (models.IPAddressField, (openapi.TYPE_STRING, openapi.FORMAT_IPV4)), - (models.GenericIPAddressField, (openapi.TYPE_STRING, openapi.FORMAT_IPV6)), - (models.SlugField, (openapi.TYPE_STRING, openapi.FORMAT_SLUG)), - (models.TextField, (openapi.TYPE_STRING, None)), - (models.TimeField, (openapi.TYPE_STRING, None)), - (models.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)), - (models.CharField, (openapi.TYPE_STRING, None)), -] - - -def inspect_model_field(model, model_field): - """Extract information from a django model field instance. - - :param model: the django model - :param model_field: a field on the model - :return: description, type, format and pattern extracted from the model field - :rtype: OrderedDict - """ - if model is not None and model_field is not None: - for model_field_class, tf in model_field_to_swagger_type: - if isinstance(model_field, model_field_class): - swagger_type, format = tf - break - else: # pragma: no cover - swagger_type, format = None, None - - if format is None or format == openapi.FORMAT_SLUG: - pattern = find_regex(model_field) - else: - pattern = None - - if model_field.help_text: - description = force_text(model_field.help_text) - elif model_field.primary_key: - description = get_pk_description(model, model_field) - else: - description = None +def guess_response_status(method): + if method == 'post': + return status.HTTP_201_CREATED + elif method == 'delete': + return status.HTTP_204_NO_CONTENT else: - description = None - swagger_type = None - format = None - pattern = None - - result = OrderedDict([ - ('description', description), - ('type', swagger_type or openapi.TYPE_STRING), - ('format', format), - ('pattern', pattern) - ]) - # TODO: filter none - return result - - -def serializer_field_to_swagger(field, swagger_object_type, definitions=None, **kwargs): - """Convert a drf Serializer or Field instance into a Swagger object. - - :param rest_framework.serializers.Field field: the source field - :param type[openapi.SwaggerDict] swagger_object_type: should be one of Schema, Parameter, Items - :param .ReferenceResolver definitions: used to serialize Schemas by reference - :param kwargs: extra attributes for constructing the object; - if swagger_object_type is Parameter, ``name`` and ``in_`` should be provided - :return: the swagger object - :rtype: openapi.Parameter, openapi.Items, openapi.Schema - """ - assert swagger_object_type in (openapi.Schema, openapi.Parameter, openapi.Items) - assert not isinstance(field, openapi.SwaggerDict), "passed field is already a SwaggerDict object" - title = force_text(field.label) if field.label else None - title = title if swagger_object_type == openapi.Schema else None # only Schema has title - title = None - description = force_text(field.help_text) if field.help_text else None - description = description if swagger_object_type != openapi.Items else None # Items has no description either - - def SwaggerType(existing_object=None, **instance_kwargs): - if swagger_object_type == openapi.Parameter and 'required' not in instance_kwargs: - instance_kwargs['required'] = field.required - if swagger_object_type != openapi.Items and 'default' not in instance_kwargs: - default = getattr(field, 'default', serializers.empty) - if default is not serializers.empty: - if callable(default): - try: - if hasattr(default, 'set_context'): - default.set_context(field) - default = default() - except Exception as e: # pragma: no cover - logger.warning("default for %s is callable but it raised an exception when " - "called; 'default' field will not be added to schema", field, exc_info=True) - default = None - - if default is not None: - try: - default = field.to_representation(default) - # JSON roundtrip ensures that the value is valid JSON; - # for example, sets get transformed into lists - default = json.loads(json.dumps(default, cls=encoders.JSONEncoder)) - except Exception: # pragma: no cover - logger.warning("'default' on schema for %s will not be set because " - "to_representation raised an exception", field, exc_info=True) - default = None - - if default is not None: - instance_kwargs['default'] = default - - if swagger_object_type == openapi.Schema and 'read_only' not in instance_kwargs: - if field.read_only: - instance_kwargs['read_only'] = True - instance_kwargs.update(kwargs) - instance_kwargs.pop('title', None) - instance_kwargs.pop('description', None) - - if existing_object is not None: - existing_object.title = title - existing_object.description = description - for attr, val in instance_kwargs.items(): - setattr(existing_object, attr, val) - return existing_object - - return swagger_object_type(title=title, description=description, **instance_kwargs) - - # arrays in Schema have Schema elements, arrays in Parameter and Items have Items elements - ChildSwaggerType = openapi.Schema if swagger_object_type == openapi.Schema else openapi.Items - - # ------ NESTED - if isinstance(field, (serializers.ListSerializer, serializers.ListField)): - child_schema = serializer_field_to_swagger(field.child, ChildSwaggerType, definitions) - return SwaggerType( - type=openapi.TYPE_ARRAY, - items=child_schema, - ) - elif isinstance(field, serializers.Serializer): - if swagger_object_type != openapi.Schema: - raise SwaggerGenerationError("cannot instantiate nested serializer as " + swagger_object_type.__name__) - assert definitions is not None, "ReferenceResolver required when instantiating Schema" - - serializer = field - if hasattr(serializer, '__ref_name__'): - ref_name = serializer.__ref_name__ - else: - ref_name = type(serializer).__name__ - if ref_name.endswith('Serializer'): - ref_name = ref_name[:-len('Serializer')] - - def make_schema_definition(): - properties = OrderedDict() - required = [] - for key, value in serializer.fields.items(): - properties[key] = serializer_field_to_swagger(value, ChildSwaggerType, definitions) - if value.required: - required.append(key) - - return SwaggerType( - type=openapi.TYPE_OBJECT, - properties=properties, - required=required or None, - ) - - if not ref_name: - return make_schema_definition() - - definitions.setdefault(ref_name, make_schema_definition) - return openapi.SchemaRef(definitions, ref_name) - elif isinstance(field, serializers.ManyRelatedField): - child_schema = serializer_field_to_swagger(field.child_relation, ChildSwaggerType, definitions) - return SwaggerType( - type=openapi.TYPE_ARRAY, - items=child_schema, - unique_items=True, # is this OK? - ) - elif isinstance(field, serializers.PrimaryKeyRelatedField): - if field.pk_field: - result = serializer_field_to_swagger(field.pk_field, swagger_object_type, definitions, **kwargs) - return SwaggerType(existing_object=result) - - attrs = {'type': openapi.TYPE_STRING} - try: - model = field.queryset.model - pk_field = model._meta.pk - except Exception: # pragma: no cover - logger.warning("an exception was raised when attempting to extract the primary key related to %s; " - "falling back to plain string" % field, exc_info=True) - else: - attrs.update(inspect_model_field(model, pk_field)) - - return SwaggerType(**attrs) - elif isinstance(field, serializers.HyperlinkedRelatedField): - return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI) - elif isinstance(field, serializers.SlugRelatedField): - model, model_field = get_model_field(field.queryset, field.slug_field) - attrs = inspect_model_field(model, model_field) - return SwaggerType(**attrs) - elif isinstance(field, serializers.RelatedField): - return SwaggerType(type=openapi.TYPE_STRING) - # ------ CHOICES - elif isinstance(field, serializers.MultipleChoiceField): - return SwaggerType( - type=openapi.TYPE_ARRAY, - items=ChildSwaggerType( - type=openapi.TYPE_STRING, - enum=list(field.choices.keys()) - ) - ) - elif isinstance(field, serializers.ChoiceField): - return SwaggerType(type=openapi.TYPE_STRING, enum=list(field.choices.keys())) - # ------ BOOL - elif isinstance(field, (serializers.BooleanField, serializers.NullBooleanField)): - return SwaggerType(type=openapi.TYPE_BOOLEAN) - # ------ NUMERIC - elif isinstance(field, (serializers.DecimalField, serializers.FloatField)): - # TODO: min_value max_value - return SwaggerType(type=openapi.TYPE_NUMBER) - elif isinstance(field, serializers.IntegerField): - # TODO: min_value max_value - return SwaggerType(type=openapi.TYPE_INTEGER) - elif isinstance(field, serializers.DurationField): - return SwaggerType(type=openapi.TYPE_INTEGER) - # ------ STRING - elif isinstance(field, serializers.EmailField): - return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_EMAIL) - elif isinstance(field, serializers.RegexField): - return SwaggerType(type=openapi.TYPE_STRING, pattern=find_regex(field)) - elif isinstance(field, serializers.SlugField): - return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_SLUG, pattern=find_regex(field)) - elif isinstance(field, serializers.URLField): - return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI) - elif isinstance(field, serializers.IPAddressField): - format = {'ipv4': openapi.FORMAT_IPV4, 'ipv6': openapi.FORMAT_IPV6}.get(field.protocol, None) - return SwaggerType(type=openapi.TYPE_STRING, format=format) - elif isinstance(field, serializers.CharField): - # TODO: min_length max_length (for all CharField subclasses above too) - return SwaggerType(type=openapi.TYPE_STRING) - elif isinstance(field, serializers.UUIDField): - return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_UUID) - # ------ DATE & TIME - elif isinstance(field, serializers.DateField): - return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_DATE) - elif isinstance(field, serializers.DateTimeField): - return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_DATETIME) - # ------ OTHERS - elif isinstance(field, serializers.FileField): - # swagger 2.0 does not support specifics about file fields, so ImageFile gets no special treatment - # OpenAPI 3.0 does support it, so a future implementation could handle this better - err = SwaggerGenerationError("FileField is supported only in a formData Parameter or response Schema") - if swagger_object_type == openapi.Schema: - # FileField.to_representation returns URL or file name - result = SwaggerType(type=openapi.TYPE_STRING, read_only=True) - if getattr(field, 'use_url', api_settings.UPLOADED_FILES_USE_URL): - result.format = openapi.FORMAT_URI - return result - elif swagger_object_type == openapi.Parameter: - param = SwaggerType(type=openapi.TYPE_FILE) - if param['in'] != openapi.IN_FORM: - raise err # pragma: no cover - return param - else: - raise err # pragma: no cover - elif isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema: - child_schema = serializer_field_to_swagger(field.child, ChildSwaggerType, definitions) - return SwaggerType( - type=openapi.TYPE_OBJECT, - additional_properties=child_schema - ) - elif isinstance(field, serializers.ModelField): - return SwaggerType(type=openapi.TYPE_STRING) - - # TODO unhandled fields: TimeField HiddenField JSONField - - # everything else gets string by default - return SwaggerType(type=openapi.TYPE_STRING) - - -def find_regex(regex_field): - """Given a ``Field``, look for a ``RegexValidator`` and try to extract its pattern and return it as a string. - - :param serializers.Field regex_field: the field instance - :return: the extracted pattern, or ``None`` - :rtype: str - """ - regex_validator = None - for validator in regex_field.validators: - if isinstance(validator, RegexValidator): - if regex_validator is not None: - # bail if multiple validators are found - no obvious way to choose - return None # pragma: no cover - regex_validator = validator - - # regex_validator.regex should be a compiled re object... - return getattr(getattr(regex_validator, 'regex', None), 'pattern', None) + return status.HTTP_200_OK def param_list_to_odict(parameters): @@ -492,3 +201,37 @@ def param_list_to_odict(parameters): result = OrderedDict(((param.name, param.in_), param) for param in parameters) assert len(result) == len(parameters), "duplicate Parameters found" return result + + +def filter_none(obj): + """Remove ``None`` values from tuples, lists or dictionaries. Return other objects as-is. + + :param obj: + :return: collection with ``None`` values removed + """ + if obj is None: + return None + new_obj = None + if isinstance(obj, dict): + new_obj = type(obj)((k, v) for k, v in obj.items() if k is not None and v is not None) + if isinstance(obj, (list, tuple)): + new_obj = type(obj)(v for v in obj if v is not None) + if new_obj is not None and len(new_obj) != len(obj): + return new_obj # pragma: no cover + return obj + + +def force_serializer_instance(serializer): + """Force `serializer` into a ``Serializer`` instance. If it is not a ``Serializer`` class or instance, raises + an assertion error. + + :param serializer: serializer class or instance + :return: serializer instance + """ + if inspect.isclass(serializer): + assert issubclass(serializer, serializers.BaseSerializer), "Serializer required, not %s" % serializer.__name__ + return serializer() + + assert isinstance(serializer, serializers.BaseSerializer), \ + "Serializer class or instance required, not %s" % type(serializer).__name__ + return serializer diff --git a/src/drf_yasg/views.py b/src/drf_yasg/views.py index 1720bdc..2d877d3 100644 --- a/src/drf_yasg/views.py +++ b/src/drf_yasg/views.py @@ -82,12 +82,24 @@ def get_schema_view(info, url=None, patterns=None, urlconf=None, public=False, v renderer_classes = _spec_renderers def get(self, request, version='', format=None): - generator = self.generator_class(info, version, url, patterns, urlconf) + generator = self.generator_class(info, request.version or version or '', url, patterns, urlconf) schema = generator.get_schema(request, self.public) if schema is None: raise exceptions.PermissionDenied() # pragma: no cover return Response(schema) + @classmethod + def apply_cache(cls, view, cache_timeout, cache_kwargs): + """Override this method to customize how caching is applied to the view. + + Arguments described in :meth:`.as_cached_view`. + """ + if not cls.public: + view = vary_on_headers('Cookie', 'Authorization')(view) + view = cache_page(cache_timeout, **cache_kwargs)(view) + view = deferred_never_cache(view) # disable in-browser caching + return view + @classmethod def as_cached_view(cls, cache_timeout=0, cache_kwargs=None, **initkwargs): """ @@ -102,10 +114,7 @@ def get_schema_view(info, url=None, patterns=None, urlconf=None, public=False, v cache_kwargs = cache_kwargs or {} view = cls.as_view(**initkwargs) if cache_timeout != 0: - if not public: - view = vary_on_headers('Cookie', 'Authorization')(view) - view = cache_page(cache_timeout, **cache_kwargs)(view) - view = deferred_never_cache(view) # disable in-browser caching + view = cls.apply_cache(view, cache_timeout, cache_kwargs) elif cache_kwargs: warnings.warn("cache_kwargs ignored because cache_timeout is 0 (disabled)") return view diff --git a/testproj/articles/serializers.py b/testproj/articles/serializers.py index ee90e2c..80c8e0c 100644 --- a/testproj/articles/serializers.py +++ b/testproj/articles/serializers.py @@ -1,11 +1,12 @@ from rest_framework import serializers from articles.models import Article +from django.utils.translation import ugettext_lazy as _ class ArticleSerializer(serializers.ModelSerializer): references = serializers.DictField( - help_text="this is a really bad example", + help_text=_("this is a really bad example"), child=serializers.URLField(help_text="but i needed to test these 2 fields somehow"), read_only=True, ) @@ -23,8 +24,8 @@ class ArticleSerializer(serializers.ModelSerializer): 'body': {'help_text': 'body serializer help_text'}, 'author': { 'default': serializers.CurrentUserDefault(), - 'help_text': "The ID of the user that created this article; if none is provided, " - "defaults to the currently logged in user." + 'help_text': _("The ID of the user that created this article; if none is provided, " + "defaults to the currently logged in user.") }, } diff --git a/testproj/articles/views.py b/testproj/articles/views.py index 860e681..ef0ba47 100644 --- a/testproj/articles/views.py +++ b/testproj/articles/views.py @@ -11,11 +11,45 @@ from rest_framework.response import Response from articles import serializers from articles.models import Article -from drf_yasg.inspectors import SwaggerAutoSchema +from drf_yasg import openapi +from drf_yasg.app_settings import swagger_settings +from drf_yasg.inspectors import SwaggerAutoSchema, FieldInspector, CoreAPICompatInspector, NotHandled from drf_yasg.utils import swagger_auto_schema -class NoPagingAutoSchema(SwaggerAutoSchema): +class DjangoFilterDescriptionInspector(CoreAPICompatInspector): + def get_filter_parameters(self, filter_backend): + if isinstance(filter_backend, DjangoFilterBackend): + result = super(DjangoFilterDescriptionInspector, self).get_filter_parameters(filter_backend) + for param in result: + if not param.get('description', ''): + param.description = "Filter the returned list by {field_name}".format(field_name=param.name) + + return result + + return NotHandled + + +class NoSchemaTitleInspector(FieldInspector): + def process_result(self, result, method_name, obj, **kwargs): + # remove the `title` attribute of all Schema objects + if isinstance(result, openapi.Schema.OR_REF): + # traverse any references and alter the Schema object in place + schema = openapi.resolve_ref(result, self.components) + schema.pop('title', None) + + # no ``return schema`` here, because it would mean we always generate + # an inline `object` instead of a definition reference + + # return back the same object that we got - i.e. a reference if we got a reference + return result + + +class NoTitleAutoSchema(SwaggerAutoSchema): + field_inspectors = [NoSchemaTitleInspector] + swagger_settings.DEFAULT_FIELD_INSPECTORS + + +class NoPagingAutoSchema(NoTitleAutoSchema): def should_page(self): return False @@ -26,7 +60,8 @@ class ArticlePagination(LimitOffsetPagination): @method_decorator(name='list', decorator=swagger_auto_schema( - operation_description="description from swagger_auto_schema via method_decorator" + operation_description="description from swagger_auto_schema via method_decorator", + filter_inspectors=[DjangoFilterDescriptionInspector] )) class ArticleViewSet(viewsets.ModelViewSet): """ @@ -52,7 +87,9 @@ class ArticleViewSet(viewsets.ModelViewSet): ordering_fields = ('date_modified', 'date_created') ordering = ('date_created',) - @swagger_auto_schema(auto_schema=NoPagingAutoSchema) + swagger_schema = NoTitleAutoSchema + + @swagger_auto_schema(auto_schema=NoPagingAutoSchema, filter_inspectors=[DjangoFilterDescriptionInspector]) @list_route(methods=['get']) def today(self, request): today_min = datetime.datetime.combine(datetime.date.today(), datetime.time.min) diff --git a/testproj/createsuperuser.py b/testproj/createsuperuser.py new file mode 100644 index 0000000..eb000e9 --- /dev/null +++ b/testproj/createsuperuser.py @@ -0,0 +1,4 @@ +from django.contrib.auth.models import User + +User.objects.filter(username='admin').delete() +User.objects.create_superuser('admin', 'admin@admin.admin', 'passwordadmin') diff --git a/testproj/db.sqlite3 b/testproj/db.sqlite3 deleted file mode 100644 index 532719e..0000000 Binary files a/testproj/db.sqlite3 and /dev/null differ diff --git a/testproj/snippets/serializers.py b/testproj/snippets/serializers.py index eaa5798..167d5ac 100644 --- a/testproj/snippets/serializers.py +++ b/testproj/snippets/serializers.py @@ -5,18 +5,22 @@ from snippets.models import Snippet, LANGUAGE_CHOICES, STYLE_CHOICES class LanguageSerializer(serializers.Serializer): - __ref_name__ = None name = serializers.ChoiceField( choices=LANGUAGE_CHOICES, default='python', help_text='The name of the programming language') + class Meta: + ref_name = None + class ExampleProjectSerializer(serializers.Serializer): - __ref_name__ = 'Project' project_name = serializers.CharField(help_text='Name of the project') github_repo = serializers.CharField(required=True, help_text='Github repository of the project') + class Meta: + ref_name = 'Project' + class SnippetSerializer(serializers.Serializer): """SnippetSerializer classdoc diff --git a/testproj/snippets/views.py b/testproj/snippets/views.py index a76d74c..462cf69 100644 --- a/testproj/snippets/views.py +++ b/testproj/snippets/views.py @@ -1,14 +1,28 @@ +from djangorestframework_camel_case.parser import CamelCaseJSONParser +from djangorestframework_camel_case.render import CamelCaseJSONRenderer +from inflection import camelize from rest_framework import generics +from drf_yasg.inspectors import SwaggerAutoSchema from snippets.models import Snippet from snippets.serializers import SnippetSerializer +class CamelCaseOperationIDAutoSchema(SwaggerAutoSchema): + def get_operation_id(self, operation_keys): + operation_id = super(CamelCaseOperationIDAutoSchema, self).get_operation_id(operation_keys) + return camelize(operation_id, uppercase_first_letter=False) + + class SnippetList(generics.ListCreateAPIView): """SnippetList classdoc""" queryset = Snippet.objects.all() serializer_class = SnippetSerializer + parser_classes = (CamelCaseJSONParser,) + renderer_classes = (CamelCaseJSONRenderer,) + swagger_schema = CamelCaseOperationIDAutoSchema + def perform_create(self, serializer): serializer.save(owner=self.request.user) @@ -31,6 +45,10 @@ class SnippetDetail(generics.RetrieveUpdateDestroyAPIView): serializer_class = SnippetSerializer pagination_class = None + parser_classes = (CamelCaseJSONParser,) + renderer_classes = (CamelCaseJSONRenderer,) + swagger_schema = CamelCaseOperationIDAutoSchema + def patch(self, request, *args, **kwargs): """patch method docstring""" return super(SnippetDetail, self).patch(request, *args, **kwargs) diff --git a/testproj/testproj/settings.py b/testproj/testproj/settings.py index c2dae4c..64ad1b8 100644 --- a/testproj/testproj/settings.py +++ b/testproj/testproj/settings.py @@ -128,3 +128,52 @@ USE_TZ = True STATIC_URL = '/static/' TEST_RUNNER = 'testproj.runner.PytestTestRunner' + +LOGGING = { + 'version': 1, + 'disable_existing_loggers': True, + 'formatters': { + 'pipe_separated': { + 'format': '%(asctime)s | %(levelname)s | %(name)s | %(message)s' + } + }, + 'handlers': { + 'console_log': { + 'level': 'DEBUG', + 'class': 'logging.StreamHandler', + 'stream': 'ext://sys.stdout', + 'formatter': 'pipe_separated', + }, + }, + 'loggers': { + 'drf_yasg': { + 'handlers': ['console_log'], + 'level': 'DEBUG', + 'propagate': False, + }, + 'django': { + 'handlers': ['console_log'], + 'level': 'DEBUG', + 'propagate': False, + }, + 'django.db.backends': { + 'handlers': ['console_log'], + 'level': 'INFO', + 'propagate': False, + }, + 'django.template': { + 'handlers': ['console_log'], + 'level': 'INFO', + 'propagate': False, + }, + 'swagger_spec_validator': { + 'handlers': ['console_log'], + 'level': 'INFO', + 'propagate': False, + } + }, + 'root': { + 'handlers': ['console_log'], + 'level': 'INFO', + } +} diff --git a/testproj/testproj/urls.py b/testproj/testproj/urls.py index d1ea6a4..efff18f 100644 --- a/testproj/testproj/urls.py +++ b/testproj/testproj/urls.py @@ -6,7 +6,7 @@ from rest_framework.decorators import api_view from drf_yasg import openapi from drf_yasg.views import get_schema_view -schema_view = get_schema_view( +SchemaView = get_schema_view( openapi.Info( title="Snippets API", default_version='v1', @@ -27,12 +27,12 @@ def plain_view(request): urlpatterns = [ - url(r'^swagger(?P.json|.yaml)$', schema_view.without_ui(cache_timeout=0), name='schema-json'), - url(r'^swagger/$', schema_view.with_ui('swagger', cache_timeout=0), name='schema-swagger-ui'), - url(r'^redoc/$', schema_view.with_ui('redoc', cache_timeout=0), name='schema-redoc'), - url(r'^cached/swagger(?P.json|.yaml)$', schema_view.without_ui(cache_timeout=None), name='schema-json'), - url(r'^cached/swagger/$', schema_view.with_ui('swagger', cache_timeout=None), name='schema-swagger-ui'), - url(r'^cached/redoc/$', schema_view.with_ui('redoc', cache_timeout=None), name='schema-redoc'), + url(r'^swagger(?P.json|.yaml)$', SchemaView.without_ui(cache_timeout=0), name='schema-json'), + url(r'^swagger/$', SchemaView.with_ui('swagger', cache_timeout=0), name='schema-swagger-ui'), + url(r'^redoc/$', SchemaView.with_ui('redoc', cache_timeout=0), name='schema-redoc'), + url(r'^cached/swagger(?P.json|.yaml)$', SchemaView.without_ui(cache_timeout=None), name='cschema-json'), + url(r'^cached/swagger/$', SchemaView.with_ui('swagger', cache_timeout=None), name='cschema-swagger-ui'), + url(r'^cached/redoc/$', SchemaView.with_ui('redoc', cache_timeout=None), name='cschema-redoc'), url(r'^admin/', admin.site.urls), url(r'^snippets/', include('snippets.urls')), diff --git a/testproj/users/serializers.py b/testproj/users/serializers.py index 87cc87c..c10f076 100644 --- a/testproj/users/serializers.py +++ b/testproj/users/serializers.py @@ -6,7 +6,7 @@ from snippets.models import Snippet class UserSerializerrr(serializers.ModelSerializer): snippets = serializers.PrimaryKeyRelatedField(many=True, queryset=Snippet.objects.all()) - article_slugs = serializers.SlugRelatedField(read_only=True, slug_field='slug', many=True, source='articlessss') + article_slugs = serializers.SlugRelatedField(read_only=True, slug_field='slug', many=True, source='articles') last_connected_ip = serializers.IPAddressField(help_text="i'm out of ideas", protocol='ipv4', read_only=True) last_connected_at = serializers.DateField(help_text="really?", read_only=True) diff --git a/testproj/users/views.py b/testproj/users/views.py index cac1162..5a44854 100644 --- a/testproj/users/views.py +++ b/testproj/users/views.py @@ -32,7 +32,7 @@ class UserList(APIView): serializer.save() return Response(serializer.data, status=status.HTTP_201_CREATED) - @swagger_auto_schema(request_body=no_body, operation_description="dummy operation") + @swagger_auto_schema(request_body=no_body, operation_id="users_dummy", operation_description="dummy operation") def patch(self, request): pass diff --git a/tests/conftest.py b/tests/conftest.py index ff80891..a7449a1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,14 +1,15 @@ import copy import json import os +from collections import OrderedDict import pytest from django.contrib.auth.models import User from rest_framework.test import APIRequestFactory from rest_framework.views import APIView -from ruamel import yaml from drf_yasg import openapi, codecs +from drf_yasg.codecs import yaml_sane_load from drf_yasg.generators import OpenAPISchemaGenerator @@ -47,7 +48,7 @@ def swagger(mock_schema_request): @pytest.fixture def swagger_dict(swagger): json_bytes = codec_json().encode(swagger) - return json.loads(json_bytes.decode('utf-8')) + return json.loads(json_bytes.decode('utf-8'), object_pairs_hook=OrderedDict) @pytest.fixture @@ -79,4 +80,4 @@ def redoc_settings(settings): @pytest.fixture def reference_schema(): with open(os.path.join(os.path.dirname(__file__), 'reference.yaml')) as reference: - return yaml.safe_load(reference) + return yaml_sane_load(reference) diff --git a/tests/reference.yaml b/tests/reference.yaml index b36e1ff..1145ff4 100644 --- a/tests/reference.yaml +++ b/tests/reference.yaml @@ -20,7 +20,7 @@ paths: parameters: - name: title in: query - description: '' + description: Filter the returned list by title required: false type: string - name: ordering @@ -89,7 +89,7 @@ paths: parameters: - name: title in: query - description: '' + description: Filter the returned list by title required: false type: string - name: ordering @@ -249,7 +249,7 @@ paths: parameters: [] /snippets/: get: - operationId: snippets_list + operationId: snippetsList description: SnippetList classdoc parameters: [] responses: @@ -264,7 +264,7 @@ paths: tags: - snippets post: - operationId: snippets_create + operationId: snippetsCreate description: post method docstring parameters: - name: data @@ -284,7 +284,7 @@ paths: parameters: [] /snippets/{id}/: get: - operationId: snippets_read + operationId: snippetsRead description: SnippetDetail classdoc parameters: [] responses: @@ -297,7 +297,7 @@ paths: tags: - snippets put: - operationId: snippets_update + operationId: snippetsUpdate description: put class docstring parameters: - name: data @@ -315,7 +315,7 @@ paths: tags: - snippets patch: - operationId: snippets_partial_update + operationId: snippetsPartialUpdate description: patch method docstring parameters: - name: data @@ -333,7 +333,7 @@ paths: tags: - snippets delete: - operationId: snippets_delete + operationId: snippetsDelete description: delete method docstring parameters: [] responses: @@ -404,7 +404,7 @@ paths: tags: - users patch: - operationId: users_partial_update + operationId: users_dummy description: dummy operation parameters: [] responses: @@ -466,6 +466,7 @@ definitions: title: description: title model help_text type: string + maxLength: 255 author: description: The ID of the user that created this article; if none is provided, defaults to the currently logged in user. @@ -474,11 +475,13 @@ definitions: body: description: body serializer help_text type: string + maxLength: 5000 slug: description: slug model help_text type: string format: slug pattern: ^[-a-zA-Z0-9_]+$ + maxLength: 50 date_created: type: string format: date-time @@ -509,14 +512,16 @@ definitions: readOnly: true Project: required: - - project_name - - github_repo + - projectName + - githubRepo type: object properties: - project_name: + projectName: + title: Project name description: Name of the project type: string - github_repo: + githubRepo: + title: Github repo description: Github repository of the project type: string Snippet: @@ -526,29 +531,38 @@ definitions: type: object properties: id: + title: Id description: id serializer help text type: integer readOnly: true owner: + title: Owner description: The ID of the user that created this snippet; if none is provided, defaults to the currently logged in user. type: integer default: 1 - owner_as_string: + ownerAsString: description: The ID of the user that created this snippet. type: string readOnly: true + title: Owner as string title: + title: Title type: string + maxLength: 100 code: + title: Code type: string linenos: + title: Linenos type: boolean language: + title: Language description: Sample help text for language type: object properties: name: + title: Name description: The name of the programming language type: string enum: @@ -988,6 +1002,7 @@ definitions: - zephir default: python styles: + title: Styles type: array items: type: string @@ -1024,19 +1039,22 @@ definitions: default: - friendly lines: + title: Lines type: array items: type: integer - example_projects: + exampleProjects: + title: Example projects type: array items: $ref: '#/definitions/Project' readOnly: true - difficulty_factor: + difficultyFactor: + title: Difficulty factor description: this is here just to test FloatField type: number - default: 6.9 readOnly: true + default: 6.9 UserSerializerrr: required: - username @@ -1045,42 +1063,55 @@ definitions: type: object properties: id: + title: ID type: integer readOnly: true username: + title: Username description: Required. 150 characters or fewer. Letters, digits and @/./+/-/_ only. type: string + pattern: ^[\w.@+-]+$ + maxLength: 150 email: + title: Email address type: string format: email + maxLength: 254 articles: + title: Articles type: array items: type: integer uniqueItems: true snippets: + title: Snippets type: array items: type: integer uniqueItems: true last_connected_ip: + title: Last connected ip description: i'm out of ideas type: string format: ipv4 readOnly: true last_connected_at: + title: Last connected at description: really? type: string format: date readOnly: true article_slugs: + title: Article slugs type: array items: type: string + format: slug + pattern: ^[-a-zA-Z0-9_]+\Z readOnly: true - uniqueItems: true readOnly: true + uniqueItems: true securityDefinitions: basic: type: basic diff --git a/tests/test_api_view.py b/tests/test_api_view.py deleted file mode 100644 index 43f9112..0000000 --- a/tests/test_api_view.py +++ /dev/null @@ -1,17 +0,0 @@ -from drf_yasg import openapi - - -def test_operation_docstrings(swagger_dict): - users_list = swagger_dict['paths']['/users/'] - assert users_list['get']['description'] == "UserList cbv classdoc" - assert users_list['post']['description'] == "apiview post description override" - - users_detail = swagger_dict['paths']['/users/{id}/'] - assert users_detail['get']['description'] == "user_detail fbv docstring" - assert users_detail['put']['description'] == "user_detail fbv docstring" - - -def test_parameter_docstrings(swagger_dict): - users_detail = swagger_dict['paths']['/users/{id}/'] - assert users_detail['get']['parameters'][0]['description'] == "test manual param" - assert users_detail['put']['parameters'][0]['in'] == openapi.IN_BODY diff --git a/tests/test_generic_api_view.py b/tests/test_generic_api_view.py deleted file mode 100644 index 2b698c7..0000000 --- a/tests/test_generic_api_view.py +++ /dev/null @@ -1,22 +0,0 @@ -def test_appropriate_status_codes(swagger_dict): - snippets_list = swagger_dict['paths']['/snippets/'] - assert '200' in snippets_list['get']['responses'] - assert '201' in snippets_list['post']['responses'] - - snippets_detail = swagger_dict['paths']['/snippets/{id}/'] - assert '200' in snippets_detail['get']['responses'] - assert '200' in snippets_detail['put']['responses'] - assert '200' in snippets_detail['patch']['responses'] - assert '204' in snippets_detail['delete']['responses'] - - -def test_operation_docstrings(swagger_dict): - snippets_list = swagger_dict['paths']['/snippets/'] - assert snippets_list['get']['description'] == "SnippetList classdoc" - assert snippets_list['post']['description'] == "post method docstring" - - snippets_detail = swagger_dict['paths']['/snippets/{id}/'] - assert snippets_detail['get']['description'] == "SnippetDetail classdoc" - assert snippets_detail['put']['description'] == "put class docstring" - assert snippets_detail['patch']['description'] == "patch method docstring" - assert snippets_detail['delete']['description'] == "delete method docstring" diff --git a/tests/test_generic_viewset.py b/tests/test_generic_viewset.py deleted file mode 100644 index d67882b..0000000 --- a/tests/test_generic_viewset.py +++ /dev/null @@ -1,29 +0,0 @@ -def test_appropriate_status_codes(swagger_dict): - articles_list = swagger_dict['paths']['/articles/'] - assert '200' in articles_list['get']['responses'] - assert '201' in articles_list['post']['responses'] - - articles_detail = swagger_dict['paths']['/articles/{slug}/'] - assert '200' in articles_detail['get']['responses'] - assert '200' in articles_detail['put']['responses'] - assert '200' in articles_detail['patch']['responses'] - assert '204' in articles_detail['delete']['responses'] - - -def test_operation_docstrings(swagger_dict): - articles_list = swagger_dict['paths']['/articles/'] - assert articles_list['get']['description'] == "description from swagger_auto_schema via method_decorator" - assert articles_list['post']['description'] == "ArticleViewSet class docstring" - - articles_detail = swagger_dict['paths']['/articles/{slug}/'] - assert articles_detail['get']['description'] == "retrieve class docstring" - assert articles_detail['put']['description'] == "update method docstring" - assert articles_detail['patch']['description'] == "partial_update description override" - assert articles_detail['delete']['description'] == "destroy method docstring" - - articles_today = swagger_dict['paths']['/articles/today/'] - assert articles_today['get']['description'] == "ArticleViewSet class docstring" - - articles_image = swagger_dict['paths']['/articles/{slug}/image/'] - assert articles_image['get']['description'] == "image GET description override" - assert articles_image['post']['description'] == "image method docstring" diff --git a/tests/test_reference_schema.py b/tests/test_reference_schema.py index d3bd8a0..cf04cd1 100644 --- a/tests/test_reference_schema.py +++ b/tests/test_reference_schema.py @@ -1,13 +1,46 @@ +from collections import OrderedDict + from datadiff.tools import assert_equal +from drf_yasg.codecs import yaml_sane_dump +from drf_yasg.inspectors import FieldInspector, SerializerInspector, PaginatorInspector, FilterInspector + def test_reference_schema(swagger_dict, reference_schema): - swagger_dict = dict(swagger_dict) - reference_schema = dict(reference_schema) + swagger_dict = OrderedDict(swagger_dict) + reference_schema = OrderedDict(reference_schema) ignore = ['info', 'host', 'schemes', 'basePath', 'securityDefinitions'] for attr in ignore: swagger_dict.pop(attr, None) reference_schema.pop(attr, None) - # formatted better than pytest diff - assert_equal(swagger_dict, reference_schema) + # print diff between YAML strings because it's prettier + assert_equal(yaml_sane_dump(swagger_dict, binary=False), yaml_sane_dump(reference_schema, binary=False)) + + +class NoOpFieldInspector(FieldInspector): + pass + + +class NoOpSerializerInspector(SerializerInspector): + pass + + +class NoOpFilterInspector(FilterInspector): + pass + + +class NoOpPaginatorInspector(PaginatorInspector): + pass + + +def test_noop_inspectors(swagger_settings, swagger_dict, reference_schema): + from drf_yasg import app_settings + + def set_inspectors(inspectors, setting_name): + swagger_settings[setting_name] = inspectors + app_settings.SWAGGER_DEFAULTS[setting_name] + + set_inspectors([NoOpFieldInspector, NoOpSerializerInspector], 'DEFAULT_FIELD_INSPECTORS') + set_inspectors([NoOpFilterInspector], 'DEFAULT_FILTER_INSPECTORS') + set_inspectors([NoOpPaginatorInspector], 'DEFAULT_PAGINATOR_INSPECTORS') + test_reference_schema(swagger_dict, reference_schema) diff --git a/tests/test_schema_generator.py b/tests/test_schema_generator.py index 62b4f80..df51b13 100644 --- a/tests/test_schema_generator.py +++ b/tests/test_schema_generator.py @@ -1,9 +1,10 @@ import json +from collections import OrderedDict import pytest -from ruamel import yaml from drf_yasg import openapi, codecs +from drf_yasg.codecs import yaml_sane_load from drf_yasg.generators import OpenAPISchemaGenerator @@ -35,12 +36,12 @@ def test_yaml_codec_roundtrip(codec_yaml, swagger, validate_schema): yaml_bytes = codec_yaml.encode(swagger) assert b'omap' not in yaml_bytes # ensure no ugly !!omap is outputted assert b'&id' not in yaml_bytes and b'*id' not in yaml_bytes # ensure no YAML references are generated - validate_schema(yaml.safe_load(yaml_bytes.decode('utf-8'))) + validate_schema(yaml_sane_load(yaml_bytes.decode('utf-8'))) def test_yaml_and_json_match(codec_yaml, codec_json, swagger): - yaml_schema = yaml.safe_load(codec_yaml.encode(swagger).decode('utf-8')) - json_schema = json.loads(codec_json.encode(swagger).decode('utf-8')) + yaml_schema = yaml_sane_load(codec_yaml.encode(swagger).decode('utf-8')) + json_schema = json.loads(codec_json.encode(swagger).decode('utf-8'), object_pairs_hook=OrderedDict) assert yaml_schema == json_schema diff --git a/tests/test_schema_structure.py b/tests/test_schema_structure.py deleted file mode 100644 index 77ec55c..0000000 --- a/tests/test_schema_structure.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_paths_not_empty(swagger_dict): - assert len(swagger_dict['paths']) > 0 diff --git a/tests/test_schema_views.py b/tests/test_schema_views.py index aabd828..136e5aa 100644 --- a/tests/test_schema_views.py +++ b/tests/test_schema_views.py @@ -2,7 +2,8 @@ import json from collections import OrderedDict import pytest -from ruamel import yaml + +from drf_yasg.codecs import yaml_sane_load def _validate_text_schema_view(client, validate_schema, path, loader): @@ -22,10 +23,10 @@ def test_swagger_json(client, validate_schema): def test_swagger_yaml(client, validate_schema): - _validate_text_schema_view(client, validate_schema, "/swagger.yaml", yaml.safe_load) + _validate_text_schema_view(client, validate_schema, "/swagger.yaml", yaml_sane_load) -def test_exception_middleware(client, swagger_settings): +def test_exception_middleware(client, swagger_settings, db): swagger_settings['SECURITY_DEFINITIONS'] = { 'bad': { 'bad_attribute': 'should not be accepted' @@ -70,5 +71,5 @@ def test_caching(client, validate_schema): @pytest.mark.urls('urlconfs.non_public_urls') def test_non_public(client): response = client.get('/private/swagger.yaml') - swagger = yaml.safe_load(response.content.decode('utf-8')) + swagger = yaml_sane_load(response.content.decode('utf-8')) assert len(swagger['paths']) == 0 diff --git a/tests/test_versioning.py b/tests/test_versioning.py new file mode 100644 index 0000000..265a58e --- /dev/null +++ b/tests/test_versioning.py @@ -0,0 +1,56 @@ +import pytest + +from drf_yasg.codecs import yaml_sane_load + + +def _get_versioned_schema(prefix, client, validate_schema): + response = client.get(prefix + 'swagger.yaml') + assert response.status_code == 200 + swagger = yaml_sane_load(response.content.decode('utf-8')) + validate_schema(swagger) + assert prefix + 'snippets/' in swagger['paths'] + return swagger + + +def _check_v1(swagger, prefix): + assert swagger['info']['version'] == '1.0' + versioned_post = swagger['paths'][prefix + 'snippets/']['post'] + assert versioned_post['responses']['201']['schema']['$ref'] == '#/definitions/Snippet' + assert 'v2field' not in swagger['definitions']['Snippet']['properties'] + + +def _check_v2(swagger, prefix): + assert swagger['info']['version'] == '2.0' + versioned_post = swagger['paths'][prefix + 'snippets/']['post'] + assert versioned_post['responses']['201']['schema']['$ref'] == '#/definitions/SnippetV2' + assert 'v2field' in swagger['definitions']['SnippetV2']['properties'] + v2field = swagger['definitions']['SnippetV2']['properties']['v2field'] + assert v2field['description'] == 'version 2.0 field' + + +@pytest.mark.urls('urlconfs.url_versioning') +def test_url_v1(client, validate_schema): + prefix = '/versioned/url/v1.0/' + swagger = _get_versioned_schema(prefix, client, validate_schema) + _check_v1(swagger, prefix) + + +@pytest.mark.urls('urlconfs.url_versioning') +def test_url_v2(client, validate_schema): + prefix = '/versioned/url/v2.0/' + swagger = _get_versioned_schema(prefix, client, validate_schema) + _check_v2(swagger, prefix) + + +@pytest.mark.urls('urlconfs.ns_versioning') +def test_ns_v1(client, validate_schema): + prefix = '/versioned/ns/v1.0/' + swagger = _get_versioned_schema(prefix, client, validate_schema) + _check_v1(swagger, prefix) + + +@pytest.mark.urls('urlconfs.ns_versioning') +def test_ns_v2(client, validate_schema): + prefix = '/versioned/ns/v2.0/' + swagger = _get_versioned_schema(prefix, client, validate_schema) + _check_v2(swagger, prefix) diff --git a/tests/urlconfs/ns_version1.py b/tests/urlconfs/ns_version1.py new file mode 100644 index 0000000..8d22a56 --- /dev/null +++ b/tests/urlconfs/ns_version1.py @@ -0,0 +1,26 @@ +from django.conf.urls import url +from rest_framework import generics, versioning + +from snippets.models import Snippet +from snippets.serializers import SnippetSerializer + + +class SnippetList(generics.ListCreateAPIView): + """SnippetList classdoc""" + queryset = Snippet.objects.all() + serializer_class = SnippetSerializer + versioning_class = versioning.NamespaceVersioning + + def perform_create(self, serializer): + serializer.save(owner=self.request.user) + + def post(self, request, *args, **kwargs): + """post method docstring""" + return super(SnippetList, self).post(request, *args, **kwargs) + + +app_name = 'test_ns_versioning' + +urlpatterns = [ + url(r"^$", SnippetList.as_view()) +] diff --git a/tests/urlconfs/ns_version2.py b/tests/urlconfs/ns_version2.py new file mode 100644 index 0000000..69908f2 --- /dev/null +++ b/tests/urlconfs/ns_version2.py @@ -0,0 +1,23 @@ +from django.conf.urls import url +from rest_framework import fields + +from snippets.serializers import SnippetSerializer +from .ns_version1 import SnippetList as SnippetListV1 + + +class SnippetSerializerV2(SnippetSerializer): + v2field = fields.IntegerField(help_text="version 2.0 field") + + class Meta: + ref_name = 'SnippetV2' + + +class SnippetListV2(SnippetListV1): + serializer_class = SnippetSerializerV2 + + +app_name = 'test_ns_versioning' + +urlpatterns = [ + url(r"^$", SnippetListV2.as_view()) +] diff --git a/tests/urlconfs/ns_versioning.py b/tests/urlconfs/ns_versioning.py new file mode 100644 index 0000000..5875908 --- /dev/null +++ b/tests/urlconfs/ns_versioning.py @@ -0,0 +1,24 @@ +from django.conf.urls import url, include +from rest_framework import versioning + +from testproj.urls import SchemaView +from . import ns_version1, ns_version2 + +VERSION_PREFIX_NS = r"^versioned/ns/" + + +class VersionedSchemaView(SchemaView): + versioning_class = versioning.NamespaceVersioning + + +schema_patterns = [ + url(r'swagger(?P.json|.yaml)$', VersionedSchemaView.without_ui(), name='ns-schema') +] + + +urlpatterns = [ + url(VERSION_PREFIX_NS + r"v1.0/snippets/", include(ns_version1, namespace='1.0')), + url(VERSION_PREFIX_NS + r"v2.0/snippets/", include(ns_version2, namespace='2.0')), + url(VERSION_PREFIX_NS + r'v1.0/', include((schema_patterns, '1.0'))), + url(VERSION_PREFIX_NS + r'v2.0/', include((schema_patterns, '2.0'))), +] diff --git a/tests/urlconfs/url_versioning.py b/tests/urlconfs/url_versioning.py new file mode 100644 index 0000000..5642b5c --- /dev/null +++ b/tests/urlconfs/url_versioning.py @@ -0,0 +1,48 @@ +from django.conf.urls import url +from rest_framework import generics, versioning, fields + +from snippets.models import Snippet +from snippets.serializers import SnippetSerializer +from testproj.urls import SchemaView + + +class SnippetSerializerV2(SnippetSerializer): + v2field = fields.IntegerField(help_text="version 2.0 field") + + class Meta: + ref_name = 'SnippetV2' + + +class SnippetList(generics.ListCreateAPIView): + """SnippetList classdoc""" + queryset = Snippet.objects.all() + serializer_class = SnippetSerializer + versioning_class = versioning.URLPathVersioning + + def get_serializer_class(self): + context = self.get_serializer_context() + request = context['request'] + if int(float(request.version)) >= 2: + return SnippetSerializerV2 + else: + return SnippetSerializer + + def perform_create(self, serializer): + serializer.save(owner=self.request.user) + + def post(self, request, *args, **kwargs): + """post method docstring""" + return super(SnippetList, self).post(request, *args, **kwargs) + + +VERSION_PREFIX_URL = r"^versioned/url/v(?P1.0|2.0)/" + + +class VersionedSchemaView(SchemaView): + versioning_class = versioning.URLPathVersioning + + +urlpatterns = [ + url(VERSION_PREFIX_URL + r"snippets/$", SnippetList.as_view()), + url(VERSION_PREFIX_URL + r'swagger(?P.json|.yaml)$', VersionedSchemaView.without_ui(), name='vschema-json'), +]