Improve handling of consumes and produces attributes (#55)

* Fix get_consumes
* Generate produces for Operation
* Set global consumes and produces from rest framework DEFAULT_ settings
openapi3
Cristi Vîjdea 2018-01-24 14:44:00 +02:00 committed by GitHub
parent a46b684fea
commit a3e81ef7f6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 93 additions and 65 deletions

View File

@ -10,13 +10,14 @@ from rest_framework.compat import URLPattern, URLResolver, get_original_route
from rest_framework.schemas.generators import EndpointEnumerator as _EndpointEnumerator from rest_framework.schemas.generators import EndpointEnumerator as _EndpointEnumerator
from rest_framework.schemas.generators import SchemaGenerator, endpoint_ordering from rest_framework.schemas.generators import SchemaGenerator, endpoint_ordering
from rest_framework.schemas.inspectors import get_pk_description from rest_framework.schemas.inspectors import get_pk_description
from rest_framework.settings import api_settings as rest_framework_settings
from drf_yasg.errors import SwaggerGenerationError
from . import openapi from . import openapi
from .app_settings import swagger_settings from .app_settings import swagger_settings
from .errors import SwaggerGenerationError
from .inspectors.field import get_basic_type_info, get_queryset_field from .inspectors.field import get_basic_type_info, get_queryset_field
from .openapi import ReferenceResolver from .openapi import ReferenceResolver
from .utils import get_consumes, get_produces
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -165,6 +166,9 @@ class OpenAPISchemaGenerator(object):
self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf) self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf)
self.info = info self.info = info
self.version = version self.version = version
self.consumes = []
self.produces = []
if url is None and swagger_settings.DEFAULT_API_URL is not None: if url is None and swagger_settings.DEFAULT_API_URL is not None:
url = swagger_settings.DEFAULT_API_URL url = swagger_settings.DEFAULT_API_URL
@ -191,22 +195,24 @@ class OpenAPISchemaGenerator(object):
""" """
endpoints = self.get_endpoints(request) endpoints = self.get_endpoints(request)
components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS) components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS)
self.consumes = get_consumes(rest_framework_settings.DEFAULT_PARSER_CLASSES)
self.produces = get_produces(rest_framework_settings.DEFAULT_RENDERER_CLASSES)
paths, prefix = self.get_paths(endpoints, components, request, public) paths, prefix = self.get_paths(endpoints, components, request, public)
security_definitions = swagger_settings.SECURITY_DEFINITIONS
security_requirements = swagger_settings.SECURITY_REQUIREMENTS
if security_requirements is None:
security_requirements = [{security_scheme: [] for security_scheme in swagger_settings.SECURITY_DEFINITIONS}]
url = self.url url = self.url
if url is None and request is not None: if url is None and request is not None:
url = request.build_absolute_uri() url = request.build_absolute_uri()
swagger = openapi.Swagger( return openapi.Swagger(
info=self.info, paths=paths, info=self.info, paths=paths, consumes=self.consumes or None, produces=self.produces or None,
security_definitions=security_definitions, security=security_requirements,
_url=url, _prefix=prefix, _version=self.version, **dict(components) _url=url, _prefix=prefix, _version=self.version, **dict(components)
) )
swagger.security_definitions = swagger_settings.SECURITY_DEFINITIONS
security_requirements = swagger_settings.SECURITY_REQUIREMENTS
if security_requirements is None:
security_requirements = [{security_scheme: [] for security_scheme in swagger_settings.SECURITY_DEFINITIONS}]
swagger.security = security_requirements
return swagger
def create_view(self, callback, method, request=None): def create_view(self, callback, method, request=None):
"""Create a view instance from a view callback as registered in urlpatterns. """Create a view instance from a view callback as registered in urlpatterns.
@ -330,7 +336,6 @@ class OpenAPISchemaGenerator(object):
:param Request request: the request made against the schema view; can be None :param Request request: the request made against the schema view; can be None
:rtype: openapi.Operation :rtype: openapi.Operation
""" """
operation_keys = self.get_operation_keys(path[len(prefix):], method, view) operation_keys = self.get_operation_keys(path[len(prefix):], method, view)
overrides = self.get_overrides(view, method) overrides = self.get_overrides(view, method)
@ -342,8 +347,16 @@ class OpenAPISchemaGenerator(object):
# 3. on the swagger_auto_schema decorator # 3. on the swagger_auto_schema decorator
view_inspector_cls = overrides.get('auto_schema', view_inspector_cls) view_inspector_cls = overrides.get('auto_schema', view_inspector_cls)
if view_inspector_cls is None:
return None
view_inspector = view_inspector_cls(view, path, method, components, request, overrides) view_inspector = view_inspector_cls(view, path, method, components, request, overrides)
return view_inspector.get_operation(operation_keys) operation = view_inspector.get_operation(operation_keys)
if set(operation.consumes) == set(self.consumes):
del operation.consumes
if set(operation.produces) == set(self.produces):
del operation.produces
return operation
def get_path_item(self, path, view_cls, operations): def get_path_item(self, path, view_cls, operations):
"""Get a :class:`.PathItem` object that describes the parameters and operations related to a single path in the """Get a :class:`.PathItem` object that describes the parameters and operations related to a single path in the

