Add suport for Response schemas (#10)

Schemas generated from Serializers will now be added to the `definitions` section by default, and used as `$ref` objects where needed.  
The Schema definition name is based on the serializer class name, and can be overriden by specifying a `__ref_name__` property on the Serializer. If this property is set to None, the schema will not be added to `definitions` and will be forced inline.

Closes #6, #7.
openapi3
Cristi Vîjdea 2017-12-10 03:06:49 +01:00 committed by GitHub
parent 53b2560063
commit bfced82ae4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1637 additions and 273 deletions

View File

@ -24,8 +24,6 @@ coverage:
changes: changes:
default: default:
enabled: yes enabled: no
if_no_uploads: error
if_ci_failed: error
comment: false comment: false

View File

@ -14,6 +14,7 @@ exclude_lines =
# Don't complain if tests don't hit defensive assertion code: # Don't complain if tests don't hit defensive assertion code:
raise AssertionError raise AssertionError
raise TypeError
raise NotImplementedError raise NotImplementedError
warnings.warn warnings.warn
@ -21,6 +22,9 @@ exclude_lines =
if 0: if 0:
if __name__ == .__main__.: if __name__ == .__main__.:
# Don't complain if we don't hit invalid schema configurations
raise SwaggerGenerationError
ignore_errors = True ignore_errors = True
precision = 0 precision = 0

View File

@ -27,7 +27,7 @@
</value> </value>
</option> </option>
</inspection_tool> </inspection_tool>
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true"> <inspection_tool class="PyPep8NamingInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false">
<option name="ignoredErrors"> <option name="ignoredErrors">
<list> <list>
<option value="N806" /> <option value="N806" />

View File

@ -1,3 +0,0 @@
[pytest]
DJANGO_SETTINGS_MODULE = testproj.settings
python_paths = testproj

View File

@ -2,6 +2,7 @@
pytest-django>=3.1.2 pytest-django>=3.1.2
pytest-pythonpath>=0.7.1 pytest-pythonpath>=0.7.1
pytest-cov>=2.5.1 pytest-cov>=2.5.1
datadiff==2.0.0
# test project requirements # test project requirements
Pillow>=4.3.0 Pillow>=4.3.0

View File

@ -1,3 +1,4 @@
import copy
import json import json
from collections import OrderedDict from collections import OrderedDict
@ -5,9 +6,9 @@ from coreapi.compat import force_bytes
from future.utils import raise_from from future.utils import raise_from
from ruamel import yaml from ruamel import yaml
from drf_swagger.app_settings import swagger_settings
from drf_swagger.errors import SwaggerValidationError
from . import openapi from . import openapi
from .app_settings import swagger_settings
from .errors import SwaggerValidationError
def _validate_flex(spec, codec): def _validate_flex(spec, codec):
@ -51,7 +52,9 @@ class _OpenAPICodec(object):
spec = self.generate_swagger_object(document) spec = self.generate_swagger_object(document)
for validator in self.validators: for validator in self.validators:
VALIDATORS[validator](spec, self) # validate a deepcopy of the spec to prevent the validator from messing with it
# for example, swagger_spec_validator adds an x-scope property to all references
VALIDATORS[validator](copy.deepcopy(spec), self)
return force_bytes(self._dump_dict(spec)) return force_bytes(self._dump_dict(spec))
def encode_error(self, err): def encode_error(self, err):
@ -119,6 +122,7 @@ class SaneYamlDumper(yaml.SafeDumper):
return node return node
SaneYamlDumper.add_representer(OrderedDict, SaneYamlDumper.represent_odict)
SaneYamlDumper.add_multi_representer(OrderedDict, SaneYamlDumper.represent_odict) SaneYamlDumper.add_multi_representer(OrderedDict, SaneYamlDumper.represent_odict)

View File

@ -6,8 +6,9 @@ from coreapi.compat import force_text
from rest_framework.schemas.generators import SchemaGenerator from rest_framework.schemas.generators import SchemaGenerator
from rest_framework.schemas.inspectors import get_pk_description from rest_framework.schemas.inspectors import get_pk_description
from drf_swagger.inspectors import SwaggerAutoSchema
from . import openapi from . import openapi
from .inspectors import SwaggerAutoSchema
from .openapi import ReferenceResolver
class OpenAPISchemaGenerator(object): class OpenAPISchemaGenerator(object):
@ -20,17 +21,13 @@ class OpenAPISchemaGenerator(object):
self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf) self._gen = SchemaGenerator(info.title, url, info.get('description', ''), patterns, urlconf)
self.info = info self.info = info
self.version = version self.version = version
self.endpoints = None
self.url = url self.url = url
def get_schema(self, request=None, public=False): def get_schema(self, request=None, public=False):
"""Generate an openapi.Swagger representing the API schema.""" """Generate an openapi.Swagger representing the API schema."""
if self.endpoints is None: endpoints = self.get_endpoints(None if public else request)
inspector = self._gen.endpoint_inspector_cls(self._gen.patterns, self._gen.urlconf) components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS)
self.endpoints = inspector.get_api_endpoints() paths = self.get_paths(endpoints, components)
self.clean_endpoints(None if public else request)
paths = self.get_paths()
url = self._gen.url url = self._gen.url
if not url and request is not None: if not url and request is not None:
@ -38,30 +35,37 @@ class OpenAPISchemaGenerator(object):
return openapi.Swagger( return openapi.Swagger(
info=self.info, paths=paths, info=self.info, paths=paths,
_url=url, _version=self.version, _url=url, _version=self.version, **components
) )
def create_view(self, callback, method, request): def create_view(self, callback, method, request=None):
"""Create a view instance from a view callback as registered in urlpatterns."""
view = self._gen.create_view(callback, method, request) view = self._gen.create_view(callback, method, request)
overrides = getattr(callback, 'swagger_auto_schema', None) overrides = getattr(callback, 'swagger_auto_schema', None)
if overrides is not None: if overrides is not None:
# decorated function based view # decorated function based view must have its decorator information passed on to th re-instantiated view
for method, _ in overrides.items(): for method, _ in overrides.items():
view_method = getattr(view, method, None) view_method = getattr(view, method, None)
if view_method is not None: if view_method is not None:
setattr(view_method.__func__, 'swagger_auto_schema', overrides) setattr(view_method.__func__, 'swagger_auto_schema', overrides)
return view return view
def clean_endpoints(self, request): def get_endpoints(self, request=None):
"""Generate {path: (view_class, [(method, view)]) given (path, method, callback).""" """Iterate over all the registered endpoints in the API.
:param rest_framework.request.Request request: used for returning only endpoints available to the given request
:return: {path: (view_class, list[(http_method, view_instance)])"""
inspector = self._gen.endpoint_inspector_cls(self._gen.patterns, self._gen.urlconf)
endpoints = inspector.get_api_endpoints()
view_paths = defaultdict(list) view_paths = defaultdict(list)
view_cls = {} view_cls = {}
for path, method, callback in self.endpoints: for path, method, callback in endpoints:
view = self.create_view(callback, method, request) view = self.create_view(callback, method, request)
path = self._gen.coerce_path(path, method, view) path = self._gen.coerce_path(path, method, view)
view_paths[path].append((method, view)) view_paths[path].append((method, view))
view_cls[path] = callback.cls view_cls[path] = callback.cls
self.endpoints = {path: (view_cls[path], methods) for path, methods in view_paths.items()} return {path: (view_cls[path], methods) for path, methods in view_paths.items()}
def get_operation_keys(self, subpath, method, view): def get_operation_keys(self, subpath, method, view):
""" """
@ -77,14 +81,20 @@ class OpenAPISchemaGenerator(object):
""" """
return self._gen.get_keys(subpath, method, view) return self._gen.get_keys(subpath, method, view)
def get_paths(self): def get_paths(self, endpoints, components):
if not self.endpoints: """Generate the Swagger Paths for the API from the given endpoints.
return []
prefix = self._gen.determine_path_prefix(self.endpoints.keys()) :param dict endpoints: endpoints as returned by get_endpoints
:param ReferenceResolver components: resolver/container for Swagger References
"""
if not endpoints:
return openapi.Paths(paths={})
prefix = self._gen.determine_path_prefix(endpoints.keys())
paths = OrderedDict() paths = OrderedDict()
default_schema_cls = SwaggerAutoSchema default_schema_cls = SwaggerAutoSchema
for path, (view_cls, methods) in sorted(self.endpoints.items()): for path, (view_cls, methods) in sorted(endpoints.items()):
path_parameters = self.get_path_parameters(path, view_cls) path_parameters = self.get_path_parameters(path, view_cls)
operations = {} operations = {}
for method, view in methods: for method, view in methods:
@ -94,7 +104,7 @@ class OpenAPISchemaGenerator(object):
operation_keys = self.get_operation_keys(path[len(prefix):], method, view) operation_keys = self.get_operation_keys(path[len(prefix):], method, view)
overrides = self.get_overrides(view, method) overrides = self.get_overrides(view, method)
auto_schema_cls = overrides.get('auto_schema', default_schema_cls) auto_schema_cls = overrides.get('auto_schema', default_schema_cls)
schema = auto_schema_cls(view, path, method, overrides) schema = auto_schema_cls(view, path, method, overrides, components)
operations[method.lower()] = schema.get_operation(operation_keys) operations[method.lower()] = schema.get_operation(operation_keys)
paths[path] = openapi.PathItem(parameters=path_parameters, **operations) paths[path] = openapi.PathItem(parameters=path_parameters, **operations)

View File

