Improve RelatedField and callable default handling

- callable default values will now be properly called 
  - PrimaryKeyRelatedField and SlugRelatedField will now return an appropriate type based on the relation model's Field
  - mock views now have a request object bound even when public is True
openapi3
Cristi Vîjdea 2017-12-23 11:52:31 +01:00
parent f05889292a
commit 9f6ee4da87
17 changed files with 274 additions and 142 deletions

View File

@ -2,14 +2,12 @@ import re
from collections import defaultdict, OrderedDict from collections import defaultdict, OrderedDict
import uritemplate import uritemplate
from coreapi.compat import force_text
from rest_framework.schemas.generators import SchemaGenerator, EndpointEnumerator as _EndpointEnumerator from rest_framework.schemas.generators import SchemaGenerator, EndpointEnumerator as _EndpointEnumerator
from rest_framework.schemas.inspectors import get_pk_description
from . import openapi from . import openapi
from .inspectors import SwaggerAutoSchema from .inspectors import SwaggerAutoSchema
from .openapi import ReferenceResolver from .openapi import ReferenceResolver
from .utils import get_schema_type_from_model_field from .utils import inspect_model_field, get_model_field
PATH_PARAMETER_RE = re.compile(r'{(?P<parameter>\w+)}') PATH_PARAMETER_RE = re.compile(r'{(?P<parameter>\w+)}')
@ -82,9 +80,9 @@ class OpenAPISchemaGenerator(object):
:return: the generated Swagger specification :return: the generated Swagger specification
:rtype: openapi.Swagger :rtype: openapi.Swagger
""" """
endpoints = self.get_endpoints(None if public else request) endpoints = self.get_endpoints(request)
components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS) components = ReferenceResolver(openapi.SCHEMA_DEFINITIONS)
paths = self.get_paths(endpoints, components) paths = self.get_paths(endpoints, components, public)
url = self._gen.url url = self._gen.url
if not url and request is not None: if not url and request is not None:
@ -114,9 +112,9 @@ class OpenAPISchemaGenerator(object):
return view return view
def get_endpoints(self, request=None): def get_endpoints(self, request=None):
"""Iterate over all the registered endpoints in the API. """Iterate over all the registered endpoints in the API and return a fake view with the right parameters.
:param rest_framework.request.Request request: used for returning only endpoints available to the given request :param rest_framework.request.Request request: request to bind to the endpoint views
:return: {path: (view_class, list[(http_method, view_instance)]) :return: {path: (view_class, list[(http_method, view_instance)])
:rtype: dict :rtype: dict
""" """
@ -151,11 +149,12 @@ class OpenAPISchemaGenerator(object):
""" """
return self._gen.get_keys(subpath, method, view) return self._gen.get_keys(subpath, method, view)
def get_paths(self, endpoints, components): def get_paths(self, endpoints, components, public):
"""Generate the Swagger Paths for the API from the given endpoints. """Generate the Swagger Paths for the API from the given endpoints.
:param dict endpoints: endpoints as returned by get_endpoints :param dict endpoints: endpoints as returned by get_endpoints
:param ReferenceResolver components: resolver/container for Swagger References :param ReferenceResolver components: resolver/container for Swagger References
:param bool public: if True, all endpoints are included regardless of access through `request`
:rtype: openapi.Paths :rtype: openapi.Paths
""" """
if not endpoints: if not endpoints:
@ -169,7 +168,7 @@ class OpenAPISchemaGenerator(object):
path_parameters = self.get_path_parameters(path, view_cls) path_parameters = self.get_path_parameters(path, view_cls)
operations = {} operations = {}
for method, view in methods: for method, view in methods:
if not self._gen.has_view_permissions(path, method, view): if not public and not self._gen.has_view_permissions(path, method, view):
continue continue
operation_keys = self.get_operation_keys(path[len(prefix):], method, view) operation_keys = self.get_operation_keys(path[len(prefix):], method, view)
@ -209,35 +208,20 @@ class OpenAPISchemaGenerator(object):
:rtype: list[openapi.Parameter] :rtype: list[openapi.Parameter]
""" """
parameters = [] parameters = []
queryset = getattr(view_cls, 'queryset', None)
model = getattr(getattr(view_cls, 'queryset', None), 'model', None) model = getattr(getattr(view_cls, 'queryset', None), 'model', None)
for variable in uritemplate.variables(path): for variable in uritemplate.variables(path):
pattern = None model, model_field = get_model_field(queryset, variable)
type = openapi.TYPE_STRING attrs = inspect_model_field(model, model_field)
description = None if hasattr(view_cls, 'lookup_value_regex') and getattr(view_cls, 'lookup_field', None) == variable:
if model is not None: attrs['pattern'] = view_cls.lookup_value_regex
# Attempt to infer a field description if possible.
try:
model_field = model._meta.get_field(variable)
except Exception: # pragma: no cover
model_field = None
else:
type = get_schema_type_from_model_field(model_field)
if model_field is not None and model_field.help_text:
description = force_text(model_field.help_text)
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
if hasattr(view_cls, 'lookup_value_regex') and getattr(view_cls, 'lookup_field', None) == variable:
pattern = view_cls.lookup_value_regex
field = openapi.Parameter( field = openapi.Parameter(
name=variable, name=variable,
required=True, required=True,
in_=openapi.IN_PATH, in_=openapi.IN_PATH,
type=type, **attrs
pattern=pattern,
description=description,
) )
parameters.append(field) parameters.append(field)

View File

@ -309,6 +309,7 @@ class Items(SwaggerDict):
:param .Items items: only valid if `type` is ``array`` :param .Items items: only valid if `type` is ``array``
""" """
super(Items, self).__init__(**extra) super(Items, self).__init__(**extra)
assert type is not None, "type is required!"
self.type = type self.type = type
self.format = format self.format = format
self.enum = enum self.enum = enum
@ -372,6 +373,7 @@ class Schema(SwaggerDict):
# common error # common error
raise AssertionError( raise AssertionError(
"the `requires` attribute of schema must be an array of required properties, not a boolean!") "the `requires` attribute of schema must be an array of required properties, not a boolean!")
assert type is not None, "type is required!"
self.description = description self.description = description
self.required = required self.required = required
self.type = type self.type = type

View File

@ -1,3 +1,4 @@
import logging
from collections import OrderedDict from collections import OrderedDict
from django.core.validators import RegexValidator from django.core.validators import RegexValidator
@ -5,21 +6,19 @@ from django.db import models
from django.utils.encoding import force_text from django.utils.encoding import force_text
from rest_framework import serializers from rest_framework import serializers
from rest_framework.mixins import RetrieveModelMixin, DestroyModelMixin, UpdateModelMixin from rest_framework.mixins import RetrieveModelMixin, DestroyModelMixin, UpdateModelMixin
from rest_framework.schemas.inspectors import get_pk_description
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.utils import json, encoders
from . import openapi from . import openapi
from .errors import SwaggerGenerationError from .errors import SwaggerGenerationError
logger = logging.getLogger(__name__)
#: used to forcibly remove the body of a request via :func:`.swagger_auto_schema` #: used to forcibly remove the body of a request via :func:`.swagger_auto_schema`
no_body = object() no_body = object()
def get_schema_type_from_model_field(model_field):
if isinstance(model_field, models.AutoField):
return openapi.TYPE_INTEGER
return openapi.TYPE_STRING
def is_list_view(path, method, view): def is_list_view(path, method, view):
"""Check if the given path/method appears to represent a list view (as opposed to a detail/instance view). """Check if the given path/method appears to represent a list view (as opposed to a detail/instance view).
@ -164,6 +163,87 @@ def swagger_auto_schema(method=None, methods=None, auto_schema=None, request_bod
return decorator return decorator
def get_model_field(queryset, field_name):
"""Try to get information about a model and model field from a queryset.
:param queryset: the queryset
:param field_name: the target field name
:returns: the model and target field from the queryset as a 2-tuple; both elements can be ``None``
:rtype: tuple
"""
model = getattr(queryset, 'model', None)
try:
model_field = model._meta.get_field(field_name)
except Exception: # pragma: no cover
model_field = None
return model, model_field
model_field_to_swagger_type = {
models.AutoField: (openapi.TYPE_INTEGER, None),
models.BinaryField: (openapi.TYPE_STRING, openapi.FORMAT_BINARY),
models.BooleanField: (openapi.TYPE_BOOLEAN, None),
models.NullBooleanField: (openapi.TYPE_BOOLEAN, None),
models.DateTimeField: (openapi.TYPE_STRING, openapi.FORMAT_DATETIME),
models.DateField: (openapi.TYPE_STRING, openapi.FORMAT_DATE),
models.DecimalField: (openapi.TYPE_NUMBER, None),
models.DurationField: (openapi.TYPE_INTEGER, None),
models.FloatField: (openapi.TYPE_NUMBER, None),
models.IntegerField: (openapi.TYPE_INTEGER, None),
models.IPAddressField: (openapi.TYPE_STRING, openapi.FORMAT_IPV4),
models.GenericIPAddressField: (openapi.TYPE_STRING, openapi.FORMAT_IPV6),
models.SlugField: (openapi.TYPE_STRING, openapi.FORMAT_SLUG),
models.TextField: (openapi.TYPE_STRING, None),
models.TimeField: (openapi.TYPE_STRING, None),
models.UUIDField: (openapi.TYPE_STRING, openapi.FORMAT_UUID),
models.CharField: (openapi.TYPE_STRING, None),
}
def inspect_model_field(model, model_field):
"""Extract information from a django model field instance.
:param model: the django model
:param model_field: a field on the model
:return: description, type, format and pattern extracted from the model field
:rtype: OrderedDict
"""
if model is not None and model_field is not None:
for model_field_class, tf in model_field_to_swagger_type.items():
if isinstance(model_field, model_field_class):
swagger_type, format = tf
break
else:
swagger_type, format = None, None
if format is None or format == openapi.FORMAT_SLUG:
pattern = find_regex(model_field)
else:
pattern = None
if model_field.help_text:
description = force_text(model_field.help_text)
elif model_field.primary_key:
description = get_pk_description(model, model_field)
else:
description = None
else:
description = None
swagger_type = None
format = None
pattern = None
result = OrderedDict([
('description', description),
('type', swagger_type or openapi.TYPE_STRING),
('format', format),
('pattern', pattern)
])
# TODO: filter none
return result
def serializer_field_to_swagger(field, swagger_object_type, definitions=None, **kwargs): def serializer_field_to_swagger(field, swagger_object_type, definitions=None, **kwargs):
"""Convert a drf Serializer or Field instance into a Swagger object. """Convert a drf Serializer or Field instance into a Swagger object.
@ -183,17 +263,50 @@ def serializer_field_to_swagger(field, swagger_object_type, definitions=None, **
description = force_text(field.help_text) if field.help_text else None description = force_text(field.help_text) if field.help_text else None
description = description if swagger_object_type != openapi.Items else None # Items has no description either description = description if swagger_object_type != openapi.Items else None # Items has no description either
def SwaggerType(**instance_kwargs): def SwaggerType(existing_object=None, **instance_kwargs):
if swagger_object_type == openapi.Parameter and 'required' not in instance_kwargs: if swagger_object_type == openapi.Parameter and 'required' not in instance_kwargs:
instance_kwargs['required'] = field.required instance_kwargs['required'] = field.required
if swagger_object_type != openapi.Items and 'default' not in instance_kwargs: if swagger_object_type != openapi.Items and 'default' not in instance_kwargs:
default = getattr(field, 'default', serializers.empty) default = getattr(field, 'default', serializers.empty)
if default is not serializers.empty: if default is not serializers.empty:
instance_kwargs['default'] = default if callable(default):
try:
if hasattr(default, 'set_context'):
default.set_context(field)
default = default()
except Exception as e:
logger.warning("default for %s is callable but it raised an exception when "
"called; 'default' field will not be added to schema", field, exc_info=True)
default = None
if default is not None:
try:
default = field.to_representation(default)
# JSON roundtrip ensures that the value is valid JSON;
# for example, sets get transformed into lists
default = json.loads(json.dumps(default, cls=encoders.JSONEncoder))
except Exception as e:
logger.warning("'default' on schema for %s will not be set because "
"to_representation raised an exception", field, exc_info=True)
default = None
if default is not None:
instance_kwargs['default'] = default
if swagger_object_type == openapi.Schema and 'read_only' not in instance_kwargs: if swagger_object_type == openapi.Schema and 'read_only' not in instance_kwargs:
if field.read_only: if field.read_only:
instance_kwargs['read_only'] = True instance_kwargs['read_only'] = True
instance_kwargs.update(kwargs) instance_kwargs.update(kwargs)
instance_kwargs.pop('title', None)
instance_kwargs.pop('description', None)
if existing_object is not None:
existing_object.title = title
existing_object.description = description
for attr, val in instance_kwargs.items():
setattr(existing_object, attr, val)
return existing_object
return swagger_object_type(title=title, description=description, **instance_kwargs) return swagger_object_type(title=title, description=description, **instance_kwargs)
# arrays in Schema have Schema elements, arrays in Parameter and Items have Items elements # arrays in Schema have Schema elements, arrays in Parameter and Items have Items elements
@ -246,8 +359,27 @@ def serializer_field_to_swagger(field, swagger_object_type, definitions=None, **
unique_items=True, # is this OK? unique_items=True, # is this OK?
) )
elif isinstance(field, serializers.PrimaryKeyRelatedField): elif isinstance(field, serializers.PrimaryKeyRelatedField):
model = field.queryset.model if field.pk_field:
return SwaggerType(type=get_schema_type_from_model_field(model._meta.pk)) result = serializer_field_to_swagger(field.pk_field, swagger_object_type, definitions, **kwargs)
return SwaggerType(existing_object=result)
attrs = {'type': openapi.TYPE_STRING}
try:
model = field.queryset.model
pk_field = model._meta.pk
except Exception:
logger.warning("an exception was raised when attempting to extract the primary key related to %s; "
"falling back to plain string" % field, exc_info=True)
else:
attrs.update(inspect_model_field(model, pk_field))
return SwaggerType(**attrs)
elif isinstance(field, serializers.HyperlinkedRelatedField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_URI)
elif isinstance(field, serializers.SlugRelatedField):
model, model_field = get_model_field(field.queryset, field.slug_field)
attrs = inspect_model_field(model, model_field)
return SwaggerType(**attrs)
elif isinstance(field, serializers.RelatedField): elif isinstance(field, serializers.RelatedField):
return SwaggerType(type=openapi.TYPE_STRING) return SwaggerType(type=openapi.TYPE_STRING)
# ------ CHOICES # ------ CHOICES
@ -262,7 +394,7 @@ def serializer_field_to_swagger(field, swagger_object_type, definitions=None, **
elif isinstance(field, serializers.ChoiceField): elif isinstance(field, serializers.ChoiceField):
return SwaggerType(type=openapi.TYPE_STRING, enum=list(field.choices.keys())) return SwaggerType(type=openapi.TYPE_STRING, enum=list(field.choices.keys()))
# ------ BOOL # ------ BOOL
elif isinstance(field, serializers.BooleanField): elif isinstance(field, (serializers.BooleanField, serializers.NullBooleanField)):
return SwaggerType(type=openapi.TYPE_BOOLEAN) return SwaggerType(type=openapi.TYPE_BOOLEAN)
# ------ NUMERIC # ------ NUMERIC
elif isinstance(field, (serializers.DecimalField, serializers.FloatField)): elif isinstance(field, (serializers.DecimalField, serializers.FloatField)):
@ -271,6 +403,8 @@ def serializer_field_to_swagger(field, swagger_object_type, definitions=None, **
elif isinstance(field, serializers.IntegerField): elif isinstance(field, serializers.IntegerField):
# TODO: min_value max_value # TODO: min_value max_value
return SwaggerType(type=openapi.TYPE_INTEGER) return SwaggerType(type=openapi.TYPE_INTEGER)
elif isinstance(field, serializers.DurationField):
return SwaggerType(type=openapi.TYPE_INTEGER)
# ------ STRING # ------ STRING
elif isinstance(field, serializers.EmailField): elif isinstance(field, serializers.EmailField):
return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_EMAIL) return SwaggerType(type=openapi.TYPE_STRING, format=openapi.FORMAT_EMAIL)
@ -317,8 +451,10 @@ def serializer_field_to_swagger(field, swagger_object_type, definitions=None, **
type=openapi.TYPE_OBJECT, type=openapi.TYPE_OBJECT,
additional_properties=child_schema additional_properties=child_schema
) )
elif isinstance(field, serializers.ModelField):
return SwaggerType(type=openapi.TYPE_STRING)
# TODO unhandled fields: TimeField DurationField HiddenField ModelField NullBooleanField? JSONField # TODO unhandled fields: TimeField HiddenField JSONField
# everything else gets string by default # everything else gets string by default
return SwaggerType(type=openapi.TYPE_STRING) return SwaggerType(type=openapi.TYPE_STRING)

View File

@ -1,6 +1,8 @@
# Generated by Django 2.0 on 2017-12-05 04:05 # Generated by Django 2.0 on 2017-12-23 09:07
from django.conf import settings
from django.db import migrations, models from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration): class Migration(migrations.Migration):
@ -8,6 +10,7 @@ class Migration(migrations.Migration):
initial = True initial = True
dependencies = [ dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
] ]
operations = [ operations = [
@ -15,12 +18,13 @@ class Migration(migrations.Migration):
name='Article', name='Article',
fields=[ fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('title', models.CharField(help_text='Main article headline', max_length=255, unique=True)), ('title', models.CharField(help_text='title model help_text', max_length=255, unique=True)),
('body', models.TextField(help_text='Article content', max_length=5000)), ('body', models.TextField(help_text='article model help_text', max_length=5000)),
('slug', models.SlugField(blank=True, help_text='Unique URL slug identifying the article', unique=True)), ('slug', models.SlugField(blank=True, help_text='slug model help_text', unique=True)),
('date_created', models.DateTimeField(auto_now_add=True)), ('date_created', models.DateTimeField(auto_now_add=True)),
('date_modified', models.DateTimeField(auto_now=True)), ('date_modified', models.DateTimeField(auto_now=True)),
('cover', models.ImageField(blank=True, upload_to='article/original/')), ('cover', models.ImageField(blank=True, upload_to='article/original/')),
('author', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='articles', to=settings.AUTH_USER_MODEL)),
], ],
), ),
] ]

