Call APIView get_ methods instead of direct attribute access
queryset -> get_queryset renderer_classes -> get_renderers parser_classes -> get_parsersmaster
parent
04d61b9d97
commit
bebcc982e6
|
|
@ -183,6 +183,7 @@ nitpick_ignore = [
|
|||
('py:class', 'ruamel.yaml.dumper.SafeDumper'),
|
||||
('py:class', 'rest_framework.serializers.Serializer'),
|
||||
('py:class', 'rest_framework.renderers.BaseRenderer'),
|
||||
('py:class', 'rest_framework.parsers.BaseParser'),
|
||||
('py:class', 'rest_framework.schemas.generators.EndpointEnumerator'),
|
||||
('py:class', 'rest_framework.views.APIView'),
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from .errors import SwaggerValidationError
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _validate_flex(spec):
|
||||
from flex.core import parse as validate_flex
|
||||
from flex.exceptions import ValidationError
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ from coreapi.compat import urlparse
|
|||
from rest_framework import versioning
|
||||
from rest_framework.compat import URLPattern, URLResolver, get_original_route
|
||||
from rest_framework.schemas.generators import EndpointEnumerator as _EndpointEnumerator
|
||||
from rest_framework.schemas.generators import SchemaGenerator, endpoint_ordering
|
||||
from rest_framework.schemas.generators import SchemaGenerator, endpoint_ordering, get_pk_name
|
||||
from rest_framework.schemas.inspectors import get_pk_description
|
||||
from rest_framework.settings import api_settings as rest_framework_settings
|
||||
from rest_framework.settings import api_settings
|
||||
|
||||
from . import openapi
|
||||
from .app_settings import swagger_settings
|
||||
|
|
@ -132,7 +132,7 @@ class EndpointEnumerator(_EndpointEnumerator):
|
|||
|
||||
def unescape_path(self, path):
|
||||
"""Remove backslashe escapes from all path components outside {parameters}. This is needed because
|
||||
``simplify_regex`` does not handle this correctly - note however that this implementation is
|
||||
``simplify_regex`` does not handle this correctly.
|
||||
|
||||
**NOTE:** this might destructively affect some url regex patterns that contain metacharacters (e.g. \\w, \\d)
|
||||
outside path parameter groups; if you are in this category, God help you
|
||||
|
|
@ -164,7 +164,7 @@ class OpenAPISchemaGenerator(object):
|
|||
def __init__(self, info, version='', url=None, patterns=None, urlconf=None):
|
||||
"""
|
||||
|
||||
:param .Info info: information about the API
|
||||
:param openapi.Info info: information about the API
|
||||
:param str version: API version string; if omitted, `info.default_version` will be used
|
||||
:param str url: API scheme, host and port; if ``None`` is passed and ``DEFAULT_API_URL`` is not set, the url
|
||||
will be inferred from the request made against the schema view, so you should generally not need to set
|
||||
|
|
@ -216,7 +216,7 @@ class OpenAPISchemaGenerator(object):
|
|||
:meth:`.get_security_definitions` returns `None`.
|
||||
|
||||
:param security_definitions: security definitions as returned by :meth:`.get_security_definitions`
|
||||
:return:
|
||||
:return: the security schemes accepted by default
|
||||
:rtype: list[dict[str,list[str]]] or None
|
||||
"""
|
||||
security_requirements = swagger_settings.SECURITY_REQUIREMENTS
|
||||
|
|
@ -239,8 +239,8 @@ class OpenAPISchemaGenerator(object):
|
|||
"""
|
||||
endpoints = self.get_endpoints(request)
|
||||
components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS)
|
||||
self.consumes = get_consumes(rest_framework_settings.DEFAULT_PARSER_CLASSES)
|
||||
self.produces = get_produces(rest_framework_settings.DEFAULT_RENDERER_CLASSES)
|
||||
self.consumes = get_consumes(api_settings.DEFAULT_PARSER_CLASSES)
|
||||
self.produces = get_produces(api_settings.DEFAULT_RENDERER_CLASSES)
|
||||
paths, prefix = self.get_paths(endpoints, components, request, public)
|
||||
|
||||
security_definitions = self.get_security_definitions()
|
||||
|
|
@ -280,6 +280,24 @@ class OpenAPISchemaGenerator(object):
|
|||
setattr(view, 'swagger_fake_view', True)
|
||||
return view
|
||||
|
||||
def coerce_path(self, path, view):
|
||||
"""Coerce {pk} path arguments into the name of the model field, where possible. This is cleaner for an
|
||||
external representation (i.e. "this is an identifier", not "this is a database primary key").
|
||||
|
||||
:param str path: the path
|
||||
:param rest_framework.views.APIView view: associated view
|
||||
:rtype: str
|
||||
"""
|
||||
if '{pk}' not in path:
|
||||
return path
|
||||
|
||||
model = getattr(get_queryset_from_view(view), 'model', None)
|
||||
if model:
|
||||
field_name = get_pk_name(model)
|
||||
else:
|
||||
field_name = 'id'
|
||||
return path.replace('{pk}', '{%s}' % field_name)
|
||||
|
||||
def get_endpoints(self, request):
|
||||
"""Iterate over all the registered endpoints in the API and return a fake view with the right parameters.
|
||||
|
||||
|
|
@ -295,7 +313,7 @@ class OpenAPISchemaGenerator(object):
|
|||
view_cls = {}
|
||||
for path, method, callback in endpoints:
|
||||
view = self.create_view(callback, method, request)
|
||||
path = self._gen.coerce_path(path, method, view)
|
||||
path = self.coerce_path(path, view)
|
||||
view_paths[path].append((method, view))
|
||||
view_cls[path] = callback.cls
|
||||
return {path: (view_cls[path], methods) for path, methods in view_paths.items()}
|
||||
|
|
@ -313,7 +331,7 @@ class OpenAPISchemaGenerator(object):
|
|||
:param str subpath: path to the operation with any common prefix/base path removed
|
||||
:param str method: HTTP method
|
||||
:param view: the view associated with the operation
|
||||
:rtype: tuple
|
||||
:rtype: list[str]
|
||||
"""
|
||||
return self._gen.get_keys(subpath, method, view)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import logging
|
|||
from rest_framework import serializers
|
||||
|
||||
from .. import openapi
|
||||
from ..utils import force_real_str, get_field_default, is_list_view
|
||||
from ..utils import force_real_str, get_field_default, get_object_classes, is_list_view
|
||||
|
||||
#: Sentinel value that inspectors must return to signal that they do not know how to handle an object
|
||||
NotHandled = object()
|
||||
|
|
@ -12,14 +12,38 @@ NotHandled = object()
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def call_view_method(view, method_name, fallback_attr=None, default=None):
|
||||
"""Call a view method which might throw an exception. If an exception is thrown, log an informative error message
|
||||
and return the value of fallback_attr, or default if not present.
|
||||
|
||||
:param rest_framework.views.APIView view:
|
||||
:param str method_name: name of a method on the view
|
||||
:param str fallback_attr: name of an attribute on the view to fall back on, if calling the method fails
|
||||
:param default: default value if all else fails
|
||||
:return: view method's return value, or value of view's fallback_attr, or default
|
||||
"""
|
||||
if hasattr(view, method_name):
|
||||
try:
|
||||
return getattr(view, method_name)()
|
||||
except Exception: # pragma: no cover
|
||||
logger.warning("view's %s.get_parsers raised exception during schema generation; use "
|
||||
"`getattr(self, 'swagger_fake_view', False)` to detect and short-circuit this",
|
||||
type(view).__name__, exc_info=True)
|
||||
|
||||
if fallback_attr and hasattr(view, fallback_attr):
|
||||
return getattr(view, fallback_attr)
|
||||
|
||||
return default
|
||||
|
||||
|
||||
class BaseInspector(object):
|
||||
def __init__(self, view, path, method, components, request):
|
||||
"""
|
||||
:param view: the view associated with this endpoint
|
||||
:param rest_framework.views.APIView view: the view associated with this endpoint
|
||||
:param str path: the path component of the operation URL
|
||||
:param str method: the http method of the operation
|
||||
:param openapi.ReferenceResolver components: referenceable components
|
||||
:param Request request: the request made against the schema view; can be None
|
||||
:param rest_framework.request.Request request: the request made against the schema view; can be None
|
||||
"""
|
||||
self.view = view
|
||||
self.path = path
|
||||
|
|
@ -81,6 +105,22 @@ class BaseInspector(object):
|
|||
|
||||
return result
|
||||
|
||||
def get_renderer_classes(self):
|
||||
"""Get the renderer classes of this view by calling `get_renderers`.
|
||||
|
||||
:return: renderer classes
|
||||
:rtype: list[type[rest_framework.renderers.BaseRenderer]]
|
||||
"""
|
||||
return get_object_classes(call_view_method(self.view, 'get_renderers', 'renderer_classes', []))
|
||||
|
||||
def get_parser_classes(self):
|
||||
"""Get the parser classes of this view by calling `get_parsers`.
|
||||
|
||||
:return: parser classes
|
||||
:rtype: list[type[rest_framework.parsers.BaseParser]]
|
||||
"""
|
||||
return get_object_classes(call_view_method(self.view, 'get_parsers', 'parser_classes', []))
|
||||
|
||||
|
||||
class PaginatorInspector(BaseInspector):
|
||||
"""Base inspector for paginators.
|
||||
|
|
@ -335,7 +375,7 @@ class ViewInspector(BaseInspector):
|
|||
return []
|
||||
|
||||
fields = []
|
||||
for filter_backend in self.view.filter_backends:
|
||||
for filter_backend in getattr(self.view, 'filter_backends'):
|
||||
fields += self.probe_inspectors(self.filter_inspectors, 'get_filter_parameters', filter_backend()) or []
|
||||
|
||||
return fields
|
||||
|
|
@ -361,7 +401,8 @@ class ViewInspector(BaseInspector):
|
|||
if not self.should_page():
|
||||
return []
|
||||
|
||||
return self.probe_inspectors(self.paginator_inspectors, 'get_paginator_parameters', self.view.paginator) or []
|
||||
return self.probe_inspectors(self.paginator_inspectors, 'get_paginator_parameters',
|
||||
getattr(self.view, 'paginator')) or []
|
||||
|
||||
def serializer_to_schema(self, serializer):
|
||||
"""Convert a serializer to an OpenAPI :class:`.Schema`.
|
||||
|
|
@ -394,4 +435,4 @@ class ViewInspector(BaseInspector):
|
|||
:rtype: openapi.Schema
|
||||
"""
|
||||
return self.probe_inspectors(self.paginator_inspectors, 'get_paginated_response',
|
||||
self.view.paginator, response_schema=response_schema)
|
||||
getattr(self.view, 'paginator'), response_schema=response_schema)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from rest_framework.settings import api_settings as rest_framework_settings
|
|||
from .. import openapi
|
||||
from ..errors import SwaggerGenerationError
|
||||
from ..utils import decimal_as_float, filter_none, get_serializer_class, get_serializer_ref_name
|
||||
from .base import FieldInspector, NotHandled, SerializerInspector
|
||||
from .base import FieldInspector, NotHandled, SerializerInspector, call_view_method
|
||||
|
||||
try:
|
||||
import typing
|
||||
|
|
@ -177,7 +177,7 @@ def get_queryset_from_view(view, serializer=None):
|
|||
:return: queryset or ``None``
|
||||
"""
|
||||
try:
|
||||
queryset = getattr(view, 'queryset', None)
|
||||
queryset = call_view_method(view, 'get_queryset', 'queryset', None)
|
||||
|
||||
if queryset is not None and serializer is not None:
|
||||
# make sure the view is actually using *this* serializer
|
||||
|
|
@ -733,8 +733,8 @@ class CamelCaseJSONFilter(FieldInspector):
|
|||
if CamelCaseJSONParser and CamelCaseJSONRenderer:
|
||||
def is_camel_case(self):
|
||||
return (
|
||||
any(issubclass(parser, CamelCaseJSONParser) for parser in self.view.parser_classes) or
|
||||
any(issubclass(renderer, CamelCaseJSONRenderer) for renderer in self.view.renderer_classes)
|
||||
any(issubclass(parser, CamelCaseJSONParser) for parser in self.get_parser_classes()) or
|
||||
any(issubclass(renderer, CamelCaseJSONRenderer) for renderer in self.get_renderer_classes())
|
||||
)
|
||||
else:
|
||||
def is_camel_case(self):
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ from ..utils import (
|
|||
filter_none, force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status,
|
||||
is_list_view, merge_params, no_body, param_list_to_odict
|
||||
)
|
||||
from .base import ViewInspector
|
||||
from .base import ViewInspector, call_view_method
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SwaggerAutoSchema(ViewInspector):
|
||||
|
|
@ -103,21 +103,12 @@ class SwaggerAutoSchema(ViewInspector):
|
|||
"""Return the serializer as defined by the view's ``get_serializer()`` method.
|
||||
|
||||
:return: the view's ``Serializer``
|
||||
:rtype: rest_framework.serializers.Serializer
|
||||
"""
|
||||
if not hasattr(self.view, 'get_serializer'):
|
||||
return None
|
||||
try:
|
||||
return self.view.get_serializer()
|
||||
except Exception:
|
||||
log.warning("view's get_serializer raised exception (%s %s %s)",
|
||||
self.method, self.path, type(self.view).__name__, exc_info=True)
|
||||
return None
|
||||
return call_view_method(self.view, 'get_serializer')
|
||||
|
||||
def _get_request_body_override(self):
|
||||
"""Parse the request_body key in the override dict. This method is not public API.
|
||||
|
||||
:return:
|
||||
"""
|
||||
"""Parse the request_body key in the override dict. This method is not public API."""
|
||||
body_override = self.overrides.get('request_body', None)
|
||||
|
||||
if body_override is not None:
|
||||
|
|
@ -136,6 +127,7 @@ class SwaggerAutoSchema(ViewInspector):
|
|||
"""Return the request serializer (used for parsing the request payload) for this endpoint.
|
||||
|
||||
:return: the request serializer, or one of :class:`.Schema`, :class:`.SchemaRef`, ``None``
|
||||
:rtype: rest_framework.serializers.Serializer
|
||||
"""
|
||||
body_override = self._get_request_body_override()
|
||||
|
||||
|
|
@ -430,11 +422,11 @@ class SwaggerAutoSchema(ViewInspector):
|
|||
|
||||
:rtype: list[str]
|
||||
"""
|
||||
return get_consumes(getattr(self.view, 'parser_classes', []))
|
||||
return get_consumes(self.get_parser_classes())
|
||||
|
||||
def get_produces(self):
|
||||
"""Return the MIME types this endpoint can produce.
|
||||
|
||||
:rtype: list[str]
|
||||
"""
|
||||
return get_produces(getattr(self.view, 'renderer_classes', []))
|
||||
return get_produces(self.get_renderer_classes())
|
||||
|
|
|
|||
|
|
@ -95,12 +95,12 @@ def swagger_auto_schema(method=None, methods=None, auto_schema=unset, request_bo
|
|||
:type responses: dict[str,(drf_yasg.openapi.Schema or drf_yasg.openapi.SchemaRef or drf_yasg.openapi.Response or
|
||||
str or rest_framework.serializers.Serializer)]
|
||||
|
||||
:param list[drf_yasg.inspectors.FieldInspector] field_inspectors: extra serializer and field inspectors; these will
|
||||
be tried before :attr:`.ViewInspector.field_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance
|
||||
:param list[drf_yasg.inspectors.FilterInspector] filter_inspectors: extra filter inspectors; these will be tried
|
||||
before :attr:`.ViewInspector.filter_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance
|
||||
:param list[drf_yasg.inspectors.PaginatorInspector] paginator_inspectors: extra paginator inspectors; these will be
|
||||
tried before :attr:`.ViewInspector.paginator_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance
|
||||
:param list[type[drf_yasg.inspectors.FieldInspector]] field_inspectors: extra serializer and field inspectors; these
|
||||
will be tried before :attr:`.ViewInspector.field_inspectors` on the :class:`.inspectors.SwaggerAutoSchema`
|
||||
:param list[type[drf_yasg.inspectors.FilterInspector]] filter_inspectors: extra filter inspectors; these will be
|
||||
tried before :attr:`.ViewInspector.filter_inspectors` on the :class:`.inspectors.SwaggerAutoSchema`
|
||||
:param list[type[drf_yasg.inspectors.PaginatorInspector]] paginator_inspectors: extra paginator inspectors; these
|
||||
will be tried before :attr:`.ViewInspector.paginator_inspectors` on the :class:`.inspectors.SwaggerAutoSchema`
|
||||
:param list[str] tags: tags override
|
||||
:param extra_overrides: extra values that will be saved into the ``overrides`` dict; these values will be available
|
||||
in the handling :class:`.inspectors.SwaggerAutoSchema` instance via ``self.overrides``
|
||||
|
|
@ -300,6 +300,7 @@ def force_serializer_instance(serializer):
|
|||
an assertion error.
|
||||
|
||||
:param serializer: serializer class or instance
|
||||
:type serializer: serializers.BaseSerializer or type[serializers.BaseSerializer]
|
||||
:return: serializer instance
|
||||
:rtype: serializers.BaseSerializer
|
||||
"""
|
||||
|
|
@ -332,13 +333,38 @@ def get_serializer_class(serializer):
|
|||
return type(serializer)
|
||||
|
||||
|
||||
def get_object_classes(classes_or_instances, expected_base_class=None):
|
||||
"""Given a list of instances or class objects, return the list of their classes.
|
||||
|
||||
:param classes_or_instances: mixed list to parse
|
||||
:type classes_or_instances: list[type or object]
|
||||
:param expected_base_class: if given, only subclasses or instances of this type will be returned
|
||||
:type expected_base_class: type
|
||||
:return: list of classes
|
||||
:rtype: list
|
||||
"""
|
||||
classes_or_instances = classes_or_instances or []
|
||||
result = []
|
||||
for obj in classes_or_instances:
|
||||
if inspect.isclass(obj):
|
||||
if not expected_base_class or issubclass(obj, expected_base_class):
|
||||
result.append(obj)
|
||||
else:
|
||||
if not expected_base_class or isinstance(obj, expected_base_class):
|
||||
result.append(type(obj))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def get_consumes(parser_classes):
|
||||
"""Extract ``consumes`` MIME types from a list of parser classes.
|
||||
|
||||
:param list parser_classes: parser classes
|
||||
:type parser_classes: list[rest_framework.parsers.BaseParser or type[rest_framework.parsers.BaseParser]]
|
||||
:return: MIME types for ``consumes``
|
||||
:rtype: list[str]
|
||||
"""
|
||||
parser_classes = get_object_classes(parser_classes)
|
||||
media_types = [parser.media_type for parser in parser_classes or []]
|
||||
non_form_media_types = [encoding for encoding in media_types if not is_form_media_type(encoding)]
|
||||
if len(non_form_media_types) == 0:
|
||||
|
|
@ -351,9 +377,11 @@ def get_produces(renderer_classes):
|
|||
"""Extract ``produces`` MIME types from a list of renderer classes.
|
||||
|
||||
:param list renderer_classes: renderer classes
|
||||
:type renderer_classes: list[rest_framework.renderers.BaseRenderer or type[rest_framework.renderers.BaseRenderer]]
|
||||
:return: MIME types for ``produces``
|
||||
:rtype: list[str]
|
||||
"""
|
||||
renderer_classes = get_object_classes(renderer_classes)
|
||||
media_types = [renderer.media_type for renderer in renderer_classes or []]
|
||||
media_types = [encoding for encoding in media_types
|
||||
if not any(excluded in encoding for excluded in swagger_settings.EXCLUDED_MEDIA_TYPES)]
|
||||
|
|
@ -378,7 +406,7 @@ def get_serializer_ref_name(serializer):
|
|||
|
||||
:param serializer: Serializer instance
|
||||
:return: Serializer's ``ref_name`` or ``None`` for inline serializer
|
||||
:rtype: str
|
||||
:rtype: str or None
|
||||
"""
|
||||
serializer_meta = getattr(serializer, 'Meta', None)
|
||||
serializer_name = type(serializer).__name__
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ def swagger_dict(swagger, codec_json):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def validate_schema(db):
|
||||
def validate_schema():
|
||||
def validate_schema(swagger):
|
||||
from flex.core import parse as validate_flex
|
||||
from swagger_spec_validator.validator20 import validate_spec as validate_ssv
|
||||
|
|
|
|||
Loading…
Reference in New Issue