View File

@ -6,7 +6,10 @@ from rest_framework.status import is_success
from .. import openapi from .. import openapi
from ..errors import SwaggerGenerationError from ..errors import SwaggerGenerationError
from ..utils import force_serializer_instance, guess_response_status, is_list_view, no_body, param_list_to_odict from ..utils import (
force_serializer_instance, get_consumes, get_produces, guess_response_status, is_list_view, no_body,
param_list_to_odict
)
from .base import ViewInspector from .base import ViewInspector
@ -18,6 +21,7 @@ class SwaggerAutoSchema(ViewInspector):
def get_operation(self, operation_keys): def get_operation(self, operation_keys):
consumes = self.get_consumes() consumes = self.get_consumes()
produces = self.get_produces()
body = self.get_request_body_parameters(consumes) body = self.get_request_body_parameters(consumes)
query = self.get_query_parameters() query = self.get_query_parameters()
@ -39,6 +43,7 @@ class SwaggerAutoSchema(ViewInspector):
responses=responses, responses=responses,
parameters=parameters, parameters=parameters,
consumes=consumes, consumes=consumes,
produces=produces,
tags=tags, tags=tags,
security=security security=security
) )
@ -296,7 +301,7 @@ class SwaggerAutoSchema(ViewInspector):
authentication schemes). Returning ``None`` will inherit the top-level secuirty requirements. authentication schemes). Returning ``None`` will inherit the top-level secuirty requirements.
:return: security requirements :return: security requirements
:rtype: list""" :rtype: list[dict[str,list[str]]]"""
return self.overrides.get('security', None) return self.overrides.get('security', None)
def get_tags(self, operation_keys): def get_tags(self, operation_keys):
@ -314,7 +319,11 @@ class SwaggerAutoSchema(ViewInspector):
:rtype: list[str] :rtype: list[str]
""" """
media_types = [parser.media_type for parser in getattr(self.view, 'parser_classes', [])] return get_consumes(getattr(self.view, 'parser_classes', []))
if all(is_form_media_type(encoding) for encoding in media_types):
return media_types def get_produces(self):
return media_types[:1] """Return the MIME types this endpoint can produce.
:rtype: list[str]
"""
return get_produces(getattr(self.view, 'renderer_classes', []))

View File

@ -211,7 +211,8 @@ class Info(SwaggerDict):
class Swagger(SwaggerDict): class Swagger(SwaggerDict):
def __init__(self, info=None, _url=None, _prefix=None, _version=None, paths=None, definitions=None, **extra): def __init__(self, info=None, _url=None, _prefix=None, _version=None, consumes=None, produces=None,
security_definitions=None, security=None, paths=None, definitions=None, **extra):
"""Root Swagger object. """Root Swagger object.
:param .Info info: info object :param .Info info: info object
@ -219,6 +220,10 @@ class Swagger(SwaggerDict):
:param str _prefix: api path prefix to use in setting basePath; this will be appended to the wsgi :param str _prefix: api path prefix to use in setting basePath; this will be appended to the wsgi
SCRIPT_NAME prefix or Django's FORCE_SCRIPT_NAME if applicable SCRIPT_NAME prefix or Django's FORCE_SCRIPT_NAME if applicable
:param str _version: version string to override Info :param str _version: version string to override Info
:param list[dict] security_definitions: list of supported authentication mechanisms
:param list[dict] security: authentication mechanisms accepted by default; can be overriden in Operation
:param list[str] consumes: consumed MIME types; can be overriden in Operation
:param list[str] produces: produced MIME types; can be overriden in Operation
:param .Paths paths: paths object :param .Paths paths: paths object
:param dict[str,.Schema] definitions: named models :param dict[str,.Schema] definitions: named models
""" """
@ -234,6 +239,10 @@ class Swagger(SwaggerDict):
self.schemes = [url.scheme] self.schemes = [url.scheme]
self.base_path = self.get_base_path(get_script_prefix(), _prefix) self.base_path = self.get_base_path(get_script_prefix(), _prefix)
self.consumes = consumes
self.produces = produces
self.security_definitions = filter_none(security_definitions)
self.security = filter_none(security)
self.paths = paths self.paths = paths
self.definitions = filter_none(definitions) self.definitions = filter_none(definitions)
self._insert_extras__() self._insert_extras__()
@ -304,8 +313,8 @@ class PathItem(SwaggerDict):
class Operation(SwaggerDict): class Operation(SwaggerDict):
def __init__(self, operation_id, responses, parameters=None, consumes=None, def __init__(self, operation_id, responses, parameters=None, consumes=None, produces=None, summary=None,
produces=None, summary=None, description=None, tags=None, **extra): description=None, tags=None, security=None, **extra):
"""Information about an API operation (path + http method combination) """Information about an API operation (path + http method combination)
:param str operation_id: operation ID, should be unique across all operations :param str operation_id: operation ID, should be unique across all operations
@ -316,6 +325,7 @@ class Operation(SwaggerDict):
:param str summary: operation summary; should be < 120 characters :param str summary: operation summary; should be < 120 characters
:param str description: operation description; can be of any length and supports markdown :param str description: operation description; can be of any length and supports markdown
:param list[str] tags: operation tags :param list[str] tags: operation tags
:param list[dict[str,list[str]]] security: list of security requirements
""" """
super(Operation, self).__init__(**extra) super(Operation, self).__init__(**extra)
self.operation_id = operation_id self.operation_id = operation_id
@ -326,6 +336,7 @@ class Operation(SwaggerDict):
self.consumes = filter_none(consumes) self.consumes = filter_none(consumes)
self.produces = filter_none(produces) self.produces = filter_none(produces)
self.tags = filter_none(tags) self.tags = filter_none(tags)
self.security = filter_none(security)
self._insert_extras__() self._insert_extras__()