View File

@ -1,28 +0,0 @@
# Generated by Django 2.0 on 2017-12-21 15:35
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('articles', '0001_initial'),
]
operations = [
migrations.AlterField(
model_name='article',
name='body',
field=models.TextField(help_text='article model help_text', max_length=5000),
),
migrations.AlterField(
model_name='article',
name='slug',
field=models.SlugField(blank=True, help_text='slug model help_text', unique=True),
),
migrations.AlterField(
model_name='article',
name='title',
field=models.CharField(help_text='title model help_text', max_length=255, unique=True),
),
]

View File

@ -7,5 +7,6 @@ class Article(models.Model):
slug = models.SlugField(help_text="slug model help_text", unique=True, blank=True) slug = models.SlugField(help_text="slug model help_text", unique=True, blank=True)
date_created = models.DateTimeField(auto_now_add=True) date_created = models.DateTimeField(auto_now_add=True)
date_modified = models.DateTimeField(auto_now=True) date_modified = models.DateTimeField(auto_now=True)
author = models.ForeignKey('auth.User', related_name='articles', on_delete=models.CASCADE)
cover = models.ImageField(upload_to='article/original/', blank=True) cover = models.ImageField(upload_to='article/original/', blank=True)

View File

@ -14,12 +14,19 @@ class ArticleSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Article model = Article
fields = ('title', 'body', 'slug', 'date_created', 'date_modified', fields = ('title', 'author', 'body', 'slug', 'date_created', 'date_modified',
'references', 'uuid', 'cover', 'cover_name') 'references', 'uuid', 'cover', 'cover_name')
read_only_fields = ('date_created', 'date_modified', read_only_fields = ('date_created', 'date_modified',
'references', 'uuid', 'cover_name') 'references', 'uuid', 'cover_name')
lookup_field = 'slug' lookup_field = 'slug'
extra_kwargs = {'body': {'help_text': 'body serializer help_text'}} extra_kwargs = {
'body': {'help_text': 'body serializer help_text'},
'author': {
'default': serializers.CurrentUserDefault(),
'help_text': "The ID of the user that created this article; if none is provided, "
"defaults to the currently logged in user."
},
}
class ImageUploadSerializer(serializers.Serializer): class ImageUploadSerializer(serializers.Serializer):

