Call APIView get_ methods instead of direct attribute access

queryset -> get_queryset
renderer_classes -> get_renderers
parser_classes -> get_parsers
master
Cristi Vîjdea 2018-12-23 16:57:01 +02:00
parent 04d61b9d97
commit bebcc982e6
8 changed files with 124 additions and 43 deletions

View File

@ -183,6 +183,7 @@ nitpick_ignore = [
('py:class', 'ruamel.yaml.dumper.SafeDumper'),
('py:class', 'rest_framework.serializers.Serializer'),
('py:class', 'rest_framework.renderers.BaseRenderer'),
('py:class', 'rest_framework.parsers.BaseParser'),
('py:class', 'rest_framework.schemas.generators.EndpointEnumerator'),
('py:class', 'rest_framework.views.APIView'),

View File

@ -13,6 +13,7 @@ from .errors import SwaggerValidationError
logger = logging.getLogger(__name__)
def _validate_flex(spec):
from flex.core import parse as validate_flex
from flex.exceptions import ValidationError

View File

@ -8,9 +8,9 @@ from coreapi.compat import urlparse
from rest_framework import versioning
from rest_framework.compat import URLPattern, URLResolver, get_original_route
from rest_framework.schemas.generators import EndpointEnumerator as _EndpointEnumerator
from rest_framework.schemas.generators import SchemaGenerator, endpoint_ordering
from rest_framework.schemas.generators import SchemaGenerator, endpoint_ordering, get_pk_name
from rest_framework.schemas.inspectors import get_pk_description
from rest_framework.settings import api_settings as rest_framework_settings
from rest_framework.settings import api_settings
from . import openapi
from .app_settings import swagger_settings
@ -132,7 +132,7 @@ class EndpointEnumerator(_EndpointEnumerator):
def unescape_path(self, path):
"""Remove backslashe escapes from all path components outside {parameters}. This is needed because
``simplify_regex`` does not handle this correctly - note however that this implementation is
``simplify_regex`` does not handle this correctly.
**NOTE:** this might destructively affect some url regex patterns that contain metacharacters (e.g. \\w, \\d)
outside path parameter groups; if you are in this category, God help you
@ -164,7 +164,7 @@ class OpenAPISchemaGenerator(object):
def __init__(self, info, version='', url=None, patterns=None, urlconf=None):
"""
:param .Info info: information about the API
:param openapi.Info info: information about the API
:param str version: API version string; if omitted, `info.default_version` will be used
:param str url: API scheme, host and port; if ``None`` is passed and ``DEFAULT_API_URL`` is not set, the url
will be inferred from the request made against the schema view, so you should generally not need to set
@ -216,7 +216,7 @@ class OpenAPISchemaGenerator(object):
:meth:`.get_security_definitions` returns `None`.
:param security_definitions: security definitions as returned by :meth:`.get_security_definitions`
:return:
:return: the security schemes accepted by default
:rtype: list[dict[str,list[str]]] or None
"""
security_requirements = swagger_settings.SECURITY_REQUIREMENTS
@ -239,8 +239,8 @@ class OpenAPISchemaGenerator(object):
"""
endpoints = self.get_endpoints(request)
components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS)
self.consumes = get_consumes(rest_framework_settings.DEFAULT_PARSER_CLASSES)
self.produces = get_produces(rest_framework_settings.DEFAULT_RENDERER_CLASSES)
self.consumes = get_consumes(api_settings.DEFAULT_PARSER_CLASSES)
self.produces = get_produces(api_settings.DEFAULT_RENDERER_CLASSES)
paths, prefix = self.get_paths(endpoints, components, request, public)
security_definitions = self.get_security_definitions()
@ -280,6 +280,24 @@ class OpenAPISchemaGenerator(object):
setattr(view, 'swagger_fake_view', True)
return view
def coerce_path(self, path, view):
"""Coerce {pk} path arguments into the name of the model field, where possible. This is cleaner for an
external representation (i.e. "this is an identifier", not "this is a database primary key").
:param str path: the path
:param rest_framework.views.APIView view: associated view
:rtype: str
"""
if '{pk}' not in path:
return path
model = getattr(get_queryset_from_view(view), 'model', None)
if model:
field_name = get_pk_name(model)
else:
field_name = 'id'
return path.replace('{pk}', '{%s}' % field_name)
def get_endpoints(self, request):
"""Iterate over all the registered endpoints in the API and return a fake view with the right parameters.
@ -295,7 +313,7 @@ class OpenAPISchemaGenerator(object):
view_cls = {}
for path, method, callback in endpoints:
view = self.create_view(callback, method, request)
path = self._gen.coerce_path(path, method, view)
path = self.coerce_path(path, view)
view_paths[path].append((method, view))
view_cls[path] = callback.cls
return {path: (view_cls[path], methods) for path, methods in view_paths.items()}
@ -313,7 +331,7 @@ class OpenAPISchemaGenerator(object):
:param str subpath: path to the operation with any common prefix/base path removed
:param str method: HTTP method
:param view: the view associated with the operation
:rtype: tuple
:rtype: list[str]
"""
return self._gen.get_keys(subpath, method, view)

View File

@ -4,7 +4,7 @@ import logging
from rest_framework import serializers
from .. import openapi
from ..utils import force_real_str, get_field_default, is_list_view
from ..utils import force_real_str, get_field_default, get_object_classes, is_list_view
#: Sentinel value that inspectors must return to signal that they do not know how to handle an object
NotHandled = object()
@ -12,14 +12,38 @@ NotHandled = object()
logger = logging.getLogger(__name__)
def call_view_method(view, method_name, fallback_attr=None, default=None):
"""Call a view method which might throw an exception. If an exception is thrown, log an informative error message
and return the value of fallback_attr, or default if not present.
:param rest_framework.views.APIView view:
:param str method_name: name of a method on the view
:param str fallback_attr: name of an attribute on the view to fall back on, if calling the method fails
:param default: default value if all else fails
:return: view method's return value, or value of view's fallback_attr, or default
"""
if hasattr(view, method_name):
try:
return getattr(view, method_name)()
except Exception: # pragma: no cover
logger.warning("view's %s.get_parsers raised exception during schema generation; use "
"`getattr(self, 'swagger_fake_view', False)` to detect and short-circuit this",
type(view).__name__, exc_info=True)
if fallback_attr and hasattr(view, fallback_attr):
return getattr(view, fallback_attr)
return default
class BaseInspector(object):
def __init__(self, view, path, method, components, request):
"""
:param view: the view associated with this endpoint
:param rest_framework.views.APIView view: the view associated with this endpoint
:param str path: the path component of the operation URL
:param str method: the http method of the operation
:param openapi.ReferenceResolver components: referenceable components
:param Request request: the request made against the schema view; can be None
:param rest_framework.request.Request request: the request made against the schema view; can be None
"""
self.view = view
self.path = path
@ -81,6 +105,22 @@ class BaseInspector(object):
return result
def get_renderer_classes(self):
"""Get the renderer classes of this view by calling `get_renderers`.
:return: renderer classes
:rtype: list[type[rest_framework.renderers.BaseRenderer]]
"""
return get_object_classes(call_view_method(self.view, 'get_renderers', 'renderer_classes', []))
def get_parser_classes(self):
"""Get the parser classes of this view by calling `get_parsers`.
:return: parser classes
:rtype: list[type[rest_framework.parsers.BaseParser]]
"""
return get_object_classes(call_view_method(self.view, 'get_parsers', 'parser_classes', []))
class PaginatorInspector(BaseInspector):
"""Base inspector for paginators.
@ -335,7 +375,7 @@ class ViewInspector(BaseInspector):
return []
fields = []
for filter_backend in self.view.filter_backends:
for filter_backend in getattr(self.view, 'filter_backends'):
fields += self.probe_inspectors(self.filter_inspectors, 'get_filter_parameters', filter_backend()) or []
return fields
@ -361,7 +401,8 @@ class ViewInspector(BaseInspector):
if not self.should_page():
return []
return self.probe_inspectors(self.paginator_inspectors, 'get_paginator_parameters', self.view.paginator) or []
return self.probe_inspectors(self.paginator_inspectors, 'get_paginator_parameters',
getattr(self.view, 'paginator')) or []
def serializer_to_schema(self, serializer):
"""Convert a serializer to an OpenAPI :class:`.Schema`.
@ -394,4 +435,4 @@ class ViewInspector(BaseInspector):
:rtype: openapi.Schema
"""
return self.probe_inspectors(self.paginator_inspectors, 'get_paginated_response',
self.view.paginator, response_schema=response_schema)
getattr(self.view, 'paginator'), response_schema=response_schema)

View File

@ -14,7 +14,7 @@ from rest_framework.settings import api_settings as rest_framework_settings
from .. import openapi
from ..errors import SwaggerGenerationError
from ..utils import decimal_as_float, filter_none, get_serializer_class, get_serializer_ref_name
from .base import FieldInspector, NotHandled, SerializerInspector
from .base import FieldInspector, NotHandled, SerializerInspector, call_view_method
try:
import typing
@ -177,7 +177,7 @@ def get_queryset_from_view(view, serializer=None):
:return: queryset or ``None``
"""
try:
queryset = getattr(view, 'queryset', None)
queryset = call_view_method(view, 'get_queryset', 'queryset', None)
if queryset is not None and serializer is not None:
# make sure the view is actually using *this* serializer
@ -733,8 +733,8 @@ class CamelCaseJSONFilter(FieldInspector):
if CamelCaseJSONParser and CamelCaseJSONRenderer:
def is_camel_case(self):
return (
any(issubclass(parser, CamelCaseJSONParser) for parser in self.view.parser_classes) or
any(issubclass(renderer, CamelCaseJSONRenderer) for renderer in self.view.renderer_classes)
any(issubclass(parser, CamelCaseJSONParser) for parser in self.get_parser_classes()) or
any(issubclass(renderer, CamelCaseJSONRenderer) for renderer in self.get_renderer_classes())
)
else:
def is_camel_case(self):

