diff --git a/src/drf_yasg/generators.py b/src/drf_yasg/generators.py index b551eaf..32fa0d0 100644 --- a/src/drf_yasg/generators.py +++ b/src/drf_yasg/generators.py @@ -72,7 +72,7 @@ class EndpointEnumerator(_EndpointEnumerator): return path - def get_api_endpoints(self, patterns=None, prefix='', app_name=None, namespace=None): + def get_api_endpoints(self, patterns=None, prefix='', app_name=None, namespace=None, previously_seen_endpoints=None): """ Return a list of all available API endpoints by inspecting the URL conf. @@ -82,6 +82,8 @@ class EndpointEnumerator(_EndpointEnumerator): patterns = self.patterns api_endpoints = [] + if previously_seen_endpoints is None: + previously_seen_endpoints = set() for pattern in patterns: path_regex = prefix + get_original_route(pattern) @@ -92,6 +94,13 @@ class EndpointEnumerator(_EndpointEnumerator): url_name = pattern.name if self.should_include_endpoint(path, callback, app_name or '', namespace or '', url_name): path = self.replace_version(path, callback) + + # avoid adding endpoints that have already been seen, + # as Django resolves urls in top-down order + if path in previously_seen_endpoints: + continue + previously_seen_endpoints.add(path) + for method in self.get_allowed_methods(callback): endpoint = (path, method, callback) api_endpoints.append(endpoint) @@ -103,7 +112,8 @@ class EndpointEnumerator(_EndpointEnumerator): patterns=pattern.url_patterns, prefix=path_regex, app_name="%s:%s" % (app_name, pattern.app_name) if app_name else pattern.app_name, - namespace="%s:%s" % (namespace, pattern.namespace) if namespace else pattern.namespace + namespace="%s:%s" % (namespace, pattern.namespace) if namespace else pattern.namespace, + previously_seen_endpoints=previously_seen_endpoints ) api_endpoints.extend(nested_endpoints) else: diff --git a/tests/test_schema_generator.py b/tests/test_schema_generator.py index 6a229f2..6b4c915 100644 --- a/tests/test_schema_generator.py +++ b/tests/test_schema_generator.py @@ -2,7 +2,9 @@ import json from collections import OrderedDict import pytest +from django.conf.urls import url from rest_framework import routers, serializers, viewsets +from rest_framework.decorators import api_view from rest_framework.response import Response from drf_yasg import codecs, openapi @@ -113,3 +115,35 @@ def test_replaced_serializer(): responses = swagger['paths']['/details/{id}/']['get']['responses'] assert '404' in responses assert responses['404']['schema']['$ref'] == "#/definitions/Detail" + + +def test_url_order(): + # this view with description override should show up in the schema ... + @swagger_auto_schema(method='get', operation_description="description override") + @api_view() + def test_override(request, pk=None): + return Response({"message": "Hello, world!"}) + + # ... instead of this view which appears later in the url patterns + @api_view() + def test_view(request, pk=None): + return Response({"message": "Hello, world!"}) + + patterns = [ + url(r'^/test/$', test_override), + url(r'^/test/$', test_view), + ] + + generator = OpenAPISchemaGenerator( + info=openapi.Info(title="Test generator", default_version="v1"), + version="v2", + url='', + patterns=patterns + ) + + # description override is successful + swagger = generator.get_schema(None, True) + assert swagger['paths']['/test/']['get']['description'] == 'description override' + + # get_endpoints only includes one endpoint + assert len(generator.get_endpoints(None)['/test/'][1]) == 1