diff --git a/src/drf_yasg/generators.py b/src/drf_yasg/generators.py index 32fa0d0..a647b84 100644 --- a/src/drf_yasg/generators.py +++ b/src/drf_yasg/generators.py @@ -72,7 +72,7 @@ class EndpointEnumerator(_EndpointEnumerator): return path - def get_api_endpoints(self, patterns=None, prefix='', app_name=None, namespace=None, previously_seen_endpoints=None): + def get_api_endpoints(self, patterns=None, prefix='', app_name=None, namespace=None, ignored_endpoints=None): """ Return a list of all available API endpoints by inspecting the URL conf. @@ -82,8 +82,8 @@ class EndpointEnumerator(_EndpointEnumerator): patterns = self.patterns api_endpoints = [] - if previously_seen_endpoints is None: - previously_seen_endpoints = set() + if ignored_endpoints is None: + ignored_endpoints = set() for pattern in patterns: path_regex = prefix + get_original_route(pattern) @@ -97,9 +97,9 @@ class EndpointEnumerator(_EndpointEnumerator): # avoid adding endpoints that have already been seen, # as Django resolves urls in top-down order - if path in previously_seen_endpoints: + if path in ignored_endpoints: continue - previously_seen_endpoints.add(path) + ignored_endpoints.add(path) for method in self.get_allowed_methods(callback): endpoint = (path, method, callback) @@ -113,7 +113,7 @@ class EndpointEnumerator(_EndpointEnumerator): prefix=path_regex, app_name="%s:%s" % (app_name, pattern.app_name) if app_name else pattern.app_name, namespace="%s:%s" % (namespace, pattern.namespace) if namespace else pattern.namespace, - previously_seen_endpoints=previously_seen_endpoints + ignored_endpoints=ignored_endpoints ) api_endpoints.extend(nested_endpoints) else: diff --git a/src/drf_yasg/inspectors/view.py b/src/drf_yasg/inspectors/view.py index 821328c..957c8f7 100644 --- a/src/drf_yasg/inspectors/view.py +++ b/src/drf_yasg/inspectors/view.py @@ -93,10 +93,10 @@ class SwaggerAutoSchema(ViewInspector): self.method, self.path, type(self.view).__name__, exc_info=True) return None - def get_request_serializer(self): - """Return the request serializer (used for parsing the request payload) for this endpoint. + def _get_request_body_override(self): + """Parse the request_body key in the override dict. This method is not public API. - :return: the request serializer, or one of :class:`.Schema`, :class:`.SchemaRef`, ``None`` + :return: """ body_override = self.overrides.get('request_body', None) @@ -109,10 +109,20 @@ class SwaggerAutoSchema(ViewInspector): if isinstance(body_override, openapi.Schema.OR_REF): return body_override return force_serializer_instance(body_override) - elif self.method in self.implicit_body_methods: + + return body_override + + def get_request_serializer(self): + """Return the request serializer (used for parsing the request payload) for this endpoint. + + :return: the request serializer, or one of :class:`.Schema`, :class:`.SchemaRef`, ``None`` + """ + body_override = self._get_request_body_override() + + if body_override is None and self.method in self.implicit_body_methods: return self.get_view_serializer() - return None + return body_override def get_request_form_parameters(self, serializer): """Given a Serializer, return a list of ``in: formData`` :class:`.Parameter`\ s. @@ -172,6 +182,14 @@ class SwaggerAutoSchema(ViewInspector): responses=self.get_response_schemas(response_serializers) ) + def get_default_response_serializer(self): + """Return the default response serializer for this endpoint. This is derived from either the ``request_body`` + override or the request serializer (:meth:`.get_view_serializer`). + + :return: response serializer, :class:`.Schema`, :class:`.SchemaRef`, ``None`` + """ + return self._get_request_body_override() or self.get_view_serializer() + def get_default_responses(self): """Get the default responses determined for this view from the request serializer and request method. @@ -182,7 +200,7 @@ class SwaggerAutoSchema(ViewInspector): default_status = guess_response_status(method) default_schema = '' if method in ('get', 'post', 'put', 'patch'): - default_schema = self.get_request_serializer() or self.get_view_serializer() + default_schema = self.get_default_response_serializer() default_schema = default_schema or '' if any(is_form_media_type(encoding) for encoding in self.get_consumes()):