diff --git a/src/drf_yasg/management/commands/generate_swagger.py b/src/drf_yasg/management/commands/generate_swagger.py index 3f1db30..fb2fa94 100644 --- a/src/drf_yasg/management/commands/generate_swagger.py +++ b/src/drf_yasg/management/commands/generate_swagger.py @@ -1,12 +1,11 @@ -import json import logging import os -from collections import OrderedDict -from importlib import import_module from django.contrib.auth import get_user_model from django.core.exceptions import ImproperlyConfigured from django.core.management.base import BaseCommand +from django.utils.module_loading import import_string +from rest_framework.settings import api_settings from rest_framework.test import APIRequestFactory, force_authenticate from rest_framework.views import APIView @@ -15,15 +14,6 @@ from ...app_settings import swagger_settings from ...codecs import OpenAPICodecJson, OpenAPICodecYaml -def import_class(import_string): - if not import_string: - return None - - module_path, class_name = import_string.rsplit('.', 1) - module = import_module(module_path) - return getattr(module, class_name) - - class Command(BaseCommand): help = 'Write the Swagger schema to disk in JSON or YAML format.' @@ -105,6 +95,17 @@ class Command(BaseCommand): request = APIView().initialize_request(request) return request + def get_schema_generator(self, generator_class_name, api_info, api_version, api_url): + generator_class = import_string(generator_class_name) or swagger_settings.DEFAULT_GENERATOR_CLASS + return generator_class( + info=api_info, + version=api_version, + url=api_url, + ) + + def get_schema(self, generator, request, public): + return generator.get_schema(request=request, public=public) + def handle(self, output_file, overwrite, format, api_url, mock, api_version, user, private, generator_class_name, *args, **kwargs): # disable logs of WARNING and below @@ -144,13 +145,8 @@ class Command(BaseCommand): if request and api_version: request.version = api_version - generator_class = import_class(generator_class_name) or swagger_settings.DEFAULT_GENERATOR_CLASS - generator = generator_class( - info=info, - version=api_version, - url=api_url, - ) - schema = generator.get_schema(request=request, public=not private) + generator = self.get_schema_generator(generator_class_name, info, api_version, api_url) + schema = self.get_schema(generator, request, not private) if output_file == '-': self.write_schema(schema, self.stdout, format)