diff --git a/src/drf_yasg/inspectors/view.py b/src/drf_yasg/inspectors/view.py index bbbb086..def629b 100644 --- a/src/drf_yasg/inspectors/view.py +++ b/src/drf_yasg/inspectors/view.py @@ -8,8 +8,8 @@ from rest_framework.status import is_success from .. import openapi from ..errors import SwaggerGenerationError from ..utils import ( - force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status, is_list_view, - merge_params, no_body, param_list_to_odict + filter_none, force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status, + is_list_view, merge_params, no_body, param_list_to_odict ) from .base import ViewInspector @@ -29,7 +29,7 @@ class SwaggerAutoSchema(ViewInspector): body = self.get_request_body_parameters(consumes) query = self.get_query_parameters() parameters = body + query - parameters = [param for param in parameters if param is not None] + parameters = filter_none(parameters) parameters = self.add_manual_parameters(parameters) operation_id = self.get_operation_id(operation_keys) @@ -167,12 +167,13 @@ class SwaggerAutoSchema(ViewInspector): if any(param.in_ == openapi.IN_BODY for param in manual_parameters): # pragma: no cover raise SwaggerGenerationError("specify the body parameter as a Schema or Serializer in request_body") if any(param.in_ == openapi.IN_FORM for param in manual_parameters): # pragma: no cover - if any(param.in_ == openapi.IN_BODY for param in parameters): + has_body_parameter = any(param.in_ == openapi.IN_BODY for param in parameters) + if has_body_parameter or not any(is_form_media_type(encoding) for encoding in self.get_consumes()): raise SwaggerGenerationError("cannot add form parameters when the request has a request body; " "did you forget to set an appropriate parser class on the view?") if self.method not in self.body_methods: - raise SwaggerGenerationError("form parameters can only be applied to (" + ','.join(self.body_methods) + - ") HTTP methods") + raise SwaggerGenerationError("form parameters can only be applied to " + "(" + ','.join(self.body_methods) + ") HTTP methods") return merge_params(parameters, manual_parameters) diff --git a/src/drf_yasg/utils.py b/src/drf_yasg/utils.py index 1682c10..5edac06 100644 --- a/src/drf_yasg/utils.py +++ b/src/drf_yasg/utils.py @@ -335,11 +335,11 @@ def get_consumes(parser_classes): :rtype: list[str] """ media_types = [parser.media_type for parser in parser_classes or []] - if all(is_form_media_type(encoding) for encoding in media_types): + non_form_media_types = [encoding for encoding in media_types if not is_form_media_type(encoding)] + if len(non_form_media_types) == 0: return media_types else: - media_types = [encoding for encoding in media_types if not is_form_media_type(encoding)] - return media_types + return non_form_media_types def get_produces(renderer_classes): diff --git a/tests/test_form_parameters.py b/tests/test_form_parameters.py new file mode 100644 index 0000000..b4d9aef --- /dev/null +++ b/tests/test_form_parameters.py @@ -0,0 +1,46 @@ +import pytest +from django.conf.urls import url +from django.utils.decorators import method_decorator +from rest_framework.authtoken.views import ObtainAuthToken +from rest_framework.settings import api_settings + +from drf_yasg import openapi +from drf_yasg.errors import SwaggerGenerationError +from drf_yasg.generators import OpenAPISchemaGenerator +from drf_yasg.utils import swagger_auto_schema + + +def test_choice_field(): + @method_decorator(name='post', decorator=swagger_auto_schema( + operation_description="Logins a user and returns a token", + manual_parameters=[ + openapi.Parameter( + "username", + openapi.IN_FORM, + required=True, + type=openapi.TYPE_STRING, + description="Valid username or email for authentication" + ), + openapi.Parameter( + "password", + openapi.IN_FORM, + required=True, + type=openapi.TYPE_STRING, + description="Valid password for authentication", + ), + ] + )) + class CustomObtainAuthToken(ObtainAuthToken): + throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES + + urlpatterns = [ + url(r'token/$', CustomObtainAuthToken.as_view()), + ] + + generator = OpenAPISchemaGenerator( + info=openapi.Info(title="Test generator", default_version="v1"), + patterns=urlpatterns + ) + + with pytest.raises(SwaggerGenerationError): + generator.get_schema(None, True)