View File

@ -4,6 +4,7 @@ from collections import OrderedDict
from rest_framework import serializers, status from rest_framework import serializers, status
from rest_framework.mixins import DestroyModelMixin, RetrieveModelMixin, UpdateModelMixin from rest_framework.mixins import DestroyModelMixin, RetrieveModelMixin, UpdateModelMixin
from rest_framework.request import is_form_media_type
from rest_framework.views import APIView from rest_framework.views import APIView
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -248,3 +249,30 @@ def force_serializer_instance(serializer):
assert isinstance(serializer, serializers.BaseSerializer), \ assert isinstance(serializer, serializers.BaseSerializer), \
"Serializer class or instance required, not %s" % type(serializer).__name__ "Serializer class or instance required, not %s" % type(serializer).__name__
return serializer return serializer
def get_consumes(parser_classes):
"""Extract ``consumes`` MIME types from a list of parser classes.
:param list parser_classes: parser classes
:return: MIME types for ``consumes``
:rtype: list[str]
"""
media_types = [parser.media_type for parser in parser_classes or []]
if all(is_form_media_type(encoding) for encoding in media_types):
return media_types
else:
media_types = [encoding for encoding in media_types if not is_form_media_type(encoding)]
return media_types
def get_produces(renderer_classes):
"""Extract ``produces`` MIME types from a list of renderer classes.
:param list renderer_classes: renderer classes
:return: MIME types for ``produces``
:rtype: list[str]
"""
media_types = [renderer.media_type for renderer in renderer_classes or []]
media_types = [encoding for encoding in media_types if 'html' not in encoding]
return media_types

View File

@ -2,6 +2,7 @@ from djangorestframework_camel_case.parser import CamelCaseJSONParser
from djangorestframework_camel_case.render import CamelCaseJSONRenderer from djangorestframework_camel_case.render import CamelCaseJSONRenderer
from inflection import camelize from inflection import camelize
from rest_framework import generics from rest_framework import generics
from rest_framework.parsers import FormParser
from drf_yasg import openapi from drf_yasg import openapi
from drf_yasg.inspectors import SwaggerAutoSchema from drf_yasg.inspectors import SwaggerAutoSchema
@ -21,7 +22,7 @@ class SnippetList(generics.ListCreateAPIView):
queryset = Snippet.objects.all() queryset = Snippet.objects.all()
serializer_class = SnippetSerializer serializer_class = SnippetSerializer
parser_classes = (CamelCaseJSONParser,) parser_classes = (FormParser, CamelCaseJSONParser,)
renderer_classes = (CamelCaseJSONRenderer,) renderer_classes = (CamelCaseJSONRenderer,)
swagger_schema = CamelCaseOperationIDAutoSchema swagger_schema = CamelCaseOperationIDAutoSchema

View File

@ -16,6 +16,15 @@ host: test.local:8002
schemes: schemes:
- http - http
basePath: / basePath: /
consumes:
- application/json
produces:
- application/json
securityDefinitions:
basic:
type: basic
security:
- basic: []
paths: paths:
/articles/: /articles/:
get: get:
@ -63,8 +72,6 @@ paths:
type: array type: array
items: items:
$ref: '#/definitions/Article' $ref: '#/definitions/Article'
consumes:
- application/json
tags: tags:
- articles - articles
post: post:
@ -81,8 +88,6 @@ paths:
description: '' description: ''
schema: schema:
$ref: '#/definitions/Article' $ref: '#/definitions/Article'
consumes:
- application/json
tags: tags:
- articles - articles
parameters: [] parameters: []
@ -108,8 +113,6 @@ paths:
type: array type: array
items: items:
$ref: '#/definitions/Article' $ref: '#/definitions/Article'
consumes:
- application/json
tags: tags:
- articles - articles
parameters: [] parameters: []
@ -123,8 +126,6 @@ paths:
description: '' description: ''
schema: schema:
$ref: '#/definitions/Article' $ref: '#/definitions/Article'
consumes:
- application/json
tags: tags:
- articles - articles
put: put:
@ -141,8 +142,6 @@ paths:
description: '' description: ''
schema: schema:
$ref: '#/definitions/Article' $ref: '#/definitions/Article'
consumes:
- application/json
tags: tags:
- articles - articles
patch: patch:
@ -161,8 +160,6 @@ paths:
$ref: '#/definitions/Article' $ref: '#/definitions/Article'
'404': '404':
description: slug not found description: slug not found
consumes:
- application/json
tags: tags:
- articles - articles
delete: delete:
@ -172,8 +169,6 @@ paths:
responses: responses:
'204': '204':
description: '' description: ''
consumes:
- application/json
tags: tags:
- articles - articles
parameters: parameters:
@ -246,8 +241,6 @@ paths:
responses: responses:
'200': '200':
description: '' description: ''
consumes:
- application/json
tags: tags:
- plain - plain
parameters: [] parameters: []
@ -263,8 +256,6 @@ paths:
type: array type: array
items: items:
$ref: '#/definitions/Snippet' $ref: '#/definitions/Snippet'
consumes:
- application/json
tags: tags:
- snippets - snippets
post: post:
@ -281,8 +272,6 @@ paths:
description: '' description: ''
schema: schema:
$ref: '#/definitions/Snippet' $ref: '#/definitions/Snippet'
consumes:
- application/json
tags: tags:
- snippets - snippets
parameters: [] parameters: []
@ -296,8 +285,6 @@ paths:
description: '' description: ''
schema: schema:
$ref: '#/definitions/Snippet' $ref: '#/definitions/Snippet'
consumes:
- application/json
tags: tags:
- snippets - snippets
put: put:
@ -314,8 +301,6 @@ paths:
description: '' description: ''
schema: schema:
$ref: '#/definitions/Snippet' $ref: '#/definitions/Snippet'
consumes:
- application/json
tags: tags:
- snippets - snippets
patch: patch:
@ -332,8 +317,6 @@ paths:
description: '' description: ''
schema: schema:
$ref: '#/definitions/Snippet' $ref: '#/definitions/Snippet'
consumes:
- application/json
tags: tags:
- snippets - snippets
delete: delete:
@ -348,8 +331,6 @@ paths:
responses: responses:
'204': '204':
description: '' description: ''
consumes:
- application/json
tags: tags:
- snippets - snippets
parameters: parameters:
@ -380,8 +361,6 @@ paths:
type: array type: array
items: items:
$ref: '#/definitions/UserSerializerrr' $ref: '#/definitions/UserSerializerrr'
consumes:
- application/json
tags: tags:
- users - users
post: post:
@ -408,8 +387,6 @@ paths:
properties: properties:
username: username:
type: string type: string
consumes:
- application/json
tags: tags:
- users - users
security: [] security: []
@ -420,8 +397,6 @@ paths:
responses: responses:
'200': '200':
description: '' description: ''
consumes:
- application/json
tags: tags:
- users - users
parameters: [] parameters: []
@ -439,8 +414,6 @@ paths:
description: response description description: response description
schema: schema:
$ref: '#/definitions/UserSerializerrr' $ref: '#/definitions/UserSerializerrr'
consumes:
- application/json
tags: tags:
- users - users
put: put:
@ -457,8 +430,6 @@ paths:
description: '' description: ''
schema: schema:
$ref: '#/definitions/UserSerializerrr' $ref: '#/definitions/UserSerializerrr'
consumes:
- application/json
tags: tags:
- users - users
parameters: parameters:
@ -1121,8 +1092,3 @@ definitions:
pattern: ^[-a-zA-Z0-9_]+$ pattern: ^[-a-zA-Z0-9_]+$
readOnly: true readOnly: true
uniqueItems: true uniqueItems: true
securityDefinitions:
basic:
type: basic
security:
- basic: []