Add swagger_auto_schema method decorator for Operation customization

See #5.
openapi3
Cristi Vîjdea 2017-12-08 05:33:01 +01:00
parent 82cac4ef0d
commit 652795f5db
11 changed files with 443 additions and 215 deletions

View File

@ -29,36 +29,61 @@ class OpenAPISchemaGenerator(object):
inspector = self._gen.endpoint_inspector_cls(self._gen.patterns, self._gen.urlconf) inspector = self._gen.endpoint_inspector_cls(self._gen.patterns, self._gen.urlconf)
self.endpoints = inspector.get_api_endpoints() self.endpoints = inspector.get_api_endpoints()
self.get_endpoints(None if public else request) self.clean_endpoints(None if public else request)
paths = self.get_paths() paths = self.get_paths()
url = self._gen.url url = self._gen.url
if not url and request is not None: if not url and request is not None:
url = request.build_absolute_uri() url = request.build_absolute_uri()
# distribute_links(links)
return openapi.Swagger( return openapi.Swagger(
info=self.info, paths=paths, info=self.info, paths=paths,
_url=url, _version=self.version, _url=url, _version=self.version,
) )
def get_endpoints(self, request): def create_view(self, callback, method, request):
view = self._gen.create_view(callback, method, request)
overrides = getattr(callback, 'swagger_auto_schema', None)
if overrides is not None:
# decorated function based view
for method, _ in overrides.items():
view_method = getattr(view, method, None)
if view_method is not None:
setattr(view_method.__func__, 'swagger_auto_schema', overrides)
return view
def clean_endpoints(self, request):
"""Generate {path: (view_class, [(method, view)]) given (path, method, callback).""" """Generate {path: (view_class, [(method, view)]) given (path, method, callback)."""
view_paths = defaultdict(list) view_paths = defaultdict(list)
view_cls = {} view_cls = {}
for path, method, callback in self.endpoints: for path, method, callback in self.endpoints:
view = self._gen.create_view(callback, method, request) view = self.create_view(callback, method, request)
path = self._gen.coerce_path(path, method, view) path = self._gen.coerce_path(path, method, view)
view_paths[path].append((method, view)) view_paths[path].append((method, view))
view_cls[path] = callback.cls view_cls[path] = callback.cls
self.endpoints = {path: (view_cls[path], methods) for path, methods in view_paths.items()} self.endpoints = {path: (view_cls[path], methods) for path, methods in view_paths.items()}
def get_operation_keys(self, subpath, method, view):
"""
Return a list of keys that should be used to layout a link within
the schema document.
/users/ ("users", "list"), ("users", "create")
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
/users/enabled/ ("users", "enabled") # custom viewset list action
/users/{pk}/star/ ("users", "star") # custom viewset detail action
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update")
"""
return self._gen.get_keys(subpath, method, view)
def get_paths(self): def get_paths(self):
if not self.endpoints: if not self.endpoints:
return [] return []
prefix = self._gen.determine_path_prefix(self.endpoints.keys()) prefix = self._gen.determine_path_prefix(self.endpoints.keys())
paths = OrderedDict() paths = OrderedDict()
default_schema_cls = SwaggerAutoSchema
for path, (view_cls, methods) in sorted(self.endpoints.items()): for path, (view_cls, methods) in sorted(self.endpoints.items()):
path_parameters = self.get_path_parameters(path, view_cls) path_parameters = self.get_path_parameters(path, view_cls)
operations = {} operations = {}
@ -66,14 +91,26 @@ class OpenAPISchemaGenerator(object):
if not self._gen.has_view_permissions(path, method, view): if not self._gen.has_view_permissions(path, method, view):
continue continue
schema = SwaggerAutoSchema(view) operation_keys = self.get_operation_keys(path[len(prefix):], method, view)
operation_keys = self._gen.get_keys(path[len(prefix):], method, view) overrides = self.get_overrides(view, method)
operations[method.lower()] = schema.get_operation(operation_keys, path, method) auto_schema_cls = overrides.get('auto_schema', default_schema_cls)
schema = auto_schema_cls(view, path, method, overrides)
operations[method.lower()] = schema.get_operation(operation_keys)
paths[path] = openapi.PathItem(parameters=path_parameters, **operations) paths[path] = openapi.PathItem(parameters=path_parameters, **operations)
return openapi.Paths(paths=paths) return openapi.Paths(paths=paths)
def get_overrides(self, view, method):
method = method.lower()
action = getattr(view, 'action', method)
action_method = getattr(view, action, None)
overrides = getattr(action_method, 'swagger_auto_schema', {})
if method in overrides:
overrides = overrides[method]
return overrides
def get_path_parameters(self, path, view_cls): def get_path_parameters(self, path, view_cls):
"""Return a list of Parameter instances corresponding to any templated path variables. """Return a list of Parameter instances corresponding to any templated path variables.

View File

@ -1,15 +1,137 @@
import functools import functools
import inspect
from collections import OrderedDict from collections import OrderedDict
import coreschema import coreschema
from django.core.validators import RegexValidator from django.core.validators import RegexValidator
from django.utils.encoding import force_text from django.utils.encoding import force_text
from rest_framework import serializers from rest_framework import serializers
from rest_framework.request import is_form_media_type
from rest_framework.schemas import AutoSchema from rest_framework.schemas import AutoSchema
from rest_framework.schemas.utils import is_list_view from rest_framework.viewsets import GenericViewSet
from drf_swagger.errors import SwaggerGenerationError from drf_swagger.errors import SwaggerGenerationError
from . import openapi from . import openapi
from .utils import no_body, is_list_view
def serializer_field_to_swagger(field, swagger_object_type, **kwargs):
"""Convert a drf Serializer or Field instance into a Swagger object.
:param rest_framework.serializers.Field field: the source field
:param type swagger_object_type: should be one of Schema, Parameter, Items
:param kwargs: extra attributes for constructing the object;
if swagger_object_type is Parameter, `name` and `in_` should be provided
:return Swagger,Parameter,Items: the swagger object
"""
assert swagger_object_type in (openapi.Schema, openapi.Parameter, openapi.Items)
title = force_text(field.label) if field.label else None
title = title if swagger_object_type == openapi.Schema else None # only Schema has title
title = None
description = force_text(field.help_text) if field.help_text else None
description = description if swagger_object_type != openapi.Items else None # Items has no description either
SwaggerType = functools.partial(swagger_object_type, title=title, description=description, **kwargs)
# arrays in Schema have Schema elements, arrays in Parameter and Items have Items elements
ChildSwaggerType = openapi.Schema if swagger_object_type == openapi.Schema else openapi.Items
# ------ NESTED
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = serializer_field_to_swagger(field.child, ChildSwaggerType)
return SwaggerType(
type=openapi.TYPE_ARRAY,
items=child_schema,
)
elif isinstance(field, serializers.Serializer):
if swagger_object_type != openapi.Schema:
raise SwaggerGenerationError("cannot instantiate nested serializer as "
+ swagger_object_type.__name__)
return SwaggerType(
type=openapi.TYPE_OBJECT,
properties=OrderedDict(
(key, serializer_field_to_swagger(value, ChildSwaggerType))
for key, value
in field.fields.items()
)
)
elif isinstance(field, serializers.ManyRelatedField):
child_schema = serializer_field_to_swagger(field.child_relation, ChildSwaggerType)
return SwaggerType(
type=openapi.TYPE_ARRAY,
items=child_schema,
unique_items=True, # is this OK?
)
elif isinstance(field, serializers.RelatedField):
# TODO: infer type for PrimaryKeyRelatedField?
return SwaggerType(type=openapi.TYPE_STRING)
# ------ CHOICES
elif isinstance(field, serializers.MultipleChoiceField):
return SwaggerType(
type=openapi.TYPE_ARRAY,
items=ChildSwaggerType(
type=openapi.TYPE_STRING,
enum=list(field.choices.keys())
)
)
elif isinstance(field, serializers.ChoiceField):
return SwaggerType(type=openapi.TYPE_STRING, enum=list(field.choices.keys()))
# ------ BOOL
elif isinstance(field, serializers.BooleanField):
return SwaggerType(type=openapi.TYPE_BOOLEAN)
# ------ NUMERIC
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
# TODO: min_value max_value
return SwaggerType(type=openapi.TYPE_NUMBER)
elif isinstance(field, serializers.IntegerField):
# TODO: min_value max_value
return SwaggerType(type=openapi.TYPE_INTEGER)
# ------ STRING
elif isinstance(field, serializers.EmailField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_EMAIL)
elif isinstance(field, serializers.RegexField):
return SwaggerType(type=openapi.TYPE_STRING, pattern=find_regex(field))
elif isinstance(field, serializers.SlugField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_SLUG)
elif isinstance(field, serializers.URLField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)
elif isinstance(field, serializers.IPAddressField):
format = {'ipv4': openapi.FORMAT_IPV4, 'ipv6': openapi.FORMAT_IPV6}.get(field.protocol, None)
return SwaggerType(type=openapi.TYPE_STRING, format=format)
elif isinstance(field, serializers.CharField):
# TODO: min_length max_length (for all CharField subclasses above too)
return SwaggerType(type=openapi.TYPE_STRING)
elif isinstance(field, serializers.UUIDField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_UUID)
# ------ DATE & TIME
elif isinstance(field, serializers.DateField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_DATE)
elif isinstance(field, serializers.DateTimeField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_DATETIME)
# ------ OTHERS
elif isinstance(field, serializers.FileField):
# swagger 2.0 does not support specifics about file fields, so ImageFile gets no special treatment
# OpenAPI 3.0 does support it, so a future implementation could handle this better
# TODO: appropriate produces/consumes somehow/somewhere?
if swagger_object_type != openapi.Parameter:
raise SwaggerGenerationError("parameter of type file is supported only in formData Parameter")
return SwaggerType(type=openapi.TYPE_FILE)
elif isinstance(field, serializers.JSONField):
return SwaggerType(
type=openapi.TYPE_STRING,
format=openapi.FORMAT_BINARY if field.binary else None
)
elif isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema:
child_schema = serializer_field_to_swagger(field.child, ChildSwaggerType)
return SwaggerType(
type=openapi.TYPE_OBJECT,
additional_properties=child_schema
)
# TODO unhandled fields: TimeField DurationField HiddenField ModelField NullBooleanField?
# TODO: return info about required/allowed empty
# everything else gets string by default
return SwaggerType(type=openapi.TYPE_STRING)
def find_regex(regex_field): def find_regex(regex_field):
@ -26,122 +148,165 @@ def find_regex(regex_field):
class SwaggerAutoSchema(object): class SwaggerAutoSchema(object):
def __init__(self, view): def __init__(self, view, path, method, overrides):
super(SwaggerAutoSchema, self).__init__() super(SwaggerAutoSchema, self).__init__()
self._sch = AutoSchema() self._sch = AutoSchema()
self.view = view self.view = view
self.path = path
self.method = method
self.overrides = overrides
self._sch.view = view self._sch.view = view
def get_operation(self, operation_keys, path, method): def get_operation(self, operation_keys):
"""Get an Operation for the given API endpoint (path, method). """Get an Operation for the given API endpoint (path, method).
This includes query, body parameters and response schemas. This includes query, body parameters and response schemas.
:param tuple[str] operation_keys: an array of keys describing the hierarchical layout of this view in the API; :param tuple[str] operation_keys: an array of keys describing the hierarchical layout of this view in the API;
e.g. ('snippets', 'list'), ('snippets', 'retrieve'), etc. e.g. ('snippets', 'list'), ('snippets', 'retrieve'), etc.
:param str path: the view's path
:param str method: HTTP request method
:return openapi.Operation: the resulting Operation object :return openapi.Operation: the resulting Operation object
""" """
body = self.get_request_body_parameters(path, method) consumes = self.get_consumes()
query = self.get_query_parameters(path, method)
body = self.get_request_body_parameters(consumes)
query = self.get_query_parameters()
parameters = body + query parameters = body + query
parameters = [param for param in parameters if param is not None] parameters = [param for param in parameters if param is not None]
description = self.get_description(path, method) parameters = self.add_manual_parameters(parameters)
responses = self.get_responses(path, method)
description = self.get_description()
responses = self.get_responses()
# manual_responses = self.overrides.get('responses', None) or {}
return openapi.Operation( return openapi.Operation(
operation_id='_'.join(operation_keys), operation_id='_'.join(operation_keys),
description=description, description=description,
responses=responses, responses=responses,
parameters=parameters, parameters=parameters,
tags=[operation_keys[0]] consumes=consumes,
tags=[operation_keys[0]],
) )
def get_request_body_parameters(self, path, method): def get_request_body_parameters(self, consumes):
"""Return the request body parameters for this view. """Return the request body parameters for this view.
This is either: This is either:
- a list with a single object Parameter with a Schema derived from the request serializer - a list with a single object Parameter with a Schema derived from the request serializer
- a list of primitive Parameters parsed as form data - a list of primitive Parameters parsed as form data
:param str path: the view's path :param list[str] consumes: a list of MIME types this request accepts as body
:param str method: HTTP request method
:return list[Parameter]: a (potentially empty) list of openapi.Parameter in: either `body` or `formData` :return list[Parameter]: a (potentially empty) list of openapi.Parameter in: either `body` or `formData`
""" """
# only PUT, PATCH or POST can have a request body # only PUT, PATCH or POST can have a request body
if method not in ('PUT', 'PATCH', 'POST'): if self.method not in ('PUT', 'PATCH', 'POST'):
return [] return []
serializer = self.get_request_serializer(path, method) serializer = self.get_request_serializer()
schema = None
if serializer is None: if serializer is None:
return [] return []
encoding = self._sch.get_encoding(path, method) if isinstance(serializer, openapi.Schema):
if 'form' in encoding: schema = serializer
return [
self.field_to_swagger(value, openapi.Parameter, name=key, in_=openapi.IN_FORM)
for key, value
in serializer.fields.items()
]
else:
schema = self.get_request_body_schema(path, method, serializer)
return [openapi.Parameter(name='data', in_=openapi.IN_BODY, schema=schema)]
def get_request_serializer(self, path, method): if any(is_form_media_type(encoding) for encoding in consumes):
if schema is not None:
raise SwaggerGenerationError("form request body cannot be a Schema")
return self.get_request_form_parameters(serializer)
else:
if schema is None:
schema = self.get_request_body_schema(serializer)
return [self.make_body_parameter(schema)]
def get_request_serializer(self):
"""Return the request serializer (used for parsing the request payload) for this endpoint. """Return the request serializer (used for parsing the request payload) for this endpoint.
:param str path: the view's path
:param str method: HTTP request method
:return serializers.Serializer: the request serializer :return serializers.Serializer: the request serializer
""" """
# TODO: only GenericAPIViews have defined serializers; body_override = self.overrides.get('request_body', None)
# APIViews and plain ViewSets will need some kind of manual treatment
if not hasattr(self.view, 'get_serializer'):
return None
return self.view.get_serializer() if body_override is not None:
if body_override is no_body:
return None
if inspect.isclass(body_override):
assert issubclass(body_override, serializers.Serializer)
return body_override()
return body_override
else:
if not hasattr(self.view, 'get_serializer'):
return None
return self.view.get_serializer()
def get_request_body_schema(self, path, method, serializer): def get_request_form_parameters(self, serializer):
"""Given a Serializer, return a list of in: formData Parameters.
:param serializer: the view's request serialzier
"""
return [
self.field_to_swagger(value, openapi.Parameter, name=key, in_=openapi.IN_FORM)
for key, value
in serializer.fields.items()
]
def get_request_body_schema(self, serializer):
"""Return the Schema for a given request's body data. Only applies to PUT, PATCH and POST requests. """Return the Schema for a given request's body data. Only applies to PUT, PATCH and POST requests.
:param str path: the view's path
:param str method: HTTP request method
:param serializer: the view's request serialzier :param serializer: the view's request serialzier
:return openapi.Schema: the request body schema :return openapi.Schema: the request body schema
""" """
return self.field_to_swagger(serializer, openapi.Schema) return self.field_to_swagger(serializer, openapi.Schema)
def get_responses(self, path, method): def make_body_parameter(self, schema):
"""Given a Schema object, create an in: body Parameter.
:param openapi.Schema schema: the request body schema
"""
return openapi.Parameter(name='data', in_=openapi.IN_BODY, schema=schema)
def add_manual_parameters(self, parameters):
"""Add/replace parameters from the given list of automatically generated request parameters.
:param list[openapi.Parameter] parameters: genereated parameters
:return list[openapi.Parameter]: modified parameters
"""
parameters = OrderedDict(((param.name, param.in_), param) for param in parameters)
manual_parameters = self.overrides.get('manual_parameters', None) or []
if any(param.in_ == openapi.IN_BODY for param in manual_parameters):
raise SwaggerGenerationError("specify the body parameter as a Schema or Serializer in request_body")
if any(param.in_ == openapi.IN_FORM for param in manual_parameters):
if any(param.in_ == openapi.IN_BODY for param in parameters.values()):
raise SwaggerGenerationError("cannot add form parameters when the request has a request schema; "
"did you forget to set an appropriate parser class on the view?")
parameters.update(((param.name, param.in_), param) for param in manual_parameters)
return list(parameters.values())
def get_responses(self):
"""Get the possible responses for this view as a swagger Responses object. """Get the possible responses for this view as a swagger Responses object.
:param str path: the view's path
:param str method: HTTP request method
:return Responses: the documented responses :return Responses: the documented responses
""" """
response_serializers = self.get_response_serializers(path, method) response_serializers = self.get_response_serializers()
return openapi.Responses( return openapi.Responses(
responses=self.get_response_schemas(path, method, response_serializers) responses=self.get_response_schemas(response_serializers)
) )
def get_response_serializers(self, path, method): def get_response_serializers(self):
"""Return the response codes that this view is expected to return, and the serializer for each response body. """Return the response codes that this view is expected to return, and the serializer for each response body.
The return value should be a dict where the keys are possible status codes, and values are either strings, The return value should be a dict where the keys are possible status codes, and values are either strings,
`Serializer`s or `openapi.Response` objects. `Serializer` or `openapi.Response` objects.
:param str path: the view's path
:param str method: HTTP request method
:return dict: the response serializers :return dict: the response serializers
""" """
if method.lower() == 'post': if self.method.lower() == 'post':
return {'201': ''} return {'201': ''}
if method.lower() == 'delete': if self.method.lower() == 'delete':
return {'204': ''} return {'204': ''}
return {'200': ''} return {'200': ''}
def get_response_schemas(self, path, method, response_serializers): def get_response_schemas(self, response_serializers):
"""Return the `openapi.Response` objects calculated for this view. """Return the `openapi.Response` objects calculated for this view.
:param str path: the view's path
:param str method: HTTP request method
:param dict response_serializers: result of get_response_serializers :param dict response_serializers: result of get_response_serializers
:return dict[str, openapi.Response]: a dictionary of status code to Response object :return dict[str, openapi.Response]: a dictionary of status code to Response object
""" """
@ -163,23 +328,25 @@ class SwaggerAutoSchema(object):
return responses return responses
def get_query_parameters(self, path, method): def get_query_parameters(self):
"""Return the query parameters accepted by this view. """Return the query parameters accepted by this view."""
return self.get_filter_parameters() + self.get_pagination_parameters()
:param str path: the view's path def should_filter(self):
:param str method: HTTP request method if getattr(self.view, 'filter_backends', None) is None:
:return list[openapi.Parameter]: the query parameters return False
"""
return self.get_filter_parameters(path, method) + self.get_pagination_parameters(path, method)
def get_filter_parameters(self, path, method): if self.method.lower() not in ["get", "put", "patch", "delete"]:
"""Return the parameters added to the view by its filter backends. return False
:param str path: the view's path if not isinstance(self.view, GenericViewSet):
:param str method: HTTP request method return True
:return list[openapi.Parameter]: the filter query parameters
""" return is_list_view(self.path, self.method, self.view)
if not self._sch._allows_filters(path, method):
def get_filter_parameters(self):
"""Return the parameters added to the view by its filter backends."""
if not self.should_filter():
return [] return []
fields = [] fields = []
@ -189,30 +356,27 @@ class SwaggerAutoSchema(object):
fields += filter.get_schema_fields(self.view) fields += filter.get_schema_fields(self.view)
return [self.coreapi_field_to_parameter(field) for field in fields] return [self.coreapi_field_to_parameter(field) for field in fields]
def get_pagination_parameters(self, path, method): def should_page(self):
"""Return the parameters added to the view by its paginator. if not hasattr(self.view, 'paginator'):
return False
:param str path: the view's path return is_list_view(self.path, self.method, self.view)
:param str method: HTTP request method
:return list[openapi.Parameter]: the pagination query parameters def get_pagination_parameters(self):
""" """Return the parameters added to the view by its paginator."""
if not is_list_view(path, method, self.view): if not self.should_page():
return [] return []
paginator = getattr(self.view, 'paginator', None) paginator = self.view.paginator
if paginator is None: if not hasattr(paginator, 'get_schema_fields'):
return [] return []
return [ return [self.coreapi_field_to_parameter(field) for field in paginator.get_schema_fields(self.view)]
self.coreapi_field_to_parameter(field)
for field in paginator.get_schema_fields(self.view)
]
def coreapi_field_to_parameter(self, field): def coreapi_field_to_parameter(self, field):
"""Convert an instance of `coreapi.Field` to a swagger Parameter object. """Convert an instance of `coreapi.Field` to a swagger Parameter object.
:param coreapi.Field field: the coreapi field :param coreapi.Field field: the coreapi field
:return openapi.Parameter: the equivalent openapi primitive Parameter
""" """
location_to_in = { location_to_in = {
'query': openapi.IN_QUERY, 'query': openapi.IN_QUERY,
@ -234,129 +398,22 @@ class SwaggerAutoSchema(object):
description=field.schema.description, description=field.schema.description,
) )
def get_description(self, path, method): def get_description(self):
"""Return an operation description determined as appropriate from the view's method and class docstrings. """Return an operation description determined as appropriate from the view's method and class docstrings.
:param str path: the view's path
:param str method: HTTP request method
:return str: the operation description :return str: the operation description
""" """
return self._sch.get_description(path, method) description = self.overrides.get('operation_description', None)
if description is None:
description = self._sch.get_description(self.path, self.method)
return description
def get_consumes(self):
"""Return the MIME types this endpoint can consume."""
media_types = [parser.media_type for parser in getattr(self.view, 'parser_classes', [])]
if all(is_form_media_type(encoding) for encoding in media_types):
return media_types
return media_types[:1]
def field_to_swagger(self, field, swagger_object_type, **kwargs): def field_to_swagger(self, field, swagger_object_type, **kwargs):
"""Convert a drf Serializer or Field instance into a Swagger object. return serializer_field_to_swagger(field, swagger_object_type, **kwargs)
:param rest_framework.serializers.Field field: the source field
:param type swagger_object_type: should be one of Schema, Parameter, Items
:param kwargs: extra attributes for constructing the object;
if swagger_object_type is Parameter, `name` and `in_` should be provided
:return Swagger,Parameter,Items: the swagger object
"""
assert swagger_object_type in (openapi.Schema, openapi.Parameter, openapi.Items)
title = force_text(field.label) if field.label else None
title = title if swagger_object_type == openapi.Schema else None # only Schema has title
title = None
description = force_text(field.help_text) if field.help_text else None
description = description if swagger_object_type != openapi.Items else None # Items has no description either
SwaggerType = functools.partial(swagger_object_type, title=title, description=description, **kwargs)
# arrays in Schema have Schema elements, arrays in Parameter and Items have Items elements
ChildSwaggerType = openapi.Schema if swagger_object_type == openapi.Schema else openapi.Items
# ------ NESTED
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = self.field_to_swagger(field.child, ChildSwaggerType)
return SwaggerType(
type=openapi.TYPE_ARRAY,
items=child_schema,
)
elif isinstance(field, serializers.Serializer):
if swagger_object_type != openapi.Schema:
raise SwaggerGenerationError("cannot instantiate nested serializer as "
+ swagger_object_type.__name__)
return SwaggerType(
type=openapi.TYPE_OBJECT,
properties=OrderedDict(
(key, self.field_to_swagger(value, ChildSwaggerType))
for key, value
in field.fields.items()
)
)
elif isinstance(field, serializers.ManyRelatedField):
child_schema = self.field_to_swagger(field.child_relation, ChildSwaggerType)
return SwaggerType(
type=openapi.TYPE_ARRAY,
items=child_schema,
unique_items=True, # is this OK?
)
elif isinstance(field, serializers.RelatedField):
# TODO: infer type for PrimaryKeyRelatedField?
return SwaggerType(type=openapi.TYPE_STRING)
# ------ CHOICES
elif isinstance(field, serializers.MultipleChoiceField):
return SwaggerType(
type=openapi.TYPE_ARRAY,
items=ChildSwaggerType(
type=openapi.TYPE_STRING,
enum=list(field.choices.keys())
)
)
elif isinstance(field, serializers.ChoiceField):
return SwaggerType(type=openapi.TYPE_STRING, enum=list(field.choices.keys()))
# ------ BOOL
elif isinstance(field, serializers.BooleanField):
return SwaggerType(type=openapi.TYPE_BOOLEAN)
# ------ NUMERIC
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
# TODO: min_value max_value
return SwaggerType(type=openapi.TYPE_NUMBER)
elif isinstance(field, serializers.IntegerField):
# TODO: min_value max_value
return SwaggerType(type=openapi.TYPE_INTEGER)
# ------ STRING
elif isinstance(field, serializers.EmailField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_EMAIL)
elif isinstance(field, serializers.RegexField):
return SwaggerType(type=openapi.TYPE_STRING, pattern=find_regex(field))
elif isinstance(field, serializers.SlugField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_SLUG)
elif isinstance(field, serializers.URLField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)
elif isinstance(field, serializers.IPAddressField):
format = {'ipv4': openapi.FORMAT_IPV4, 'ipv6': openapi.FORMAT_IPV6}.get(field.protocol, None)
return SwaggerType(type=openapi.TYPE_STRING, format=format)
elif isinstance(field, serializers.CharField):
# TODO: min_length max_length (for all CharField subclasses above too)
return SwaggerType(type=openapi.TYPE_STRING)
elif isinstance(field, serializers.UUIDField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_UUID)
# ------ DATE & TIME
elif isinstance(field, serializers.DateField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_DATE)
elif isinstance(field, serializers.DateTimeField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_DATETIME)
# ------ OTHERS
elif isinstance(field, serializers.FileField):
# swagger 2.0 does not support specifics about file fields, so ImageFile gets no special treatment
# OpenAPI 3.0 does support it, so a future implementation could handle this better
# TODO: appropriate produces/consumes somehow/somewhere?
if swagger_object_type != openapi.Parameter:
raise SwaggerGenerationError("parameter of type file is supported only in formData Parameter")
return SwaggerType(type=openapi.TYPE_FILE)
elif isinstance(field, serializers.JSONField):
return SwaggerType(
type=openapi.TYPE_STRING,
format=openapi.FORMAT_BINARY if field.binary else None
)
elif isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema:
child_schema = self.field_to_swagger(field.child, ChildSwaggerType)
return SwaggerType(
type=openapi.TYPE_OBJECT,
additional_properties=child_schema
)
# TODO unhandled fields: TimeField DurationField HiddenField ModelField NullBooleanField?
# TODO: return info about required/allowed empty
# everything else gets string by default
return SwaggerType(type=openapi.TYPE_STRING)

