From dce00156d529d276aa1c2610ce67fe6708af8abd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristi=20V=C3=AEjdea?= Date: Tue, 5 Dec 2017 19:46:02 +0100 Subject: [PATCH] Rewrite schema generation (#1) * Completeley rewritten schema generation * Added support for python 2.7 and 3.4 * Restructured testing and build configuration * Added nested request schemas This rewrite completely replaces the public interface of the django rest schema generation library, so further changes will be needed to re-enable and further extend the customization points one might want. --- .coveragerc | 32 ++ .idea/drf-swagger.iml | 35 ++ .idea/inspectionProfiles/Project_Default.xml | 60 +++ .idea/misc.xml | 73 ++++ .idea/modules.xml | 8 + .idea/vcs.xml | 6 + .travis.yml | 20 +- MANIFEST.in | 2 +- requirements.txt => requirements/base.txt | 5 +- requirements/ci.txt | 4 + requirements/dev.txt | 3 + requirements/test.txt | 11 + requirements/validation.txt | 3 + requirements_dev.txt | 5 - requirements_test.txt | 6 - requirements_validation.txt | 3 - setup.py | 12 +- src/drf_swagger/codecs.py | 75 ++-- src/drf_swagger/errors.py | 14 + src/drf_swagger/generators.py | 115 +++++- src/drf_swagger/inspectors.py | 361 +++++++++++++++++- src/drf_swagger/middleware.py | 18 + src/drf_swagger/openapi.py | 323 ++++++++++++---- src/drf_swagger/renderers.py | 10 +- .../templates/drf-swagger/swagger-ui.html | 53 ++- src/drf_swagger/views.py | 11 +- testproj/articles/__init__.py | 0 testproj/articles/migrations/0001_initial.py | 26 ++ testproj/articles/migrations/__init__.py | 0 testproj/articles/models.py | 11 + testproj/articles/serializers.py | 16 + testproj/articles/urls.py | 11 + testproj/articles/views.py | 60 +++ testproj/db.sqlite3 | Bin 135168 -> 155648 bytes .../migrations/0002_auto_20171205_0505.py | 27 ++ testproj/snippets/models.py | 5 +- testproj/snippets/serializers.py | 3 +- testproj/snippets/views.py | 10 +- testproj/testproj/settings.py | 5 +- testproj/testproj/tests.py | 91 ----- testproj/testproj/urls.py | 16 +- testproj/users/__init__.py | 0 testproj/users/models.py | 0 testproj/users/serializers.py | 12 + testproj/users/urls.py | 8 + testproj/users/views.py | 24 ++ tests/conftest.py | 30 +- tests/test_api_view.py | 6 + tests/test_generic_api_view.py | 22 ++ tests/test_generic_viewset.py | 29 ++ tests/test_schema_generator.py | 18 +- tests/test_schema_structure.py | 24 +- tests/test_schema_views.py | 39 ++ tests/test_swaggerdict.py | 53 +++ tox.ini | 44 ++- 55 files changed, 1512 insertions(+), 346 deletions(-) create mode 100644 .coveragerc create mode 100644 .idea/drf-swagger.iml create mode 100644 .idea/inspectionProfiles/Project_Default.xml create mode 100644 .idea/misc.xml create mode 100644 .idea/modules.xml create mode 100644 .idea/vcs.xml rename requirements.txt => requirements/base.txt (52%) create mode 100644 requirements/ci.txt create mode 100644 requirements/dev.txt create mode 100644 requirements/test.txt create mode 100644 requirements/validation.txt delete mode 100644 requirements_dev.txt delete mode 100644 requirements_test.txt delete mode 100644 requirements_validation.txt create mode 100644 src/drf_swagger/errors.py create mode 100644 src/drf_swagger/middleware.py create mode 100644 testproj/articles/__init__.py create mode 100644 testproj/articles/migrations/0001_initial.py create mode 100644 testproj/articles/migrations/__init__.py create mode 100644 testproj/articles/models.py create mode 100644 testproj/articles/serializers.py create mode 100644 testproj/articles/urls.py create mode 100644 testproj/articles/views.py create mode 100644 testproj/snippets/migrations/0002_auto_20171205_0505.py delete mode 100644 testproj/testproj/tests.py create mode 100644 testproj/users/__init__.py create mode 100644 testproj/users/models.py create mode 100644 testproj/users/serializers.py create mode 100644 testproj/users/urls.py create mode 100644 testproj/users/views.py create mode 100644 tests/test_api_view.py create mode 100644 tests/test_generic_api_view.py create mode 100644 tests/test_generic_viewset.py create mode 100644 tests/test_schema_views.py create mode 100644 tests/test_swaggerdict.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..9f1c676 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,32 @@ +[run] +source = drf_swagger +branch = True + +[report] +# Regexes for lines to exclude from consideration +exclude_lines = + # Have to re-enable the standard pragma + pragma: no cover + + # Don't complain about missing debug-only code: + def __repr__ + if self/.debug + + # Don't complain if tests don't hit defensive assertion code: + raise AssertionError + raise NotImplementedError + warnings.warn + + # Don't complain if non-runnable code isn't run: + if 0: + if __name__ == .__main__.: + +ignore_errors = True +precision = 0 + +[paths] +source = + src/drf_swagger/ + .tox/*/Lib/site-packages/drf_swagger/ + .tox/*/lib/*/site-packages/drf_swagger/ + /home/travis/virtualenv/*/lib/*/site-packages/drf_swagger/ diff --git a/.idea/drf-swagger.iml b/.idea/drf-swagger.iml new file mode 100644 index 0000000..30b590b --- /dev/null +++ b/.idea/drf-swagger.iml @@ -0,0 +1,35 @@ + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..2af15d1 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,60 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..b6e1878 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,73 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..b317f5e --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index 909e9ed..80aff3f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,21 +1,31 @@ language: python cache: pip python: + - '2.7' + - '3.4' - '3.5' - '3.6' - '3.7-dev' +env: + - DRF=3.7 + matrix: fast_finish: true include: - - python: "3.5" + - python: '2.7' env: TOXENV=flake8 + - python: '3.6' + env: DRF=master allow_failures: - env: TOXENV=flake8 + - env: DRF=master + - python: '2.7' + - python: '3.7' install: - - pip install -r requirements_dev.txt + - pip install -r requirements/ci.txt before_script: - coverage erase @@ -24,4 +34,8 @@ script: - tox after_success: - - coveralls + - codecov + +branches: + only: + - master diff --git a/MANIFEST.in b/MANIFEST.in index 93f2962..061705b 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,5 @@ include README.md include LICENSE -include requirements* +recursive-include requirements * recursive-include src/drf_swagger/static * recursive-include src/drf_swagger/templates * diff --git a/requirements.txt b/requirements/base.txt similarity index 52% rename from requirements.txt rename to requirements/base.txt index 9536b6c..77b4da0 100644 --- a/requirements.txt +++ b/requirements/base.txt @@ -1,5 +1,6 @@ -djangorestframework>=3.7.3 -django>=1.11.7 coreapi>=2.3.3 +coreschema>=0.0.4 openapi_codec>=1.3.2 ruamel.yaml>=0.15.34 +inflection>=0.3.1 +future>=0.16.0 diff --git a/requirements/ci.txt b/requirements/ci.txt new file mode 100644 index 0000000..6f65907 --- /dev/null +++ b/requirements/ci.txt @@ -0,0 +1,4 @@ +# requirements for CI test suite +-r dev.txt +tox-travis>=0.10 +codecov>=2.0.9 diff --git a/requirements/dev.txt b/requirements/dev.txt new file mode 100644 index 0000000..a278302 --- /dev/null +++ b/requirements/dev.txt @@ -0,0 +1,3 @@ +# requirements for local development +tox>=2.9.1 +tox-battery>=0.5 diff --git a/requirements/test.txt b/requirements/test.txt new file mode 100644 index 0000000..3b7fd2b --- /dev/null +++ b/requirements/test.txt @@ -0,0 +1,11 @@ +# pytest runner + plugins +pytest-django>=3.1.2 +pytest-pythonpath>=0.7.1 +pytest-cov>=2.5.1 + +# test project requirements +Pillow>=4.3.0 +pygments>=2.2.0 +django-cors-headers>=2.1.0 +django-filter>=1.1.0,<2.0; python_version == "2.7" +django-filter>=1.1.0; python_version >= "3.4" diff --git a/requirements/validation.txt b/requirements/validation.txt new file mode 100644 index 0000000..65e9d43 --- /dev/null +++ b/requirements/validation.txt @@ -0,0 +1,3 @@ +# requirements for the validation feature +flex>=6.11.1 +swagger-spec-validator>=2.1.0 diff --git a/requirements_dev.txt b/requirements_dev.txt deleted file mode 100644 index 77209fb..0000000 --- a/requirements_dev.txt +++ /dev/null @@ -1,5 +0,0 @@ -# Packages required for development and CI -tox>=2.9.1 -tox-battery>=0.5 -tox-travis>=0.10 -python-coveralls>=2.9.1 diff --git a/requirements_test.txt b/requirements_test.txt deleted file mode 100644 index d090a94..0000000 --- a/requirements_test.txt +++ /dev/null @@ -1,6 +0,0 @@ -# Packages required for running the tests -pygments>=2.2.0 -django-cors-headers>=2.1.0 -pytest-django>=3.1.2 -pytest-pythonpath>=0.7.1 -pytest-cov>=2.5.1 diff --git a/requirements_validation.txt b/requirements_validation.txt deleted file mode 100644 index cc26ca0..0000000 --- a/requirements_validation.txt +++ /dev/null @@ -1,3 +0,0 @@ -# Packages required for the validation feature -flex>=6.11.1 -swagger-spec-validator>=2.1.0 diff --git a/setup.py b/setup.py index 6d196f9..29240b0 100644 --- a/setup.py +++ b/setup.py @@ -1,17 +1,18 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- +import os + from setuptools import setup, find_packages def read_req(req_file): - with open(req_file) as req: + with open(os.path.join('requirements', req_file)) as req: return [line for line in req.readlines() if line and not line.isspace()] -requirements = read_req('requirements.txt') -requirements_validation = read_req('requirements_validation.txt') -requirements_dev = read_req('requirements_dev.txt') -requirements_test = read_req('requirements_test.txt') +requirements = ['djangorestframework>=3.7.3'] + read_req('base.txt') +requirements_validation = read_req('validation.txt') +requirements_test = read_req('test.txt') setup( name='drf-swagger', @@ -24,7 +25,6 @@ setup( extras_require={ 'validation': requirements_validation, 'test': requirements_test, - 'dev': requirements_dev, }, license='BSD License', description='Automated generation of real Swagger/OpenAPI 2.0 schemas from Django Rest Framework code.', diff --git a/src/drf_swagger/codecs.py b/src/drf_swagger/codecs.py index 54c8a6e..fd2ab14 100644 --- a/src/drf_swagger/codecs.py +++ b/src/drf_swagger/codecs.py @@ -1,52 +1,42 @@ import json from collections import OrderedDict -from coreapi.codecs import BaseCodec -from coreapi.compat import force_bytes, urlparse -from drf_swagger.app_settings import swagger_settings -from openapi_codec import encode +from coreapi.compat import force_bytes +from future.utils import raise_from from ruamel import yaml +from drf_swagger.app_settings import swagger_settings +from drf_swagger.errors import SwaggerValidationError from . import openapi -class SwaggerValidationError(Exception): - def __init__(self, msg, validator_name, spec, *args) -> None: - super(SwaggerValidationError, self).__init__(msg, *args) - self.validator_name = validator_name - self.spec = spec - - def __str__(self): - return str(self.validator_name) + ": " + super(SwaggerValidationError, self).__str__() - - -def _validate_flex(spec): +def _validate_flex(spec, codec): from flex.core import parse as validate_flex from flex.exceptions import ValidationError try: validate_flex(spec) except ValidationError as ex: - raise SwaggerValidationError(str(ex), 'flex', spec) from ex + raise_from(SwaggerValidationError(str(ex), 'flex', spec, codec), ex) -def _validate_swagger_spec_validator(spec): +def _validate_swagger_spec_validator(spec, codec): from swagger_spec_validator.validator20 import validate_spec as validate_ssv from swagger_spec_validator.common import SwaggerValidationError as SSVErr try: validate_ssv(spec) except SSVErr as ex: - raise SwaggerValidationError(str(ex), 'swagger_spec_validator', spec) from ex + raise_from(SwaggerValidationError(str(ex), 'swagger_spec_validator', spec, codec), ex) VALIDATORS = { 'flex': _validate_flex, - 'swagger_spec_validator': _validate_swagger_spec_validator, 'ssv': _validate_swagger_spec_validator, } -class _OpenAPICodec(BaseCodec): +class _OpenAPICodec(object): format = 'openapi' + media_type = None def __init__(self, validators): self._validators = validators @@ -61,43 +51,30 @@ class _OpenAPICodec(BaseCodec): spec = self.generate_swagger_object(document) for validator in self.validators: - VALIDATORS[validator](spec) - return force_bytes(self._dump_spec(spec)) + VALIDATORS[validator](spec, self) + return force_bytes(self._dump_dict(spec)) - def _dump_spec(self, spec): - return NotImplementedError("override this method") + def encode_error(self, err): + return force_bytes(self._dump_dict(err)) + + def _dump_dict(self, spec): + raise NotImplementedError("override this method") def generate_swagger_object(self, swagger): """ - Generates root of the Swagger spec. + Generates the root Swagger object. :param openapi.Swagger swagger: :return OrderedDict: swagger spec as dict """ - parsed_url = urlparse.urlparse(swagger.url) - - spec = OrderedDict() - - spec['swagger'] = '2.0' - spec['info'] = swagger.info.to_swagger(swagger.version) - - if parsed_url.netloc: - spec['host'] = parsed_url.netloc - if parsed_url.scheme: - spec['schemes'] = [parsed_url.scheme] - spec['basePath'] = '/' - - spec['paths'] = encode._get_paths_object(swagger) - - spec['securityDefinitions'] = swagger_settings.SECURITY_DEFINITIONS - - return spec + swagger.security_definitions = swagger_settings.SECURITY_DEFINITIONS + return swagger class OpenAPICodecJson(_OpenAPICodec): - media_type = 'application/openapi+json' + media_type = 'application/json' - def _dump_spec(self, spec): + def _dump_dict(self, spec): return json.dumps(spec) @@ -109,7 +86,7 @@ class SaneYamlDumper(yaml.SafeDumper): return super(SaneYamlDumper, self).increase_indent(flow=flow, indentless=False, **kwargs) @staticmethod - def represent_odict(dump, mapping, flow_style=None): + def represent_odict(dump, mapping, flow_style=None): # pragma: no cover """https://gist.github.com/miracle2k/3184458 Make PyYAML output an OrderedDict. @@ -142,11 +119,11 @@ class SaneYamlDumper(yaml.SafeDumper): return node -SaneYamlDumper.add_representer(OrderedDict, SaneYamlDumper.represent_odict) +SaneYamlDumper.add_multi_representer(OrderedDict, SaneYamlDumper.represent_odict) class OpenAPICodecYaml(_OpenAPICodec): - media_type = 'application/openapi+yaml' + media_type = 'application/yaml' - def _dump_spec(self, spec): + def _dump_dict(self, spec): return yaml.dump(spec, Dumper=SaneYamlDumper, default_flow_style=False, encoding='utf-8') diff --git a/src/drf_swagger/errors.py b/src/drf_swagger/errors.py new file mode 100644 index 0000000..745ec57 --- /dev/null +++ b/src/drf_swagger/errors.py @@ -0,0 +1,14 @@ +class SwaggerError(Exception): + pass + + +class SwaggerValidationError(SwaggerError): + def __init__(self, msg, validator_name, spec, source_codec, *args): + super(SwaggerValidationError, self).__init__(msg, *args) + self.validator_name = validator_name + self.spec = spec + self.source_codec = source_codec + + +class SwaggerGenerationError(SwaggerError): + pass diff --git a/src/drf_swagger/generators.py b/src/drf_swagger/generators.py index 23e2ec0..c8bb458 100644 --- a/src/drf_swagger/generators.py +++ b/src/drf_swagger/generators.py @@ -1,15 +1,118 @@ -from rest_framework.schemas import SchemaGenerator as _SchemaGenerator +from collections import defaultdict +import django.db.models +import uritemplate +from coreapi.compat import force_text +from rest_framework.schemas.generators import SchemaGenerator +from rest_framework.schemas.inspectors import get_pk_description + +from drf_swagger.inspectors import SwaggerAutoSchema from . import openapi -class OpenAPISchemaGenerator(_SchemaGenerator): +class OpenAPISchemaGenerator(object): + """ + This class iterates over all registered API endpoints and returns an appropriate OpenAPI 2.0 compliant schema. + Method implementations shamelessly stolen and adapted from rest_framework SchemaGenerator. + """ + def __init__(self, info, version, url=None, patterns=None, urlconf=None): - super(OpenAPISchemaGenerator, self).__init__(info.title, url, info.description, patterns, urlconf) + self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf) self.info = info self.version = version + self.endpoints = None + self.url = url def get_schema(self, request=None, public=False): - document = super(OpenAPISchemaGenerator, self).get_schema(request, public) - swagger = openapi.Swagger.from_coreapi(document, self.info, self.version) - return swagger + """Generate an openapi.Swagger representing the API schema.""" + if self.endpoints is None: + inspector = self._gen.endpoint_inspector_cls(self._gen.patterns, self._gen.urlconf) + self.endpoints = inspector.get_api_endpoints() + + self.get_endpoints(None if public else request) + paths = self.get_paths() + + url = self._gen.url + if not url and request is not None: + url = request.build_absolute_uri() + + # distribute_links(links) + return openapi.Swagger( + info=self.info, paths=paths, + _url=url, _version=self.version, + ) + + def get_endpoints(self, request): + """Generate {path: (view_class, [(method, view)]) given (path, method, callback).""" + view_paths = defaultdict(list) + view_cls = {} + for path, method, callback in self.endpoints: + view = self._gen.create_view(callback, method, request) + path = self._gen.coerce_path(path, method, view) + view_paths[path].append((method, view)) + view_cls[path] = callback.cls + self.endpoints = {path: (view_cls[path], methods) for path, methods in view_paths.items()} + + def get_paths(self): + if not self.endpoints: + return [] + prefix = self._gen.determine_path_prefix(self.endpoints.keys()) + paths = {} + + for path, (view_cls, methods) in self.endpoints.items(): + path_parameters = self.get_path_parameters(path, view_cls) + operations = {} + for method, view in methods: + if not self._gen.has_view_permissions(path, method, view): + continue + + schema = SwaggerAutoSchema(view) + operation_keys = self._gen.get_keys(path[len(prefix):], method, view) + operations[method.lower()] = schema.get_operation(operation_keys, path, method) + + paths[path] = openapi.PathItem(parameters=path_parameters, **operations) + + return openapi.Paths(paths=paths) + + def get_path_parameters(self, path, view_cls): + """Return a list of Parameter instances corresponding to any templated path variables. + + :param str path: templated request path + :param type view_cls: the view class associated with the path + :return list[openapi.Parameter]: path parameters + """ + parameters = [] + model = getattr(getattr(view_cls, 'queryset', None), 'model', None) + + for variable in uritemplate.variables(path): + pattern = None + type = openapi.TYPE_STRING + description = None + if model is not None: + # Attempt to infer a field description if possible. + try: + model_field = model._meta.get_field(variable) + except Exception: + model_field = None + + if model_field is not None and model_field.help_text: + description = force_text(model_field.help_text) + elif model_field is not None and model_field.primary_key: + description = get_pk_description(model, model_field) + + if hasattr(view_cls, 'lookup_value_regex') and getattr(view_cls, 'lookup_field', None) == variable: + pattern = view_cls.lookup_value_regex + elif isinstance(model_field, django.db.models.AutoField): + type = openapi.TYPE_INTEGER + + field = openapi.Parameter( + name=variable, + required=True, + in_=openapi.IN_PATH, + type=type, + pattern=pattern, + description=description, + ) + parameters.append(field) + + return parameters diff --git a/src/drf_swagger/inspectors.py b/src/drf_swagger/inspectors.py index 79786a3..3dba1ff 100644 --- a/src/drf_swagger/inspectors.py +++ b/src/drf_swagger/inspectors.py @@ -1,5 +1,362 @@ +import functools +from collections import OrderedDict + +import coreschema +from django.core.validators import RegexValidator +from django.utils.encoding import force_text +from rest_framework import serializers from rest_framework.schemas import AutoSchema +from rest_framework.schemas.utils import is_list_view + +from drf_swagger.errors import SwaggerGenerationError +from . import openapi -class SwaggerAutoSchema(AutoSchema): - pass +def find_regex(regex_field): + regex_validator = None + for validator in regex_field.validators: + if isinstance(validator, RegexValidator): + if regex_validator is not None: + # bail if multiple validators are found - no obvious way to choose + return None + regex_validator = validator + + # regex_validator.regex should be a compiled re object... + return getattr(getattr(regex_validator, 'regex', None), 'pattern', None) + + +class SwaggerAutoSchema(object): + def __init__(self, view): + super(SwaggerAutoSchema, self).__init__() + self._sch = AutoSchema() + self.view = view + self._sch.view = view + + def get_operation(self, operation_keys, path, method): + """Get an Operation for the given API endpoint (path, method). + This includes query, body parameters and response schemas. + + :param tuple[str] operation_keys: an array of keys describing the hierarchical layout of this view in the API; + e.g. ('snippets', 'list'), ('snippets', 'retrieve'), etc. + :param str path: the view's path + :param str method: HTTP request method + :return openapi.Operation: the resulting Operation object + """ + body = self.get_request_body_parameters(path, method) + query = self.get_query_parameters(path, method) + parameters = body + query + + parameters = [param for param in parameters if param is not None] + description = self.get_description(path, method) + responses = self.get_responses(path, method) + return openapi.Operation( + operation_id='_'.join(operation_keys), + description=description, + responses=responses, + parameters=parameters, + tags=[operation_keys[0]] + ) + + def get_request_body_parameters(self, path, method): + """Return the request body parameters for this view. + This is either: + - a list with a single object Parameter with a Schema derived from the request serializer + - a list of primitive Parameters parsed as form data + + :param str path: the view's path + :param str method: HTTP request method + :return list[Parameter]: a (potentially empty) list of openapi.Parameter in: either `body` or `formData` + """ + # only PUT, PATCH or POST can have a request body + if method not in ('PUT', 'PATCH', 'POST'): + return [] + + serializer = self.get_request_serializer(path, method) + if serializer is None: + return [] + + encoding = self._sch.get_encoding(path, method) + if 'form' in encoding: + return [ + self.field_to_swagger(value, openapi.Parameter, name=key, in_=openapi.IN_FORM) + for key, value + in serializer.fields.items() + ] + else: + schema = self.get_request_body_schema(path, method, serializer) + return [openapi.Parameter(name='data', in_=openapi.IN_BODY, schema=schema)] + + def get_request_serializer(self, path, method): + """Return the request serializer (used for parsing the request payload) for this endpoint. + + :param str path: the view's path + :param str method: HTTP request method + :return serializers.Serializer: the request serializer + """ + # TODO: only GenericAPIViews have defined serializers; + # APIViews and plain ViewSets will need some kind of manual treatment + if not hasattr(self.view, 'get_serializer'): + return None + + return self.view.get_serializer() + + def get_request_body_schema(self, path, method, serializer): + """Return the Schema for a given request's body data. Only applies to PUT, PATCH and POST requests. + + :param str path: the view's path + :param str method: HTTP request method + :param serializer: the view's request serialzier + :return openapi.Schema: the request body schema + """ + return self.field_to_swagger(serializer, openapi.Schema) + + def get_responses(self, path, method): + """Get the possible responses for this view as a swagger Responses object. + + :param str path: the view's path + :param str method: HTTP request method + :return Responses: the documented responses + """ + response_serializers = self.get_response_serializers(path, method) + return openapi.Responses( + responses=self.get_response_schemas(path, method, response_serializers) + ) + + def get_response_serializers(self, path, method): + """Return the response codes that this view is expected to return, and the serializer for each response body. + The return value should be a dict where the keys are possible status codes, and values are either strings, + `Serializer`s or `openapi.Response` objects. + + :param str path: the view's path + :param str method: HTTP request method + :return dict: the response serializers + """ + if method.lower() == 'post': + return {'201': ''} + if method.lower() == 'delete': + return {'204': ''} + return {'200': ''} + + def get_response_schemas(self, path, method, response_serializers): + """Return the `openapi.Response` objects calculated for this view. + + :param str path: the view's path + :param str method: HTTP request method + :param dict response_serializers: result of get_response_serializers + :return dict[str, openapi.Response]: a dictionary of status code to Response object + """ + responses = {} + for status, serializer in response_serializers.items(): + if isinstance(serializer, str): + response = openapi.Response( + description=serializer + ) + elif isinstance(serializer, openapi.Response): + response = serializer + else: + response = openapi.Response( + description='', + schema=self.field_to_swagger(serializer, openapi.Schema) + ) + + responses[str(status)] = response + + return responses + + def get_query_parameters(self, path, method): + """Return the query parameters accepted by this view. + + :param str path: the view's path + :param str method: HTTP request method + :return list[openapi.Parameter]: the query parameters + """ + return self.get_filter_parameters(path, method) + self.get_pagination_parameters(path, method) + + def get_filter_parameters(self, path, method): + """Return the parameters added to the view by its filter backends. + + :param str path: the view's path + :param str method: HTTP request method + :return list[openapi.Parameter]: the filter query parameters + """ + if not self._sch._allows_filters(path, method): + return [] + + fields = [] + for filter_backend in self.view.filter_backends: + filter = filter_backend() + if hasattr(filter, 'get_schema_fields'): + fields += filter.get_schema_fields(self.view) + return [self.coreapi_field_to_parameter(field) for field in fields] + + def get_pagination_parameters(self, path, method): + """Return the parameters added to the view by its paginator. + + :param str path: the view's path + :param str method: HTTP request method + :return list[openapi.Parameter]: the pagination query parameters + """ + if not is_list_view(path, method, self.view): + return [] + + paginator = getattr(self.view, 'paginator', None) + if paginator is None: + return [] + + return [ + self.coreapi_field_to_parameter(field) + for field in paginator.get_schema_fields(self.view) + ] + + def coreapi_field_to_parameter(self, field): + """Convert an instance of `coreapi.Field` to a swagger Parameter object. + + :param coreapi.Field field: the coreapi field + :return openapi.Parameter: the equivalent openapi primitive Parameter + """ + location_to_in = { + 'query': openapi.IN_QUERY, + 'path': openapi.IN_PATH, + 'form': openapi.IN_FORM, + 'body': openapi.IN_FORM, + } + coreapi_types = { + coreschema.Integer: openapi.TYPE_INTEGER, + coreschema.Number: openapi.TYPE_NUMBER, + coreschema.String: openapi.TYPE_STRING, + coreschema.Boolean: openapi.TYPE_BOOLEAN, + } + return openapi.Parameter( + name=field.name, + in_=location_to_in[field.location], + type=coreapi_types.get(field.schema.__class__, openapi.TYPE_STRING), + required=field.required, + description=field.schema.description, + ) + + def get_description(self, path, method): + """Return an operation description determined as appropriate from the view's method and class docstrings. + + :param str path: the view's path + :param str method: HTTP request method + :return str: the operation description + """ + return self._sch.get_description(path, method) + + def field_to_swagger(self, field, swagger_object_type, **kwargs): + """Convert a drf Serializer or Field instance into a Swagger object. + + :param rest_framework.serializers.Field field: the source field + :param type swagger_object_type: should be one of Schema, Parameter, Items + :param kwargs: extra attributes for constructing the object; + if swagger_object_type is Parameter, `name` and `in_` should be provided + :return Swagger,Parameter,Items: the swagger object + """ + assert swagger_object_type in (openapi.Schema, openapi.Parameter, openapi.Items) + title = force_text(field.label) if field.label else None + title = title if swagger_object_type == openapi.Schema else None # only Schema has title + title = None + description = force_text(field.help_text) if field.help_text else None + description = description if swagger_object_type != openapi.Items else None # Items has no description either + + SwaggerType = functools.partial(swagger_object_type, title=title, description=description, **kwargs) + # arrays in Schema have Schema elements, arrays in Parameter and Items have Items elements + ChildSwaggerType = openapi.Schema if swagger_object_type == openapi.Schema else openapi.Items + + # ------ NESTED + if isinstance(field, (serializers.ListSerializer, serializers.ListField)): + child_schema = self.field_to_swagger(field.child, ChildSwaggerType) + return SwaggerType( + type=openapi.TYPE_ARRAY, + items=child_schema, + ) + elif isinstance(field, serializers.Serializer): + if swagger_object_type != openapi.Schema: + raise SwaggerGenerationError("cannot instantiate nested serializer as " + + swagger_object_type.__name__) + return SwaggerType( + type=openapi.TYPE_OBJECT, + properties=OrderedDict( + (key, self.field_to_swagger(value, ChildSwaggerType)) + for key, value + in field.fields.items() + ) + ) + elif isinstance(field, serializers.ManyRelatedField): + child_schema = self.field_to_swagger(field.child_relation, ChildSwaggerType) + return SwaggerType( + type=openapi.TYPE_ARRAY, + items=child_schema, + unique_items=True, # is this OK? + ) + elif isinstance(field, serializers.RelatedField): + # TODO: infer type for PrimaryKeyRelatedField? + return SwaggerType(type=openapi.TYPE_STRING) + # ------ CHOICES + elif isinstance(field, serializers.MultipleChoiceField): + return SwaggerType( + type=openapi.TYPE_ARRAY, + items=ChildSwaggerType( + type=openapi.TYPE_STRING, + enum=list(field.choices.keys()) + ) + ) + elif isinstance(field, serializers.ChoiceField): + return SwaggerType(type=openapi.TYPE_STRING, enum=list(field.choices.keys())) + # ------ BOOL + elif isinstance(field, serializers.BooleanField): + return SwaggerType(type=openapi.TYPE_BOOLEAN) + # ------ NUMERIC + elif isinstance(field, (serializers.DecimalField, serializers.FloatField)): + # TODO: min_value max_value + return SwaggerType(type=openapi.TYPE_NUMBER) + elif isinstance(field, serializers.IntegerField): + # TODO: min_value max_value + return SwaggerType(type=openapi.TYPE_INTEGER) + # ------ STRING + elif isinstance(field, serializers.EmailField): + return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_EMAIL) + elif isinstance(field, serializers.RegexField): + return SwaggerType(type=openapi.TYPE_STRING, pattern=find_regex(field)) + elif isinstance(field, serializers.SlugField): + return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_SLUG) + elif isinstance(field, serializers.URLField): + return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI) + elif isinstance(field, serializers.IPAddressField): + format = {'ipv4': openapi.FORMAT_IPV4, 'ipv6': openapi.FORMAT_IPV6}.get(field.protocol, None) + return SwaggerType(type=openapi.TYPE_STRING, format=format) + elif isinstance(field, serializers.CharField): + # TODO: min_length max_length (for all CharField subclasses above too) + return SwaggerType(type=openapi.TYPE_STRING) + elif isinstance(field, serializers.UUIDField): + return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_UUID) + # ------ DATE & TIME + elif isinstance(field, serializers.DateField): + return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_DATE) + elif isinstance(field, serializers.DateTimeField): + return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_DATETIME) + # ------ OTHERS + elif isinstance(field, serializers.FileField): + # swagger 2.0 does not support specifics about file fields, so ImageFile gets no special treatment + # OpenAPI 3.0 does support it, so a future implementation could handle this better + # TODO: appropriate produces/consumes somehow/somewhere? + if swagger_object_type != openapi.Parameter: + raise SwaggerGenerationError("parameter of type file is supported only in formData Parameter") + return SwaggerType(type=openapi.TYPE_FILE) + elif isinstance(field, serializers.JSONField): + return SwaggerType( + type=openapi.TYPE_STRING, + format=openapi.FORMAT_BINARY if field.binary else None + ) + elif isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema: + child_schema = self.field_to_swagger(field.child, ChildSwaggerType) + return SwaggerType( + type=openapi.TYPE_OBJECT, + additional_properties=child_schema + ) + + # TODO unhandled fields: TimeField DurationField HiddenField ModelField NullBooleanField? + # TODO: return info about required/allowed empty + + # everything else gets string by default + return SwaggerType(type=openapi.TYPE_STRING) diff --git a/src/drf_swagger/middleware.py b/src/drf_swagger/middleware.py new file mode 100644 index 0000000..38a714d --- /dev/null +++ b/src/drf_swagger/middleware.py @@ -0,0 +1,18 @@ +from django.http import HttpResponse +from django.utils.deprecation import MiddlewareMixin + +from drf_swagger.errors import SwaggerValidationError +from .codecs import _OpenAPICodec + + +class SwaggerExceptionMiddleware(MiddlewareMixin): + def process_exception(self, request, exception): + if isinstance(exception, SwaggerValidationError): + err = {'errors': {exception.validator_name: str(exception)}} + codec = exception.source_codec + if isinstance(codec, _OpenAPICodec): + err = codec.encode_error(err) + content_type = codec.media_type + return HttpResponse(err, status=500, content_type=content_type) + + return None diff --git a/src/drf_swagger/openapi.py b/src/drf_swagger/openapi.py index 9434a15..8858cab 100644 --- a/src/drf_swagger/openapi.py +++ b/src/drf_swagger/openapi.py @@ -1,10 +1,105 @@ -import warnings from collections import OrderedDict -import coreapi +from coreapi.compat import urlparse +from future.utils import raise_from +from inflection import camelize + +TYPE_OBJECT = "object" +TYPE_STRING = "string" +TYPE_NUMBER = "number" +TYPE_INTEGER = "integer" +TYPE_BOOLEAN = "boolean" +TYPE_ARRAY = "array" +TYPE_FILE = "file" + +# officially supported by Swagger 2.0 spec +FORMAT_DATE = "date" +FORMAT_DATETIME = "date-time" +FORMAT_PASSWORD = "password" +FORMAT_BINARY = "binary" +FORMAT_BASE64 = "bytes" +FORMAT_FLOAT = "float" +FORMAT_DOUBLE = "double" +FORMAT_INT32 = "int32" +FORMAT_INT64 = "int64" + +# defined in JSON-schema +FORMAT_EMAIL = "email" +FORMAT_IPV4 = "ipv4" +FORMAT_IPV6 = "ipv6" +FORMAT_URI = "uri" + +# pulled out of my ass +FORMAT_UUID = "uuid" +FORMAT_SLUG = "slug" + +IN_BODY = 'body' +IN_PATH = 'path' +IN_QUERY = 'query' +IN_FORM = 'formData' +IN_HEADER = 'header' -class Contact(object): +def make_swagger_name(attribute_name): + """ + Convert a python variable name into a Swagger spec attribute name. + + In particular, + * if name starts with x_, return "x-{camelCase}" + * if name is 'ref', return "$ref" + * else return the name converted to camelCase, with trailing underscores stripped + + :param str attribute_name: python attribute name + :return: swagger name + """ + if attribute_name == 'ref': + return "$ref" + if attribute_name.startswith("x_"): + return "x-" + camelize(attribute_name[2:], uppercase_first_letter=False) + return camelize(attribute_name.rstrip('_'), uppercase_first_letter=False) + + +class SwaggerDict(OrderedDict): + def __init__(self, **attrs): + super(SwaggerDict, self).__init__() + self._extras__ = attrs + if self.__class__ == SwaggerDict: + self._insert_extras__() + + def __setattr__(self, key, value): + if key.startswith('_'): + super(SwaggerDict, self).__setattr__(key, value) + return + if value is not None: + self[make_swagger_name(key)] = value + + def __getattr__(self, item): + if item.startswith('_'): + raise AttributeError + try: + return self[make_swagger_name(item)] + except KeyError as e: + raise_from(AttributeError("no attribute " + item), e) + + def __delattr__(self, item): + if item.startswith('_'): + super(SwaggerDict, self).__delattr__(item) + return + del self[make_swagger_name(item)] + + def _insert_extras__(self): + """ + From an ordering perspective, it is desired that extra attributes such as vendor extensions stay at the + bottom of the object. However, python2.7's OrderdDict craps out if you try to insert into it before calling + init. This means that subclasses must call super().__init__ as the first statement of their own __init__, + which would result in the extra attributes being added first. For this reason, we defer the insertion of the + attributes and require that subclasses call ._insert_extras__ at the end of their __init__ method. + """ + for attr, val in self._extras__.items(): + setattr(self, attr, val) + + +class Contact(SwaggerDict): """Swagger Contact object At least one of the following fields is required: @@ -12,47 +107,34 @@ class Contact(object): :param str url: contact url :param str email: contact e-mail """ - def __init__(self, name=None, url=None, email=None): + + def __init__(self, name=None, url=None, email=None, **extra): + super(Contact, self).__init__(**extra) + if name is None and url is None and email is None: + raise ValueError("one of name, url or email is requires for Swagger Contact object") self.name = name self.url = url self.email = email - if name is None and url is None and email is None: - raise ValueError("one of name, url or email is requires for Swagger Contact object") - - def to_swagger(self): - contact = OrderedDict() - if self.name is not None: - contact['name'] = self.name - if self.url is not None: - contact['url'] = self.url - if self.email is not None: - contact['email'] = self.email - - return contact + self._insert_extras__() -class License(object): +class License(SwaggerDict): """Swagger License object :param str name: Requird. License name :param str url: link to detailed license information """ - def __init__(self, name, url=None): - self.name = name - self.url = url + + def __init__(self, name, url=None, **extra): + super(License, self).__init__(**extra) if name is None: raise ValueError("name is required for Swagger License object") - - def to_swagger(self): - license = OrderedDict() - license['name'] = self.name - if self.url is not None: - license['url'] = self.url - - return license + self.name = name + self.url = url + self._insert_extras__() -class Info(object): +class Info(SwaggerDict): """Swagger Info object :param str title: Required. API title. @@ -62,7 +144,10 @@ class Info(object): :param Contact contact: contact object :param License license: license object """ - def __init__(self, title, default_version, description=None, terms_of_service=None, contact=None, license=None): + + def __init__(self, title, default_version, description=None, terms_of_service=None, contact=None, license=None, + **extra): + super(Info, self).__init__(**extra) if title is None or default_version is None: raise ValueError("title and version are required for Swagger info object") if contact is not None and not isinstance(contact, Contact): @@ -70,67 +155,139 @@ class Info(object): if license is not None and not isinstance(license, License): raise ValueError("license must be a License object") self.title = title - self.default_version = default_version + self._default_version = default_version self.description = description self.terms_of_service = terms_of_service self.contact = contact self.license = license - - def to_swagger(self, version): - info = OrderedDict() - info['title'] = self.title - if self.description is not None: - info['description'] = self.description - if self.terms_of_service is not None: - info['termsOfService'] = self.terms_of_service - if self.contact is not None: - info['contact'] = self.contact.to_swagger() - if self.license is not None: - info['license'] = self.license.to_swagger() - info['version'] = version or self.default_version - return info + self._insert_extras__() -class Swagger(coreapi.Document): - @classmethod - def from_coreapi(cls, document, info, version): - """ - Create an openapi.Swagger from the fields of a coreapi.Document. +class Swagger(SwaggerDict): + def __init__(self, info=None, _url=None, _version=None, paths=None, **extra): + super(Swagger, self).__init__(**extra) + self.swagger = '2.0' + self.info = info + self.info.version = _version or info._default_version + self.paths = paths - :param coreapi.Document document: source coreapi.Document - :param openapi.Info info: Swagger info object - :param string version: API version string - :return: an openapi.Swagger - """ - if document.title and document.title != info.title: - warnings.warn("document title is overriden by Swagger Info") - if document.description and document.description != info.description: - warnings.warn("document description is overriden by Swagger Info") - return Swagger( - info=info, - version=version, - url=document.url, - media_type=document.media_type, - content=document.data - ) + if _url: + url = urlparse.urlparse(_url) + if url.netloc: + self.host = url.netloc + if url.scheme: + self.schemes = [url.scheme] - def __init__(self, info=None, version=None, url=None, media_type=None, content=None): - super(Swagger, self).__init__(url, info.title, info.description, media_type, content) - self._info = info - self._version = version - - @property - def info(self): - return self._info - - @property - def version(self): - return self._version + self.base_path = '/' + self._insert_extras__() -class Field(coreapi.Field): - pass +class Paths(SwaggerDict): + def __init__(self, paths, **extra): + super(Paths, self).__init__(**extra) + for path, path_obj in paths.items(): + assert path.startswith("/") + if path_obj is not None: + self[path] = path_obj + self._insert_extras__() -class Link(coreapi.Link): - pass +class PathItem(SwaggerDict): + def __init__(self, get=None, put=None, post=None, delete=None, options=None, + head=None, patch=None, parameters=None, **extra): + super(PathItem, self).__init__(**extra) + self.get = get + self.put = put + self.post = post + self.delete = delete + self.options = options + self.head = head + self.patch = patch + self.parameters = parameters + self._insert_extras__() + + +class Operation(SwaggerDict): + def __init__(self, operation_id, responses, parameters=None, consumes=None, + produces=None, description=None, tags=None, **extra): + super(Operation, self).__init__(**extra) + self.operation_id = operation_id + self.responses = responses + self.parameters = [param for param in parameters if param is not None] + self.consumes = consumes + self.produces = produces + self.description = description + self.tags = tags + self._insert_extras__() + + +class Items(SwaggerDict): + def __init__(self, type=None, format=None, enum=None, pattern=None, items=None, **extra): + super(Items, self).__init__(**extra) + self.type = type + self.format = format + self.enum = enum + self.pattern = pattern + self.items = items + self._insert_extras__() + + +class Parameter(SwaggerDict): + def __init__(self, name, in_, description=None, required=None, schema=None, + type=None, format=None, enum=None, pattern=None, items=None, **extra): + super(Parameter, self).__init__(**extra) + if (not schema and not type) or (schema and type): + raise ValueError("either schema or type are required for Parameter object!") + self.name = name + self.in_ = in_ + self.description = description + self.required = required + self.schema = schema + self.type = type + self.format = format + self.enum = enum + self.pattern = pattern + self.items = items + self._insert_extras__() + + +class Schema(SwaggerDict): + def __init__(self, description=None, required=None, type=None, properties=None, additional_properties=None, + format=None, enum=None, pattern=None, items=None, **extra): + super(Schema, self).__init__(**extra) + self.description = description + self.required = required + self.type = type + self.properties = properties + self.additional_properties = additional_properties + self.format = format + self.enum = enum + self.pattern = pattern + self.items = items + self._insert_extras__() + + +class Ref(SwaggerDict): + def __init__(self, ref): + super(Ref, self).__init__() + self.ref = ref + self._insert_extras__() + + +class Responses(SwaggerDict): + def __init__(self, responses, default=None, **extra): + super(Responses, self).__init__(**extra) + for status, response in responses.items(): + if response is not None: + self[str(status)] = response + self.default = default + self._insert_extras__() + + +class Response(SwaggerDict): + def __init__(self, description, schema=None, examples=None, **extra): + super(Response, self).__init__(**extra) + self.description = description + self.schema = schema + self.examples = examples + self._insert_extras__() diff --git a/src/drf_swagger/renderers.py b/src/drf_swagger/renderers.py index 90e6d43..1d20a86 100644 --- a/src/drf_swagger/renderers.py +++ b/src/drf_swagger/renderers.py @@ -45,17 +45,17 @@ class _UIRenderer(BaseRenderer): charset = 'utf-8' template = '' - def render(self, data, accepted_media_type=None, renderer_context=None): - self.set_context(renderer_context, data) + def render(self, swagger, accepted_media_type=None, renderer_context=None): + self.set_context(renderer_context, swagger) return render( renderer_context['request'], self.template, renderer_context ) - def set_context(self, renderer_context, data): - renderer_context['title'] = data.title - renderer_context['version'] = data.version + def set_context(self, renderer_context, swagger): + renderer_context['title'] = swagger.info.title + renderer_context['version'] = swagger.info.version renderer_context['swagger_settings'] = json.dumps(self.get_swagger_ui_settings()) renderer_context['redoc_settings'] = json.dumps(self.get_redoc_settings()) renderer_context['USE_SESSION_AUTH'] = swagger_settings.USE_SESSION_AUTH diff --git a/src/drf_swagger/templates/drf-swagger/swagger-ui.html b/src/drf_swagger/templates/drf-swagger/swagger-ui.html index 1b3f58f..634da25 100644 --- a/src/drf_swagger/templates/drf-swagger/swagger-ui.html +++ b/src/drf_swagger/templates/drf-swagger/swagger-ui.html @@ -35,7 +35,7 @@ margin-right: 8px; } - #django-session-auth.hidden { + .hidden { display: none; } @@ -126,14 +126,31 @@
+ - - - + + + +