From 3b31c54b9e16ca692c9ef7ae26266233b271326a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cristi=20V=C3=AEjdea?= Date: Fri, 21 Dec 2018 12:40:45 +0200 Subject: [PATCH] Add get_security_definitions and get_security_requirements hooks --- src/drf_yasg/generators.py | 44 +++++++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 10 deletions(-) diff --git a/src/drf_yasg/generators.py b/src/drf_yasg/generators.py index 2d20ff5..12ba61e 100644 --- a/src/drf_yasg/generators.py +++ b/src/drf_yasg/generators.py @@ -198,6 +198,35 @@ class OpenAPISchemaGenerator(object): def url(self): return self._gen.url + def get_security_definitions(self): + """Get the security schemes for this API. This determines what is usable in security requirements, + and helps clients configure their authorization credentials. + + :return: the security schemes usable with this API + :rtype: dict[str,dict]|None + """ + security_definitions = swagger_settings.SECURITY_DEFINITIONS + if security_definitions is not None: + security_definitions = SwaggerDict._as_odict(security_definitions, {}) + + return security_definitions + + def get_security_requirements(self, security_definitions): + """Get the base (global) security requirements of the API. This is never called if + :meth:`.get_security_definitions` returns `None`. + + :param security_definitions: security definitions as returned by :meth:`.get_security_definitions` + :return: + :rtype: dict[str,list[str]]|None + """ + security_requirements = swagger_settings.SECURITY_REQUIREMENTS + if security_requirements is None: + security_requirements = [{security_scheme: []} for security_scheme in security_definitions] + + security_requirements = [SwaggerDict._as_odict(sr, {}) for sr in security_requirements] + security_requirements = sorted(security_requirements, key=list) + return security_requirements + def get_schema(self, request=None, public=False): """Generate a :class:`.Swagger` object representing the API schema. @@ -214,16 +243,11 @@ class OpenAPISchemaGenerator(object): self.produces = get_produces(rest_framework_settings.DEFAULT_RENDERER_CLASSES) paths, prefix = self.get_paths(endpoints, components, request, public) - security_definitions = swagger_settings.SECURITY_DEFINITIONS - if security_definitions is not None: - security_definitions = SwaggerDict._as_odict(security_definitions, {}) - - security_requirements = swagger_settings.SECURITY_REQUIREMENTS - if security_requirements is None: - security_requirements = [{security_scheme: []} for security_scheme in swagger_settings.SECURITY_DEFINITIONS] - - security_requirements = [SwaggerDict._as_odict(sr, {}) for sr in security_requirements] - security_requirements = sorted(security_requirements, key=list) + security_definitions = self.get_security_definitions() + if security_definitions: + security_requirements = self.get_security_requirements(security_definitions) + else: + security_requirements = None url = self.url if url is None and request is not None: