Allow specific version generation in command

* Add --api-version parameter
* Fix request mocking
* Add tests
openapi3
Amoki 2018-07-27 12:33:26 +02:00 committed by Cristi Vîjdea
parent ca43a7de0c
commit bbc70a7e3d
4 changed files with 63 additions and 21 deletions

View File

@ -59,6 +59,11 @@ class Command(BaseCommand):
help='Use a mock request when generating the swagger schema. This is useful if your views or serializers'
'depend on context from a request in order to function.'
)
parser.add_argument(
'--api-version', dest='api_version',
type=str,
help='Version to use to generate schema. This option implies --mock-request.'
)
parser.add_argument(
'--user', dest='user',
help='Username of an existing user to use for mocked authentication. This option implies --mock-request.'
@ -102,7 +107,7 @@ class Command(BaseCommand):
request = APIView().initialize_request(request)
return request
def handle(self, output_file, overwrite, format, api_url, mock, user, private, generator_class_name,
def handle(self, output_file, overwrite, format, api_url, mock, api_version, user, private, generator_class_name,
*args, **kwargs):
# disable logs of WARNING and below
logging.disable(logging.WARNING)
@ -126,19 +131,25 @@ class Command(BaseCommand):
# avoid crashing if auth is not configured in the project
user = get_user_model().objects.get(username=user)
mock = mock or private or (user is not None)
mock = mock or private or (user is not None) or (api_version is not None)
if mock and not api_url:
raise ImproperlyConfigured(
'--mock-request requires an API url; either provide '
'the --url argument or set the DEFAULT_API_URL setting'
)
request = self.get_mock_request(api_url, format, user) if mock else None
request = None
if mock:
request = self.get_mock_request(api_url, format, user)
if request and api_version:
request.version = api_version
generator_class = import_class(generator_class_name) or swagger_settings.DEFAULT_GENERATOR_CLASS
generator = generator_class(
info=info,
url=api_url
version=api_version,
url=api_url,
)
schema = generator.get_schema(request=request, public=not private)

View File

@ -6,8 +6,10 @@ from collections import OrderedDict
import pytest
from datadiff.tools import assert_equal
from django.contrib.auth.models import User
from django.core.management import call_command
from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
from six import StringIO
from drf_yasg import codecs, openapi
from drf_yasg.codecs import yaml_sane_dump, yaml_sane_load
@ -64,6 +66,21 @@ def validate_schema(db):
return validate_schema
@pytest.fixture
def call_generate_swagger():
def call_generate_swagger(output_file='-', overwrite=False, format='', api_url='',
mock=False, user=None, private=False, generator_class_name='', **kwargs):
out = StringIO()
call_command(
'generate_swagger', stdout=out,
output_file=output_file, overwrite=overwrite, format=format, api_url=api_url, mock=mock, user=user,
private=private, generator_class_name=generator_class_name, **kwargs
)
return out.getvalue()
return call_generate_swagger
@pytest.fixture
def compare_schemas():
def compare_schemas(schema1, schema2):

View File

@ -9,25 +9,13 @@ from collections import OrderedDict
import pytest
from django.contrib.auth.models import User
from django.core.management import call_command
from drf_yasg import openapi
from drf_yasg.codecs import yaml_sane_load
from drf_yasg.generators import OpenAPISchemaGenerator
def call_generate_swagger(output_file='-', overwrite=False, format='', api_url='',
mock=False, user=None, private=False, generator_class_name='', **kwargs):
out = StringIO()
call_command(
'generate_swagger', stdout=out,
output_file=output_file, overwrite=overwrite, format=format, api_url=api_url, mock=mock, user=user,
private=private, generator_class_name=generator_class_name, **kwargs
)
return out.getvalue()
def test_reference_schema(db, reference_schema):
def test_reference_schema(call_generate_swagger, db, reference_schema):
User.objects.create_superuser('admin', 'admin@admin.admin', 'blabla')
output = call_generate_swagger(format='yaml', api_url='http://test.local:8002/', user='admin')
@ -35,13 +23,13 @@ def test_reference_schema(db, reference_schema):
assert output_schema == reference_schema
def test_non_public(db):
def test_non_public(call_generate_swagger, db):
output = call_generate_swagger(format='yaml', api_url='http://test.local:8002/', private=True)
output_schema = yaml_sane_load(output)
assert len(output_schema['paths']) == 0
def test_no_mock(db):
def test_no_mock(call_generate_swagger, db):
output = call_generate_swagger()
output_schema = json.loads(output, object_pairs_hook=OrderedDict)
assert len(output_schema['paths']) > 0
@ -52,7 +40,7 @@ class EmptySchemaGenerator(OpenAPISchemaGenerator):
return openapi.Paths(paths={}), ''
def test_generator_class(db):
def test_generator_class(call_generate_swagger, db):
output = call_generate_swagger(generator_class_name='test_management.EmptySchemaGenerator')
output_schema = json.loads(output, object_pairs_hook=OrderedDict)
assert len(output_schema['paths']) == 0
@ -65,7 +53,7 @@ def silentremove(filename):
pass
def test_file_output(db):
def test_file_output(call_generate_swagger, db):
prefix = os.path.join(tempfile.gettempdir(), tempfile.gettempprefix())
name = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(8))
yaml_file = prefix + name + '.yaml'

View File

@ -7,6 +7,18 @@ def _get_versioned_schema(prefix, client, validate_schema):
response = client.get(prefix + '/swagger.yaml')
assert response.status_code == 200
swagger = yaml_sane_load(response.content.decode('utf-8'))
_check_base(swagger, prefix, validate_schema)
return swagger
def _get_versioned_schema_management(prefix, call_generate_swagger, validate_schema, kwargs):
output = call_generate_swagger(format='yaml', api_url='http://localhost' + prefix + '/swagger.yaml', **kwargs)
swagger = yaml_sane_load(output)
_check_base(swagger, prefix, validate_schema)
return swagger
def _check_base(swagger, prefix, validate_schema):
assert swagger['basePath'] == prefix
validate_schema(swagger)
assert '/snippets/' in swagger['paths']
@ -51,3 +63,17 @@ def test_ns_v1(client, validate_schema):
def test_ns_v2(client, validate_schema):
swagger = _get_versioned_schema('/versioned/ns/v2.0', client, validate_schema)
_check_v2(swagger)
@pytest.mark.urls('urlconfs.url_versioning')
def test_url_v2_management(call_generate_swagger, validate_schema):
kwargs = {'api_version': '2.0'}
swagger = _get_versioned_schema_management('/versioned/url/v2.0', call_generate_swagger, validate_schema, kwargs)
_check_v2(swagger)
@pytest.mark.urls('urlconfs.ns_versioning')
def test_ns_v2_management(call_generate_swagger, validate_schema):
kwargs = {'api_version': '2.0'}
swagger = _get_versioned_schema_management('/versioned/ns/v2.0', call_generate_swagger, validate_schema, kwargs)
_check_v2(swagger)