@ -1,160 +1,45 @@
import functools
import inspect import inspect
from collections import OrderedDict from collections import OrderedDict
import coreschema import coreschema
from django.core.validators import RegexValidator from rest_framework import serializers, status
from django.utils.encoding import force_text
from rest_framework import serializers
from rest_framework.request import is_form_media_type from rest_framework.request import is_form_media_type
from rest_framework.schemas import AutoSchema from rest_framework.schemas import AutoSchema
from rest_framework.status import is_success
from rest_framework.viewsets import GenericViewSet from rest_framework.viewsets import GenericViewSet
from drf_swagger.errors import SwaggerGenerationError
from . import openapi from . import openapi
from .utils import no_body, is_list_view from .errors import SwaggerGenerationError
from .utils import serializer_field_to_swagger, no_body, is_list_view
def serializer_field_to_swagger(field, swagger_object_type, **kwargs): def force_serializer_instance(serializer):
"""Convert a drf Serializer or Field instance into a Swagger object. if inspect.isclass(serializer):
assert issubclass(serializer, serializers.BaseSerializer), "Serializer required, not %s" % serializer.__name__
return serializer()
:param rest_framework.serializers.Field field: the source field assert isinstance(serializer, serializers.BaseSerializer), \
:param type swagger_object_type: should be one of Schema, Parameter, Items "Serializer class or instance required, not %s" % type(serializer).__name__
:param kwargs: extra attributes for constructing the object; return serializer
if swagger_object_type is Parameter, `name` and `in_` should be provided
:return Swagger,Parameter,Items: the swagger object
"""
assert swagger_object_type in (openapi.Schema, openapi.Parameter, openapi.Items)
title = force_text(field.label) if field.label else None
title = title if swagger_object_type == openapi.Schema else None # only Schema has title
title = None
description = force_text(field.help_text) if field.help_text else None
description = description if swagger_object_type != openapi.Items else None # Items has no description either
SwaggerType = functools.partial(swagger_object_type, title=title, description=description, **kwargs)
# arrays in Schema have Schema elements, arrays in Parameter and Items have Items elements
ChildSwaggerType = openapi.Schema if swagger_object_type == openapi.Schema else openapi.Items
# ------ NESTED
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = serializer_field_to_swagger(field.child, ChildSwaggerType)
return SwaggerType(
type=openapi.TYPE_ARRAY,
items=child_schema,
)
elif isinstance(field, serializers.Serializer):
if swagger_object_type != openapi.Schema:
raise SwaggerGenerationError("cannot instantiate nested serializer as "
+ swagger_object_type.__name__)
return SwaggerType(
type=openapi.TYPE_OBJECT,
properties=OrderedDict(
(key, serializer_field_to_swagger(value, ChildSwaggerType))
for key, value
in field.fields.items()
)
)
elif isinstance(field, serializers.ManyRelatedField):
child_schema = serializer_field_to_swagger(field.child_relation, ChildSwaggerType)
return SwaggerType(
type=openapi.TYPE_ARRAY,
items=child_schema,
unique_items=True, # is this OK?
)
elif isinstance(field, serializers.RelatedField):
# TODO: infer type for PrimaryKeyRelatedField?
return SwaggerType(type=openapi.TYPE_STRING)
# ------ CHOICES
elif isinstance(field, serializers.MultipleChoiceField):
return SwaggerType(
type=openapi.TYPE_ARRAY,
items=ChildSwaggerType(
type=openapi.TYPE_STRING,
enum=list(field.choices.keys())
)
)
elif isinstance(field, serializers.ChoiceField):
return SwaggerType(type=openapi.TYPE_STRING, enum=list(field.choices.keys()))
# ------ BOOL
elif isinstance(field, serializers.BooleanField):
return SwaggerType(type=openapi.TYPE_BOOLEAN)
# ------ NUMERIC
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
# TODO: min_value max_value
return SwaggerType(type=openapi.TYPE_NUMBER)
elif isinstance(field, serializers.IntegerField):
# TODO: min_value max_value
return SwaggerType(type=openapi.TYPE_INTEGER)
# ------ STRING
elif isinstance(field, serializers.EmailField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_EMAIL)
elif isinstance(field, serializers.RegexField):
return SwaggerType(type=openapi.TYPE_STRING, pattern=find_regex(field))
elif isinstance(field, serializers.SlugField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_SLUG)
elif isinstance(field, serializers.URLField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)
elif isinstance(field, serializers.IPAddressField):
format = {'ipv4': openapi.FORMAT_IPV4, 'ipv6': openapi.FORMAT_IPV6}.get(field.protocol, None)
return SwaggerType(type=openapi.TYPE_STRING, format=format)
elif isinstance(field, serializers.CharField):
# TODO: min_length max_length (for all CharField subclasses above too)
return SwaggerType(type=openapi.TYPE_STRING)
elif isinstance(field, serializers.UUIDField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_UUID)
# ------ DATE & TIME
elif isinstance(field, serializers.DateField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_DATE)
elif isinstance(field, serializers.DateTimeField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_DATETIME)
# ------ OTHERS
elif isinstance(field, serializers.FileField):
# swagger 2.0 does not support specifics about file fields, so ImageFile gets no special treatment
# OpenAPI 3.0 does support it, so a future implementation could handle this better
# TODO: appropriate produces/consumes somehow/somewhere?
if swagger_object_type != openapi.Parameter:
raise SwaggerGenerationError("parameter of type file is supported only in formData Parameter")
return SwaggerType(type=openapi.TYPE_FILE)
elif isinstance(field, serializers.JSONField):
return SwaggerType(
type=openapi.TYPE_STRING,
format=openapi.FORMAT_BINARY if field.binary else None
)
elif isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema:
child_schema = serializer_field_to_swagger(field.child, ChildSwaggerType)
return SwaggerType(
type=openapi.TYPE_OBJECT,
additional_properties=child_schema
)
# TODO unhandled fields: TimeField DurationField HiddenField ModelField NullBooleanField?
# TODO: return info about required/allowed empty
# everything else gets string by default
return SwaggerType(type=openapi.TYPE_STRING)
def find_regex(regex_field):
regex_validator = None
for validator in regex_field.validators:
if isinstance(validator, RegexValidator):
if regex_validator is not None:
# bail if multiple validators are found - no obvious way to choose
return None
regex_validator = validator
# regex_validator.regex should be a compiled re object...
return getattr(getattr(regex_validator, 'regex', None), 'pattern', None)
class SwaggerAutoSchema(object): class SwaggerAutoSchema(object):
def __init__(self, view, path, method, overrides): def __init__(self, view, path, method, overrides, components):
"""Inspector class responsible for providing Operation definitions given a
:param view: the view associated with this endpoint
:param str path: the path component of the operation URL
:param str method: the http method of the operation
:param dict overrides: manual overrides as passed to @swagger_auto_schema
:param openapi.ReferenceResolver components: referenceable components
"""
super(SwaggerAutoSchema, self).__init__() super(SwaggerAutoSchema, self).__init__()
self._sch = AutoSchema() self._sch = AutoSchema()
self.view = view self.view = view
self.path = path self.path = path
self.method = method self.method = method
self.overrides = overrides self.overrides = overrides
self.components = components
self._sch.view = view self._sch.view = view
def get_operation(self, operation_keys): def get_operation(self, operation_keys):
@ -163,7 +48,7 @@ class SwaggerAutoSchema(object):
:param tuple[str] operation_keys: an array of keys describing the hierarchical layout of this view in the API; :param tuple[str] operation_keys: an array of keys describing the hierarchical layout of this view in the API;
e.g. ('snippets', 'list'), ('snippets', 'retrieve'), etc. e.g. ('snippets', 'list'), ('snippets', 'retrieve'), etc.
:return openapi.Operation: the resulting Operation object :return: the resulting Operation object
""" """
consumes = self.get_consumes() consumes = self.get_consumes()
@ -177,7 +62,7 @@ class SwaggerAutoSchema(object):
description = self.get_description() description = self.get_description()
responses = self.get_responses() responses = self.get_responses()
# manual_responses = self.overrides.get('responses', None) or {}
return openapi.Operation( return openapi.Operation(
operation_id='_'.join(operation_keys), operation_id='_'.join(operation_keys),
description=description, description=description,
@ -194,7 +79,8 @@ class SwaggerAutoSchema(object):
- a list of primitive Parameters parsed as form data - a list of primitive Parameters parsed as form data
:param list[str] consumes: a list of MIME types this request accepts as body :param list[str] consumes: a list of MIME types this request accepts as body
:return list[Parameter]: a (potentially empty) list of openapi.Parameter in: either `body` or `formData` :return: a (potentially empty) list of openapi.Parameter in: either `body` or `formData`
:rtype: list[openapi.Parameter]
""" """
# only PUT, PATCH or POST can have a request body # only PUT, PATCH or POST can have a request body
if self.method not in ('PUT', 'PATCH', 'POST'): if self.method not in ('PUT', 'PATCH', 'POST'):
@ -205,7 +91,7 @@ class SwaggerAutoSchema(object):
if serializer is None: if serializer is None:
return [] return []
if isinstance(serializer, openapi.Schema): if isinstance(serializer, openapi.Schema.OR_REF):
schema = serializer schema = serializer
if any(is_form_media_type(encoding) for encoding in consumes): if any(is_form_media_type(encoding) for encoding in consumes):
@ -220,17 +106,17 @@ class SwaggerAutoSchema(object):
def get_request_serializer(self): def get_request_serializer(self):
"""Return the request serializer (used for parsing the request payload) for this endpoint. """Return the request serializer (used for parsing the request payload) for this endpoint.
:return serializers.Serializer: the request serializer :return: the request serializer
:rtype: serializers.BaseSerializer
""" """
body_override = self.overrides.get('request_body', None) body_override = self.overrides.get('request_body', None)
if body_override is not None: if body_override is not None:
if body_override is no_body: if body_override is no_body:
return None return None
if inspect.isclass(body_override): if isinstance(body_override, openapi.Schema.OR_REF):
assert issubclass(body_override, serializers.Serializer) return body_override
return body_override() return force_serializer_instance(body_override)
return body_override
else: else:
if not hasattr(self.view, 'get_serializer'): if not hasattr(self.view, 'get_serializer'):
return None return None
@ -240,33 +126,37 @@ class SwaggerAutoSchema(object):
"""Given a Serializer, return a list of in: formData Parameters. """Given a Serializer, return a list of in: formData Parameters.
:param serializer: the view's request serialzier :param serializer: the view's request serialzier
:rtype: list[openapi.Parameter]
""" """
fields = getattr(serializer, 'fields', {})
return [ return [
self.field_to_swagger(value, openapi.Parameter, name=key, in_=openapi.IN_FORM) self.field_to_parameter(value, key, openapi.IN_FORM)
for key, value for key, value
in serializer.fields.items() in fields.items()
] ]
def get_request_body_schema(self, serializer): def get_request_body_schema(self, serializer):
"""Return the Schema for a given request's body data. Only applies to PUT, PATCH and POST requests. """Return the Schema for a given request's body data. Only applies to PUT, PATCH and POST requests.
:param serializer: the view's request serialzier :param serializers.BaseSerializer serializer: the view's request serialzier
:return openapi.Schema: the request body schema :return: the request body schema
:rtype: openapi.Schema
""" """
return self.field_to_swagger(serializer, openapi.Schema) return self.serializer_to_schema(serializer)
def make_body_parameter(self, schema): def make_body_parameter(self, schema):
"""Given a Schema object, create an in: body Parameter. """Given a Schema object, create an in: body Parameter.
:param openapi.Schema schema: the request body schema :param openapi.Schema schema: the request body schema
""" """
return openapi.Parameter(name='data', in_=openapi.IN_BODY, schema=schema) return openapi.Parameter(name='data', in_=openapi.IN_BODY, required=True, schema=schema)
def add_manual_parameters(self, parameters): def add_manual_parameters(self, parameters):
"""Add/replace parameters from the given list of automatically generated request parameters. """Add/replace parameters from the given list of automatically generated request parameters.
:param list[openapi.Parameter] parameters: genereated parameters :param list[openapi.Parameter] parameters: genereated parameters
:return list[openapi.Parameter]: modified parameters :return: modified parameters
:rtype: list[openapi.Parameter]
""" """
parameters = OrderedDict(((param.name, param.in_), param) for param in parameters) parameters = OrderedDict(((param.name, param.in_), param) for param in parameters)
manual_parameters = self.overrides.get('manual_parameters', None) or [] manual_parameters = self.overrides.get('manual_parameters', None) or []
@ -284,13 +174,56 @@ class SwaggerAutoSchema(object):
def get_responses(self): def get_responses(self):
"""Get the possible responses for this view as a swagger Responses object. """Get the possible responses for this view as a swagger Responses object.
:return Responses: the documented responses :return: the documented responses
""" """
response_serializers = self.get_response_serializers() response_serializers = self.get_response_serializers()
return openapi.Responses( return openapi.Responses(
responses=self.get_response_schemas(response_serializers) responses=self.get_response_schemas(response_serializers)
) )
def get_paged_response_schema(self, response_schema):
"""Add appropriate paging fields to a response Schema.
:param openapi.Schema response_schema: the response schema that must be paged.
"""
assert response_schema.type == openapi.TYPE_ARRAY, "array return expected for paged response"
paged_schema = openapi.Schema(
type=openapi.TYPE_OBJECT,
properties={
'count': openapi.Schema(type=openapi.TYPE_INTEGER),
'next': openapi.Schema(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI),
'previous': openapi.Schema(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI),
'results': response_schema,
},
required=['count', 'results']
)
return paged_schema
def get_default_responses(self):
method = self.method.lower()
default_status = status.HTTP_200_OK
default_schema = ''
if method == 'post':
default_status = status.HTTP_201_CREATED
default_schema = self.get_request_serializer()
elif method == 'delete':
default_status = status.HTTP_204_NO_CONTENT
elif method in ('get', 'put', 'patch'):
default_schema = self.get_request_serializer()
default_schema = default_schema or ''
if default_schema:
if not isinstance(default_schema, openapi.Schema):
default_schema = self.serializer_to_schema(default_schema)
if is_list_view(self.path, self.method, self.view) and self.method.lower() == 'get':
default_schema = openapi.Schema(type=openapi.TYPE_ARRAY, items=default_schema)
if self.should_page():
default_schema = self.get_paged_response_schema(default_schema)
return {str(default_status): default_schema}
def get_response_serializers(self): def get_response_serializers(self):
"""Return the response codes that this view is expected to return, and the serializer for each response body. """Return the response codes that this view is expected to return, and the serializer for each response body.
The return value should be a dict where the keys are possible status codes, and values are either strings, The return value should be a dict where the keys are possible status codes, and values are either strings,
@ -298,11 +231,15 @@ class SwaggerAutoSchema(object):
:return dict: the response serializers :return dict: the response serializers
""" """
if self.method.lower() == 'post': manual_responses = self.overrides.get('responses', None) or {}
return {'201': ''} manual_responses = OrderedDict((str(sc), resp) for sc, resp in manual_responses.items())
if self.method.lower() == 'delete':
return {'204': ''} responses = {}
return {'200': ''} if not any(is_success(int(sc)) for sc in manual_responses if sc != 'default'):
responses = self.get_default_responses()
responses.update((str(sc), resp) for sc, resp in manual_responses.items())
return responses
def get_response_schemas(self, response_serializers): def get_response_schemas(self, response_serializers):
"""Return the `openapi.Response` objects calculated for this view. """Return the `openapi.Response` objects calculated for this view.
@ -311,32 +248,43 @@ class SwaggerAutoSchema(object):
:return dict[str, openapi.Response]: a dictionary of status code to Response object :return dict[str, openapi.Response]: a dictionary of status code to Response object
""" """
responses = {} responses = {}
for status, serializer in response_serializers.items(): for sc, serializer in response_serializers.items():
if isinstance(serializer, str): if isinstance(serializer, str):
response = openapi.Response( response = openapi.Response(
description=serializer description=serializer
) )
elif isinstance(serializer, openapi.Response): elif isinstance(serializer, openapi.Response):
response = serializer response = serializer
else: if not isinstance(response.schema, openapi.Schema.OR_REF):
serializer = force_serializer_instance(response.schema)
response.schema = self.serializer_to_schema(serializer)
elif isinstance(serializer, openapi.Schema.OR_REF):
response = openapi.Response( response = openapi.Response(
description='', description='',
schema=self.field_to_swagger(serializer, openapi.Schema) schema=serializer,
)
else:
serializer = force_serializer_instance(serializer)
response = openapi.Response(
description='',
schema=self.serializer_to_schema(serializer),
) )
responses[str(status)] = response responses[str(sc)] = response
return responses return responses
def get_query_parameters(self): def get_query_parameters(self):
"""Return the query parameters accepted by this view.""" """Return the query parameters accepted by this view.
:rtype: list[openapi.Parameter]"""
return self.get_filter_parameters() + self.get_pagination_parameters() return self.get_filter_parameters() + self.get_pagination_parameters()
def should_filter(self): def should_filter(self):
if getattr(self.view, 'filter_backends', None) is None: if not getattr(self.view, 'filter_backends', None):
return False return False
if self.method.lower() not in ["get", "put", "patch", "delete"]: if self.method.lower() not in ["get", "delete"]:
return False return False
if not isinstance(self.view, GenericViewSet): if not isinstance(self.view, GenericViewSet):
@ -344,39 +292,109 @@ class SwaggerAutoSchema(object):
return is_list_view(self.path, self.method, self.view) return is_list_view(self.path, self.method, self.view)
def get_filter_backend_parameters(self, filter_backend):
"""Get the filter parameters for a single filter backend **instance**.
:param BaseFilterBackend filter_backend: the filter backend
:rtype: list[openapi.Parameter]
"""
fields = []
if hasattr(filter_backend, 'get_schema_fields'):
fields = filter_backend.get_schema_fields(self.view)
return [self.coreapi_field_to_parameter(field) for field in fields]
def get_filter_parameters(self): def get_filter_parameters(self):
"""Return the parameters added to the view by its filter backends.""" """Return the parameters added to the view by its filter backends.
:rtype: list[openapi.Parameter]
"""
if not self.should_filter(): if not self.should_filter():
return [] return []
fields = [] fields = []
for filter_backend in self.view.filter_backends: for filter_backend in self.view.filter_backends:
filter = filter_backend() fields += self.get_filter_backend_parameters(filter_backend())
if hasattr(filter, 'get_schema_fields'):
fields += filter.get_schema_fields(self.view) return fields
return [self.coreapi_field_to_parameter(field) for field in fields]
def should_page(self): def should_page(self):
if not hasattr(self.view, 'paginator'): if not hasattr(self.view, 'paginator'):
return False return False
if self.view.paginator is None:
return False
if self.method.lower() != 'get':
return False
return is_list_view(self.path, self.method, self.view) return is_list_view(self.path, self.method, self.view)
def get_paginator_parameters(self, paginator):
"""Get the pagination parameters for a single paginator **instance**.
:param BasePagination paginator: the paginator
:rtype: list[openapi.Parameter]
"""
fields = []
if hasattr(paginator, 'get_schema_fields'):
fields = paginator.get_schema_fields(self.view)
return [self.coreapi_field_to_parameter(field) for field in fields]
def get_pagination_parameters(self): def get_pagination_parameters(self):
"""Return the parameters added to the view by its paginator.""" """Return the parameters added to the view by its paginator.
:rtype: list[openapi.Parameter]"""
if not self.should_page(): if not self.should_page():
return [] return []
paginator = self.view.paginator return self.get_paginator_parameters(self.view.paginator)
if not hasattr(paginator, 'get_schema_fields'):
return []
return [self.coreapi_field_to_parameter(field) for field in paginator.get_schema_fields(self.view)] def get_description(self):
"""Return an operation description determined as appropriate from the view's method and class docstrings.
:return: the operation description
:rtype: str
"""
description = self.overrides.get('operation_description', None)
if description is None:
description = self._sch.get_description(self.path, self.method)
return description
def get_consumes(self):
"""Return the MIME types this endpoint can consume.
:rtype: list[str]
"""
media_types = [parser.media_type for parser in getattr(self.view, 'parser_classes', [])]
if all(is_form_media_type(encoding) for encoding in media_types):
return media_types
return media_types[:1]
def serializer_to_schema(self, serializer):
"""Convert a DRF Serializer instance to an openapi.Schema.
:param serializers.BaseSerializer serializer:
:rtype: openapi.Schema
"""
definitions = self.components.with_scope(openapi.SCHEMA_DEFINITIONS)
return serializer_field_to_swagger(serializer, openapi.Schema, definitions)
def field_to_parameter(self, field, name, in_):
"""Convert a DRF serializer Field to a swagger Parameter object.
:param coreapi.Field field:
:param str name: the name of the parameter
:param str in_: the location of the parameter, one of the `openapi.IN_*` constants
:rtype: openapi.Parameter
"""
return serializer_field_to_swagger(field, openapi.Parameter, name=name, in_=in_)
def coreapi_field_to_parameter(self, field): def coreapi_field_to_parameter(self, field):
"""Convert an instance of `coreapi.Field` to a swagger Parameter object. """Convert an instance of `coreapi.Field` to a swagger Parameter object.
:param coreapi.Field field: the coreapi field :param coreapi.Field field:
:rtype: openapi.Parameter
""" """
location_to_in = { location_to_in = {
'query': openapi.IN_QUERY, 'query': openapi.IN_QUERY,
@ -393,27 +411,7 @@ class SwaggerAutoSchema(object):
return openapi.Parameter( return openapi.Parameter(
name=field.name, name=field.name,
in_=location_to_in[field.location], in_=location_to_in[field.location],
type=coreapi_types.get(field.schema.__class__, openapi.TYPE_STRING), type=coreapi_types.get(type(field.schema), openapi.TYPE_STRING),
required=field.required, required=field.required,
description=field.schema.description, description=field.schema.description,
) )
def get_description(self):
"""Return an operation description determined as appropriate from the view's method and class docstrings.
:return str: the operation description
"""
description = self.overrides.get('operation_description', None)
if description is None:
description = self._sch.get_description(self.path, self.method)
return description
def get_consumes(self):
"""Return the MIME types this endpoint can consume."""
media_types = [parser.media_type for parser in getattr(self.view, 'parser_classes', [])]
if all(is_form_media_type(encoding) for encoding in media_types):
return media_types
return media_types[:1]
def field_to_swagger(self, field, swagger_object_type, **kwargs):
return serializer_field_to_swagger(field, swagger_object_type, **kwargs)

View File

@ -1,8 +1,8 @@
from django.http import HttpResponse from django.http import HttpResponse
from django.utils.deprecation import MiddlewareMixin from django.utils.deprecation import MiddlewareMixin
from drf_swagger.errors import SwaggerValidationError
from .codecs import _OpenAPICodec from .codecs import _OpenAPICodec
from .errors import SwaggerValidationError
class SwaggerExceptionMiddleware(MiddlewareMixin): class SwaggerExceptionMiddleware(MiddlewareMixin):
@ -15,4 +15,4 @@ class SwaggerExceptionMiddleware(MiddlewareMixin):
content_type = codec.media_type content_type = codec.media_type
return HttpResponse(err, status=500, content_type=content_type) return HttpResponse(err, status=500, content_type=content_type)
return None return None # pragma: no cover

View File

@ -1,3 +1,4 @@
import copy
from collections import OrderedDict from collections import OrderedDict
from coreapi.compat import urlparse from coreapi.compat import urlparse
@ -39,6 +40,8 @@ IN_QUERY = 'query'
IN_FORM = 'formData' IN_FORM = 'formData'
IN_HEADER = 'header' IN_HEADER = 'header'
SCHEMA_DEFINITIONS = 'definitions'
def make_swagger_name(attribute_name): def make_swagger_name(attribute_name):
""" """
@ -63,7 +66,7 @@ class SwaggerDict(OrderedDict):
def __init__(self, **attrs): def __init__(self, **attrs):
super(SwaggerDict, self).__init__() super(SwaggerDict, self).__init__()
self._extras__ = attrs self._extras__ = attrs
if self.__class__ == SwaggerDict: if type(self) == SwaggerDict:
self._insert_extras__() self._insert_extras__()
def __setattr__(self, key, value): def __setattr__(self, key, value):
@ -79,7 +82,7 @@ class SwaggerDict(OrderedDict):
try: try:
return self[make_swagger_name(item)] return self[make_swagger_name(item)]
except KeyError as e: except KeyError as e:
raise_from(AttributeError("no attribute " + item), e) raise_from(AttributeError("object of class " + type(self).__name__ + " has no attribute " + item), e)
def __delattr__(self, item): def __delattr__(self, item):
if item.startswith('_'): if item.startswith('_'):
@ -98,6 +101,12 @@ class SwaggerDict(OrderedDict):
for attr, val in self._extras__.items(): for attr, val in self._extras__.items():
setattr(self, attr, val) setattr(self, attr, val)
# noinspection PyArgumentList,PyDefaultArgument
def __deepcopy__(self, memodict={}):
result = OrderedDict(list(self.items()))
result.update(copy.deepcopy(result, memodict))
return result
class Contact(SwaggerDict): class Contact(SwaggerDict):
"""Swagger Contact object """Swagger Contact object
@ -111,7 +120,7 @@ class Contact(SwaggerDict):
def __init__(self, name=None, url=None, email=None, **extra): def __init__(self, name=None, url=None, email=None, **extra):
super(Contact, self).__init__(**extra) super(Contact, self).__init__(**extra)
if name is None and url is None and email is None: if name is None and url is None and email is None:
raise ValueError("one of name, url or email is requires for Swagger Contact object") raise AssertionError("one of name, url or email is requires for Swagger Contact object")
self.name = name self.name = name
self.url = url self.url = url
self.email = email self.email = email
@ -128,7 +137,7 @@ class License(SwaggerDict):
def __init__(self, name, url=None, **extra): def __init__(self, name, url=None, **extra):
super(License, self).__init__(**extra) super(License, self).__init__(**extra)
if name is None: if name is None:
raise ValueError("name is required for Swagger License object") raise AssertionError("name is required for Swagger License object")
self.name = name self.name = name
self.url = url self.url = url
self._insert_extras__() self._insert_extras__()
@ -149,11 +158,11 @@ class Info(SwaggerDict):
**extra): **extra):
super(Info, self).__init__(**extra) super(Info, self).__init__(**extra)
if title is None or default_version is None: if title is None or default_version is None:
raise ValueError("title and version are required for Swagger info object") raise AssertionError("title and version are required for Swagger info object")
if contact is not None and not isinstance(contact, Contact): if contact is not None and not isinstance(contact, Contact):
raise ValueError("contact must be a Contact object") raise AssertionError("contact must be a Contact object")
if license is not None and not isinstance(license, License): if license is not None and not isinstance(license, License):
raise ValueError("license must be a License object") raise AssertionError("license must be a License object")
self.title = title self.title = title
self._default_version = default_version self._default_version = default_version
self.description = description self.description = description
@ -164,12 +173,11 @@ class Info(SwaggerDict):
class Swagger(SwaggerDict): class Swagger(SwaggerDict):
def __init__(self, info=None, _url=None, _version=None, paths=None, **extra): def __init__(self, info=None, _url=None, _version=None, paths=None, definitions=None, **extra):
super(Swagger, self).__init__(**extra) super(Swagger, self).__init__(**extra)
self.swagger = '2.0' self.swagger = '2.0'
self.info = info self.info = info
self.info.version = _version or info._default_version self.info.version = _version or info._default_version
self.paths = paths
if _url: if _url:
url = urlparse.urlparse(_url) url = urlparse.urlparse(_url)
@ -177,8 +185,10 @@ class Swagger(SwaggerDict):
self.host = url.netloc self.host = url.netloc
if url.scheme: if url.scheme:
self.schemes = [url.scheme] self.schemes = [url.scheme]
self.base_path = '/' self.base_path = '/'
self.paths = paths
self.definitions = definitions
self._insert_extras__() self._insert_extras__()
@ -237,7 +247,7 @@ class Parameter(SwaggerDict):
type=None, format=None, enum=None, pattern=None, items=None, **extra): type=None, format=None, enum=None, pattern=None, items=None, **extra):
super(Parameter, self).__init__(**extra) super(Parameter, self).__init__(**extra)
if (not schema and not type) or (schema and type): if (not schema and not type) or (schema and type):
raise ValueError("either schema or type are required for Parameter object!") raise AssertionError("either schema or type are required for Parameter object!")
self.name = name self.name = name
self.in_ = in_ self.in_ = in_
self.description = description self.description = description
@ -252,9 +262,15 @@ class Parameter(SwaggerDict):
class Schema(SwaggerDict): class Schema(SwaggerDict):
OR_REF = ()
def __init__(self, description=None, required=None, type=None, properties=None, additional_properties=None, def __init__(self, description=None, required=None, type=None, properties=None, additional_properties=None,
format=None, enum=None, pattern=None, items=None, **extra): format=None, enum=None, pattern=None, items=None, **extra):
super(Schema, self).__init__(**extra) super(Schema, self).__init__(**extra)
if required is True or required is False:
# common error
raise AssertionError(
"the requires attribute of schema must be an array of required properties, not a boolean!")
self.description = description self.description = description
self.required = required self.required = required
self.type = type self.type = type
@ -267,11 +283,37 @@ class Schema(SwaggerDict):
self._insert_extras__() self._insert_extras__()
class Ref(SwaggerDict): class _Ref(SwaggerDict):
def __init__(self, ref): def __init__(self, resolver, name, scope, expected_type):
super(Ref, self).__init__() super(_Ref, self).__init__()
self.ref = ref assert not type(self) == _Ref, "do not instantiate _Ref directly"
self._insert_extras__() ref_name = "#/{scope}/{name}".format(scope=scope, name=name)
obj = resolver.get(name, scope)
assert isinstance(obj, expected_type), ref_name + " is a {actual}, not a {expected}" \
.format(actual=type(obj).__name__, expected=expected_type.__name__)
self.ref = ref_name
def __setitem__(self, key, value, **kwargs):
if key == "$ref":
return super(_Ref, self).__setitem__(key, value, **kwargs)
raise NotImplementedError("only $ref can be set on Reference objects (not %s)" % key)
def __delitem__(self, key, **kwargs):
raise NotImplementedError("cannot delete property of Reference object")
class SchemaRef(_Ref):
def __init__(self, resolver, schema_name):
"""Add a reference to a named Schema defined in the #/definitions/ object.
:param ReferenceResolver resolver: component resolver which must contain the definition
:param str schema_name: schema name
"""
assert SCHEMA_DEFINITIONS in resolver.scopes
super(SchemaRef, self).__init__(resolver, schema_name, SCHEMA_DEFINITIONS, Schema)
Schema.OR_REF = (Schema, SchemaRef)
class Responses(SwaggerDict): class Responses(SwaggerDict):
@ -291,3 +333,92 @@ class Response(SwaggerDict):
self.schema = schema self.schema = schema
self.examples = examples self.examples = examples
self._insert_extras__() self._insert_extras__()
class ReferenceResolver(object):
"""A mapping type intended for storing objects pointed at by Swagger Refs.
Provides support and checks for different refernce scopes, e.g. 'definitions'.
For example:
> components = ReferenceResolver('definitions', 'parameters')
> definitions = ReferenceResolver.with_scope('definitions')
> definitions.set('Article', Schema(...))
> print(components)
{'definitions': OrderedDict([('Article', Schema(...)]), 'parameters': OrderedDict()}
"""
def __init__(self, *scopes):
self._objects = OrderedDict()
self._force_scope = None
for scope in scopes:
assert isinstance(scope, str), "scope names must be strings"
self._objects[scope] = OrderedDict()
def with_scope(self, scope):
assert scope in self.scopes, "unknown scope %s" % scope
ret = ReferenceResolver()
ret._objects = self._objects
ret._force_scope = scope
return ret
def _check_scope(self, scope):
real_scope = self._force_scope or scope
if scope is not None:
assert not self._force_scope or scope == self._force_scope, "cannot overrride forced scope"
assert real_scope and real_scope in self._objects, "invalid scope %s" % scope
return real_scope
def set(self, name, obj, scope=None):
scope = self._check_scope(scope)
assert obj is not None, "referenced objects cannot be None/null"
assert name not in self._objects[scope], "#/%s/%s already exists" % (scope, name)
self._objects[scope][name] = obj
def setdefault(self, name, maker, scope=None):
scope = self._check_scope(scope)
assert callable(maker), "setdefault expects a callable, not %s" % type(maker).__name__
ret = self.getdefault(name, None, scope)
if ret is None:
ret = maker()
assert ret is not None, "maker returned None; referenced objects cannot be None/null"
self.set(name, ret, scope)
return ret
def get(self, name, scope=None):
scope = self._check_scope(scope)
assert name in self._objects[scope], "#/%s/%s is not defined" % (scope, name)
return self._objects[scope][name]
def getdefault(self, name, default=None, scope=None):
scope = self._check_scope(scope)
return self._objects[scope].get(name, default)
def has(self, name, scope=None):
scope = self._check_scope(scope)
return name in self._objects[scope]
def __iter__(self):
if self._force_scope:
return iter(self._objects[self._force_scope])
return iter(self._objects)
@property
def scopes(self):
if self._force_scope:
return [self._force_scope]
return list(self._objects.keys())
# act as mapping
def keys(self):
if self._force_scope:
return self._objects[self._force_scope].keys()
return self._objects.keys()
def __getitem__(self, item):
if self._force_scope:
return self._objects[self._force_scope][item]
return self._objects[item]
def __str__(self):
return str(dict(self))

