From 90812f5c43e0ccfc12827d076626c0e01419fec5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristi=20V=C3=AEjdea?= Date: Mon, 10 Sep 2018 14:34:30 +0300 Subject: [PATCH] Refactor get_queryset_from_view --- src/drf_yasg/generators.py | 4 ++-- src/drf_yasg/inspectors/field.py | 34 +++++++++++++++++++++++--------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/src/drf_yasg/generators.py b/src/drf_yasg/generators.py index ecfba6a..215f597 100644 --- a/src/drf_yasg/generators.py +++ b/src/drf_yasg/generators.py @@ -15,7 +15,7 @@ from rest_framework.settings import api_settings as rest_framework_settings from . import openapi from .app_settings import swagger_settings from .errors import SwaggerGenerationError -from .inspectors.field import get_basic_type_info, get_queryset_field +from .inspectors.field import get_basic_type_info, get_queryset_field, get_queryset_from_view from .openapi import ReferenceResolver from .utils import force_real_str, get_consumes, get_produces @@ -424,7 +424,7 @@ class OpenAPISchemaGenerator(object): :rtype: list[openapi.Parameter] """ parameters = [] - queryset = getattr(view_cls, 'queryset', None) + queryset = get_queryset_from_view(view_cls) for variable in sorted(uritemplate.variables(path)): model, model_field = get_queryset_field(queryset, variable) diff --git a/src/drf_yasg/inspectors/field.py b/src/drf_yasg/inspectors/field.py index aa6ac90..24436d9 100644 --- a/src/drf_yasg/inspectors/field.py +++ b/src/drf_yasg/inspectors/field.py @@ -165,6 +165,25 @@ def get_model_field(model, field_name): return None +def get_queryset_from_view(view, serializer=None): + """Try to get the queryset of the given view + + :param view: the view instance or class + :param serializer: if given, will check that the view's get_serializer_class return matches this serialzier + :return: queryset or ``None`` + """ + try: + queryset = getattr(view, 'queryset', None) + + if queryset is not None and serializer is not None: + # make sure the view is actually using *this* serializer + assert type(serializer) == view.get_serializer_class() + + return queryset + except Exception: # pragma: no cover + return None + + def get_parent_serializer(field): """Get the nearest parent ``Serializer`` instance for the given field. @@ -231,21 +250,18 @@ class RelatedFieldInspector(FieldInspector): else: # if the RelatedField has no queryset (e.g. read only), try to find the target model # from the view queryset or ModelSerializer model, if present - view_queryset = getattr(self.view, 'queryset', None) parent_serializer = get_parent_serializer(field) - if view_queryset is not None: - # make sure the view is actually using *this* serializer - try: - if type(parent_serializer) != self.view.get_serializer_class(): - view_queryset = None - except Exception as e: - view_queryset = None serializer_meta = getattr(parent_serializer, 'Meta', None) - this_model = getattr(serializer_meta, 'model', None) or getattr(view_queryset, 'model', None) + this_model = getattr(serializer_meta, 'model', None) + if not this_model: + view_queryset = get_queryset_from_view(self.view, parent_serializer) + this_model = getattr(view_queryset, 'model', None) + source = getattr(field, 'source', '') or field.field_name if not source and isinstance(field.parent, serializers.ManyRelatedField): source = field.parent.field_name + model = get_related_model(this_model, source) model_field = get_model_field(model, target_field)