Call APIView get_ methods instead of direct attribute access

queryset -> get_queryset
renderer_classes -> get_renderers
parser_classes -> get_parsers
master
Cristi Vîjdea 2018-12-23 16:57:01 +02:00
parent 04d61b9d97
commit bebcc982e6
8 changed files with 124 additions and 43 deletions

View File

@ -183,6 +183,7 @@ nitpick_ignore = [
('py:class', 'ruamel.yaml.dumper.SafeDumper'), ('py:class', 'ruamel.yaml.dumper.SafeDumper'),
('py:class', 'rest_framework.serializers.Serializer'), ('py:class', 'rest_framework.serializers.Serializer'),
('py:class', 'rest_framework.renderers.BaseRenderer'), ('py:class', 'rest_framework.renderers.BaseRenderer'),
('py:class', 'rest_framework.parsers.BaseParser'),
('py:class', 'rest_framework.schemas.generators.EndpointEnumerator'), ('py:class', 'rest_framework.schemas.generators.EndpointEnumerator'),
('py:class', 'rest_framework.views.APIView'), ('py:class', 'rest_framework.views.APIView'),

View File

@ -13,6 +13,7 @@ from .errors import SwaggerValidationError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _validate_flex(spec): def _validate_flex(spec):
from flex.core import parse as validate_flex from flex.core import parse as validate_flex
from flex.exceptions import ValidationError from flex.exceptions import ValidationError

View File

@ -8,9 +8,9 @@ from coreapi.compat import urlparse
from rest_framework import versioning from rest_framework import versioning
from rest_framework.compat import URLPattern, URLResolver, get_original_route from rest_framework.compat import URLPattern, URLResolver, get_original_route
from rest_framework.schemas.generators import EndpointEnumerator as _EndpointEnumerator from rest_framework.schemas.generators import EndpointEnumerator as _EndpointEnumerator
from rest_framework.schemas.generators import SchemaGenerator, endpoint_ordering from rest_framework.schemas.generators import SchemaGenerator, endpoint_ordering, get_pk_name
from rest_framework.schemas.inspectors import get_pk_description from rest_framework.schemas.inspectors import get_pk_description
from rest_framework.settings import api_settings as rest_framework_settings from rest_framework.settings import api_settings
from . import openapi from . import openapi
from .app_settings import swagger_settings from .app_settings import swagger_settings
@ -132,7 +132,7 @@ class EndpointEnumerator(_EndpointEnumerator):
def unescape_path(self, path): def unescape_path(self, path):
"""Remove backslashe escapes from all path components outside {parameters}. This is needed because """Remove backslashe escapes from all path components outside {parameters}. This is needed because
``simplify_regex`` does not handle this correctly - note however that this implementation is ``simplify_regex`` does not handle this correctly.
**NOTE:** this might destructively affect some url regex patterns that contain metacharacters (e.g. \\w, \\d) **NOTE:** this might destructively affect some url regex patterns that contain metacharacters (e.g. \\w, \\d)
outside path parameter groups; if you are in this category, God help you outside path parameter groups; if you are in this category, God help you
@ -164,7 +164,7 @@ class OpenAPISchemaGenerator(object):
def __init__(self, info, version='', url=None, patterns=None, urlconf=None): def __init__(self, info, version='', url=None, patterns=None, urlconf=None):
""" """
:param .Info info: information about the API :param openapi.Info info: information about the API
:param str version: API version string; if omitted, `info.default_version` will be used :param str version: API version string; if omitted, `info.default_version` will be used
:param str url: API scheme, host and port; if ``None`` is passed and ``DEFAULT_API_URL`` is not set, the url :param str url: API scheme, host and port; if ``None`` is passed and ``DEFAULT_API_URL`` is not set, the url
will be inferred from the request made against the schema view, so you should generally not need to set will be inferred from the request made against the schema view, so you should generally not need to set
@ -216,7 +216,7 @@ class OpenAPISchemaGenerator(object):
:meth:`.get_security_definitions` returns `None`. :meth:`.get_security_definitions` returns `None`.
:param security_definitions: security definitions as returned by :meth:`.get_security_definitions` :param security_definitions: security definitions as returned by :meth:`.get_security_definitions`
:return: :return: the security schemes accepted by default
:rtype: list[dict[str,list[str]]] or None :rtype: list[dict[str,list[str]]] or None
""" """
security_requirements = swagger_settings.SECURITY_REQUIREMENTS security_requirements = swagger_settings.SECURITY_REQUIREMENTS
@ -239,8 +239,8 @@ class OpenAPISchemaGenerator(object):
""" """
endpoints = self.get_endpoints(request) endpoints = self.get_endpoints(request)
components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS) components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS)
self.consumes = get_consumes(rest_framework_settings.DEFAULT_PARSER_CLASSES) self.consumes = get_consumes(api_settings.DEFAULT_PARSER_CLASSES)
self.produces = get_produces(rest_framework_settings.DEFAULT_RENDERER_CLASSES) self.produces = get_produces(api_settings.DEFAULT_RENDERER_CLASSES)
paths, prefix = self.get_paths(endpoints, components, request, public) paths, prefix = self.get_paths(endpoints, components, request, public)
security_definitions = self.get_security_definitions() security_definitions = self.get_security_definitions()
@ -280,6 +280,24 @@ class OpenAPISchemaGenerator(object):
setattr(view, 'swagger_fake_view', True) setattr(view, 'swagger_fake_view', True)
return view return view
def coerce_path(self, path, view):
"""Coerce {pk} path arguments into the name of the model field, where possible. This is cleaner for an
external representation (i.e. "this is an identifier", not "this is a database primary key").
:param str path: the path
:param rest_framework.views.APIView view: associated view
:rtype: str
"""
if '{pk}' not in path:
return path
model = getattr(get_queryset_from_view(view), 'model', None)
if model:
field_name = get_pk_name(model)
else:
field_name = 'id'
return path.replace('{pk}', '{%s}' % field_name)
def get_endpoints(self, request): def get_endpoints(self, request):
"""Iterate over all the registered endpoints in the API and return a fake view with the right parameters. """Iterate over all the registered endpoints in the API and return a fake view with the right parameters.
@ -295,7 +313,7 @@ class OpenAPISchemaGenerator(object):
view_cls = {} view_cls = {}
for path, method, callback in endpoints: for path, method, callback in endpoints:
view = self.create_view(callback, method, request) view = self.create_view(callback, method, request)
path = self._gen.coerce_path(path, method, view) path = self.coerce_path(path, view)
view_paths[path].append((method, view)) view_paths[path].append((method, view))
view_cls[path] = callback.cls view_cls[path] = callback.cls
return {path: (view_cls[path], methods) for path, methods in view_paths.items()} return {path: (view_cls[path], methods) for path, methods in view_paths.items()}
@ -313,7 +331,7 @@ class OpenAPISchemaGenerator(object):
:param str subpath: path to the operation with any common prefix/base path removed :param str subpath: path to the operation with any common prefix/base path removed
:param str method: HTTP method :param str method: HTTP method
:param view: the view associated with the operation :param view: the view associated with the operation
:rtype: tuple :rtype: list[str]
""" """
return self._gen.get_keys(subpath, method, view) return self._gen.get_keys(subpath, method, view)

View File

@ -4,7 +4,7 @@ import logging
from rest_framework import serializers from rest_framework import serializers
from .. import openapi from .. import openapi
from ..utils import force_real_str, get_field_default, is_list_view from ..utils import force_real_str, get_field_default, get_object_classes, is_list_view
#: Sentinel value that inspectors must return to signal that they do not know how to handle an object #: Sentinel value that inspectors must return to signal that they do not know how to handle an object
NotHandled = object() NotHandled = object()
@ -12,14 +12,38 @@ NotHandled = object()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def call_view_method(view, method_name, fallback_attr=None, default=None):
"""Call a view method which might throw an exception. If an exception is thrown, log an informative error message
and return the value of fallback_attr, or default if not present.
:param rest_framework.views.APIView view:
:param str method_name: name of a method on the view
:param str fallback_attr: name of an attribute on the view to fall back on, if calling the method fails
:param default: default value if all else fails
:return: view method's return value, or value of view's fallback_attr, or default
"""
if hasattr(view, method_name):
try:
return getattr(view, method_name)()
except Exception: # pragma: no cover
logger.warning("view's %s.get_parsers raised exception during schema generation; use "
"`getattr(self, 'swagger_fake_view', False)` to detect and short-circuit this",
type(view).__name__, exc_info=True)
if fallback_attr and hasattr(view, fallback_attr):
return getattr(view, fallback_attr)
return default
class BaseInspector(object): class BaseInspector(object):
def __init__(self, view, path, method, components, request): def __init__(self, view, path, method, components, request):
""" """
:param view: the view associated with this endpoint :param rest_framework.views.APIView view: the view associated with this endpoint
:param str path: the path component of the operation URL :param str path: the path component of the operation URL
:param str method: the http method of the operation :param str method: the http method of the operation
:param openapi.ReferenceResolver components: referenceable components :param openapi.ReferenceResolver components: referenceable components
:param Request request: the request made against the schema view; can be None :param rest_framework.request.Request request: the request made against the schema view; can be None
""" """
self.view = view self.view = view
self.path = path self.path = path
@ -81,6 +105,22 @@ class BaseInspector(object):
return result return result
def get_renderer_classes(self):
"""Get the renderer classes of this view by calling `get_renderers`.
:return: renderer classes
:rtype: list[type[rest_framework.renderers.BaseRenderer]]
"""
return get_object_classes(call_view_method(self.view, 'get_renderers', 'renderer_classes', []))
def get_parser_classes(self):
"""Get the parser classes of this view by calling `get_parsers`.
:return: parser classes
:rtype: list[type[rest_framework.parsers.BaseParser]]
"""
return get_object_classes(call_view_method(self.view, 'get_parsers', 'parser_classes', []))
class PaginatorInspector(BaseInspector): class PaginatorInspector(BaseInspector):
"""Base inspector for paginators. """Base inspector for paginators.
@ -335,7 +375,7 @@ class ViewInspector(BaseInspector):
return [] return []
fields = [] fields = []
for filter_backend in self.view.filter_backends: for filter_backend in getattr(self.view, 'filter_backends'):
fields += self.probe_inspectors(self.filter_inspectors, 'get_filter_parameters', filter_backend()) or [] fields += self.probe_inspectors(self.filter_inspectors, 'get_filter_parameters', filter_backend()) or []
return fields return fields
@ -361,7 +401,8 @@ class ViewInspector(BaseInspector):
if not self.should_page(): if not self.should_page():
return [] return []
return self.probe_inspectors(self.paginator_inspectors, 'get_paginator_parameters', self.view.paginator) or [] return self.probe_inspectors(self.paginator_inspectors, 'get_paginator_parameters',
getattr(self.view, 'paginator')) or []
def serializer_to_schema(self, serializer): def serializer_to_schema(self, serializer):
"""Convert a serializer to an OpenAPI :class:`.Schema`. """Convert a serializer to an OpenAPI :class:`.Schema`.
@ -394,4 +435,4 @@ class ViewInspector(BaseInspector):
:rtype: openapi.Schema :rtype: openapi.Schema
""" """
return self.probe_inspectors(self.paginator_inspectors, 'get_paginated_response', return self.probe_inspectors(self.paginator_inspectors, 'get_paginated_response',
self.view.paginator, response_schema=response_schema) getattr(self.view, 'paginator'), response_schema=response_schema)

View File

@ -14,7 +14,7 @@ from rest_framework.settings import api_settings as rest_framework_settings
from .. import openapi from .. import openapi
from ..errors import SwaggerGenerationError from ..errors import SwaggerGenerationError
from ..utils import decimal_as_float, filter_none, get_serializer_class, get_serializer_ref_name from ..utils import decimal_as_float, filter_none, get_serializer_class, get_serializer_ref_name
from .base import FieldInspector, NotHandled, SerializerInspector from .base import FieldInspector, NotHandled, SerializerInspector, call_view_method
try: try:
import typing import typing
@ -177,7 +177,7 @@ def get_queryset_from_view(view, serializer=None):
:return: queryset or ``None`` :return: queryset or ``None``
""" """
try: try:
queryset = getattr(view, 'queryset', None) queryset = call_view_method(view, 'get_queryset', 'queryset', None)
if queryset is not None and serializer is not None: if queryset is not None and serializer is not None:
# make sure the view is actually using *this* serializer # make sure the view is actually using *this* serializer
@ -733,8 +733,8 @@ class CamelCaseJSONFilter(FieldInspector):
if CamelCaseJSONParser and CamelCaseJSONRenderer: if CamelCaseJSONParser and CamelCaseJSONRenderer:
def is_camel_case(self): def is_camel_case(self):
return ( return (
any(issubclass(parser, CamelCaseJSONParser) for parser in self.view.parser_classes) or any(issubclass(parser, CamelCaseJSONParser) for parser in self.get_parser_classes()) or
any(issubclass(renderer, CamelCaseJSONRenderer) for renderer in self.view.renderer_classes) any(issubclass(renderer, CamelCaseJSONRenderer) for renderer in self.get_renderer_classes())
) )
else: else:
def is_camel_case(self): def is_camel_case(self):