View File

@ -1,12 +1,16 @@
from rest_framework.mixins import RetrieveModelMixin, DestroyModelMixin from collections import OrderedDict
from django.core.validators import RegexValidator
from django.utils.encoding import force_text
from rest_framework import serializers
from rest_framework.mixins import RetrieveModelMixin, DestroyModelMixin, UpdateModelMixin
from . import openapi
from .errors import SwaggerGenerationError
no_body = object() no_body = object()
class UpdateModelMixing(object):
pass
def is_list_view(path, method, view): def is_list_view(path, method, view):
"""Return True if the given path/method appears to represent a list view (as opposed to a detail/instance view).""" """Return True if the given path/method appears to represent a list view (as opposed to a detail/instance view)."""
# for ViewSets, it could be the default 'list' view, or a list_route # for ViewSets, it could be the default 'list' view, or a list_route
@ -22,7 +26,7 @@ def is_list_view(path, method, view):
return False return False
# for APIView, if it's a detail view it can't also be a list view # for APIView, if it's a detail view it can't also be a list view
if isinstance(view, (RetrieveModelMixin, UpdateModelMixing, DestroyModelMixin)): if isinstance(view, (RetrieveModelMixin, UpdateModelMixin, DestroyModelMixin)):
return False return False
# if the last component in the path is parameterized it's probably not a list view # if the last component in the path is parameterized it's probably not a list view
@ -71,7 +75,7 @@ def swagger_auto_schema(method=None, methods=None, auto_schema=None, request_bod
_methods = [method.lower()] _methods = [method.lower()]
else: else:
_methods = [mth.lower() for mth in methods] _methods = [mth.lower() for mth in methods]
assert not isinstance(_methods, str) assert not isinstance(_methods, str), "`methods` expects to receive; use `method` for a single arg"
assert not any(mth in existing_data for mth in _methods), "method defined multiple times" assert not any(mth in existing_data for mth in _methods), "method defined multiple times"
assert all(mth in available_methods for mth in _methods), "method not bound to %s" % _route assert all(mth in available_methods for mth in _methods), "method not bound to %s" % _route
@ -88,3 +92,162 @@ def swagger_auto_schema(method=None, methods=None, auto_schema=None, request_bod
return view_method return view_method
return decorator return decorator
def serializer_field_to_swagger(field, swagger_object_type, definitions=None, **kwargs):
"""Convert a drf Serializer or Field instance into a Swagger object.
:param rest_framework.serializers.Field field: the source field
:param type swagger_object_type: should be one of Schema, Parameter, Items
:param drf_swagger.openapi.ReferenceResolver definitions: used to serialize Schemas by reference
:param kwargs: extra attributes for constructing the object;
if swagger_object_type is Parameter, `name` and `in_` should be provided
:return Swagger,Parameter,Items: the swagger object
"""
assert swagger_object_type in (openapi.Schema, openapi.Parameter, openapi.Items)
assert not isinstance(field, openapi.SwaggerDict), "passed field is already a SwaggerDict object"
title = force_text(field.label) if field.label else None
title = title if swagger_object_type == openapi.Schema else None # only Schema has title
title = None
description = force_text(field.help_text) if field.help_text else None
description = description if swagger_object_type != openapi.Items else None # Items has no description either
def SwaggerType(**instance_kwargs):
if swagger_object_type == openapi.Parameter:
instance_kwargs['required'] = field.required
instance_kwargs.update(kwargs)
return swagger_object_type(title=title, description=description, **instance_kwargs)
# arrays in Schema have Schema elements, arrays in Parameter and Items have Items elements
ChildSwaggerType = openapi.Schema if swagger_object_type == openapi.Schema else openapi.Items
# ------ NESTED
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = serializer_field_to_swagger(field.child, ChildSwaggerType, definitions)
return SwaggerType(
type=openapi.TYPE_ARRAY,
items=child_schema,
)
elif isinstance(field, serializers.Serializer):
if swagger_object_type != openapi.Schema:
raise SwaggerGenerationError("cannot instantiate nested serializer as " + swagger_object_type.__name__)
assert definitions is not None, "ReferenceResolver required when instantiating Schema"
serializer = field
if hasattr(serializer, '__ref_name__'):
ref_name = serializer.__ref_name__
else:
ref_name = type(serializer).__name__
if ref_name.endswith('Serializer'):
ref_name = ref_name[:-len('Serializer')]
def make_schema_definition():
properties = OrderedDict()
required = []
for key, value in serializer.fields.items():
properties[key] = serializer_field_to_swagger(value, ChildSwaggerType, definitions)
if value.read_only:
properties[key].read_only = value.read_only
if value.required:
required.append(key)
return SwaggerType(
type=openapi.TYPE_OBJECT,
properties=properties,
required=required or None,
)
if not ref_name:
return make_schema_definition()
definitions.setdefault(ref_name, make_schema_definition)
return openapi.SchemaRef(definitions, ref_name)
elif isinstance(field, serializers.ManyRelatedField):
child_schema = serializer_field_to_swagger(field.child_relation, ChildSwaggerType, definitions)
return SwaggerType(
type=openapi.TYPE_ARRAY,
items=child_schema,
unique_items=True, # is this OK?
)
elif isinstance(field, serializers.RelatedField):
# TODO: infer type for PrimaryKeyRelatedField?
return SwaggerType(type=openapi.TYPE_STRING)
# ------ CHOICES
elif isinstance(field, serializers.MultipleChoiceField):
return SwaggerType(
type=openapi.TYPE_ARRAY,
items=ChildSwaggerType(
type=openapi.TYPE_STRING,
enum=list(field.choices.keys())
)
)
elif isinstance(field, serializers.ChoiceField):
return SwaggerType(type=openapi.TYPE_STRING, enum=list(field.choices.keys()))
# ------ BOOL
elif isinstance(field, serializers.BooleanField):
return SwaggerType(type=openapi.TYPE_BOOLEAN)
# ------ NUMERIC
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
# TODO: min_value max_value
return SwaggerType(type=openapi.TYPE_NUMBER)
elif isinstance(field, serializers.IntegerField):
# TODO: min_value max_value
return SwaggerType(type=openapi.TYPE_INTEGER)
# ------ STRING
elif isinstance(field, serializers.EmailField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_EMAIL)
elif isinstance(field, serializers.RegexField):
return SwaggerType(type=openapi.TYPE_STRING, pattern=find_regex(field))
elif isinstance(field, serializers.SlugField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_SLUG)
elif isinstance(field, serializers.URLField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)
elif isinstance(field, serializers.IPAddressField):
format = {'ipv4': openapi.FORMAT_IPV4, 'ipv6': openapi.FORMAT_IPV6}.get(field.protocol, None)
return SwaggerType(type=openapi.TYPE_STRING, format=format)
elif isinstance(field, serializers.CharField):
# TODO: min_length max_length (for all CharField subclasses above too)
return SwaggerType(type=openapi.TYPE_STRING)
elif isinstance(field, serializers.UUIDField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_UUID)
# ------ DATE & TIME
elif isinstance(field, serializers.DateField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_DATE)
elif isinstance(field, serializers.DateTimeField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_DATETIME)
# ------ OTHERS
elif isinstance(field, serializers.FileField):
# swagger 2.0 does not support specifics about file fields, so ImageFile gets no special treatment
# OpenAPI 3.0 does support it, so a future implementation could handle this better
if swagger_object_type != openapi.Parameter:
raise SwaggerGenerationError("parameter of type file is supported only in formData Parameter")
return SwaggerType(type=openapi.TYPE_FILE)
elif isinstance(field, serializers.JSONField):
return SwaggerType(
type=openapi.TYPE_STRING,
format=openapi.FORMAT_BINARY if field.binary else None
)
elif isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema:
child_schema = serializer_field_to_swagger(field.child, ChildSwaggerType, definitions)
return SwaggerType(
type=openapi.TYPE_OBJECT,
additional_properties=child_schema
)
# TODO unhandled fields: TimeField DurationField HiddenField ModelField NullBooleanField?
# everything else gets string by default
return SwaggerType(type=openapi.TYPE_STRING)
def find_regex(regex_field):
regex_validator = None
for validator in regex_field.validators:
if isinstance(validator, RegexValidator):
if regex_validator is not None:
# bail if multiple validators are found - no obvious way to choose
return None
regex_validator = validator
# regex_validator.regex should be a compiled re object...
return getattr(getattr(regex_validator, 'regex', None), 'pattern', None)

View File

@ -84,7 +84,7 @@ def get_schema_view(info, url=None, patterns=None, urlconf=None, public=False, v
generator = self.generator_class(info, version, url, patterns, urlconf) generator = self.generator_class(info, version, url, patterns, urlconf)
schema = generator.get_schema(request, self.public) schema = generator.get_schema(request, self.public)
if schema is None: if schema is None:
raise exceptions.PermissionDenied() raise exceptions.PermissionDenied() # pragma: no cover
return Response(schema) return Response(schema)
@classmethod @classmethod

View File

@ -1,17 +1,24 @@
import datetime import datetime
from django_filters.rest_framework import DjangoFilterBackend, filters from django_filters.rest_framework import DjangoFilterBackend
from rest_framework import viewsets from rest_framework import viewsets
from rest_framework.decorators import detail_route, list_route from rest_framework.decorators import detail_route, list_route
from rest_framework.filters import OrderingFilter
from rest_framework.pagination import LimitOffsetPagination from rest_framework.pagination import LimitOffsetPagination
from rest_framework.parsers import MultiPartParser from rest_framework.parsers import MultiPartParser
from rest_framework.response import Response from rest_framework.response import Response
from articles import serializers from articles import serializers
from articles.models import Article from articles.models import Article
from drf_swagger.inspectors import SwaggerAutoSchema
from drf_swagger.utils import swagger_auto_schema from drf_swagger.utils import swagger_auto_schema
class NoPagingAutoSchema(SwaggerAutoSchema):
def should_page(self):
return False
class ArticleViewSet(viewsets.ModelViewSet): class ArticleViewSet(viewsets.ModelViewSet):
""" """
ArticleViewSet class docstring ArticleViewSet class docstring
@ -27,15 +34,17 @@ class ArticleViewSet(viewsets.ModelViewSet):
""" """
queryset = Article.objects.all() queryset = Article.objects.all()
lookup_field = 'slug' lookup_field = 'slug'
lookup_value_regex = r'[a-z0-9]+(?:-[a-z0-9]+)'
serializer_class = serializers.ArticleSerializer serializer_class = serializers.ArticleSerializer
pagination_class = LimitOffsetPagination pagination_class = LimitOffsetPagination
max_page_size = 5 max_page_size = 5
filter_backends = (DjangoFilterBackend, filters.OrderingFilter) filter_backends = (DjangoFilterBackend, OrderingFilter)
filter_fields = ('title',) filter_fields = ('title',)
ordering_fields = ('date_modified',) ordering_fields = ('date_modified',)
ordering = ('username',) ordering = ('username',)
@swagger_auto_schema(auto_schema=NoPagingAutoSchema)
@list_route(methods=['get']) @list_route(methods=['get'])
def today(self, request): def today(self, request):
today_min = datetime.datetime.combine(datetime.date.today(), datetime.time.min) today_min = datetime.datetime.combine(datetime.date.today(), datetime.time.min)
@ -45,7 +54,7 @@ class ArticleViewSet(viewsets.ModelViewSet):
return Response(serializer.data) return Response(serializer.data)
@swagger_auto_schema(method='get', operation_description="image GET description override") @swagger_auto_schema(method='get', operation_description="image GET description override")
@swagger_auto_schema(method='post', request_body=serializers.ImageUploadSerializer) @swagger_auto_schema(method='post', request_body=serializers.ImageUploadSerializer, responses={200: 'success'})
@detail_route(methods=['get', 'post'], parser_classes=(MultiPartParser,)) @detail_route(methods=['get', 'post'], parser_classes=(MultiPartParser,))
def image(self, request, slug=None): def image(self, request, slug=None):
""" """
@ -57,7 +66,7 @@ class ArticleViewSet(viewsets.ModelViewSet):
"""update method docstring""" """update method docstring"""
return super(ArticleViewSet, self).update(request, *args, **kwargs) return super(ArticleViewSet, self).update(request, *args, **kwargs)
@swagger_auto_schema(operation_description="partial_update description override") @swagger_auto_schema(operation_description="partial_update description override", responses={404: 'slug not found'})
def partial_update(self, request, *args, **kwargs): def partial_update(self, request, *args, **kwargs):
"""partial_update method docstring""" """partial_update method docstring"""
return super(ArticleViewSet, self).partial_update(request, *args, **kwargs) return super(ArticleViewSet, self).partial_update(request, *args, **kwargs)

View File

@ -4,11 +4,15 @@ from snippets.models import Snippet, LANGUAGE_CHOICES, STYLE_CHOICES
class LanguageSerializer(serializers.Serializer): class LanguageSerializer(serializers.Serializer):
__ref_name__ = None
name = serializers.ChoiceField( name = serializers.ChoiceField(
choices=LANGUAGE_CHOICES, default='python', help_text='The name of the programming language') choices=LANGUAGE_CHOICES, default='python', help_text='The name of the programming language')
class ExampleProjectsSerializer(serializers.Serializer): class ExampleProjectSerializer(serializers.Serializer):
__ref_name__ = 'Project'
project_name = serializers.CharField(help_text='Name of the project') project_name = serializers.CharField(help_text='Name of the project')
github_repo = serializers.CharField(required=True, help_text='Github repository of the project') github_repo = serializers.CharField(required=True, help_text='Github repository of the project')
@ -26,7 +30,7 @@ class SnippetSerializer(serializers.Serializer):
language = LanguageSerializer(help_text="Sample help text for language") language = LanguageSerializer(help_text="Sample help text for language")
style = serializers.ChoiceField(choices=STYLE_CHOICES, default='friendly') style = serializers.ChoiceField(choices=STYLE_CHOICES, default='friendly')
lines = serializers.ListField(child=serializers.IntegerField(), allow_empty=True, allow_null=True, required=False) lines = serializers.ListField(child=serializers.IntegerField(), allow_empty=True, allow_null=True, required=False)
example_projects = serializers.ListSerializer(child=ExampleProjectsSerializer()) example_projects = serializers.ListSerializer(child=ExampleProjectSerializer())
def create(self, validated_data): def create(self, validated_data):
""" """

View File

@ -29,6 +29,7 @@ class SnippetDetail(generics.RetrieveUpdateDestroyAPIView):
""" """
queryset = Snippet.objects.all() queryset = Snippet.objects.all()
serializer_class = SnippetSerializer serializer_class = SnippetSerializer
pagination_class = None
def patch(self, request, *args, **kwargs): def patch(self, request, *args, **kwargs):
"""patch method docstring""" """patch method docstring"""

View File

@ -6,30 +6,43 @@ from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from drf_swagger import openapi from drf_swagger import openapi
from drf_swagger.utils import swagger_auto_schema from drf_swagger.utils import swagger_auto_schema, no_body
from users.serializers import UserSerializer from users.serializers import UserSerializer
class UserList(APIView): class UserList(APIView):
"""UserList cbv classdoc""" """UserList cbv classdoc"""
@swagger_auto_schema(responses={200: UserSerializer(many=True)})
def get(self, request): def get(self, request):
queryset = User.objects.all() queryset = User.objects.all()
serializer = UserSerializer(queryset, many=True) serializer = UserSerializer(queryset, many=True)
return Response(serializer.data) return Response(serializer.data)
@swagger_auto_schema(request_body=UserSerializer, operation_description="apiview post description override") @swagger_auto_schema(operation_description="apiview post description override", request_body=openapi.Schema(
type=openapi.TYPE_OBJECT,
required=['username'],
properties={
'username': openapi.Schema(type=openapi.TYPE_STRING)
},
))
def post(self, request): def post(self, request):
serializer = UserSerializer(request.data) serializer = UserSerializer(request.data)
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
serializer.save() serializer.save()
return Response(serializer.data, status=status.HTTP_201_CREATED) return Response(serializer.data, status=status.HTTP_201_CREATED)
@swagger_auto_schema(request_body=no_body, operation_description="dummy operation")
def patch(self, request):
pass
@swagger_auto_schema(method='put', request_body=UserSerializer) @swagger_auto_schema(method='put', request_body=UserSerializer)
@swagger_auto_schema(method='get', manual_parameters=[ @swagger_auto_schema(methods=['get'], manual_parameters=[
openapi.Parameter('test', openapi.IN_QUERY, "test manual param", type=openapi.TYPE_BOOLEAN) openapi.Parameter('test', openapi.IN_QUERY, "test manual param", type=openapi.TYPE_BOOLEAN),
]) ], responses={
200: openapi.Response('response description', UserSerializer),
})
@api_view(['GET', 'PUT']) @api_view(['GET', 'PUT'])
def user_detail(request, pk): def user_detail(request, pk):
"""user_detail fbv docstring""" """user_detail fbv docstring"""

View File

@ -29,8 +29,8 @@ def codec_yaml():
@pytest.fixture @pytest.fixture
def swagger_dict(): def swagger_dict():
swagger = generator().get_schema(None, True) swagger = generator().get_schema(None, True)
json_bytes = codec_yaml().encode(swagger) json_bytes = codec_json().encode(swagger)
return yaml.safe_load(json_bytes.decode('utf-8')) return json.loads(json_bytes.decode('utf-8'))
@pytest.fixture @pytest.fixture
@ -60,5 +60,5 @@ def bad_settings():
@pytest.fixture @pytest.fixture
def reference_schema(): def reference_schema():
with open(os.path.join(os.path.dirname(__file__), 'reference.json')) as reference: with open(os.path.join(os.path.dirname(__file__), 'reference.yaml')) as reference:
return json.load(reference) return yaml.safe_load(reference)

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,971 @@
swagger: '2.0'
info:
title: Snippets API
description: Test description
termsOfService: https://www.google.com/policies/terms/
contact:
email: contact@snippets.local
license:
name: BSD License
version: v1
host: test.local:8002
schemes:
- http
basePath: /
paths:
/articles/:
get:
operationId: articles_list
description: ArticleViewSet class docstring
parameters:
- name: title
in: query
description: ''
required: false
type: string
- name: ordering
in: query
description: Which field to use when ordering the results.
required: false
type: string
- name: limit
in: query
description: Number of results to return per page.
required: false
type: integer
- name: offset
in: query
description: The initial index from which to return the results.
required: false
type: integer
responses:
'200':
description: ''
schema:
required:
- count
- results
type: object
properties:
count:
type: integer
next:
type: string
format: uri
previous:
type: string
format: uri
results:
type: array
items:
$ref: '#/definitions/Article'
consumes:
- application/json
tags:
- articles
post:
operationId: articles_create
description: ArticleViewSet class docstring
parameters:
- name: data
in: body
required: true
schema:
$ref: '#/definitions/Article'
responses:
'201':
description: ''
schema:
$ref: '#/definitions/Article'
consumes:
- application/json
tags:
- articles
parameters: []
/articles/today/:
get:
operationId: articles_today
description: ArticleViewSet class docstring
parameters:
- name: title
in: query
description: ''
required: false
type: string
- name: ordering
in: query
description: Which field to use when ordering the results.
required: false
type: string
responses:
'200':
description: ''
schema:
type: array
items:
$ref: '#/definitions/Article'
consumes:
- application/json
tags:
- articles
parameters: []
/articles/{slug}/:
get:
operationId: articles_read
description: retrieve class docstring
parameters: []
responses:
'200':
description: ''
schema:
$ref: '#/definitions/Article'
consumes:
- application/json
tags:
- articles
put:
operationId: articles_update
description: update method docstring
parameters:
- name: data
in: body
required: true
schema:
$ref: '#/definitions/Article'
responses:
'200':
description: ''
schema:
$ref: '#/definitions/Article'
consumes:
- application/json
tags:
- articles
delete:
operationId: articles_delete
description: destroy method docstring
parameters: []
responses:
'204':
description: ''
consumes:
- application/json
tags:
- articles
patch:
operationId: articles_partial_update
description: partial_update description override
parameters:
- name: data
in: body
required: true
schema:
$ref: '#/definitions/Article'
responses:
'200':
description: ''
schema:
$ref: '#/definitions/Article'
'404':
description: slug not found
consumes:
- application/json
tags:
- articles
parameters:
- name: slug
in: path
description: slug model help_text
required: true
type: string
pattern: '[a-z0-9]+(?:-[a-z0-9]+)'
/articles/{slug}/image/:
get:
operationId: articles_image_read
description: image GET description override
parameters: []
responses:
'200':
description: ''
schema:
$ref: '#/definitions/Article'
consumes:
- multipart/form-data
tags:
- articles
post:
operationId: articles_image_create
description: image method docstring
parameters:
- name: upload
in: formData
description: image serializer help_text
required: true
type: file
responses:
'200':
description: success
consumes:
- multipart/form-data
tags:
- articles
parameters:
- name: slug
in: path
description: slug model help_text
required: true
type: string
pattern: '[a-z0-9]+(?:-[a-z0-9]+)'
/snippets/:
get:
operationId: snippets_list
description: SnippetList classdoc
parameters: []
responses:
'200':
description: ''
schema:
type: array
items:
$ref: '#/definitions/Snippet'
consumes:
- application/json
tags:
- snippets
post:
operationId: snippets_create
description: post method docstring
parameters:
- name: data
in: body
required: true
schema:
$ref: '#/definitions/Snippet'
responses:
'201':
description: ''
schema:
$ref: '#/definitions/Snippet'
consumes:
- application/json
tags:
- snippets
parameters: []
/snippets/{id}/:
get:
operationId: snippets_read
description: SnippetDetail classdoc
parameters: []
responses:
'200':
description: ''
schema:
$ref: '#/definitions/Snippet'
consumes:
- application/json
tags:
- snippets
put:
operationId: snippets_update
description: put class docstring
parameters:
- name: data
in: body
required: true
schema:
$ref: '#/definitions/Snippet'
responses:
'200':
description: ''
schema:
$ref: '#/definitions/Snippet'
consumes:
- application/json
tags:
- snippets
delete:
operationId: snippets_delete
description: delete method docstring
parameters: []
responses:
'204':
description: ''
consumes:
- application/json
tags:
- snippets
patch:
operationId: snippets_partial_update
description: patch method docstring
parameters:
- name: data
in: body
required: true
schema:
$ref: '#/definitions/Snippet'
responses:
'200':
description: ''
schema:
$ref: '#/definitions/Snippet'
consumes:
- application/json
tags:
- snippets
parameters:
- name: id
in: path
description: A unique integer value identifying this snippet.
required: true
type: integer
/users/:
get:
operationId: users_list
description: UserList cbv classdoc
parameters: []
responses:
'200':
description: ''
schema:
type: array
items:
$ref: '#/definitions/User'
consumes:
- application/json
tags:
- users
post:
operationId: users_create
description: apiview post description override
parameters:
- name: data
in: body
required: true
schema: &id001
required:
- username
type: object
properties:
username:
type: string
responses:
'201':
description: ''
schema: *id001
consumes:
- application/json
tags:
- users
patch:
operationId: users_partial_update
description: dummy operation
parameters: []
responses:
'200':
description: ''
consumes:
- application/json
tags:
- users
parameters: []
/users/{id}/:
get:
operationId: users_read
description: user_detail fbv docstring
parameters:
- name: test
in: query
description: test manual param
type: boolean
responses:
'200':
description: response description
schema:
$ref: '#/definitions/User'
consumes:
- application/json
tags:
- users
put:
operationId: users_update
description: user_detail fbv docstring
parameters:
- name: data
in: body
required: true
schema:
$ref: '#/definitions/User'
responses:
'200':
description: ''
schema:
$ref: '#/definitions/User'
consumes:
- application/json
tags:
- users
parameters:
- name: id
in: path
required: true
type: string
definitions:
Article:
required:
- title
- body
type: object
properties:
title:
description: title model help_text
type: string
body:
description: body serializer help_text
type: string
slug:
description: slug model help_text
type: string
format: slug
date_created:
type: string
format: date-time
readOnly: true
date_modified:
type: string
format: date-time
readOnly: true
Project:
required:
- project_name
- github_repo
type: object
properties:
project_name:
description: Name of the project
type: string
github_repo:
description: Github repository of the project
type: string
Snippet:
required:
- code
- language
- example_projects
type: object
properties:
id:
description: id serializer help text
type: integer
readOnly: true
owner:
type: string
readOnly: true
title:
type: string
code:
type: string
linenos:
type: boolean
language:
description: Sample help text for language
type: object
properties:
name:
description: The name of the programming language
type: string
enum:
- abap
- abnf
- ada
- adl
- agda
- aheui
- ahk
- alloy
- ampl
- antlr
- antlr-as
- antlr-cpp
- antlr-csharp
- antlr-java
- antlr-objc
- antlr-perl
- antlr-python
- antlr-ruby
- apacheconf
- apl
- applescript
- arduino
- as
- as3
- aspectj
- aspx-cs
- aspx-vb
- asy
- at
- autoit
- awk
- basemake
- bash
- bat
- bbcode
- bc
- befunge
- bib
- blitzbasic
- blitzmax
- bnf
- boo
- boogie
- brainfuck
- bro
- bst
- bugs
- c
- c-objdump
- ca65
- cadl
- camkes
- capdl
- capnp
- cbmbas
- ceylon
- cfc
- cfengine3
- cfm
- cfs
- chai
- chapel
- cheetah
- cirru
- clay
- clean
- clojure
- clojurescript
- cmake
- cobol
- cobolfree
- coffee-script
- common-lisp
- componentpascal
- console
- control
- coq
- cpp
- cpp-objdump
- cpsa
- cr
- crmsh
- croc
- cryptol
- csharp
- csound
- csound-document
- csound-score
- css
- css+django
- css+erb
- css+genshitext
- css+lasso
- css+mako
- css+mozpreproc
- css+myghty
- css+php
- css+smarty
- cucumber
- cuda
- cypher
- cython
- d
- d-objdump
- dart
- delphi
- dg
- diff
- django
- docker
- doscon
- dpatch
- dtd
- duel
- dylan
- dylan-console
- dylan-lid
- earl-grey
- easytrieve
- ebnf
- ec
- ecl
- eiffel
- elixir
- elm
- emacs
- erb
- erl
- erlang
- evoque
- extempore
- ezhil
- factor
- fan
- fancy
- felix
- fish
- flatline
- forth
- fortran
- fortranfixed
- foxpro
- fsharp
- gap
- gas
- genshi
- genshitext
- glsl
- gnuplot
- go
- golo
- gooddata-cl
- gosu
- groff
- groovy
- gst
- haml
- handlebars
- haskell
- haxeml
- hexdump
- hsail
- html
- html+cheetah
- html+django
- html+evoque
- html+genshi
- html+handlebars
- html+lasso
- html+mako
- html+myghty
- html+ng2
- html+php
- html+smarty
- html+twig
- html+velocity
- http
- hx
- hybris
- hylang
- i6t
- idl
- idris
- iex
- igor
- inform6
- inform7
- ini
- io
- ioke
- irc
- isabelle
- j
- jags
- jasmin
- java
- javascript+mozpreproc
- jcl
- jlcon
- js
- js+cheetah
- js+django
- js+erb
- js+genshitext
- js+lasso
- js+mako
- js+myghty
- js+php
- js+smarty
- jsgf
- json
- json-object
- jsonld
- jsp
- julia
- juttle
- kal
- kconfig
- koka
- kotlin
- lagda
- lasso
- lcry
- lean
- less
- lhs
- lidr
- lighty
- limbo
- liquid
- live-script
- llvm
- logos
- logtalk
- lsl
- lua
- make
- mako
- maql
- mask
- mason
- mathematica
- matlab
- matlabsession
- md
- minid
- modelica
- modula2
- monkey
- monte
- moocode
- moon
- mozhashpreproc
- mozpercentpreproc
- mql
- mscgen
- mupad
- mxml
- myghty
- mysql
- nasm
- ncl
- nemerle
- nesc
- newlisp
- newspeak
- ng2
- nginx
- nim
- nit
- nixos
- nsis
- numpy
- nusmv
- objdump
- objdump-nasm
- objective-c
- objective-c++
- objective-j
- ocaml
- octave
- odin
- ooc
- opa
- openedge
- pacmanconf
- pan
- parasail
- pawn
- perl
- perl6
- php
- pig
- pike
- pkgconfig
- plpgsql
- postgresql
- postscript
- pot
- pov
- powershell
- praat
- prolog
- properties
- protobuf
- ps1con
- psql
- pug
- puppet
- py3tb
- pycon
- pypylog
- pytb
- python
- python3
- qbasic
- qml
- qvto
- racket
- ragel
- ragel-c
- ragel-cpp
- ragel-d
- ragel-em
- ragel-java
- ragel-objc
- ragel-ruby
- raw
- rb
- rbcon
- rconsole
- rd
- rebol
- red
- redcode
- registry
- resource
- rexx
- rhtml
- rnc
- roboconf-graph
- roboconf-instances
- robotframework
- rql
- rsl
- rst
- rts
- rust
- sas
- sass
- sc
- scala
- scaml
- scheme
- scilab
- scss
- shen
- silver
- slim
- smali
- smalltalk
- smarty
- sml
- snobol
- snowball
- sourceslist
- sp
- sparql
- spec
- splus
- sql
- sqlite3
- squidconf
- ssp
- stan
- stata
- swift
- swig
- systemverilog
- tads3
- tap
- tasm
- tcl
- tcsh
- tcshcon
- tea
- termcap
- terminfo
- terraform
- tex
- text
- thrift
- todotxt
- trac-wiki
- treetop
- ts
- tsql
- turtle
- twig
- typoscript
- typoscriptcssdata
- typoscripthtmldata
- urbiscript
- vala
- vb.net
- vcl
- vclsnippets
- vctreestatus
- velocity
- verilog
- vgl
- vhdl
- vim
- wdiff
- whiley
- x10
- xml
- xml+cheetah
- xml+django
- xml+erb
- xml+evoque
- xml+lasso
- xml+mako
- xml+myghty
- xml+php
- xml+smarty
- xml+velocity
- xquery
- xslt
- xtend
- xul+mozpreproc
- yaml
- yaml+jinja
- zephir
style:
type: string
enum:
- abap
- algol
- algol_nu
- arduino
- autumn
- borland
- bw
- colorful
- default
- emacs
- friendly
- fruity
- igor
- lovelace
- manni
- monokai
- murphy
- native
- paraiso-dark
- paraiso-light
- pastie
- perldoc
- rainbow_dash
- rrt
- tango
- trac
- vim
- vs
- xcode
lines:
type: array
items:
type: integer
example_projects:
type: array
items:
$ref: '#/definitions/Project'
User:
required:
- username
- snippets
type: object
properties:
id:
type: integer
readOnly: true
username:
description: Required. 150 characters or fewer. Letters, digits and @/./+/-/_
only.
type: string
snippets:
type: array
items:
type: string
uniqueItems: true
securityDefinitions:
basic:
type: basic

View File

@ -1,2 +1,13 @@
from datadiff.tools import assert_equal
def test_reference_schema(swagger_dict, reference_schema): def test_reference_schema(swagger_dict, reference_schema):
return swagger_dict == reference_schema # formatted better than pytest diff
swagger_dict = dict(swagger_dict)
reference_schema = dict(reference_schema)
ignore = ['info', 'host', 'schemes', 'basePath', 'securityDefinitions']
for attr in ignore:
swagger_dict.pop(attr, None)
reference_schema.pop(attr, None)
assert_equal(swagger_dict, reference_schema)

View File

@ -0,0 +1,45 @@
import pytest
from drf_swagger.openapi import ReferenceResolver
def test_basic():
scopes = ['s1', 's2']
rr = ReferenceResolver(*scopes)
assert scopes == rr.scopes == list(rr.keys()) == list(rr)
rr.set('o1', 1, scope='s1')
assert rr.has('o1', scope='s1')
assert rr.get('o1', scope='s1') == 1
rr.setdefault('o1', lambda: 2, scope='s1')
assert rr.get('o1', scope='s1') == 1
assert not rr.has('o1', scope='s2')
rr.setdefault('o3', lambda: 3, scope='s2')
assert rr.get('o3', scope='s2') == 3
assert rr['s1'] == {'o1': 1}
assert dict(rr) == {'s1': {'o1': 1}, 's2': {'o3': 3}}
assert str(rr) == str(dict(rr))
def test_scoped():
scopes = ['s1', 's2']
rr = ReferenceResolver(*scopes)
r1 = rr.with_scope('s1')
r2 = rr.with_scope('s2')
with pytest.raises(AssertionError):
rr.with_scope('bad')
assert r1.scopes == ['s1']
assert list(r1.keys()) == list(r1) == []
r2.set('o2', 2)
assert r2.scopes == ['s2']
assert list(r2.keys()) == list(r2) == ['o2']
assert r2['o2'] == 2
with pytest.raises(AssertionError):
r2.get('o2', scope='s1')
assert rr.get('o2', scope='s2') == 2

View File

@ -39,5 +39,6 @@ def test_json_codec_roundtrip(codec_json, generator, validate_schema):
def test_yaml_codec_roundtrip(codec_yaml, generator, validate_schema): def test_yaml_codec_roundtrip(codec_yaml, generator, validate_schema):
swagger = generator.get_schema(None, True) swagger = generator.get_schema(None, True)
json_bytes = codec_yaml.encode(swagger) yaml_bytes = codec_yaml.encode(swagger)
validate_schema(yaml.safe_load(json_bytes.decode('utf-8'))) assert b'omap' not in yaml_bytes
validate_schema(yaml.safe_load(yaml_bytes.decode('utf-8')))

View File

@ -42,6 +42,10 @@ deps =
commands= commands=
flake8 src/drf_swagger testproj tests setup.py flake8 src/drf_swagger testproj tests setup.py
[pytest]
DJANGO_SETTINGS_MODULE = testproj.settings
python_paths = testproj
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
exclude = **/migrations/* exclude = **/migrations/*