From 979ec84630f9fb57a3659cd6cc62b286153ceaef Mon Sep 17 00:00:00 2001 From: Roman Sichny Date: Fri, 27 Apr 2018 01:51:10 +0300 Subject: [PATCH] Django rest framework recursive support (#110) * add get_serializer_ref_name utility function * implement RecursiveFieldInspector * add option to allow non-existing reference in SchemaRef * add examples and README * Update changelog and docs --- README.rst | 6 + docs/changelog.rst | 14 ++ docs/settings.rst | 1 + requirements/testproj.txt | 1 + src/drf_yasg/app_settings.py | 3 +- src/drf_yasg/inspectors/__init__.py | 10 +- src/drf_yasg/inspectors/field.py | 38 ++++-- src/drf_yasg/openapi.py | 15 +- src/drf_yasg/utils.py | 22 +++ testproj/todo/migrations/0002_todotree.py | 22 +++ testproj/todo/models.py | 5 + testproj/todo/serializer.py | 22 ++- testproj/todo/urls.py | 4 +- testproj/todo/views.py | 16 ++- tests/reference.yaml | 159 ++++++++++++++++++++++ tox.ini | 6 +- 16 files changed, 311 insertions(+), 33 deletions(-) create mode 100644 testproj/todo/migrations/0002_todotree.py diff --git a/README.rst b/README.rst index 4908cfd..5ece519 100644 --- a/README.rst +++ b/README.rst @@ -353,6 +353,12 @@ Integration with `djangorestframework-camel-case `_ is +provided out of the box - if you have ``djangorestframework-recursive`` installed. + .. |travis| image:: https://img.shields.io/travis/axnsan12/drf-yasg/master.svg :target: https://travis-ci.org/axnsan12/drf-yasg :alt: Travis CI diff --git a/docs/changelog.rst b/docs/changelog.rst index ea684a1..bf81ee0 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -3,6 +3,20 @@ Changelog ######### +********* +**1.7.0** +********* + +*Release date: Apr 27, 2018* + +- **ADDED:** added integration with `djangorestframework-recursive `_ + (:issue:`109`, :pr:`110`, thanks to :ghuser:`rsichny`) + + *NOTE:* in order for this to work, you will have to add the new ``drf_yasg.inspectors.RecursiveFieldInspector`` to + your ``DEFAULT_FIELD_INSPECTORS`` array if you changed it from the default value + +- **FIXED:** ``SchemaRef`` now supports cyclical references via the ``ignore_unresolved`` argument + ********* **1.6.2** ********* diff --git a/docs/settings.rst b/docs/settings.rst index d3db470..4cea0d3 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -67,6 +67,7 @@ to this list. :class:`'drf_yasg.inspectors.FileFieldInspector' <.inspectors.FileFieldInspector>`, |br| \ :class:`'drf_yasg.inspectors.DictFieldInspector' <.inspectors.DictFieldInspector>`, |br| \ :class:`'drf_yasg.inspectors.HiddenFieldInspector' <.inspectors.HiddenFieldInspector>`, |br| \ +:class:`'drf_yasg.inspectors.RecursiveFieldInspector' <.inspectors.RecursiveFieldInspector>`, |br| \ :class:`'drf_yasg.inspectors.SimpleFieldInspector' <.inspectors.SimpleFieldInspector>`, |br| \ :class:`'drf_yasg.inspectors.StringDefaultFieldInspector' <.inspectors.StringDefaultFieldInspector>`, |br| \ ``]`` diff --git a/requirements/testproj.txt b/requirements/testproj.txt index 958b760..d082869 100644 --- a/requirements/testproj.txt +++ b/requirements/testproj.txt @@ -5,5 +5,6 @@ django-cors-headers>=2.1.0 django-filter>=1.1.0,<2.0; python_version == "2.7" django-filter>=1.1.0; python_version >= "3.4" djangorestframework-camel-case>=0.2.0 +djangorestframework-recursive>=0.1.2 dj-database-url>=0.4.2 user_agents>=1.1.0 diff --git a/src/drf_yasg/app_settings.py b/src/drf_yasg/app_settings.py index e3c1a64..ead07c3 100644 --- a/src/drf_yasg/app_settings.py +++ b/src/drf_yasg/app_settings.py @@ -6,12 +6,13 @@ SWAGGER_DEFAULTS = { 'DEFAULT_FIELD_INSPECTORS': [ 'drf_yasg.inspectors.CamelCaseJSONFilter', + 'drf_yasg.inspectors.RecursiveFieldInspector', 'drf_yasg.inspectors.ReferencingSerializerInspector', - 'drf_yasg.inspectors.RelatedFieldInspector', 'drf_yasg.inspectors.ChoiceFieldInspector', 'drf_yasg.inspectors.FileFieldInspector', 'drf_yasg.inspectors.DictFieldInspector', 'drf_yasg.inspectors.HiddenFieldInspector', + 'drf_yasg.inspectors.RelatedFieldInspector', 'drf_yasg.inspectors.SimpleFieldInspector', 'drf_yasg.inspectors.StringDefaultFieldInspector', ], diff --git a/src/drf_yasg/inspectors/__init__.py b/src/drf_yasg/inspectors/__init__.py index 74e4d85..d141b2d 100644 --- a/src/drf_yasg/inspectors/__init__.py +++ b/src/drf_yasg/inspectors/__init__.py @@ -4,8 +4,8 @@ from .base import ( ) from .field import ( CamelCaseJSONFilter, ChoiceFieldInspector, DictFieldInspector, FileFieldInspector, HiddenFieldInspector, - InlineSerializerInspector, ReferencingSerializerInspector, RelatedFieldInspector, SimpleFieldInspector, - StringDefaultFieldInspector + InlineSerializerInspector, RecursiveFieldInspector, ReferencingSerializerInspector, RelatedFieldInspector, + SimpleFieldInspector, StringDefaultFieldInspector ) from .query import CoreAPICompatInspector, DjangoRestResponsePagination from .view import SwaggerAutoSchema @@ -23,9 +23,9 @@ __all__ = [ 'CoreAPICompatInspector', 'DjangoRestResponsePagination', # field inspectors - 'InlineSerializerInspector', 'ReferencingSerializerInspector', 'RelatedFieldInspector', 'SimpleFieldInspector', - 'FileFieldInspector', 'ChoiceFieldInspector', 'DictFieldInspector', 'StringDefaultFieldInspector', - 'CamelCaseJSONFilter', 'HiddenFieldInspector', + 'InlineSerializerInspector', 'RecursiveFieldInspector', 'ReferencingSerializerInspector', 'RelatedFieldInspector', + 'SimpleFieldInspector', 'FileFieldInspector', 'ChoiceFieldInspector', 'DictFieldInspector', + 'StringDefaultFieldInspector', 'CamelCaseJSONFilter', 'HiddenFieldInspector', # view inspectors 'SwaggerAutoSchema', diff --git a/src/drf_yasg/inspectors/field.py b/src/drf_yasg/inspectors/field.py index 543bcf8..24cc8d5 100644 --- a/src/drf_yasg/inspectors/field.py +++ b/src/drf_yasg/inspectors/field.py @@ -10,7 +10,7 @@ from rest_framework.settings import api_settings as rest_framework_settings from .. import openapi from ..errors import SwaggerGenerationError -from ..utils import decimal_as_float, filter_none +from ..utils import decimal_as_float, filter_none, get_serializer_ref_name from .base import FieldInspector, NotHandled, SerializerInspector logger = logging.getLogger(__name__) @@ -55,23 +55,12 @@ class InlineSerializerInspector(SerializerInspector): if swagger_object_type != openapi.Schema: raise SwaggerGenerationError("cannot instantiate nested serializer as " + swagger_object_type.__name__) - serializer = field - serializer_meta = getattr(serializer, 'Meta', None) - serializer_name = type(serializer).__name__ - if hasattr(serializer_meta, 'ref_name'): - ref_name = serializer_meta.ref_name - elif serializer_name == 'NestedSerializer' and isinstance(serializer, serializers.ModelSerializer): - logger.debug("Forcing inline output for ModelSerializer named 'NestedSerializer': " + str(serializer)) - ref_name = None - else: - ref_name = serializer_name - if ref_name.endswith('Serializer'): - ref_name = ref_name[:-len('Serializer')] + ref_name = get_serializer_ref_name(field) def make_schema_definition(): properties = OrderedDict() required = [] - for property_name, child in serializer.fields.items(): + for property_name, child in field.fields.items(): property_name = self.get_property_name(property_name) prop_kwargs = { 'read_only': child.read_only or None @@ -531,3 +520,24 @@ else: return camelize_schema(result, self.components) return result + +try: + from rest_framework_recursive.fields import RecursiveField +except ImportError: + class RecursiveFieldInspector(FieldInspector): + """Provides conversion for RecursiveField (https://github.com/heywbj/django-rest-framework-recursive)""" + pass +else: + class RecursiveFieldInspector(FieldInspector): + """Provides conversion for RecursiveField (https://github.com/heywbj/django-rest-framework-recursive)""" + def field_to_swagger_object(self, field, swagger_object_type, use_references, **kwargs): + if isinstance(field, RecursiveField) and swagger_object_type == openapi.Schema: + assert use_references is True, "Can not create schema for RecursiveField when use_references is False" + + ref_name = get_serializer_ref_name(field.proxied) + assert ref_name is not None, "Can not create RecursiveField schema for inline ModelSerializer" + + return openapi.SchemaRef(self.components.with_scope(openapi.SCHEMA_DEFINITIONS), ref_name, + ignore_unresolved=True) + + return NotHandled diff --git a/src/drf_yasg/openapi.py b/src/drf_yasg/openapi.py index c76b03b..5aee816 100644 --- a/src/drf_yasg/openapi.py +++ b/src/drf_yasg/openapi.py @@ -466,7 +466,7 @@ class Schema(SwaggerDict): class _Ref(SwaggerDict): ref_name_re = re.compile(r"#/(?P.+)/(?P[^/]+)$") - def __init__(self, resolver, name, scope, expected_type): + def __init__(self, resolver, name, scope, expected_type, ignore_unresolved=False): """Base class for all reference types. A reference object has only one property, ``$ref``, which must be a JSON reference to a valid object in the specification, e.g. ``#/definitions/Article`` to refer to an article model. @@ -474,13 +474,15 @@ class _Ref(SwaggerDict): :param str name: referenced object name, e.g. "Article" :param str scope: reference scope, e.g. "definitions" :param type[.SwaggerDict] expected_type: the expected type that will be asserted on the object found in resolver + :param bool ignore_unresolved: allow the reference to be not defined in resolver """ super(_Ref, self).__init__() assert not type(self) == _Ref, "do not instantiate _Ref directly" ref_name = "#/{scope}/{name}".format(scope=scope, name=name) - obj = resolver.get(name, scope) - assert isinstance(obj, expected_type), ref_name + " is a {actual}, not a {expected}" \ - .format(actual=type(obj).__name__, expected=expected_type.__name__) + if not ignore_unresolved: + obj = resolver.get(name, scope) + assert isinstance(obj, expected_type), ref_name + " is a {actual}, not a {expected}" \ + .format(actual=type(obj).__name__, expected=expected_type.__name__) self.ref = ref_name def resolve(self, resolver): @@ -502,14 +504,15 @@ class _Ref(SwaggerDict): class SchemaRef(_Ref): - def __init__(self, resolver, schema_name): + def __init__(self, resolver, schema_name, ignore_unresolved=False): """Adds a reference to a named Schema defined in the ``#/definitions/`` object. :param .ReferenceResolver resolver: component resolver which must contain the definition :param str schema_name: schema name + :param bool ignore_unresolved: allow the reference to be not defined in resolver """ assert SCHEMA_DEFINITIONS in resolver.scopes - super(SchemaRef, self).__init__(resolver, schema_name, SCHEMA_DEFINITIONS, Schema) + super(SchemaRef, self).__init__(resolver, schema_name, SCHEMA_DEFINITIONS, Schema, ignore_unresolved) Schema.OR_REF = (Schema, SchemaRef) diff --git a/src/drf_yasg/utils.py b/src/drf_yasg/utils.py index 9e1c5ed..85c8550 100644 --- a/src/drf_yasg/utils.py +++ b/src/drf_yasg/utils.py @@ -295,3 +295,25 @@ def decimal_as_float(field): if isinstance(field, serializers.DecimalField) or isinstance(field, models.DecimalField): return not getattr(field, 'coerce_to_string', rest_framework_settings.COERCE_DECIMAL_TO_STRING) return False + + +def get_serializer_ref_name(serializer): + """ + Get serializer's ref_name (or None for ModelSerializer if it is named 'NestedSerializer') + + :param serializer: Serializer instance + :return: Serializer's ref_name or None for inline serializer + :rtype: str or None + """ + serializer_meta = getattr(serializer, 'Meta', None) + serializer_name = type(serializer).__name__ + if hasattr(serializer_meta, 'ref_name'): + ref_name = serializer_meta.ref_name + elif serializer_name == 'NestedSerializer' and isinstance(serializer, serializers.ModelSerializer): + logger.debug("Forcing inline output for ModelSerializer named 'NestedSerializer': " + str(serializer)) + ref_name = None + else: + ref_name = serializer_name + if ref_name.endswith('Serializer'): + ref_name = ref_name[:-len('Serializer')] + return ref_name diff --git a/testproj/todo/migrations/0002_todotree.py b/testproj/todo/migrations/0002_todotree.py new file mode 100644 index 0000000..0160e83 --- /dev/null +++ b/testproj/todo/migrations/0002_todotree.py @@ -0,0 +1,22 @@ +# Generated by Django 2.0.4 on 2018-04-26 13:06 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + dependencies = [ + ('todo', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='TodoTree', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('title', models.CharField(max_length=50)), + ('parent', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, + related_name='children', to='todo.TodoTree')), + ], + ), + ] diff --git a/testproj/todo/models.py b/testproj/todo/models.py index cc29546..46e6300 100644 --- a/testproj/todo/models.py +++ b/testproj/todo/models.py @@ -13,3 +13,8 @@ class TodoAnother(models.Model): class TodoYetAnother(models.Model): todo = models.ForeignKey(TodoAnother, on_delete=models.CASCADE) title = models.CharField(max_length=50) + + +class TodoTree(models.Model): + parent = models.ForeignKey('self', on_delete=models.CASCADE, related_name='children', null=True) + title = models.CharField(max_length=50) diff --git a/testproj/todo/serializer.py b/testproj/todo/serializer.py index daa0c2f..ede243c 100644 --- a/testproj/todo/serializer.py +++ b/testproj/todo/serializer.py @@ -1,7 +1,8 @@ from django.utils import timezone from rest_framework import serializers +from rest_framework_recursive.fields import RecursiveField -from .models import Todo, TodoAnother, TodoYetAnother +from .models import Todo, TodoAnother, TodoTree, TodoYetAnother class TodoSerializer(serializers.ModelSerializer): @@ -25,3 +26,22 @@ class TodoYetAnotherSerializer(serializers.ModelSerializer): model = TodoYetAnother fields = ('title', 'todo') depth = 2 + + +class TodoTreeSerializer(serializers.ModelSerializer): + children = serializers.ListField(child=RecursiveField(), source='children.all') + + class Meta: + model = TodoTree + fields = ('id', 'title', 'children') + + +class TodoRecursiveSerializer(serializers.ModelSerializer): + parent = RecursiveField(read_only=True) + parent_id = serializers.PrimaryKeyRelatedField(queryset=TodoTree.objects.all(), pk_field=serializers.IntegerField(), + write_only=True, allow_null=True, required=False, default=None, + source='parent') + + class Meta: + model = TodoTree + fields = ('id', 'title', 'parent', 'parent_id') diff --git a/testproj/todo/urls.py b/testproj/todo/urls.py index cb0a3ba..637db7c 100644 --- a/testproj/todo/urls.py +++ b/testproj/todo/urls.py @@ -7,10 +7,12 @@ router = routers.DefaultRouter() router.register(r'', views.TodoViewSet) router.register(r'another', views.TodoAnotherViewSet) router.register(r'yetanother', views.TodoYetAnotherViewSet) +router.register(r'tree', views.TodoTreeView) +router.register(r'recursive', views.TodoRecursiveView) urlpatterns = router.urls urlpatterns += [ url(r'^(?P\d+)/yetanother/(?P\d+)/$', - views.NestedTodoView.as_view(),), + views.NestedTodoView.as_view(), ), ] diff --git a/testproj/todo/views.py b/testproj/todo/views.py index 848e461..786d6d7 100644 --- a/testproj/todo/views.py +++ b/testproj/todo/views.py @@ -1,8 +1,10 @@ from rest_framework import viewsets from rest_framework.generics import RetrieveAPIView -from .models import Todo, TodoAnother, TodoYetAnother -from .serializer import TodoAnotherSerializer, TodoSerializer, TodoYetAnotherSerializer +from .models import Todo, TodoAnother, TodoTree, TodoYetAnother +from .serializer import ( + TodoAnotherSerializer, TodoRecursiveSerializer, TodoSerializer, TodoTreeSerializer, TodoYetAnotherSerializer +) class TodoViewSet(viewsets.ReadOnlyModelViewSet): @@ -25,3 +27,13 @@ class TodoYetAnotherViewSet(viewsets.ReadOnlyModelViewSet): class NestedTodoView(RetrieveAPIView): serializer_class = TodoYetAnotherSerializer + + +class TodoTreeView(viewsets.ReadOnlyModelViewSet): + queryset = TodoTree.objects.all() + serializer_class = TodoTreeSerializer + + +class TodoRecursiveView(viewsets.ModelViewSet): + queryset = TodoTree.objects.all() + serializer_class = TodoRecursiveSerializer diff --git a/tests/reference.yaml b/tests/reference.yaml index f0d917c..0a27e44 100644 --- a/tests/reference.yaml +++ b/tests/reference.yaml @@ -495,6 +495,129 @@ paths: description: A unique integer value identifying this todo another. required: true type: integer + /todo/recursive/: + get: + operationId: todo_recursive_list + description: '' + parameters: [] + responses: + '200': + description: '' + schema: + type: array + items: + $ref: '#/definitions/TodoRecursive' + tags: + - todo + post: + operationId: todo_recursive_create + description: '' + parameters: + - name: data + in: body + required: true + schema: + $ref: '#/definitions/TodoRecursive' + responses: + '201': + description: '' + schema: + $ref: '#/definitions/TodoRecursive' + tags: + - todo + parameters: [] + /todo/recursive/{id}/: + get: + operationId: todo_recursive_read + description: '' + parameters: [] + responses: + '200': + description: '' + schema: + $ref: '#/definitions/TodoRecursive' + tags: + - todo + put: + operationId: todo_recursive_update + description: '' + parameters: + - name: data + in: body + required: true + schema: + $ref: '#/definitions/TodoRecursive' + responses: + '200': + description: '' + schema: + $ref: '#/definitions/TodoRecursive' + tags: + - todo + patch: + operationId: todo_recursive_partial_update + description: '' + parameters: + - name: data + in: body + required: true + schema: + $ref: '#/definitions/TodoRecursive' + responses: + '200': + description: '' + schema: + $ref: '#/definitions/TodoRecursive' + tags: + - todo + delete: + operationId: todo_recursive_delete + description: '' + parameters: [] + responses: + '204': + description: '' + tags: + - todo + parameters: + - name: id + in: path + description: A unique integer value identifying this todo tree. + required: true + type: integer + /todo/tree/: + get: + operationId: todo_tree_list + description: '' + parameters: [] + responses: + '200': + description: '' + schema: + type: array + items: + $ref: '#/definitions/TodoTree' + tags: + - todo + parameters: [] + /todo/tree/{id}/: + get: + operationId: todo_tree_read + description: '' + parameters: [] + responses: + '200': + description: '' + schema: + $ref: '#/definitions/TodoTree' + tags: + - todo + parameters: + - name: id + in: path + description: A unique integer value identifying this todo tree. + required: true + type: integer /todo/yetanother/: get: operationId: todo_yetanother_list @@ -1337,6 +1460,42 @@ definitions: maxLength: 50 todo: $ref: '#/definitions/Todo' + TodoRecursive: + required: + - title + type: object + properties: + id: + title: ID + type: integer + readOnly: true + title: + title: Title + type: string + maxLength: 50 + parent: + $ref: '#/definitions/TodoRecursive' + parent_id: + type: integer + title: Parent id + TodoTree: + required: + - title + - children + type: object + properties: + id: + title: ID + type: integer + readOnly: true + title: + title: Title + type: string + maxLength: 50 + children: + type: array + items: + $ref: '#/definitions/TodoTree' TodoYetAnother: required: - title diff --git a/tox.ini b/tox.ini index ef97c6b..3b58a8e 100644 --- a/tox.ini +++ b/tox.ini @@ -65,7 +65,7 @@ known_standard_library = collections,copy,distutils,functools,inspect,io,json,logging,operator,os,pkg_resources,re,setuptools,sys, types,warnings known_third_party = - coreapi,coreschema,datadiff,dj_database_url,django,django_filters,djangorestframework_camel_case,flex,gunicorn, - inflection,pygments,pytest,rest_framework,ruamel,setuptools_scm,swagger_spec_validator,uritemplate,user_agents, - whitenoise + coreapi,coreschema,datadiff,dj_database_url,django,django_filters,djangorestframework_camel_case, + rest_framework_recursive,flex,gunicorn,inflection,pygments,pytest,rest_framework,ruamel,setuptools_scm, + swagger_spec_validator,uritemplate,user_agents,whitenoise known_first_party = drf_yasg,testproj,articles,people,snippets,todo,users,urlconfs