View File

@ -212,11 +212,11 @@ class Operation(SwaggerDict):
produces=None, description=None, tags=None, **extra): produces=None, description=None, tags=None, **extra):
super(Operation, self).__init__(**extra) super(Operation, self).__init__(**extra)
self.operation_id = operation_id self.operation_id = operation_id
self.responses = responses self.description = description
self.parameters = [param for param in parameters if param is not None] self.parameters = [param for param in parameters if param is not None]
self.responses = responses
self.consumes = consumes self.consumes = consumes
self.produces = produces self.produces = produces
self.description = description
self.tags = tags self.tags = tags
self._insert_extras__() self._insert_extras__()

View File

@ -0,0 +1,90 @@
from rest_framework.mixins import RetrieveModelMixin, DestroyModelMixin
no_body = object()
class UpdateModelMixing(object):
pass
def is_list_view(path, method, view):
"""Return True if the given path/method appears to represent a list view (as opposed to a detail/instance view)."""
# for ViewSets, it could be the default 'list' view, or a list_route
action = getattr(view, 'action', '')
method = getattr(view, action, None)
detail = getattr(method, 'detail', None)
suffix = getattr(view, 'suffix', None)
if action == 'list' or detail is False or suffix == 'List':
return True
if action in ('retrieve', 'update', 'partial_update', 'destroy') or detail is True or suffix == 'Instance':
# a detail_route is surely not a list route
return False
# for APIView, if it's a detail view it can't also be a list view
if isinstance(view, (RetrieveModelMixin, UpdateModelMixing, DestroyModelMixin)):
return False
# if the last component in the path is parameterized it's probably not a list view
path_components = path.strip('/').split('/')
if path_components and '{' in path_components[-1]:
return False
# otherwise assume it's a list route
return True
def swagger_auto_schema(method=None, methods=None, auto_schema=None, request_body=None, manual_parameters=None,
operation_description=None, responses=None):
def decorator(view_method):
data = {
'auto_schema': auto_schema,
'request_body': request_body,
'manual_parameters': manual_parameters,
'operation_description': operation_description,
'responses': responses,
}
data = {k: v for k, v in data.items() if v is not None}
bind_to_methods = getattr(view_method, 'bind_to_methods', [])
# if the method is actually a function based view
view_cls = getattr(view_method, 'cls', None)
http_method_names = getattr(view_cls, 'http_method_names', [])
if bind_to_methods or http_method_names:
# detail_route, list_route or api_view
assert bool(http_method_names) != bool(bind_to_methods), "this should never happen"
available_methods = http_method_names + bind_to_methods
existing_data = getattr(view_method, 'swagger_auto_schema', {})
if http_method_names:
_route = "api_view"
else:
_route = "detail_route" if view_method.detail else "list_route"
_methods = methods
if len(available_methods) > 1:
assert methods or method, \
"on multi-method %s, you must specify swagger_auto_schema on a per-method basis " \
"using one of the `method` or `methods` arguments" % _route
assert bool(methods) != bool(method), "specify either method or methods"
if method:
_methods = [method.lower()]
else:
_methods = [mth.lower() for mth in methods]
assert not isinstance(_methods, str)
assert not any(mth in existing_data for mth in _methods), "method defined multiple times"
assert all(mth in available_methods for mth in _methods), "method not bound to %s" % _route
existing_data.update((mth.lower(), data) for mth in _methods)
else:
existing_data[available_methods[0]] = data
view_method.swagger_auto_schema = existing_data
else:
assert methods is None, \
"the methods argument should only be specified when decorating a detail_route or list_route; you " \
"should also ensure that you put the swagger_auto_schema decorator AFTER (above) the _route decorator"
view_method.swagger_auto_schema = data
return view_method
return decorator