View File

@ -20,6 +20,11 @@ class NoPagingAutoSchema(SwaggerAutoSchema):
return False return False
class ArticlePagination(LimitOffsetPagination):
default_limit = 5
max_limit = 25
@method_decorator(name='list', decorator=swagger_auto_schema( @method_decorator(name='list', decorator=swagger_auto_schema(
operation_description="description from swagger_auto_schema via method_decorator" operation_description="description from swagger_auto_schema via method_decorator"
)) ))
@ -41,8 +46,7 @@ class ArticleViewSet(viewsets.ModelViewSet):
lookup_value_regex = r'[a-z0-9]+(?:-[a-z0-9]+)' lookup_value_regex = r'[a-z0-9]+(?:-[a-z0-9]+)'
serializer_class = serializers.ArticleSerializer serializer_class = serializers.ArticleSerializer
pagination_class = LimitOffsetPagination pagination_class = ArticlePagination
max_page_size = 5
filter_backends = (DjangoFilterBackend, OrderingFilter) filter_backends = (DjangoFilterBackend, OrderingFilter)
filter_fields = ('title',) filter_fields = ('title',)
ordering_fields = ('date_modified','date_created') ordering_fields = ('date_modified','date_created')

Binary file not shown.

File diff suppressed because one or more lines are too long

View File

@ -1,27 +0,0 @@
# Generated by Django 2.0 on 2017-12-05 04:05
from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
('snippets', '0001_initial'),
]
operations = [
migrations.AddField(
model_name='snippet',
name='owner',
field=models.ForeignKey(default='', on_delete=django.db.models.deletion.CASCADE, related_name='snippets', to=settings.AUTH_USER_MODEL),
preserve_default=False,
),
migrations.AlterField(
model_name='snippet',
name='code',
field=models.TextField(help_text='code model help text'),
),
]

View File

@ -1,3 +1,4 @@
from django.contrib.auth import get_user_model
from rest_framework import serializers from rest_framework import serializers
from snippets.models import Snippet, LANGUAGE_CHOICES, STYLE_CHOICES from snippets.models import Snippet, LANGUAGE_CHOICES, STYLE_CHOICES
@ -23,7 +24,18 @@ class SnippetSerializer(serializers.Serializer):
create: docstring for create from serializer classdoc create: docstring for create from serializer classdoc
""" """
id = serializers.IntegerField(read_only=True, help_text="id serializer help text") id = serializers.IntegerField(read_only=True, help_text="id serializer help text")
owner = serializers.ReadOnlyField(source='owner.username') owner = serializers.PrimaryKeyRelatedField(
queryset=get_user_model().objects.all(),
default=serializers.CurrentUserDefault(),
help_text="The ID of the user that created this snippet; if none is provided, "
"defaults to the currently logged in user."
)
owner_as_string = serializers.PrimaryKeyRelatedField(
help_text="The ID of the user that created this snippet.",
pk_field=serializers.CharField(help_text="this help text should not show up"),
read_only=True,
source='owner',
)
title = serializers.CharField(required=False, allow_blank=True, max_length=100) title = serializers.CharField(required=False, allow_blank=True, max_length=100)
code = serializers.CharField(style={'base_template': 'textarea.html'}) code = serializers.CharField(style={'base_template': 'textarea.html'})
linenos = serializers.BooleanField(required=False) linenos = serializers.BooleanField(required=False)
@ -31,7 +43,8 @@ class SnippetSerializer(serializers.Serializer):
styles = serializers.MultipleChoiceField(choices=STYLE_CHOICES, default=['friendly']) styles = serializers.MultipleChoiceField(choices=STYLE_CHOICES, default=['friendly'])
lines = serializers.ListField(child=serializers.IntegerField(), allow_empty=True, allow_null=True, required=False) lines = serializers.ListField(child=serializers.IntegerField(), allow_empty=True, allow_null=True, required=False)
example_projects = serializers.ListSerializer(child=ExampleProjectSerializer(), read_only=True) example_projects = serializers.ListSerializer(child=ExampleProjectSerializer(), read_only=True)
difficulty_factor = serializers.FloatField(help_text="this is here just to test FloatField", read_only=True) difficulty_factor = serializers.FloatField(help_text="this is here just to test FloatField",
read_only=True, default=lambda: 6.9)
def create(self, validated_data): def create(self, validated_data):
""" """
@ -39,6 +52,7 @@ class SnippetSerializer(serializers.Serializer):
""" """
del validated_data['styles'] del validated_data['styles']
del validated_data['lines'] del validated_data['lines']
del validated_data['difficulty_factor']
return Snippet.objects.create(**validated_data) return Snippet.objects.create(**validated_data)
def update(self, instance, validated_data): def update(self, instance, validated_data):

View File

@ -1,19 +1,22 @@
from django.contrib.auth.models import User from django.contrib.auth.models import User
from rest_framework import serializers from rest_framework import serializers
from articles.models import Article
from snippets.models import Snippet from snippets.models import Snippet
class UserSerializerrr(serializers.ModelSerializer): class UserSerializerrr(serializers.ModelSerializer):
snippets = serializers.PrimaryKeyRelatedField(many=True, queryset=Snippet.objects.all()) snippets = serializers.PrimaryKeyRelatedField(many=True, queryset=Snippet.objects.all())
article_slugs = serializers.SlugRelatedField(read_only=True, slug_field='slug', many=True, source='articlessss')
last_connected_ip = serializers.IPAddressField(help_text="i'm out of ideas", protocol='ipv4', read_only=True) last_connected_ip = serializers.IPAddressField(help_text="i'm out of ideas", protocol='ipv4', read_only=True)
last_connected_at = serializers.DateField(help_text="really?", read_only=True) last_connected_at = serializers.DateField(help_text="really?", read_only=True)
class Meta: class Meta:
model = User model = User
fields = ('id', 'username', 'email', 'snippets', 'last_connected_ip', 'last_connected_at') fields = ('id', 'username', 'email', 'articles', 'snippets',
'last_connected_ip', 'last_connected_at', 'article_slugs')
class UserListQuerySerializer(serializers.Serializer): class UserListQuerySerializer(serializers.Serializer):
username = serializers.CharField(help_text="this field is generated from a query_serializer") username = serializers.CharField(help_text="this field is generated from a query_serializer", required=False)
is_staff = serializers.BooleanField(help_text="this one too!") is_staff = serializers.BooleanField(help_text="this one too!", required=False)

View File

@ -3,6 +3,9 @@ import json
import os import os
import pytest import pytest
from django.contrib.auth.models import User
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
from ruamel import yaml from ruamel import yaml
from drf_yasg import openapi, codecs from drf_yasg import openapi, codecs
@ -10,11 +13,16 @@ from drf_yasg.generators import OpenAPISchemaGenerator
@pytest.fixture @pytest.fixture
def generator(): def mock_schema_request(db):
return OpenAPISchemaGenerator( from rest_framework.test import force_authenticate
info=openapi.Info(title="Test generator", default_version="v1"),
version="v2", factory = APIRequestFactory()
) user = User.objects.create_user(username='admin', is_staff=True, is_superuser=True)
request = factory.get('/swagger.json')
force_authenticate(request, user=user)
request = APIView().initialize_request(request)
return request
@pytest.fixture @pytest.fixture
@ -28,13 +36,16 @@ def codec_yaml():
@pytest.fixture @pytest.fixture
def swagger(generator): def swagger(mock_schema_request):
return generator.get_schema(None, True) generator = OpenAPISchemaGenerator(
info=openapi.Info(title="Test generator", default_version="v1"),
version="v2",
)
return generator.get_schema(mock_schema_request, True)
@pytest.fixture @pytest.fixture
def swagger_dict(generator): def swagger_dict(swagger):
swagger = generator.get_schema(None, True)
json_bytes = codec_json().encode(swagger) json_bytes = codec_json().encode(swagger)
return json.loads(json_bytes.decode('utf-8')) return json.loads(json_bytes.decode('utf-8'))

View File

@ -178,6 +178,7 @@ paths:
description: slug model help_text description: slug model help_text
required: true required: true
type: string type: string
format: slug
pattern: '[a-z0-9]+(?:-[a-z0-9]+)' pattern: '[a-z0-9]+(?:-[a-z0-9]+)'
/articles/{slug}/image/: /articles/{slug}/image/:
get: get:
@ -231,6 +232,7 @@ paths:
description: slug model help_text description: slug model help_text
required: true required: true
type: string type: string
format: slug
pattern: '[a-z0-9]+(?:-[a-z0-9]+)' pattern: '[a-z0-9]+(?:-[a-z0-9]+)'
/plain/: /plain/:
get: get:
@ -355,12 +357,12 @@ paths:
- name: username - name: username
in: query in: query
description: this field is generated from a query_serializer description: this field is generated from a query_serializer
required: true required: false
type: string type: string
- name: is_staff - name: is_staff
in: query in: query
description: this one too! description: this one too!
required: true required: false
type: boolean type: boolean
responses: responses:
'200': '200':
@ -464,6 +466,11 @@ definitions:
title: title:
description: title model help_text description: title model help_text
type: string type: string
author:
description: The ID of the user that created this article; if none is provided,
defaults to the currently logged in user.
type: integer
default: 1
body: body:
description: body serializer help_text description: body serializer help_text
type: string type: string
@ -523,6 +530,12 @@ definitions:
type: integer type: integer
readOnly: true readOnly: true
owner: owner:
description: The ID of the user that created this snippet; if none is provided,
defaults to the currently logged in user.
type: integer
default: 1
owner_as_string:
description: The ID of the user that created this snippet.
type: string type: string
readOnly: true readOnly: true
title: title:
@ -1022,10 +1035,12 @@ definitions:
difficulty_factor: difficulty_factor:
description: this is here just to test FloatField description: this is here just to test FloatField
type: number type: number
default: 6.9
readOnly: true readOnly: true
UserSerializerrr: UserSerializerrr:
required: required:
- username - username
- articles
- snippets - snippets
type: object type: object
properties: properties:
@ -1039,6 +1054,11 @@ definitions:
email: email:
type: string type: string
format: email format: email
articles:
type: array
items:
type: integer
uniqueItems: true
snippets: snippets:
type: array type: array
items: items:
@ -1054,6 +1074,13 @@ definitions:
type: string type: string
format: date format: date
readOnly: true readOnly: true
article_slugs:
type: array
items:
type: string
readOnly: true
uniqueItems: true
readOnly: true
securityDefinitions: securityDefinitions:
basic: basic:
type: basic type: basic

View File

@ -2,7 +2,6 @@ from datadiff.tools import assert_equal
def test_reference_schema(swagger_dict, reference_schema): def test_reference_schema(swagger_dict, reference_schema):
# formatted better than pytest diff
swagger_dict = dict(swagger_dict) swagger_dict = dict(swagger_dict)
reference_schema = dict(reference_schema) reference_schema = dict(reference_schema)
ignore = ['info', 'host', 'schemes', 'basePath', 'securityDefinitions'] ignore = ['info', 'host', 'schemes', 'basePath', 'securityDefinitions']
@ -10,4 +9,5 @@ def test_reference_schema(swagger_dict, reference_schema):
swagger_dict.pop(attr, None) swagger_dict.pop(attr, None)
reference_schema.pop(attr, None) reference_schema.pop(attr, None)
# formatted better than pytest diff
assert_equal(swagger_dict, reference_schema) assert_equal(swagger_dict, reference_schema)

View File

@ -7,16 +7,11 @@ from drf_yasg import openapi, codecs
from drf_yasg.generators import OpenAPISchemaGenerator from drf_yasg.generators import OpenAPISchemaGenerator
def test_schema_generates_without_errors(generator): def test_schema_is_valid(swagger, codec_yaml):
generator.get_schema(None, True)
def test_schema_is_valid(generator, codec_yaml):
swagger = generator.get_schema(request=None, public=False)
codec_yaml.encode(swagger) codec_yaml.encode(swagger)
def test_invalid_schema_fails(codec_json): def test_invalid_schema_fails(codec_json, mock_schema_request):
# noinspection PyTypeChecker # noinspection PyTypeChecker
bad_generator = OpenAPISchemaGenerator( bad_generator = OpenAPISchemaGenerator(
info=openapi.Info( info=openapi.Info(
@ -26,40 +21,37 @@ def test_invalid_schema_fails(codec_json):
version="v2", version="v2",
) )
swagger = bad_generator.get_schema(None, True) swagger = bad_generator.get_schema(mock_schema_request, True)
with pytest.raises(codecs.SwaggerValidationError): with pytest.raises(codecs.SwaggerValidationError):
codec_json.encode(swagger) codec_json.encode(swagger)
def test_json_codec_roundtrip(codec_json, generator, validate_schema): def test_json_codec_roundtrip(codec_json, swagger, validate_schema):
swagger = generator.get_schema(None, True)
json_bytes = codec_json.encode(swagger) json_bytes = codec_json.encode(swagger)
validate_schema(json.loads(json_bytes.decode('utf-8'))) validate_schema(json.loads(json_bytes.decode('utf-8')))
def test_yaml_codec_roundtrip(codec_yaml, generator, validate_schema): def test_yaml_codec_roundtrip(codec_yaml, swagger, validate_schema):
swagger = generator.get_schema(None, True)
yaml_bytes = codec_yaml.encode(swagger) yaml_bytes = codec_yaml.encode(swagger)
assert b'omap' not in yaml_bytes # ensure no ugly !!omap is outputted assert b'omap' not in yaml_bytes # ensure no ugly !!omap is outputted
assert b'&id' not in yaml_bytes and b'*id' not in yaml_bytes # ensure no YAML references are generated assert b'&id' not in yaml_bytes and b'*id' not in yaml_bytes # ensure no YAML references are generated
validate_schema(yaml.safe_load(yaml_bytes.decode('utf-8'))) validate_schema(yaml.safe_load(yaml_bytes.decode('utf-8')))
def test_yaml_and_json_match(codec_yaml, codec_json, generator): def test_yaml_and_json_match(codec_yaml, codec_json, swagger):
swagger = generator.get_schema(None, True)
yaml_schema = yaml.safe_load(codec_yaml.encode(swagger).decode('utf-8')) yaml_schema = yaml.safe_load(codec_yaml.encode(swagger).decode('utf-8'))
json_schema = json.loads(codec_json.encode(swagger).decode('utf-8')) json_schema = json.loads(codec_json.encode(swagger).decode('utf-8'))
assert yaml_schema == json_schema assert yaml_schema == json_schema
def test_basepath_only(): def test_basepath_only(mock_schema_request):
generator = OpenAPISchemaGenerator( generator = OpenAPISchemaGenerator(
info=openapi.Info(title="Test generator", default_version="v1"), info=openapi.Info(title="Test generator", default_version="v1"),
version="v2", version="v2",
url='/basepath/', url='/basepath/',
) )
swagger = generator.get_schema(None, public=True) swagger = generator.get_schema(mock_schema_request, public=True)
assert 'host' not in swagger assert 'host' not in swagger
assert 'schemes' not in swagger assert 'schemes' not in swagger
assert swagger['basePath'] == '/' # base path is not implemented for now assert swagger['basePath'] == '/' # base path is not implemented for now