Improve host, schemes and basePath handling (#42)

* added handling of basePath by taking into account SCRIPT_NAME and the longest common prefix
* improved handling of NamespaceVersioning by excluding URLs of differing versions
* added documentation and error messages for the problem reported in #37
openapi3
Cristi Vîjdea 2018-01-12 03:37:04 +01:00 committed by GitHub
parent 757d47e1c0
commit 7a3fe8ec0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 247 additions and 99 deletions

View File

@ -42,6 +42,10 @@ Features
- generated Swagger schema can be automatically validated by
`swagger-spec-validator <https://github.com/Yelp/swagger_spec_validator>`_ or
`flex <https://github.com/pipermerriam/flex>`_
- supports Django REST Framework API versioning
+ ``URLPathVersioning``, ``NamespaceVersioning`` and ``HostnameVersioning`` are supported
+ ``AcceptHeaderVersioning`` and ``QueryParameterVersioning`` are not currently supported
.. figure:: https://raw.githubusercontent.com/axnsan12/drf-yasg/1.0.2/screenshots/redoc-nested-response.png
:width: 100%
@ -211,7 +215,7 @@ The possible settings and their default values are as follows:
# default api Info if none is otherwise given; should be an import string to an openapi.Info object
'DEFAULT_INFO': None,
# default API url if none is otherwise given
'DEFAULT_API_URL': '',
'DEFAULT_API_URL': None,
'USE_SESSION_AUTH': True, # add Django Login and Django Logout buttons, CSRF token to swagger UI page
'LOGIN_URL': getattr(django.conf.settings, 'LOGIN_URL', None), # URL for the login button

View File

@ -127,6 +127,29 @@ This section describes where information is sourced from when using the default
* *descriptions* for :class:`.Operation`\ s, :class:`.Parameter`\ s and :class:`.Schema`\ s are picked up from
docstrings and ``help_text`` attributes in the same manner as the `default DRF SchemaGenerator
<http://www.django-rest-framework.org/api-guide/schemas/#schemas-as-documentation>`_
* .. _custom-spec-base-url:
the base URL for the API consists of three values - the ``host``, ``schemes`` and ``basePath`` attributes
* the host name and scheme are determined, in descending order of priority:
+ from the ``url`` argument passed to :func:`.get_schema_view` (more specifically, to the underlying
:class:`.OpenAPISchemaGenerator`)
+ from the :ref:`DEFAULT_API_URL setting <default-swagger-settings>`
+ inferred from the request made to the schema endpoint
For example, an url of ``https://www.example.com:8080/some/path`` will populate the ``host`` and ``schemes``
attributes with ``www.example.com:8080`` and ``['https']``, respectively. The path component will be ignored.
* the base path is determined as the concatenation of two variables:
#. the `SCRIPT_NAME`_ wsgi environment variable; this is set, for example, when serving the site from a
sub-path using web server url rewriting
.. Tip::
The Django `FORCE_SCRIPT_NAME`_ setting can be used to override the `SCRIPT_NAME`_ or set it when it's
missing from the environment.
#. the longest common path prefix of all the urls in your API - see :meth:`.determine_path_prefix`
.. _custom-spec-swagger-auto-schema:
@ -398,3 +421,6 @@ A second example, of a :class:`~.inspectors.FieldInspector` that removes the ``t
This means that you should generally avoid view or method-specific ``FieldInspector``\ s if you are dealing with
references (a.k.a named models), because you can never know which view will be the first to generate the schema
for a given serializer.
.. _SCRIPT_NAME: https://www.python.org/dev/peps/pep-0333/#environ-variables
.. _FORCE_SCRIPT_NAME: https://docs.djangoproject.com/en/2.0/ref/settings/#force-script-name

View File

@ -66,7 +66,7 @@ See the command help for more advanced options:
In ``settings.py``:
.. code:: python
.. code-block:: python
SWAGGER_SETTINGS = {
'DEFAULT_INFO': 'import.path.to.urls.api_info',
@ -74,7 +74,7 @@ See the command help for more advanced options:
In ``urls.py``:
.. code:: python
.. code-block:: python
api_info = openapi.Info(
title="Snippets API",

View File

@ -107,10 +107,13 @@ management command, or if no ``info`` argument is passed to ``get_schema_view``.
DEFAULT_API_URL
---------------
A string representing the default API URL. This will be used to populate the ``host``, ``schemes`` and ``basePath``
attributes of the Swagger document if no API URL is otherwise provided.
A string representing the default API URL. This will be used to populate the ``host`` and ``schemes`` attributes
of the Swagger document if no API URL is otherwise provided. The Django `FORCE_SCRIPT_NAME`_ setting can be used for
providing an API mount point prefix.
**Default**: :python:`''`
See also: :ref:`documentation on base URL construction <custom-spec-base-url>`
**Default**: :python:`None`
Authorization
=============
@ -274,3 +277,6 @@ PATH_IN_MIDDLE
**Default**: :python:`False` |br|
*Maps to attribute*: ``path-in-middle-panel``
.. _FORCE_SCRIPT_NAME: https://docs.djangoproject.com/en/2.0/ref/settings/#force-script-name

View File

@ -23,7 +23,7 @@ SWAGGER_DEFAULTS = {
],
'DEFAULT_INFO': None,
'DEFAULT_API_URL': '',
'DEFAULT_API_URL': None,
'USE_SESSION_AUTH': True,
'SECURITY_DEFINITIONS': {

View File

@ -1,25 +1,107 @@
import logging
import re
from collections import OrderedDict, defaultdict
import uritemplate
from coreapi.compat import urlparse
from django.utils.encoding import force_text
from rest_framework import versioning
from rest_framework.compat import URLPattern, URLResolver, get_original_route
from rest_framework.schemas.generators import EndpointEnumerator as _EndpointEnumerator
from rest_framework.schemas.generators import SchemaGenerator
from rest_framework.schemas.generators import SchemaGenerator, endpoint_ordering
from rest_framework.schemas.inspectors import get_pk_description
from drf_yasg.errors import SwaggerGenerationError
from . import openapi
from .app_settings import swagger_settings
from .inspectors.field import get_basic_type_info, get_queryset_field
from .openapi import ReferenceResolver
logger = logging.getLogger(__name__)
PATH_PARAMETER_RE = re.compile(r'{(?P<parameter>\w+)}')
class EndpointEnumerator(_EndpointEnumerator):
def __init__(self, patterns=None, urlconf=None, request=None):
super(EndpointEnumerator, self).__init__(patterns, urlconf)
self.request = request
def get_path_from_regex(self, path_regex):
if path_regex.endswith(')'):
logger.warning("url pattern does not end in $ ('%s') - unexpected things might happen")
return self.unescape_path(super(EndpointEnumerator, self).get_path_from_regex(path_regex))
def should_include_endpoint(self, path, callback, app_name='', namespace='', url_name=None):
if not super(EndpointEnumerator, self).should_include_endpoint(path, callback):
return False
version = getattr(self.request, 'version', None)
versioning_class = getattr(callback.cls, 'versioning_class', None)
if versioning_class is not None and issubclass(versioning_class, versioning.NamespaceVersioning):
if version and version not in namespace.split(':'):
return False
return True
def replace_version(self, path, callback):
"""If ``request.version`` is not ``None`` and `callback` uses ``URLPathVersioning``, this function replaces
the ``version`` parameter in `path` with the actual version.
:param str path: the templated path
:param callback: the view callback
:rtype: str
"""
versioning_class = getattr(callback.cls, 'versioning_class', None)
if versioning_class is not None and issubclass(versioning_class, versioning.URLPathVersioning):
version = getattr(self.request, 'version', None)
if version:
version_param = getattr(versioning_class, 'version_param', 'version')
version_param = '{%s}' % version_param
if version_param not in path:
logger.info("view %s uses URLPathVersioning but URL %s has no param %s"
% (callback.cls, path, version_param))
path = path.replace(version_param, version)
return path
def get_api_endpoints(self, patterns=None, prefix='', app_name=None, namespace=None):
"""
Return a list of all available API endpoints by inspecting the URL conf.
Copied entirely from super.
"""
if patterns is None:
patterns = self.patterns
api_endpoints = []
for pattern in patterns:
path_regex = prefix + get_original_route(pattern)
if isinstance(pattern, URLPattern):
path = self.get_path_from_regex(path_regex)
callback = pattern.callback
url_name = pattern.name
if self.should_include_endpoint(path, callback, app_name or '', namespace or '', url_name):
path = self.replace_version(path, callback)
for method in self.get_allowed_methods(callback):
endpoint = (path, method, callback)
api_endpoints.append(endpoint)
elif isinstance(pattern, URLResolver):
nested_endpoints = self.get_api_endpoints(
patterns=pattern.url_patterns,
prefix=path_regex,
app_name="%s:%s" % (app_name, pattern.app_name) if app_name else pattern.app_name,
namespace="%s:%s" % (namespace, pattern.namespace) if namespace else pattern.namespace
)
api_endpoints.extend(nested_endpoints)
api_endpoints = sorted(api_endpoints, key=endpoint_ordering)
return api_endpoints
def unescape(self, s):
"""Unescape all backslash escapes from `s`.
@ -30,8 +112,8 @@ class EndpointEnumerator(_EndpointEnumerator):
return re.sub(r'\\(.)', r'\1', s)
def unescape_path(self, path):
"""Remove backslashes from all path components outside {parameters}. This is needed because
Django>=2.0 ``path()``/``RoutePattern`` aggresively escapes all non-parameter path components.
"""Remove backslashe escapes from all path components outside {parameters}. This is needed because
``simplify_regex`` does not handle this correctly - note however that this implementation is
**NOTE:** this might destructively affect some url regex patterns that contain metacharacters (e.g. \w, \d)
outside path parameter groups; if you are in this category, God help you
@ -60,12 +142,19 @@ class OpenAPISchemaGenerator(object):
"""
endpoint_enumerator_class = EndpointEnumerator
def __init__(self, info, version='', url=swagger_settings.DEFAULT_API_URL, patterns=None, urlconf=None):
def __init__(self, info, version='', url=None, patterns=None, urlconf=None):
"""
:param .Info info: information about the API
:param str version: API version string; can be omitted to use `info.default_version`
:param str url: API url; can be empty to remove URL info from the result
:param str version: API version string; if omitted, `info.default_version` will be used
:param str url: API scheme, host and port; if ``None`` is passed and ``DEFAULT_API_URL`` is not set, the url
will be inferred from the request made against the schema view, so you should generally not need to set
this parameter explicitly; if the empty string is passed, no host and scheme will be emitted
If `url` is not ``None`` or the empty string, it must be a scheme-absolute uri (i.e. starting with http://
or https://), and any path component is ignored;
See also: :ref:`documentation on base URL construction <custom-spec-base-url>`
:param patterns: if given, only these patterns will be enumerated for inclusion in the API spec
:param urlconf: if patterns is not given, use this urlconf to enumerate patterns;
if not given, the default urlconf is used
@ -73,6 +162,15 @@ class OpenAPISchemaGenerator(object):
self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf)
self.info = info
self.version = version
if url is None and swagger_settings.DEFAULT_API_URL is not None:
url = swagger_settings.DEFAULT_API_URL
if url:
parsed_url = urlparse.urlparse(url)
if parsed_url.scheme not in ('http', 'https') or not parsed_url.netloc:
raise SwaggerGenerationError("`url` must be an absolute HTTP(S) url")
if parsed_url.path:
logger.warning("path component of api base URL %s is ignored; use FORCE_SCRIPT_NAME instead" % url)
@property
def url(self):
@ -89,17 +187,16 @@ class OpenAPISchemaGenerator(object):
:rtype: openapi.Swagger
"""
endpoints = self.get_endpoints(request)
endpoints = self.replace_version(endpoints, request)
components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS)
paths = self.get_paths(endpoints, components, request, public)
paths, prefix = self.get_paths(endpoints, components, request, public)
url = self.url
if not url and request is not None:
if url is None and request is not None:
url = request.build_absolute_uri()
return openapi.Swagger(
info=self.info, paths=paths,
_url=url, _version=self.version, **dict(components)
_url=url, _prefix=prefix, _version=self.version, **dict(components)
)
def create_view(self, callback, method, request=None):
@ -120,30 +217,6 @@ class OpenAPISchemaGenerator(object):
setattr(view_method.__func__, '_swagger_auto_schema', overrides)
return view
def replace_version(self, endpoints, request):
"""If ``request.version`` is not ``None``, replace the version parameter in the path of any endpoints using
``URLPathVersioning`` as a versioning class.
:param dict endpoints: endpoints as returned by :meth:`.get_endpoints`
:param Request request: the request made against the schema view
:return: endpoints with modified paths
"""
version = getattr(request, 'version', None)
if version is None:
return endpoints
new_endpoints = {}
for path, endpoint in endpoints.items():
view_cls = endpoint[0]
versioning_class = getattr(view_cls, 'versioning_class', None)
version_param = getattr(versioning_class, 'version_param', 'version')
if versioning_class is not None and issubclass(versioning_class, versioning.URLPathVersioning):
path = path.replace('{%s}' % version_param, version)
new_endpoints[path] = endpoint
return new_endpoints
def get_endpoints(self, request):
"""Iterate over all the registered endpoints in the API and return a fake view with the right parameters.
@ -151,7 +224,7 @@ class OpenAPISchemaGenerator(object):
:return: {path: (view_class, list[(http_method, view_instance)])
:rtype: dict
"""
enumerator = self.endpoint_enumerator_class(self._gen.patterns, self._gen.urlconf)
enumerator = self.endpoint_enumerator_class(self._gen.patterns, self._gen.urlconf, request=request)
endpoints = enumerator.get_api_endpoints()
view_paths = defaultdict(list)
@ -207,14 +280,16 @@ class OpenAPISchemaGenerator(object):
:param ReferenceResolver components: resolver/container for Swagger References
:param Request request: the request made against the schema view; can be None
:param bool public: if True, all endpoints are included regardless of access through `request`
:rtype: openapi.Paths
:returns: the :class:`.Paths` object and the longest common path prefix, as a 2-tuple
:rtype: tuple[openapi.Paths,str]
"""
if not endpoints:
return openapi.Paths(paths={})
return openapi.Paths(paths={}), ''
prefix = self.determine_path_prefix(list(endpoints.keys())) or ''
assert '{' not in prefix, "base path cannot be templated in swagger 2.0"
prefix = self.determine_path_prefix(list(endpoints.keys()))
paths = OrderedDict()
for path, (view_cls, methods) in sorted(endpoints.items()):
operations = {}
for method, view in methods:
@ -224,9 +299,14 @@ class OpenAPISchemaGenerator(object):
operations[method.lower()] = self.get_operation(view, path, prefix, method, components, request)
if operations:
paths[path] = self.get_path_item(path, view_cls, operations)
# since the common prefix is used as the API basePath, it must be stripped
# from individual paths when writing them into the swagger document
path_suffix = path[len(prefix):]
if not path_suffix.startswith('/'):
path_suffix = '/' + path_suffix
paths[path_suffix] = self.get_path_item(path, view_cls, operations)
return openapi.Paths(paths=paths)
return openapi.Paths(paths=paths), prefix
def get_operation(self, view, path, prefix, method, components, request):
"""Get an :class:`.Operation` for the given API endpoint (path, method). This method delegates to

View File

@ -42,7 +42,7 @@ class Command(BaseCommand):
'-u', '--url', dest='api_url',
default='',
type=str,
help='Base API URL - sets the host, scheme and basePath attributes of the generated document.'
help='Base API URL - sets the host and scheme attributes of the generated document.'
)
parser.add_argument(
'-m', '--mock-request', dest='mock',

View File

@ -2,6 +2,7 @@ import re
from collections import OrderedDict
from coreapi.compat import urlparse
from django.urls import get_script_prefix
from inflection import camelize
from .utils import filter_none
@ -210,11 +211,13 @@ class Info(SwaggerDict):
class Swagger(SwaggerDict):
def __init__(self, info=None, _url=None, _version=None, paths=None, definitions=None, **extra):
def __init__(self, info=None, _url=None, _prefix=None, _version=None, paths=None, definitions=None, **extra):
"""Root Swagger object.
:param .Info info: info object
:param str _url: URL used for guessing the API host, scheme and basepath
:param str _url: URL used for setting the API host and scheme
:param str _prefix: api path prefix to use in setting basePath; this will be appended to the wsgi
SCRIPT_NAME prefix or Django's FORCE_SCRIPT_NAME if applicable
:param str _version: version string to override Info
:param .Paths paths: paths object
:param dict[str,.Schema] definitions: named models
@ -226,16 +229,39 @@ class Swagger(SwaggerDict):
if _url:
url = urlparse.urlparse(_url)
if url.netloc:
assert url.netloc and url.scheme, "if given, url must have both schema and netloc"
self.host = url.netloc
if url.scheme:
self.schemes = [url.scheme]
self.base_path = '/'
self.base_path = self.get_base_path(get_script_prefix(), _prefix)
self.paths = paths
self.definitions = filter_none(definitions)
self._insert_extras__()
@classmethod
def get_base_path(cls, script_prefix, api_prefix):
"""Determine an appropriate value for ``basePath`` based on the SCRIPT_NAME and the api common prefix.
:param str script_prefix: script prefix as defined by django ``get_script_prefix``
:param str api_prefix: api common prefix
:return: joined base path
"""
# avoid double slash when joining script_name with api_prefix
if script_prefix and script_prefix.endswith('/'):
script_prefix = script_prefix[:-1]
if not api_prefix.startswith('/'):
api_prefix = '/' + api_prefix
base_path = script_prefix + api_prefix
# ensure that the base path has a leading slash and no trailing slash
if base_path and base_path.endswith('/'):
base_path = base_path[:-1]
if not base_path.startswith('/'):
base_path = '/' + base_path
return base_path
class Paths(SwaggerDict):
def __init__(self, paths, **extra):

View File

@ -52,14 +52,13 @@ def get_schema_view(info=None, url=None, patterns=None, urlconf=None, public=Fal
generator_class=OpenAPISchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES):
"""
Create a SchemaView class with default renderers and generators.
"""Create a SchemaView class with default renderers and generators.
:param .Info info: Swagger API Info object; if omitted, defaults to `DEFAULT_INFO`
:param str url: API base url; if left blank will be deduced from the location the view is served at
:param patterns: passed to SchemaGenerator
:param urlconf: passed to SchemaGenerator
:param bool public: if False, includes only endpoints the current user has access to
:param .Info info: information about the API; if omitted, defaults to :ref:`DEFAULT_INFO <default-swagger-settings>`
:param str url: same as :class:`.OpenAPISchemaGenerator`
:param patterns: same as :class:`.OpenAPISchemaGenerator`
:param urlconf: same as :class:`.OpenAPISchemaGenerator`
:param bool public: if False, includes only the endpoints that are accesible by the user viewing the schema
:param list validators: a list of validator names to apply; allowed values are ``flex``, ``ssv``
:param type generator_class: schema generator class to use; should be a subclass of :class:`.OpenAPISchemaGenerator`
:param tuple authentication_classes: authentication classes for the schema view itself

View File

@ -5,6 +5,7 @@ import pytest
from drf_yasg import codecs, openapi
from drf_yasg.codecs import yaml_sane_load
from drf_yasg.errors import SwaggerGenerationError
from drf_yasg.generators import OpenAPISchemaGenerator
@ -46,14 +47,23 @@ def test_yaml_and_json_match(codec_yaml, codec_json, swagger):
def test_basepath_only(mock_schema_request):
with pytest.raises(SwaggerGenerationError):
generator = OpenAPISchemaGenerator(
info=openapi.Info(title="Test generator", default_version="v1"),
version="v2",
url='/basepath/',
)
generator.get_schema(mock_schema_request, public=True)
def test_no_netloc(mock_schema_request):
generator = OpenAPISchemaGenerator(
info=openapi.Info(title="Test generator", default_version="v1"),
version="v2",
url='',
)
swagger = generator.get_schema(mock_schema_request, public=True)
assert 'host' not in swagger
assert 'schemes' not in swagger
assert swagger['basePath'] == '/' # base path is not implemented for now
assert 'host' not in swagger and 'schemes' not in swagger
assert swagger['info']['version'] == 'v2'

View File

@ -4,24 +4,25 @@ from drf_yasg.codecs import yaml_sane_load
def _get_versioned_schema(prefix, client, validate_schema):
response = client.get(prefix + 'swagger.yaml')
response = client.get(prefix + '/swagger.yaml')
assert response.status_code == 200
swagger = yaml_sane_load(response.content.decode('utf-8'))
assert swagger['basePath'] == prefix
validate_schema(swagger)
assert prefix + 'snippets/' in swagger['paths']
assert '/snippets/' in swagger['paths']
return swagger
def _check_v1(swagger, prefix):
def _check_v1(swagger):
assert swagger['info']['version'] == '1.0'
versioned_post = swagger['paths'][prefix + 'snippets/']['post']
versioned_post = swagger['paths']['/snippets/']['post']
assert versioned_post['responses']['201']['schema']['$ref'] == '#/definitions/Snippet'
assert 'v2field' not in swagger['definitions']['Snippet']['properties']
def _check_v2(swagger, prefix):
def _check_v2(swagger):
assert swagger['info']['version'] == '2.0'
versioned_post = swagger['paths'][prefix + 'snippets/']['post']
versioned_post = swagger['paths']['/snippets/']['post']
assert versioned_post['responses']['201']['schema']['$ref'] == '#/definitions/SnippetV2'
assert 'v2field' in swagger['definitions']['SnippetV2']['properties']
v2field = swagger['definitions']['SnippetV2']['properties']['v2field']
@ -30,27 +31,23 @@ def _check_v2(swagger, prefix):
@pytest.mark.urls('urlconfs.url_versioning')
def test_url_v1(client, validate_schema):
prefix = '/versioned/url/v1.0/'
swagger = _get_versioned_schema(prefix, client, validate_schema)
_check_v1(swagger, prefix)
swagger = _get_versioned_schema('/versioned/url/v1.0', client, validate_schema)
_check_v1(swagger)
@pytest.mark.urls('urlconfs.url_versioning')
def test_url_v2(client, validate_schema):
prefix = '/versioned/url/v2.0/'
swagger = _get_versioned_schema(prefix, client, validate_schema)
_check_v2(swagger, prefix)
swagger = _get_versioned_schema('/versioned/url/v2.0', client, validate_schema)
_check_v2(swagger)
@pytest.mark.urls('urlconfs.ns_versioning')
def test_ns_v1(client, validate_schema):
prefix = '/versioned/ns/v1.0/'
swagger = _get_versioned_schema(prefix, client, validate_schema)
_check_v1(swagger, prefix)
swagger = _get_versioned_schema('/versioned/ns/v1.0', client, validate_schema)
_check_v1(swagger)
@pytest.mark.urls('urlconfs.ns_versioning')
def test_ns_v2(client, validate_schema):
prefix = '/versioned/ns/v2.0/'
swagger = _get_versioned_schema(prefix, client, validate_schema)
_check_v2(swagger, prefix)
swagger = _get_versioned_schema('/versioned/ns/v2.0', client, validate_schema)
_check_v2(swagger)

View File

@ -17,7 +17,7 @@ class SnippetListV2(SnippetListV1):
serializer_class = SnippetSerializerV2
app_name = 'test_ns_versioning'
app_name = '2.0'
urlpatterns = [
url(r"^$", SnippetListV2.as_view())

View File

@ -19,7 +19,7 @@ schema_patterns = [
urlpatterns = [
url(VERSION_PREFIX_NS + r"v1.0/snippets/", include(ns_version1, namespace='1.0')),
url(VERSION_PREFIX_NS + r"v2.0/snippets/", include(ns_version2, namespace='2.0')),
url(VERSION_PREFIX_NS + r"v2.0/snippets/", include(ns_version2)),
url(VERSION_PREFIX_NS + r'v1.0/', include((schema_patterns, '1.0'))),
url(VERSION_PREFIX_NS + r'v2.0/', include((schema_patterns, '2.0'))),
]