parent
acfb0c5442
commit
1f95f4098b
|
|
@ -8,8 +8,8 @@ from rest_framework.status import is_success
|
||||||
from .. import openapi
|
from .. import openapi
|
||||||
from ..errors import SwaggerGenerationError
|
from ..errors import SwaggerGenerationError
|
||||||
from ..utils import (
|
from ..utils import (
|
||||||
force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status, is_list_view,
|
filter_none, force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status,
|
||||||
merge_params, no_body, param_list_to_odict
|
is_list_view, merge_params, no_body, param_list_to_odict
|
||||||
)
|
)
|
||||||
from .base import ViewInspector
|
from .base import ViewInspector
|
||||||
|
|
||||||
|
|
@ -29,7 +29,7 @@ class SwaggerAutoSchema(ViewInspector):
|
||||||
body = self.get_request_body_parameters(consumes)
|
body = self.get_request_body_parameters(consumes)
|
||||||
query = self.get_query_parameters()
|
query = self.get_query_parameters()
|
||||||
parameters = body + query
|
parameters = body + query
|
||||||
parameters = [param for param in parameters if param is not None]
|
parameters = filter_none(parameters)
|
||||||
parameters = self.add_manual_parameters(parameters)
|
parameters = self.add_manual_parameters(parameters)
|
||||||
|
|
||||||
operation_id = self.get_operation_id(operation_keys)
|
operation_id = self.get_operation_id(operation_keys)
|
||||||
|
|
@ -167,12 +167,13 @@ class SwaggerAutoSchema(ViewInspector):
|
||||||
if any(param.in_ == openapi.IN_BODY for param in manual_parameters): # pragma: no cover
|
if any(param.in_ == openapi.IN_BODY for param in manual_parameters): # pragma: no cover
|
||||||
raise SwaggerGenerationError("specify the body parameter as a Schema or Serializer in request_body")
|
raise SwaggerGenerationError("specify the body parameter as a Schema or Serializer in request_body")
|
||||||
if any(param.in_ == openapi.IN_FORM for param in manual_parameters): # pragma: no cover
|
if any(param.in_ == openapi.IN_FORM for param in manual_parameters): # pragma: no cover
|
||||||
if any(param.in_ == openapi.IN_BODY for param in parameters):
|
has_body_parameter = any(param.in_ == openapi.IN_BODY for param in parameters)
|
||||||
|
if has_body_parameter or not any(is_form_media_type(encoding) for encoding in self.get_consumes()):
|
||||||
raise SwaggerGenerationError("cannot add form parameters when the request has a request body; "
|
raise SwaggerGenerationError("cannot add form parameters when the request has a request body; "
|
||||||
"did you forget to set an appropriate parser class on the view?")
|
"did you forget to set an appropriate parser class on the view?")
|
||||||
if self.method not in self.body_methods:
|
if self.method not in self.body_methods:
|
||||||
raise SwaggerGenerationError("form parameters can only be applied to (" + ','.join(self.body_methods) +
|
raise SwaggerGenerationError("form parameters can only be applied to "
|
||||||
") HTTP methods")
|
"(" + ','.join(self.body_methods) + ") HTTP methods")
|
||||||
|
|
||||||
return merge_params(parameters, manual_parameters)
|
return merge_params(parameters, manual_parameters)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -335,11 +335,11 @@ def get_consumes(parser_classes):
|
||||||
:rtype: list[str]
|
:rtype: list[str]
|
||||||
"""
|
"""
|
||||||
media_types = [parser.media_type for parser in parser_classes or []]
|
media_types = [parser.media_type for parser in parser_classes or []]
|
||||||
if all(is_form_media_type(encoding) for encoding in media_types):
|
non_form_media_types = [encoding for encoding in media_types if not is_form_media_type(encoding)]
|
||||||
|
if len(non_form_media_types) == 0:
|
||||||
return media_types
|
return media_types
|
||||||
else:
|
else:
|
||||||
media_types = [encoding for encoding in media_types if not is_form_media_type(encoding)]
|
return non_form_media_types
|
||||||
return media_types
|
|
||||||
|
|
||||||
|
|
||||||
def get_produces(renderer_classes):
|
def get_produces(renderer_classes):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,46 @@
|
||||||
|
import pytest
|
||||||
|
from django.conf.urls import url
|
||||||
|
from django.utils.decorators import method_decorator
|
||||||
|
from rest_framework.authtoken.views import ObtainAuthToken
|
||||||
|
from rest_framework.settings import api_settings
|
||||||
|
|
||||||
|
from drf_yasg import openapi
|
||||||
|
from drf_yasg.errors import SwaggerGenerationError
|
||||||
|
from drf_yasg.generators import OpenAPISchemaGenerator
|
||||||
|
from drf_yasg.utils import swagger_auto_schema
|
||||||
|
|
||||||
|
|
||||||
|
def test_choice_field():
|
||||||
|
@method_decorator(name='post', decorator=swagger_auto_schema(
|
||||||
|
operation_description="Logins a user and returns a token",
|
||||||
|
manual_parameters=[
|
||||||
|
openapi.Parameter(
|
||||||
|
"username",
|
||||||
|
openapi.IN_FORM,
|
||||||
|
required=True,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
description="Valid username or email for authentication"
|
||||||
|
),
|
||||||
|
openapi.Parameter(
|
||||||
|
"password",
|
||||||
|
openapi.IN_FORM,
|
||||||
|
required=True,
|
||||||
|
type=openapi.TYPE_STRING,
|
||||||
|
description="Valid password for authentication",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
))
|
||||||
|
class CustomObtainAuthToken(ObtainAuthToken):
|
||||||
|
throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
|
||||||
|
|
||||||
|
urlpatterns = [
|
||||||
|
url(r'token/$', CustomObtainAuthToken.as_view()),
|
||||||
|
]
|
||||||
|
|
||||||
|
generator = OpenAPISchemaGenerator(
|
||||||
|
info=openapi.Info(title="Test generator", default_version="v1"),
|
||||||
|
patterns=urlpatterns
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SwaggerGenerationError):
|
||||||
|
generator.get_schema(None, True)
|
||||||
Loading…
Reference in New Issue