diff --git a/src/drf_yasg/codecs.py b/src/drf_yasg/codecs.py index c150d50..9739f75 100644 --- a/src/drf_yasg/codecs.py +++ b/src/drf_yasg/codecs.py @@ -5,6 +5,7 @@ import json from collections import OrderedDict from coreapi.compat import force_bytes +from django.utils.safestring import SafeData, SafeText from ruamel import yaml from . import openapi diff --git a/src/drf_yasg/generators.py b/src/drf_yasg/generators.py index a647b84..689192c 100644 --- a/src/drf_yasg/generators.py +++ b/src/drf_yasg/generators.py @@ -5,7 +5,6 @@ from collections import OrderedDict, defaultdict import uritemplate from coreapi.compat import urlparse -from django.utils.encoding import force_text from rest_framework import versioning from rest_framework.compat import URLPattern, URLResolver, get_original_route from rest_framework.schemas.generators import EndpointEnumerator as _EndpointEnumerator @@ -18,7 +17,7 @@ from .app_settings import swagger_settings from .errors import SwaggerGenerationError from .inspectors.field import get_basic_type_info, get_queryset_field from .openapi import ReferenceResolver -from .utils import get_consumes, get_produces +from .utils import force_real_str, get_consumes, get_produces logger = logging.getLogger(__name__) @@ -435,7 +434,7 @@ class OpenAPISchemaGenerator(object): attrs['pattern'] = getattr(view_cls, 'lookup_value_regex', attrs.get('pattern', None)) if model_field and getattr(model_field, 'help_text', False): - description = force_text(model_field.help_text) + description = model_field.help_text elif model_field and getattr(model_field, 'primary_key', False): description = get_pk_description(model, model_field) else: @@ -443,7 +442,7 @@ class OpenAPISchemaGenerator(object): field = openapi.Parameter( name=variable, - description=description, + description=force_real_str(description), required=True, in_=openapi.IN_PATH, **attrs diff --git a/src/drf_yasg/inspectors/base.py b/src/drf_yasg/inspectors/base.py index c54fe7d..3c04c40 100644 --- a/src/drf_yasg/inspectors/base.py +++ b/src/drf_yasg/inspectors/base.py @@ -1,12 +1,11 @@ import inspect import logging -from django.utils.encoding import force_text from rest_framework import serializers from rest_framework.utils import encoders, json from .. import openapi -from ..utils import decimal_as_float, is_list_view +from ..utils import decimal_as_float, force_real_str, is_list_view #: Sentinel value that inspectors must return to signal that they do not know how to handle an object NotHandled = object() @@ -196,9 +195,9 @@ class FieldInspector(BaseInspector): """ assert swagger_object_type in (openapi.Schema, openapi.Parameter, openapi.Items) assert not isinstance(field, openapi.SwaggerDict), "passed field is already a SwaggerDict object" - title = force_text(field.label) if field.label else None + title = force_real_str(field.label) if field.label else None title = title if swagger_object_type == openapi.Schema else None # only Schema has title - description = force_text(field.help_text) if field.help_text else None + description = force_real_str(field.help_text) if field.help_text else None description = description if swagger_object_type != openapi.Items else None # Items has no description either def SwaggerType(existing_object=None, **instance_kwargs): diff --git a/src/drf_yasg/inspectors/query.py b/src/drf_yasg/inspectors/query.py index d69ba00..ec8b2ba 100644 --- a/src/drf_yasg/inspectors/query.py +++ b/src/drf_yasg/inspectors/query.py @@ -3,6 +3,8 @@ from collections import OrderedDict import coreschema from rest_framework.pagination import CursorPagination, LimitOffsetPagination, PageNumberPagination +from drf_yasg.utils import force_real_str + from .. import openapi from .base import FilterInspector, PaginatorInspector @@ -48,7 +50,7 @@ class CoreAPICompatInspector(PaginatorInspector, FilterInspector): in_=location_to_in[field.location], type=coreapi_types.get(type(field.schema), openapi.TYPE_STRING), required=field.required, - description=field.schema.description if field.schema else None, + description=force_real_str(field.schema.description) if field.schema else None, ) diff --git a/src/drf_yasg/inspectors/view.py b/src/drf_yasg/inspectors/view.py index 957c8f7..03c3686 100644 --- a/src/drf_yasg/inspectors/view.py +++ b/src/drf_yasg/inspectors/view.py @@ -8,7 +8,7 @@ from rest_framework.status import is_success from .. import openapi from ..errors import SwaggerGenerationError from ..utils import ( - force_serializer_instance, get_consumes, get_produces, guess_response_status, is_list_view, no_body, + force_real_str, force_serializer_instance, get_consumes, get_produces, guess_response_status, is_list_view, no_body, param_list_to_odict ) from .base import ViewInspector @@ -42,7 +42,7 @@ class SwaggerAutoSchema(ViewInspector): return openapi.Operation( operation_id=operation_id, - description=description, + description=force_real_str(description), responses=responses, parameters=parameters, consumes=consumes, @@ -246,7 +246,7 @@ class SwaggerAutoSchema(ViewInspector): for sc, serializer in response_serializers.items(): if isinstance(serializer, str): response = openapi.Response( - description=serializer + description=force_real_str(serializer) ) elif not serializer: continue diff --git a/src/drf_yasg/utils.py b/src/drf_yasg/utils.py index a457388..f0591a3 100644 --- a/src/drf_yasg/utils.py +++ b/src/drf_yasg/utils.py @@ -3,6 +3,7 @@ import logging from collections import OrderedDict from django.db import models +from django.utils.encoding import force_text from rest_framework import serializers, status from rest_framework.mixins import DestroyModelMixin, RetrieveModelMixin, UpdateModelMixin from rest_framework.request import is_form_media_type @@ -319,3 +320,17 @@ def get_serializer_ref_name(serializer): if ref_name.endswith('Serializer'): ref_name = ref_name[:-len('Serializer')] return ref_name + + +def force_real_str(s, encoding='utf-8', strings_only=False, errors='strict'): + """ + Force `s` into a ``str`` instance. + + Fix for https://github.com/axnsan12/drf-yasg/issues/159 + """ + if s is not None: + s = force_text(s, encoding, strings_only, errors) + if type(s) != str: + s = '' + s + + return s diff --git a/testproj/people/models.py b/testproj/people/models.py index 919212a..ca98b39 100644 --- a/testproj/people/models.py +++ b/testproj/people/models.py @@ -1,9 +1,10 @@ from django.db import models +from django.utils.safestring import mark_safe class Identity(models.Model): firstName = models.CharField(max_length=30, null=True) - lastName = models.CharField(max_length=30, null=True) + lastName = models.CharField(max_length=30, null=True, help_text=mark_safe("Here's some HTML!")) class Person(models.Model): diff --git a/tests/conftest.py b/tests/conftest.py index abcc090..0c9e0c1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,5 @@ +from six import StringIO + import copy import json import os @@ -9,7 +11,6 @@ from django.contrib.auth.models import User from django.core.management import call_command from rest_framework.test import APIRequestFactory from rest_framework.views import APIView -from six import StringIO from drf_yasg import codecs, openapi from drf_yasg.codecs import yaml_sane_dump, yaml_sane_load diff --git a/tests/reference.yaml b/tests/reference.yaml index 4e5a570..3c13ddd 100644 --- a/tests/reference.yaml +++ b/tests/reference.yaml @@ -905,6 +905,7 @@ definitions: minLength: 1 lastName: title: LastName + description: Here's some HTML! type: string maxLength: 30 minLength: 1