Make security pluggable (#1671)

Make security pluggable

- [x] Solution for standard security handlers: `security_deny`,
`security_passthrough`, `verify_none`
- [x] HTTP security handlers & overlap with basic from swagger 2
- [x] Do we need a separate handler for each `oauth2` flow?
This commit is contained in:
Ruwann
2023-04-08 18:19:26 +02:00
committed by GitHub
parent 55e376f816
commit 5b4beeb2ea
11 changed files with 639 additions and 501 deletions

View File

@@ -47,6 +47,7 @@ class AbstractApp:
uri_parser_class: t.Optional[AbstractURIParser] = None,
validate_responses: t.Optional[bool] = None,
validator_map: t.Optional[dict] = None,
security_map: t.Optional[dict] = None,
) -> None:
"""
:param import_name: The name of the package or module that this object belongs to. If you
@@ -77,6 +78,8 @@ class AbstractApp:
an impact on performance. Defaults to False.
:param validator_map: A dictionary of validators to use. Defaults to
:obj:`validators.VALIDATOR_MAP`.
:param security_map: A dictionary of security handlers to use. Defaults to
:obj:`security.SECURITY_HANDLERS`
"""
self.middleware = ConnexionMiddleware(
self.middleware_app,
@@ -95,6 +98,7 @@ class AbstractApp:
uri_parser_class=uri_parser_class,
validate_responses=validate_responses,
validator_map=validator_map,
security_map=security_map,
)
def add_api(
@@ -113,6 +117,7 @@ class AbstractApp:
uri_parser_class: t.Optional[AbstractURIParser] = None,
validate_responses: t.Optional[bool] = None,
validator_map: t.Optional[dict] = None,
security_map: t.Optional[dict] = None,
**kwargs,
) -> t.Any:
"""
@@ -143,6 +148,8 @@ class AbstractApp:
an impact on performance. Defaults to False.
:param validator_map: A dictionary of validators to use. Defaults to
:obj:`validators.VALIDATOR_MAP`
:param security_map: A dictionary of security handlers to use. Defaults to
:obj:`security.SECURITY_HANDLERS`
:param kwargs: Additional keyword arguments to pass to the `add_api` method of the managed
middlewares. This can be used to pass arguments to middlewares added beyond the default
ones.
@@ -163,6 +170,7 @@ class AbstractApp:
uri_parser_class=uri_parser_class,
validate_responses=validate_responses,
validator_map=validator_map,
security_map=security_map,
**kwargs,
)

View File

@@ -135,6 +135,7 @@ class AsyncApp(AbstractApp):
uri_parser_class: t.Optional[AbstractURIParser] = None,
validate_responses: t.Optional[bool] = None,
validator_map: t.Optional[dict] = None,
security_map: t.Optional[dict] = None,
) -> None:
"""
:param import_name: The name of the package or module that this object belongs to. If you
@@ -165,6 +166,8 @@ class AsyncApp(AbstractApp):
an impact on performance. Defaults to False.
:param validator_map: A dictionary of validators to use. Defaults to
:obj:`validators.VALIDATOR_MAP`.
:param security_map: A dictionary of security handlers to use. Defaults to
:obj:`security.SECURITY_HANDLERS`
"""
self.middleware_app: AsyncMiddlewareApp = AsyncMiddlewareApp()
@@ -184,6 +187,7 @@ class AsyncApp(AbstractApp):
uri_parser_class=uri_parser_class,
validate_responses=validate_responses,
validator_map=validator_map,
security_map=security_map,
)
def add_url_rule(

View File

@@ -192,6 +192,7 @@ class FlaskApp(AbstractApp):
uri_parser_class: t.Optional[AbstractURIParser] = None,
validate_responses: t.Optional[bool] = None,
validator_map: t.Optional[dict] = None,
security_map: t.Optional[dict] = None,
):
"""
:param import_name: The name of the package or module that this object belongs to. If you
@@ -225,6 +226,8 @@ class FlaskApp(AbstractApp):
an impact on performance. Defaults to False.
:param validator_map: A dictionary of validators to use. Defaults to
:obj:`validators.VALIDATOR_MAP`.
:param security_map: A dictionary of security handlers to use. Defaults to
:obj:`security.SECURITY_HANDLERS`
"""
self.middleware_app = FlaskMiddlewareApp(import_name, server_args or {})
self.app = self.middleware_app.app
@@ -244,6 +247,7 @@ class FlaskApp(AbstractApp):
uri_parser_class=uri_parser_class,
validate_responses=validate_responses,
validator_map=validator_map,
security_map=security_map,
)
def add_url_rule(

View File

@@ -51,6 +51,7 @@ class _Options:
uri_parser_class: t.Optional[AbstractURIParser] = None
validate_responses: t.Optional[bool] = False
validator_map: t.Optional[dict] = None
security_map: t.Optional[dict] = None
def __post_init__(self):
self.resolver = (
@@ -115,6 +116,7 @@ class ConnexionMiddleware:
uri_parser_class: t.Optional[AbstractURIParser] = None,
validate_responses: t.Optional[bool] = None,
validator_map: t.Optional[dict] = None,
security_map: t.Optional[dict] = None,
):
"""
:param import_name: The name of the package or module that this object belongs to. If you
@@ -145,6 +147,8 @@ class ConnexionMiddleware:
an impact on performance. Defaults to False.
:param validator_map: A dictionary of validators to use. Defaults to
:obj:`validators.VALIDATOR_MAP`.
:param security_map: A dictionary of security handlers to use. Defaults to
:obj:`security.SECURITY_HANDLERS`.
"""
import_name = import_name or str(pathlib.Path.cwd())
self.root_path = utils.get_root_path(import_name)
@@ -169,6 +173,7 @@ class ConnexionMiddleware:
uri_parser_class=uri_parser_class,
validate_responses=validate_responses,
validator_map=validator_map,
security_map=security_map,
)
self.extra_files: t.List[str] = []
@@ -217,6 +222,7 @@ class ConnexionMiddleware:
uri_parser_class: t.Optional[AbstractURIParser] = None,
validate_responses: t.Optional[bool] = None,
validator_map: t.Optional[dict] = None,
security_map: t.Optional[dict] = None,
**kwargs,
) -> t.Any:
"""
@@ -247,6 +253,8 @@ class ConnexionMiddleware:
an impact on performance. Defaults to False.
:param validator_map: A dictionary of validators to use. Defaults to
:obj:`validators.VALIDATOR_MAP`
:param security_map: A dictionary of security handlers to use. Defaults to
:obj:`security.SECURITY_HANDLERS`
:param kwargs: Additional keyword arguments to pass to the `add_api` method of the managed
middlewares. This can be used to pass arguments to middlewares added beyond the default
ones.
@@ -275,6 +283,7 @@ class ConnexionMiddleware:
uri_parser_class=uri_parser_class,
validate_responses=validate_responses,
validator_map=validator_map,
security_map=security_map,
)
for app in self.apps:

View File

@@ -70,6 +70,29 @@ class RequestValidationOperation:
f"expected {self._operation.consumes}"
)
@property
def security_query_params(self) -> t.List[str]:
"""Get the names of query parameters that are used for security."""
if not hasattr(self, "_security_query_params"):
security_query_params: t.List[str] = []
if self._operation.security is None:
self._security_query_params = security_query_params
return self._security_query_params
for security_req in self._operation.security:
for scheme_name in security_req:
security_scheme = self._operation.security_schemes[scheme_name]
if (
security_scheme["type"] == "apiKey"
and security_scheme["in"] == "query"
):
# Only query parameters need to be considered for strict_validation
security_query_params.append(security_scheme["name"])
self._security_query_params = security_query_params
return self._security_query_params
async def __call__(self, scope: Scope, receive: Receive, send: Send):
# Validate parameters & headers
uri_parser_class = self._operation._uri_parser_class
@@ -81,6 +104,7 @@ class RequestValidationOperation:
self._operation.parameters,
uri_parser=uri_parser,
strict_validation=self.strict_validation,
security_query_params=self.security_query_params,
)
parameter_validator.validate(scope)

View File

@@ -51,7 +51,7 @@ class SecurityOperation:
auth_funcs = []
for security_req in self.security:
if not security_req:
auth_funcs.append(self.security_handler_factory.verify_none())
auth_funcs.append(self.security_handler_factory.verify_none)
continue
sec_req_funcs = {}
@@ -67,123 +67,15 @@ class SecurityOperation:
)
break
oauth = True
token_info_func = self.security_handler_factory.get_tokeninfo_func(
security_scheme
sec_req_func = self.security_handler_factory.parse_security_scheme(
security_scheme, required_scopes
)
scope_validate_func = (
self.security_handler_factory.get_scope_validate_func(
security_scheme
)
)
if not token_info_func:
logger.warning("... x-tokenInfoFunc missing", extra=vars(self))
if sec_req_func is None:
break
sec_req_funcs[
scheme_name
] = self.security_handler_factory.verify_oauth(
token_info_func, scope_validate_func, required_scopes
)
sec_req_funcs[scheme_name] = sec_req_func
# Swagger 2.0
elif security_scheme["type"] == "basic":
basic_info_func = self.security_handler_factory.get_basicinfo_func(
security_scheme
)
if not basic_info_func:
logger.warning("... x-basicInfoFunc missing", extra=vars(self))
break
sec_req_funcs[
scheme_name
] = self.security_handler_factory.verify_basic(basic_info_func)
# OpenAPI 3.0.0
elif security_scheme["type"] == "http":
scheme = security_scheme["scheme"].lower()
if scheme == "basic":
basic_info_func = (
self.security_handler_factory.get_basicinfo_func(
security_scheme
)
)
if not basic_info_func:
logger.warning(
"... x-basicInfoFunc missing", extra=vars(self)
)
break
sec_req_funcs[
scheme_name
] = self.security_handler_factory.verify_basic(basic_info_func)
elif scheme == "bearer":
bearer_info_func = (
self.security_handler_factory.get_bearerinfo_func(
security_scheme
)
)
if not bearer_info_func:
logger.warning(
"... x-bearerInfoFunc missing", extra=vars(self)
)
break
sec_req_funcs[
scheme_name
] = self.security_handler_factory.verify_bearer(
bearer_info_func
)
else:
logger.warning(
"... Unsupported http authorization scheme %s" % scheme,
extra=vars(self),
)
break
elif security_scheme["type"] == "apiKey":
scheme = security_scheme.get("x-authentication-scheme", "").lower()
if scheme == "bearer":
bearer_info_func = (
self.security_handler_factory.get_bearerinfo_func(
security_scheme
)
)
if not bearer_info_func:
logger.warning(
"... x-bearerInfoFunc missing", extra=vars(self)
)
break
sec_req_funcs[
scheme_name
] = self.security_handler_factory.verify_bearer(
bearer_info_func
)
else:
apikey_info_func = (
self.security_handler_factory.get_apikeyinfo_func(
security_scheme
)
)
if not apikey_info_func:
logger.warning(
"... x-apikeyInfoFunc missing", extra=vars(self)
)
break
sec_req_funcs[
scheme_name
] = self.security_handler_factory.verify_api_key(
apikey_info_func,
security_scheme["in"],
security_scheme["name"],
)
else:
logger.warning(
"... Unsupported security scheme type %s"
% security_scheme["type"],
extra=vars(self),
)
break
else:
# No break encountered: no missing funcs
if len(sec_req_funcs) == 1:
@@ -199,16 +91,22 @@ class SecurityOperation:
return self.security_handler_factory.verify_security(auth_funcs)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if not self.security:
await self.next_app(scope, receive, send)
return
request = ASGIRequest(scope)
await self.verification_fn(request)
await self.next_app(scope, receive, send)
class SecurityAPI(RoutedAPI[SecurityOperation]):
def __init__(self, *args, auth_all_paths: bool = False, **kwargs):
def __init__(
self, *args, auth_all_paths: bool = False, security_map: dict = None, **kwargs
):
super().__init__(*args, **kwargs)
self.security_handler_factory = SecurityHandlerFactory()
self.security_handler_factory = SecurityHandlerFactory(security_map)
if auth_all_paths:
self.add_auth_on_not_found()

View File

@@ -1,6 +1,46 @@
"""
This module defines an abstract SecurityHandlerFactory which supports the creation of security
handlers for operations.
This module defines a SecurityHandlerFactory which supports the creation of
SecurityHandler instances for different security schemes.
It also exposes a `SECURITY_HANDLERS` dictionary which maps security scheme
types to SecurityHandler classes. This dictionary can be used to register
custom SecurityHandler classes for custom security schemes, or to overwrite
existing SecurityHandler classes.
This can be done by supplying a value for `security_map` argument of the
SecurityHandlerFactory.
Swagger 2.0 lets you define the following authentication types for an API:
- Basic authentication
- API key (as a header or a query string parameter)
- OAuth 2 common flows (authorization code, implicit, resource owner password credentials, client credentials)
Changes from OpenAPI 2.0 to OpenAPI 3.0
If you used OpenAPI 2.0 before, here is a summary of changes to help you get started with OpenAPI 3.0:
- securityDefinitions were renamed to securitySchemes and moved inside components.
- type: basic was replaced with type: http and scheme: basic.
- The new type: http is an umbrella type for all HTTP security schemes, including Basic, Bearer and other,
and the scheme keyword indicates the scheme type.
- API keys can now be sent in: cookie.
- Added support for OpenID Connect Discovery (type: openIdConnect).
- OAuth 2 security schemes can now define multiple flows.
- OAuth 2 flows were renamed to match the OAuth 2 Specification: accessCode is now authorizationCode,
and application is now clientCredentials.
OpenAPI uses the term security scheme for authentication and authorization schemes.
OpenAPI 3.0 lets you describe APIs protected using the following security schemes:
- HTTP authentication schemes (they use the Authorization header):
- Basic
- Bearer
- other HTTP schemes as defined by RFC 7235 and HTTP Authentication Scheme Registry
- API keys in headers, query string or cookies
- Cookie authentication
- OAuth 2
- OpenID Connect Discovery
"""
import asyncio
@@ -8,62 +48,274 @@ import base64
import http.cookies
import logging
import os
import textwrap
import typing as t
import httpx
from connexion.decorators.parameter import inspect_function_arguments
from connexion.exceptions import (
ConnexionException,
OAuthProblem,
OAuthResponseProblem,
OAuthScopeProblem,
)
from connexion.exceptions import OAuthProblem, OAuthResponseProblem, OAuthScopeProblem
from connexion.lifecycle import ASGIRequest
from connexion.utils import get_function_from_name
logger = logging.getLogger("connexion.api.security")
logger = logging.getLogger(__name__)
class SecurityHandlerFactory:
"""
get_*_func -> _get_function -> get_function_from_name (name=security function defined in spec)
(if url defined instead of a function -> get_token_info_remote)
NO_VALUE = object()
"""Sentinel value to indicate that no security credentials were found."""
std security functions: security_{passthrough,deny}
verify_* -> returns a security wrapper around the security function
check_* -> returns a function tasked with doing auth for use inside the verify wrapper
check helpers (used outside wrappers): _need_to_add_context_or_scopes
the security function
class AbstractSecurityHandler:
verify helpers (used inside wrappers): get_auth_header_value, get_cookie_value
"""
no_value = object()
required_scopes_kw = "required_scopes"
context_kw = "context_"
client = None
security_definition_key: str
"""The key which contains the value for the function name to resolve."""
environ_key: str
"""The name of the environment variable that can be used alternatively for the function name."""
@staticmethod
def get_fn(self, security_scheme, required_scopes):
"""Returns the handler function"""
security_func = self._resolve_func(security_scheme)
if not security_func:
logger.warning("... %s missing", self.security_definition_key)
return None
return self._get_verify_func(security_func)
@classmethod
def _get_function(
security_definition, security_definition_key, environ_key, default=None
cls,
security_definition: dict,
security_definition_key: str,
environ_key: str,
default: t.Optional[t.Callable] = None,
):
"""
Return function by getting its name from security_definition or environment variable
:param security_definition: Security Definition (scheme) from the spec.
:param security_definition_key: The key which contains the value for the function name to resolve.
:param environ_key: The name of the environment variable that can be used alternatively for the function name.
:param default: The default to use in case the function cannot be found based on the security_definition_key or the environ_key
"""
func = security_definition.get(security_definition_key) or os.environ.get(
func_name = security_definition.get(security_definition_key) or os.environ.get(
environ_key
)
if func:
return get_function_from_name(func)
if func_name:
return get_function_from_name(func_name)
return default
def _generic_check(self, func, exception_msg):
(
need_to_add_context,
need_to_add_required_scopes,
) = self._need_to_add_context_or_scopes(func)
async def wrapper(request, *args, required_scopes=None):
kwargs = {}
if need_to_add_context:
kwargs[self.context_kw] = request.context
if need_to_add_required_scopes:
kwargs[self.required_scopes_kw] = required_scopes
token_info = func(*args, **kwargs)
while asyncio.iscoroutine(token_info):
token_info = await token_info
if token_info is NO_VALUE:
return NO_VALUE
if token_info is None:
raise OAuthResponseProblem(detail=exception_msg)
return token_info
return wrapper
@staticmethod
def get_auth_header_value(request):
"""
Return Authorization type and value if any.
If not Authorization, return (None, None)
Raise OAuthProblem for invalid Authorization header
"""
authorization = request.headers.get("Authorization")
if not authorization:
return None, None
try:
auth_type, value = authorization.split(maxsplit=1)
except ValueError:
raise OAuthProblem(detail="Invalid authorization header")
return auth_type.lower(), value
def _need_to_add_context_or_scopes(self, func):
arguments, has_kwargs = inspect_function_arguments(func)
need_context = self.context_kw in arguments
need_required_scopes = has_kwargs or self.required_scopes_kw in arguments
return need_context, need_required_scopes
def _resolve_func(self, security_scheme):
"""
Get the user function object based on the security scheme or the environment variable.
:param security_scheme: Security Definition (scheme) from the spec.
"""
return self._get_function(
security_scheme, self.security_definition_key, self.environ_key
)
def _get_verify_func(self, function):
"""
Wraps the user security function in a function that checks the request for the correct
security credentials and calls the user function with the correct arguments.
"""
return self._generic_check(function, "Provided authorization is not valid")
class BasicSecurityHandler(AbstractSecurityHandler):
"""
Security Handler for
- `type: basic` (Swagger 2), and
- `type: http` and `scheme: basic` (OpenAPI 3)
"""
security_definition_key = "x-basicInfoFunc"
environ_key = "BASICINFO_FUNC"
def _get_verify_func(self, basic_info_func):
check_basic_info_func = self.check_basic_auth(basic_info_func)
def wrapper(request):
auth_type, user_pass = self.get_auth_header_value(request)
if auth_type != "basic":
return NO_VALUE
try:
username, password = (
base64.b64decode(user_pass).decode("latin1").split(":", 1)
)
except Exception:
raise OAuthProblem(detail="Invalid authorization header")
return check_basic_info_func(request, username, password)
return wrapper
def check_basic_auth(self, basic_info_func):
return self._generic_check(
basic_info_func, "Provided authorization is not valid"
)
class BearerSecurityHandler(AbstractSecurityHandler):
"""
Security Handler for HTTP Bearer authentication.
"""
security_definition_key = "x-bearerInfoFunc"
environ_key = "BEARERINFO_FUNC"
def check_bearer_token(self, token_info_func):
return self._generic_check(token_info_func, "Provided token is not valid")
def _get_verify_func(self, token_info_func):
"""
:param token_info_func: types.FunctionType
:rtype: types.FunctionType
"""
check_bearer_func = self.check_bearer_token(token_info_func)
def wrapper(request):
auth_type, token = self.get_auth_header_value(request)
if auth_type != "bearer":
return NO_VALUE
return check_bearer_func(request, token)
return wrapper
class ApiKeySecurityHandler(AbstractSecurityHandler):
"""
Security Handler for API Keys.
"""
security_definition_key = "x-apikeyInfoFunc"
environ_key = "APIKEYINFO_FUNC"
def get_fn(self, security_scheme, required_scopes):
apikey_info_func = self._resolve_func(security_scheme)
if not apikey_info_func:
logger.warning("... %s missing", self.security_definition_key)
return None
return self._get_verify_func(
apikey_info_func,
security_scheme["in"],
security_scheme["name"],
)
def _get_verify_func(self, api_key_info_func, loc, name):
check_api_key_func = self.check_api_key(api_key_info_func)
def wrapper(request: ASGIRequest):
if loc == "query":
api_key = request.query_params.get(name)
elif loc == "header":
api_key = request.headers.get(name)
elif loc == "cookie":
cookie_list = request.headers.get("Cookie")
api_key = self.get_cookie_value(cookie_list, name)
else:
return NO_VALUE
if api_key is None:
return NO_VALUE
return check_api_key_func(request, api_key)
return wrapper
def check_api_key(self, api_key_info_func):
return self._generic_check(api_key_info_func, "Provided apikey is not valid")
@staticmethod
def get_cookie_value(cookies, name):
"""
Returns cookie value by its name. `None` if no such value.
:param cookies: str: cookies raw data
:param name: str: cookies key
"""
cookie_parser = http.cookies.SimpleCookie()
cookie_parser.load(str(cookies))
try:
return cookie_parser[name].value
except KeyError:
return None
class OAuthSecurityHandler(AbstractSecurityHandler):
"""
Security Handler for the OAuth security scheme.
"""
def get_fn(self, security_scheme, required_scopes):
token_info_func = self.get_tokeninfo_func(security_scheme)
scope_validate_func = self.get_scope_validate_func(security_scheme)
if not token_info_func:
logger.warning("... x-tokenInfoFunc missing")
return None
return self._get_verify_func(
token_info_func, scope_validate_func, required_scopes
)
def get_tokeninfo_func(self, security_definition: dict) -> t.Optional[t.Callable]:
"""
:type security_definition: dict
Gets the function for retrieving the token info.
It is possible to specify a function or a URL. The function variant is
preferred. If it is not found, the URL variant is used with the
`get_token_info_remote` function.
>>> get_tokeninfo_url({'x-tokenInfoFunc': 'foo.bar'})
>>> get_tokeninfo_func({'x-tokenInfoFunc': 'foo.bar'})
'<function foo.bar>'
"""
token_info_func = self._get_function(
@@ -83,8 +335,8 @@ class SecurityHandlerFactory:
@classmethod
def get_scope_validate_func(cls, security_definition):
"""
:type security_definition: dict
:rtype: function
Gets the function for validating the token scopes.
If it is not found, the default `validate_scope` function is used.
>>> get_scope_validate_func({'x-scopeValidateFunc': 'foo.bar'})
'<function foo.bar>'
@@ -96,61 +348,6 @@ class SecurityHandlerFactory:
cls.validate_scope,
)
@classmethod
def get_basicinfo_func(cls, security_definition):
"""
:type security_definition: dict
:rtype: function
>>> get_basicinfo_func({'x-basicInfoFunc': 'foo.bar'})
'<function foo.bar>'
"""
return cls._get_function(
security_definition, "x-basicInfoFunc", "BASICINFO_FUNC"
)
@classmethod
def get_apikeyinfo_func(cls, security_definition):
"""
:type security_definition: dict
:rtype: function
>>> get_apikeyinfo_func({'x-apikeyInfoFunc': 'foo.bar'})
'<function foo.bar>'
"""
return cls._get_function(
security_definition, "x-apikeyInfoFunc", "APIKEYINFO_FUNC"
)
@classmethod
def get_bearerinfo_func(cls, security_definition):
"""
:type security_definition: dict
:rtype: function
>>> get_bearerinfo_func({'x-bearerInfoFunc': 'foo.bar'})
'<function foo.bar>'
"""
return cls._get_function(
security_definition, "x-bearerInfoFunc", "BEARERINFO_FUNC"
)
@staticmethod
async def security_passthrough(request):
return request
@staticmethod
def security_deny(function):
"""
:type function: types.FunctionType
:rtype: types.FunctionType
"""
def deny(*args, **kwargs):
raise ConnexionException("Error in security definitions")
return deny
@staticmethod
def validate_scope(required_scopes, token_scopes):
"""
@@ -167,210 +364,49 @@ class SecurityHandlerFactory:
logger.debug("... Token scopes: %s", token_scopes)
if not required_scopes <= token_scopes:
logger.info(
textwrap.dedent(
"""
... Token scopes (%s) do not match the scopes necessary to call endpoint (%s).
Aborting with 403."""
).replace("\n", ""),
"... Token scopes (%s) do not match the scopes necessary to call endpoint (%s)."
" Aborting with 403.",
token_scopes,
required_scopes,
)
return False
return True
@staticmethod
def get_auth_header_value(request):
def get_token_info_remote(self, token_info_url: str) -> t.Callable:
"""
Called inside security wrapper functions
Return a function which will call `token_info_url` to retrieve token info.
Return Authorization type and value if any.
If not Authorization, return (None, None)
Raise OAuthProblem for invalid Authorization header
Returned function must accept oauth token in parameter.
It must return a token_info dict in case of success, None otherwise.
:param token_info_url: URL to get information about the token
"""
authorization = request.headers.get("Authorization")
if not authorization:
return None, None
try:
auth_type, value = authorization.split(None, 1)
except ValueError:
raise OAuthProblem(detail="Invalid authorization header")
return auth_type.lower(), value
async def wrapper(token):
if self.client is None:
self.client = httpx.AsyncClient()
headers = {"Authorization": f"Bearer {token}"}
token_request = await self.client.get(
token_info_url, headers=headers, timeout=5
)
if token_request.status_code != 200:
return
return token_request.json()
def verify_oauth(self, token_info_func, scope_validate_func, required_scopes):
return wrapper
def _get_verify_func(self, token_info_func, scope_validate_func, required_scopes):
check_oauth_func = self.check_oauth_func(token_info_func, scope_validate_func)
def wrapper(request):
auth_type, token = self.get_auth_header_value(request)
if auth_type != "bearer":
return self.no_value
return NO_VALUE
return check_oauth_func(request, token, required_scopes=required_scopes)
return wrapper
def verify_basic(self, basic_info_func):
check_basic_info_func = self.check_basic_auth(basic_info_func)
def wrapper(request):
auth_type, user_pass = self.get_auth_header_value(request)
if auth_type != "basic":
return self.no_value
try:
username, password = (
base64.b64decode(user_pass).decode("latin1").split(":", 1)
)
except Exception:
raise OAuthProblem(detail="Invalid authorization header")
return check_basic_info_func(request, username, password)
return wrapper
@staticmethod
def get_cookie_value(cookies, name):
"""
Called inside security wrapper functions
Returns cookie value by its name. None if no such value.
:param cookies: str: cookies raw data
:param name: str: cookies key
"""
cookie_parser = http.cookies.SimpleCookie()
cookie_parser.load(str(cookies))
try:
return cookie_parser[name].value
except KeyError:
return None
def verify_api_key(self, api_key_info_func, loc, name):
check_api_key_func = self.check_api_key(api_key_info_func)
def wrapper(request):
def _immutable_pop(_dict, key):
"""
Pops the key from an immutable dict and returns the value that was popped,
and a new immutable dict without the popped key.
"""
cls = type(_dict)
try:
_dict = _dict.to_dict(flat=False)
return _dict.pop(key)[0], cls(_dict)
except AttributeError:
_dict = dict(_dict.items())
return _dict.pop(key), cls(_dict)
if loc == "query":
try:
api_key, request.query = _immutable_pop(request.query, name)
except KeyError:
api_key = None
elif loc == "header":
api_key = request.headers.get(name)
elif loc == "cookie":
cookie_list = request.headers.get("Cookie")
api_key = self.get_cookie_value(cookie_list, name)
else:
return self.no_value
if api_key is None:
return self.no_value
return check_api_key_func(request, api_key)
return wrapper
def verify_bearer(self, token_info_func):
"""
:param token_info_func: types.FunctionType
:rtype: types.FunctionType
"""
check_bearer_func = self.check_bearer_token(token_info_func)
def wrapper(request):
auth_type, token = self.get_auth_header_value(request)
if auth_type != "bearer":
return self.no_value
return check_bearer_func(request, token)
return wrapper
def verify_multiple_schemes(self, schemes):
"""
Verifies multiple authentication schemes in AND fashion.
If any scheme fails, the entire authentication fails.
:param schemes: mapping scheme_name to auth function
:type schemes: dict
:rtype: types.FunctionType
"""
async def wrapper(request):
token_info = {}
for scheme_name, func in schemes.items():
result = func(request)
while asyncio.iscoroutine(result):
result = await result
if result is self.no_value:
return self.no_value
token_info[scheme_name] = result
return token_info
return wrapper
@staticmethod
def verify_none():
"""
:rtype: types.FunctionType
"""
def wrapper(request):
return {}
return wrapper
def _need_to_add_context_or_scopes(self, func):
arguments, has_kwargs = inspect_function_arguments(func)
need_context = self.context_kw in arguments
need_required_scopes = has_kwargs or self.required_scopes_kw in arguments
return need_context, need_required_scopes
def _generic_check(self, func, exception_msg):
(
need_to_add_context,
need_to_add_required_scopes,
) = self._need_to_add_context_or_scopes(func)
async def wrapper(request, *args, required_scopes=None):
kwargs = {}
if need_to_add_context:
kwargs[self.context_kw] = request.context
if need_to_add_required_scopes:
kwargs[self.required_scopes_kw] = required_scopes
token_info = func(*args, **kwargs)
while asyncio.iscoroutine(token_info):
token_info = await token_info
if token_info is self.no_value:
return self.no_value
if token_info is None:
raise OAuthResponseProblem(detail=exception_msg)
return token_info
return wrapper
def check_bearer_token(self, token_info_func):
return self._generic_check(token_info_func, "Provided token is not valid")
def check_basic_auth(self, basic_info_func):
return self._generic_check(
basic_info_func, "Provided authorization is not valid"
)
def check_api_key(self, api_key_info_func):
return self._generic_check(api_key_info_func, "Provided apikey is not valid")
def check_oauth_func(self, token_info_func, scope_validate_func):
get_token_info = self._generic_check(
token_info_func, "Provided token is not valid"
@@ -403,17 +439,141 @@ class SecurityHandlerFactory:
return wrapper
SECURITY_HANDLERS = {
# Swagger 2: `type: basic`
# OpenAPI 3: `type: http` and `scheme: basic`
"basic": BasicSecurityHandler,
# Swagger 2 and OpenAPI 3
"apiKey": ApiKeySecurityHandler,
"oauth2": OAuthSecurityHandler,
# OpenAPI 3: http schemes
"bearer": BearerSecurityHandler,
}
class SecurityHandlerFactory:
"""
A factory class for parsing security schemes and returning the appropriate
security handler.
By default, it will use the built-in security handlers specified in the
SECURITY_HANDLERS dict, but you can also pass in your own security handlers
to override the built-in ones.
"""
def __init__(
self,
security_handlers: t.Optional[dict] = None,
) -> None:
self.security_handlers = SECURITY_HANDLERS.copy()
if security_handlers is not None:
self.security_handlers.update(security_handlers)
def parse_security_scheme(
self,
security_scheme: dict,
required_scopes: t.List[str],
) -> t.Optional[t.Callable]:
"""Parses the security scheme and returns the function for verifying it.
:param security_scheme: The security scheme from the spec.
:param required_scopes: List of scopes for this security scheme.
"""
security_type = security_scheme["type"]
if security_type in ("basic", "oauth2"):
security_handler = self.security_handlers[security_type]
return security_handler().get_fn(security_scheme, required_scopes)
# OpenAPI 3.0.0
elif security_type == "http":
scheme = security_scheme["scheme"].lower()
if scheme in self.security_handlers:
security_handler = self.security_handlers[scheme]
return security_handler().get_fn(security_scheme, required_scopes)
else:
logger.warning("... Unsupported http authorization scheme %s", scheme)
return None
elif security_type == "apiKey":
scheme = security_scheme.get("x-authentication-scheme", "").lower()
if scheme == "bearer":
return BearerSecurityHandler().get_fn(security_scheme, required_scopes)
else:
security_handler = self.security_handlers["apiKey"]
return security_handler().get_fn(security_scheme, required_scopes)
else:
logger.warning(
"... Unsupported security scheme type %s",
security_type,
)
return None
@staticmethod
async def security_passthrough(request):
"""Used when no security is required for the operation.
Equivalent OpenAPI snippet:
.. code-block:: yaml
/helloworld
get:
security: [] # No security
...
"""
return request
@staticmethod
def verify_none(request):
"""Used for optional security.
Equivalent OpenAPI snippet:
.. code-block:: yaml
security:
- {} # <--
- myapikey: []
"""
return {}
def verify_multiple_schemes(self, schemes):
"""
Verifies multiple authentication schemes in AND fashion.
If any scheme fails, the entire authentication fails.
:param schemes: mapping scheme_name to auth function
:type schemes: dict
:rtype: types.FunctionType
"""
async def wrapper(request):
token_info = {}
for scheme_name, func in schemes.items():
result = func(request)
while asyncio.iscoroutine(result):
result = await result
if result is NO_VALUE:
return NO_VALUE
token_info[scheme_name] = result
return token_info
return wrapper
@classmethod
def verify_security(cls, auth_funcs):
async def verify_fn(request):
token_info = cls.no_value
token_info = NO_VALUE
errors = []
for func in auth_funcs:
try:
token_info = func(request)
while asyncio.iscoroutine(token_info):
token_info = await token_info
if token_info is not cls.no_value:
if token_info is not NO_VALUE:
break
except Exception as err:
errors.append(err)
@@ -465,28 +625,3 @@ class SecurityHandlerFactory:
else:
lowest_status_code = min(status_to_exc)
raise status_to_exc[lowest_status_code]
def get_token_info_remote(self, token_info_url):
"""
Return a function which will call `token_info_url` to retrieve token info.
Returned function must accept oauth token in parameter.
It must return a token_info dict in case of success, None otherwise.
:param token_info_url: Url to get information about the token
:type token_info_url: str
:rtype: types.FunctionType
"""
async def wrapper(token):
if self.client is None:
self.client = httpx.AsyncClient()
headers = {"Authorization": f"Bearer {token}"}
token_request = await self.client.get(
token_info_url, headers=headers, timeout=5
)
if token_request.status_code != 200:
return
return token_request.json()
return wrapper

View File

@@ -19,11 +19,20 @@ except AttributeError: # jsonschema < 4.5.0
class ParameterValidator:
def __init__(self, parameters, uri_parser, strict_validation=False):
def __init__(
self,
parameters,
uri_parser,
strict_validation=False,
security_query_params=None,
):
"""
:param parameters: List of request parameter dictionaries
:param uri_parser: class to use for uri parsing
:param strict_validation: Flag indicating if parameters not in spec are allowed
:param security_query_params: List of query parameter names used for security.
These parameters will be ignored when checking for extra parameters in case of
strict validation.
"""
self.parameters = collections.defaultdict(list)
for p in parameters:
@@ -31,6 +40,7 @@ class ParameterValidator:
self.uri_parser = uri_parser
self.strict_validation = strict_validation
self.security_query_params = set(security_query_params or [])
@staticmethod
def validate_parameter(parameter_type, value, param, param_name=None):
@@ -59,9 +69,10 @@ class ParameterValidator:
return request_params.difference(spec_params)
def validate_query_parameter_list(self, request):
def validate_query_parameter_list(self, request, security_params=None):
request_params = request.query_params.keys()
spec_params = [x["name"] for x in self.parameters.get("query", [])]
spec_params.extend(security_params or [])
return self.validate_parameter_list(request_params, spec_params)
def validate_query_parameter(self, param, request):
@@ -99,9 +110,10 @@ class ParameterValidator:
self.validate_request(request)
def validate_request(self, request):
if self.strict_validation:
query_errors = self.validate_query_parameter_list(request)
query_errors = self.validate_query_parameter_list(
request, security_params=self.security_query_params
)
if query_errors:
raise ExtraParameterProblem(

View File

@@ -1,7 +1,7 @@
import json
import pytest
from connexion.security import SecurityHandlerFactory
from connexion.security import OAuthSecurityHandler
class FakeResponse:
@@ -50,7 +50,7 @@ def oauth_requests(monkeypatch):
)
return url
monkeypatch.setattr(SecurityHandlerFactory, "client", FakeClient())
monkeypatch.setattr(OAuthSecurityHandler, "client", FakeClient())
def test_security_over_nonexistent_endpoints(oauth_requests, secure_api_app):

View File

@@ -10,12 +10,19 @@ from connexion.exceptions import (
OAuthResponseProblem,
OAuthScopeProblem,
)
from connexion.security import SecurityHandlerFactory
from connexion.lifecycle import ASGIRequest
from connexion.security import (
NO_VALUE,
ApiKeySecurityHandler,
BasicSecurityHandler,
OAuthSecurityHandler,
SecurityHandlerFactory,
)
def test_get_tokeninfo_url(monkeypatch):
security_handler_factory = SecurityHandlerFactory()
security_handler_factory.get_token_info_remote = MagicMock(
security_handler = OAuthSecurityHandler()
security_handler.get_token_info_remote = MagicMock(
return_value="get_token_info_remote_result"
)
env = {}
@@ -24,24 +31,24 @@ def test_get_tokeninfo_url(monkeypatch):
monkeypatch.setattr("connexion.security.logger", logger)
security_def = {}
assert security_handler_factory.get_tokeninfo_func(security_def) is None
assert security_handler.get_tokeninfo_func(security_def) is None
logger.warn.assert_not_called()
env["TOKENINFO_URL"] = "issue-146"
assert (
security_handler_factory.get_tokeninfo_func(security_def)
security_handler.get_tokeninfo_func(security_def)
== "get_token_info_remote_result"
)
security_handler_factory.get_token_info_remote.assert_called_with("issue-146")
security_handler.get_token_info_remote.assert_called_with("issue-146")
logger.warn.assert_not_called()
logger.warn.reset_mock()
security_def = {"x-tokenInfoUrl": "bar"}
assert (
security_handler_factory.get_tokeninfo_func(security_def)
security_handler.get_tokeninfo_func(security_def)
== "get_token_info_remote_result"
)
security_handler_factory.get_token_info_remote.assert_called_with("bar")
security_handler.get_token_info_remote.assert_called_with("bar")
logger.warn.assert_not_called()
@@ -49,15 +56,14 @@ def test_verify_oauth_missing_auth_header():
def somefunc(token):
return None
security_handler_factory = SecurityHandlerFactory()
wrapped_func = security_handler_factory.verify_oauth(
somefunc, security_handler_factory.validate_scope, ["admin"]
security_handler = OAuthSecurityHandler()
wrapped_func = security_handler._get_verify_func(
somefunc, security_handler.validate_scope, ["admin"]
)
request = MagicMock()
request.headers = {}
request = ASGIRequest(scope={"type": "http", "headers": []})
assert wrapped_func(request) is security_handler_factory.no_value
assert wrapped_func(request) is NO_VALUE
async def test_verify_oauth_scopes_remote(monkeypatch):
@@ -69,20 +75,21 @@ async def test_verify_oauth_scopes_remote(monkeypatch):
tokeninfo_response._content = json.dumps(tokeninfo).encode()
return tokeninfo_response
security_handler_factory = SecurityHandlerFactory()
token_info_func = security_handler_factory.get_tokeninfo_func(
security_handler = OAuthSecurityHandler()
token_info_func = security_handler.get_tokeninfo_func(
{"x-tokenInfoUrl": "https://example.org/tokeninfo"}
)
wrapped_func = security_handler_factory.verify_oauth(
token_info_func, security_handler_factory.validate_scope, ["admin"]
wrapped_func = security_handler._get_verify_func(
token_info_func, security_handler.validate_scope, ["admin"]
)
request = MagicMock()
request.headers = {"Authorization": "Bearer 123"}
request = ASGIRequest(
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
)
client = MagicMock()
client.get = get_tokeninfo_response
monkeypatch.setattr(SecurityHandlerFactory, "client", client)
monkeypatch.setattr(OAuthSecurityHandler, "client", client)
with pytest.raises(OAuthScopeProblem) as exc_info:
await wrapped_func(request)
@@ -112,13 +119,14 @@ async def test_verify_oauth_invalid_local_token_response_none():
def somefunc(token):
return None
security_handler_factory = SecurityHandlerFactory()
wrapped_func = security_handler_factory.verify_oauth(
somefunc, security_handler_factory.validate_scope, ["admin"]
security_handler = OAuthSecurityHandler()
wrapped_func = security_handler._get_verify_func(
somefunc, security_handler.validate_scope, ["admin"]
)
request = MagicMock()
request.headers = {"Authorization": "Bearer 123"}
request = ASGIRequest(
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
)
with pytest.raises(OAuthResponseProblem):
await wrapped_func(request)
@@ -130,13 +138,14 @@ async def test_verify_oauth_scopes_local():
def token_info(token):
return tokeninfo
security_handler_factory = SecurityHandlerFactory()
wrapped_func = security_handler_factory.verify_oauth(
token_info, security_handler_factory.validate_scope, ["admin"]
security_handler = OAuthSecurityHandler()
wrapped_func = security_handler._get_verify_func(
token_info, security_handler.validate_scope, ["admin"]
)
request = MagicMock()
request.headers = {"Authorization": "Bearer 123"}
request = ASGIRequest(
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
)
with pytest.raises(OAuthScopeProblem) as exc_info:
await wrapped_func(request)
@@ -166,13 +175,14 @@ def test_verify_basic_missing_auth_header():
def somefunc(username, password, required_scopes=None):
return None
security_handler_factory = SecurityHandlerFactory()
wrapped_func = security_handler_factory.verify_basic(somefunc)
security_handler = BasicSecurityHandler()
wrapped_func = security_handler._get_verify_func(somefunc)
request = MagicMock()
request.headers = {"Authorization": "Bearer 123"}
request = ASGIRequest(
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
)
assert wrapped_func(request) is security_handler_factory.no_value
assert wrapped_func(request) is NO_VALUE
async def test_verify_basic():
@@ -181,11 +191,12 @@ async def test_verify_basic():
return {"sub": "foo"}
return None
security_handler_factory = SecurityHandlerFactory()
wrapped_func = security_handler_factory.verify_basic(basic_info)
security_handler = BasicSecurityHandler()
wrapped_func = security_handler._get_verify_func(basic_info)
request = MagicMock()
request.headers = {"Authorization": "Basic Zm9vOmJhcg=="}
request = ASGIRequest(
scope={"type": "http", "headers": [[b"authorization", b"Basic Zm9vOmJhcg=="]]}
)
assert await wrapped_func(request) is not None
@@ -196,11 +207,12 @@ async def test_verify_apikey_query():
return {"sub": "foo"}
return None
security_handler_factory = SecurityHandlerFactory()
wrapped_func = security_handler_factory.verify_api_key(apikey_info, "query", "auth")
security_handler_factory = ApiKeySecurityHandler()
wrapped_func = security_handler_factory._get_verify_func(
apikey_info, "query", "auth"
)
request = MagicMock()
request.query = {"auth": "foobar"}
request = ASGIRequest(scope={"type": "http", "query_string": b"auth=foobar"})
assert await wrapped_func(request) is not None
@@ -211,13 +223,12 @@ async def test_verify_apikey_header():
return {"sub": "foo"}
return None
security_handler_factory = SecurityHandlerFactory()
wrapped_func = security_handler_factory.verify_api_key(
security_handler_factory = ApiKeySecurityHandler()
wrapped_func = security_handler_factory._get_verify_func(
apikey_info, "header", "X-Auth"
)
request = MagicMock()
request.headers = {"X-Auth": "foobar"}
request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth", b"foobar"]]})
assert await wrapped_func(request) is not None
@@ -234,10 +245,11 @@ async def test_multiple_schemes():
return None
security_handler_factory = SecurityHandlerFactory()
wrapped_func_key1 = security_handler_factory.verify_api_key(
apikey_security_handler_factory = ApiKeySecurityHandler()
wrapped_func_key1 = apikey_security_handler_factory._get_verify_func(
apikey1_info, "header", "X-Auth-1"
)
wrapped_func_key2 = security_handler_factory.verify_api_key(
wrapped_func_key2 = apikey_security_handler_factory._get_verify_func(
apikey2_info, "header", "X-Auth-2"
)
schemes = {
@@ -247,19 +259,21 @@ async def test_multiple_schemes():
wrapped_func = security_handler_factory.verify_multiple_schemes(schemes)
# Single key does not succeed
request = MagicMock()
request.headers = {"X-Auth-1": "foobar"}
request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth-1", b"foobar"]]})
assert await wrapped_func(request) is security_handler_factory.no_value
assert await wrapped_func(request) is NO_VALUE
request = MagicMock()
request.headers = {"X-Auth-2": "bar"}
request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth-2", b"bar"]]})
assert await wrapped_func(request) is security_handler_factory.no_value
assert await wrapped_func(request) is NO_VALUE
# Supplying both keys does succeed
request = MagicMock()
request.headers = {"X-Auth-1": "foobar", "X-Auth-2": "bar"}
request = ASGIRequest(
scope={
"type": "http",
"headers": [[b"x-auth-1", b"foobar"], [b"x-auth-2", b"bar"]],
}
)
expected_token_info = {
"key1": {"sub": "foo"},
@@ -273,7 +287,7 @@ async def test_verify_security_oauthproblem():
security_handler_factory = SecurityHandlerFactory()
security_func = security_handler_factory.verify_security([])
request = MagicMock()
request = MagicMock(spec_set=ASGIRequest)
with pytest.raises(OAuthProblem) as exc_info:
await security_func(request)

View File

@@ -12,7 +12,11 @@ from connexion.jsonifier import Jsonifier
from connexion.middleware.security import SecurityOperation
from connexion.operations import Swagger2Operation
from connexion.resolver import Resolver
from connexion.security import SecurityHandlerFactory
from connexion.security import (
ApiKeySecurityHandler,
OAuthSecurityHandler,
SecurityHandlerFactory,
)
TEST_FOLDER = pathlib.Path(__file__).parent
@@ -418,11 +422,14 @@ def test_operation(api):
def test_operation_remote_token_info():
security_handler_factory = SecurityHandlerFactory()
class MockOAuthHandler(OAuthSecurityHandler):
"""Mock."""
security_handler_factory = SecurityHandlerFactory({"oauth2": MockOAuthHandler})
oauth_security_handler = security_handler_factory.security_handlers["oauth2"]
verify_oauth = mock.MagicMock(return_value="verify_oauth_result")
security_handler_factory.verify_oauth = verify_oauth
security_handler_factory.get_token_info_remote = mock.MagicMock(
oauth_security_handler._get_verify_func = verify_oauth
oauth_security_handler.get_token_info_remote = mock.MagicMock(
return_value="get_token_info_remote_result"
)
@@ -434,9 +441,11 @@ def test_operation_remote_token_info():
)
verify_oauth.assert_called_with(
"get_token_info_remote_result", security_handler_factory.validate_scope, ["uid"]
"get_token_info_remote_result",
oauth_security_handler.validate_scope,
["uid"],
)
security_handler_factory.get_token_info_remote.assert_called_with(
oauth_security_handler.get_token_info_remote.assert_called_with(
"https://oauth.example/token_info"
)
@@ -491,10 +500,13 @@ def test_operation_composed_definition(api):
def test_operation_local_security_oauth2():
security_handler_factory = SecurityHandlerFactory()
class MockOAuthHandler(OAuthSecurityHandler):
"""Mock."""
security_handler_factory = SecurityHandlerFactory({"oauth2": MockOAuthHandler})
oauth_security_handler = security_handler_factory.security_handlers["oauth2"]
verify_oauth = mock.MagicMock(return_value="verify_oauth_result")
security_handler_factory.verify_oauth = verify_oauth
oauth_security_handler._get_verify_func = verify_oauth
SecurityOperation(
next_app=mock.Mock,
@@ -504,15 +516,24 @@ def test_operation_local_security_oauth2():
)
verify_oauth.assert_called_with(
math.ceil, security_handler_factory.validate_scope, ["uid"]
math.ceil, oauth_security_handler.validate_scope, ["uid"]
)
verify_oauth.assert_called_with(
math.ceil,
security_handler_factory.security_handlers["oauth2"].validate_scope,
["uid"],
)
def test_operation_local_security_duplicate_token_info():
security_handler_factory = SecurityHandlerFactory()
class MockOAuthHandler(OAuthSecurityHandler):
"""Mock."""
security_handler_factory = SecurityHandlerFactory({"oauth2": MockOAuthHandler})
oauth_security_handler = security_handler_factory.security_handlers["oauth2"]
verify_oauth = mock.MagicMock(return_value="verify_oauth_result")
security_handler_factory.verify_oauth = verify_oauth
oauth_security_handler._get_verify_func = verify_oauth
SecurityOperation(
next_app=mock.Mock,
@@ -522,7 +543,11 @@ def test_operation_local_security_duplicate_token_info():
)
verify_oauth.call_args.assert_called_with(
math.ceil, security_handler_factory.validate_scope
math.ceil, oauth_security_handler.validate_scope
)
verify_oauth.call_args.assert_called_with(
math.ceil, security_handler_factory.security_handlers["oauth2"].validate_scope
)
@@ -565,9 +590,13 @@ def test_multiple_security_schemes_and():
def return_api_key_name(func, in_, name):
return name
security_handler_factory = SecurityHandlerFactory()
class MockApiKeyHandler(ApiKeySecurityHandler):
"""Mock"""
security_handler_factory = SecurityHandlerFactory({"apiKey": MockApiKeyHandler})
apikey_security_handler = security_handler_factory.security_handlers["apiKey"]
verify_api_key = mock.MagicMock(side_effect=return_api_key_name)
security_handler_factory.verify_api_key = verify_api_key
apikey_security_handler._get_verify_func = verify_api_key
verify_multiple = mock.MagicMock(return_value="verify_multiple_result")
security_handler_factory.verify_multiple_schemes = verify_multiple
@@ -595,9 +624,6 @@ def test_multiple_oauth_in_and(caplog):
caplog.set_level(logging.WARNING, logger="connexion.operations.secure")
security_handler_factory = SecurityHandlerFactory()
verify_oauth = mock.MagicMock(return_value="verify_oauth_result")
security_handler_factory.verify_oauth = verify_oauth
security = [{"oauth_1": ["uid"], "oauth_2": ["uid"]}]
SecurityOperation(
@@ -688,10 +714,14 @@ def test_get_path_parameter_types(api):
def test_oauth_scopes_in_or():
"""Tests whether an OAuth security scheme with 2 different possible scopes is correctly handled."""
security_handler_factory = SecurityHandlerFactory()
class MockOAuthFactory(OAuthSecurityHandler):
"""Mock."""
security_handler_factory = SecurityHandlerFactory({"oauth2": MockOAuthFactory})
oauth_security_handler = security_handler_factory.security_handlers["oauth2"]
verify_oauth = mock.MagicMock(return_value="verify_oauth_result")
security_handler_factory.verify_oauth = verify_oauth
oauth_security_handler._get_verify_func = verify_oauth
security = [{"oauth": ["myscope"]}, {"oauth": ["myscope2"]}]
@@ -704,8 +734,8 @@ def test_oauth_scopes_in_or():
verify_oauth.assert_has_calls(
[
mock.call(math.ceil, security_handler_factory.validate_scope, ["myscope"]),
mock.call(math.ceil, security_handler_factory.validate_scope, ["myscope2"]),
mock.call(math.ceil, oauth_security_handler.validate_scope, ["myscope"]),
mock.call(math.ceil, oauth_security_handler.validate_scope, ["myscope2"]),
]
)