753 lines
30 KiB
Python
753 lines
30 KiB
Python
import datetime
|
|
import inspect
|
|
import logging
|
|
import operator
|
|
import uuid
|
|
from collections import OrderedDict
|
|
from decimal import Decimal
|
|
|
|
from django.core import validators
|
|
from django.db import models
|
|
from rest_framework import serializers
|
|
from rest_framework.settings import api_settings as rest_framework_settings
|
|
|
|
from .. import openapi
|
|
from ..errors import SwaggerGenerationError
|
|
from ..utils import decimal_as_float, filter_none, get_serializer_class, get_serializer_ref_name
|
|
from .base import FieldInspector, NotHandled, SerializerInspector
|
|
|
|
try:
|
|
import typing
|
|
except ImportError:
|
|
typing = None
|
|
|
|
try:
|
|
from inspect import signature as inspect_signature
|
|
except ImportError:
|
|
inspect_signature = None
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class InlineSerializerInspector(SerializerInspector):
|
|
"""Provides serializer conversions using :meth:`.FieldInspector.field_to_swagger_object`."""
|
|
|
|
#: whether to output :class:`.Schema` definitions inline or into the ``definitions`` section
|
|
use_definitions = False
|
|
|
|
def get_schema(self, serializer):
|
|
return self.probe_field_inspectors(serializer, openapi.Schema, self.use_definitions)
|
|
|
|
def add_manual_parameters(self, serializer, parameters):
|
|
"""Add/replace parameters from the given list of automatically generated request parameters. This method
|
|
is called only when the serializer is converted into a list of parameters for use in a form data request.
|
|
|
|
:param serializer: serializer instance
|
|
:param list[openapi.Parameter] parameters: genereated parameters
|
|
:return: modified parameters
|
|
:rtype: list[openapi.Parameter]
|
|
"""
|
|
return parameters
|
|
|
|
def get_request_parameters(self, serializer, in_):
|
|
fields = getattr(serializer, 'fields', {})
|
|
parameters = [
|
|
self.probe_field_inspectors(
|
|
value, openapi.Parameter, self.use_definitions,
|
|
name=self.get_parameter_name(key), in_=in_
|
|
)
|
|
for key, value
|
|
in fields.items()
|
|
]
|
|
|
|
return self.add_manual_parameters(serializer, parameters)
|
|
|
|
def get_property_name(self, field_name):
|
|
return field_name
|
|
|
|
def get_parameter_name(self, field_name):
|
|
return field_name
|
|
|
|
def get_serializer_ref_name(self, serializer):
|
|
return get_serializer_ref_name(serializer)
|
|
|
|
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
|
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
|
|
|
|
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
|
|
child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references)
|
|
return SwaggerType(
|
|
type=openapi.TYPE_ARRAY,
|
|
items=child_schema,
|
|
)
|
|
elif isinstance(field, serializers.Serializer):
|
|
if swagger_object_type != openapi.Schema:
|
|
raise SwaggerGenerationError("cannot instantiate nested serializer as " + swagger_object_type.__name__)
|
|
|
|
ref_name = self.get_serializer_ref_name(field)
|
|
|
|
def make_schema_definition(serializer=field):
|
|
properties = OrderedDict()
|
|
required = []
|
|
for property_name, child in serializer.fields.items():
|
|
property_name = self.get_property_name(property_name)
|
|
prop_kwargs = {
|
|
'read_only': bool(child.read_only) or None
|
|
}
|
|
prop_kwargs = filter_none(prop_kwargs)
|
|
|
|
child_schema = self.probe_field_inspectors(
|
|
child, ChildSwaggerType, use_references, **prop_kwargs
|
|
)
|
|
properties[property_name] = child_schema
|
|
|
|
if child.required and not getattr(child_schema, 'read_only', False):
|
|
required.append(property_name)
|
|
|
|
result = SwaggerType(
|
|
type=openapi.TYPE_OBJECT,
|
|
properties=properties,
|
|
required=required or None,
|
|
)
|
|
if not ref_name and 'title' in result:
|
|
# on an inline model, the title is derived from the field name
|
|
# but is visno coverually displayed like the model name, which is confusing
|
|
# it is better to just remove title from inline models
|
|
del result.title
|
|
|
|
return result
|
|
|
|
if not ref_name or not use_references:
|
|
return make_schema_definition()
|
|
|
|
definitions = self.components.with_scope(openapi.SCHEMA_DEFINITIONS)
|
|
actual_schema = definitions.setdefault(ref_name, make_schema_definition)
|
|
actual_schema._remove_read_only()
|
|
|
|
actual_serializer = get_serializer_class(getattr(actual_schema, '_serializer', None))
|
|
this_serializer = get_serializer_class(field)
|
|
if actual_serializer and actual_serializer != this_serializer: # pragma: no cover
|
|
logger.warning("Schema for %s will override distinct serializer %s because they "
|
|
"share the same ref_name", actual_serializer, this_serializer)
|
|
|
|
return openapi.SchemaRef(definitions, ref_name)
|
|
|
|
return NotHandled
|
|
|
|
|
|
class ReferencingSerializerInspector(InlineSerializerInspector):
|
|
use_definitions = True
|
|
|
|
|
|
def get_queryset_field(queryset, field_name):
|
|
"""Try to get information about a model and model field from a queryset.
|
|
|
|
:param queryset: the queryset
|
|
:param field_name: target field name
|
|
:returns: the model and target field from the queryset as a 2-tuple; both elements can be ``None``
|
|
:rtype: tuple
|
|
"""
|
|
model = getattr(queryset, 'model', None)
|
|
model_field = get_model_field(model, field_name)
|
|
return model, model_field
|
|
|
|
|
|
def get_model_field(model, field_name):
|
|
"""Try to get the given field from a django db model.
|
|
|
|
:param model: the model
|
|
:param field_name: target field name
|
|
:return: model field or ``None``
|
|
"""
|
|
try:
|
|
if field_name == 'pk':
|
|
return model._meta.pk
|
|
else:
|
|
return model._meta.get_field(field_name)
|
|
except Exception: # pragma: no cover
|
|
return None
|
|
|
|
|
|
def get_queryset_from_view(view, serializer=None):
|
|
"""Try to get the queryset of the given view
|
|
|
|
:param view: the view instance or class
|
|
:param serializer: if given, will check that the view's get_serializer_class return matches this serialzier
|
|
:return: queryset or ``None``
|
|
"""
|
|
try:
|
|
queryset = getattr(view, 'queryset', None)
|
|
|
|
if queryset is not None and serializer is not None:
|
|
# make sure the view is actually using *this* serializer
|
|
assert type(serializer) == view.get_serializer_class()
|
|
|
|
return queryset
|
|
except Exception: # pragma: no cover
|
|
return None
|
|
|
|
|
|
def get_parent_serializer(field):
|
|
"""Get the nearest parent ``Serializer`` instance for the given field.
|
|
|
|
:return: ``Serializer`` or ``None``
|
|
"""
|
|
while field is not None:
|
|
if isinstance(field, serializers.Serializer):
|
|
return field
|
|
|
|
field = field.parent
|
|
|
|
return None # pragma: no cover
|
|
|
|
|
|
def get_related_model(model, source):
|
|
"""Try to find the other side of a model relationship given the name of a related field.
|
|
|
|
:param model: one side of the relationship
|
|
:param str source: related field name
|
|
:return: related model or ``None``
|
|
"""
|
|
try:
|
|
descriptor = getattr(model, source)
|
|
try:
|
|
return descriptor.rel.related_model
|
|
except Exception:
|
|
return descriptor.field.remote_field.model
|
|
except Exception: # pragma: no cover
|
|
return None
|
|
|
|
|
|
class RelatedFieldInspector(FieldInspector):
|
|
"""Provides conversions for ``RelatedField``\\ s."""
|
|
|
|
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
|
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
|
|
|
|
if isinstance(field, serializers.ManyRelatedField):
|
|
child_schema = self.probe_field_inspectors(field.child_relation, ChildSwaggerType, use_references)
|
|
return SwaggerType(
|
|
type=openapi.TYPE_ARRAY,
|
|
items=child_schema,
|
|
unique_items=True,
|
|
)
|
|
|
|
if not isinstance(field, serializers.RelatedField):
|
|
return NotHandled
|
|
|
|
field_queryset = getattr(field, 'queryset', None)
|
|
|
|
if isinstance(field, (serializers.PrimaryKeyRelatedField, serializers.SlugRelatedField)):
|
|
if getattr(field, 'pk_field', ''):
|
|
# a PrimaryKeyRelatedField can have a `pk_field` attribute which is a
|
|
# serializer field that will convert the PK value
|
|
result = self.probe_field_inspectors(field.pk_field, swagger_object_type, use_references, **kwargs)
|
|
# take the type, format, etc from `pk_field`, and the field-level information
|
|
# like title, description, default from the PrimaryKeyRelatedField
|
|
return SwaggerType(existing_object=result)
|
|
|
|
target_field = getattr(field, 'slug_field', 'pk')
|
|
if field_queryset is not None:
|
|
# if the RelatedField has a queryset, try to get the related model field from there
|
|
model, model_field = get_queryset_field(field_queryset, target_field)
|
|
else:
|
|
# if the RelatedField has no queryset (e.g. read only), try to find the target model
|
|
# from the view queryset or ModelSerializer model, if present
|
|
parent_serializer = get_parent_serializer(field)
|
|
|
|
serializer_meta = getattr(parent_serializer, 'Meta', None)
|
|
this_model = getattr(serializer_meta, 'model', None)
|
|
if not this_model:
|
|
view_queryset = get_queryset_from_view(self.view, parent_serializer)
|
|
this_model = getattr(view_queryset, 'model', None)
|
|
|
|
source = getattr(field, 'source', '') or field.field_name
|
|
if not source and isinstance(field.parent, serializers.ManyRelatedField):
|
|
source = field.parent.field_name
|
|
|
|
model = get_related_model(this_model, source)
|
|
model_field = get_model_field(model, target_field)
|
|
|
|
attrs = get_basic_type_info(model_field) or {'type': openapi.TYPE_STRING}
|
|
return SwaggerType(**attrs)
|
|
elif isinstance(field, serializers.HyperlinkedRelatedField):
|
|
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)
|
|
|
|
return SwaggerType(type=openapi.TYPE_STRING)
|
|
|
|
|
|
def find_regex(regex_field):
|
|
"""Given a ``Field``, look for a ``RegexValidator`` and try to extract its pattern and return it as a string.
|
|
|
|
:param serializers.Field regex_field: the field instance
|
|
:return: the extracted pattern, or ``None``
|
|
:rtype: str
|
|
"""
|
|
regex_validator = None
|
|
for validator in regex_field.validators:
|
|
if isinstance(validator, validators.RegexValidator):
|
|
if isinstance(validator, validators.URLValidator) or validator == validators.validate_ipv4_address:
|
|
# skip the default url and IP regexes because they are complex and unhelpful
|
|
# validate_ipv4_address is a RegexValidator instance in Django 1.11
|
|
continue
|
|
if regex_validator is not None:
|
|
# bail if multiple validators are found - no obvious way to choose
|
|
return None # pragma: no cover
|
|
regex_validator = validator
|
|
|
|
# regex_validator.regex should be a compiled re object...
|
|
try:
|
|
pattern = getattr(getattr(regex_validator, 'regex', None), 'pattern', None)
|
|
except Exception: # pragma: no cover
|
|
logger.warning('failed to compile regex validator of ' + str(regex_field), exc_info=True)
|
|
return None
|
|
|
|
if pattern:
|
|
# attempt some basic cleanup to remove regex constructs not supported by JavaScript
|
|
# -- swagger uses javascript-style regexes - see https://github.com/swagger-api/swagger-editor/issues/1601
|
|
if pattern.endswith('\\Z') or pattern.endswith('\\z'):
|
|
pattern = pattern[:-2] + '$'
|
|
|
|
return pattern
|
|
|
|
|
|
numeric_fields = (serializers.IntegerField, serializers.FloatField, serializers.DecimalField)
|
|
limit_validators = [
|
|
# minimum and maximum apply to numbers
|
|
(validators.MinValueValidator, numeric_fields, 'minimum', operator.__gt__),
|
|
(validators.MaxValueValidator, numeric_fields, 'maximum', operator.__lt__),
|
|
|
|
# minLength and maxLength apply to strings
|
|
(validators.MinLengthValidator, serializers.CharField, 'min_length', operator.__gt__),
|
|
(validators.MaxLengthValidator, serializers.CharField, 'max_length', operator.__lt__),
|
|
|
|
# minItems and maxItems apply to lists
|
|
(validators.MinLengthValidator, serializers.ListField, 'min_items', operator.__gt__),
|
|
(validators.MaxLengthValidator, serializers.ListField, 'max_items', operator.__lt__),
|
|
]
|
|
|
|
|
|
def find_limits(field):
|
|
"""Given a ``Field``, look for min/max value/length validators and return appropriate limit validation attributes.
|
|
|
|
:param serializers.Field field: the field instance
|
|
:return: the extracted limits
|
|
:rtype: OrderedDict
|
|
"""
|
|
limits = {}
|
|
applicable_limits = [
|
|
(validator, attr, improves)
|
|
for validator, field_class, attr, improves in limit_validators
|
|
if isinstance(field, field_class)
|
|
]
|
|
|
|
if isinstance(field, serializers.DecimalField) and not decimal_as_float(field):
|
|
return limits
|
|
|
|
for validator in field.validators:
|
|
if not hasattr(validator, 'limit_value'):
|
|
continue
|
|
|
|
limit_value = validator.limit_value
|
|
if isinstance(limit_value, Decimal) and decimal_as_float(field):
|
|
limit_value = float(limit_value)
|
|
|
|
for validator_class, attr, improves in applicable_limits:
|
|
if isinstance(validator, validator_class):
|
|
if attr not in limits or improves(limit_value, limits[attr]):
|
|
limits[attr] = limit_value
|
|
|
|
if hasattr(field, "allow_blank") and not field.allow_blank:
|
|
if limits.get('min_length', 0) < 1:
|
|
limits['min_length'] = 1
|
|
|
|
return OrderedDict(sorted(limits.items()))
|
|
|
|
|
|
def decimal_field_type(field):
|
|
return openapi.TYPE_NUMBER if decimal_as_float(field) else openapi.TYPE_STRING
|
|
|
|
|
|
model_field_to_basic_type = [
|
|
(models.AutoField, (openapi.TYPE_INTEGER, None)),
|
|
(models.BinaryField, (openapi.TYPE_STRING, openapi.FORMAT_BINARY)),
|
|
(models.BooleanField, (openapi.TYPE_BOOLEAN, None)),
|
|
(models.NullBooleanField, (openapi.TYPE_BOOLEAN, None)),
|
|
(models.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
|
|
(models.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
|
|
(models.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)),
|
|
(models.DurationField, (openapi.TYPE_INTEGER, None)),
|
|
(models.FloatField, (openapi.TYPE_NUMBER, None)),
|
|
(models.IntegerField, (openapi.TYPE_INTEGER, None)),
|
|
(models.IPAddressField, (openapi.TYPE_STRING, openapi.FORMAT_IPV4)),
|
|
(models.GenericIPAddressField, (openapi.TYPE_STRING, openapi.FORMAT_IPV6)),
|
|
(models.SlugField, (openapi.TYPE_STRING, openapi.FORMAT_SLUG)),
|
|
(models.TextField, (openapi.TYPE_STRING, None)),
|
|
(models.TimeField, (openapi.TYPE_STRING, None)),
|
|
(models.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
|
|
(models.CharField, (openapi.TYPE_STRING, None)),
|
|
]
|
|
|
|
ip_format = {'ipv4': openapi.FORMAT_IPV4, 'ipv6': openapi.FORMAT_IPV6}
|
|
|
|
serializer_field_to_basic_type = [
|
|
(serializers.EmailField, (openapi.TYPE_STRING, openapi.FORMAT_EMAIL)),
|
|
(serializers.SlugField, (openapi.TYPE_STRING, openapi.FORMAT_SLUG)),
|
|
(serializers.URLField, (openapi.TYPE_STRING, openapi.FORMAT_URI)),
|
|
(serializers.IPAddressField, (openapi.TYPE_STRING, lambda field: ip_format.get(field.protocol, None))),
|
|
(serializers.UUIDField, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
|
|
(serializers.RegexField, (openapi.TYPE_STRING, None)),
|
|
(serializers.CharField, (openapi.TYPE_STRING, None)),
|
|
(serializers.BooleanField, (openapi.TYPE_BOOLEAN, None)),
|
|
(serializers.NullBooleanField, (openapi.TYPE_BOOLEAN, None)),
|
|
(serializers.IntegerField, (openapi.TYPE_INTEGER, None)),
|
|
(serializers.FloatField, (openapi.TYPE_NUMBER, None)),
|
|
(serializers.DecimalField, (decimal_field_type, openapi.FORMAT_DECIMAL)),
|
|
(serializers.DurationField, (openapi.TYPE_NUMBER, None)), # ?
|
|
(serializers.DateField, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
|
|
(serializers.DateTimeField, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
|
|
(serializers.ModelField, (openapi.TYPE_STRING, None)),
|
|
]
|
|
|
|
basic_type_info = serializer_field_to_basic_type + model_field_to_basic_type
|
|
|
|
|
|
def get_basic_type_info(field):
|
|
"""Given a serializer or model ``Field``, return its basic type information - ``type``, ``format``, ``pattern``,
|
|
and any applicable min/max limit values.
|
|
|
|
:param field: the field instance
|
|
:return: the extracted attributes as a dictionary, or ``None`` if the field type is not known
|
|
:rtype: OrderedDict
|
|
"""
|
|
if field is None:
|
|
return None
|
|
|
|
for field_class, type_format in basic_type_info:
|
|
if isinstance(field, field_class):
|
|
swagger_type, format = type_format
|
|
if callable(swagger_type):
|
|
swagger_type = swagger_type(field)
|
|
if callable(format):
|
|
format = format(field)
|
|
break
|
|
else: # pragma: no cover
|
|
return None
|
|
|
|
pattern = None
|
|
if swagger_type == openapi.TYPE_STRING:
|
|
pattern = find_regex(field)
|
|
|
|
limits = find_limits(field)
|
|
|
|
result = OrderedDict([
|
|
('type', swagger_type),
|
|
('format', format),
|
|
('pattern', pattern)
|
|
])
|
|
result.update(limits)
|
|
result = filter_none(result)
|
|
return result
|
|
|
|
|
|
def decimal_return_type():
|
|
return openapi.TYPE_STRING if rest_framework_settings.COERCE_DECIMAL_TO_STRING else openapi.TYPE_NUMBER
|
|
|
|
|
|
raw_type_info = [
|
|
(bool, (openapi.TYPE_BOOLEAN, None)),
|
|
(int, (openapi.TYPE_INTEGER, None)),
|
|
(float, (openapi.TYPE_NUMBER, None)),
|
|
(Decimal, (decimal_return_type, openapi.FORMAT_DECIMAL)),
|
|
(uuid.UUID, (openapi.TYPE_STRING, openapi.FORMAT_UUID)),
|
|
(datetime.datetime, (openapi.TYPE_STRING, openapi.FORMAT_DATETIME)),
|
|
(datetime.date, (openapi.TYPE_STRING, openapi.FORMAT_DATE)),
|
|
# TODO - support typing.List etc
|
|
]
|
|
|
|
hinting_type_info = raw_type_info
|
|
|
|
|
|
def get_basic_type_info_from_hint(hint_class):
|
|
"""Given a class (eg from a SerializerMethodField's return type hint,
|
|
return its basic type information - ``type``, ``format``, ``pattern``,
|
|
and any applicable min/max limit values.
|
|
|
|
:param hint_class: the class
|
|
:return: the extracted attributes as a dictionary, or ``None`` if the field type is not known
|
|
:rtype: OrderedDict
|
|
"""
|
|
|
|
for check_class, type_format in hinting_type_info:
|
|
if issubclass(hint_class, check_class):
|
|
swagger_type, format = type_format
|
|
if callable(swagger_type):
|
|
swagger_type = swagger_type()
|
|
# if callable(format):
|
|
# format = format(klass)
|
|
break
|
|
else: # pragma: no cover
|
|
return None
|
|
|
|
pattern = None
|
|
|
|
result = OrderedDict([
|
|
('type', swagger_type),
|
|
('format', format),
|
|
('pattern', pattern)
|
|
])
|
|
|
|
return result
|
|
|
|
|
|
class SerializerMethodFieldInspector(FieldInspector):
|
|
"""Provides conversion for SerializerMethodField, optionally using information from the swagger_serializer_method
|
|
decorator.
|
|
"""
|
|
|
|
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
|
if not isinstance(field, serializers.SerializerMethodField):
|
|
return NotHandled
|
|
|
|
method = getattr(field.parent, field.method_name)
|
|
if method is None:
|
|
return NotHandled
|
|
|
|
serializer = getattr(method, "_swagger_serializer", None)
|
|
|
|
if serializer:
|
|
# attribute added by the swagger_serializer_method decorator
|
|
serializer = getattr(method, '_swagger_serializer', None)
|
|
|
|
# in order of preference for description, use:
|
|
# 1) field.help_text from SerializerMethodField(help_text)
|
|
# 2) serializer.help_text from swagger_serializer_method(serializer)
|
|
# 3) method's docstring
|
|
description = field.help_text
|
|
if description is None:
|
|
description = getattr(serializer, 'help_text', None)
|
|
if description is None:
|
|
description = method.__doc__
|
|
|
|
label = field.label
|
|
if label is None:
|
|
label = getattr(serializer, 'label', None)
|
|
|
|
if inspect.isclass(serializer):
|
|
serializer_kwargs = {
|
|
"help_text": description,
|
|
"label": label,
|
|
"read_only": True,
|
|
}
|
|
|
|
serializer = method._swagger_serializer(**serializer_kwargs)
|
|
else:
|
|
serializer.help_text = description
|
|
serializer.label = label
|
|
serializer.read_only = True
|
|
|
|
return self.probe_field_inspectors(serializer, swagger_object_type, use_references, read_only=True)
|
|
elif typing and inspect_signature:
|
|
# look for Python 3.5+ style type hinting of the return value
|
|
hint_class = inspect_signature(method).return_annotation
|
|
|
|
if not inspect.isclass(hint_class) and hasattr(hint_class, '__args__'):
|
|
hint_class = hint_class.__args__[0]
|
|
if inspect.isclass(hint_class) and not issubclass(hint_class, inspect._empty):
|
|
type_info = get_basic_type_info_from_hint(hint_class)
|
|
|
|
if type_info is not None:
|
|
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type,
|
|
use_references, **kwargs)
|
|
return SwaggerType(**type_info)
|
|
|
|
return NotHandled
|
|
|
|
|
|
class SimpleFieldInspector(FieldInspector):
|
|
"""Provides conversions for fields which can be described using just ``type``, ``format``, ``pattern``
|
|
and min/max validators.
|
|
"""
|
|
|
|
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
|
type_info = get_basic_type_info(field)
|
|
if type_info is None:
|
|
return NotHandled
|
|
|
|
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
|
|
return SwaggerType(**type_info)
|
|
|
|
|
|
class ChoiceFieldInspector(FieldInspector):
|
|
"""Provides conversions for ``ChoiceField`` and ``MultipleChoiceField``."""
|
|
|
|
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
|
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
|
|
|
|
if isinstance(field, serializers.ChoiceField):
|
|
enum_type = openapi.TYPE_STRING
|
|
|
|
# for ModelSerializer, try to infer the type from the associated model field
|
|
serializer = get_parent_serializer(field)
|
|
if isinstance(serializer, serializers.ModelSerializer):
|
|
model = getattr(getattr(serializer, 'Meta'), 'model')
|
|
model_field = get_model_field(model, field.source)
|
|
if model_field:
|
|
model_type = get_basic_type_info(model_field)
|
|
if model_type:
|
|
enum_type = model_type.get('type', enum_type)
|
|
|
|
if isinstance(field, serializers.MultipleChoiceField):
|
|
return SwaggerType(
|
|
type=openapi.TYPE_ARRAY,
|
|
items=ChildSwaggerType(
|
|
type=enum_type,
|
|
enum=list(field.choices.keys())
|
|
)
|
|
)
|
|
|
|
return SwaggerType(type=enum_type, enum=list(field.choices.keys()))
|
|
|
|
return NotHandled
|
|
|
|
|
|
class FileFieldInspector(FieldInspector):
|
|
"""Provides conversions for ``FileField``\\ s."""
|
|
|
|
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
|
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
|
|
|
|
if isinstance(field, serializers.FileField):
|
|
# swagger 2.0 does not support specifics about file fields, so ImageFile gets no special treatment
|
|
# OpenAPI 3.0 does support it, so a future implementation could handle this better
|
|
err = SwaggerGenerationError("FileField is supported only in a formData Parameter or response Schema")
|
|
if swagger_object_type == openapi.Schema:
|
|
# FileField.to_representation returns URL or file name
|
|
result = SwaggerType(type=openapi.TYPE_STRING, read_only=True)
|
|
if getattr(field, 'use_url', rest_framework_settings.UPLOADED_FILES_USE_URL):
|
|
result.format = openapi.FORMAT_URI
|
|
return result
|
|
elif swagger_object_type == openapi.Parameter:
|
|
param = SwaggerType(type=openapi.TYPE_FILE)
|
|
if param['in'] != openapi.IN_FORM:
|
|
raise err # pragma: no cover
|
|
return param
|
|
else:
|
|
raise err # pragma: no cover
|
|
|
|
return NotHandled
|
|
|
|
|
|
class DictFieldInspector(FieldInspector):
|
|
"""Provides conversion for ``DictField``."""
|
|
|
|
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
|
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
|
|
|
|
if isinstance(field, serializers.DictField) and swagger_object_type == openapi.Schema:
|
|
child_schema = self.probe_field_inspectors(field.child, ChildSwaggerType, use_references)
|
|
return SwaggerType(
|
|
type=openapi.TYPE_OBJECT,
|
|
additional_properties=child_schema
|
|
)
|
|
|
|
return NotHandled
|
|
|
|
|
|
class HiddenFieldInspector(FieldInspector):
|
|
"""Hide ``HiddenField``."""
|
|
|
|
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
|
if isinstance(field, serializers.HiddenField):
|
|
return None
|
|
|
|
return NotHandled
|
|
|
|
|
|
class StringDefaultFieldInspector(FieldInspector):
|
|
"""For otherwise unhandled fields, return them as plain :data:`.TYPE_STRING` objects."""
|
|
|
|
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): # pragma: no cover
|
|
# TODO unhandled fields: TimeField JSONField
|
|
SwaggerType, ChildSwaggerType = self._get_partial_types(field, swagger_object_type, use_references, **kwargs)
|
|
return SwaggerType(type=openapi.TYPE_STRING)
|
|
|
|
|
|
try:
|
|
from djangorestframework_camel_case.parser import CamelCaseJSONParser
|
|
from djangorestframework_camel_case.render import CamelCaseJSONRenderer
|
|
from djangorestframework_camel_case.render import camelize
|
|
except ImportError: # pragma: no cover
|
|
CamelCaseJSONParser = CamelCaseJSONRenderer = None
|
|
|
|
def camelize(data):
|
|
return data
|
|
|
|
|
|
class CamelCaseJSONFilter(FieldInspector):
|
|
"""Converts property names to camelCase if ``djangorestframework_camel_case`` is used."""
|
|
|
|
def camelize_string(self, s):
|
|
"""Hack to force ``djangorestframework_camel_case`` to camelize a plain string.
|
|
|
|
:param str s: the string
|
|
:return: camelized string
|
|
:rtype: str
|
|
"""
|
|
return next(iter(camelize({s: ''})))
|
|
|
|
def camelize_schema(self, schema):
|
|
"""Recursively camelize property names for the given schema using ``djangorestframework_camel_case``.
|
|
The target schema object must be modified in-place.
|
|
|
|
:param openapi.Schema schema: the :class:`.Schema` object
|
|
"""
|
|
if getattr(schema, 'properties', {}):
|
|
schema.properties = OrderedDict(
|
|
(self.camelize_string(key), self.camelize_schema(openapi.resolve_ref(val, self.components)) or val)
|
|
for key, val in schema.properties.items()
|
|
)
|
|
|
|
if getattr(schema, 'required', []):
|
|
schema.required = [self.camelize_string(p) for p in schema.required]
|
|
|
|
def process_result(self, result, method_name, obj, **kwargs):
|
|
if isinstance(result, openapi.Schema.OR_REF) and self.is_camel_case():
|
|
schema = openapi.resolve_ref(result, self.components)
|
|
self.camelize_schema(schema)
|
|
|
|
return result
|
|
|
|
if CamelCaseJSONParser and CamelCaseJSONRenderer:
|
|
def is_camel_case(self):
|
|
return (
|
|
any(issubclass(parser, CamelCaseJSONParser) for parser in self.view.parser_classes) or
|
|
any(issubclass(renderer, CamelCaseJSONRenderer) for renderer in self.view.renderer_classes)
|
|
)
|
|
else:
|
|
def is_camel_case(self):
|
|
return False
|
|
|
|
|
|
try:
|
|
from rest_framework_recursive.fields import RecursiveField
|
|
except ImportError: # pragma: no cover
|
|
class RecursiveFieldInspector(FieldInspector):
|
|
"""Provides conversion for RecursiveField (https://github.com/heywbj/django-rest-framework-recursive)"""
|
|
pass
|
|
else:
|
|
class RecursiveFieldInspector(FieldInspector):
|
|
"""Provides conversion for RecursiveField (https://github.com/heywbj/django-rest-framework-recursive)"""
|
|
|
|
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
|
|
if isinstance(field, RecursiveField) and swagger_object_type == openapi.Schema:
|
|
assert use_references is True, "Can not create schema for RecursiveField when use_references is False"
|
|
|
|
ref_name = get_serializer_ref_name(field.proxied)
|
|
assert ref_name is not None, "Can't create RecursiveField schema for inline " + str(type(field.proxied))
|
|
|
|
definitions = self.components.with_scope(openapi.SCHEMA_DEFINITIONS)
|
|
return openapi.SchemaRef(definitions, ref_name, ignore_unresolved=True)
|
|
|
|
return NotHandled
|