View File

@ -12,9 +12,9 @@ from ..utils import (
filter_none, force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status,
is_list_view, merge_params, no_body, param_list_to_odict
)
from .base import ViewInspector
from .base import ViewInspector, call_view_method
log = logging.getLogger(__name__)
logger = logging.getLogger(__name__)
class SwaggerAutoSchema(ViewInspector):
@ -103,21 +103,12 @@ class SwaggerAutoSchema(ViewInspector):
"""Return the serializer as defined by the view's ``get_serializer()`` method.
:return: the view's ``Serializer``
:rtype: rest_framework.serializers.Serializer
"""
if not hasattr(self.view, 'get_serializer'):
return None
try:
return self.view.get_serializer()
except Exception:
log.warning("view's get_serializer raised exception (%s %s %s)",
self.method, self.path, type(self.view).__name__, exc_info=True)
return None
return call_view_method(self.view, 'get_serializer')
def _get_request_body_override(self):
"""Parse the request_body key in the override dict. This method is not public API.
:return:
"""
"""Parse the request_body key in the override dict. This method is not public API."""
body_override = self.overrides.get('request_body', None)
if body_override is not None:
@ -136,6 +127,7 @@ class SwaggerAutoSchema(ViewInspector):
"""Return the request serializer (used for parsing the request payload) for this endpoint.
:return: the request serializer, or one of :class:`.Schema`, :class:`.SchemaRef`, ``None``
:rtype: rest_framework.serializers.Serializer
"""
body_override = self._get_request_body_override()
@ -430,11 +422,11 @@ class SwaggerAutoSchema(ViewInspector):
:rtype: list[str]
"""
return get_consumes(getattr(self.view, 'parser_classes', []))
return get_consumes(self.get_parser_classes())
def get_produces(self):
"""Return the MIME types this endpoint can produce.
:rtype: list[str]
"""
return get_produces(getattr(self.view, 'renderer_classes', []))
return get_produces(self.get_renderer_classes())

View File

@ -95,12 +95,12 @@ def swagger_auto_schema(method=None, methods=None, auto_schema=unset, request_bo
:type responses: dict[str,(drf_yasg.openapi.Schema or drf_yasg.openapi.SchemaRef or drf_yasg.openapi.Response or
str or rest_framework.serializers.Serializer)]
:param list[drf_yasg.inspectors.FieldInspector] field_inspectors: extra serializer and field inspectors; these will
be tried before :attr:`.ViewInspector.field_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance
:param list[drf_yasg.inspectors.FilterInspector] filter_inspectors: extra filter inspectors; these will be tried
before :attr:`.ViewInspector.filter_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance
:param list[drf_yasg.inspectors.PaginatorInspector] paginator_inspectors: extra paginator inspectors; these will be
tried before :attr:`.ViewInspector.paginator_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance
:param list[type[drf_yasg.inspectors.FieldInspector]] field_inspectors: extra serializer and field inspectors; these
will be tried before :attr:`.ViewInspector.field_inspectors` on the :class:`.inspectors.SwaggerAutoSchema`
:param list[type[drf_yasg.inspectors.FilterInspector]] filter_inspectors: extra filter inspectors; these will be
tried before :attr:`.ViewInspector.filter_inspectors` on the :class:`.inspectors.SwaggerAutoSchema`
:param list[type[drf_yasg.inspectors.PaginatorInspector]] paginator_inspectors: extra paginator inspectors; these
will be tried before :attr:`.ViewInspector.paginator_inspectors` on the :class:`.inspectors.SwaggerAutoSchema`
:param list[str] tags: tags override
:param extra_overrides: extra values that will be saved into the ``overrides`` dict; these values will be available
in the handling :class:`.inspectors.SwaggerAutoSchema` instance via ``self.overrides``
@ -300,6 +300,7 @@ def force_serializer_instance(serializer):
an assertion error.
:param serializer: serializer class or instance
:type serializer: serializers.BaseSerializer or type[serializers.BaseSerializer]
:return: serializer instance
:rtype: serializers.BaseSerializer
"""
@ -332,13 +333,38 @@ def get_serializer_class(serializer):
return type(serializer)
def get_object_classes(classes_or_instances, expected_base_class=None):
"""Given a list of instances or class objects, return the list of their classes.
:param classes_or_instances: mixed list to parse
:type classes_or_instances: list[type or object]
:param expected_base_class: if given, only subclasses or instances of this type will be returned
:type expected_base_class: type
:return: list of classes
:rtype: list
"""
classes_or_instances = classes_or_instances or []
result = []
for obj in classes_or_instances:
if inspect.isclass(obj):
if not expected_base_class or issubclass(obj, expected_base_class):
result.append(obj)
else:
if not expected_base_class or isinstance(obj, expected_base_class):
result.append(type(obj))
return result
def get_consumes(parser_classes):
"""Extract ``consumes`` MIME types from a list of parser classes.
:param list parser_classes: parser classes
:type parser_classes: list[rest_framework.parsers.BaseParser or type[rest_framework.parsers.BaseParser]]
:return: MIME types for ``consumes``
:rtype: list[str]
"""
parser_classes = get_object_classes(parser_classes)
media_types = [parser.media_type for parser in parser_classes or []]
non_form_media_types = [encoding for encoding in media_types if not is_form_media_type(encoding)]
if len(non_form_media_types) == 0:
@ -351,9 +377,11 @@ def get_produces(renderer_classes):
"""Extract ``produces`` MIME types from a list of renderer classes.
:param list renderer_classes: renderer classes
:type renderer_classes: list[rest_framework.renderers.BaseRenderer or type[rest_framework.renderers.BaseRenderer]]
:return: MIME types for ``produces``
:rtype: list[str]
"""
renderer_classes = get_object_classes(renderer_classes)
media_types = [renderer.media_type for renderer in renderer_classes or []]
media_types = [encoding for encoding in media_types
if not any(excluded in encoding for excluded in swagger_settings.EXCLUDED_MEDIA_TYPES)]
@ -378,7 +406,7 @@ def get_serializer_ref_name(serializer):
:param serializer: Serializer instance
:return: Serializer's ``ref_name`` or ``None`` for inline serializer
:rtype: str
:rtype: str or None
"""
serializer_meta = getattr(serializer, 'Meta', None)
serializer_name = type(serializer).__name__

View File

@ -55,7 +55,7 @@ def swagger_dict(swagger, codec_json):
@pytest.fixture
def validate_schema(db):
def validate_schema():
def validate_schema(swagger):
from flex.core import parse as validate_flex
from swagger_spec_validator.validator20 import validate_spec as validate_ssv