diff --git a/.coveragerc b/.coveragerc
new file mode 100644
index 0000000..9f1c676
--- /dev/null
+++ b/.coveragerc
@@ -0,0 +1,32 @@
+[run]
+source = drf_swagger
+branch = True
+
+[report]
+# Regexes for lines to exclude from consideration
+exclude_lines =
+ # Have to re-enable the standard pragma
+ pragma: no cover
+
+ # Don't complain about missing debug-only code:
+ def __repr__
+ if self/.debug
+
+ # Don't complain if tests don't hit defensive assertion code:
+ raise AssertionError
+ raise NotImplementedError
+ warnings.warn
+
+ # Don't complain if non-runnable code isn't run:
+ if 0:
+ if __name__ == .__main__.:
+
+ignore_errors = True
+precision = 0
+
+[paths]
+source =
+ src/drf_swagger/
+ .tox/*/Lib/site-packages/drf_swagger/
+ .tox/*/lib/*/site-packages/drf_swagger/
+ /home/travis/virtualenv/*/lib/*/site-packages/drf_swagger/
diff --git a/.idea/drf-swagger.iml b/.idea/drf-swagger.iml
new file mode 100644
index 0000000..30b590b
--- /dev/null
+++ b/.idea/drf-swagger.iml
@@ -0,0 +1,35 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..2af15d1
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,60 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..b6e1878
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,73 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..b317f5e
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..94a25f7
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.travis.yml b/.travis.yml
index 909e9ed..80aff3f 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,21 +1,31 @@
language: python
cache: pip
python:
+ - '2.7'
+ - '3.4'
- '3.5'
- '3.6'
- '3.7-dev'
+env:
+ - DRF=3.7
+
matrix:
fast_finish: true
include:
- - python: "3.5"
+ - python: '2.7'
env: TOXENV=flake8
+ - python: '3.6'
+ env: DRF=master
allow_failures:
- env: TOXENV=flake8
+ - env: DRF=master
+ - python: '2.7'
+ - python: '3.7'
install:
- - pip install -r requirements_dev.txt
+ - pip install -r requirements/ci.txt
before_script:
- coverage erase
@@ -24,4 +34,8 @@ script:
- tox
after_success:
- - coveralls
+ - codecov
+
+branches:
+ only:
+ - master
diff --git a/MANIFEST.in b/MANIFEST.in
index 93f2962..061705b 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,5 +1,5 @@
include README.md
include LICENSE
-include requirements*
+recursive-include requirements *
recursive-include src/drf_swagger/static *
recursive-include src/drf_swagger/templates *
diff --git a/requirements.txt b/requirements/base.txt
similarity index 52%
rename from requirements.txt
rename to requirements/base.txt
index 9536b6c..77b4da0 100644
--- a/requirements.txt
+++ b/requirements/base.txt
@@ -1,5 +1,6 @@
-djangorestframework>=3.7.3
-django>=1.11.7
coreapi>=2.3.3
+coreschema>=0.0.4
openapi_codec>=1.3.2
ruamel.yaml>=0.15.34
+inflection>=0.3.1
+future>=0.16.0
diff --git a/requirements/ci.txt b/requirements/ci.txt
new file mode 100644
index 0000000..6f65907
--- /dev/null
+++ b/requirements/ci.txt
@@ -0,0 +1,4 @@
+# requirements for CI test suite
+-r dev.txt
+tox-travis>=0.10
+codecov>=2.0.9
diff --git a/requirements/dev.txt b/requirements/dev.txt
new file mode 100644
index 0000000..a278302
--- /dev/null
+++ b/requirements/dev.txt
@@ -0,0 +1,3 @@
+# requirements for local development
+tox>=2.9.1
+tox-battery>=0.5
diff --git a/requirements/test.txt b/requirements/test.txt
new file mode 100644
index 0000000..3b7fd2b
--- /dev/null
+++ b/requirements/test.txt
@@ -0,0 +1,11 @@
+# pytest runner + plugins
+pytest-django>=3.1.2
+pytest-pythonpath>=0.7.1
+pytest-cov>=2.5.1
+
+# test project requirements
+Pillow>=4.3.0
+pygments>=2.2.0
+django-cors-headers>=2.1.0
+django-filter>=1.1.0,<2.0; python_version == "2.7"
+django-filter>=1.1.0; python_version >= "3.4"
diff --git a/requirements/validation.txt b/requirements/validation.txt
new file mode 100644
index 0000000..65e9d43
--- /dev/null
+++ b/requirements/validation.txt
@@ -0,0 +1,3 @@
+# requirements for the validation feature
+flex>=6.11.1
+swagger-spec-validator>=2.1.0
diff --git a/requirements_dev.txt b/requirements_dev.txt
deleted file mode 100644
index 77209fb..0000000
--- a/requirements_dev.txt
+++ /dev/null
@@ -1,5 +0,0 @@
-# Packages required for development and CI
-tox>=2.9.1
-tox-battery>=0.5
-tox-travis>=0.10
-python-coveralls>=2.9.1
diff --git a/requirements_test.txt b/requirements_test.txt
deleted file mode 100644
index d090a94..0000000
--- a/requirements_test.txt
+++ /dev/null
@@ -1,6 +0,0 @@
-# Packages required for running the tests
-pygments>=2.2.0
-django-cors-headers>=2.1.0
-pytest-django>=3.1.2
-pytest-pythonpath>=0.7.1
-pytest-cov>=2.5.1
diff --git a/requirements_validation.txt b/requirements_validation.txt
deleted file mode 100644
index cc26ca0..0000000
--- a/requirements_validation.txt
+++ /dev/null
@@ -1,3 +0,0 @@
-# Packages required for the validation feature
-flex>=6.11.1
-swagger-spec-validator>=2.1.0
diff --git a/setup.py b/setup.py
index 6d196f9..29240b0 100644
--- a/setup.py
+++ b/setup.py
@@ -1,17 +1,18 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
+import os
+
from setuptools import setup, find_packages
def read_req(req_file):
- with open(req_file) as req:
+ with open(os.path.join('requirements', req_file)) as req:
return [line for line in req.readlines() if line and not line.isspace()]
-requirements = read_req('requirements.txt')
-requirements_validation = read_req('requirements_validation.txt')
-requirements_dev = read_req('requirements_dev.txt')
-requirements_test = read_req('requirements_test.txt')
+requirements = ['djangorestframework>=3.7.3'] + read_req('base.txt')
+requirements_validation = read_req('validation.txt')
+requirements_test = read_req('test.txt')
setup(
name='drf-swagger',
@@ -24,7 +25,6 @@ setup(
extras_require={
'validation': requirements_validation,
'test': requirements_test,
- 'dev': requirements_dev,
},
license='BSD License',
description='Automated generation of real Swagger/OpenAPI 2.0 schemas from Django Rest Framework code.',
diff --git a/src/drf_swagger/codecs.py b/src/drf_swagger/codecs.py
index 54c8a6e..fd2ab14 100644
--- a/src/drf_swagger/codecs.py
+++ b/src/drf_swagger/codecs.py
@@ -1,52 +1,42 @@
import json
from collections import OrderedDict
-from coreapi.codecs import BaseCodec
-from coreapi.compat import force_bytes, urlparse
-from drf_swagger.app_settings import swagger_settings
-from openapi_codec import encode
+from coreapi.compat import force_bytes
+from future.utils import raise_from
from ruamel import yaml
+from drf_swagger.app_settings import swagger_settings
+from drf_swagger.errors import SwaggerValidationError
from . import openapi
-class SwaggerValidationError(Exception):
- def __init__(self, msg, validator_name, spec, *args) -> None:
- super(SwaggerValidationError, self).__init__(msg, *args)
- self.validator_name = validator_name
- self.spec = spec
-
- def __str__(self):
- return str(self.validator_name) + ": " + super(SwaggerValidationError, self).__str__()
-
-
-def _validate_flex(spec):
+def _validate_flex(spec, codec):
from flex.core import parse as validate_flex
from flex.exceptions import ValidationError
try:
validate_flex(spec)
except ValidationError as ex:
- raise SwaggerValidationError(str(ex), 'flex', spec) from ex
+ raise_from(SwaggerValidationError(str(ex), 'flex', spec, codec), ex)
-def _validate_swagger_spec_validator(spec):
+def _validate_swagger_spec_validator(spec, codec):
from swagger_spec_validator.validator20 import validate_spec as validate_ssv
from swagger_spec_validator.common import SwaggerValidationError as SSVErr
try:
validate_ssv(spec)
except SSVErr as ex:
- raise SwaggerValidationError(str(ex), 'swagger_spec_validator', spec) from ex
+ raise_from(SwaggerValidationError(str(ex), 'swagger_spec_validator', spec, codec), ex)
VALIDATORS = {
'flex': _validate_flex,
- 'swagger_spec_validator': _validate_swagger_spec_validator,
'ssv': _validate_swagger_spec_validator,
}
-class _OpenAPICodec(BaseCodec):
+class _OpenAPICodec(object):
format = 'openapi'
+ media_type = None
def __init__(self, validators):
self._validators = validators
@@ -61,43 +51,30 @@ class _OpenAPICodec(BaseCodec):
spec = self.generate_swagger_object(document)
for validator in self.validators:
- VALIDATORS[validator](spec)
- return force_bytes(self._dump_spec(spec))
+ VALIDATORS[validator](spec, self)
+ return force_bytes(self._dump_dict(spec))
- def _dump_spec(self, spec):
- return NotImplementedError("override this method")
+ def encode_error(self, err):
+ return force_bytes(self._dump_dict(err))
+
+ def _dump_dict(self, spec):
+ raise NotImplementedError("override this method")
def generate_swagger_object(self, swagger):
"""
- Generates root of the Swagger spec.
+ Generates the root Swagger object.
:param openapi.Swagger swagger:
:return OrderedDict: swagger spec as dict
"""
- parsed_url = urlparse.urlparse(swagger.url)
-
- spec = OrderedDict()
-
- spec['swagger'] = '2.0'
- spec['info'] = swagger.info.to_swagger(swagger.version)
-
- if parsed_url.netloc:
- spec['host'] = parsed_url.netloc
- if parsed_url.scheme:
- spec['schemes'] = [parsed_url.scheme]
- spec['basePath'] = '/'
-
- spec['paths'] = encode._get_paths_object(swagger)
-
- spec['securityDefinitions'] = swagger_settings.SECURITY_DEFINITIONS
-
- return spec
+ swagger.security_definitions = swagger_settings.SECURITY_DEFINITIONS
+ return swagger
class OpenAPICodecJson(_OpenAPICodec):
- media_type = 'application/openapi+json'
+ media_type = 'application/json'
- def _dump_spec(self, spec):
+ def _dump_dict(self, spec):
return json.dumps(spec)
@@ -109,7 +86,7 @@ class SaneYamlDumper(yaml.SafeDumper):
return super(SaneYamlDumper, self).increase_indent(flow=flow, indentless=False, **kwargs)
@staticmethod
- def represent_odict(dump, mapping, flow_style=None):
+ def represent_odict(dump, mapping, flow_style=None): # pragma: no cover
"""https://gist.github.com/miracle2k/3184458
Make PyYAML output an OrderedDict.
@@ -142,11 +119,11 @@ class SaneYamlDumper(yaml.SafeDumper):
return node
-SaneYamlDumper.add_representer(OrderedDict, SaneYamlDumper.represent_odict)
+SaneYamlDumper.add_multi_representer(OrderedDict, SaneYamlDumper.represent_odict)
class OpenAPICodecYaml(_OpenAPICodec):
- media_type = 'application/openapi+yaml'
+ media_type = 'application/yaml'
- def _dump_spec(self, spec):
+ def _dump_dict(self, spec):
return yaml.dump(spec, Dumper=SaneYamlDumper, default_flow_style=False, encoding='utf-8')
diff --git a/src/drf_swagger/errors.py b/src/drf_swagger/errors.py
new file mode 100644
index 0000000..745ec57
--- /dev/null
+++ b/src/drf_swagger/errors.py
@@ -0,0 +1,14 @@
+class SwaggerError(Exception):
+ pass
+
+
+class SwaggerValidationError(SwaggerError):
+ def __init__(self, msg, validator_name, spec, source_codec, *args):
+ super(SwaggerValidationError, self).__init__(msg, *args)
+ self.validator_name = validator_name
+ self.spec = spec
+ self.source_codec = source_codec
+
+
+class SwaggerGenerationError(SwaggerError):
+ pass
diff --git a/src/drf_swagger/generators.py b/src/drf_swagger/generators.py
index 23e2ec0..c8bb458 100644
--- a/src/drf_swagger/generators.py
+++ b/src/drf_swagger/generators.py
@@ -1,15 +1,118 @@
-from rest_framework.schemas import SchemaGenerator as _SchemaGenerator
+from collections import defaultdict
+import django.db.models
+import uritemplate
+from coreapi.compat import force_text
+from rest_framework.schemas.generators import SchemaGenerator
+from rest_framework.schemas.inspectors import get_pk_description
+
+from drf_swagger.inspectors import SwaggerAutoSchema
from . import openapi
-class OpenAPISchemaGenerator(_SchemaGenerator):
+class OpenAPISchemaGenerator(object):
+ """
+ This class iterates over all registered API endpoints and returns an appropriate OpenAPI 2.0 compliant schema.
+ Method implementations shamelessly stolen and adapted from rest_framework SchemaGenerator.
+ """
+
def __init__(self, info, version, url=None, patterns=None, urlconf=None):
- super(OpenAPISchemaGenerator, self).__init__(info.title, url, info.description, patterns, urlconf)
+ self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf)
self.info = info
self.version = version
+ self.endpoints = None
+ self.url = url
def get_schema(self, request=None, public=False):
- document = super(OpenAPISchemaGenerator, self).get_schema(request, public)
- swagger = openapi.Swagger.from_coreapi(document, self.info, self.version)
- return swagger
+ """Generate an openapi.Swagger representing the API schema."""
+ if self.endpoints is None:
+ inspector = self._gen.endpoint_inspector_cls(self._gen.patterns, self._gen.urlconf)
+ self.endpoints = inspector.get_api_endpoints()
+
+ self.get_endpoints(None if public else request)
+ paths = self.get_paths()
+
+ url = self._gen.url
+ if not url and request is not None:
+ url = request.build_absolute_uri()
+
+ # distribute_links(links)
+ return openapi.Swagger(
+ info=self.info, paths=paths,
+ _url=url, _version=self.version,
+ )
+
+ def get_endpoints(self, request):
+ """Generate {path: (view_class, [(method, view)]) given (path, method, callback)."""
+ view_paths = defaultdict(list)
+ view_cls = {}
+ for path, method, callback in self.endpoints:
+ view = self._gen.create_view(callback, method, request)
+ path = self._gen.coerce_path(path, method, view)
+ view_paths[path].append((method, view))
+ view_cls[path] = callback.cls
+ self.endpoints = {path: (view_cls[path], methods) for path, methods in view_paths.items()}
+
+ def get_paths(self):
+ if not self.endpoints:
+ return []
+ prefix = self._gen.determine_path_prefix(self.endpoints.keys())
+ paths = {}
+
+ for path, (view_cls, methods) in self.endpoints.items():
+ path_parameters = self.get_path_parameters(path, view_cls)
+ operations = {}
+ for method, view in methods:
+ if not self._gen.has_view_permissions(path, method, view):
+ continue
+
+ schema = SwaggerAutoSchema(view)
+ operation_keys = self._gen.get_keys(path[len(prefix):], method, view)
+ operations[method.lower()] = schema.get_operation(operation_keys, path, method)
+
+ paths[path] = openapi.PathItem(parameters=path_parameters, **operations)
+
+ return openapi.Paths(paths=paths)
+
+ def get_path_parameters(self, path, view_cls):
+ """Return a list of Parameter instances corresponding to any templated path variables.
+
+ :param str path: templated request path
+ :param type view_cls: the view class associated with the path
+ :return list[openapi.Parameter]: path parameters
+ """
+ parameters = []
+ model = getattr(getattr(view_cls, 'queryset', None), 'model', None)
+
+ for variable in uritemplate.variables(path):
+ pattern = None
+ type = openapi.TYPE_STRING
+ description = None
+ if model is not None:
+ # Attempt to infer a field description if possible.
+ try:
+ model_field = model._meta.get_field(variable)
+ except Exception:
+ model_field = None
+
+ if model_field is not None and model_field.help_text:
+ description = force_text(model_field.help_text)
+ elif model_field is not None and model_field.primary_key:
+ description = get_pk_description(model, model_field)
+
+ if hasattr(view_cls, 'lookup_value_regex') and getattr(view_cls, 'lookup_field', None) == variable:
+ pattern = view_cls.lookup_value_regex
+ elif isinstance(model_field, django.db.models.AutoField):
+ type = openapi.TYPE_INTEGER
+
+ field = openapi.Parameter(
+ name=variable,
+ required=True,
+ in_=openapi.IN_PATH,
+ type=type,
+ pattern=pattern,
+ description=description,
+ )
+ parameters.append(field)
+
+ return parameters
diff --git a/src/drf_swagger/inspectors.py b/src/drf_swagger/inspectors.py
index 79786a3..3dba1ff 100644
--- a/src/drf_swagger/inspectors.py
+++ b/src/drf_swagger/inspectors.py
@@ -1,5 +1,362 @@
+import functools
+from collections import OrderedDict
+
+import coreschema
+from django.core.validators import RegexValidator
+from django.utils.encoding import force_text
+from rest_framework import serializers
from rest_framework.schemas import AutoSchema
+from rest_framework.schemas.utils import is_list_view
+
+from drf_swagger.errors import SwaggerGenerationError
+from . import openapi
-class SwaggerAutoSchema(AutoSchema):
- pass
+def find_regex(regex_field):
+ regex_validator = None
+ for validator in regex_field.validators:
+ if isinstance(validator, RegexValidator):
+ if regex_validator is not None:
+ # bail if multiple validators are found - no obvious way to choose
+ return None
+ regex_validator = validator
+
+ # regex_validator.regex should be a compiled re object...
+ return getattr(getattr(regex_validator, 'regex', None), 'pattern', None)
+
+
+class SwaggerAutoSchema(object):
+ def __init__(self, view):
+ super(SwaggerAutoSchema, self).__init__()
+ self._sch = AutoSchema()
+ self.view = view
+ self._sch.view = view
+
+ def get_operation(self, operation_keys, path, method):
+ """Get an Operation for the given API endpoint (path, method).
+ This includes query, body parameters and response schemas.
+
+ :param tuple[str] operation_keys: an array of keys describing the hierarchical layout of this view in the API;
+ e.g. ('snippets', 'list'), ('snippets', 'retrieve'), etc.
+ :param str path: the view's path
+ :param str method: HTTP request method
+ :return openapi.Operation: the resulting Operation object
+ """
+ body = self.get_request_body_parameters(path, method)
+ query = self.get_query_parameters(path, method)
+ parameters = body + query
+
+ parameters = [param for param in parameters if param is not None]
+ description = self.get_description(path, method)
+ responses = self.get_responses(path, method)
+ return openapi.Operation(
+ operation_id='_'.join(operation_keys),
+ description=description,
+ responses=responses,
+ parameters=parameters,
+ tags=[operation_keys[0]]
+ )
+
+ def get_request_body_parameters(self, path, method):
+ """Return the request body parameters for this view.
+ This is either:
+ - a list with a single object Parameter with a Schema derived from the request serializer
+ - a list of primitive Parameters parsed as form data
+
+ :param str path: the view's path
+ :param str method: HTTP request method
+ :return list[Parameter]: a (potentially empty) list of openapi.Parameter in: either `body` or `formData`
+ """
+ # only PUT, PATCH or POST can have a request body
+ if method not in ('PUT', 'PATCH', 'POST'):
+ return []
+
+ serializer = self.get_request_serializer(path, method)
+ if serializer is None:
+ return []
+
+ encoding = self._sch.get_encoding(path, method)
+ if 'form' in encoding:
+ return [
+ self.field_to_swagger(value, openapi.Parameter, name=key, in_=openapi.IN_FORM)
+ for key, value
+ in serializer.fields.items()
+ ]
+ else:
+ schema = self.get_request_body_schema(path, method, serializer)
+ return [openapi.Parameter(name='data', in_=openapi.IN_BODY, schema=schema)]
+
+ def get_request_serializer(self, path, method):
+ """Return the request serializer (used for parsing the request payload) for this endpoint.
+
+ :param str path: the view's path
+ :param str method: HTTP request method
+ :return serializers.Serializer: the request serializer
+ """
+ # TODO: only GenericAPIViews have defined serializers;
+ # APIViews and plain ViewSets will need some kind of manual treatment
+ if not hasattr(self.view, 'get_serializer'):
+ return None
+
+ return self.view.get_serializer()
+
+ def get_request_body_schema(self, path, method, serializer):
+ """Return the Schema for a given request's body data. Only applies to PUT, PATCH and POST requests.
+
+ :param str path: the view's path
+ :param str method: HTTP request method
+ :param serializer: the view's request serialzier
+ :return openapi.Schema: the request body schema
+ """
+ return self.field_to_swagger(serializer, openapi.Schema)
+
+ def get_responses(self, path, method):
+ """Get the possible responses for this view as a swagger Responses object.
+
+ :param str path: the view's path
+ :param str method: HTTP request method
+ :return Responses: the documented responses
+ """
+ response_serializers = self.get_response_serializers(path, method)
+ return openapi.Responses(
+ responses=self.get_response_schemas(path, method, response_serializers)
+ )
+
+ def get_response_serializers(self, path, method):
+ """Return the response codes that this view is expected to return, and the serializer for each response body.
+ The return value should be a dict where the keys are possible status codes, and values are either strings,
+ `Serializer`s or `openapi.Response` objects.
+
+ :param str path: the view's path
+ :param str method: HTTP request method
+ :return dict: the response serializers
+ """
+ if method.lower() == 'post':
+ return {'201': ''}
+ if method.lower() == 'delete':
+ return {'204': ''}
+ return {'200': ''}
+
+ def get_response_schemas(self, path, method, response_serializers):
+ """Return the `openapi.Response` objects calculated for this view.
+
+ :param str path: the view's path
+ :param str method: HTTP request method
+ :param dict response_serializers: result of get_response_serializers
+ :return dict[str, openapi.Response]: a dictionary of status code to Response object
+ """
+ responses = {}
+ for status, serializer in response_serializers.items():
+ if isinstance(serializer, str):
+ response = openapi.Response(
+ description=serializer
+ )
+ elif isinstance(serializer, openapi.Response):
+ response = serializer
+ else:
+ response = openapi.Response(
+ description='',
+ schema=self.field_to_swagger(serializer, openapi.Schema)
+ )
+
+ responses[str(status)] = response
+
+ return responses
+
+ def get_query_parameters(self, path, method):
+ """Return the query parameters accepted by this view.
+
+ :param str path: the view's path
+ :param str method: HTTP request method
+ :return list[openapi.Parameter]: the query parameters
+ """
+ return self.get_filter_parameters(path, method) + self.get_pagination_parameters(path, method)
+
+ def get_filter_parameters(self, path, method):
+ """Return the parameters added to the view by its filter backends.
+
+ :param str path: the view's path
+ :param str method: HTTP request method
+ :return list[openapi.Parameter]: the filter query parameters
+ """
+ if not self._sch._allows_filters(path, method):
+ return []
+
+ fields = []
+ for filter_backend in self.view.filter_backends:
+ filter = filter_backend()
+ if hasattr(filter, 'get_schema_fields'):
+ fields += filter.get_schema_fields(self.view)
+ return [self.coreapi_field_to_parameter(field) for field in fields]
+
+ def get_pagination_parameters(self, path, method):
+ """Return the parameters added to the view by its paginator.
+
+ :param str path: the view's path
+ :param str method: HTTP request method
+ :return list[openapi.Parameter]: the pagination query parameters
+ """
+ if not is_list_view(path, method, self.view):
+ return []
+
+ paginator = getattr(self.view, 'paginator', None)
+ if paginator is None:
+ return []
+
+ return [
+ self.coreapi_field_to_parameter(field)
+ for field in paginator.get_schema_fields(self.view)
+ ]
+
+ def coreapi_field_to_parameter(self, field):
+ """Convert an instance of `coreapi.Field` to a swagger Parameter object.
+
+ :param coreapi.Field field: the coreapi field
+ :return openapi.Parameter: the equivalent openapi primitive Parameter
+ """
+ location_to_in = {
+ 'query': openapi.IN_QUERY,
+ 'path': openapi.IN_PATH,
+ 'form': openapi.IN_FORM,
+ 'body': openapi.IN_FORM,
+ }
+ coreapi_types = {
+ coreschema.Integer: openapi.TYPE_INTEGER,
+ coreschema.Number: openapi.TYPE_NUMBER,
+ coreschema.String: openapi.TYPE_STRING,
+ coreschema.Boolean: openapi.TYPE_BOOLEAN,
+ }
+ return openapi.Parameter(
+ name=field.name,
+ in_=location_to_in[field.location],
+ type=coreapi_types.get(field.schema.__class__, openapi.TYPE_STRING),
+ required=field.required,
+ description=field.schema.description,
+ )
+
+ def get_description(self, path, method):
+ """Return an operation description determined as appropriate from the view's method and class docstrings.
+
+ :param str path: the view's path
+ :param str method: HTTP request method
+ :return str: the operation description
+ """
+ return self._sch.get_description(path, method)
+
+ def field_to_swagger(self, field, swagger_object_type, **kwargs):
+ """Convert a drf Serializer or Field instance into a Swagger object.
+
+ :param rest_framework.serializers.Field field: the source field
+ :param type swagger_object_type: should be one of Schema, Parameter, Items
+ :param kwargs: extra attributes for constructing the object;
+ if swagger_object_type is Parameter, `name` and `in_` should be provided
+ :return Swagger,Parameter,Items: the swagger object
+ """
+ assert swagger_object_type in (openapi.Schema, openapi.Parameter, openapi.Items)
+ title = force_text(field.label) if field.label else None
+ title = title if swagger_object_type == openapi.Schema else None # only Schema has title
+ title = None
+ description = force_text(field.help_text) if field.help_text else None
+ description = description if swagger_object_type != openapi.Items else None # Items has no description either
+
+ SwaggerType = functools.partial(swagger_object_type, title=title, description=description, **kwargs)
+ # arrays in Schema have Schema elements, arrays in Parameter and Items have Items elements
+ ChildSwaggerType = openapi.Schema if swagger_object_type == openapi.Schema else openapi.Items
+
+ # ------ NESTED
+ if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
+ child_schema = self.field_to_swagger(field.child, ChildSwaggerType)
+ return SwaggerType(
+ type=openapi.TYPE_ARRAY,
+ items=child_schema,
+ )
+ elif isinstance(field, serializers.Serializer):
+ if swagger_object_type != openapi.Schema:
+ raise SwaggerGenerationError("cannot instantiate nested serializer as "
+ + swagger_object_type.__name__)
+ return SwaggerType(
+ type=openapi.TYPE_OBJECT,
+ properties=OrderedDict(
+ (key, self.field_to_swagger(value, ChildSwaggerType))
+ for key, value
+ in field.fields.items()
+ )
+ )
+ elif isinstance(field, serializers.ManyRelatedField):
+ child_schema = self.field_to_swagger(field.child_relation, ChildSwaggerType)
+ return SwaggerType(
+ type=openapi.TYPE_ARRAY,
+ items=child_schema,
+ unique_items=True, # is this OK?
+ )
+ elif isinstance(field, serializers.RelatedField):
+ # TODO: infer type for PrimaryKeyRelatedField?
+ return SwaggerType(type=openapi.TYPE_STRING)
+ # ------ CHOICES
+ elif isinstance(field, serializers.MultipleChoiceField):
+ return SwaggerType(
+ type=openapi.TYPE_ARRAY,
+ items=ChildSwaggerType(
+ type=openapi.TYPE_STRING,
+ enum=list(field.choices.keys())
+ )
+ )
+ elif isinstance(field, serializers.ChoiceField):
+ return SwaggerType(type=openapi.TYPE_STRING, enum=list(field.choices.keys()))
+ # ------ BOOL
+ elif isinstance(field, serializers.BooleanField):
+ return SwaggerType(type=openapi.TYPE_BOOLEAN)
+ # ------ NUMERIC
+ elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
+ # TODO: min_value max_value
+ return SwaggerType(type=openapi.TYPE_NUMBER)
+ elif isinstance(field, serializers.IntegerField):
+ # TODO: min_value max_value
+ return SwaggerType(type=openapi.TYPE_INTEGER)
+ # ------ STRING
+ elif isinstance(field, serializers.EmailField):
+ return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_EMAIL)
+ elif isinstance(field, serializers.RegexField):
+ return SwaggerType(type=openapi.TYPE_STRING, pattern=find_regex(field))
+ elif isinstance(field, serializers.SlugField):
+ return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_SLUG)
+ elif isinstance(field, serializers.URLField):
+ return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)
+ elif isinstance(field, serializers.IPAddressField):
+ format = {'ipv4': openapi.FORMAT_IPV4, 'ipv6': openapi.FORMAT_IPV6}.get(field.protocol, None)
+ return SwaggerType(type=openapi.TYPE_STRING, format=format)
+ elif isinstance(field, serializers.CharField):
+ # TODO: min_length max_length (for all CharField subclasses above too)
+ return SwaggerType(type=openapi.TYPE_STRING)
+ elif isinstance(field, serializers.UUIDField):
+ return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_UUID)
+ # ------ DATE & TIME
+ elif isinstance(field, serializers.DateField):
+ return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_DATE)
+ elif isinstance(field, serializers.DateTimeField):
+ return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_DATETIME)
+ # ------ OTHERS
+ elif isinstance(field, serializers.FileField):
+ # swagger 2.0 does not support specifics about file fields, so ImageFile gets no special treatment
+ # OpenAPI 3.0 does support it, so a future implementation could handle this better
+ # TODO: appropriate produces/consumes somehow/somewhere?
+ if swagger_object_type != openapi.Parameter:
+ raise SwaggerGenerationError("parameter of type file is supported only in formData Parameter")
+ return SwaggerType(type=openapi.TYPE_FILE)
+ elif isinstance(field, serializers.JSONField):
+ return SwaggerType(
+ type=openapi.TYPE_STRING,
+ format=openapi.FORMAT_BINARY if field.binary else None
+ )
+ elif isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema:
+ child_schema = self.field_to_swagger(field.child, ChildSwaggerType)
+ return SwaggerType(
+ type=openapi.TYPE_OBJECT,
+ additional_properties=child_schema
+ )
+
+ # TODO unhandled fields: TimeField DurationField HiddenField ModelField NullBooleanField?
+ # TODO: return info about required/allowed empty
+
+ # everything else gets string by default
+ return SwaggerType(type=openapi.TYPE_STRING)
diff --git a/src/drf_swagger/middleware.py b/src/drf_swagger/middleware.py
new file mode 100644
index 0000000..38a714d
--- /dev/null
+++ b/src/drf_swagger/middleware.py
@@ -0,0 +1,18 @@
+from django.http import HttpResponse
+from django.utils.deprecation import MiddlewareMixin
+
+from drf_swagger.errors import SwaggerValidationError
+from .codecs import _OpenAPICodec
+
+
+class SwaggerExceptionMiddleware(MiddlewareMixin):
+ def process_exception(self, request, exception):
+ if isinstance(exception, SwaggerValidationError):
+ err = {'errors': {exception.validator_name: str(exception)}}
+ codec = exception.source_codec
+ if isinstance(codec, _OpenAPICodec):
+ err = codec.encode_error(err)
+ content_type = codec.media_type
+ return HttpResponse(err, status=500, content_type=content_type)
+
+ return None
diff --git a/src/drf_swagger/openapi.py b/src/drf_swagger/openapi.py
index 9434a15..8858cab 100644
--- a/src/drf_swagger/openapi.py
+++ b/src/drf_swagger/openapi.py
@@ -1,10 +1,105 @@
-import warnings
from collections import OrderedDict
-import coreapi
+from coreapi.compat import urlparse
+from future.utils import raise_from
+from inflection import camelize
+
+TYPE_OBJECT = "object"
+TYPE_STRING = "string"
+TYPE_NUMBER = "number"
+TYPE_INTEGER = "integer"
+TYPE_BOOLEAN = "boolean"
+TYPE_ARRAY = "array"
+TYPE_FILE = "file"
+
+# officially supported by Swagger 2.0 spec
+FORMAT_DATE = "date"
+FORMAT_DATETIME = "date-time"
+FORMAT_PASSWORD = "password"
+FORMAT_BINARY = "binary"
+FORMAT_BASE64 = "bytes"
+FORMAT_FLOAT = "float"
+FORMAT_DOUBLE = "double"
+FORMAT_INT32 = "int32"
+FORMAT_INT64 = "int64"
+
+# defined in JSON-schema
+FORMAT_EMAIL = "email"
+FORMAT_IPV4 = "ipv4"
+FORMAT_IPV6 = "ipv6"
+FORMAT_URI = "uri"
+
+# pulled out of my ass
+FORMAT_UUID = "uuid"
+FORMAT_SLUG = "slug"
+
+IN_BODY = 'body'
+IN_PATH = 'path'
+IN_QUERY = 'query'
+IN_FORM = 'formData'
+IN_HEADER = 'header'
-class Contact(object):
+def make_swagger_name(attribute_name):
+ """
+ Convert a python variable name into a Swagger spec attribute name.
+
+ In particular,
+ * if name starts with x_, return "x-{camelCase}"
+ * if name is 'ref', return "$ref"
+ * else return the name converted to camelCase, with trailing underscores stripped
+
+ :param str attribute_name: python attribute name
+ :return: swagger name
+ """
+ if attribute_name == 'ref':
+ return "$ref"
+ if attribute_name.startswith("x_"):
+ return "x-" + camelize(attribute_name[2:], uppercase_first_letter=False)
+ return camelize(attribute_name.rstrip('_'), uppercase_first_letter=False)
+
+
+class SwaggerDict(OrderedDict):
+ def __init__(self, **attrs):
+ super(SwaggerDict, self).__init__()
+ self._extras__ = attrs
+ if self.__class__ == SwaggerDict:
+ self._insert_extras__()
+
+ def __setattr__(self, key, value):
+ if key.startswith('_'):
+ super(SwaggerDict, self).__setattr__(key, value)
+ return
+ if value is not None:
+ self[make_swagger_name(key)] = value
+
+ def __getattr__(self, item):
+ if item.startswith('_'):
+ raise AttributeError
+ try:
+ return self[make_swagger_name(item)]
+ except KeyError as e:
+ raise_from(AttributeError("no attribute " + item), e)
+
+ def __delattr__(self, item):
+ if item.startswith('_'):
+ super(SwaggerDict, self).__delattr__(item)
+ return
+ del self[make_swagger_name(item)]
+
+ def _insert_extras__(self):
+ """
+ From an ordering perspective, it is desired that extra attributes such as vendor extensions stay at the
+ bottom of the object. However, python2.7's OrderdDict craps out if you try to insert into it before calling
+ init. This means that subclasses must call super().__init__ as the first statement of their own __init__,
+ which would result in the extra attributes being added first. For this reason, we defer the insertion of the
+ attributes and require that subclasses call ._insert_extras__ at the end of their __init__ method.
+ """
+ for attr, val in self._extras__.items():
+ setattr(self, attr, val)
+
+
+class Contact(SwaggerDict):
"""Swagger Contact object
At least one of the following fields is required:
@@ -12,47 +107,34 @@ class Contact(object):
:param str url: contact url
:param str email: contact e-mail
"""
- def __init__(self, name=None, url=None, email=None):
+
+ def __init__(self, name=None, url=None, email=None, **extra):
+ super(Contact, self).__init__(**extra)
+ if name is None and url is None and email is None:
+ raise ValueError("one of name, url or email is requires for Swagger Contact object")
self.name = name
self.url = url
self.email = email
- if name is None and url is None and email is None:
- raise ValueError("one of name, url or email is requires for Swagger Contact object")
-
- def to_swagger(self):
- contact = OrderedDict()
- if self.name is not None:
- contact['name'] = self.name
- if self.url is not None:
- contact['url'] = self.url
- if self.email is not None:
- contact['email'] = self.email
-
- return contact
+ self._insert_extras__()
-class License(object):
+class License(SwaggerDict):
"""Swagger License object
:param str name: Requird. License name
:param str url: link to detailed license information
"""
- def __init__(self, name, url=None):
- self.name = name
- self.url = url
+
+ def __init__(self, name, url=None, **extra):
+ super(License, self).__init__(**extra)
if name is None:
raise ValueError("name is required for Swagger License object")
-
- def to_swagger(self):
- license = OrderedDict()
- license['name'] = self.name
- if self.url is not None:
- license['url'] = self.url
-
- return license
+ self.name = name
+ self.url = url
+ self._insert_extras__()
-class Info(object):
+class Info(SwaggerDict):
"""Swagger Info object
:param str title: Required. API title.
@@ -62,7 +144,10 @@ class Info(object):
:param Contact contact: contact object
:param License license: license object
"""
- def __init__(self, title, default_version, description=None, terms_of_service=None, contact=None, license=None):
+
+ def __init__(self, title, default_version, description=None, terms_of_service=None, contact=None, license=None,
+ **extra):
+ super(Info, self).__init__(**extra)
if title is None or default_version is None:
raise ValueError("title and version are required for Swagger info object")
if contact is not None and not isinstance(contact, Contact):
@@ -70,67 +155,139 @@ class Info(object):
if license is not None and not isinstance(license, License):
raise ValueError("license must be a License object")
self.title = title
- self.default_version = default_version
+ self._default_version = default_version
self.description = description
self.terms_of_service = terms_of_service
self.contact = contact
self.license = license
-
- def to_swagger(self, version):
- info = OrderedDict()
- info['title'] = self.title
- if self.description is not None:
- info['description'] = self.description
- if self.terms_of_service is not None:
- info['termsOfService'] = self.terms_of_service
- if self.contact is not None:
- info['contact'] = self.contact.to_swagger()
- if self.license is not None:
- info['license'] = self.license.to_swagger()
- info['version'] = version or self.default_version
- return info
+ self._insert_extras__()
-class Swagger(coreapi.Document):
- @classmethod
- def from_coreapi(cls, document, info, version):
- """
- Create an openapi.Swagger from the fields of a coreapi.Document.
+class Swagger(SwaggerDict):
+ def __init__(self, info=None, _url=None, _version=None, paths=None, **extra):
+ super(Swagger, self).__init__(**extra)
+ self.swagger = '2.0'
+ self.info = info
+ self.info.version = _version or info._default_version
+ self.paths = paths
- :param coreapi.Document document: source coreapi.Document
- :param openapi.Info info: Swagger info object
- :param string version: API version string
- :return: an openapi.Swagger
- """
- if document.title and document.title != info.title:
- warnings.warn("document title is overriden by Swagger Info")
- if document.description and document.description != info.description:
- warnings.warn("document description is overriden by Swagger Info")
- return Swagger(
- info=info,
- version=version,
- url=document.url,
- media_type=document.media_type,
- content=document.data
- )
+ if _url:
+ url = urlparse.urlparse(_url)
+ if url.netloc:
+ self.host = url.netloc
+ if url.scheme:
+ self.schemes = [url.scheme]
- def __init__(self, info=None, version=None, url=None, media_type=None, content=None):
- super(Swagger, self).__init__(url, info.title, info.description, media_type, content)
- self._info = info
- self._version = version
-
- @property
- def info(self):
- return self._info
-
- @property
- def version(self):
- return self._version
+ self.base_path = '/'
+ self._insert_extras__()
-class Field(coreapi.Field):
- pass
+class Paths(SwaggerDict):
+ def __init__(self, paths, **extra):
+ super(Paths, self).__init__(**extra)
+ for path, path_obj in paths.items():
+ assert path.startswith("/")
+ if path_obj is not None:
+ self[path] = path_obj
+ self._insert_extras__()
-class Link(coreapi.Link):
- pass
+class PathItem(SwaggerDict):
+ def __init__(self, get=None, put=None, post=None, delete=None, options=None,
+ head=None, patch=None, parameters=None, **extra):
+ super(PathItem, self).__init__(**extra)
+ self.get = get
+ self.put = put
+ self.post = post
+ self.delete = delete
+ self.options = options
+ self.head = head
+ self.patch = patch
+ self.parameters = parameters
+ self._insert_extras__()
+
+
+class Operation(SwaggerDict):
+ def __init__(self, operation_id, responses, parameters=None, consumes=None,
+ produces=None, description=None, tags=None, **extra):
+ super(Operation, self).__init__(**extra)
+ self.operation_id = operation_id
+ self.responses = responses
+ self.parameters = [param for param in parameters if param is not None]
+ self.consumes = consumes
+ self.produces = produces
+ self.description = description
+ self.tags = tags
+ self._insert_extras__()
+
+
+class Items(SwaggerDict):
+ def __init__(self, type=None, format=None, enum=None, pattern=None, items=None, **extra):
+ super(Items, self).__init__(**extra)
+ self.type = type
+ self.format = format
+ self.enum = enum
+ self.pattern = pattern
+ self.items = items
+ self._insert_extras__()
+
+
+class Parameter(SwaggerDict):
+ def __init__(self, name, in_, description=None, required=None, schema=None,
+ type=None, format=None, enum=None, pattern=None, items=None, **extra):
+ super(Parameter, self).__init__(**extra)
+ if (not schema and not type) or (schema and type):
+ raise ValueError("either schema or type are required for Parameter object!")
+ self.name = name
+ self.in_ = in_
+ self.description = description
+ self.required = required
+ self.schema = schema
+ self.type = type
+ self.format = format
+ self.enum = enum
+ self.pattern = pattern
+ self.items = items
+ self._insert_extras__()
+
+
+class Schema(SwaggerDict):
+ def __init__(self, description=None, required=None, type=None, properties=None, additional_properties=None,
+ format=None, enum=None, pattern=None, items=None, **extra):
+ super(Schema, self).__init__(**extra)
+ self.description = description
+ self.required = required
+ self.type = type
+ self.properties = properties
+ self.additional_properties = additional_properties
+ self.format = format
+ self.enum = enum
+ self.pattern = pattern
+ self.items = items
+ self._insert_extras__()
+
+
+class Ref(SwaggerDict):
+ def __init__(self, ref):
+ super(Ref, self).__init__()
+ self.ref = ref
+ self._insert_extras__()
+
+
+class Responses(SwaggerDict):
+ def __init__(self, responses, default=None, **extra):
+ super(Responses, self).__init__(**extra)
+ for status, response in responses.items():
+ if response is not None:
+ self[str(status)] = response
+ self.default = default
+ self._insert_extras__()
+
+
+class Response(SwaggerDict):
+ def __init__(self, description, schema=None, examples=None, **extra):
+ super(Response, self).__init__(**extra)
+ self.description = description
+ self.schema = schema
+ self.examples = examples
+ self._insert_extras__()
diff --git a/src/drf_swagger/renderers.py b/src/drf_swagger/renderers.py
index 90e6d43..1d20a86 100644
--- a/src/drf_swagger/renderers.py
+++ b/src/drf_swagger/renderers.py
@@ -45,17 +45,17 @@ class _UIRenderer(BaseRenderer):
charset = 'utf-8'
template = ''
- def render(self, data, accepted_media_type=None, renderer_context=None):
- self.set_context(renderer_context, data)
+ def render(self, swagger, accepted_media_type=None, renderer_context=None):
+ self.set_context(renderer_context, swagger)
return render(
renderer_context['request'],
self.template,
renderer_context
)
- def set_context(self, renderer_context, data):
- renderer_context['title'] = data.title
- renderer_context['version'] = data.version
+ def set_context(self, renderer_context, swagger):
+ renderer_context['title'] = swagger.info.title
+ renderer_context['version'] = swagger.info.version
renderer_context['swagger_settings'] = json.dumps(self.get_swagger_ui_settings())
renderer_context['redoc_settings'] = json.dumps(self.get_redoc_settings())
renderer_context['USE_SESSION_AUTH'] = swagger_settings.USE_SESSION_AUTH
diff --git a/src/drf_swagger/templates/drf-swagger/swagger-ui.html b/src/drf_swagger/templates/drf-swagger/swagger-ui.html
index 1b3f58f..634da25 100644
--- a/src/drf_swagger/templates/drf-swagger/swagger-ui.html
+++ b/src/drf_swagger/templates/drf-swagger/swagger-ui.html
@@ -35,7 +35,7 @@
margin-right: 8px;
}
- #django-session-auth.hidden {
+ .hidden {
display: none;
}
@@ -126,14 +126,31 @@