Add extension points to management command

master
Cristi Vîjdea 2018-12-28 01:36:26 +02:00
parent 470c993b98
commit e5a569ebf7
1 changed files with 15 additions and 19 deletions

View File

@ -1,12 +1,11 @@
import json
import logging
import os
from collections import OrderedDict
from importlib import import_module
from django.contrib.auth import get_user_model
from django.core.exceptions import ImproperlyConfigured
from django.core.management.base import BaseCommand
from django.utils.module_loading import import_string
from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory, force_authenticate
from rest_framework.views import APIView
@ -15,15 +14,6 @@ from ...app_settings import swagger_settings
from ...codecs import OpenAPICodecJson, OpenAPICodecYaml
def import_class(import_string):
if not import_string:
return None
module_path, class_name = import_string.rsplit('.', 1)
module = import_module(module_path)
return getattr(module, class_name)
class Command(BaseCommand):
help = 'Write the Swagger schema to disk in JSON or YAML format.'
@ -105,6 +95,17 @@ class Command(BaseCommand):
request = APIView().initialize_request(request)
return request
def get_schema_generator(self, generator_class_name, api_info, api_version, api_url):
generator_class = import_string(generator_class_name) or swagger_settings.DEFAULT_GENERATOR_CLASS
return generator_class(
info=api_info,
version=api_version,
url=api_url,
)
def get_schema(self, generator, request, public):
return generator.get_schema(request=request, public=public)
def handle(self, output_file, overwrite, format, api_url, mock, api_version, user, private, generator_class_name,
*args, **kwargs):
# disable logs of WARNING and below
@ -144,13 +145,8 @@ class Command(BaseCommand):
if request and api_version:
request.version = api_version
generator_class = import_class(generator_class_name) or swagger_settings.DEFAULT_GENERATOR_CLASS
generator = generator_class(
info=info,
version=api_version,
url=api_url,
)
schema = generator.get_schema(request=request, public=not private)
generator = self.get_schema_generator(generator_class_name, info, api_version, api_url)
schema = self.get_schema(generator, request, not private)
if output_file == '-':
self.write_schema(schema, self.stdout, format)