Django rest framework recursive support (#110)

* add get_serializer_ref_name utility function
* implement RecursiveFieldInspector
* add option to allow non-existing reference in SchemaRef
* add examples and README
* Update changelog and docs
openapi3
Roman Sichny 2018-04-27 01:51:10 +03:00 committed by Cristi Vîjdea
parent d2dc09cb3c
commit 979ec84630
16 changed files with 311 additions and 33 deletions

View File

@ -353,6 +353,12 @@ Integration with `djangorestframework-camel-case <https://github.com/vbabiy/djan
provided out of the box - if you have ``djangorestframework-camel-case`` installed and your ``APIView`` uses
``CamelCaseJSONParser`` or ``CamelCaseJSONRenderer``, all property names will be converted to *camelCase* by default.
djangorestframework-recursive
===============================
Integration with `djangorestframework-recursive <https://github.com/heywbj/django-rest-framework-recursive>`_ is
provided out of the box - if you have ``djangorestframework-recursive`` installed.
.. |travis| image:: https://img.shields.io/travis/axnsan12/drf-yasg/master.svg
:target: https://travis-ci.org/axnsan12/drf-yasg
:alt: Travis CI

View File

@ -3,6 +3,20 @@ Changelog
#########
*********
**1.7.0**
*********
*Release date: Apr 27, 2018*
- **ADDED:** added integration with `djangorestframework-recursive <https://github.com/heywbj/django-rest-framework-recursive>`_
(:issue:`109`, :pr:`110`, thanks to :ghuser:`rsichny`)
*NOTE:* in order for this to work, you will have to add the new ``drf_yasg.inspectors.RecursiveFieldInspector`` to
your ``DEFAULT_FIELD_INSPECTORS`` array if you changed it from the default value
- **FIXED:** ``SchemaRef`` now supports cyclical references via the ``ignore_unresolved`` argument
*********
**1.6.2**
*********

View File

@ -67,6 +67,7 @@ to this list.
:class:`'drf_yasg.inspectors.FileFieldInspector' <.inspectors.FileFieldInspector>`, |br| \
:class:`'drf_yasg.inspectors.DictFieldInspector' <.inspectors.DictFieldInspector>`, |br| \
:class:`'drf_yasg.inspectors.HiddenFieldInspector' <.inspectors.HiddenFieldInspector>`, |br| \
:class:`'drf_yasg.inspectors.RecursiveFieldInspector' <.inspectors.RecursiveFieldInspector>`, |br| \
:class:`'drf_yasg.inspectors.SimpleFieldInspector' <.inspectors.SimpleFieldInspector>`, |br| \
:class:`'drf_yasg.inspectors.StringDefaultFieldInspector' <.inspectors.StringDefaultFieldInspector>`, |br| \
``]``

View File

@ -5,5 +5,6 @@ django-cors-headers>=2.1.0
django-filter>=1.1.0,<2.0; python_version == "2.7"
django-filter>=1.1.0; python_version >= "3.4"
djangorestframework-camel-case>=0.2.0
djangorestframework-recursive>=0.1.2
dj-database-url>=0.4.2
user_agents>=1.1.0

View File

@ -6,12 +6,13 @@ SWAGGER_DEFAULTS = {
'DEFAULT_FIELD_INSPECTORS': [
'drf_yasg.inspectors.CamelCaseJSONFilter',
'drf_yasg.inspectors.RecursiveFieldInspector',
'drf_yasg.inspectors.ReferencingSerializerInspector',
'drf_yasg.inspectors.RelatedFieldInspector',
'drf_yasg.inspectors.ChoiceFieldInspector',
'drf_yasg.inspectors.FileFieldInspector',
'drf_yasg.inspectors.DictFieldInspector',
'drf_yasg.inspectors.HiddenFieldInspector',
'drf_yasg.inspectors.RelatedFieldInspector',
'drf_yasg.inspectors.SimpleFieldInspector',
'drf_yasg.inspectors.StringDefaultFieldInspector',
],

View File

@ -4,8 +4,8 @@ from .base import (
)
from .field import (
CamelCaseJSONFilter, ChoiceFieldInspector, DictFieldInspector, FileFieldInspector, HiddenFieldInspector,
InlineSerializerInspector, ReferencingSerializerInspector, RelatedFieldInspector, SimpleFieldInspector,
StringDefaultFieldInspector
InlineSerializerInspector, RecursiveFieldInspector, ReferencingSerializerInspector, RelatedFieldInspector,
SimpleFieldInspector, StringDefaultFieldInspector
)
from .query import CoreAPICompatInspector, DjangoRestResponsePagination
from .view import SwaggerAutoSchema
@ -23,9 +23,9 @@ __all__ = [
'CoreAPICompatInspector', 'DjangoRestResponsePagination',
# field inspectors
'InlineSerializerInspector', 'ReferencingSerializerInspector', 'RelatedFieldInspector', 'SimpleFieldInspector',
'FileFieldInspector', 'ChoiceFieldInspector', 'DictFieldInspector', 'StringDefaultFieldInspector',
'CamelCaseJSONFilter', 'HiddenFieldInspector',
'InlineSerializerInspector', 'RecursiveFieldInspector', 'ReferencingSerializerInspector', 'RelatedFieldInspector',
'SimpleFieldInspector', 'FileFieldInspector', 'ChoiceFieldInspector', 'DictFieldInspector',
'StringDefaultFieldInspector', 'CamelCaseJSONFilter', 'HiddenFieldInspector',
# view inspectors
'SwaggerAutoSchema',

View File

@ -10,7 +10,7 @@ from rest_framework.settings import api_settings as rest_framework_settings
from .. import openapi
from ..errors import SwaggerGenerationError
from ..utils import decimal_as_float, filter_none
from ..utils import decimal_as_float, filter_none, get_serializer_ref_name
from .base import FieldInspector, NotHandled, SerializerInspector
logger = logging.getLogger(__name__)
@ -55,23 +55,12 @@ class InlineSerializerInspector(SerializerInspector):
if swagger_object_type != openapi.Schema:
raise SwaggerGenerationError("cannot instantiate nested serializer as " + swagger_object_type.__name__)
serializer = field
serializer_meta = getattr(serializer, 'Meta', None)
serializer_name = type(serializer).__name__
if hasattr(serializer_meta, 'ref_name'):
ref_name = serializer_meta.ref_name
elif serializer_name == 'NestedSerializer' and isinstance(serializer, serializers.ModelSerializer):
logger.debug("Forcing inline output for ModelSerializer named 'NestedSerializer': " + str(serializer))
ref_name = None
else:
ref_name = serializer_name
if ref_name.endswith('Serializer'):
ref_name = ref_name[:-len('Serializer')]
ref_name = get_serializer_ref_name(field)
def make_schema_definition():
properties = OrderedDict()
required = []
for property_name, child in serializer.fields.items():
for property_name, child in field.fields.items():
property_name = self.get_property_name(property_name)
prop_kwargs = {
'read_only': child.read_only or None
@ -531,3 +520,24 @@ else:
return camelize_schema(result, self.components)
return result
try:
from rest_framework_recursive.fields import RecursiveField
except ImportError:
class RecursiveFieldInspector(FieldInspector):
"""Provides conversion for RecursiveField (https://github.com/heywbj/django-rest-framework-recursive)"""
pass
else:
class RecursiveFieldInspector(FieldInspector):
"""Provides conversion for RecursiveField (https://github.com/heywbj/django-rest-framework-recursive)"""
def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs):
if isinstance(field, RecursiveField) and swagger_object_type == openapi.Schema:
assert use_references is True, "Can not create schema for RecursiveField when use_references is False"
ref_name = get_serializer_ref_name(field.proxied)
assert ref_name is not None, "Can not create RecursiveField schema for inline ModelSerializer"
return openapi.SchemaRef(self.components.with_scope(openapi.SCHEMA_DEFINITIONS), ref_name,
ignore_unresolved=True)
return NotHandled

View File

@ -466,7 +466,7 @@ class Schema(SwaggerDict):
class _Ref(SwaggerDict):
ref_name_re = re.compile(r"#/(?P<scope>.+)/(?P<name>[^/]+)$")
def __init__(self, resolver, name, scope, expected_type):
def __init__(self, resolver, name, scope, expected_type, ignore_unresolved=False):
"""Base class for all reference types. A reference object has only one property, ``$ref``, which must be a JSON
reference to a valid object in the specification, e.g. ``#/definitions/Article`` to refer to an article model.
@ -474,13 +474,15 @@ class _Ref(SwaggerDict):
:param str name: referenced object name, e.g. "Article"
:param str scope: reference scope, e.g. "definitions"
:param type[.SwaggerDict] expected_type: the expected type that will be asserted on the object found in resolver
:param bool ignore_unresolved: allow the reference to be not defined in resolver
"""
super(_Ref, self).__init__()
assert not type(self) == _Ref, "do not instantiate _Ref directly"
ref_name = "#/{scope}/{name}".format(scope=scope, name=name)
obj = resolver.get(name, scope)
assert isinstance(obj, expected_type), ref_name + " is a {actual}, not a {expected}" \
.format(actual=type(obj).__name__, expected=expected_type.__name__)
if not ignore_unresolved:
obj = resolver.get(name, scope)
assert isinstance(obj, expected_type), ref_name + " is a {actual}, not a {expected}" \
.format(actual=type(obj).__name__, expected=expected_type.__name__)
self.ref = ref_name
def resolve(self, resolver):
@ -502,14 +504,15 @@ class _Ref(SwaggerDict):
class SchemaRef(_Ref):
def __init__(self, resolver, schema_name):
def __init__(self, resolver, schema_name, ignore_unresolved=False):
"""Adds a reference to a named Schema defined in the ``#/definitions/`` object.
:param .ReferenceResolver resolver: component resolver which must contain the definition
:param str schema_name: schema name
:param bool ignore_unresolved: allow the reference to be not defined in resolver
"""
assert SCHEMA_DEFINITIONS in resolver.scopes
super(SchemaRef, self).__init__(resolver, schema_name, SCHEMA_DEFINITIONS, Schema)
super(SchemaRef, self).__init__(resolver, schema_name, SCHEMA_DEFINITIONS, Schema, ignore_unresolved)
Schema.OR_REF = (Schema, SchemaRef)

View File

@ -295,3 +295,25 @@ def decimal_as_float(field):
if isinstance(field, serializers.DecimalField) or isinstance(field, models.DecimalField):
return not getattr(field, 'coerce_to_string', rest_framework_settings.COERCE_DECIMAL_TO_STRING)
return False
def get_serializer_ref_name(serializer):
"""
Get serializer's ref_name (or None for ModelSerializer if it is named 'NestedSerializer')
:param serializer: Serializer instance
:return: Serializer's ref_name or None for inline serializer
:rtype: str or None
"""
serializer_meta = getattr(serializer, 'Meta', None)
serializer_name = type(serializer).__name__
if hasattr(serializer_meta, 'ref_name'):
ref_name = serializer_meta.ref_name
elif serializer_name == 'NestedSerializer' and isinstance(serializer, serializers.ModelSerializer):
logger.debug("Forcing inline output for ModelSerializer named 'NestedSerializer': " + str(serializer))
ref_name = None
else:
ref_name = serializer_name
if ref_name.endswith('Serializer'):
ref_name = ref_name[:-len('Serializer')]
return ref_name

View File

@ -0,0 +1,22 @@
# Generated by Django 2.0.4 on 2018-04-26 13:06
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
('todo', '0001_initial'),
]
operations = [
migrations.CreateModel(
name='TodoTree',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('title', models.CharField(max_length=50)),
('parent', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE,
related_name='children', to='todo.TodoTree')),
],
),
]

View File

@ -13,3 +13,8 @@ class TodoAnother(models.Model):
class TodoYetAnother(models.Model):
todo = models.ForeignKey(TodoAnother, on_delete=models.CASCADE)
title = models.CharField(max_length=50)
class TodoTree(models.Model):
parent = models.ForeignKey('self', on_delete=models.CASCADE, related_name='children', null=True)
title = models.CharField(max_length=50)

View File

@ -1,7 +1,8 @@
from django.utils import timezone
from rest_framework import serializers
from rest_framework_recursive.fields import RecursiveField
from .models import Todo, TodoAnother, TodoYetAnother
from .models import Todo, TodoAnother, TodoTree, TodoYetAnother
class TodoSerializer(serializers.ModelSerializer):
@ -25,3 +26,22 @@ class TodoYetAnotherSerializer(serializers.ModelSerializer):
model = TodoYetAnother
fields = ('title', 'todo')
depth = 2
class TodoTreeSerializer(serializers.ModelSerializer):
children = serializers.ListField(child=RecursiveField(), source='children.all')
class Meta:
model = TodoTree
fields = ('id', 'title', 'children')
class TodoRecursiveSerializer(serializers.ModelSerializer):
parent = RecursiveField(read_only=True)
parent_id = serializers.PrimaryKeyRelatedField(queryset=TodoTree.objects.all(), pk_field=serializers.IntegerField(),
write_only=True, allow_null=True, required=False, default=None,
source='parent')
class Meta:
model = TodoTree
fields = ('id', 'title', 'parent', 'parent_id')

View File

@ -7,10 +7,12 @@ router = routers.DefaultRouter()
router.register(r'', views.TodoViewSet)
router.register(r'another', views.TodoAnotherViewSet)
router.register(r'yetanother', views.TodoYetAnotherViewSet)
router.register(r'tree', views.TodoTreeView)
router.register(r'recursive', views.TodoRecursiveView)
urlpatterns = router.urls
urlpatterns += [
url(r'^(?P<todo_id>\d+)/yetanother/(?P<yetanother_id>\d+)/$',
views.NestedTodoView.as_view(),),
views.NestedTodoView.as_view(), ),
]

View File

@ -1,8 +1,10 @@
from rest_framework import viewsets
from rest_framework.generics import RetrieveAPIView
from .models import Todo, TodoAnother, TodoYetAnother
from .serializer import TodoAnotherSerializer, TodoSerializer, TodoYetAnotherSerializer
from .models import Todo, TodoAnother, TodoTree, TodoYetAnother
from .serializer import (
TodoAnotherSerializer, TodoRecursiveSerializer, TodoSerializer, TodoTreeSerializer, TodoYetAnotherSerializer
)
class TodoViewSet(viewsets.ReadOnlyModelViewSet):
@ -25,3 +27,13 @@ class TodoYetAnotherViewSet(viewsets.ReadOnlyModelViewSet):
class NestedTodoView(RetrieveAPIView):
serializer_class = TodoYetAnotherSerializer
class TodoTreeView(viewsets.ReadOnlyModelViewSet):
queryset = TodoTree.objects.all()
serializer_class = TodoTreeSerializer
class TodoRecursiveView(viewsets.ModelViewSet):
queryset = TodoTree.objects.all()
serializer_class = TodoRecursiveSerializer

View File

@ -495,6 +495,129 @@ paths:
description: A unique integer value identifying this todo another.
required: true
type: integer
/todo/recursive/:
get:
operationId: todo_recursive_list
description: ''
parameters: []
responses:
'200':
description: ''
schema:
type: array
items:
$ref: '#/definitions/TodoRecursive'
tags:
- todo
post:
operationId: todo_recursive_create
description: ''
parameters:
- name: data
in: body
required: true
schema:
$ref: '#/definitions/TodoRecursive'
responses:
'201':
description: ''
schema:
$ref: '#/definitions/TodoRecursive'
tags:
- todo
parameters: []
/todo/recursive/{id}/:
get:
operationId: todo_recursive_read
description: ''
parameters: []
responses:
'200':
description: ''
schema:
$ref: '#/definitions/TodoRecursive'
tags:
- todo
put:
operationId: todo_recursive_update
description: ''
parameters:
- name: data
in: body
required: true
schema:
$ref: '#/definitions/TodoRecursive'
responses:
'200':
description: ''
schema:
$ref: '#/definitions/TodoRecursive'
tags:
- todo
patch:
operationId: todo_recursive_partial_update
description: ''
parameters:
- name: data
in: body
required: true
schema:
$ref: '#/definitions/TodoRecursive'
responses:
'200':
description: ''
schema:
$ref: '#/definitions/TodoRecursive'
tags:
- todo
delete:
operationId: todo_recursive_delete
description: ''
parameters: []
responses:
'204':
description: ''
tags:
- todo
parameters:
- name: id
in: path
description: A unique integer value identifying this todo tree.
required: true
type: integer
/todo/tree/:
get:
operationId: todo_tree_list
description: ''
parameters: []
responses:
'200':
description: ''
schema:
type: array
items:
$ref: '#/definitions/TodoTree'
tags:
- todo
parameters: []
/todo/tree/{id}/:
get:
operationId: todo_tree_read
description: ''
parameters: []
responses:
'200':
description: ''
schema:
$ref: '#/definitions/TodoTree'
tags:
- todo
parameters:
- name: id
in: path
description: A unique integer value identifying this todo tree.
required: true
type: integer
/todo/yetanother/:
get:
operationId: todo_yetanother_list
@ -1337,6 +1460,42 @@ definitions:
maxLength: 50
todo:
$ref: '#/definitions/Todo'
TodoRecursive:
required:
- title
type: object
properties:
id:
title: ID
type: integer
readOnly: true
title:
title: Title
type: string
maxLength: 50
parent:
$ref: '#/definitions/TodoRecursive'
parent_id:
type: integer
title: Parent id
TodoTree:
required:
- title
- children
type: object
properties:
id:
title: ID
type: integer
readOnly: true
title:
title: Title
type: string
maxLength: 50
children:
type: array
items:
$ref: '#/definitions/TodoTree'
TodoYetAnother:
required:
- title

View File

@ -65,7 +65,7 @@ known_standard_library =
collections,copy,distutils,functools,inspect,io,json,logging,operator,os,pkg_resources,re,setuptools,sys,
types,warnings
known_third_party =
coreapi,coreschema,datadiff,dj_database_url,django,django_filters,djangorestframework_camel_case,flex,gunicorn,
inflection,pygments,pytest,rest_framework,ruamel,setuptools_scm,swagger_spec_validator,uritemplate,user_agents,
whitenoise
coreapi,coreschema,datadiff,dj_database_url,django,django_filters,djangorestframework_camel_case,
rest_framework_recursive,flex,gunicorn,inflection,pygments,pytest,rest_framework,ruamel,setuptools_scm,
swagger_spec_validator,uritemplate,user_agents,whitenoise
known_first_party = drf_yasg,testproj,articles,people,snippets,todo,users,urlconfs