From 1184ea8b4695b896769bd782e60bb77996f18f9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristi=20V=C3=AEjdea?= Date: Sat, 16 Jun 2018 15:54:51 +0300 Subject: [PATCH] Add --generator-class argument to management command Closes #140 --- docs/settings.rst | 10 ++++++++- src/drf_yasg/app_settings.py | 2 ++ .../management/commands/generate_swagger.py | 22 ++++++++++++++++--- src/drf_yasg/views.py | 3 +-- tests/test_management.py | 20 +++++++++++++---- 5 files changed, 47 insertions(+), 10 deletions(-) diff --git a/docs/settings.rst b/docs/settings.rst index 38b8735..ba7af28 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -43,6 +43,14 @@ The possible settings and their default values are as follows: Default classes =============== +DEFAULT_GENERATOR_CLASS +------------------------- + +:class:`~.generators.OpenAPISchemaGenerator` subclass that will be used by default for generating the final +:class:`.Schema` object. Can be overriden by the ``generator_class`` argument to :func:`.get_schema_view`. + +**Default**: :class:`drf_yasg.generators.OpenAPISchemaGenerator` + DEFAULT_AUTO_SCHEMA_CLASS ------------------------- @@ -102,7 +110,7 @@ DEFAULT_INFO ------------ An import string to an :class:`.openapi.Info` object. This will be used when running the ``generate_swagger`` -management command, or if no ``info`` argument is passed to ``get_schema_view``. +management command, or if no ``info`` argument is passed to :func:`.get_schema_view`. **Default**: :python:`None` diff --git a/src/drf_yasg/app_settings.py b/src/drf_yasg/app_settings.py index 02d7628..2edde3a 100644 --- a/src/drf_yasg/app_settings.py +++ b/src/drf_yasg/app_settings.py @@ -2,6 +2,7 @@ from django.conf import settings from rest_framework.settings import perform_import SWAGGER_DEFAULTS = { + 'DEFAULT_GENERATOR_CLASS': 'drf_yasg.generators.OpenAPISchemaGenerator', 'DEFAULT_AUTO_SCHEMA_CLASS': 'drf_yasg.inspectors.SwaggerAutoSchema', 'DEFAULT_FIELD_INSPECTORS': [ @@ -68,6 +69,7 @@ REDOC_DEFAULTS = { } IMPORT_STRINGS = [ + 'DEFAULT_GENERATOR_CLASS', 'DEFAULT_AUTO_SCHEMA_CLASS', 'DEFAULT_FIELD_INSPECTORS', 'DEFAULT_FILTER_INSPECTORS', diff --git a/src/drf_yasg/management/commands/generate_swagger.py b/src/drf_yasg/management/commands/generate_swagger.py index 2c36253..18fd815 100644 --- a/src/drf_yasg/management/commands/generate_swagger.py +++ b/src/drf_yasg/management/commands/generate_swagger.py @@ -2,6 +2,7 @@ import json import logging import os from collections import OrderedDict +from importlib import import_module from django.contrib.auth.models import User from django.core.exceptions import ImproperlyConfigured @@ -12,7 +13,15 @@ from rest_framework.views import APIView from ... import openapi from ...app_settings import swagger_settings from ...codecs import OpenAPICodecJson, OpenAPICodecYaml -from ...generators import OpenAPISchemaGenerator + + +def import_class(import_string): + if not import_string: + return None + + module_path, class_name = import_string.rsplit('.', 1) + module = import_module(module_path) + return getattr(module, class_name) class Command(BaseCommand): @@ -64,6 +73,11 @@ class Command(BaseCommand): 'OpenAPISchemaGenerator.get_schema().\n' 'This option implies --mock-request.' ) + parser.add_argument( + '-g', '--generator-class', dest='generator_class_name', + default='', + help='Import string pointing to an OpenAPISchemaGenerator subclass to use for schema generation.' + ) def write_schema(self, schema, stream, format): if format == 'json': @@ -89,7 +103,8 @@ class Command(BaseCommand): request = APIView().initialize_request(request) return request - def handle(self, output_file, overwrite, format, api_url, mock, user, private, *args, **options): + def handle(self, output_file, overwrite, format, api_url, mock, user, private, generator_class_name, + *args, **kwargs): # disable logs of WARNING and below logging.disable(logging.WARNING) @@ -117,7 +132,8 @@ class Command(BaseCommand): request = self.get_mock_request(api_url, format, user) if mock else None - generator = OpenAPISchemaGenerator( + generator_class = import_class(generator_class_name) or swagger_settings.DEFAULT_GENERATOR_CLASS + generator = generator_class( info=info, url=api_url ) diff --git a/src/drf_yasg/views.py b/src/drf_yasg/views.py index 1330374..4f3c8df 100644 --- a/src/drf_yasg/views.py +++ b/src/drf_yasg/views.py @@ -11,7 +11,6 @@ from rest_framework.settings import api_settings from rest_framework.views import APIView from .app_settings import swagger_settings -from .generators import OpenAPISchemaGenerator from .renderers import OpenAPIRenderer, ReDocRenderer, SwaggerJSONRenderer, SwaggerUIRenderer, SwaggerYAMLRenderer SPEC_RENDERERS = (SwaggerYAMLRenderer, SwaggerJSONRenderer, OpenAPIRenderer) @@ -46,7 +45,7 @@ def deferred_never_cache(view_func): def get_schema_view(info=None, url=None, patterns=None, urlconf=None, public=False, validators=None, - generator_class=OpenAPISchemaGenerator, + generator_class=swagger_settings.DEFAULT_GENERATOR_CLASS, authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES, permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES): """Create a SchemaView class with default renderers and generators. diff --git a/tests/test_management.py b/tests/test_management.py index d39e237..70bf6d9 100644 --- a/tests/test_management.py +++ b/tests/test_management.py @@ -11,17 +11,18 @@ import pytest from django.contrib.auth.models import User from django.core.management import call_command +from drf_yasg import openapi from drf_yasg.codecs import yaml_sane_load +from drf_yasg.generators import OpenAPISchemaGenerator def call_generate_swagger(output_file='-', overwrite=False, format='', api_url='', - mock=False, user='', private=False, **kwargs): + mock=False, user='', private=False, generator_class_name='', **kwargs): out = StringIO() call_command( 'generate_swagger', stdout=out, - output_file=output_file, overwrite=overwrite, format=format, - api_url=api_url, mock=mock, user=user, private=private, - **kwargs + output_file=output_file, overwrite=overwrite, format=format, api_url=api_url, mock=mock, user=user, + private=private, generator_class_name=generator_class_name, **kwargs ) return out.getvalue() @@ -46,6 +47,17 @@ def test_no_mock(db): assert len(output_schema['paths']) > 0 +class EmptySchemaGenerator(OpenAPISchemaGenerator): + def get_paths(self, endpoints, components, request, public): + return openapi.Paths(paths={}), '' + + +def test_generator_class(db): + output = call_generate_swagger(generator_class_name='test_management.EmptySchemaGenerator') + output_schema = json.loads(output, object_pairs_hook=OrderedDict) + assert len(output_schema['paths']) == 0 + + def silentremove(filename): try: os.remove(filename)