From bebcc982e67597412f9207a6e36d8f1aa0822ba0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristi=20V=C3=AEjdea?= Date: Sun, 23 Dec 2018 16:57:01 +0200 Subject: [PATCH] Call APIView get_ methods instead of direct attribute access queryset -> get_queryset renderer_classes -> get_renderers parser_classes -> get_parsers --- docs/conf.py | 1 + src/drf_yasg/codecs.py | 1 + src/drf_yasg/generators.py | 36 ++++++++++++++++------ src/drf_yasg/inspectors/base.py | 53 ++++++++++++++++++++++++++++---- src/drf_yasg/inspectors/field.py | 8 ++--- src/drf_yasg/inspectors/view.py | 24 +++++---------- src/drf_yasg/utils.py | 42 ++++++++++++++++++++----- tests/conftest.py | 2 +- 8 files changed, 124 insertions(+), 43 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index d571e5d..6fc44f4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -183,6 +183,7 @@ nitpick_ignore = [ ('py:class', 'ruamel.yaml.dumper.SafeDumper'), ('py:class', 'rest_framework.serializers.Serializer'), ('py:class', 'rest_framework.renderers.BaseRenderer'), + ('py:class', 'rest_framework.parsers.BaseParser'), ('py:class', 'rest_framework.schemas.generators.EndpointEnumerator'), ('py:class', 'rest_framework.views.APIView'), diff --git a/src/drf_yasg/codecs.py b/src/drf_yasg/codecs.py index c942929..a9cf557 100644 --- a/src/drf_yasg/codecs.py +++ b/src/drf_yasg/codecs.py @@ -13,6 +13,7 @@ from .errors import SwaggerValidationError logger = logging.getLogger(__name__) + def _validate_flex(spec): from flex.core import parse as validate_flex from flex.exceptions import ValidationError diff --git a/src/drf_yasg/generators.py b/src/drf_yasg/generators.py index db67786..950fac6 100644 --- a/src/drf_yasg/generators.py +++ b/src/drf_yasg/generators.py @@ -8,9 +8,9 @@ from coreapi.compat import urlparse from rest_framework import versioning from rest_framework.compat import URLPattern, URLResolver, get_original_route from rest_framework.schemas.generators import EndpointEnumerator as _EndpointEnumerator -from rest_framework.schemas.generators import SchemaGenerator, endpoint_ordering +from rest_framework.schemas.generators import SchemaGenerator, endpoint_ordering, get_pk_name from rest_framework.schemas.inspectors import get_pk_description -from rest_framework.settings import api_settings as rest_framework_settings +from rest_framework.settings import api_settings from . import openapi from .app_settings import swagger_settings @@ -132,7 +132,7 @@ class EndpointEnumerator(_EndpointEnumerator): def unescape_path(self, path): """Remove backslashe escapes from all path components outside {parameters}. This is needed because - ``simplify_regex`` does not handle this correctly - note however that this implementation is + ``simplify_regex`` does not handle this correctly. **NOTE:** this might destructively affect some url regex patterns that contain metacharacters (e.g. \\w, \\d) outside path parameter groups; if you are in this category, God help you @@ -164,7 +164,7 @@ class OpenAPISchemaGenerator(object): def __init__(self, info, version='', url=None, patterns=None, urlconf=None): """ - :param .Info info: information about the API + :param openapi.Info info: information about the API :param str version: API version string; if omitted, `info.default_version` will be used :param str url: API scheme, host and port; if ``None`` is passed and ``DEFAULT_API_URL`` is not set, the url will be inferred from the request made against the schema view, so you should generally not need to set @@ -216,7 +216,7 @@ class OpenAPISchemaGenerator(object): :meth:`.get_security_definitions` returns `None`. :param security_definitions: security definitions as returned by :meth:`.get_security_definitions` - :return: + :return: the security schemes accepted by default :rtype: list[dict[str,list[str]]] or None """ security_requirements = swagger_settings.SECURITY_REQUIREMENTS @@ -239,8 +239,8 @@ class OpenAPISchemaGenerator(object): """ endpoints = self.get_endpoints(request) components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS) - self.consumes = get_consumes(rest_framework_settings.DEFAULT_PARSER_CLASSES) - self.produces = get_produces(rest_framework_settings.DEFAULT_RENDERER_CLASSES) + self.consumes = get_consumes(api_settings.DEFAULT_PARSER_CLASSES) + self.produces = get_produces(api_settings.DEFAULT_RENDERER_CLASSES) paths, prefix = self.get_paths(endpoints, components, request, public) security_definitions = self.get_security_definitions() @@ -280,6 +280,24 @@ class OpenAPISchemaGenerator(object): setattr(view, 'swagger_fake_view', True) return view + def coerce_path(self, path, view): + """Coerce {pk} path arguments into the name of the model field, where possible. This is cleaner for an + external representation (i.e. "this is an identifier", not "this is a database primary key"). + + :param str path: the path + :param rest_framework.views.APIView view: associated view + :rtype: str + """ + if '{pk}' not in path: + return path + + model = getattr(get_queryset_from_view(view), 'model', None) + if model: + field_name = get_pk_name(model) + else: + field_name = 'id' + return path.replace('{pk}', '{%s}' % field_name) + def get_endpoints(self, request): """Iterate over all the registered endpoints in the API and return a fake view with the right parameters. @@ -295,7 +313,7 @@ class OpenAPISchemaGenerator(object): view_cls = {} for path, method, callback in endpoints: view = self.create_view(callback, method, request) - path = self._gen.coerce_path(path, method, view) + path = self.coerce_path(path, view) view_paths[path].append((method, view)) view_cls[path] = callback.cls return {path: (view_cls[path], methods) for path, methods in view_paths.items()} @@ -313,7 +331,7 @@ class OpenAPISchemaGenerator(object): :param str subpath: path to the operation with any common prefix/base path removed :param str method: HTTP method :param view: the view associated with the operation - :rtype: tuple + :rtype: list[str] """ return self._gen.get_keys(subpath, method, view) diff --git a/src/drf_yasg/inspectors/base.py b/src/drf_yasg/inspectors/base.py index 7c6f11a..853b0e9 100644 --- a/src/drf_yasg/inspectors/base.py +++ b/src/drf_yasg/inspectors/base.py @@ -4,7 +4,7 @@ import logging from rest_framework import serializers from .. import openapi -from ..utils import force_real_str, get_field_default, is_list_view +from ..utils import force_real_str, get_field_default, get_object_classes, is_list_view #: Sentinel value that inspectors must return to signal that they do not know how to handle an object NotHandled = object() @@ -12,14 +12,38 @@ NotHandled = object() logger = logging.getLogger(__name__) +def call_view_method(view, method_name, fallback_attr=None, default=None): + """Call a view method which might throw an exception. If an exception is thrown, log an informative error message + and return the value of fallback_attr, or default if not present. + + :param rest_framework.views.APIView view: + :param str method_name: name of a method on the view + :param str fallback_attr: name of an attribute on the view to fall back on, if calling the method fails + :param default: default value if all else fails + :return: view method's return value, or value of view's fallback_attr, or default + """ + if hasattr(view, method_name): + try: + return getattr(view, method_name)() + except Exception: # pragma: no cover + logger.warning("view's %s.get_parsers raised exception during schema generation; use " + "`getattr(self, 'swagger_fake_view', False)` to detect and short-circuit this", + type(view).__name__, exc_info=True) + + if fallback_attr and hasattr(view, fallback_attr): + return getattr(view, fallback_attr) + + return default + + class BaseInspector(object): def __init__(self, view, path, method, components, request): """ - :param view: the view associated with this endpoint + :param rest_framework.views.APIView view: the view associated with this endpoint :param str path: the path component of the operation URL :param str method: the http method of the operation :param openapi.ReferenceResolver components: referenceable components - :param Request request: the request made against the schema view; can be None + :param rest_framework.request.Request request: the request made against the schema view; can be None """ self.view = view self.path = path @@ -81,6 +105,22 @@ class BaseInspector(object): return result + def get_renderer_classes(self): + """Get the renderer classes of this view by calling `get_renderers`. + + :return: renderer classes + :rtype: list[type[rest_framework.renderers.BaseRenderer]] + """ + return get_object_classes(call_view_method(self.view, 'get_renderers', 'renderer_classes', [])) + + def get_parser_classes(self): + """Get the parser classes of this view by calling `get_parsers`. + + :return: parser classes + :rtype: list[type[rest_framework.parsers.BaseParser]] + """ + return get_object_classes(call_view_method(self.view, 'get_parsers', 'parser_classes', [])) + class PaginatorInspector(BaseInspector): """Base inspector for paginators. @@ -335,7 +375,7 @@ class ViewInspector(BaseInspector): return [] fields = [] - for filter_backend in self.view.filter_backends: + for filter_backend in getattr(self.view, 'filter_backends'): fields += self.probe_inspectors(self.filter_inspectors, 'get_filter_parameters', filter_backend()) or [] return fields @@ -361,7 +401,8 @@ class ViewInspector(BaseInspector): if not self.should_page(): return [] - return self.probe_inspectors(self.paginator_inspectors, 'get_paginator_parameters', self.view.paginator) or [] + return self.probe_inspectors(self.paginator_inspectors, 'get_paginator_parameters', + getattr(self.view, 'paginator')) or [] def serializer_to_schema(self, serializer): """Convert a serializer to an OpenAPI :class:`.Schema`. @@ -394,4 +435,4 @@ class ViewInspector(BaseInspector): :rtype: openapi.Schema """ return self.probe_inspectors(self.paginator_inspectors, 'get_paginated_response', - self.view.paginator, response_schema=response_schema) + getattr(self.view, 'paginator'), response_schema=response_schema) diff --git a/src/drf_yasg/inspectors/field.py b/src/drf_yasg/inspectors/field.py index 45b23e3..194504b 100644 --- a/src/drf_yasg/inspectors/field.py +++ b/src/drf_yasg/inspectors/field.py @@ -14,7 +14,7 @@ from rest_framework.settings import api_settings as rest_framework_settings from .. import openapi from ..errors import SwaggerGenerationError from ..utils import decimal_as_float, filter_none, get_serializer_class, get_serializer_ref_name -from .base import FieldInspector, NotHandled, SerializerInspector +from .base import FieldInspector, NotHandled, SerializerInspector, call_view_method try: import typing @@ -177,7 +177,7 @@ def get_queryset_from_view(view, serializer=None): :return: queryset or ``None`` """ try: - queryset = getattr(view, 'queryset', None) + queryset = call_view_method(view, 'get_queryset', 'queryset', None) if queryset is not None and serializer is not None: # make sure the view is actually using *this* serializer @@ -733,8 +733,8 @@ class CamelCaseJSONFilter(FieldInspector): if CamelCaseJSONParser and CamelCaseJSONRenderer: def is_camel_case(self): return ( - any(issubclass(parser, CamelCaseJSONParser) for parser in self.view.parser_classes) or - any(issubclass(renderer, CamelCaseJSONRenderer) for renderer in self.view.renderer_classes) + any(issubclass(parser, CamelCaseJSONParser) for parser in self.get_parser_classes()) or + any(issubclass(renderer, CamelCaseJSONRenderer) for renderer in self.get_renderer_classes()) ) else: def is_camel_case(self): diff --git a/src/drf_yasg/inspectors/view.py b/src/drf_yasg/inspectors/view.py index 211bd31..167d689 100644 --- a/src/drf_yasg/inspectors/view.py +++ b/src/drf_yasg/inspectors/view.py @@ -12,9 +12,9 @@ from ..utils import ( filter_none, force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status, is_list_view, merge_params, no_body, param_list_to_odict ) -from .base import ViewInspector +from .base import ViewInspector, call_view_method -log = logging.getLogger(__name__) +logger = logging.getLogger(__name__) class SwaggerAutoSchema(ViewInspector): @@ -103,21 +103,12 @@ class SwaggerAutoSchema(ViewInspector): """Return the serializer as defined by the view's ``get_serializer()`` method. :return: the view's ``Serializer`` + :rtype: rest_framework.serializers.Serializer """ - if not hasattr(self.view, 'get_serializer'): - return None - try: - return self.view.get_serializer() - except Exception: - log.warning("view's get_serializer raised exception (%s %s %s)", - self.method, self.path, type(self.view).__name__, exc_info=True) - return None + return call_view_method(self.view, 'get_serializer') def _get_request_body_override(self): - """Parse the request_body key in the override dict. This method is not public API. - - :return: - """ + """Parse the request_body key in the override dict. This method is not public API.""" body_override = self.overrides.get('request_body', None) if body_override is not None: @@ -136,6 +127,7 @@ class SwaggerAutoSchema(ViewInspector): """Return the request serializer (used for parsing the request payload) for this endpoint. :return: the request serializer, or one of :class:`.Schema`, :class:`.SchemaRef`, ``None`` + :rtype: rest_framework.serializers.Serializer """ body_override = self._get_request_body_override() @@ -430,11 +422,11 @@ class SwaggerAutoSchema(ViewInspector): :rtype: list[str] """ - return get_consumes(getattr(self.view, 'parser_classes', [])) + return get_consumes(self.get_parser_classes()) def get_produces(self): """Return the MIME types this endpoint can produce. :rtype: list[str] """ - return get_produces(getattr(self.view, 'renderer_classes', [])) + return get_produces(self.get_renderer_classes()) diff --git a/src/drf_yasg/utils.py b/src/drf_yasg/utils.py index 0486515..e69958d 100644 --- a/src/drf_yasg/utils.py +++ b/src/drf_yasg/utils.py @@ -95,12 +95,12 @@ def swagger_auto_schema(method=None, methods=None, auto_schema=unset, request_bo :type responses: dict[str,(drf_yasg.openapi.Schema or drf_yasg.openapi.SchemaRef or drf_yasg.openapi.Response or str or rest_framework.serializers.Serializer)] - :param list[drf_yasg.inspectors.FieldInspector] field_inspectors: extra serializer and field inspectors; these will - be tried before :attr:`.ViewInspector.field_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance - :param list[drf_yasg.inspectors.FilterInspector] filter_inspectors: extra filter inspectors; these will be tried - before :attr:`.ViewInspector.filter_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance - :param list[drf_yasg.inspectors.PaginatorInspector] paginator_inspectors: extra paginator inspectors; these will be - tried before :attr:`.ViewInspector.paginator_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` instance + :param list[type[drf_yasg.inspectors.FieldInspector]] field_inspectors: extra serializer and field inspectors; these + will be tried before :attr:`.ViewInspector.field_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` + :param list[type[drf_yasg.inspectors.FilterInspector]] filter_inspectors: extra filter inspectors; these will be + tried before :attr:`.ViewInspector.filter_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` + :param list[type[drf_yasg.inspectors.PaginatorInspector]] paginator_inspectors: extra paginator inspectors; these + will be tried before :attr:`.ViewInspector.paginator_inspectors` on the :class:`.inspectors.SwaggerAutoSchema` :param list[str] tags: tags override :param extra_overrides: extra values that will be saved into the ``overrides`` dict; these values will be available in the handling :class:`.inspectors.SwaggerAutoSchema` instance via ``self.overrides`` @@ -300,6 +300,7 @@ def force_serializer_instance(serializer): an assertion error. :param serializer: serializer class or instance + :type serializer: serializers.BaseSerializer or type[serializers.BaseSerializer] :return: serializer instance :rtype: serializers.BaseSerializer """ @@ -332,13 +333,38 @@ def get_serializer_class(serializer): return type(serializer) +def get_object_classes(classes_or_instances, expected_base_class=None): + """Given a list of instances or class objects, return the list of their classes. + + :param classes_or_instances: mixed list to parse + :type classes_or_instances: list[type or object] + :param expected_base_class: if given, only subclasses or instances of this type will be returned + :type expected_base_class: type + :return: list of classes + :rtype: list + """ + classes_or_instances = classes_or_instances or [] + result = [] + for obj in classes_or_instances: + if inspect.isclass(obj): + if not expected_base_class or issubclass(obj, expected_base_class): + result.append(obj) + else: + if not expected_base_class or isinstance(obj, expected_base_class): + result.append(type(obj)) + + return result + + def get_consumes(parser_classes): """Extract ``consumes`` MIME types from a list of parser classes. :param list parser_classes: parser classes + :type parser_classes: list[rest_framework.parsers.BaseParser or type[rest_framework.parsers.BaseParser]] :return: MIME types for ``consumes`` :rtype: list[str] """ + parser_classes = get_object_classes(parser_classes) media_types = [parser.media_type for parser in parser_classes or []] non_form_media_types = [encoding for encoding in media_types if not is_form_media_type(encoding)] if len(non_form_media_types) == 0: @@ -351,9 +377,11 @@ def get_produces(renderer_classes): """Extract ``produces`` MIME types from a list of renderer classes. :param list renderer_classes: renderer classes + :type renderer_classes: list[rest_framework.renderers.BaseRenderer or type[rest_framework.renderers.BaseRenderer]] :return: MIME types for ``produces`` :rtype: list[str] """ + renderer_classes = get_object_classes(renderer_classes) media_types = [renderer.media_type for renderer in renderer_classes or []] media_types = [encoding for encoding in media_types if not any(excluded in encoding for excluded in swagger_settings.EXCLUDED_MEDIA_TYPES)] @@ -378,7 +406,7 @@ def get_serializer_ref_name(serializer): :param serializer: Serializer instance :return: Serializer's ``ref_name`` or ``None`` for inline serializer - :rtype: str + :rtype: str or None """ serializer_meta = getattr(serializer, 'Meta', None) serializer_name = type(serializer).__name__ diff --git a/tests/conftest.py b/tests/conftest.py index 3cafaa7..48180aa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,7 +55,7 @@ def swagger_dict(swagger, codec_json): @pytest.fixture -def validate_schema(db): +def validate_schema(): def validate_schema(swagger): from flex.core import parse as validate_flex from swagger_spec_validator.validator20 import validate_spec as validate_ssv