Add extension points to management command
parent
470c993b98
commit
e5a569ebf7
|
|
@ -1,12 +1,11 @@
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import OrderedDict
|
|
||||||
from importlib import import_module
|
|
||||||
|
|
||||||
from django.contrib.auth import get_user_model
|
from django.contrib.auth import get_user_model
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
from django.core.management.base import BaseCommand
|
from django.core.management.base import BaseCommand
|
||||||
|
from django.utils.module_loading import import_string
|
||||||
|
from rest_framework.settings import api_settings
|
||||||
from rest_framework.test import APIRequestFactory, force_authenticate
|
from rest_framework.test import APIRequestFactory, force_authenticate
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
|
|
@ -15,15 +14,6 @@ from ...app_settings import swagger_settings
|
||||||
from ...codecs import OpenAPICodecJson, OpenAPICodecYaml
|
from ...codecs import OpenAPICodecJson, OpenAPICodecYaml
|
||||||
|
|
||||||
|
|
||||||
def import_class(import_string):
|
|
||||||
if not import_string:
|
|
||||||
return None
|
|
||||||
|
|
||||||
module_path, class_name = import_string.rsplit('.', 1)
|
|
||||||
module = import_module(module_path)
|
|
||||||
return getattr(module, class_name)
|
|
||||||
|
|
||||||
|
|
||||||
class Command(BaseCommand):
|
class Command(BaseCommand):
|
||||||
help = 'Write the Swagger schema to disk in JSON or YAML format.'
|
help = 'Write the Swagger schema to disk in JSON or YAML format.'
|
||||||
|
|
||||||
|
|
@ -105,6 +95,17 @@ class Command(BaseCommand):
|
||||||
request = APIView().initialize_request(request)
|
request = APIView().initialize_request(request)
|
||||||
return request
|
return request
|
||||||
|
|
||||||
|
def get_schema_generator(self, generator_class_name, api_info, api_version, api_url):
|
||||||
|
generator_class = import_string(generator_class_name) or swagger_settings.DEFAULT_GENERATOR_CLASS
|
||||||
|
return generator_class(
|
||||||
|
info=api_info,
|
||||||
|
version=api_version,
|
||||||
|
url=api_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_schema(self, generator, request, public):
|
||||||
|
return generator.get_schema(request=request, public=public)
|
||||||
|
|
||||||
def handle(self, output_file, overwrite, format, api_url, mock, api_version, user, private, generator_class_name,
|
def handle(self, output_file, overwrite, format, api_url, mock, api_version, user, private, generator_class_name,
|
||||||
*args, **kwargs):
|
*args, **kwargs):
|
||||||
# disable logs of WARNING and below
|
# disable logs of WARNING and below
|
||||||
|
|
@ -144,13 +145,8 @@ class Command(BaseCommand):
|
||||||
if request and api_version:
|
if request and api_version:
|
||||||
request.version = api_version
|
request.version = api_version
|
||||||
|
|
||||||
generator_class = import_class(generator_class_name) or swagger_settings.DEFAULT_GENERATOR_CLASS
|
generator = self.get_schema_generator(generator_class_name, info, api_version, api_url)
|
||||||
generator = generator_class(
|
schema = self.get_schema(generator, request, not private)
|
||||||
info=info,
|
|
||||||
version=api_version,
|
|
||||||
url=api_url,
|
|
||||||
)
|
|
||||||
schema = generator.get_schema(request=request, public=not private)
|
|
||||||
|
|
||||||
if output_file == '-':
|
if output_file == '-':
|
||||||
self.write_schema(schema, self.stdout, format)
|
self.write_schema(schema, self.stdout, format)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue