Allow specific version generation in command
* Add --api-version parameter * Fix request mocking * Add testsopenapi3
parent
ca43a7de0c
commit
bbc70a7e3d
|
|
@ -59,6 +59,11 @@ class Command(BaseCommand):
|
|||
help='Use a mock request when generating the swagger schema. This is useful if your views or serializers'
|
||||
'depend on context from a request in order to function.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--api-version', dest='api_version',
|
||||
type=str,
|
||||
help='Version to use to generate schema. This option implies --mock-request.'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--user', dest='user',
|
||||
help='Username of an existing user to use for mocked authentication. This option implies --mock-request.'
|
||||
|
|
@ -102,7 +107,7 @@ class Command(BaseCommand):
|
|||
request = APIView().initialize_request(request)
|
||||
return request
|
||||
|
||||
def handle(self, output_file, overwrite, format, api_url, mock, user, private, generator_class_name,
|
||||
def handle(self, output_file, overwrite, format, api_url, mock, api_version, user, private, generator_class_name,
|
||||
*args, **kwargs):
|
||||
# disable logs of WARNING and below
|
||||
logging.disable(logging.WARNING)
|
||||
|
|
@ -126,19 +131,25 @@ class Command(BaseCommand):
|
|||
# avoid crashing if auth is not configured in the project
|
||||
user = get_user_model().objects.get(username=user)
|
||||
|
||||
mock = mock or private or (user is not None)
|
||||
mock = mock or private or (user is not None) or (api_version is not None)
|
||||
if mock and not api_url:
|
||||
raise ImproperlyConfigured(
|
||||
'--mock-request requires an API url; either provide '
|
||||
'the --url argument or set the DEFAULT_API_URL setting'
|
||||
)
|
||||
|
||||
request = self.get_mock_request(api_url, format, user) if mock else None
|
||||
request = None
|
||||
if mock:
|
||||
request = self.get_mock_request(api_url, format, user)
|
||||
|
||||
if request and api_version:
|
||||
request.version = api_version
|
||||
|
||||
generator_class = import_class(generator_class_name) or swagger_settings.DEFAULT_GENERATOR_CLASS
|
||||
generator = generator_class(
|
||||
info=info,
|
||||
url=api_url
|
||||
version=api_version,
|
||||
url=api_url,
|
||||
)
|
||||
schema = generator.get_schema(request=request, public=not private)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,10 @@ from collections import OrderedDict
|
|||
import pytest
|
||||
from datadiff.tools import assert_equal
|
||||
from django.contrib.auth.models import User
|
||||
from django.core.management import call_command
|
||||
from rest_framework.test import APIRequestFactory
|
||||
from rest_framework.views import APIView
|
||||
from six import StringIO
|
||||
|
||||
from drf_yasg import codecs, openapi
|
||||
from drf_yasg.codecs import yaml_sane_dump, yaml_sane_load
|
||||
|
|
@ -64,6 +66,21 @@ def validate_schema(db):
|
|||
return validate_schema
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def call_generate_swagger():
|
||||
def call_generate_swagger(output_file='-', overwrite=False, format='', api_url='',
|
||||
mock=False, user=None, private=False, generator_class_name='', **kwargs):
|
||||
out = StringIO()
|
||||
call_command(
|
||||
'generate_swagger', stdout=out,
|
||||
output_file=output_file, overwrite=overwrite, format=format, api_url=api_url, mock=mock, user=user,
|
||||
private=private, generator_class_name=generator_class_name, **kwargs
|
||||
)
|
||||
return out.getvalue()
|
||||
|
||||
return call_generate_swagger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def compare_schemas():
|
||||
def compare_schemas(schema1, schema2):
|
||||
|
|
|
|||
|
|
@ -9,25 +9,13 @@ from collections import OrderedDict
|
|||
|
||||
import pytest
|
||||
from django.contrib.auth.models import User
|
||||
from django.core.management import call_command
|
||||
|
||||
from drf_yasg import openapi
|
||||
from drf_yasg.codecs import yaml_sane_load
|
||||
from drf_yasg.generators import OpenAPISchemaGenerator
|
||||
|
||||
|
||||
def call_generate_swagger(output_file='-', overwrite=False, format='', api_url='',
|
||||
mock=False, user=None, private=False, generator_class_name='', **kwargs):
|
||||
out = StringIO()
|
||||
call_command(
|
||||
'generate_swagger', stdout=out,
|
||||
output_file=output_file, overwrite=overwrite, format=format, api_url=api_url, mock=mock, user=user,
|
||||
private=private, generator_class_name=generator_class_name, **kwargs
|
||||
)
|
||||
return out.getvalue()
|
||||
|
||||
|
||||
def test_reference_schema(db, reference_schema):
|
||||
def test_reference_schema(call_generate_swagger, db, reference_schema):
|
||||
User.objects.create_superuser('admin', 'admin@admin.admin', 'blabla')
|
||||
|
||||
output = call_generate_swagger(format='yaml', api_url='http://test.local:8002/', user='admin')
|
||||
|
|
@ -35,13 +23,13 @@ def test_reference_schema(db, reference_schema):
|
|||
assert output_schema == reference_schema
|
||||
|
||||
|
||||
def test_non_public(db):
|
||||
def test_non_public(call_generate_swagger, db):
|
||||
output = call_generate_swagger(format='yaml', api_url='http://test.local:8002/', private=True)
|
||||
output_schema = yaml_sane_load(output)
|
||||
assert len(output_schema['paths']) == 0
|
||||
|
||||
|
||||
def test_no_mock(db):
|
||||
def test_no_mock(call_generate_swagger, db):
|
||||
output = call_generate_swagger()
|
||||
output_schema = json.loads(output, object_pairs_hook=OrderedDict)
|
||||
assert len(output_schema['paths']) > 0
|
||||
|
|
@ -52,7 +40,7 @@ class EmptySchemaGenerator(OpenAPISchemaGenerator):
|
|||
return openapi.Paths(paths={}), ''
|
||||
|
||||
|
||||
def test_generator_class(db):
|
||||
def test_generator_class(call_generate_swagger, db):
|
||||
output = call_generate_swagger(generator_class_name='test_management.EmptySchemaGenerator')
|
||||
output_schema = json.loads(output, object_pairs_hook=OrderedDict)
|
||||
assert len(output_schema['paths']) == 0
|
||||
|
|
@ -65,7 +53,7 @@ def silentremove(filename):
|
|||
pass
|
||||
|
||||
|
||||
def test_file_output(db):
|
||||
def test_file_output(call_generate_swagger, db):
|
||||
prefix = os.path.join(tempfile.gettempdir(), tempfile.gettempprefix())
|
||||
name = ''.join(random.choice(string.ascii_lowercase + string.digits) for _ in range(8))
|
||||
yaml_file = prefix + name + '.yaml'
|
||||
|
|
|
|||
|
|
@ -7,6 +7,18 @@ def _get_versioned_schema(prefix, client, validate_schema):
|
|||
response = client.get(prefix + '/swagger.yaml')
|
||||
assert response.status_code == 200
|
||||
swagger = yaml_sane_load(response.content.decode('utf-8'))
|
||||
_check_base(swagger, prefix, validate_schema)
|
||||
return swagger
|
||||
|
||||
|
||||
def _get_versioned_schema_management(prefix, call_generate_swagger, validate_schema, kwargs):
|
||||
output = call_generate_swagger(format='yaml', api_url='http://localhost' + prefix + '/swagger.yaml', **kwargs)
|
||||
swagger = yaml_sane_load(output)
|
||||
_check_base(swagger, prefix, validate_schema)
|
||||
return swagger
|
||||
|
||||
|
||||
def _check_base(swagger, prefix, validate_schema):
|
||||
assert swagger['basePath'] == prefix
|
||||
validate_schema(swagger)
|
||||
assert '/snippets/' in swagger['paths']
|
||||
|
|
@ -51,3 +63,17 @@ def test_ns_v1(client, validate_schema):
|
|||
def test_ns_v2(client, validate_schema):
|
||||
swagger = _get_versioned_schema('/versioned/ns/v2.0', client, validate_schema)
|
||||
_check_v2(swagger)
|
||||
|
||||
|
||||
@pytest.mark.urls('urlconfs.url_versioning')
|
||||
def test_url_v2_management(call_generate_swagger, validate_schema):
|
||||
kwargs = {'api_version': '2.0'}
|
||||
swagger = _get_versioned_schema_management('/versioned/url/v2.0', call_generate_swagger, validate_schema, kwargs)
|
||||
_check_v2(swagger)
|
||||
|
||||
|
||||
@pytest.mark.urls('urlconfs.ns_versioning')
|
||||
def test_ns_v2_management(call_generate_swagger, validate_schema):
|
||||
kwargs = {'api_version': '2.0'}
|
||||
swagger = _get_versioned_schema_management('/versioned/ns/v2.0', call_generate_swagger, validate_schema, kwargs)
|
||||
_check_v2(swagger)
|
||||
|
|
|
|||
Loading…
Reference in New Issue