Don't allow form parameters with non-form consumes

Closes #270.
master
Cristi Vîjdea 2018-12-12 12:43:33 +02:00
parent acfb0c5442
commit 1f95f4098b
3 changed files with 56 additions and 9 deletions

View File

@ -8,8 +8,8 @@ from rest_framework.status import is_success
from .. import openapi
from ..errors import SwaggerGenerationError
from ..utils import (
force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status, is_list_view,
merge_params, no_body, param_list_to_odict
filter_none, force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status,
is_list_view, merge_params, no_body, param_list_to_odict
)
from .base import ViewInspector
@ -29,7 +29,7 @@ class SwaggerAutoSchema(ViewInspector):
body = self.get_request_body_parameters(consumes)
query = self.get_query_parameters()
parameters = body + query
parameters = [param for param in parameters if param is not None]
parameters = filter_none(parameters)
parameters = self.add_manual_parameters(parameters)
operation_id = self.get_operation_id(operation_keys)
@ -167,12 +167,13 @@ class SwaggerAutoSchema(ViewInspector):
if any(param.in_ == openapi.IN_BODY for param in manual_parameters): # pragma: no cover
raise SwaggerGenerationError("specify the body parameter as a Schema or Serializer in request_body")
if any(param.in_ == openapi.IN_FORM for param in manual_parameters): # pragma: no cover
if any(param.in_ == openapi.IN_BODY for param in parameters):
has_body_parameter = any(param.in_ == openapi.IN_BODY for param in parameters)
if has_body_parameter or not any(is_form_media_type(encoding) for encoding in self.get_consumes()):
raise SwaggerGenerationError("cannot add form parameters when the request has a request body; "
"did you forget to set an appropriate parser class on the view?")
if self.method not in self.body_methods:
raise SwaggerGenerationError("form parameters can only be applied to (" + ','.join(self.body_methods) +
") HTTP methods")
raise SwaggerGenerationError("form parameters can only be applied to "
"(" + ','.join(self.body_methods) + ") HTTP methods")
return merge_params(parameters, manual_parameters)

View File

@ -335,11 +335,11 @@ def get_consumes(parser_classes):
:rtype: list[str]
"""
media_types = [parser.media_type for parser in parser_classes or []]
if all(is_form_media_type(encoding) for encoding in media_types):
non_form_media_types = [encoding for encoding in media_types if not is_form_media_type(encoding)]
if len(non_form_media_types) == 0:
return media_types
else:
media_types = [encoding for encoding in media_types if not is_form_media_type(encoding)]
return media_types
return non_form_media_types
def get_produces(renderer_classes):

View File

@ -0,0 +1,46 @@
import pytest
from django.conf.urls import url
from django.utils.decorators import method_decorator
from rest_framework.authtoken.views import ObtainAuthToken
from rest_framework.settings import api_settings
from drf_yasg import openapi
from drf_yasg.errors import SwaggerGenerationError
from drf_yasg.generators import OpenAPISchemaGenerator
from drf_yasg.utils import swagger_auto_schema
def test_choice_field():
@method_decorator(name='post', decorator=swagger_auto_schema(
operation_description="Logins a user and returns a token",
manual_parameters=[
openapi.Parameter(
"username",
openapi.IN_FORM,
required=True,
type=openapi.TYPE_STRING,
description="Valid username or email for authentication"
),
openapi.Parameter(
"password",
openapi.IN_FORM,
required=True,
type=openapi.TYPE_STRING,
description="Valid password for authentication",
),
]
))
class CustomObtainAuthToken(ObtainAuthToken):
throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
urlpatterns = [
url(r'token/$', CustomObtainAuthToken.as_view()),
]
generator = OpenAPISchemaGenerator(
info=openapi.Info(title="Test generator", default_version="v1"),
patterns=urlpatterns
)
with pytest.raises(SwaggerGenerationError):
generator.get_schema(None, True)