Refactor get_queryset_from_view
parent
a6ae8b0521
commit
90812f5c43
|
|
@ -15,7 +15,7 @@ from rest_framework.settings import api_settings as rest_framework_settings
|
||||||
from . import openapi
|
from . import openapi
|
||||||
from .app_settings import swagger_settings
|
from .app_settings import swagger_settings
|
||||||
from .errors import SwaggerGenerationError
|
from .errors import SwaggerGenerationError
|
||||||
from .inspectors.field import get_basic_type_info, get_queryset_field
|
from .inspectors.field import get_basic_type_info, get_queryset_field, get_queryset_from_view
|
||||||
from .openapi import ReferenceResolver
|
from .openapi import ReferenceResolver
|
||||||
from .utils import force_real_str, get_consumes, get_produces
|
from .utils import force_real_str, get_consumes, get_produces
|
||||||
|
|
||||||
|
|
@ -424,7 +424,7 @@ class OpenAPISchemaGenerator(object):
|
||||||
:rtype: list[openapi.Parameter]
|
:rtype: list[openapi.Parameter]
|
||||||
"""
|
"""
|
||||||
parameters = []
|
parameters = []
|
||||||
queryset = getattr(view_cls, 'queryset', None)
|
queryset = get_queryset_from_view(view_cls)
|
||||||
|
|
||||||
for variable in sorted(uritemplate.variables(path)):
|
for variable in sorted(uritemplate.variables(path)):
|
||||||
model, model_field = get_queryset_field(queryset, variable)
|
model, model_field = get_queryset_field(queryset, variable)
|
||||||
|
|
|
||||||
|
|
@ -165,6 +165,25 @@ def get_model_field(model, field_name):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_queryset_from_view(view, serializer=None):
|
||||||
|
"""Try to get the queryset of the given view
|
||||||
|
|
||||||
|
:param view: the view instance or class
|
||||||
|
:param serializer: if given, will check that the view's get_serializer_class return matches this serialzier
|
||||||
|
:return: queryset or ``None``
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
queryset = getattr(view, 'queryset', None)
|
||||||
|
|
||||||
|
if queryset is not None and serializer is not None:
|
||||||
|
# make sure the view is actually using *this* serializer
|
||||||
|
assert type(serializer) == view.get_serializer_class()
|
||||||
|
|
||||||
|
return queryset
|
||||||
|
except Exception: # pragma: no cover
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_parent_serializer(field):
|
def get_parent_serializer(field):
|
||||||
"""Get the nearest parent ``Serializer`` instance for the given field.
|
"""Get the nearest parent ``Serializer`` instance for the given field.
|
||||||
|
|
||||||
|
|
@ -231,21 +250,18 @@ class RelatedFieldInspector(FieldInspector):
|
||||||
else:
|
else:
|
||||||
# if the RelatedField has no queryset (e.g. read only), try to find the target model
|
# if the RelatedField has no queryset (e.g. read only), try to find the target model
|
||||||
# from the view queryset or ModelSerializer model, if present
|
# from the view queryset or ModelSerializer model, if present
|
||||||
view_queryset = getattr(self.view, 'queryset', None)
|
|
||||||
parent_serializer = get_parent_serializer(field)
|
parent_serializer = get_parent_serializer(field)
|
||||||
if view_queryset is not None:
|
|
||||||
# make sure the view is actually using *this* serializer
|
|
||||||
try:
|
|
||||||
if type(parent_serializer) != self.view.get_serializer_class():
|
|
||||||
view_queryset = None
|
|
||||||
except Exception as e:
|
|
||||||
view_queryset = None
|
|
||||||
|
|
||||||
serializer_meta = getattr(parent_serializer, 'Meta', None)
|
serializer_meta = getattr(parent_serializer, 'Meta', None)
|
||||||
this_model = getattr(serializer_meta, 'model', None) or getattr(view_queryset, 'model', None)
|
this_model = getattr(serializer_meta, 'model', None)
|
||||||
|
if not this_model:
|
||||||
|
view_queryset = get_queryset_from_view(self.view, parent_serializer)
|
||||||
|
this_model = getattr(view_queryset, 'model', None)
|
||||||
|
|
||||||
source = getattr(field, 'source', '') or field.field_name
|
source = getattr(field, 'source', '') or field.field_name
|
||||||
if not source and isinstance(field.parent, serializers.ManyRelatedField):
|
if not source and isinstance(field.parent, serializers.ManyRelatedField):
|
||||||
source = field.parent.field_name
|
source = field.parent.field_name
|
||||||
|
|
||||||
model = get_related_model(this_model, source)
|
model = get_related_model(this_model, source)
|
||||||
model_field = get_model_field(model, target_field)
|
model_field = get_model_field(model, target_field)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue