Add swagger_auto_schema method decorator for Operation customization

See #5.
openapi3
Cristi Vîjdea 2017-12-08 05:33:01 +01:00
parent 82cac4ef0d
commit 652795f5db
11 changed files with 443 additions and 215 deletions

View File

@ -29,36 +29,61 @@ class OpenAPISchemaGenerator(object):
inspector = self._gen.endpoint_inspector_cls(self._gen.patterns, self._gen.urlconf) inspector = self._gen.endpoint_inspector_cls(self._gen.patterns, self._gen.urlconf)
self.endpoints = inspector.get_api_endpoints() self.endpoints = inspector.get_api_endpoints()
self.get_endpoints(None if public else request) self.clean_endpoints(None if public else request)
paths = self.get_paths() paths = self.get_paths()
url = self._gen.url url = self._gen.url
if not url and request is not None: if not url and request is not None:
url = request.build_absolute_uri() url = request.build_absolute_uri()
# distribute_links(links)
return openapi.Swagger( return openapi.Swagger(
info=self.info, paths=paths, info=self.info, paths=paths,
_url=url, _version=self.version, _url=url, _version=self.version,
) )
def get_endpoints(self, request): def create_view(self, callback, method, request):
view = self._gen.create_view(callback, method, request)
overrides = getattr(callback, 'swagger_auto_schema', None)
if overrides is not None:
# decorated function based view
for method, _ in overrides.items():
view_method = getattr(view, method, None)
if view_method is not None:
setattr(view_method.__func__, 'swagger_auto_schema', overrides)
return view
def clean_endpoints(self, request):
"""Generate {path: (view_class, [(method, view)]) given (path, method, callback).""" """Generate {path: (view_class, [(method, view)]) given (path, method, callback)."""
view_paths = defaultdict(list) view_paths = defaultdict(list)
view_cls = {} view_cls = {}
for path, method, callback in self.endpoints: for path, method, callback in self.endpoints:
view = self._gen.create_view(callback, method, request) view = self.create_view(callback, method, request)
path = self._gen.coerce_path(path, method, view) path = self._gen.coerce_path(path, method, view)
view_paths[path].append((method, view)) view_paths[path].append((method, view))
view_cls[path] = callback.cls view_cls[path] = callback.cls
self.endpoints = {path: (view_cls[path], methods) for path, methods in view_paths.items()} self.endpoints = {path: (view_cls[path], methods) for path, methods in view_paths.items()}
def get_operation_keys(self, subpath, method, view):
"""
Return a list of keys that should be used to layout a link within
the schema document.
/users/ ("users", "list"), ("users", "create")
/users/{pk}/ ("users", "read"), ("users", "update"), ("users", "delete")
/users/enabled/ ("users", "enabled") # custom viewset list action
/users/{pk}/star/ ("users", "star") # custom viewset detail action
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update")
"""
return self._gen.get_keys(subpath, method, view)
def get_paths(self): def get_paths(self):
if not self.endpoints: if not self.endpoints:
return [] return []
prefix = self._gen.determine_path_prefix(self.endpoints.keys()) prefix = self._gen.determine_path_prefix(self.endpoints.keys())
paths = OrderedDict() paths = OrderedDict()
default_schema_cls = SwaggerAutoSchema
for path, (view_cls, methods) in sorted(self.endpoints.items()): for path, (view_cls, methods) in sorted(self.endpoints.items()):
path_parameters = self.get_path_parameters(path, view_cls) path_parameters = self.get_path_parameters(path, view_cls)
operations = {} operations = {}
@ -66,14 +91,26 @@ class OpenAPISchemaGenerator(object):
if not self._gen.has_view_permissions(path, method, view): if not self._gen.has_view_permissions(path, method, view):
continue continue
schema = SwaggerAutoSchema(view) operation_keys = self.get_operation_keys(path[len(prefix):], method, view)
operation_keys = self._gen.get_keys(path[len(prefix):], method, view) overrides = self.get_overrides(view, method)
operations[method.lower()] = schema.get_operation(operation_keys, path, method) auto_schema_cls = overrides.get('auto_schema', default_schema_cls)
schema = auto_schema_cls(view, path, method, overrides)
operations[method.lower()] = schema.get_operation(operation_keys)
paths[path] = openapi.PathItem(parameters=path_parameters, **operations) paths[path] = openapi.PathItem(parameters=path_parameters, **operations)
return openapi.Paths(paths=paths) return openapi.Paths(paths=paths)
def get_overrides(self, view, method):
method = method.lower()
action = getattr(view, 'action', method)
action_method = getattr(view, action, None)
overrides = getattr(action_method, 'swagger_auto_schema', {})
if method in overrides:
overrides = overrides[method]
return overrides
def get_path_parameters(self, path, view_cls): def get_path_parameters(self, path, view_cls):
"""Return a list of Parameter instances corresponding to any templated path variables. """Return a list of Parameter instances corresponding to any templated path variables.