View File

@ -12,9 +12,9 @@ from ..utils import (
filter_none, force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status, filter_none, force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status,
is_list_view, merge_params, no_body, param_list_to_odict is_list_view, merge_params, no_body, param_list_to_odict
) )
from .base import ViewInspector from .base import ViewInspector, call_view_method
log = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class SwaggerAutoSchema(ViewInspector): class SwaggerAutoSchema(ViewInspector):
@ -103,21 +103,12 @@ class SwaggerAutoSchema(ViewInspector):
"""Return the serializer as defined by the view's ``get_serializer()`` method. """Return the serializer as defined by the view's ``get_serializer()`` method.
:return: the view's ``Serializer`` :return: the view's ``Serializer``
:rtype: rest_framework.serializers.Serializer
""" """
if not hasattr(self.view, 'get_serializer'): return call_view_method(self.view, 'get_serializer')
return None
try:
return self.view.get_serializer()
except Exception:
log.warning("view's get_serializer raised exception (%s %s %s)",
self.method, self.path, type(self.view).__name__, exc_info=True)
return None
def _get_request_body_override(self): def _get_request_body_override(self):
"""Parse the request_body key in the override dict. This method is not public API. """Parse the request_body key in the override dict. This method is not public API."""
:return:
"""
body_override = self.overrides.get('request_body', None) body_override = self.overrides.get('request_body', None)
if body_override is not None: if body_override is not None:
@ -136,6 +127,7 @@ class SwaggerAutoSchema(ViewInspector):
"""Return the request serializer (used for parsing the request payload) for this endpoint. """Return the request serializer (used for parsing the request payload) for this endpoint.
:return: the request serializer, or one of :class:`.Schema`, :class:`.SchemaRef`, ``None`` :return: the request serializer, or one of :class:`.Schema`, :class:`.SchemaRef`, ``None``
:rtype: rest_framework.serializers.Serializer
""" """
body_override = self._get_request_body_override() body_override = self._get_request_body_override()
@ -430,11 +422,11 @@ class SwaggerAutoSchema(ViewInspector):
:rtype: list[str] :rtype: list[str]
""" """
return get_consumes(getattr(self.view, 'parser_classes', [])) return get_consumes(self.get_parser_classes())
def get_produces(self): def get_produces(self):
"""Return the MIME types this endpoint can produce. """Return the MIME types this endpoint can produce.
:rtype: list[str] :rtype: list[str]
""" """
return get_produces(getattr(self.view, 'renderer_classes', [])) return get_produces(self.get_renderer_classes())

View File

