From a3e81ef7f636f8d2bea044521c15a9f7a89bd1e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristi=20V=C3=AEjdea?= Date: Wed, 24 Jan 2018 14:44:00 +0200 Subject: [PATCH] Improve handling of consumes and produces attributes (#55) * Fix get_consumes * Generate produces for Operation * Set global consumes and produces from rest framework DEFAULT_ settings --- src/drf_yasg/generators.py | 37 +++++++++++++++-------- src/drf_yasg/inspectors/view.py | 21 +++++++++---- src/drf_yasg/openapi.py | 17 +++++++++-- src/drf_yasg/utils.py | 28 ++++++++++++++++++ testproj/snippets/views.py | 3 +- tests/reference.yaml | 52 ++++++--------------------------- 6 files changed, 93 insertions(+), 65 deletions(-) diff --git a/src/drf_yasg/generators.py b/src/drf_yasg/generators.py index 6fbd0f4..85aa2da 100644 --- a/src/drf_yasg/generators.py +++ b/src/drf_yasg/generators.py @@ -10,13 +10,14 @@ from rest_framework.compat import URLPattern, URLResolver, get_original_route from rest_framework.schemas.generators import EndpointEnumerator as _EndpointEnumerator from rest_framework.schemas.generators import SchemaGenerator, endpoint_ordering from rest_framework.schemas.inspectors import get_pk_description - -from drf_yasg.errors import SwaggerGenerationError +from rest_framework.settings import api_settings as rest_framework_settings from . import openapi from .app_settings import swagger_settings +from .errors import SwaggerGenerationError from .inspectors.field import get_basic_type_info, get_queryset_field from .openapi import ReferenceResolver +from .utils import get_consumes, get_produces logger = logging.getLogger(__name__) @@ -165,6 +166,9 @@ class OpenAPISchemaGenerator(object): self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf) self.info = info self.version = version + self.consumes = [] + self.produces = [] + if url is None and swagger_settings.DEFAULT_API_URL is not None: url = swagger_settings.DEFAULT_API_URL @@ -191,22 +195,24 @@ class OpenAPISchemaGenerator(object): """ endpoints = self.get_endpoints(request) components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS) + self.consumes = get_consumes(rest_framework_settings.DEFAULT_PARSER_CLASSES) + self.produces = get_produces(rest_framework_settings.DEFAULT_RENDERER_CLASSES) paths, prefix = self.get_paths(endpoints, components, request, public) + security_definitions = swagger_settings.SECURITY_DEFINITIONS + security_requirements = swagger_settings.SECURITY_REQUIREMENTS + if security_requirements is None: + security_requirements = [{security_scheme: [] for security_scheme in swagger_settings.SECURITY_DEFINITIONS}] + url = self.url if url is None and request is not None: url = request.build_absolute_uri() - swagger = openapi.Swagger( - info=self.info, paths=paths, + return openapi.Swagger( + info=self.info, paths=paths, consumes=self.consumes or None, produces=self.produces or None, + security_definitions=security_definitions, security=security_requirements, _url=url, _prefix=prefix, _version=self.version, **dict(components) ) - swagger.security_definitions = swagger_settings.SECURITY_DEFINITIONS - security_requirements = swagger_settings.SECURITY_REQUIREMENTS - if security_requirements is None: - security_requirements = [{security_scheme: [] for security_scheme in swagger_settings.SECURITY_DEFINITIONS}] - swagger.security = security_requirements - return swagger def create_view(self, callback, method, request=None): """Create a view instance from a view callback as registered in urlpatterns. @@ -330,7 +336,6 @@ class OpenAPISchemaGenerator(object): :param Request request: the request made against the schema view; can be None :rtype: openapi.Operation """ - operation_keys = self.get_operation_keys(path[len(prefix):], method, view) overrides = self.get_overrides(view, method) @@ -342,8 +347,16 @@ class OpenAPISchemaGenerator(object): # 3. on the swagger_auto_schema decorator view_inspector_cls = overrides.get('auto_schema', view_inspector_cls) + if view_inspector_cls is None: + return None + view_inspector = view_inspector_cls(view, path, method, components, request, overrides) - return view_inspector.get_operation(operation_keys) + operation = view_inspector.get_operation(operation_keys) + if set(operation.consumes) == set(self.consumes): + del operation.consumes + if set(operation.produces) == set(self.produces): + del operation.produces + return operation def get_path_item(self, path, view_cls, operations): """Get a :class:`.PathItem` object that describes the parameters and operations related to a single path in the diff --git a/src/drf_yasg/inspectors/view.py b/src/drf_yasg/inspectors/view.py index d3fda3a..0e5d46f 100644 --- a/src/drf_yasg/inspectors/view.py +++ b/src/drf_yasg/inspectors/view.py @@ -6,7 +6,10 @@ from rest_framework.status import is_success from .. import openapi from ..errors import SwaggerGenerationError -from ..utils import force_serializer_instance, guess_response_status, is_list_view, no_body, param_list_to_odict +from ..utils import ( + force_serializer_instance, get_consumes, get_produces, guess_response_status, is_list_view, no_body, + param_list_to_odict +) from .base import ViewInspector @@ -18,6 +21,7 @@ class SwaggerAutoSchema(ViewInspector): def get_operation(self, operation_keys): consumes = self.get_consumes() + produces = self.get_produces() body = self.get_request_body_parameters(consumes) query = self.get_query_parameters() @@ -39,6 +43,7 @@ class SwaggerAutoSchema(ViewInspector): responses=responses, parameters=parameters, consumes=consumes, + produces=produces, tags=tags, security=security ) @@ -296,7 +301,7 @@ class SwaggerAutoSchema(ViewInspector): authentication schemes). Returning ``None`` will inherit the top-level secuirty requirements. :return: security requirements - :rtype: list""" + :rtype: list[dict[str,list[str]]]""" return self.overrides.get('security', None) def get_tags(self, operation_keys): @@ -314,7 +319,11 @@ class SwaggerAutoSchema(ViewInspector): :rtype: list[str] """ - media_types = [parser.media_type for parser in getattr(self.view, 'parser_classes', [])] - if all(is_form_media_type(encoding) for encoding in media_types): - return media_types - return media_types[:1] + return get_consumes(getattr(self.view, 'parser_classes', [])) + + def get_produces(self): + """Return the MIME types this endpoint can produce. + + :rtype: list[str] + """ + return get_produces(getattr(self.view, 'renderer_classes', [])) diff --git a/src/drf_yasg/openapi.py b/src/drf_yasg/openapi.py index 2aa8dca..d1ed257 100644 --- a/src/drf_yasg/openapi.py +++ b/src/drf_yasg/openapi.py @@ -211,7 +211,8 @@ class Info(SwaggerDict): class Swagger(SwaggerDict): - def __init__(self, info=None, _url=None, _prefix=None, _version=None, paths=None, definitions=None, **extra): + def __init__(self, info=None, _url=None, _prefix=None, _version=None, consumes=None, produces=None, + security_definitions=None, security=None, paths=None, definitions=None, **extra): """Root Swagger object. :param .Info info: info object @@ -219,6 +220,10 @@ class Swagger(SwaggerDict): :param str _prefix: api path prefix to use in setting basePath; this will be appended to the wsgi SCRIPT_NAME prefix or Django's FORCE_SCRIPT_NAME if applicable :param str _version: version string to override Info + :param list[dict] security_definitions: list of supported authentication mechanisms + :param list[dict] security: authentication mechanisms accepted by default; can be overriden in Operation + :param list[str] consumes: consumed MIME types; can be overriden in Operation + :param list[str] produces: produced MIME types; can be overriden in Operation :param .Paths paths: paths object :param dict[str,.Schema] definitions: named models """ @@ -234,6 +239,10 @@ class Swagger(SwaggerDict): self.schemes = [url.scheme] self.base_path = self.get_base_path(get_script_prefix(), _prefix) + self.consumes = consumes + self.produces = produces + self.security_definitions = filter_none(security_definitions) + self.security = filter_none(security) self.paths = paths self.definitions = filter_none(definitions) self._insert_extras__() @@ -304,8 +313,8 @@ class PathItem(SwaggerDict): class Operation(SwaggerDict): - def __init__(self, operation_id, responses, parameters=None, consumes=None, - produces=None, summary=None, description=None, tags=None, **extra): + def __init__(self, operation_id, responses, parameters=None, consumes=None, produces=None, summary=None, + description=None, tags=None, security=None, **extra): """Information about an API operation (path + http method combination) :param str operation_id: operation ID, should be unique across all operations @@ -316,6 +325,7 @@ class Operation(SwaggerDict): :param str summary: operation summary; should be < 120 characters :param str description: operation description; can be of any length and supports markdown :param list[str] tags: operation tags + :param list[dict[str,list[str]]] security: list of security requirements """ super(Operation, self).__init__(**extra) self.operation_id = operation_id @@ -326,6 +336,7 @@ class Operation(SwaggerDict): self.consumes = filter_none(consumes) self.produces = filter_none(produces) self.tags = filter_none(tags) + self.security = filter_none(security) self._insert_extras__() diff --git a/src/drf_yasg/utils.py b/src/drf_yasg/utils.py index 7eda8a1..32b2dbc 100644 --- a/src/drf_yasg/utils.py +++ b/src/drf_yasg/utils.py @@ -4,6 +4,7 @@ from collections import OrderedDict from rest_framework import serializers, status from rest_framework.mixins import DestroyModelMixin, RetrieveModelMixin, UpdateModelMixin +from rest_framework.request import is_form_media_type from rest_framework.views import APIView logger = logging.getLogger(__name__) @@ -248,3 +249,30 @@ def force_serializer_instance(serializer): assert isinstance(serializer, serializers.BaseSerializer), \ "Serializer class or instance required, not %s" % type(serializer).__name__ return serializer + + +def get_consumes(parser_classes): + """Extract ``consumes`` MIME types from a list of parser classes. + + :param list parser_classes: parser classes + :return: MIME types for ``consumes`` + :rtype: list[str] + """ + media_types = [parser.media_type for parser in parser_classes or []] + if all(is_form_media_type(encoding) for encoding in media_types): + return media_types + else: + media_types = [encoding for encoding in media_types if not is_form_media_type(encoding)] + return media_types + + +def get_produces(renderer_classes): + """Extract ``produces`` MIME types from a list of renderer classes. + + :param list renderer_classes: renderer classes + :return: MIME types for ``produces`` + :rtype: list[str] + """ + media_types = [renderer.media_type for renderer in renderer_classes or []] + media_types = [encoding for encoding in media_types if 'html' not in encoding] + return media_types diff --git a/testproj/snippets/views.py b/testproj/snippets/views.py index e4cf7ea..25cdb3d 100644 --- a/testproj/snippets/views.py +++ b/testproj/snippets/views.py @@ -2,6 +2,7 @@ from djangorestframework_camel_case.parser import CamelCaseJSONParser from djangorestframework_camel_case.render import CamelCaseJSONRenderer from inflection import camelize from rest_framework import generics +from rest_framework.parsers import FormParser from drf_yasg import openapi from drf_yasg.inspectors import SwaggerAutoSchema @@ -21,7 +22,7 @@ class SnippetList(generics.ListCreateAPIView): queryset = Snippet.objects.all() serializer_class = SnippetSerializer - parser_classes = (CamelCaseJSONParser,) + parser_classes = (FormParser, CamelCaseJSONParser,) renderer_classes = (CamelCaseJSONRenderer,) swagger_schema = CamelCaseOperationIDAutoSchema diff --git a/tests/reference.yaml b/tests/reference.yaml index e4d9c25..1990d7c 100644 --- a/tests/reference.yaml +++ b/tests/reference.yaml @@ -16,6 +16,15 @@ host: test.local:8002 schemes: - http basePath: / +consumes: + - application/json +produces: + - application/json +securityDefinitions: + basic: + type: basic +security: + - basic: [] paths: /articles/: get: @@ -63,8 +72,6 @@ paths: type: array items: $ref: '#/definitions/Article' - consumes: - - application/json tags: - articles post: @@ -81,8 +88,6 @@ paths: description: '' schema: $ref: '#/definitions/Article' - consumes: - - application/json tags: - articles parameters: [] @@ -108,8 +113,6 @@ paths: type: array items: $ref: '#/definitions/Article' - consumes: - - application/json tags: - articles parameters: [] @@ -123,8 +126,6 @@ paths: description: '' schema: $ref: '#/definitions/Article' - consumes: - - application/json tags: - articles put: @@ -141,8 +142,6 @@ paths: description: '' schema: $ref: '#/definitions/Article' - consumes: - - application/json tags: - articles patch: @@ -161,8 +160,6 @@ paths: $ref: '#/definitions/Article' '404': description: slug not found - consumes: - - application/json tags: - articles delete: @@ -172,8 +169,6 @@ paths: responses: '204': description: '' - consumes: - - application/json tags: - articles parameters: @@ -246,8 +241,6 @@ paths: responses: '200': description: '' - consumes: - - application/json tags: - plain parameters: [] @@ -263,8 +256,6 @@ paths: type: array items: $ref: '#/definitions/Snippet' - consumes: - - application/json tags: - snippets post: @@ -281,8 +272,6 @@ paths: description: '' schema: $ref: '#/definitions/Snippet' - consumes: - - application/json tags: - snippets parameters: [] @@ -296,8 +285,6 @@ paths: description: '' schema: $ref: '#/definitions/Snippet' - consumes: - - application/json tags: - snippets put: @@ -314,8 +301,6 @@ paths: description: '' schema: $ref: '#/definitions/Snippet' - consumes: - - application/json tags: - snippets patch: @@ -332,8 +317,6 @@ paths: description: '' schema: $ref: '#/definitions/Snippet' - consumes: - - application/json tags: - snippets delete: @@ -348,8 +331,6 @@ paths: responses: '204': description: '' - consumes: - - application/json tags: - snippets parameters: @@ -380,8 +361,6 @@ paths: type: array items: $ref: '#/definitions/UserSerializerrr' - consumes: - - application/json tags: - users post: @@ -408,8 +387,6 @@ paths: properties: username: type: string - consumes: - - application/json tags: - users security: [] @@ -420,8 +397,6 @@ paths: responses: '200': description: '' - consumes: - - application/json tags: - users parameters: [] @@ -439,8 +414,6 @@ paths: description: response description schema: $ref: '#/definitions/UserSerializerrr' - consumes: - - application/json tags: - users put: @@ -457,8 +430,6 @@ paths: description: '' schema: $ref: '#/definitions/UserSerializerrr' - consumes: - - application/json tags: - users parameters: @@ -1121,8 +1092,3 @@ definitions: pattern: ^[-a-zA-Z0-9_]+$ readOnly: true uniqueItems: true -securityDefinitions: - basic: - type: basic -security: - - basic: []