View File

@ -1,249 +1,21 @@
import functools import functools
import inspect
from collections import OrderedDict from collections import OrderedDict
import coreschema import coreschema
from django.core.validators import RegexValidator from django.core.validators import RegexValidator
from django.utils.encoding import force_text from django.utils.encoding import force_text
from rest_framework import serializers from rest_framework import serializers
from rest_framework.request import is_form_media_type
from rest_framework.schemas import AutoSchema from rest_framework.schemas import AutoSchema
from rest_framework.schemas.utils import is_list_view from rest_framework.viewsets import GenericViewSet
from drf_swagger.errors import SwaggerGenerationError from drf_swagger.errors import SwaggerGenerationError
from . import openapi from . import openapi
from .utils import no_body, is_list_view
def find_regex(regex_field): def serializer_field_to_swagger(field, swagger_object_type, **kwargs):
regex_validator = None
for validator in regex_field.validators:
if isinstance(validator, RegexValidator):
if regex_validator is not None:
# bail if multiple validators are found - no obvious way to choose
return None
regex_validator = validator
# regex_validator.regex should be a compiled re object...
return getattr(getattr(regex_validator, 'regex', None), 'pattern', None)
class SwaggerAutoSchema(object):
def __init__(self, view):
super(SwaggerAutoSchema, self).__init__()
self._sch = AutoSchema()
self.view = view
self._sch.view = view
def get_operation(self, operation_keys, path, method):
"""Get an Operation for the given API endpoint (path, method).
This includes query, body parameters and response schemas.
:param tuple[str] operation_keys: an array of keys describing the hierarchical layout of this view in the API;
e.g. ('snippets', 'list'), ('snippets', 'retrieve'), etc.
:param str path: the view's path
:param str method: HTTP request method
:return openapi.Operation: the resulting Operation object
"""
body = self.get_request_body_parameters(path, method)
query = self.get_query_parameters(path, method)
parameters = body + query
parameters = [param for param in parameters if param is not None]
description = self.get_description(path, method)
responses = self.get_responses(path, method)
return openapi.Operation(
operation_id='_'.join(operation_keys),
description=description,
responses=responses,
parameters=parameters,
tags=[operation_keys[0]]
)
def get_request_body_parameters(self, path, method):
"""Return the request body parameters for this view.
This is either:
- a list with a single object Parameter with a Schema derived from the request serializer
- a list of primitive Parameters parsed as form data
:param str path: the view's path
:param str method: HTTP request method
:return list[Parameter]: a (potentially empty) list of openapi.Parameter in: either `body` or `formData`
"""
# only PUT, PATCH or POST can have a request body
if method not in ('PUT', 'PATCH', 'POST'):
return []
serializer = self.get_request_serializer(path, method)
if serializer is None:
return []
encoding = self._sch.get_encoding(path, method)
if 'form' in encoding:
return [
self.field_to_swagger(value, openapi.Parameter, name=key, in_=openapi.IN_FORM)
for key, value
in serializer.fields.items()
]
else:
schema = self.get_request_body_schema(path, method, serializer)
return [openapi.Parameter(name='data', in_=openapi.IN_BODY, schema=schema)]
def get_request_serializer(self, path, method):
"""Return the request serializer (used for parsing the request payload) for this endpoint.
:param str path: the view's path
:param str method: HTTP request method
:return serializers.Serializer: the request serializer
"""
# TODO: only GenericAPIViews have defined serializers;
# APIViews and plain ViewSets will need some kind of manual treatment
if not hasattr(self.view, 'get_serializer'):
return None
return self.view.get_serializer()
def get_request_body_schema(self, path, method, serializer):
"""Return the Schema for a given request's body data. Only applies to PUT, PATCH and POST requests.
:param str path: the view's path
:param str method: HTTP request method
:param serializer: the view's request serialzier
:return openapi.Schema: the request body schema
"""
return self.field_to_swagger(serializer, openapi.Schema)
def get_responses(self, path, method):
"""Get the possible responses for this view as a swagger Responses object.
:param str path: the view's path
:param str method: HTTP request method
:return Responses: the documented responses
"""
response_serializers = self.get_response_serializers(path, method)
return openapi.Responses(
responses=self.get_response_schemas(path, method, response_serializers)
)
def get_response_serializers(self, path, method):
"""Return the response codes that this view is expected to return, and the serializer for each response body.
The return value should be a dict where the keys are possible status codes, and values are either strings,
`Serializer`s or `openapi.Response` objects.
:param str path: the view's path
:param str method: HTTP request method
:return dict: the response serializers
"""
if method.lower() == 'post':
return {'201': ''}
if method.lower() == 'delete':
return {'204': ''}
return {'200': ''}
def get_response_schemas(self, path, method, response_serializers):
"""Return the `openapi.Response` objects calculated for this view.
:param str path: the view's path
:param str method: HTTP request method
:param dict response_serializers: result of get_response_serializers
:return dict[str, openapi.Response]: a dictionary of status code to Response object
"""
responses = {}
for status, serializer in response_serializers.items():
if isinstance(serializer, str):
response = openapi.Response(
description=serializer
)
elif isinstance(serializer, openapi.Response):
response = serializer
else:
response = openapi.Response(
description='',
schema=self.field_to_swagger(serializer, openapi.Schema)
)
responses[str(status)] = response
return responses
def get_query_parameters(self, path, method):
"""Return the query parameters accepted by this view.
:param str path: the view's path
:param str method: HTTP request method
:return list[openapi.Parameter]: the query parameters
"""
return self.get_filter_parameters(path, method) + self.get_pagination_parameters(path, method)
def get_filter_parameters(self, path, method):
"""Return the parameters added to the view by its filter backends.
:param str path: the view's path
:param str method: HTTP request method
:return list[openapi.Parameter]: the filter query parameters
"""
if not self._sch._allows_filters(path, method):
return []
fields = []
for filter_backend in self.view.filter_backends:
filter = filter_backend()
if hasattr(filter, 'get_schema_fields'):
fields += filter.get_schema_fields(self.view)
return [self.coreapi_field_to_parameter(field) for field in fields]
def get_pagination_parameters(self, path, method):
"""Return the parameters added to the view by its paginator.
:param str path: the view's path
:param str method: HTTP request method
:return list[openapi.Parameter]: the pagination query parameters
"""
if not is_list_view(path, method, self.view):
return []
paginator = getattr(self.view, 'paginator', None)
if paginator is None:
return []
return [
self.coreapi_field_to_parameter(field)
for field in paginator.get_schema_fields(self.view)
]
def coreapi_field_to_parameter(self, field):
"""Convert an instance of `coreapi.Field` to a swagger Parameter object.
:param coreapi.Field field: the coreapi field
:return openapi.Parameter: the equivalent openapi primitive Parameter
"""
location_to_in = {
'query': openapi.IN_QUERY,
'path': openapi.IN_PATH,
'form': openapi.IN_FORM,
'body': openapi.IN_FORM,
}
coreapi_types = {
coreschema.Integer: openapi.TYPE_INTEGER,
coreschema.Number: openapi.TYPE_NUMBER,
coreschema.String: openapi.TYPE_STRING,
coreschema.Boolean: openapi.TYPE_BOOLEAN,
}
return openapi.Parameter(
name=field.name,
in_=location_to_in[field.location],
type=coreapi_types.get(field.schema.__class__, openapi.TYPE_STRING),
required=field.required,
description=field.schema.description,
)
def get_description(self, path, method):
"""Return an operation description determined as appropriate from the view's method and class docstrings.
:param str path: the view's path
:param str method: HTTP request method
:return str: the operation description
"""
return self._sch.get_description(path, method)
def field_to_swagger(self, field, swagger_object_type, **kwargs):
"""Convert a drf Serializer or Field instance into a Swagger object. """Convert a drf Serializer or Field instance into a Swagger object.
:param rest_framework.serializers.Field field: the source field :param rest_framework.serializers.Field field: the source field
@ -265,7 +37,7 @@ class SwaggerAutoSchema(object):
# ------ NESTED # ------ NESTED
if isinstance(field, (serializers.ListSerializer, serializers.ListField)): if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = self.field_to_swagger(field.child, ChildSwaggerType) child_schema = serializer_field_to_swagger(field.child, ChildSwaggerType)
return SwaggerType( return SwaggerType(
type=openapi.TYPE_ARRAY, type=openapi.TYPE_ARRAY,
items=child_schema, items=child_schema,
@ -277,13 +49,13 @@ class SwaggerAutoSchema(object):
return SwaggerType( return SwaggerType(
type=openapi.TYPE_OBJECT, type=openapi.TYPE_OBJECT,
properties=OrderedDict( properties=OrderedDict(
(key, self.field_to_swagger(value, ChildSwaggerType)) (key, serializer_field_to_swagger(value, ChildSwaggerType))
for key, value for key, value
in field.fields.items() in field.fields.items()
) )
) )
elif isinstance(field, serializers.ManyRelatedField): elif isinstance(field, serializers.ManyRelatedField):
child_schema = self.field_to_swagger(field.child_relation, ChildSwaggerType) child_schema = serializer_field_to_swagger(field.child_relation, ChildSwaggerType)
return SwaggerType( return SwaggerType(
type=openapi.TYPE_ARRAY, type=openapi.TYPE_ARRAY,
items=child_schema, items=child_schema,
@ -349,7 +121,7 @@ class SwaggerAutoSchema(object):
format=openapi.FORMAT_BINARY if field.binary else None format=openapi.FORMAT_BINARY if field.binary else None
) )
elif isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema: elif isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema:
child_schema = self.field_to_swagger(field.child, ChildSwaggerType) child_schema = serializer_field_to_swagger(field.child, ChildSwaggerType)
return SwaggerType( return SwaggerType(
type=openapi.TYPE_OBJECT, type=openapi.TYPE_OBJECT,
additional_properties=child_schema additional_properties=child_schema
@ -360,3 +132,288 @@ class SwaggerAutoSchema(object):
# everything else gets string by default # everything else gets string by default
return SwaggerType(type=openapi.TYPE_STRING) return SwaggerType(type=openapi.TYPE_STRING)
def find_regex(regex_field):
regex_validator = None
for validator in regex_field.validators:
if isinstance(validator, RegexValidator):
if regex_validator is not None:
# bail if multiple validators are found - no obvious way to choose
return None
regex_validator = validator
# regex_validator.regex should be a compiled re object...
return getattr(getattr(regex_validator, 'regex', None), 'pattern', None)
class SwaggerAutoSchema(object):
def __init__(self, view, path, method, overrides):
super(SwaggerAutoSchema, self).__init__()
self._sch = AutoSchema()
self.view = view
self.path = path
self.method = method
self.overrides = overrides
self._sch.view = view
def get_operation(self, operation_keys):
"""Get an Operation for the given API endpoint (path, method).
This includes query, body parameters and response schemas.
:param tuple[str] operation_keys: an array of keys describing the hierarchical layout of this view in the API;
e.g. ('snippets', 'list'), ('snippets', 'retrieve'), etc.
:return openapi.Operation: the resulting Operation object
"""
consumes = self.get_consumes()
body = self.get_request_body_parameters(consumes)
query = self.get_query_parameters()
parameters = body + query
parameters = [param for param in parameters if param is not None]
parameters = self.add_manual_parameters(parameters)
description = self.get_description()
responses = self.get_responses()
# manual_responses = self.overrides.get('responses', None) or {}
return openapi.Operation(
operation_id='_'.join(operation_keys),
description=description,
responses=responses,
parameters=parameters,
consumes=consumes,
tags=[operation_keys[0]],
)
def get_request_body_parameters(self, consumes):
"""Return the request body parameters for this view.
This is either:
- a list with a single object Parameter with a Schema derived from the request serializer
- a list of primitive Parameters parsed as form data
:param list[str] consumes: a list of MIME types this request accepts as body
:return list[Parameter]: a (potentially empty) list of openapi.Parameter in: either `body` or `formData`
"""
# only PUT, PATCH or POST can have a request body
if self.method not in ('PUT', 'PATCH', 'POST'):
return []
serializer = self.get_request_serializer()
schema = None
if serializer is None:
return []
if isinstance(serializer, openapi.Schema):
schema = serializer
if any(is_form_media_type(encoding) for encoding in consumes):
if schema is not None:
raise SwaggerGenerationError("form request body cannot be a Schema")
return self.get_request_form_parameters(serializer)
else:
if schema is None:
schema = self.get_request_body_schema(serializer)
return [self.make_body_parameter(schema)]
def get_request_serializer(self):
"""Return the request serializer (used for parsing the request payload) for this endpoint.
:return serializers.Serializer: the request serializer
"""
body_override = self.overrides.get('request_body', None)
if body_override is not None:
if body_override is no_body:
return None
if inspect.isclass(body_override):
assert issubclass(body_override, serializers.Serializer)
return body_override()
return body_override
else:
if not hasattr(self.view, 'get_serializer'):
return None
return self.view.get_serializer()
def get_request_form_parameters(self, serializer):
"""Given a Serializer, return a list of in: formData Parameters.
:param serializer: the view's request serialzier
"""
return [
self.field_to_swagger(value, openapi.Parameter, name=key, in_=openapi.IN_FORM)
for key, value
in serializer.fields.items()
]
def get_request_body_schema(self, serializer):
"""Return the Schema for a given request's body data. Only applies to PUT, PATCH and POST requests.
:param serializer: the view's request serialzier
:return openapi.Schema: the request body schema
"""
return self.field_to_swagger(serializer, openapi.Schema)
def make_body_parameter(self, schema):
"""Given a Schema object, create an in: body Parameter.
:param openapi.Schema schema: the request body schema
"""
return openapi.Parameter(name='data', in_=openapi.IN_BODY, schema=schema)
def add_manual_parameters(self, parameters):
"""Add/replace parameters from the given list of automatically generated request parameters.
:param list[openapi.Parameter] parameters: genereated parameters
:return list[openapi.Parameter]: modified parameters
"""
parameters = OrderedDict(((param.name, param.in_), param) for param in parameters)
manual_parameters = self.overrides.get('manual_parameters', None) or []
if any(param.in_ == openapi.IN_BODY for param in manual_parameters):
raise SwaggerGenerationError("specify the body parameter as a Schema or Serializer in request_body")
if any(param.in_ == openapi.IN_FORM for param in manual_parameters):
if any(param.in_ == openapi.IN_BODY for param in parameters.values()):
raise SwaggerGenerationError("cannot add form parameters when the request has a request schema; "
"did you forget to set an appropriate parser class on the view?")
parameters.update(((param.name, param.in_), param) for param in manual_parameters)
return list(parameters.values())
def get_responses(self):
"""Get the possible responses for this view as a swagger Responses object.
:return Responses: the documented responses
"""
response_serializers = self.get_response_serializers()
return openapi.Responses(
responses=self.get_response_schemas(response_serializers)
)
def get_response_serializers(self):
"""Return the response codes that this view is expected to return, and the serializer for each response body.
The return value should be a dict where the keys are possible status codes, and values are either strings,
`Serializer` or `openapi.Response` objects.
:return dict: the response serializers
"""
if self.method.lower() == 'post':
return {'201': ''}
if self.method.lower() == 'delete':
return {'204': ''}
return {'200': ''}
def get_response_schemas(self, response_serializers):
"""Return the `openapi.Response` objects calculated for this view.
:param dict response_serializers: result of get_response_serializers
:return dict[str, openapi.Response]: a dictionary of status code to Response object
"""
responses = {}
for status, serializer in response_serializers.items():
if isinstance(serializer, str):
response = openapi.Response(
description=serializer
)
elif isinstance(serializer, openapi.Response):
response = serializer
else:
response = openapi.Response(
description='',
schema=self.field_to_swagger(serializer, openapi.Schema)
)
responses[str(status)] = response
return responses
def get_query_parameters(self):
"""Return the query parameters accepted by this view."""
return self.get_filter_parameters() + self.get_pagination_parameters()
def should_filter(self):
if getattr(self.view, 'filter_backends', None) is None:
return False
if self.method.lower() not in ["get", "put", "patch", "delete"]:
return False
if not isinstance(self.view, GenericViewSet):
return True
return is_list_view(self.path, self.method, self.view)
def get_filter_parameters(self):
"""Return the parameters added to the view by its filter backends."""
if not self.should_filter():
return []
fields = []
for filter_backend in self.view.filter_backends:
filter = filter_backend()
if hasattr(filter, 'get_schema_fields'):
fields += filter.get_schema_fields(self.view)
return [self.coreapi_field_to_parameter(field) for field in fields]
def should_page(self):
if not hasattr(self.view, 'paginator'):
return False
return is_list_view(self.path, self.method, self.view)
def get_pagination_parameters(self):
"""Return the parameters added to the view by its paginator."""
if not self.should_page():
return []
paginator = self.view.paginator
if not hasattr(paginator, 'get_schema_fields'):
return []
return [self.coreapi_field_to_parameter(field) for field in paginator.get_schema_fields(self.view)]
def coreapi_field_to_parameter(self, field):
"""Convert an instance of `coreapi.Field` to a swagger Parameter object.
:param coreapi.Field field: the coreapi field
"""
location_to_in = {
'query': openapi.IN_QUERY,
'path': openapi.IN_PATH,
'form': openapi.IN_FORM,
'body': openapi.IN_FORM,
}
coreapi_types = {
coreschema.Integer: openapi.TYPE_INTEGER,
coreschema.Number: openapi.TYPE_NUMBER,
coreschema.String: openapi.TYPE_STRING,
coreschema.Boolean: openapi.TYPE_BOOLEAN,
}
return openapi.Parameter(
name=field.name,
in_=location_to_in[field.location],
type=coreapi_types.get(field.schema.__class__, openapi.TYPE_STRING),
required=field.required,
description=field.schema.description,
)
def get_description(self):
"""Return an operation description determined as appropriate from the view's method and class docstrings.
:return str: the operation description
"""
description = self.overrides.get('operation_description', None)
if description is None:
description = self._sch.get_description(self.path, self.method)
return description
def get_consumes(self):
"""Return the MIME types this endpoint can consume."""
media_types = [parser.media_type for parser in getattr(self.view, 'parser_classes', [])]
if all(is_form_media_type(encoding) for encoding in media_types):
return media_types
return media_types[:1]
def field_to_swagger(self, field, swagger_object_type, **kwargs):
return serializer_field_to_swagger(field, swagger_object_type, **kwargs)

View File

@ -212,11 +212,11 @@ class Operation(SwaggerDict):
produces=None, description=None, tags=None, **extra): produces=None, description=None, tags=None, **extra):
super(Operation, self).__init__(**extra) super(Operation, self).__init__(**extra)
self.operation_id = operation_id self.operation_id = operation_id
self.responses = responses self.description = description
self.parameters = [param for param in parameters if param is not None] self.parameters = [param for param in parameters if param is not None]
self.responses = responses
self.consumes = consumes self.consumes = consumes
self.produces = produces self.produces = produces
self.description = description
self.tags = tags self.tags = tags
self._insert_extras__() self._insert_extras__()

View File

@ -0,0 +1,90 @@
from rest_framework.mixins import RetrieveModelMixin, DestroyModelMixin
no_body = object()
class UpdateModelMixing(object):
pass
def is_list_view(path, method, view):
"""Return True if the given path/method appears to represent a list view (as opposed to a detail/instance view)."""
# for ViewSets, it could be the default 'list' view, or a list_route
action = getattr(view, 'action', '')
method = getattr(view, action, None)
detail = getattr(method, 'detail', None)
suffix = getattr(view, 'suffix', None)
if action == 'list' or detail is False or suffix == 'List':
return True
if action in ('retrieve', 'update', 'partial_update', 'destroy') or detail is True or suffix == 'Instance':
# a detail_route is surely not a list route
return False
# for APIView, if it's a detail view it can't also be a list view
if isinstance(view, (RetrieveModelMixin, UpdateModelMixing, DestroyModelMixin)):
return False
# if the last component in the path is parameterized it's probably not a list view
path_components = path.strip('/').split('/')
if path_components and '{' in path_components[-1]:
return False
# otherwise assume it's a list route
return True
def swagger_auto_schema(method=None, methods=None, auto_schema=None, request_body=None, manual_parameters=None,
operation_description=None, responses=None):
def decorator(view_method):
data = {
'auto_schema': auto_schema,
'request_body': request_body,
'manual_parameters': manual_parameters,
'operation_description': operation_description,
'responses': responses,
}
data = {k: v for k, v in data.items() if v is not None}
bind_to_methods = getattr(view_method, 'bind_to_methods', [])
# if the method is actually a function based view
view_cls = getattr(view_method, 'cls', None)
http_method_names = getattr(view_cls, 'http_method_names', [])
if bind_to_methods or http_method_names:
# detail_route, list_route or api_view
assert bool(http_method_names) != bool(bind_to_methods), "this should never happen"
available_methods = http_method_names + bind_to_methods
existing_data = getattr(view_method, 'swagger_auto_schema', {})
if http_method_names:
_route = "api_view"
else:
_route = "detail_route" if view_method.detail else "list_route"
_methods = methods
if len(available_methods) > 1:
assert methods or method, \
"on multi-method %s, you must specify swagger_auto_schema on a per-method basis " \
"using one of the `method` or `methods` arguments" % _route
assert bool(methods) != bool(method), "specify either method or methods"
if method:
_methods = [method.lower()]
else:
_methods = [mth.lower() for mth in methods]
assert not isinstance(_methods, str)
assert not any(mth in existing_data for mth in _methods), "method defined multiple times"
assert all(mth in available_methods for mth in _methods), "method not bound to %s" % _route
existing_data.update((mth.lower(), data) for mth in _methods)
else:
existing_data[available_methods[0]] = data
view_method.swagger_auto_schema = existing_data
else:
assert methods is None, \
"the methods argument should only be specified when decorating a detail_route or list_route; you " \
"should also ensure that you put the swagger_auto_schema decorator AFTER (above) the _route decorator"
view_method.swagger_auto_schema = data
return view_method
return decorator

View File

@ -9,6 +9,7 @@ from rest_framework.response import Response
from articles import serializers from articles import serializers
from articles.models import Article from articles.models import Article
from drf_swagger.utils import swagger_auto_schema
class ArticleViewSet(viewsets.ModelViewSet): class ArticleViewSet(viewsets.ModelViewSet):
@ -20,6 +21,9 @@ class ArticleViewSet(viewsets.ModelViewSet):
destroy: destroy:
destroy class docstring destroy class docstring
partial_update:
partial_update class docstring
""" """
queryset = Article.objects.all() queryset = Article.objects.all()
lookup_field = 'slug' lookup_field = 'slug'
@ -40,11 +44,9 @@ class ArticleViewSet(viewsets.ModelViewSet):
serializer = self.serializer_class(articles, many=True) serializer = self.serializer_class(articles, many=True)
return Response(serializer.data) return Response(serializer.data)
@detail_route( @swagger_auto_schema(method='get', operation_description="image GET description override")
methods=['get', 'post'], @swagger_auto_schema(method='post', request_body=serializers.ImageUploadSerializer)
parser_classes=(MultiPartParser,), @detail_route(methods=['get', 'post'], parser_classes=(MultiPartParser,))
serializer_class=serializers.ImageUploadSerializer,
)
def image(self, request, slug=None): def image(self, request, slug=None):
""" """
image method docstring image method docstring
@ -55,6 +57,11 @@ class ArticleViewSet(viewsets.ModelViewSet):
"""update method docstring""" """update method docstring"""
return super(ArticleViewSet, self).update(request, *args, **kwargs) return super(ArticleViewSet, self).update(request, *args, **kwargs)
@swagger_auto_schema(operation_description="partial_update description override")
def partial_update(self, request, *args, **kwargs):
"""partial_update method docstring"""
return super(ArticleViewSet, self).partial_update(request, *args, **kwargs)
def destroy(self, request, *args, **kwargs): def destroy(self, request, *args, **kwargs):
"""destroy method docstring""" """destroy method docstring"""
return super(ArticleViewSet, self).destroy(request, *args, **kwargs) return super(ArticleViewSet, self).destroy(request, *args, **kwargs)

View File

@ -1,9 +1,12 @@
from django.contrib.auth.models import User from django.contrib.auth.models import User
from rest_framework import status
from rest_framework.decorators import api_view from rest_framework.decorators import api_view
from rest_framework.generics import get_object_or_404 from rest_framework.generics import get_object_or_404
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from drf_swagger import openapi
from drf_swagger.utils import swagger_auto_schema
from users.serializers import UserSerializer from users.serializers import UserSerializer
@ -15,8 +18,19 @@ class UserList(APIView):
serializer = UserSerializer(queryset, many=True) serializer = UserSerializer(queryset, many=True)
return Response(serializer.data) return Response(serializer.data)
@swagger_auto_schema(request_body=UserSerializer, operation_description="apiview post description override")
def post(self, request):
serializer = UserSerializer(request.data)
serializer.is_valid(raise_exception=True)
serializer.save()
return Response(serializer.data, status=status.HTTP_201_CREATED)
@api_view(['GET'])
@swagger_auto_schema(method='put', request_body=UserSerializer)
@swagger_auto_schema(method='get', manual_parameters=[
openapi.Parameter('test', openapi.IN_QUERY, "test manual param", type=openapi.TYPE_BOOLEAN)
])
@api_view(['GET', 'PUT'])
def user_detail(request, pk): def user_detail(request, pk):
"""user_detail fbv docstring""" """user_detail fbv docstring"""
user = get_object_or_404(User.objects, pk=pk) user = get_object_or_404(User.objects, pk=pk)

View File

@ -1,3 +1,6 @@
import json
import os
import pytest import pytest
from ruamel import yaml from ruamel import yaml
@ -53,3 +56,9 @@ def bad_settings():
SWAGGER_DEFAULTS['SECURITY_DEFINITIONS'].update(bad_security) SWAGGER_DEFAULTS['SECURITY_DEFINITIONS'].update(bad_security)
yield swagger_settings yield swagger_settings
del SWAGGER_DEFAULTS['SECURITY_DEFINITIONS']['bad'] del SWAGGER_DEFAULTS['SECURITY_DEFINITIONS']['bad']
@pytest.fixture
def reference_schema():
with open(os.path.join(os.path.dirname(__file__), 'reference.json')) as reference:
return json.load(reference)

File diff suppressed because one or more lines are too long

View File

@ -1,6 +1,17 @@
from drf_swagger import openapi
def test_operation_docstrings(swagger_dict): def test_operation_docstrings(swagger_dict):
users_list = swagger_dict['paths']['/users/'] users_list = swagger_dict['paths']['/users/']
assert users_list['get']['description'] == "UserList cbv classdoc" assert users_list['get']['description'] == "UserList cbv classdoc"
assert users_list['post']['description'] == "apiview post description override"
users_detail = swagger_dict['paths']['/users/{id}/'] users_detail = swagger_dict['paths']['/users/{id}/']
assert users_detail['get']['description'] == "user_detail fbv docstring" assert users_detail['get']['description'] == "user_detail fbv docstring"
assert users_detail['put']['description'] == "user_detail fbv docstring"
def test_parameter_docstrings(swagger_dict):
users_detail = swagger_dict['paths']['/users/{id}/']
assert users_detail['get']['parameters'][0]['description'] == "test manual param"
assert users_detail['put']['parameters'][0]['in'] == openapi.IN_BODY

View File

@ -18,12 +18,12 @@ def test_operation_docstrings(swagger_dict):
articles_detail = swagger_dict['paths']['/articles/{slug}/'] articles_detail = swagger_dict['paths']['/articles/{slug}/']
assert articles_detail['get']['description'] == "retrieve class docstring" assert articles_detail['get']['description'] == "retrieve class docstring"
assert articles_detail['put']['description'] == "update method docstring" assert articles_detail['put']['description'] == "update method docstring"
assert articles_detail['patch']['description'] == "ArticleViewSet class docstring" assert articles_detail['patch']['description'] == "partial_update description override"
assert articles_detail['delete']['description'] == "destroy method docstring" assert articles_detail['delete']['description'] == "destroy method docstring"
articles_today = swagger_dict['paths']['/articles/today/'] articles_today = swagger_dict['paths']['/articles/today/']
assert articles_today['get']['description'] == "ArticleViewSet class docstring" assert articles_today['get']['description'] == "ArticleViewSet class docstring"
articles_image = swagger_dict['paths']['/articles/{slug}/image/'] articles_image = swagger_dict['paths']['/articles/{slug}/image/']
assert articles_image['get']['description'] == "image method docstring" assert articles_image['get']['description'] == "image GET description override"
assert articles_image['post']['description'] == "image method docstring" assert articles_image['post']['description'] == "image method docstring"

View File

@ -0,0 +1,2 @@
def test_reference_schema(swagger_dict, reference_schema):
return swagger_dict == reference_schema