View File

@ -9,6 +9,7 @@ from rest_framework.response import Response
from articles import serializers from articles import serializers
from articles.models import Article from articles.models import Article
from drf_swagger.utils import swagger_auto_schema
class ArticleViewSet(viewsets.ModelViewSet): class ArticleViewSet(viewsets.ModelViewSet):
@ -20,6 +21,9 @@ class ArticleViewSet(viewsets.ModelViewSet):
destroy: destroy:
destroy class docstring destroy class docstring
partial_update:
partial_update class docstring
""" """
queryset = Article.objects.all() queryset = Article.objects.all()
lookup_field = 'slug' lookup_field = 'slug'
@ -40,11 +44,9 @@ class ArticleViewSet(viewsets.ModelViewSet):
serializer = self.serializer_class(articles, many=True) serializer = self.serializer_class(articles, many=True)
return Response(serializer.data) return Response(serializer.data)
@detail_route( @swagger_auto_schema(method='get', operation_description="image GET description override")
methods=['get', 'post'], @swagger_auto_schema(method='post', request_body=serializers.ImageUploadSerializer)
parser_classes=(MultiPartParser,), @detail_route(methods=['get', 'post'], parser_classes=(MultiPartParser,))
serializer_class=serializers.ImageUploadSerializer,
)
def image(self, request, slug=None): def image(self, request, slug=None):
""" """
image method docstring image method docstring
@ -55,6 +57,11 @@ class ArticleViewSet(viewsets.ModelViewSet):
"""update method docstring""" """update method docstring"""
return super(ArticleViewSet, self).update(request, *args, **kwargs) return super(ArticleViewSet, self).update(request, *args, **kwargs)
@swagger_auto_schema(operation_description="partial_update description override")
def partial_update(self, request, *args, **kwargs):
"""partial_update method docstring"""
return super(ArticleViewSet, self).partial_update(request, *args, **kwargs)
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
"""destroy method docstring""" """destroy method docstring"""
return super(ArticleViewSet, self).destroy(request, *args, **kwargs) return super(ArticleViewSet, self).destroy(request, *args, **kwargs)

View File

@ -1,9 +1,12 @@
from django.contrib.auth.models import User from django.contrib.auth.models import User
from rest_framework import status
from rest_framework.decorators import api_view from rest_framework.decorators import api_view
from rest_framework.generics import get_object_or_404 from rest_framework.generics import get_object_or_404
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from drf_swagger import openapi
from drf_swagger.utils import swagger_auto_schema
from users.serializers import UserSerializer from users.serializers import UserSerializer
@ -15,8 +18,19 @@ class UserList(APIView):
serializer = UserSerializer(queryset, many=True) serializer = UserSerializer(queryset, many=True)
return Response(serializer.data) return Response(serializer.data)
@swagger_auto_schema(request_body=UserSerializer, operation_description="apiview post description override")
def post(self, request):
serializer = UserSerializer(request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data, status=status.HTTP_201_CREATED)
@api_view(['GET'])
@swagger_auto_schema(method='put', request_body=UserSerializer)
@swagger_auto_schema(method='get', manual_parameters=[
openapi.Parameter('test', openapi.IN_QUERY, "test manual param", type=openapi.TYPE_BOOLEAN)
])
@api_view(['GET', 'PUT'])
def user_detail(request, pk): def user_detail(request, pk):
"""user_detail fbv docstring""" """user_detail fbv docstring"""
user = get_object_or_404(User.objects, pk=pk) user = get_object_or_404(User.objects, pk=pk)

