Improve host, schemes and basePath handling (#42)

* added handling of basePath by taking into account SCRIPT_NAME and the longest common prefix
* improved handling of NamespaceVersioning by excluding URLs of differing versions
* added documentation and error messages for the problem reported in #37
openapi3
Cristi Vîjdea 2018-01-12 03:37:04 +01:00 committed by GitHub
parent 757d47e1c0
commit 7a3fe8ec0c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 247 additions and 99 deletions

View File

@ -42,6 +42,10 @@ Features
- generated Swagger schema can be automatically validated by - generated Swagger schema can be automatically validated by
`swagger-spec-validator <https://github.com/Yelp/swagger_spec_validator>`_ or `swagger-spec-validator <https://github.com/Yelp/swagger_spec_validator>`_ or
`flex <https://github.com/pipermerriam/flex>`_ `flex <https://github.com/pipermerriam/flex>`_
- supports Django REST Framework API versioning
+ ``URLPathVersioning``, ``NamespaceVersioning`` and ``HostnameVersioning`` are supported
+ ``AcceptHeaderVersioning`` and ``QueryParameterVersioning`` are not currently supported
.. figure:: https://raw.githubusercontent.com/axnsan12/drf-yasg/1.0.2/screenshots/redoc-nested-response.png .. figure:: https://raw.githubusercontent.com/axnsan12/drf-yasg/1.0.2/screenshots/redoc-nested-response.png
:width: 100% :width: 100%
@ -211,7 +215,7 @@ The possible settings and their default values are as follows:
# default api Info if none is otherwise given; should be an import string to an openapi.Info object # default api Info if none is otherwise given; should be an import string to an openapi.Info object
'DEFAULT_INFO': None, 'DEFAULT_INFO': None,
# default API url if none is otherwise given # default API url if none is otherwise given
'DEFAULT_API_URL': '', 'DEFAULT_API_URL': None,
'USE_SESSION_AUTH': True, # add Django Login and Django Logout buttons, CSRF token to swagger UI page 'USE_SESSION_AUTH': True, # add Django Login and Django Logout buttons, CSRF token to swagger UI page
'LOGIN_URL': getattr(django.conf.settings, 'LOGIN_URL', None), # URL for the login button 'LOGIN_URL': getattr(django.conf.settings, 'LOGIN_URL', None), # URL for the login button

View File

@ -127,6 +127,29 @@ This section describes where information is sourced from when using the default
* *descriptions* for :class:`.Operation`\ s, :class:`.Parameter`\ s and :class:`.Schema`\ s are picked up from * *descriptions* for :class:`.Operation`\ s, :class:`.Parameter`\ s and :class:`.Schema`\ s are picked up from
docstrings and ``help_text`` attributes in the same manner as the `default DRF SchemaGenerator docstrings and ``help_text`` attributes in the same manner as the `default DRF SchemaGenerator
<http://www.django-rest-framework.org/api-guide/schemas/#schemas-as-documentation>`_ <http://www.django-rest-framework.org/api-guide/schemas/#schemas-as-documentation>`_
* .. _custom-spec-base-url:
the base URL for the API consists of three values - the ``host``, ``schemes`` and ``basePath`` attributes
* the host name and scheme are determined, in descending order of priority:
+ from the ``url`` argument passed to :func:`.get_schema_view` (more specifically, to the underlying
:class:`.OpenAPISchemaGenerator`)
+ from the :ref:`DEFAULT_API_URL setting <default-swagger-settings>`
+ inferred from the request made to the schema endpoint
For example, an url of ``https://www.example.com:8080/some/path`` will populate the ``host`` and ``schemes``
attributes with ``www.example.com:8080`` and ``['https']``, respectively. The path component will be ignored.
* the base path is determined as the concatenation of two variables:
#. the `SCRIPT_NAME`_ wsgi environment variable; this is set, for example, when serving the site from a
sub-path using web server url rewriting
.. Tip::
The Django `FORCE_SCRIPT_NAME`_ setting can be used to override the `SCRIPT_NAME`_ or set it when it's
missing from the environment.
#. the longest common path prefix of all the urls in your API - see :meth:`.determine_path_prefix`
.. _custom-spec-swagger-auto-schema: .. _custom-spec-swagger-auto-schema:
@ -398,3 +421,6 @@ A second example, of a :class:`~.inspectors.FieldInspector` that removes the ``t
This means that you should generally avoid view or method-specific ``FieldInspector``\ s if you are dealing with This means that you should generally avoid view or method-specific ``FieldInspector``\ s if you are dealing with
references (a.k.a named models), because you can never know which view will be the first to generate the schema references (a.k.a named models), because you can never know which view will be the first to generate the schema
for a given serializer. for a given serializer.
.. _SCRIPT_NAME: https://www.python.org/dev/peps/pep-0333/#environ-variables
.. _FORCE_SCRIPT_NAME: https://docs.djangoproject.com/en/2.0/ref/settings/#force-script-name

View File

@ -66,7 +66,7 @@ See the command help for more advanced options:
In ``settings.py``: In ``settings.py``:
.. code:: python .. code-block:: python
SWAGGER_SETTINGS = { SWAGGER_SETTINGS = {
'DEFAULT_INFO': 'import.path.to.urls.api_info', 'DEFAULT_INFO': 'import.path.to.urls.api_info',
@ -74,7 +74,7 @@ See the command help for more advanced options:
In ``urls.py``: In ``urls.py``:
.. code:: python .. code-block:: python
api_info = openapi.Info( api_info = openapi.Info(
title="Snippets API", title="Snippets API",

View File

@ -107,10 +107,13 @@ management command, or if no ``info`` argument is passed to ``get_schema_view``.
DEFAULT_API_URL DEFAULT_API_URL
--------------- ---------------
A string representing the default API URL. This will be used to populate the ``host``, ``schemes`` and ``basePath`` A string representing the default API URL. This will be used to populate the ``host`` and ``schemes`` attributes
attributes of the Swagger document if no API URL is otherwise provided. of the Swagger document if no API URL is otherwise provided. The Django `FORCE_SCRIPT_NAME`_ setting can be used for
providing an API mount point prefix.
**Default**: :python:`''` See also: :ref:`documentation on base URL construction <custom-spec-base-url>`
**Default**: :python:`None`
Authorization Authorization
============= =============
@ -274,3 +277,6 @@ PATH_IN_MIDDLE
**Default**: :python:`False` |br| **Default**: :python:`False` |br|
*Maps to attribute*: ``path-in-middle-panel`` *Maps to attribute*: ``path-in-middle-panel``
.. _FORCE_SCRIPT_NAME: https://docs.djangoproject.com/en/2.0/ref/settings/#force-script-name

View File

@ -23,7 +23,7 @@ SWAGGER_DEFAULTS = {
], ],
'DEFAULT_INFO': None, 'DEFAULT_INFO': None,
'DEFAULT_API_URL': '', 'DEFAULT_API_URL': None,
'USE_SESSION_AUTH': True, 'USE_SESSION_AUTH': True,
'SECURITY_DEFINITIONS': { 'SECURITY_DEFINITIONS': {

View File

@ -1,25 +1,107 @@
import logging
import re import re
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
import uritemplate import uritemplate
from coreapi.compat import urlparse
from django.utils.encoding import force_text from django.utils.encoding import force_text
from rest_framework import versioning from rest_framework import versioning
from rest_framework.compat import URLPattern, URLResolver, get_original_route
from rest_framework.schemas.generators import EndpointEnumerator as _EndpointEnumerator from rest_framework.schemas.generators import EndpointEnumerator as _EndpointEnumerator
from rest_framework.schemas.generators import SchemaGenerator from rest_framework.schemas.generators import SchemaGenerator, endpoint_ordering
from rest_framework.schemas.inspectors import get_pk_description from rest_framework.schemas.inspectors import get_pk_description
from drf_yasg.errors import SwaggerGenerationError
from . import openapi from . import openapi
from .app_settings import swagger_settings from .app_settings import swagger_settings
from .inspectors.field import get_basic_type_info, get_queryset_field from .inspectors.field import get_basic_type_info, get_queryset_field
from .openapi import ReferenceResolver from .openapi import ReferenceResolver
logger = logging.getLogger(__name__)
PATH_PARAMETER_RE = re.compile(r'{(?P<parameter>\w+)}') PATH_PARAMETER_RE = re.compile(r'{(?P<parameter>\w+)}')
class EndpointEnumerator(_EndpointEnumerator): class EndpointEnumerator(_EndpointEnumerator):
def __init__(self, patterns=None, urlconf=None, request=None):
super(EndpointEnumerator, self).__init__(patterns, urlconf)
self.request = request
def get_path_from_regex(self, path_regex): def get_path_from_regex(self, path_regex):
if path_regex.endswith(')'):
logger.warning("url pattern does not end in $ ('%s') - unexpected things might happen")
return self.unescape_path(super(EndpointEnumerator, self).get_path_from_regex(path_regex)) return self.unescape_path(super(EndpointEnumerator, self).get_path_from_regex(path_regex))
def should_include_endpoint(self, path, callback, app_name='', namespace='', url_name=None):
if not super(EndpointEnumerator, self).should_include_endpoint(path, callback):
return False
version = getattr(self.request, 'version', None)
versioning_class = getattr(callback.cls, 'versioning_class', None)
if versioning_class is not None and issubclass(versioning_class, versioning.NamespaceVersioning):
if version and version not in namespace.split(':'):
return False
return True
def replace_version(self, path, callback):
"""If ``request.version`` is not ``None`` and `callback` uses ``URLPathVersioning``, this function replaces
the ``version`` parameter in `path` with the actual version.
:param str path: the templated path
:param callback: the view callback
:rtype: str
"""
versioning_class = getattr(callback.cls, 'versioning_class', None)
if versioning_class is not None and issubclass(versioning_class, versioning.URLPathVersioning):
version = getattr(self.request, 'version', None)
if version:
version_param = getattr(versioning_class, 'version_param', 'version')
version_param = '{%s}' % version_param
if version_param not in path:
logger.info("view %s uses URLPathVersioning but URL %s has no param %s"
% (callback.cls, path, version_param))
path = path.replace(version_param, version)
return path
def get_api_endpoints(self, patterns=None, prefix='', app_name=None, namespace=None):
"""
Return a list of all available API endpoints by inspecting the URL conf.
Copied entirely from super.
"""
if patterns is None:
patterns = self.patterns
api_endpoints = []
for pattern in patterns:
path_regex = prefix + get_original_route(pattern)
if isinstance(pattern, URLPattern):
path = self.get_path_from_regex(path_regex)
callback = pattern.callback
url_name = pattern.name
if self.should_include_endpoint(path, callback, app_name or '', namespace or '', url_name):
path = self.replace_version(path, callback)
for method in self.get_allowed_methods(callback):
endpoint = (path, method, callback)
api_endpoints.append(endpoint)
elif isinstance(pattern, URLResolver):
nested_endpoints = self.get_api_endpoints(
patterns=pattern.url_patterns,
prefix=path_regex,
app_name="%s:%s" % (app_name, pattern.app_name) if app_name else pattern.app_name,
namespace="%s:%s" % (namespace, pattern.namespace) if namespace else pattern.namespace
)
api_endpoints.extend(nested_endpoints)
api_endpoints = sorted(api_endpoints, key=endpoint_ordering)
return api_endpoints
def unescape(self, s): def unescape(self, s):
"""Unescape all backslash escapes from `s`. """Unescape all backslash escapes from `s`.
@ -30,8 +112,8 @@ class EndpointEnumerator(_EndpointEnumerator):
return re.sub(r'\\(.)', r'\1', s) return re.sub(r'\\(.)', r'\1', s)
def unescape_path(self, path): def unescape_path(self, path):
"""Remove backslashes from all path components outside {parameters}. This is needed because """Remove backslashe escapes from all path components outside {parameters}. This is needed because
Django>=2.0 ``path()``/``RoutePattern`` aggresively escapes all non-parameter path components. ``simplify_regex`` does not handle this correctly - note however that this implementation is
**NOTE:** this might destructively affect some url regex patterns that contain metacharacters (e.g. \w, \d) **NOTE:** this might destructively affect some url regex patterns that contain metacharacters (e.g. \w, \d)
outside path parameter groups; if you are in this category, God help you outside path parameter groups; if you are in this category, God help you
@ -60,12 +142,19 @@ class OpenAPISchemaGenerator(object):
""" """
endpoint_enumerator_class = EndpointEnumerator endpoint_enumerator_class = EndpointEnumerator
def __init__(self, info, version='', url=swagger_settings.DEFAULT_API_URL, patterns=None, urlconf=None): def __init__(self, info, version='', url=None, patterns=None, urlconf=None):
""" """
:param .Info info: information about the API :param .Info info: information about the API
:param str version: API version string; can be omitted to use `info.default_version` :param str version: API version string; if omitted, `info.default_version` will be used
:param str url: API url; can be empty to remove URL info from the result :param str url: API scheme, host and port; if ``None`` is passed and ``DEFAULT_API_URL`` is not set, the url
will be inferred from the request made against the schema view, so you should generally not need to set
this parameter explicitly; if the empty string is passed, no host and scheme will be emitted
If `url` is not ``None`` or the empty string, it must be a scheme-absolute uri (i.e. starting with http://
or https://), and any path component is ignored;
See also: :ref:`documentation on base URL construction <custom-spec-base-url>`
:param patterns: if given, only these patterns will be enumerated for inclusion in the API spec :param patterns: if given, only these patterns will be enumerated for inclusion in the API spec
:param urlconf: if patterns is not given, use this urlconf to enumerate patterns; :param urlconf: if patterns is not given, use this urlconf to enumerate patterns;
if not given, the default urlconf is used if not given, the default urlconf is used
@ -73,6 +162,15 @@ class OpenAPISchemaGenerator(object):
self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf) self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf)
self.info = info self.info = info
self.version = version self.version = version
if url is None and swagger_settings.DEFAULT_API_URL is not None:
url = swagger_settings.DEFAULT_API_URL
if url:
parsed_url = urlparse.urlparse(url)
if parsed_url.scheme not in ('http', 'https') or not parsed_url.netloc:
raise SwaggerGenerationError("`url` must be an absolute HTTP(S) url")
if parsed_url.path:
logger.warning("path component of api base URL %s is ignored; use FORCE_SCRIPT_NAME instead" % url)
@property @property
def url(self): def url(self):
@ -89,17 +187,16 @@ class OpenAPISchemaGenerator(object):
:rtype: openapi.Swagger :rtype: openapi.Swagger
""" """
endpoints = self.get_endpoints(request) endpoints = self.get_endpoints(request)
endpoints = self.replace_version(endpoints, request)
components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS) components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS)
paths = self.get_paths(endpoints, components, request, public) paths, prefix = self.get_paths(endpoints, components, request, public)
url = self.url url = self.url
if not url and request is not None: if url is None and request is not None:
url = request.build_absolute_uri() url = request.build_absolute_uri()
return openapi.Swagger( return openapi.Swagger(
info=self.info, paths=paths, info=self.info, paths=paths,
_url=url, _version=self.version, **dict(components) _url=url, _prefix=prefix, _version=self.version, **dict(components)
) )
def create_view(self, callback, method, request=None): def create_view(self, callback, method, request=None):
@ -120,30 +217,6 @@ class OpenAPISchemaGenerator(object):
setattr(view_method.__func__, '_swagger_auto_schema', overrides) setattr(view_method.__func__, '_swagger_auto_schema', overrides)
return view return view
def replace_version(self, endpoints, request):
"""If ``request.version`` is not ``None``, replace the version parameter in the path of any endpoints using
``URLPathVersioning`` as a versioning class.
:param dict endpoints: endpoints as returned by :meth:`.get_endpoints`
:param Request request: the request made against the schema view
:return: endpoints with modified paths
"""
version = getattr(request, 'version', None)
if version is None:
return endpoints
new_endpoints = {}
for path, endpoint in endpoints.items():
view_cls = endpoint[0]
versioning_class = getattr(view_cls, 'versioning_class', None)
version_param = getattr(versioning_class, 'version_param', 'version')
if versioning_class is not None and issubclass(versioning_class, versioning.URLPathVersioning):
path = path.replace('{%s}' % version_param, version)
new_endpoints[path] = endpoint
return new_endpoints
def get_endpoints(self, request): def get_endpoints(self, request):
"""Iterate over all the registered endpoints in the API and return a fake view with the right parameters. """Iterate over all the registered endpoints in the API and return a fake view with the right parameters.
@ -151,7 +224,7 @@ class OpenAPISchemaGenerator(object):
:return: {path: (view_class, list[(http_method, view_instance)]) :return: {path: (view_class, list[(http_method, view_instance)])
:rtype: dict :rtype: dict
""" """
enumerator = self.endpoint_enumerator_class(self._gen.patterns, self._gen.urlconf) enumerator = self.endpoint_enumerator_class(self._gen.patterns, self._gen.urlconf, request=request)
endpoints = enumerator.get_api_endpoints() endpoints = enumerator.get_api_endpoints()
view_paths = defaultdict(list) view_paths = defaultdict(list)
@ -207,14 +280,16 @@ class OpenAPISchemaGenerator(object):
:param ReferenceResolver components: resolver/container for Swagger References :param ReferenceResolver components: resolver/container for Swagger References
:param Request request: the request made against the schema view; can be None :param Request request: the request made against the schema view; can be None
:param bool public: if True, all endpoints are included regardless of access through `request` :param bool public: if True, all endpoints are included regardless of access through `request`
:rtype: openapi.Paths :returns: the :class:`.Paths` object and the longest common path prefix, as a 2-tuple
:rtype: tuple[openapi.Paths,str]
""" """
if not endpoints: if not endpoints:
return openapi.Paths(paths={}) return openapi.Paths(paths={}), ''
prefix = self.determine_path_prefix(list(endpoints.keys())) or ''
assert '{' not in prefix, "base path cannot be templated in swagger 2.0"
prefix = self.determine_path_prefix(list(endpoints.keys()))
paths = OrderedDict() paths = OrderedDict()
for path, (view_cls, methods) in sorted(endpoints.items()): for path, (view_cls, methods) in sorted(endpoints.items()):
operations = {} operations = {}
for method, view in methods: for method, view in methods:
@ -224,9 +299,14 @@ class OpenAPISchemaGenerator(object):
operations[method.lower()] = self.get_operation(view, path, prefix, method, components, request) operations[method.lower()] = self.get_operation(view, path, prefix, method, components, request)
if operations: if operations:
paths[path] = self.get_path_item(path, view_cls, operations) # since the common prefix is used as the API basePath, it must be stripped
# from individual paths when writing them into the swagger document
path_suffix = path[len(prefix):]
if not path_suffix.startswith('/'):
path_suffix = '/' + path_suffix
paths[path_suffix] = self.get_path_item(path, view_cls, operations)
return openapi.Paths(paths=paths) return openapi.Paths(paths=paths), prefix
def get_operation(self, view, path, prefix, method, components, request): def get_operation(self, view, path, prefix, method, components, request):
"""Get an :class:`.Operation` for the given API endpoint (path, method). This method delegates to """Get an :class:`.Operation` for the given API endpoint (path, method). This method delegates to

View File

@ -42,7 +42,7 @@ class Command(BaseCommand):
'-u', '--url', dest='api_url', '-u', '--url', dest='api_url',
default='', default='',
type=str, type=str,
help='Base API URL - sets the host, scheme and basePath attributes of the generated document.' help='Base API URL - sets the host and scheme attributes of the generated document.'
) )
parser.add_argument( parser.add_argument(
'-m', '--mock-request', dest='mock', '-m', '--mock-request', dest='mock',

View File

@ -2,6 +2,7 @@ import re
from collections import OrderedDict from collections import OrderedDict
from coreapi.compat import urlparse from coreapi.compat import urlparse
from django.urls import get_script_prefix
from inflection import camelize from inflection import camelize
from .utils import filter_none from .utils import filter_none
@ -210,11 +211,13 @@ class Info(SwaggerDict):
class Swagger(SwaggerDict): class Swagger(SwaggerDict):
def __init__(self, info=None, _url=None, _version=None, paths=None, definitions=None, **extra): def __init__(self, info=None, _url=None, _prefix=None, _version=None, paths=None, definitions=None, **extra):
"""Root Swagger object. """Root Swagger object.
:param .Info info: info object :param .Info info: info object
:param str _url: URL used for guessing the API host, scheme and basepath :param str _url: URL used for setting the API host and scheme
:param str _prefix: api path prefix to use in setting basePath; this will be appended to the wsgi
SCRIPT_NAME prefix or Django's FORCE_SCRIPT_NAME if applicable
:param str _version: version string to override Info :param str _version: version string to override Info
:param .Paths paths: paths object :param .Paths paths: paths object
:param dict[str,.Schema] definitions: named models :param dict[str,.Schema] definitions: named models
@ -226,16 +229,39 @@ class Swagger(SwaggerDict):
if _url: if _url:
url = urlparse.urlparse(_url) url = urlparse.urlparse(_url)
if url.netloc: assert url.netloc and url.scheme, "if given, url must have both schema and netloc"
self.host = url.netloc self.host = url.netloc
if url.scheme:
self.schemes = [url.scheme] self.schemes = [url.scheme]
self.base_path = '/'
self.base_path = self.get_base_path(get_script_prefix(), _prefix)
self.paths = paths self.paths = paths
self.definitions = filter_none(definitions) self.definitions = filter_none(definitions)
self._insert_extras__() self._insert_extras__()
@classmethod
def get_base_path(cls, script_prefix, api_prefix):
"""Determine an appropriate value for ``basePath`` based on the SCRIPT_NAME and the api common prefix.
:param str script_prefix: script prefix as defined by django ``get_script_prefix``
:param str api_prefix: api common prefix
:return: joined base path
"""
# avoid double slash when joining script_name with api_prefix
if script_prefix and script_prefix.endswith('/'):
script_prefix = script_prefix[:-1]
if not api_prefix.startswith('/'):
api_prefix = '/' + api_prefix
base_path = script_prefix + api_prefix
# ensure that the base path has a leading slash and no trailing slash
if base_path and base_path.endswith('/'):
base_path = base_path[:-1]
if not base_path.startswith('/'):
base_path = '/' + base_path
return base_path
class Paths(SwaggerDict): class Paths(SwaggerDict):
def __init__(self, paths, **extra): def __init__(self, paths, **extra):

View File

@ -52,14 +52,13 @@ def get_schema_view(info=None, url=None, patterns=None, urlconf=None, public=Fal
generator_class=OpenAPISchemaGenerator, generator_class=OpenAPISchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES): permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES):
""" """Create a SchemaView class with default renderers and generators.
Create a SchemaView class with default renderers and generators.
:param .Info info: Swagger API Info object; if omitted, defaults to `DEFAULT_INFO` :param .Info info: information about the API; if omitted, defaults to :ref:`DEFAULT_INFO <default-swagger-settings>`
:param str url: API base url; if left blank will be deduced from the location the view is served at :param str url: same as :class:`.OpenAPISchemaGenerator`
:param patterns: passed to SchemaGenerator :param patterns: same as :class:`.OpenAPISchemaGenerator`
:param urlconf: passed to SchemaGenerator :param urlconf: same as :class:`.OpenAPISchemaGenerator`
:param bool public: if False, includes only endpoints the current user has access to :param bool public: if False, includes only the endpoints that are accesible by the user viewing the schema
:param list validators: a list of validator names to apply; allowed values are ``flex``, ``ssv`` :param list validators: a list of validator names to apply; allowed values are ``flex``, ``ssv``
:param type generator_class: schema generator class to use; should be a subclass of :class:`.OpenAPISchemaGenerator` :param type generator_class: schema generator class to use; should be a subclass of :class:`.OpenAPISchemaGenerator`
:param tuple authentication_classes: authentication classes for the schema view itself :param tuple authentication_classes: authentication classes for the schema view itself

View File

@ -5,6 +5,7 @@ import pytest
from drf_yasg import codecs, openapi from drf_yasg import codecs, openapi
from drf_yasg.codecs import yaml_sane_load from drf_yasg.codecs import yaml_sane_load
from drf_yasg.errors import SwaggerGenerationError
from drf_yasg.generators import OpenAPISchemaGenerator from drf_yasg.generators import OpenAPISchemaGenerator
@ -46,14 +47,23 @@ def test_yaml_and_json_match(codec_yaml, codec_json, swagger):
def test_basepath_only(mock_schema_request): def test_basepath_only(mock_schema_request):
with pytest.raises(SwaggerGenerationError):
generator = OpenAPISchemaGenerator( generator = OpenAPISchemaGenerator(
info=openapi.Info(title="Test generator", default_version="v1"), info=openapi.Info(title="Test generator", default_version="v1"),
version="v2", version="v2",
url='/basepath/', url='/basepath/',
) )
generator.get_schema(mock_schema_request, public=True)
def test_no_netloc(mock_schema_request):
generator = OpenAPISchemaGenerator(
info=openapi.Info(title="Test generator", default_version="v1"),
version="v2",
url='',
)
swagger = generator.get_schema(mock_schema_request, public=True) swagger = generator.get_schema(mock_schema_request, public=True)
assert 'host' not in swagger assert 'host' not in swagger and 'schemes' not in swagger
assert 'schemes' not in swagger
assert swagger['basePath'] == '/' # base path is not implemented for now
assert swagger['info']['version'] == 'v2' assert swagger['info']['version'] == 'v2'

View File

@ -4,24 +4,25 @@ from drf_yasg.codecs import yaml_sane_load
def _get_versioned_schema(prefix, client, validate_schema): def _get_versioned_schema(prefix, client, validate_schema):
response = client.get(prefix + 'swagger.yaml') response = client.get(prefix + '/swagger.yaml')
assert response.status_code == 200 assert response.status_code == 200
swagger = yaml_sane_load(response.content.decode('utf-8')) swagger = yaml_sane_load(response.content.decode('utf-8'))
assert swagger['basePath'] == prefix
validate_schema(swagger) validate_schema(swagger)
assert prefix + 'snippets/' in swagger['paths'] assert '/snippets/' in swagger['paths']
return swagger return swagger
def _check_v1(swagger, prefix): def _check_v1(swagger):
assert swagger['info']['version'] == '1.0' assert swagger['info']['version'] == '1.0'
versioned_post = swagger['paths'][prefix + 'snippets/']['post'] versioned_post = swagger['paths']['/snippets/']['post']
assert versioned_post['responses']['201']['schema']['$ref'] == '#/definitions/Snippet' assert versioned_post['responses']['201']['schema']['$ref'] == '#/definitions/Snippet'
assert 'v2field' not in swagger['definitions']['Snippet']['properties'] assert 'v2field' not in swagger['definitions']['Snippet']['properties']
def _check_v2(swagger, prefix): def _check_v2(swagger):
assert swagger['info']['version'] == '2.0' assert swagger['info']['version'] == '2.0'
versioned_post = swagger['paths'][prefix + 'snippets/']['post'] versioned_post = swagger['paths']['/snippets/']['post']
assert versioned_post['responses']['201']['schema']['$ref'] == '#/definitions/SnippetV2' assert versioned_post['responses']['201']['schema']['$ref'] == '#/definitions/SnippetV2'
assert 'v2field' in swagger['definitions']['SnippetV2']['properties'] assert 'v2field' in swagger['definitions']['SnippetV2']['properties']
v2field = swagger['definitions']['SnippetV2']['properties']['v2field'] v2field = swagger['definitions']['SnippetV2']['properties']['v2field']
@ -30,27 +31,23 @@ def _check_v2(swagger, prefix):
@pytest.mark.urls('urlconfs.url_versioning') @pytest.mark.urls('urlconfs.url_versioning')
def test_url_v1(client, validate_schema): def test_url_v1(client, validate_schema):
prefix = '/versioned/url/v1.0/' swagger = _get_versioned_schema('/versioned/url/v1.0', client, validate_schema)
swagger = _get_versioned_schema(prefix, client, validate_schema) _check_v1(swagger)
_check_v1(swagger, prefix)
@pytest.mark.urls('urlconfs.url_versioning') @pytest.mark.urls('urlconfs.url_versioning')
def test_url_v2(client, validate_schema): def test_url_v2(client, validate_schema):
prefix = '/versioned/url/v2.0/' swagger = _get_versioned_schema('/versioned/url/v2.0', client, validate_schema)
swagger = _get_versioned_schema(prefix, client, validate_schema) _check_v2(swagger)
_check_v2(swagger, prefix)
@pytest.mark.urls('urlconfs.ns_versioning') @pytest.mark.urls('urlconfs.ns_versioning')
def test_ns_v1(client, validate_schema): def test_ns_v1(client, validate_schema):
prefix = '/versioned/ns/v1.0/' swagger = _get_versioned_schema('/versioned/ns/v1.0', client, validate_schema)
swagger = _get_versioned_schema(prefix, client, validate_schema) _check_v1(swagger)
_check_v1(swagger, prefix)
@pytest.mark.urls('urlconfs.ns_versioning') @pytest.mark.urls('urlconfs.ns_versioning')
def test_ns_v2(client, validate_schema): def test_ns_v2(client, validate_schema):
prefix = '/versioned/ns/v2.0/' swagger = _get_versioned_schema('/versioned/ns/v2.0', client, validate_schema)
swagger = _get_versioned_schema(prefix, client, validate_schema) _check_v2(swagger)
_check_v2(swagger, prefix)

View File

@ -17,7 +17,7 @@ class SnippetListV2(SnippetListV1):
serializer_class = SnippetSerializerV2 serializer_class = SnippetSerializerV2
app_name = 'test_ns_versioning' app_name = '2.0'
urlpatterns = [ urlpatterns = [
url(r"^$", SnippetListV2.as_view()) url(r"^$", SnippetListV2.as_view())

View File

@ -19,7 +19,7 @@ schema_patterns = [
urlpatterns = [ urlpatterns = [
url(VERSION_PREFIX_NS + r"v1.0/snippets/", include(ns_version1, namespace='1.0')), url(VERSION_PREFIX_NS + r"v1.0/snippets/", include(ns_version1, namespace='1.0')),
url(VERSION_PREFIX_NS + r"v2.0/snippets/", include(ns_version2, namespace='2.0')), url(VERSION_PREFIX_NS + r"v2.0/snippets/", include(ns_version2)),
url(VERSION_PREFIX_NS + r'v1.0/', include((schema_patterns, '1.0'))), url(VERSION_PREFIX_NS + r'v1.0/', include((schema_patterns, '1.0'))),
url(VERSION_PREFIX_NS + r'v2.0/', include((schema_patterns, '2.0'))), url(VERSION_PREFIX_NS + r'v2.0/', include((schema_patterns, '2.0'))),
] ]