@ -95,12 +95,12 @@ def swagger_auto_schema(method=None, methods=None, auto_schema=unset, request_bo
:type responses: dict[str,(drf_yasg.openapi.Schema or drf_yasg.openapi.SchemaRef or drf_yasg.openapi.Response or :type responses: dict[str,(drf_yasg.openapi.Schema or drf_yasg.openapi.SchemaRef or drf_yasg.openapi.Response or
str or rest_framework.serializers.Serializer)] str or rest_framework.serializers.Serializer)]
:param list[drf_yasg.inspectors.FieldInspector] field_inspectors: extra serializer and field inspectors; these will :param list[type[drf_yasg.inspectors.FieldInspector]] field_inspectors: extra serializer and field inspectors; these
be tried before :attr:`.ViewInspector.field_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance will be tried before :attr:`.ViewInspector.field_inspectors` on the :class:`.inspectors.SwaggerAutoSchema`
:param list[drf_yasg.inspectors.FilterInspector] filter_inspectors: extra filter inspectors; these will be tried :param list[type[drf_yasg.inspectors.FilterInspector]] filter_inspectors: extra filter inspectors; these will be
before :attr:`.ViewInspector.filter_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance tried before :attr:`.ViewInspector.filter_inspectors` on the :class:`.inspectors.SwaggerAutoSchema`
:param list[drf_yasg.inspectors.PaginatorInspector] paginator_inspectors: extra paginator inspectors; these will be :param list[type[drf_yasg.inspectors.PaginatorInspector]] paginator_inspectors: extra paginator inspectors; these
tried before :attr:`.ViewInspector.paginator_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance will be tried before :attr:`.ViewInspector.paginator_inspectors` on the :class:`.inspectors.SwaggerAutoSchema`
:param list[str] tags: tags override :param list[str] tags: tags override
:param extra_overrides: extra values that will be saved into the ``overrides`` dict; these values will be available :param extra_overrides: extra values that will be saved into the ``overrides`` dict; these values will be available
in the handling :class:`.inspectors.SwaggerAutoSchema` instance via ``self.overrides`` in the handling :class:`.inspectors.SwaggerAutoSchema` instance via ``self.overrides``
@ -300,6 +300,7 @@ def force_serializer_instance(serializer):
an assertion error. an assertion error.
:param serializer: serializer class or instance :param serializer: serializer class or instance
:type serializer: serializers.BaseSerializer or type[serializers.BaseSerializer]
:return: serializer instance :return: serializer instance
:rtype: serializers.BaseSerializer :rtype: serializers.BaseSerializer
""" """
@ -332,13 +333,38 @@ def get_serializer_class(serializer):
return type(serializer) return type(serializer)
def get_object_classes(classes_or_instances, expected_base_class=None):
"""Given a list of instances or class objects, return the list of their classes.
:param classes_or_instances: mixed list to parse
:type classes_or_instances: list[type or object]
:param expected_base_class: if given, only subclasses or instances of this type will be returned
:type expected_base_class: type
:return: list of classes
:rtype: list
"""
classes_or_instances = classes_or_instances or []
result = []
for obj in classes_or_instances:
if inspect.isclass(obj):
if not expected_base_class or issubclass(obj, expected_base_class):
result.append(obj)
else:
if not expected_base_class or isinstance(obj, expected_base_class):
result.append(type(obj))
return result
def get_consumes(parser_classes): def get_consumes(parser_classes):
"""Extract ``consumes`` MIME types from a list of parser classes. """Extract ``consumes`` MIME types from a list of parser classes.
:param list parser_classes: parser classes :param list parser_classes: parser classes
:type parser_classes: list[rest_framework.parsers.BaseParser or type[rest_framework.parsers.BaseParser]]
:return: MIME types for ``consumes`` :return: MIME types for ``consumes``
:rtype: list[str] :rtype: list[str]
""" """
parser_classes = get_object_classes(parser_classes)
media_types = [parser.media_type for parser in parser_classes or []] media_types = [parser.media_type for parser in parser_classes or []]
non_form_media_types = [encoding for encoding in media_types if not is_form_media_type(encoding)] non_form_media_types = [encoding for encoding in media_types if not is_form_media_type(encoding)]
if len(non_form_media_types) == 0: if len(non_form_media_types) == 0:
@ -351,9 +377,11 @@ def get_produces(renderer_classes):
"""Extract ``produces`` MIME types from a list of renderer classes. """Extract ``produces`` MIME types from a list of renderer classes.
:param list renderer_classes: renderer classes :param list renderer_classes: renderer classes
:type renderer_classes: list[rest_framework.renderers.BaseRenderer or type[rest_framework.renderers.BaseRenderer]]
:return: MIME types for ``produces`` :return: MIME types for ``produces``
:rtype: list[str] :rtype: list[str]
""" """
renderer_classes = get_object_classes(renderer_classes)
media_types = [renderer.media_type for renderer in renderer_classes or []] media_types = [renderer.media_type for renderer in renderer_classes or []]
media_types = [encoding for encoding in media_types media_types = [encoding for encoding in media_types
if not any(excluded in encoding for excluded in swagger_settings.EXCLUDED_MEDIA_TYPES)] if not any(excluded in encoding for excluded in swagger_settings.EXCLUDED_MEDIA_TYPES)]
@ -378,7 +406,7 @@ def get_serializer_ref_name(serializer):
:param serializer: Serializer instance :param serializer: Serializer instance
:return: Serializer's ``ref_name`` or ``None`` for inline serializer :return: Serializer's ``ref_name`` or ``None`` for inline serializer
:rtype: str :rtype: str or None
""" """
serializer_meta = getattr(serializer, 'Meta', None) serializer_meta = getattr(serializer, 'Meta', None)
serializer_name = type(serializer).__name__ serializer_name = type(serializer).__name__

View File

@ -55,7 +55,7 @@ def swagger_dict(swagger, codec_json):
@pytest.fixture @pytest.fixture
def validate_schema(db): def validate_schema():
def validate_schema(swagger): def validate_schema(swagger):
from flex.core import parse as validate_flex from flex.core import parse as validate_flex
from swagger_spec_validator.validator20 import validate_spec as validate_ssv from swagger_spec_validator.validator20 import validate_spec as validate_ssv