View File

@ -1,3 +1,6 @@
import json
import os
import pytest import pytest
from ruamel import yaml from ruamel import yaml
@ -53,3 +56,9 @@ def bad_settings():
SWAGGER_DEFAULTS['SECURITY_DEFINITIONS'].update(bad_security) SWAGGER_DEFAULTS['SECURITY_DEFINITIONS'].update(bad_security)
yield swagger_settings yield swagger_settings
del SWAGGER_DEFAULTS['SECURITY_DEFINITIONS']['bad'] del SWAGGER_DEFAULTS['SECURITY_DEFINITIONS']['bad']
@pytest.fixture
def reference_schema():
with open(os.path.join(os.path.dirname(__file__), 'reference.json')) as reference:
return json.load(reference)

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,17 @@
from drf_swagger import openapi
def test_operation_docstrings(swagger_dict): def test_operation_docstrings(swagger_dict):
users_list = swagger_dict['paths']['/users/'] users_list = swagger_dict['paths']['/users/']
assert users_list['get']['description'] == "UserList cbv classdoc" assert users_list['get']['description'] == "UserList cbv classdoc"
assert users_list['post']['description'] == "apiview post description override"
users_detail = swagger_dict['paths']['/users/{id}/'] users_detail = swagger_dict['paths']['/users/{id}/']
assert users_detail['get']['description'] == "user_detail fbv docstring" assert users_detail['get']['description'] == "user_detail fbv docstring"
assert users_detail['put']['description'] == "user_detail fbv docstring"
def test_parameter_docstrings(swagger_dict):
users_detail = swagger_dict['paths']['/users/{id}/']
assert users_detail['get']['parameters'][0]['description'] == "test manual param"
assert users_detail['put']['parameters'][0]['in'] == openapi.IN_BODY

