Refactor get_queryset_from_view
parent
a6ae8b0521
commit
90812f5c43
|
|
@ -15,7 +15,7 @@ from rest_framework.settings import api_settings as rest_framework_settings
|
|||
from . import openapi
|
||||
from .app_settings import swagger_settings
|
||||
from .errors import SwaggerGenerationError
|
||||
from .inspectors.field import get_basic_type_info, get_queryset_field
|
||||
from .inspectors.field import get_basic_type_info, get_queryset_field, get_queryset_from_view
|
||||
from .openapi import ReferenceResolver
|
||||
from .utils import force_real_str, get_consumes, get_produces
|
||||
|
||||
|
|
@ -424,7 +424,7 @@ class OpenAPISchemaGenerator(object):
|
|||
:rtype: list[openapi.Parameter]
|
||||
"""
|
||||
parameters = []
|
||||
queryset = getattr(view_cls, 'queryset', None)
|
||||
queryset = get_queryset_from_view(view_cls)
|
||||
|
||||
for variable in sorted(uritemplate.variables(path)):
|
||||
model, model_field = get_queryset_field(queryset, variable)
|
||||
|
|
|
|||
|
|
@ -165,6 +165,25 @@ def get_model_field(model, field_name):
|
|||
return None
|
||||
|
||||
|
||||
def get_queryset_from_view(view, serializer=None):
|
||||
"""Try to get the queryset of the given view
|
||||
|
||||
:param view: the view instance or class
|
||||
:param serializer: if given, will check that the view's get_serializer_class return matches this serialzier
|
||||
:return: queryset or ``None``
|
||||
"""
|
||||
try:
|
||||
queryset = getattr(view, 'queryset', None)
|
||||
|
||||
if queryset is not None and serializer is not None:
|
||||
# make sure the view is actually using *this* serializer
|
||||
assert type(serializer) == view.get_serializer_class()
|
||||
|
||||
return queryset
|
||||
except Exception: # pragma: no cover
|
||||
return None
|
||||
|
||||
|
||||
def get_parent_serializer(field):
|
||||
"""Get the nearest parent ``Serializer`` instance for the given field.
|
||||
|
||||
|
|
@ -231,21 +250,18 @@ class RelatedFieldInspector(FieldInspector):
|
|||
else:
|
||||
# if the RelatedField has no queryset (e.g. read only), try to find the target model
|
||||
# from the view queryset or ModelSerializer model, if present
|
||||
view_queryset = getattr(self.view, 'queryset', None)
|
||||
parent_serializer = get_parent_serializer(field)
|
||||
if view_queryset is not None:
|
||||
# make sure the view is actually using *this* serializer
|
||||
try:
|
||||
if type(parent_serializer) != self.view.get_serializer_class():
|
||||
view_queryset = None
|
||||
except Exception as e:
|
||||
view_queryset = None
|
||||
|
||||
serializer_meta = getattr(parent_serializer, 'Meta', None)
|
||||
this_model = getattr(serializer_meta, 'model', None) or getattr(view_queryset, 'model', None)
|
||||
this_model = getattr(serializer_meta, 'model', None)
|
||||
if not this_model:
|
||||
view_queryset = get_queryset_from_view(self.view, parent_serializer)
|
||||
this_model = getattr(view_queryset, 'model', None)
|
||||
|
||||
source = getattr(field, 'source', '') or field.field_name
|
||||
if not source and isinstance(field.parent, serializers.ManyRelatedField):
|
||||
source = field.parent.field_name
|
||||
|
||||
model = get_related_model(this_model, source)
|
||||
model_field = get_model_field(model, target_field)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue