Add --generator-class argument to management command

Closes #140
openapi3
Cristi Vîjdea 2018-06-16 15:54:51 +03:00
parent 6ac58b8cf7
commit 1184ea8b46
5 changed files with 47 additions and 10 deletions

View File

@ -43,6 +43,14 @@ The possible settings and their default values are as follows:
Default classes Default classes
=============== ===============
DEFAULT_GENERATOR_CLASS
-------------------------
:class:`~.generators.OpenAPISchemaGenerator` subclass that will be used by default for generating the final
:class:`.Schema` object. Can be overriden by the ``generator_class`` argument to :func:`.get_schema_view`.
**Default**: :class:`drf_yasg.generators.OpenAPISchemaGenerator`
DEFAULT_AUTO_SCHEMA_CLASS DEFAULT_AUTO_SCHEMA_CLASS
------------------------- -------------------------
@ -102,7 +110,7 @@ DEFAULT_INFO
------------ ------------
An import string to an :class:`.openapi.Info` object. This will be used when running the ``generate_swagger`` An import string to an :class:`.openapi.Info` object. This will be used when running the ``generate_swagger``
management command, or if no ``info`` argument is passed to ``get_schema_view``. management command, or if no ``info`` argument is passed to :func:`.get_schema_view`.
**Default**: :python:`None` **Default**: :python:`None`

View File

@ -2,6 +2,7 @@ from django.conf import settings
from rest_framework.settings import perform_import from rest_framework.settings import perform_import
SWAGGER_DEFAULTS = { SWAGGER_DEFAULTS = {
'DEFAULT_GENERATOR_CLASS': 'drf_yasg.generators.OpenAPISchemaGenerator',
'DEFAULT_AUTO_SCHEMA_CLASS': 'drf_yasg.inspectors.SwaggerAutoSchema', 'DEFAULT_AUTO_SCHEMA_CLASS': 'drf_yasg.inspectors.SwaggerAutoSchema',
'DEFAULT_FIELD_INSPECTORS': [ 'DEFAULT_FIELD_INSPECTORS': [
@ -68,6 +69,7 @@ REDOC_DEFAULTS = {
} }
IMPORT_STRINGS = [ IMPORT_STRINGS = [
'DEFAULT_GENERATOR_CLASS',
'DEFAULT_AUTO_SCHEMA_CLASS', 'DEFAULT_AUTO_SCHEMA_CLASS',
'DEFAULT_FIELD_INSPECTORS', 'DEFAULT_FIELD_INSPECTORS',
'DEFAULT_FILTER_INSPECTORS', 'DEFAULT_FILTER_INSPECTORS',

View File

@ -2,6 +2,7 @@ import json
import logging import logging
import os import os
from collections import OrderedDict from collections import OrderedDict
from importlib import import_module
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
@ -12,7 +13,15 @@ from rest_framework.views import APIView
from ... import openapi from ... import openapi
from ...app_settings import swagger_settings from ...app_settings import swagger_settings
from ...codecs import OpenAPICodecJson, OpenAPICodecYaml from ...codecs import OpenAPICodecJson, OpenAPICodecYaml
from ...generators import OpenAPISchemaGenerator
def import_class(import_string):
if not import_string:
return None
module_path, class_name = import_string.rsplit('.', 1)
module = import_module(module_path)
return getattr(module, class_name)
class Command(BaseCommand): class Command(BaseCommand):
@ -64,6 +73,11 @@ class Command(BaseCommand):
'OpenAPISchemaGenerator.get_schema().\n' 'OpenAPISchemaGenerator.get_schema().\n'
'This option implies --mock-request.' 'This option implies --mock-request.'
) )
parser.add_argument(
'-g', '--generator-class', dest='generator_class_name',
default='',
help='Import string pointing to an OpenAPISchemaGenerator subclass to use for schema generation.'
)
def write_schema(self, schema, stream, format): def write_schema(self, schema, stream, format):
if format == 'json': if format == 'json':
@ -89,7 +103,8 @@ class Command(BaseCommand):
request = APIView().initialize_request(request) request = APIView().initialize_request(request)
return request return request
def handle(self, output_file, overwrite, format, api_url, mock, user, private, *args, **options): def handle(self, output_file, overwrite, format, api_url, mock, user, private, generator_class_name,
*args, **kwargs):
# disable logs of WARNING and below # disable logs of WARNING and below
logging.disable(logging.WARNING) logging.disable(logging.WARNING)
@ -117,7 +132,8 @@ class Command(BaseCommand):
request = self.get_mock_request(api_url, format, user) if mock else None request = self.get_mock_request(api_url, format, user) if mock else None
generator = OpenAPISchemaGenerator( generator_class = import_class(generator_class_name) or swagger_settings.DEFAULT_GENERATOR_CLASS
generator = generator_class(
info=info, info=info,
url=api_url url=api_url
) )

View File

@ -11,7 +11,6 @@ from rest_framework.settings import api_settings
from rest_framework.views import APIView from rest_framework.views import APIView
from .app_settings import swagger_settings from .app_settings import swagger_settings
from .generators import OpenAPISchemaGenerator
from .renderers import OpenAPIRenderer, ReDocRenderer, SwaggerJSONRenderer, SwaggerUIRenderer, SwaggerYAMLRenderer from .renderers import OpenAPIRenderer, ReDocRenderer, SwaggerJSONRenderer, SwaggerUIRenderer, SwaggerYAMLRenderer
SPEC_RENDERERS = (SwaggerYAMLRenderer, SwaggerJSONRenderer, OpenAPIRenderer) SPEC_RENDERERS = (SwaggerYAMLRenderer, SwaggerJSONRenderer, OpenAPIRenderer)
@ -46,7 +45,7 @@ def deferred_never_cache(view_func):
def get_schema_view(info=None, url=None, patterns=None, urlconf=None, public=False, validators=None, def get_schema_view(info=None, url=None, patterns=None, urlconf=None, public=False, validators=None,
generator_class=OpenAPISchemaGenerator, generator_class=swagger_settings.DEFAULT_GENERATOR_CLASS,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES): permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES):
"""Create a SchemaView class with default renderers and generators. """Create a SchemaView class with default renderers and generators.

View File

@ -11,17 +11,18 @@ import pytest
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.core.management import call_command from django.core.management import call_command
from drf_yasg import openapi
from drf_yasg.codecs import yaml_sane_load from drf_yasg.codecs import yaml_sane_load
from drf_yasg.generators import OpenAPISchemaGenerator
def call_generate_swagger(output_file='-', overwrite=False, format='', api_url='', def call_generate_swagger(output_file='-', overwrite=False, format='', api_url='',
mock=False, user='', private=False, **kwargs): mock=False, user='', private=False, generator_class_name='', **kwargs):
out = StringIO() out = StringIO()
call_command( call_command(
'generate_swagger', stdout=out, 'generate_swagger', stdout=out,
output_file=output_file, overwrite=overwrite, format=format, output_file=output_file, overwrite=overwrite, format=format, api_url=api_url, mock=mock, user=user,
api_url=api_url, mock=mock, user=user, private=private, private=private, generator_class_name=generator_class_name, **kwargs
**kwargs
) )
return out.getvalue() return out.getvalue()
@ -46,6 +47,17 @@ def test_no_mock(db):
assert len(output_schema['paths']) > 0 assert len(output_schema['paths']) > 0
class EmptySchemaGenerator(OpenAPISchemaGenerator):
def get_paths(self, endpoints, components, request, public):
return openapi.Paths(paths={}), ''
def test_generator_class(db):
output = call_generate_swagger(generator_class_name='test_management.EmptySchemaGenerator')
output_schema = json.loads(output, object_pairs_hook=OrderedDict)
assert len(output_schema['paths']) == 0
def silentremove(filename): def silentremove(filename):
try: try:
os.remove(filename) os.remove(filename)