diff --git a/src/drf_yasg/management/commands/generate_swagger.py b/src/drf_yasg/management/commands/generate_swagger.py index e4d3e8a..ba26e14 100644 --- a/src/drf_yasg/management/commands/generate_swagger.py +++ b/src/drf_yasg/management/commands/generate_swagger.py @@ -59,6 +59,11 @@ class Command(BaseCommand): help='Use a mock request when generating the swagger schema. This is useful if your views or serializers' 'depend on context from a request in order to function.' ) + parser.add_argument( + '--api-version', dest='api_version', + type=str, + help='Version to use to generate schema. This option implies --mock-request.' + ) parser.add_argument( '--user', dest='user', help='Username of an existing user to use for mocked authentication. This option implies --mock-request.' @@ -102,7 +107,7 @@ class Command(BaseCommand): request = APIView().initialize_request(request) return request - def handle(self, output_file, overwrite, format, api_url, mock, user, private, generator_class_name, + def handle(self, output_file, overwrite, format, api_url, mock, api_version, user, private, generator_class_name, *args, **kwargs): # disable logs of WARNING and below logging.disable(logging.WARNING) @@ -126,19 +131,25 @@ class Command(BaseCommand): # avoid crashing if auth is not configured in the project user = get_user_model().objects.get(username=user) - mock = mock or private or (user is not None) + mock = mock or private or (user is not None) or (api_version is not None) if mock and not api_url: raise ImproperlyConfigured( '--mock-request requires an API url; either provide ' 'the --url argument or set the DEFAULT_API_URL setting' ) - request = self.get_mock_request(api_url, format, user) if mock else None + request = None + if mock: + request = self.get_mock_request(api_url, format, user) + + if request and api_version: + request.version = api_version generator_class = import_class(generator_class_name) or swagger_settings.DEFAULT_GENERATOR_CLASS generator = generator_class( info=info, - url=api_url + version=api_version, + url=api_url, ) schema = generator.get_schema(request=request, public=not private) diff --git a/tests/conftest.py b/tests/conftest.py index e41845c..abcc090 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,8 +6,10 @@ from collections import OrderedDict import pytest from datadiff.tools import assert_equal from django.contrib.auth.models import User +from django.core.management import call_command from rest_framework.test import APIRequestFactory from rest_framework.views import APIView +from six import StringIO from drf_yasg import codecs, openapi from drf_yasg.codecs import yaml_sane_dump, yaml_sane_load @@ -64,6 +66,21 @@ def validate_schema(db): return validate_schema +@pytest.fixture +def call_generate_swagger(): + def call_generate_swagger(output_file='-', overwrite=False, format='', api_url='', + mock=False, user=None, private=False, generator_class_name='', **kwargs): + out = StringIO() + call_command( + 'generate_swagger', stdout=out, + output_file=output_file, overwrite=overwrite, format=format, api_url=api_url, mock=mock, user=user, + private=private, generator_class_name=generator_class_name, **kwargs + ) + return out.getvalue() + + return call_generate_swagger + + @pytest.fixture def compare_schemas(): def compare_schemas(schema1, schema2): diff --git a/tests/test_management.py b/tests/test_management.py index 7907c63..c3cdd40 100644 --- a/tests/test_management.py +++ b/tests/test_management.py @@ -9,25 +9,13 @@ from collections import OrderedDict import pytest from django.contrib.auth.models import User -from django.core.management import call_command from drf_yasg import openapi from drf_yasg.codecs import yaml_sane_load from drf_yasg.generators import OpenAPISchemaGenerator -def call_generate_swagger(output_file='-', overwrite=False, format='', api_url='', - mock=False, user=None, private=False, generator_class_name='', **kwargs): - out = StringIO() - call_command( - 'generate_swagger', stdout=out, - output_file=output_file, overwrite=overwrite, format=format, api_url=api_url, mock=mock, user=user, - private=private, generator_class_name=generator_class_name, **kwargs - ) - return out.getvalue() - - -def test_reference_schema(db, reference_schema): +def test_reference_schema(call_generate_swagger, db, reference_schema): User.objects.create_superuser('admin', 'admin@admin.admin', 'blabla') output = call_generate_swagger(format='yaml', api_url='http://test.local:8002/', user='admin') @@ -35,13 +23,13 @@ def test_reference_schema(db, reference_schema): assert output_schema == reference_schema -def test_non_public(db): +def test_non_public(call_generate_swagger, db): output = call_generate_swagger(format='yaml', api_url='http://test.local:8002/', private=True) output_schema = yaml_sane_load(output) assert len(output_schema['paths']) == 0 -def test_no_mock(db): +def test_no_mock(call_generate_swagger, db): output = call_generate_swagger() output_schema = json.loads(output, object_pairs_hook=OrderedDict) assert len(output_schema['paths']) > 0 @@ -52,7 +40,7 @@ class EmptySchemaGenerator(OpenAPISchemaGenerator): return openapi.Paths(paths={}), '' -def test_generator_class(db): +def test_generator_class(call_generate_swagger, db): output = call_generate_swagger(generator_class_name='test_management.EmptySchemaGenerator') output_schema = json.loads(output, object_pairs_hook=OrderedDict) assert len(output_schema['paths']) == 0 @@ -65,7 +53,7 @@ def silentremove(filename): pass -def test_file_output(db): +def test_file_output(call_generate_swagger, db): prefix = os.path.join(tempfile.gettempdir(), tempfile.gettempprefix()) name = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(8)) yaml_file = prefix + name + '.yaml' diff --git a/tests/test_versioning.py b/tests/test_versioning.py index 2679e65..79276c4 100644 --- a/tests/test_versioning.py +++ b/tests/test_versioning.py @@ -7,6 +7,18 @@ def _get_versioned_schema(prefix, client, validate_schema): response = client.get(prefix + '/swagger.yaml') assert response.status_code == 200 swagger = yaml_sane_load(response.content.decode('utf-8')) + _check_base(swagger, prefix, validate_schema) + return swagger + + +def _get_versioned_schema_management(prefix, call_generate_swagger, validate_schema, kwargs): + output = call_generate_swagger(format='yaml', api_url='http://localhost' + prefix + '/swagger.yaml', **kwargs) + swagger = yaml_sane_load(output) + _check_base(swagger, prefix, validate_schema) + return swagger + + +def _check_base(swagger, prefix, validate_schema): assert swagger['basePath'] == prefix validate_schema(swagger) assert '/snippets/' in swagger['paths'] @@ -51,3 +63,17 @@ def test_ns_v1(client, validate_schema): def test_ns_v2(client, validate_schema): swagger = _get_versioned_schema('/versioned/ns/v2.0', client, validate_schema) _check_v2(swagger) + + +@pytest.mark.urls('urlconfs.url_versioning') +def test_url_v2_management(call_generate_swagger, validate_schema): + kwargs = {'api_version': '2.0'} + swagger = _get_versioned_schema_management('/versioned/url/v2.0', call_generate_swagger, validate_schema, kwargs) + _check_v2(swagger) + + +@pytest.mark.urls('urlconfs.ns_versioning') +def test_ns_v2_management(call_generate_swagger, validate_schema): + kwargs = {'api_version': '2.0'} + swagger = _get_versioned_schema_management('/versioned/ns/v2.0', call_generate_swagger, validate_schema, kwargs) + _check_v2(swagger)