View File

@ -18,12 +18,12 @@ def test_operation_docstrings(swagger_dict):
articles_detail = swagger_dict['paths']['/articles/{slug}/'] articles_detail = swagger_dict['paths']['/articles/{slug}/']
assert articles_detail['get']['description'] == "retrieve class docstring" assert articles_detail['get']['description'] == "retrieve class docstring"
assert articles_detail['put']['description'] == "update method docstring" assert articles_detail['put']['description'] == "update method docstring"
assert articles_detail['patch']['description'] == "ArticleViewSet class docstring" assert articles_detail['patch']['description'] == "partial_update description override"
assert articles_detail['delete']['description'] == "destroy method docstring" assert articles_detail['delete']['description'] == "destroy method docstring"
articles_today = swagger_dict['paths']['/articles/today/'] articles_today = swagger_dict['paths']['/articles/today/']
assert articles_today['get']['description'] == "ArticleViewSet class docstring" assert articles_today['get']['description'] == "ArticleViewSet class docstring"
articles_image = swagger_dict['paths']['/articles/{slug}/image/'] articles_image = swagger_dict['paths']['/articles/{slug}/image/']
assert articles_image['get']['description'] == "image method docstring" assert articles_image['get']['description'] == "image GET description override"
assert articles_image['post']['description'] == "image method docstring" assert articles_image['post']['description'] == "image method docstring"

View File

@ -0,0 +1,2 @@
def test_reference_schema(swagger_dict, reference_schema):
return swagger_dict == reference_schema