diff --git a/docs/changelog.rst b/docs/changelog.rst index 0c625d9..c36ec9c 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -8,6 +8,8 @@ Changelog ********* - **FIX:** fixed a crash caused by having read-only Serializers nested by reference +- **FIX:** removed erroneous backslashes in paths when routes are generated using Django 2 + `path() `_ - **IMPROVEMENT:** updated ``swagger-ui`` to version 3.7.0 ********* diff --git a/docs/conf.py b/docs/conf.py index 1c9be10..22c83e7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -168,6 +168,7 @@ nitpick_ignore = [ ('py:class', 'ruamel.yaml.dumper.SafeDumper'), ('py:class', 'rest_framework.renderers.BaseRenderer'), + ('py:class', 'rest_framework.schemas.generators.EndpointEnumerator'), ('py:class', 'rest_framework.views.APIView'), ('py:class', 'OpenAPICodecYaml'), diff --git a/src/drf_yasg/generators.py b/src/drf_yasg/generators.py index 3549caf..dc98408 100644 --- a/src/drf_yasg/generators.py +++ b/src/drf_yasg/generators.py @@ -1,21 +1,63 @@ +import re from collections import defaultdict, OrderedDict import django.db.models import uritemplate from coreapi.compat import force_text -from rest_framework.schemas.generators import SchemaGenerator +from rest_framework.schemas.generators import SchemaGenerator, EndpointEnumerator as _EndpointEnumerator from rest_framework.schemas.inspectors import get_pk_description from . import openapi from .inspectors import SwaggerAutoSchema from .openapi import ReferenceResolver +PATH_PARAMETER_RE = re.compile(r'{(?P\w+)}') + + +class EndpointEnumerator(_EndpointEnumerator): + def get_path_from_regex(self, path_regex): + return self.unescape_path(super(EndpointEnumerator, self).get_path_from_regex(path_regex)) + + def unescape(self, s): + """Unescape all backslash escapes from `s`. + + :param str s: string with backslash escapes + :rtype: str + """ + # unlike .replace('\\', ''), this corectly transforms a double backslash into a single backslash + return re.sub(r'\\(.)', r'\1', s) + + def unescape_path(self, path): + """Remove backslashes from all path components outside {parameters}. This is needed because + Django>=2.0 ``path()``/``RoutePattern`` aggresively escapes all non-parameter path components. + + **NOTE:** this might destructively affect some url regex patterns that contain metacharacters (e.g. \w, \d) + outside path parameter groups; if you are in this category, God help you + + :param str path: path possibly containing + :return: the unescaped path + :rtype: str + """ + original_path = path + clean_path = '' + while path: + match = PATH_PARAMETER_RE.search(path) + if not match: + clean_path += self.unescape(path) + break + clean_path += self.unescape(path[:match.start()]) + clean_path += match.group() + path = path[match.end():] + + return clean_path + class OpenAPISchemaGenerator(object): """ This class iterates over all registered API endpoints and returns an appropriate OpenAPI 2.0 compliant schema. Method implementations shamelessly stolen and adapted from rest_framework SchemaGenerator. """ + endpoint_enumerator_class = EndpointEnumerator def __init__(self, info, version, url=None, patterns=None, urlconf=None): """ @@ -79,8 +121,8 @@ class OpenAPISchemaGenerator(object): :return: {path: (view_class, list[(http_method, view_instance)]) :rtype: dict """ - inspector = self._gen.endpoint_inspector_cls(self._gen.patterns, self._gen.urlconf) - endpoints = inspector.get_api_endpoints() + enumerator = self.endpoint_enumerator_class(self._gen.patterns, self._gen.urlconf) + endpoints = enumerator.get_api_endpoints() view_paths = defaultdict(list) view_cls = {} diff --git a/testproj/snippets/urls.py b/testproj/snippets/urls.py index 7af103e..dd34aa5 100644 --- a/testproj/snippets/urls.py +++ b/testproj/snippets/urls.py @@ -1,8 +1,17 @@ -from django.conf.urls import url +import django from . import views -urlpatterns = [ - url(r'$', views.SnippetList.as_view()), - url(r'^(?P[0-9]+)/$', views.SnippetDetail.as_view()), -] +if django.VERSION[:2] >= (2, 0): + from django.urls import path + + urlpatterns = [ + path('', views.SnippetList.as_view()), + path('/', views.SnippetDetail.as_view()), + ] +else: + from django.conf.urls import url + urlpatterns = [ + url('^$', views.SnippetList.as_view()), + url(r'^(?P\d+)/$', views.SnippetDetail.as_view()), + ] diff --git a/testproj/users/urls.py b/testproj/users/urls.py index 66e4126..7ab1870 100644 --- a/testproj/users/urls.py +++ b/testproj/users/urls.py @@ -4,5 +4,5 @@ from users import views urlpatterns = [ url(r'^$', views.UserList.as_view()), - url(r'^(?P[0-9]+)/$', views.user_detail), + url(r'^(?P\d+)/$', views.user_detail), ]