Handle duplicate urls in urlconf (#155)
Django resolves urls in order from top to bottom, and only uses the first matching URL found.openapi3
parent
696ec3a94a
commit
544d72db0a
|
|
@ -72,7 +72,7 @@ class EndpointEnumerator(_EndpointEnumerator):
|
||||||
|
|
||||||
return path
|
return path
|
||||||
|
|
||||||
def get_api_endpoints(self, patterns=None, prefix='', app_name=None, namespace=None):
|
def get_api_endpoints(self, patterns=None, prefix='', app_name=None, namespace=None, previously_seen_endpoints=None):
|
||||||
"""
|
"""
|
||||||
Return a list of all available API endpoints by inspecting the URL conf.
|
Return a list of all available API endpoints by inspecting the URL conf.
|
||||||
|
|
||||||
|
|
@ -82,6 +82,8 @@ class EndpointEnumerator(_EndpointEnumerator):
|
||||||
patterns = self.patterns
|
patterns = self.patterns
|
||||||
|
|
||||||
api_endpoints = []
|
api_endpoints = []
|
||||||
|
if previously_seen_endpoints is None:
|
||||||
|
previously_seen_endpoints = set()
|
||||||
|
|
||||||
for pattern in patterns:
|
for pattern in patterns:
|
||||||
path_regex = prefix + get_original_route(pattern)
|
path_regex = prefix + get_original_route(pattern)
|
||||||
|
|
@ -92,6 +94,13 @@ class EndpointEnumerator(_EndpointEnumerator):
|
||||||
url_name = pattern.name
|
url_name = pattern.name
|
||||||
if self.should_include_endpoint(path, callback, app_name or '', namespace or '', url_name):
|
if self.should_include_endpoint(path, callback, app_name or '', namespace or '', url_name):
|
||||||
path = self.replace_version(path, callback)
|
path = self.replace_version(path, callback)
|
||||||
|
|
||||||
|
# avoid adding endpoints that have already been seen,
|
||||||
|
# as Django resolves urls in top-down order
|
||||||
|
if path in previously_seen_endpoints:
|
||||||
|
continue
|
||||||
|
previously_seen_endpoints.add(path)
|
||||||
|
|
||||||
for method in self.get_allowed_methods(callback):
|
for method in self.get_allowed_methods(callback):
|
||||||
endpoint = (path, method, callback)
|
endpoint = (path, method, callback)
|
||||||
api_endpoints.append(endpoint)
|
api_endpoints.append(endpoint)
|
||||||
|
|
@ -103,7 +112,8 @@ class EndpointEnumerator(_EndpointEnumerator):
|
||||||
patterns=pattern.url_patterns,
|
patterns=pattern.url_patterns,
|
||||||
prefix=path_regex,
|
prefix=path_regex,
|
||||||
app_name="%s:%s" % (app_name, pattern.app_name) if app_name else pattern.app_name,
|
app_name="%s:%s" % (app_name, pattern.app_name) if app_name else pattern.app_name,
|
||||||
namespace="%s:%s" % (namespace, pattern.namespace) if namespace else pattern.namespace
|
namespace="%s:%s" % (namespace, pattern.namespace) if namespace else pattern.namespace,
|
||||||
|
previously_seen_endpoints=previously_seen_endpoints
|
||||||
)
|
)
|
||||||
api_endpoints.extend(nested_endpoints)
|
api_endpoints.extend(nested_endpoints)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,9 @@ import json
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from django.conf.urls import url
|
||||||
from rest_framework import routers, serializers, viewsets
|
from rest_framework import routers, serializers, viewsets
|
||||||
|
from rest_framework.decorators import api_view
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
|
|
||||||
from drf_yasg import codecs, openapi
|
from drf_yasg import codecs, openapi
|
||||||
|
|
@ -113,3 +115,35 @@ def test_replaced_serializer():
|
||||||
responses = swagger['paths']['/details/{id}/']['get']['responses']
|
responses = swagger['paths']['/details/{id}/']['get']['responses']
|
||||||
assert '404' in responses
|
assert '404' in responses
|
||||||
assert responses['404']['schema']['$ref'] == "#/definitions/Detail"
|
assert responses['404']['schema']['$ref'] == "#/definitions/Detail"
|
||||||
|
|
||||||
|
|
||||||
|
def test_url_order():
|
||||||
|
# this view with description override should show up in the schema ...
|
||||||
|
@swagger_auto_schema(method='get', operation_description="description override")
|
||||||
|
@api_view()
|
||||||
|
def test_override(request, pk=None):
|
||||||
|
return Response({"message": "Hello, world!"})
|
||||||
|
|
||||||
|
# ... instead of this view which appears later in the url patterns
|
||||||
|
@api_view()
|
||||||
|
def test_view(request, pk=None):
|
||||||
|
return Response({"message": "Hello, world!"})
|
||||||
|
|
||||||
|
patterns = [
|
||||||
|
url(r'^/test/$', test_override),
|
||||||
|
url(r'^/test/$', test_view),
|
||||||
|
]
|
||||||
|
|
||||||
|
generator = OpenAPISchemaGenerator(
|
||||||
|
info=openapi.Info(title="Test generator", default_version="v1"),
|
||||||
|
version="v2",
|
||||||
|
url='',
|
||||||
|
patterns=patterns
|
||||||
|
)
|
||||||
|
|
||||||
|
# description override is successful
|
||||||
|
swagger = generator.get_schema(None, True)
|
||||||
|
assert swagger['paths']['/test/']['get']['description'] == 'description override'
|
||||||
|
|
||||||
|
# get_endpoints only includes one endpoint
|
||||||
|
assert len(generator.get_endpoints(None)['/test/'][1]) == 1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue