mirror of
https://github.com/LukeHagar/connexion.git
synced 2025-12-06 04:19:26 +00:00
Make security pluggable (#1671)
Make security pluggable - [x] Solution for standard security handlers: `security_deny`, `security_passthrough`, `verify_none` - [x] HTTP security handlers & overlap with basic from swagger 2 - [x] Do we need a separate handler for each `oauth2` flow?
This commit is contained in:
@@ -47,6 +47,7 @@ class AbstractApp:
|
||||
uri_parser_class: t.Optional[AbstractURIParser] = None,
|
||||
validate_responses: t.Optional[bool] = None,
|
||||
validator_map: t.Optional[dict] = None,
|
||||
security_map: t.Optional[dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param import_name: The name of the package or module that this object belongs to. If you
|
||||
@@ -77,6 +78,8 @@ class AbstractApp:
|
||||
an impact on performance. Defaults to False.
|
||||
:param validator_map: A dictionary of validators to use. Defaults to
|
||||
:obj:`validators.VALIDATOR_MAP`.
|
||||
:param security_map: A dictionary of security handlers to use. Defaults to
|
||||
:obj:`security.SECURITY_HANDLERS`
|
||||
"""
|
||||
self.middleware = ConnexionMiddleware(
|
||||
self.middleware_app,
|
||||
@@ -95,6 +98,7 @@ class AbstractApp:
|
||||
uri_parser_class=uri_parser_class,
|
||||
validate_responses=validate_responses,
|
||||
validator_map=validator_map,
|
||||
security_map=security_map,
|
||||
)
|
||||
|
||||
def add_api(
|
||||
@@ -113,6 +117,7 @@ class AbstractApp:
|
||||
uri_parser_class: t.Optional[AbstractURIParser] = None,
|
||||
validate_responses: t.Optional[bool] = None,
|
||||
validator_map: t.Optional[dict] = None,
|
||||
security_map: t.Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> t.Any:
|
||||
"""
|
||||
@@ -143,6 +148,8 @@ class AbstractApp:
|
||||
an impact on performance. Defaults to False.
|
||||
:param validator_map: A dictionary of validators to use. Defaults to
|
||||
:obj:`validators.VALIDATOR_MAP`
|
||||
:param security_map: A dictionary of security handlers to use. Defaults to
|
||||
:obj:`security.SECURITY_HANDLERS`
|
||||
:param kwargs: Additional keyword arguments to pass to the `add_api` method of the managed
|
||||
middlewares. This can be used to pass arguments to middlewares added beyond the default
|
||||
ones.
|
||||
@@ -163,6 +170,7 @@ class AbstractApp:
|
||||
uri_parser_class=uri_parser_class,
|
||||
validate_responses=validate_responses,
|
||||
validator_map=validator_map,
|
||||
security_map=security_map,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -135,6 +135,7 @@ class AsyncApp(AbstractApp):
|
||||
uri_parser_class: t.Optional[AbstractURIParser] = None,
|
||||
validate_responses: t.Optional[bool] = None,
|
||||
validator_map: t.Optional[dict] = None,
|
||||
security_map: t.Optional[dict] = None,
|
||||
) -> None:
|
||||
"""
|
||||
:param import_name: The name of the package or module that this object belongs to. If you
|
||||
@@ -165,6 +166,8 @@ class AsyncApp(AbstractApp):
|
||||
an impact on performance. Defaults to False.
|
||||
:param validator_map: A dictionary of validators to use. Defaults to
|
||||
:obj:`validators.VALIDATOR_MAP`.
|
||||
:param security_map: A dictionary of security handlers to use. Defaults to
|
||||
:obj:`security.SECURITY_HANDLERS`
|
||||
"""
|
||||
self.middleware_app: AsyncMiddlewareApp = AsyncMiddlewareApp()
|
||||
|
||||
@@ -184,6 +187,7 @@ class AsyncApp(AbstractApp):
|
||||
uri_parser_class=uri_parser_class,
|
||||
validate_responses=validate_responses,
|
||||
validator_map=validator_map,
|
||||
security_map=security_map,
|
||||
)
|
||||
|
||||
def add_url_rule(
|
||||
|
||||
@@ -192,6 +192,7 @@ class FlaskApp(AbstractApp):
|
||||
uri_parser_class: t.Optional[AbstractURIParser] = None,
|
||||
validate_responses: t.Optional[bool] = None,
|
||||
validator_map: t.Optional[dict] = None,
|
||||
security_map: t.Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
:param import_name: The name of the package or module that this object belongs to. If you
|
||||
@@ -225,6 +226,8 @@ class FlaskApp(AbstractApp):
|
||||
an impact on performance. Defaults to False.
|
||||
:param validator_map: A dictionary of validators to use. Defaults to
|
||||
:obj:`validators.VALIDATOR_MAP`.
|
||||
:param security_map: A dictionary of security handlers to use. Defaults to
|
||||
:obj:`security.SECURITY_HANDLERS`
|
||||
"""
|
||||
self.middleware_app = FlaskMiddlewareApp(import_name, server_args or {})
|
||||
self.app = self.middleware_app.app
|
||||
@@ -244,6 +247,7 @@ class FlaskApp(AbstractApp):
|
||||
uri_parser_class=uri_parser_class,
|
||||
validate_responses=validate_responses,
|
||||
validator_map=validator_map,
|
||||
security_map=security_map,
|
||||
)
|
||||
|
||||
def add_url_rule(
|
||||
|
||||
@@ -51,6 +51,7 @@ class _Options:
|
||||
uri_parser_class: t.Optional[AbstractURIParser] = None
|
||||
validate_responses: t.Optional[bool] = False
|
||||
validator_map: t.Optional[dict] = None
|
||||
security_map: t.Optional[dict] = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.resolver = (
|
||||
@@ -115,6 +116,7 @@ class ConnexionMiddleware:
|
||||
uri_parser_class: t.Optional[AbstractURIParser] = None,
|
||||
validate_responses: t.Optional[bool] = None,
|
||||
validator_map: t.Optional[dict] = None,
|
||||
security_map: t.Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
:param import_name: The name of the package or module that this object belongs to. If you
|
||||
@@ -145,6 +147,8 @@ class ConnexionMiddleware:
|
||||
an impact on performance. Defaults to False.
|
||||
:param validator_map: A dictionary of validators to use. Defaults to
|
||||
:obj:`validators.VALIDATOR_MAP`.
|
||||
:param security_map: A dictionary of security handlers to use. Defaults to
|
||||
:obj:`security.SECURITY_HANDLERS`.
|
||||
"""
|
||||
import_name = import_name or str(pathlib.Path.cwd())
|
||||
self.root_path = utils.get_root_path(import_name)
|
||||
@@ -169,6 +173,7 @@ class ConnexionMiddleware:
|
||||
uri_parser_class=uri_parser_class,
|
||||
validate_responses=validate_responses,
|
||||
validator_map=validator_map,
|
||||
security_map=security_map,
|
||||
)
|
||||
|
||||
self.extra_files: t.List[str] = []
|
||||
@@ -217,6 +222,7 @@ class ConnexionMiddleware:
|
||||
uri_parser_class: t.Optional[AbstractURIParser] = None,
|
||||
validate_responses: t.Optional[bool] = None,
|
||||
validator_map: t.Optional[dict] = None,
|
||||
security_map: t.Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> t.Any:
|
||||
"""
|
||||
@@ -247,6 +253,8 @@ class ConnexionMiddleware:
|
||||
an impact on performance. Defaults to False.
|
||||
:param validator_map: A dictionary of validators to use. Defaults to
|
||||
:obj:`validators.VALIDATOR_MAP`
|
||||
:param security_map: A dictionary of security handlers to use. Defaults to
|
||||
:obj:`security.SECURITY_HANDLERS`
|
||||
:param kwargs: Additional keyword arguments to pass to the `add_api` method of the managed
|
||||
middlewares. This can be used to pass arguments to middlewares added beyond the default
|
||||
ones.
|
||||
@@ -275,6 +283,7 @@ class ConnexionMiddleware:
|
||||
uri_parser_class=uri_parser_class,
|
||||
validate_responses=validate_responses,
|
||||
validator_map=validator_map,
|
||||
security_map=security_map,
|
||||
)
|
||||
|
||||
for app in self.apps:
|
||||
|
||||
@@ -70,6 +70,29 @@ class RequestValidationOperation:
|
||||
f"expected {self._operation.consumes}"
|
||||
)
|
||||
|
||||
@property
|
||||
def security_query_params(self) -> t.List[str]:
|
||||
"""Get the names of query parameters that are used for security."""
|
||||
if not hasattr(self, "_security_query_params"):
|
||||
security_query_params: t.List[str] = []
|
||||
if self._operation.security is None:
|
||||
self._security_query_params = security_query_params
|
||||
return self._security_query_params
|
||||
|
||||
for security_req in self._operation.security:
|
||||
for scheme_name in security_req:
|
||||
security_scheme = self._operation.security_schemes[scheme_name]
|
||||
|
||||
if (
|
||||
security_scheme["type"] == "apiKey"
|
||||
and security_scheme["in"] == "query"
|
||||
):
|
||||
# Only query parameters need to be considered for strict_validation
|
||||
security_query_params.append(security_scheme["name"])
|
||||
self._security_query_params = security_query_params
|
||||
|
||||
return self._security_query_params
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
# Validate parameters & headers
|
||||
uri_parser_class = self._operation._uri_parser_class
|
||||
@@ -81,6 +104,7 @@ class RequestValidationOperation:
|
||||
self._operation.parameters,
|
||||
uri_parser=uri_parser,
|
||||
strict_validation=self.strict_validation,
|
||||
security_query_params=self.security_query_params,
|
||||
)
|
||||
parameter_validator.validate(scope)
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ class SecurityOperation:
|
||||
auth_funcs = []
|
||||
for security_req in self.security:
|
||||
if not security_req:
|
||||
auth_funcs.append(self.security_handler_factory.verify_none())
|
||||
auth_funcs.append(self.security_handler_factory.verify_none)
|
||||
continue
|
||||
|
||||
sec_req_funcs = {}
|
||||
@@ -67,123 +67,15 @@ class SecurityOperation:
|
||||
)
|
||||
break
|
||||
oauth = True
|
||||
token_info_func = self.security_handler_factory.get_tokeninfo_func(
|
||||
security_scheme
|
||||
|
||||
sec_req_func = self.security_handler_factory.parse_security_scheme(
|
||||
security_scheme, required_scopes
|
||||
)
|
||||
scope_validate_func = (
|
||||
self.security_handler_factory.get_scope_validate_func(
|
||||
security_scheme
|
||||
)
|
||||
)
|
||||
if not token_info_func:
|
||||
logger.warning("... x-tokenInfoFunc missing", extra=vars(self))
|
||||
if sec_req_func is None:
|
||||
break
|
||||
|
||||
sec_req_funcs[
|
||||
scheme_name
|
||||
] = self.security_handler_factory.verify_oauth(
|
||||
token_info_func, scope_validate_func, required_scopes
|
||||
)
|
||||
sec_req_funcs[scheme_name] = sec_req_func
|
||||
|
||||
# Swagger 2.0
|
||||
elif security_scheme["type"] == "basic":
|
||||
basic_info_func = self.security_handler_factory.get_basicinfo_func(
|
||||
security_scheme
|
||||
)
|
||||
if not basic_info_func:
|
||||
logger.warning("... x-basicInfoFunc missing", extra=vars(self))
|
||||
break
|
||||
|
||||
sec_req_funcs[
|
||||
scheme_name
|
||||
] = self.security_handler_factory.verify_basic(basic_info_func)
|
||||
|
||||
# OpenAPI 3.0.0
|
||||
elif security_scheme["type"] == "http":
|
||||
scheme = security_scheme["scheme"].lower()
|
||||
if scheme == "basic":
|
||||
basic_info_func = (
|
||||
self.security_handler_factory.get_basicinfo_func(
|
||||
security_scheme
|
||||
)
|
||||
)
|
||||
if not basic_info_func:
|
||||
logger.warning(
|
||||
"... x-basicInfoFunc missing", extra=vars(self)
|
||||
)
|
||||
break
|
||||
|
||||
sec_req_funcs[
|
||||
scheme_name
|
||||
] = self.security_handler_factory.verify_basic(basic_info_func)
|
||||
elif scheme == "bearer":
|
||||
bearer_info_func = (
|
||||
self.security_handler_factory.get_bearerinfo_func(
|
||||
security_scheme
|
||||
)
|
||||
)
|
||||
if not bearer_info_func:
|
||||
logger.warning(
|
||||
"... x-bearerInfoFunc missing", extra=vars(self)
|
||||
)
|
||||
break
|
||||
sec_req_funcs[
|
||||
scheme_name
|
||||
] = self.security_handler_factory.verify_bearer(
|
||||
bearer_info_func
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"... Unsupported http authorization scheme %s" % scheme,
|
||||
extra=vars(self),
|
||||
)
|
||||
break
|
||||
|
||||
elif security_scheme["type"] == "apiKey":
|
||||
scheme = security_scheme.get("x-authentication-scheme", "").lower()
|
||||
if scheme == "bearer":
|
||||
bearer_info_func = (
|
||||
self.security_handler_factory.get_bearerinfo_func(
|
||||
security_scheme
|
||||
)
|
||||
)
|
||||
if not bearer_info_func:
|
||||
logger.warning(
|
||||
"... x-bearerInfoFunc missing", extra=vars(self)
|
||||
)
|
||||
break
|
||||
sec_req_funcs[
|
||||
scheme_name
|
||||
] = self.security_handler_factory.verify_bearer(
|
||||
bearer_info_func
|
||||
)
|
||||
else:
|
||||
apikey_info_func = (
|
||||
self.security_handler_factory.get_apikeyinfo_func(
|
||||
security_scheme
|
||||
)
|
||||
)
|
||||
if not apikey_info_func:
|
||||
logger.warning(
|
||||
"... x-apikeyInfoFunc missing", extra=vars(self)
|
||||
)
|
||||
break
|
||||
|
||||
sec_req_funcs[
|
||||
scheme_name
|
||||
] = self.security_handler_factory.verify_api_key(
|
||||
apikey_info_func,
|
||||
security_scheme["in"],
|
||||
security_scheme["name"],
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"... Unsupported security scheme type %s"
|
||||
% security_scheme["type"],
|
||||
extra=vars(self),
|
||||
)
|
||||
break
|
||||
else:
|
||||
# No break encountered: no missing funcs
|
||||
if len(sec_req_funcs) == 1:
|
||||
@@ -199,16 +91,22 @@ class SecurityOperation:
|
||||
return self.security_handler_factory.verify_security(auth_funcs)
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||
if not self.security:
|
||||
await self.next_app(scope, receive, send)
|
||||
return
|
||||
|
||||
request = ASGIRequest(scope)
|
||||
await self.verification_fn(request)
|
||||
await self.next_app(scope, receive, send)
|
||||
|
||||
|
||||
class SecurityAPI(RoutedAPI[SecurityOperation]):
|
||||
def __init__(self, *args, auth_all_paths: bool = False, **kwargs):
|
||||
def __init__(
|
||||
self, *args, auth_all_paths: bool = False, security_map: dict = None, **kwargs
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.security_handler_factory = SecurityHandlerFactory()
|
||||
self.security_handler_factory = SecurityHandlerFactory(security_map)
|
||||
|
||||
if auth_all_paths:
|
||||
self.add_auth_on_not_found()
|
||||
|
||||
@@ -1,6 +1,46 @@
|
||||
"""
|
||||
This module defines an abstract SecurityHandlerFactory which supports the creation of security
|
||||
handlers for operations.
|
||||
This module defines a SecurityHandlerFactory which supports the creation of
|
||||
SecurityHandler instances for different security schemes.
|
||||
|
||||
It also exposes a `SECURITY_HANDLERS` dictionary which maps security scheme
|
||||
types to SecurityHandler classes. This dictionary can be used to register
|
||||
custom SecurityHandler classes for custom security schemes, or to overwrite
|
||||
existing SecurityHandler classes.
|
||||
This can be done by supplying a value for `security_map` argument of the
|
||||
SecurityHandlerFactory.
|
||||
|
||||
Swagger 2.0 lets you define the following authentication types for an API:
|
||||
|
||||
- Basic authentication
|
||||
- API key (as a header or a query string parameter)
|
||||
- OAuth 2 common flows (authorization code, implicit, resource owner password credentials, client credentials)
|
||||
|
||||
|
||||
Changes from OpenAPI 2.0 to OpenAPI 3.0
|
||||
If you used OpenAPI 2.0 before, here is a summary of changes to help you get started with OpenAPI 3.0:
|
||||
- securityDefinitions were renamed to securitySchemes and moved inside components.
|
||||
- type: basic was replaced with type: http and scheme: basic.
|
||||
- The new type: http is an umbrella type for all HTTP security schemes, including Basic, Bearer and other,
|
||||
and the scheme keyword indicates the scheme type.
|
||||
- API keys can now be sent in: cookie.
|
||||
- Added support for OpenID Connect Discovery (type: openIdConnect).
|
||||
- OAuth 2 security schemes can now define multiple flows.
|
||||
- OAuth 2 flows were renamed to match the OAuth 2 Specification: accessCode is now authorizationCode,
|
||||
and application is now clientCredentials.
|
||||
|
||||
|
||||
OpenAPI uses the term security scheme for authentication and authorization schemes.
|
||||
OpenAPI 3.0 lets you describe APIs protected using the following security schemes:
|
||||
|
||||
- HTTP authentication schemes (they use the Authorization header):
|
||||
- Basic
|
||||
- Bearer
|
||||
- other HTTP schemes as defined by RFC 7235 and HTTP Authentication Scheme Registry
|
||||
- API keys in headers, query string or cookies
|
||||
- Cookie authentication
|
||||
- OAuth 2
|
||||
- OpenID Connect Discovery
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -8,62 +48,274 @@ import base64
|
||||
import http.cookies
|
||||
import logging
|
||||
import os
|
||||
import textwrap
|
||||
import typing as t
|
||||
|
||||
import httpx
|
||||
|
||||
from connexion.decorators.parameter import inspect_function_arguments
|
||||
from connexion.exceptions import (
|
||||
ConnexionException,
|
||||
OAuthProblem,
|
||||
OAuthResponseProblem,
|
||||
OAuthScopeProblem,
|
||||
)
|
||||
from connexion.exceptions import OAuthProblem, OAuthResponseProblem, OAuthScopeProblem
|
||||
from connexion.lifecycle import ASGIRequest
|
||||
from connexion.utils import get_function_from_name
|
||||
|
||||
logger = logging.getLogger("connexion.api.security")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SecurityHandlerFactory:
|
||||
"""
|
||||
get_*_func -> _get_function -> get_function_from_name (name=security function defined in spec)
|
||||
(if url defined instead of a function -> get_token_info_remote)
|
||||
NO_VALUE = object()
|
||||
"""Sentinel value to indicate that no security credentials were found."""
|
||||
|
||||
std security functions: security_{passthrough,deny}
|
||||
|
||||
verify_* -> returns a security wrapper around the security function
|
||||
check_* -> returns a function tasked with doing auth for use inside the verify wrapper
|
||||
check helpers (used outside wrappers): _need_to_add_context_or_scopes
|
||||
the security function
|
||||
class AbstractSecurityHandler:
|
||||
|
||||
verify helpers (used inside wrappers): get_auth_header_value, get_cookie_value
|
||||
"""
|
||||
|
||||
no_value = object()
|
||||
required_scopes_kw = "required_scopes"
|
||||
context_kw = "context_"
|
||||
client = None
|
||||
security_definition_key: str
|
||||
"""The key which contains the value for the function name to resolve."""
|
||||
environ_key: str
|
||||
"""The name of the environment variable that can be used alternatively for the function name."""
|
||||
|
||||
@staticmethod
|
||||
def get_fn(self, security_scheme, required_scopes):
|
||||
"""Returns the handler function"""
|
||||
security_func = self._resolve_func(security_scheme)
|
||||
if not security_func:
|
||||
logger.warning("... %s missing", self.security_definition_key)
|
||||
return None
|
||||
|
||||
return self._get_verify_func(security_func)
|
||||
|
||||
@classmethod
|
||||
def _get_function(
|
||||
security_definition, security_definition_key, environ_key, default=None
|
||||
cls,
|
||||
security_definition: dict,
|
||||
security_definition_key: str,
|
||||
environ_key: str,
|
||||
default: t.Optional[t.Callable] = None,
|
||||
):
|
||||
"""
|
||||
Return function by getting its name from security_definition or environment variable
|
||||
|
||||
:param security_definition: Security Definition (scheme) from the spec.
|
||||
:param security_definition_key: The key which contains the value for the function name to resolve.
|
||||
:param environ_key: The name of the environment variable that can be used alternatively for the function name.
|
||||
:param default: The default to use in case the function cannot be found based on the security_definition_key or the environ_key
|
||||
"""
|
||||
func = security_definition.get(security_definition_key) or os.environ.get(
|
||||
func_name = security_definition.get(security_definition_key) or os.environ.get(
|
||||
environ_key
|
||||
)
|
||||
if func:
|
||||
return get_function_from_name(func)
|
||||
if func_name:
|
||||
return get_function_from_name(func_name)
|
||||
return default
|
||||
|
||||
def _generic_check(self, func, exception_msg):
|
||||
(
|
||||
need_to_add_context,
|
||||
need_to_add_required_scopes,
|
||||
) = self._need_to_add_context_or_scopes(func)
|
||||
|
||||
async def wrapper(request, *args, required_scopes=None):
|
||||
kwargs = {}
|
||||
if need_to_add_context:
|
||||
kwargs[self.context_kw] = request.context
|
||||
if need_to_add_required_scopes:
|
||||
kwargs[self.required_scopes_kw] = required_scopes
|
||||
token_info = func(*args, **kwargs)
|
||||
while asyncio.iscoroutine(token_info):
|
||||
token_info = await token_info
|
||||
if token_info is NO_VALUE:
|
||||
return NO_VALUE
|
||||
if token_info is None:
|
||||
raise OAuthResponseProblem(detail=exception_msg)
|
||||
return token_info
|
||||
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
def get_auth_header_value(request):
|
||||
"""
|
||||
Return Authorization type and value if any.
|
||||
If not Authorization, return (None, None)
|
||||
Raise OAuthProblem for invalid Authorization header
|
||||
"""
|
||||
authorization = request.headers.get("Authorization")
|
||||
if not authorization:
|
||||
return None, None
|
||||
|
||||
try:
|
||||
auth_type, value = authorization.split(maxsplit=1)
|
||||
except ValueError:
|
||||
raise OAuthProblem(detail="Invalid authorization header")
|
||||
return auth_type.lower(), value
|
||||
|
||||
def _need_to_add_context_or_scopes(self, func):
|
||||
arguments, has_kwargs = inspect_function_arguments(func)
|
||||
need_context = self.context_kw in arguments
|
||||
need_required_scopes = has_kwargs or self.required_scopes_kw in arguments
|
||||
return need_context, need_required_scopes
|
||||
|
||||
def _resolve_func(self, security_scheme):
|
||||
"""
|
||||
Get the user function object based on the security scheme or the environment variable.
|
||||
|
||||
:param security_scheme: Security Definition (scheme) from the spec.
|
||||
"""
|
||||
return self._get_function(
|
||||
security_scheme, self.security_definition_key, self.environ_key
|
||||
)
|
||||
|
||||
def _get_verify_func(self, function):
|
||||
"""
|
||||
Wraps the user security function in a function that checks the request for the correct
|
||||
security credentials and calls the user function with the correct arguments.
|
||||
"""
|
||||
return self._generic_check(function, "Provided authorization is not valid")
|
||||
|
||||
|
||||
class BasicSecurityHandler(AbstractSecurityHandler):
|
||||
"""
|
||||
Security Handler for
|
||||
- `type: basic` (Swagger 2), and
|
||||
- `type: http` and `scheme: basic` (OpenAPI 3)
|
||||
"""
|
||||
|
||||
security_definition_key = "x-basicInfoFunc"
|
||||
environ_key = "BASICINFO_FUNC"
|
||||
|
||||
def _get_verify_func(self, basic_info_func):
|
||||
check_basic_info_func = self.check_basic_auth(basic_info_func)
|
||||
|
||||
def wrapper(request):
|
||||
auth_type, user_pass = self.get_auth_header_value(request)
|
||||
if auth_type != "basic":
|
||||
return NO_VALUE
|
||||
|
||||
try:
|
||||
username, password = (
|
||||
base64.b64decode(user_pass).decode("latin1").split(":", 1)
|
||||
)
|
||||
except Exception:
|
||||
raise OAuthProblem(detail="Invalid authorization header")
|
||||
|
||||
return check_basic_info_func(request, username, password)
|
||||
|
||||
return wrapper
|
||||
|
||||
def check_basic_auth(self, basic_info_func):
|
||||
return self._generic_check(
|
||||
basic_info_func, "Provided authorization is not valid"
|
||||
)
|
||||
|
||||
|
||||
class BearerSecurityHandler(AbstractSecurityHandler):
|
||||
"""
|
||||
Security Handler for HTTP Bearer authentication.
|
||||
"""
|
||||
|
||||
security_definition_key = "x-bearerInfoFunc"
|
||||
environ_key = "BEARERINFO_FUNC"
|
||||
|
||||
def check_bearer_token(self, token_info_func):
|
||||
return self._generic_check(token_info_func, "Provided token is not valid")
|
||||
|
||||
def _get_verify_func(self, token_info_func):
|
||||
"""
|
||||
:param token_info_func: types.FunctionType
|
||||
:rtype: types.FunctionType
|
||||
"""
|
||||
check_bearer_func = self.check_bearer_token(token_info_func)
|
||||
|
||||
def wrapper(request):
|
||||
auth_type, token = self.get_auth_header_value(request)
|
||||
if auth_type != "bearer":
|
||||
return NO_VALUE
|
||||
return check_bearer_func(request, token)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class ApiKeySecurityHandler(AbstractSecurityHandler):
|
||||
"""
|
||||
Security Handler for API Keys.
|
||||
"""
|
||||
|
||||
security_definition_key = "x-apikeyInfoFunc"
|
||||
environ_key = "APIKEYINFO_FUNC"
|
||||
|
||||
def get_fn(self, security_scheme, required_scopes):
|
||||
apikey_info_func = self._resolve_func(security_scheme)
|
||||
if not apikey_info_func:
|
||||
logger.warning("... %s missing", self.security_definition_key)
|
||||
return None
|
||||
|
||||
return self._get_verify_func(
|
||||
apikey_info_func,
|
||||
security_scheme["in"],
|
||||
security_scheme["name"],
|
||||
)
|
||||
|
||||
def _get_verify_func(self, api_key_info_func, loc, name):
|
||||
check_api_key_func = self.check_api_key(api_key_info_func)
|
||||
|
||||
def wrapper(request: ASGIRequest):
|
||||
if loc == "query":
|
||||
api_key = request.query_params.get(name)
|
||||
elif loc == "header":
|
||||
api_key = request.headers.get(name)
|
||||
elif loc == "cookie":
|
||||
cookie_list = request.headers.get("Cookie")
|
||||
api_key = self.get_cookie_value(cookie_list, name)
|
||||
else:
|
||||
return NO_VALUE
|
||||
|
||||
if api_key is None:
|
||||
return NO_VALUE
|
||||
|
||||
return check_api_key_func(request, api_key)
|
||||
|
||||
return wrapper
|
||||
|
||||
def check_api_key(self, api_key_info_func):
|
||||
return self._generic_check(api_key_info_func, "Provided apikey is not valid")
|
||||
|
||||
@staticmethod
|
||||
def get_cookie_value(cookies, name):
|
||||
"""
|
||||
Returns cookie value by its name. `None` if no such value.
|
||||
|
||||
:param cookies: str: cookies raw data
|
||||
:param name: str: cookies key
|
||||
"""
|
||||
cookie_parser = http.cookies.SimpleCookie()
|
||||
cookie_parser.load(str(cookies))
|
||||
try:
|
||||
return cookie_parser[name].value
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
|
||||
class OAuthSecurityHandler(AbstractSecurityHandler):
|
||||
"""
|
||||
Security Handler for the OAuth security scheme.
|
||||
"""
|
||||
|
||||
def get_fn(self, security_scheme, required_scopes):
|
||||
token_info_func = self.get_tokeninfo_func(security_scheme)
|
||||
scope_validate_func = self.get_scope_validate_func(security_scheme)
|
||||
if not token_info_func:
|
||||
logger.warning("... x-tokenInfoFunc missing")
|
||||
return None
|
||||
|
||||
return self._get_verify_func(
|
||||
token_info_func, scope_validate_func, required_scopes
|
||||
)
|
||||
|
||||
def get_tokeninfo_func(self, security_definition: dict) -> t.Optional[t.Callable]:
|
||||
"""
|
||||
:type security_definition: dict
|
||||
Gets the function for retrieving the token info.
|
||||
It is possible to specify a function or a URL. The function variant is
|
||||
preferred. If it is not found, the URL variant is used with the
|
||||
`get_token_info_remote` function.
|
||||
|
||||
>>> get_tokeninfo_url({'x-tokenInfoFunc': 'foo.bar'})
|
||||
>>> get_tokeninfo_func({'x-tokenInfoFunc': 'foo.bar'})
|
||||
'<function foo.bar>'
|
||||
"""
|
||||
token_info_func = self._get_function(
|
||||
@@ -83,8 +335,8 @@ class SecurityHandlerFactory:
|
||||
@classmethod
|
||||
def get_scope_validate_func(cls, security_definition):
|
||||
"""
|
||||
:type security_definition: dict
|
||||
:rtype: function
|
||||
Gets the function for validating the token scopes.
|
||||
If it is not found, the default `validate_scope` function is used.
|
||||
|
||||
>>> get_scope_validate_func({'x-scopeValidateFunc': 'foo.bar'})
|
||||
'<function foo.bar>'
|
||||
@@ -96,61 +348,6 @@ class SecurityHandlerFactory:
|
||||
cls.validate_scope,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_basicinfo_func(cls, security_definition):
|
||||
"""
|
||||
:type security_definition: dict
|
||||
:rtype: function
|
||||
|
||||
>>> get_basicinfo_func({'x-basicInfoFunc': 'foo.bar'})
|
||||
'<function foo.bar>'
|
||||
"""
|
||||
return cls._get_function(
|
||||
security_definition, "x-basicInfoFunc", "BASICINFO_FUNC"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_apikeyinfo_func(cls, security_definition):
|
||||
"""
|
||||
:type security_definition: dict
|
||||
:rtype: function
|
||||
|
||||
>>> get_apikeyinfo_func({'x-apikeyInfoFunc': 'foo.bar'})
|
||||
'<function foo.bar>'
|
||||
"""
|
||||
return cls._get_function(
|
||||
security_definition, "x-apikeyInfoFunc", "APIKEYINFO_FUNC"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_bearerinfo_func(cls, security_definition):
|
||||
"""
|
||||
:type security_definition: dict
|
||||
:rtype: function
|
||||
|
||||
>>> get_bearerinfo_func({'x-bearerInfoFunc': 'foo.bar'})
|
||||
'<function foo.bar>'
|
||||
"""
|
||||
return cls._get_function(
|
||||
security_definition, "x-bearerInfoFunc", "BEARERINFO_FUNC"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def security_passthrough(request):
|
||||
return request
|
||||
|
||||
@staticmethod
|
||||
def security_deny(function):
|
||||
"""
|
||||
:type function: types.FunctionType
|
||||
:rtype: types.FunctionType
|
||||
"""
|
||||
|
||||
def deny(*args, **kwargs):
|
||||
raise ConnexionException("Error in security definitions")
|
||||
|
||||
return deny
|
||||
|
||||
@staticmethod
|
||||
def validate_scope(required_scopes, token_scopes):
|
||||
"""
|
||||
@@ -167,210 +364,49 @@ class SecurityHandlerFactory:
|
||||
logger.debug("... Token scopes: %s", token_scopes)
|
||||
if not required_scopes <= token_scopes:
|
||||
logger.info(
|
||||
textwrap.dedent(
|
||||
"""
|
||||
... Token scopes (%s) do not match the scopes necessary to call endpoint (%s).
|
||||
Aborting with 403."""
|
||||
).replace("\n", ""),
|
||||
"... Token scopes (%s) do not match the scopes necessary to call endpoint (%s)."
|
||||
" Aborting with 403.",
|
||||
token_scopes,
|
||||
required_scopes,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_auth_header_value(request):
|
||||
def get_token_info_remote(self, token_info_url: str) -> t.Callable:
|
||||
"""
|
||||
Called inside security wrapper functions
|
||||
Return a function which will call `token_info_url` to retrieve token info.
|
||||
|
||||
Return Authorization type and value if any.
|
||||
If not Authorization, return (None, None)
|
||||
Raise OAuthProblem for invalid Authorization header
|
||||
Returned function must accept oauth token in parameter.
|
||||
It must return a token_info dict in case of success, None otherwise.
|
||||
|
||||
:param token_info_url: URL to get information about the token
|
||||
"""
|
||||
authorization = request.headers.get("Authorization")
|
||||
if not authorization:
|
||||
return None, None
|
||||
|
||||
try:
|
||||
auth_type, value = authorization.split(None, 1)
|
||||
except ValueError:
|
||||
raise OAuthProblem(detail="Invalid authorization header")
|
||||
return auth_type.lower(), value
|
||||
async def wrapper(token):
|
||||
if self.client is None:
|
||||
self.client = httpx.AsyncClient()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
token_request = await self.client.get(
|
||||
token_info_url, headers=headers, timeout=5
|
||||
)
|
||||
if token_request.status_code != 200:
|
||||
return
|
||||
return token_request.json()
|
||||
|
||||
def verify_oauth(self, token_info_func, scope_validate_func, required_scopes):
|
||||
return wrapper
|
||||
|
||||
def _get_verify_func(self, token_info_func, scope_validate_func, required_scopes):
|
||||
check_oauth_func = self.check_oauth_func(token_info_func, scope_validate_func)
|
||||
|
||||
def wrapper(request):
|
||||
auth_type, token = self.get_auth_header_value(request)
|
||||
if auth_type != "bearer":
|
||||
return self.no_value
|
||||
return NO_VALUE
|
||||
|
||||
return check_oauth_func(request, token, required_scopes=required_scopes)
|
||||
|
||||
return wrapper
|
||||
|
||||
def verify_basic(self, basic_info_func):
|
||||
check_basic_info_func = self.check_basic_auth(basic_info_func)
|
||||
|
||||
def wrapper(request):
|
||||
auth_type, user_pass = self.get_auth_header_value(request)
|
||||
if auth_type != "basic":
|
||||
return self.no_value
|
||||
|
||||
try:
|
||||
username, password = (
|
||||
base64.b64decode(user_pass).decode("latin1").split(":", 1)
|
||||
)
|
||||
except Exception:
|
||||
raise OAuthProblem(detail="Invalid authorization header")
|
||||
|
||||
return check_basic_info_func(request, username, password)
|
||||
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
def get_cookie_value(cookies, name):
|
||||
"""
|
||||
Called inside security wrapper functions
|
||||
|
||||
Returns cookie value by its name. None if no such value.
|
||||
:param cookies: str: cookies raw data
|
||||
:param name: str: cookies key
|
||||
"""
|
||||
cookie_parser = http.cookies.SimpleCookie()
|
||||
cookie_parser.load(str(cookies))
|
||||
try:
|
||||
return cookie_parser[name].value
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def verify_api_key(self, api_key_info_func, loc, name):
|
||||
check_api_key_func = self.check_api_key(api_key_info_func)
|
||||
|
||||
def wrapper(request):
|
||||
def _immutable_pop(_dict, key):
|
||||
"""
|
||||
Pops the key from an immutable dict and returns the value that was popped,
|
||||
and a new immutable dict without the popped key.
|
||||
"""
|
||||
cls = type(_dict)
|
||||
try:
|
||||
_dict = _dict.to_dict(flat=False)
|
||||
return _dict.pop(key)[0], cls(_dict)
|
||||
except AttributeError:
|
||||
_dict = dict(_dict.items())
|
||||
return _dict.pop(key), cls(_dict)
|
||||
|
||||
if loc == "query":
|
||||
try:
|
||||
api_key, request.query = _immutable_pop(request.query, name)
|
||||
except KeyError:
|
||||
api_key = None
|
||||
elif loc == "header":
|
||||
api_key = request.headers.get(name)
|
||||
elif loc == "cookie":
|
||||
cookie_list = request.headers.get("Cookie")
|
||||
api_key = self.get_cookie_value(cookie_list, name)
|
||||
else:
|
||||
return self.no_value
|
||||
|
||||
if api_key is None:
|
||||
return self.no_value
|
||||
|
||||
return check_api_key_func(request, api_key)
|
||||
|
||||
return wrapper
|
||||
|
||||
def verify_bearer(self, token_info_func):
|
||||
"""
|
||||
:param token_info_func: types.FunctionType
|
||||
:rtype: types.FunctionType
|
||||
"""
|
||||
check_bearer_func = self.check_bearer_token(token_info_func)
|
||||
|
||||
def wrapper(request):
|
||||
auth_type, token = self.get_auth_header_value(request)
|
||||
if auth_type != "bearer":
|
||||
return self.no_value
|
||||
return check_bearer_func(request, token)
|
||||
|
||||
return wrapper
|
||||
|
||||
def verify_multiple_schemes(self, schemes):
|
||||
"""
|
||||
Verifies multiple authentication schemes in AND fashion.
|
||||
If any scheme fails, the entire authentication fails.
|
||||
|
||||
:param schemes: mapping scheme_name to auth function
|
||||
:type schemes: dict
|
||||
:rtype: types.FunctionType
|
||||
"""
|
||||
|
||||
async def wrapper(request):
|
||||
token_info = {}
|
||||
for scheme_name, func in schemes.items():
|
||||
result = func(request)
|
||||
while asyncio.iscoroutine(result):
|
||||
result = await result
|
||||
if result is self.no_value:
|
||||
return self.no_value
|
||||
token_info[scheme_name] = result
|
||||
|
||||
return token_info
|
||||
|
||||
return wrapper
|
||||
|
||||
@staticmethod
|
||||
def verify_none():
|
||||
"""
|
||||
:rtype: types.FunctionType
|
||||
"""
|
||||
|
||||
def wrapper(request):
|
||||
return {}
|
||||
|
||||
return wrapper
|
||||
|
||||
def _need_to_add_context_or_scopes(self, func):
|
||||
arguments, has_kwargs = inspect_function_arguments(func)
|
||||
need_context = self.context_kw in arguments
|
||||
need_required_scopes = has_kwargs or self.required_scopes_kw in arguments
|
||||
return need_context, need_required_scopes
|
||||
|
||||
def _generic_check(self, func, exception_msg):
|
||||
(
|
||||
need_to_add_context,
|
||||
need_to_add_required_scopes,
|
||||
) = self._need_to_add_context_or_scopes(func)
|
||||
|
||||
async def wrapper(request, *args, required_scopes=None):
|
||||
kwargs = {}
|
||||
if need_to_add_context:
|
||||
kwargs[self.context_kw] = request.context
|
||||
if need_to_add_required_scopes:
|
||||
kwargs[self.required_scopes_kw] = required_scopes
|
||||
token_info = func(*args, **kwargs)
|
||||
while asyncio.iscoroutine(token_info):
|
||||
token_info = await token_info
|
||||
if token_info is self.no_value:
|
||||
return self.no_value
|
||||
if token_info is None:
|
||||
raise OAuthResponseProblem(detail=exception_msg)
|
||||
return token_info
|
||||
|
||||
return wrapper
|
||||
|
||||
def check_bearer_token(self, token_info_func):
|
||||
return self._generic_check(token_info_func, "Provided token is not valid")
|
||||
|
||||
def check_basic_auth(self, basic_info_func):
|
||||
return self._generic_check(
|
||||
basic_info_func, "Provided authorization is not valid"
|
||||
)
|
||||
|
||||
def check_api_key(self, api_key_info_func):
|
||||
return self._generic_check(api_key_info_func, "Provided apikey is not valid")
|
||||
|
||||
def check_oauth_func(self, token_info_func, scope_validate_func):
|
||||
get_token_info = self._generic_check(
|
||||
token_info_func, "Provided token is not valid"
|
||||
@@ -403,17 +439,141 @@ class SecurityHandlerFactory:
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
SECURITY_HANDLERS = {
|
||||
# Swagger 2: `type: basic`
|
||||
# OpenAPI 3: `type: http` and `scheme: basic`
|
||||
"basic": BasicSecurityHandler,
|
||||
# Swagger 2 and OpenAPI 3
|
||||
"apiKey": ApiKeySecurityHandler,
|
||||
"oauth2": OAuthSecurityHandler,
|
||||
# OpenAPI 3: http schemes
|
||||
"bearer": BearerSecurityHandler,
|
||||
}
|
||||
|
||||
|
||||
class SecurityHandlerFactory:
|
||||
"""
|
||||
A factory class for parsing security schemes and returning the appropriate
|
||||
security handler.
|
||||
|
||||
By default, it will use the built-in security handlers specified in the
|
||||
SECURITY_HANDLERS dict, but you can also pass in your own security handlers
|
||||
to override the built-in ones.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
security_handlers: t.Optional[dict] = None,
|
||||
) -> None:
|
||||
self.security_handlers = SECURITY_HANDLERS.copy()
|
||||
if security_handlers is not None:
|
||||
self.security_handlers.update(security_handlers)
|
||||
|
||||
def parse_security_scheme(
|
||||
self,
|
||||
security_scheme: dict,
|
||||
required_scopes: t.List[str],
|
||||
) -> t.Optional[t.Callable]:
|
||||
"""Parses the security scheme and returns the function for verifying it.
|
||||
|
||||
:param security_scheme: The security scheme from the spec.
|
||||
:param required_scopes: List of scopes for this security scheme.
|
||||
"""
|
||||
security_type = security_scheme["type"]
|
||||
if security_type in ("basic", "oauth2"):
|
||||
security_handler = self.security_handlers[security_type]
|
||||
return security_handler().get_fn(security_scheme, required_scopes)
|
||||
|
||||
# OpenAPI 3.0.0
|
||||
elif security_type == "http":
|
||||
scheme = security_scheme["scheme"].lower()
|
||||
if scheme in self.security_handlers:
|
||||
security_handler = self.security_handlers[scheme]
|
||||
return security_handler().get_fn(security_scheme, required_scopes)
|
||||
else:
|
||||
logger.warning("... Unsupported http authorization scheme %s", scheme)
|
||||
return None
|
||||
|
||||
elif security_type == "apiKey":
|
||||
scheme = security_scheme.get("x-authentication-scheme", "").lower()
|
||||
if scheme == "bearer":
|
||||
return BearerSecurityHandler().get_fn(security_scheme, required_scopes)
|
||||
else:
|
||||
security_handler = self.security_handlers["apiKey"]
|
||||
return security_handler().get_fn(security_scheme, required_scopes)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"... Unsupported security scheme type %s",
|
||||
security_type,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def security_passthrough(request):
|
||||
"""Used when no security is required for the operation.
|
||||
|
||||
Equivalent OpenAPI snippet:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
/helloworld
|
||||
get:
|
||||
security: [] # No security
|
||||
...
|
||||
"""
|
||||
return request
|
||||
|
||||
@staticmethod
|
||||
def verify_none(request):
|
||||
"""Used for optional security.
|
||||
|
||||
Equivalent OpenAPI snippet:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
security:
|
||||
- {} # <--
|
||||
- myapikey: []
|
||||
"""
|
||||
return {}
|
||||
|
||||
def verify_multiple_schemes(self, schemes):
|
||||
"""
|
||||
Verifies multiple authentication schemes in AND fashion.
|
||||
If any scheme fails, the entire authentication fails.
|
||||
|
||||
:param schemes: mapping scheme_name to auth function
|
||||
:type schemes: dict
|
||||
:rtype: types.FunctionType
|
||||
"""
|
||||
|
||||
async def wrapper(request):
|
||||
token_info = {}
|
||||
for scheme_name, func in schemes.items():
|
||||
result = func(request)
|
||||
while asyncio.iscoroutine(result):
|
||||
result = await result
|
||||
if result is NO_VALUE:
|
||||
return NO_VALUE
|
||||
token_info[scheme_name] = result
|
||||
|
||||
return token_info
|
||||
|
||||
return wrapper
|
||||
|
||||
@classmethod
|
||||
def verify_security(cls, auth_funcs):
|
||||
async def verify_fn(request):
|
||||
token_info = cls.no_value
|
||||
token_info = NO_VALUE
|
||||
errors = []
|
||||
for func in auth_funcs:
|
||||
try:
|
||||
token_info = func(request)
|
||||
while asyncio.iscoroutine(token_info):
|
||||
token_info = await token_info
|
||||
if token_info is not cls.no_value:
|
||||
if token_info is not NO_VALUE:
|
||||
break
|
||||
except Exception as err:
|
||||
errors.append(err)
|
||||
@@ -465,28 +625,3 @@ class SecurityHandlerFactory:
|
||||
else:
|
||||
lowest_status_code = min(status_to_exc)
|
||||
raise status_to_exc[lowest_status_code]
|
||||
|
||||
def get_token_info_remote(self, token_info_url):
|
||||
"""
|
||||
Return a function which will call `token_info_url` to retrieve token info.
|
||||
|
||||
Returned function must accept oauth token in parameter.
|
||||
It must return a token_info dict in case of success, None otherwise.
|
||||
|
||||
:param token_info_url: Url to get information about the token
|
||||
:type token_info_url: str
|
||||
:rtype: types.FunctionType
|
||||
"""
|
||||
|
||||
async def wrapper(token):
|
||||
if self.client is None:
|
||||
self.client = httpx.AsyncClient()
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
token_request = await self.client.get(
|
||||
token_info_url, headers=headers, timeout=5
|
||||
)
|
||||
if token_request.status_code != 200:
|
||||
return
|
||||
return token_request.json()
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -19,11 +19,20 @@ except AttributeError: # jsonschema < 4.5.0
|
||||
|
||||
|
||||
class ParameterValidator:
|
||||
def __init__(self, parameters, uri_parser, strict_validation=False):
|
||||
def __init__(
|
||||
self,
|
||||
parameters,
|
||||
uri_parser,
|
||||
strict_validation=False,
|
||||
security_query_params=None,
|
||||
):
|
||||
"""
|
||||
:param parameters: List of request parameter dictionaries
|
||||
:param uri_parser: class to use for uri parsing
|
||||
:param strict_validation: Flag indicating if parameters not in spec are allowed
|
||||
:param security_query_params: List of query parameter names used for security.
|
||||
These parameters will be ignored when checking for extra parameters in case of
|
||||
strict validation.
|
||||
"""
|
||||
self.parameters = collections.defaultdict(list)
|
||||
for p in parameters:
|
||||
@@ -31,6 +40,7 @@ class ParameterValidator:
|
||||
|
||||
self.uri_parser = uri_parser
|
||||
self.strict_validation = strict_validation
|
||||
self.security_query_params = set(security_query_params or [])
|
||||
|
||||
@staticmethod
|
||||
def validate_parameter(parameter_type, value, param, param_name=None):
|
||||
@@ -59,9 +69,10 @@ class ParameterValidator:
|
||||
|
||||
return request_params.difference(spec_params)
|
||||
|
||||
def validate_query_parameter_list(self, request):
|
||||
def validate_query_parameter_list(self, request, security_params=None):
|
||||
request_params = request.query_params.keys()
|
||||
spec_params = [x["name"] for x in self.parameters.get("query", [])]
|
||||
spec_params.extend(security_params or [])
|
||||
return self.validate_parameter_list(request_params, spec_params)
|
||||
|
||||
def validate_query_parameter(self, param, request):
|
||||
@@ -99,9 +110,10 @@ class ParameterValidator:
|
||||
self.validate_request(request)
|
||||
|
||||
def validate_request(self, request):
|
||||
|
||||
if self.strict_validation:
|
||||
query_errors = self.validate_query_parameter_list(request)
|
||||
query_errors = self.validate_query_parameter_list(
|
||||
request, security_params=self.security_query_params
|
||||
)
|
||||
|
||||
if query_errors:
|
||||
raise ExtraParameterProblem(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
from connexion.security import SecurityHandlerFactory
|
||||
from connexion.security import OAuthSecurityHandler
|
||||
|
||||
|
||||
class FakeResponse:
|
||||
@@ -50,7 +50,7 @@ def oauth_requests(monkeypatch):
|
||||
)
|
||||
return url
|
||||
|
||||
monkeypatch.setattr(SecurityHandlerFactory, "client", FakeClient())
|
||||
monkeypatch.setattr(OAuthSecurityHandler, "client", FakeClient())
|
||||
|
||||
|
||||
def test_security_over_nonexistent_endpoints(oauth_requests, secure_api_app):
|
||||
|
||||
@@ -10,12 +10,19 @@ from connexion.exceptions import (
|
||||
OAuthResponseProblem,
|
||||
OAuthScopeProblem,
|
||||
)
|
||||
from connexion.security import SecurityHandlerFactory
|
||||
from connexion.lifecycle import ASGIRequest
|
||||
from connexion.security import (
|
||||
NO_VALUE,
|
||||
ApiKeySecurityHandler,
|
||||
BasicSecurityHandler,
|
||||
OAuthSecurityHandler,
|
||||
SecurityHandlerFactory,
|
||||
)
|
||||
|
||||
|
||||
def test_get_tokeninfo_url(monkeypatch):
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
security_handler_factory.get_token_info_remote = MagicMock(
|
||||
security_handler = OAuthSecurityHandler()
|
||||
security_handler.get_token_info_remote = MagicMock(
|
||||
return_value="get_token_info_remote_result"
|
||||
)
|
||||
env = {}
|
||||
@@ -24,24 +31,24 @@ def test_get_tokeninfo_url(monkeypatch):
|
||||
monkeypatch.setattr("connexion.security.logger", logger)
|
||||
|
||||
security_def = {}
|
||||
assert security_handler_factory.get_tokeninfo_func(security_def) is None
|
||||
assert security_handler.get_tokeninfo_func(security_def) is None
|
||||
logger.warn.assert_not_called()
|
||||
|
||||
env["TOKENINFO_URL"] = "issue-146"
|
||||
assert (
|
||||
security_handler_factory.get_tokeninfo_func(security_def)
|
||||
security_handler.get_tokeninfo_func(security_def)
|
||||
== "get_token_info_remote_result"
|
||||
)
|
||||
security_handler_factory.get_token_info_remote.assert_called_with("issue-146")
|
||||
security_handler.get_token_info_remote.assert_called_with("issue-146")
|
||||
logger.warn.assert_not_called()
|
||||
logger.warn.reset_mock()
|
||||
|
||||
security_def = {"x-tokenInfoUrl": "bar"}
|
||||
assert (
|
||||
security_handler_factory.get_tokeninfo_func(security_def)
|
||||
security_handler.get_tokeninfo_func(security_def)
|
||||
== "get_token_info_remote_result"
|
||||
)
|
||||
security_handler_factory.get_token_info_remote.assert_called_with("bar")
|
||||
security_handler.get_token_info_remote.assert_called_with("bar")
|
||||
logger.warn.assert_not_called()
|
||||
|
||||
|
||||
@@ -49,15 +56,14 @@ def test_verify_oauth_missing_auth_header():
|
||||
def somefunc(token):
|
||||
return None
|
||||
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
wrapped_func = security_handler_factory.verify_oauth(
|
||||
somefunc, security_handler_factory.validate_scope, ["admin"]
|
||||
security_handler = OAuthSecurityHandler()
|
||||
wrapped_func = security_handler._get_verify_func(
|
||||
somefunc, security_handler.validate_scope, ["admin"]
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {}
|
||||
request = ASGIRequest(scope={"type": "http", "headers": []})
|
||||
|
||||
assert wrapped_func(request) is security_handler_factory.no_value
|
||||
assert wrapped_func(request) is NO_VALUE
|
||||
|
||||
|
||||
async def test_verify_oauth_scopes_remote(monkeypatch):
|
||||
@@ -69,20 +75,21 @@ async def test_verify_oauth_scopes_remote(monkeypatch):
|
||||
tokeninfo_response._content = json.dumps(tokeninfo).encode()
|
||||
return tokeninfo_response
|
||||
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
token_info_func = security_handler_factory.get_tokeninfo_func(
|
||||
security_handler = OAuthSecurityHandler()
|
||||
token_info_func = security_handler.get_tokeninfo_func(
|
||||
{"x-tokenInfoUrl": "https://example.org/tokeninfo"}
|
||||
)
|
||||
wrapped_func = security_handler_factory.verify_oauth(
|
||||
token_info_func, security_handler_factory.validate_scope, ["admin"]
|
||||
wrapped_func = security_handler._get_verify_func(
|
||||
token_info_func, security_handler.validate_scope, ["admin"]
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {"Authorization": "Bearer 123"}
|
||||
request = ASGIRequest(
|
||||
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
|
||||
)
|
||||
|
||||
client = MagicMock()
|
||||
client.get = get_tokeninfo_response
|
||||
monkeypatch.setattr(SecurityHandlerFactory, "client", client)
|
||||
monkeypatch.setattr(OAuthSecurityHandler, "client", client)
|
||||
|
||||
with pytest.raises(OAuthScopeProblem) as exc_info:
|
||||
await wrapped_func(request)
|
||||
@@ -112,13 +119,14 @@ async def test_verify_oauth_invalid_local_token_response_none():
|
||||
def somefunc(token):
|
||||
return None
|
||||
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
wrapped_func = security_handler_factory.verify_oauth(
|
||||
somefunc, security_handler_factory.validate_scope, ["admin"]
|
||||
security_handler = OAuthSecurityHandler()
|
||||
wrapped_func = security_handler._get_verify_func(
|
||||
somefunc, security_handler.validate_scope, ["admin"]
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {"Authorization": "Bearer 123"}
|
||||
request = ASGIRequest(
|
||||
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
|
||||
)
|
||||
|
||||
with pytest.raises(OAuthResponseProblem):
|
||||
await wrapped_func(request)
|
||||
@@ -130,13 +138,14 @@ async def test_verify_oauth_scopes_local():
|
||||
def token_info(token):
|
||||
return tokeninfo
|
||||
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
wrapped_func = security_handler_factory.verify_oauth(
|
||||
token_info, security_handler_factory.validate_scope, ["admin"]
|
||||
security_handler = OAuthSecurityHandler()
|
||||
wrapped_func = security_handler._get_verify_func(
|
||||
token_info, security_handler.validate_scope, ["admin"]
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {"Authorization": "Bearer 123"}
|
||||
request = ASGIRequest(
|
||||
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
|
||||
)
|
||||
|
||||
with pytest.raises(OAuthScopeProblem) as exc_info:
|
||||
await wrapped_func(request)
|
||||
@@ -166,13 +175,14 @@ def test_verify_basic_missing_auth_header():
|
||||
def somefunc(username, password, required_scopes=None):
|
||||
return None
|
||||
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
wrapped_func = security_handler_factory.verify_basic(somefunc)
|
||||
security_handler = BasicSecurityHandler()
|
||||
wrapped_func = security_handler._get_verify_func(somefunc)
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {"Authorization": "Bearer 123"}
|
||||
request = ASGIRequest(
|
||||
scope={"type": "http", "headers": [[b"authorization", b"Bearer 123"]]}
|
||||
)
|
||||
|
||||
assert wrapped_func(request) is security_handler_factory.no_value
|
||||
assert wrapped_func(request) is NO_VALUE
|
||||
|
||||
|
||||
async def test_verify_basic():
|
||||
@@ -181,11 +191,12 @@ async def test_verify_basic():
|
||||
return {"sub": "foo"}
|
||||
return None
|
||||
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
wrapped_func = security_handler_factory.verify_basic(basic_info)
|
||||
security_handler = BasicSecurityHandler()
|
||||
wrapped_func = security_handler._get_verify_func(basic_info)
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {"Authorization": "Basic Zm9vOmJhcg=="}
|
||||
request = ASGIRequest(
|
||||
scope={"type": "http", "headers": [[b"authorization", b"Basic Zm9vOmJhcg=="]]}
|
||||
)
|
||||
|
||||
assert await wrapped_func(request) is not None
|
||||
|
||||
@@ -196,11 +207,12 @@ async def test_verify_apikey_query():
|
||||
return {"sub": "foo"}
|
||||
return None
|
||||
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
wrapped_func = security_handler_factory.verify_api_key(apikey_info, "query", "auth")
|
||||
security_handler_factory = ApiKeySecurityHandler()
|
||||
wrapped_func = security_handler_factory._get_verify_func(
|
||||
apikey_info, "query", "auth"
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.query = {"auth": "foobar"}
|
||||
request = ASGIRequest(scope={"type": "http", "query_string": b"auth=foobar"})
|
||||
|
||||
assert await wrapped_func(request) is not None
|
||||
|
||||
@@ -211,13 +223,12 @@ async def test_verify_apikey_header():
|
||||
return {"sub": "foo"}
|
||||
return None
|
||||
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
wrapped_func = security_handler_factory.verify_api_key(
|
||||
security_handler_factory = ApiKeySecurityHandler()
|
||||
wrapped_func = security_handler_factory._get_verify_func(
|
||||
apikey_info, "header", "X-Auth"
|
||||
)
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {"X-Auth": "foobar"}
|
||||
request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth", b"foobar"]]})
|
||||
|
||||
assert await wrapped_func(request) is not None
|
||||
|
||||
@@ -234,10 +245,11 @@ async def test_multiple_schemes():
|
||||
return None
|
||||
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
wrapped_func_key1 = security_handler_factory.verify_api_key(
|
||||
apikey_security_handler_factory = ApiKeySecurityHandler()
|
||||
wrapped_func_key1 = apikey_security_handler_factory._get_verify_func(
|
||||
apikey1_info, "header", "X-Auth-1"
|
||||
)
|
||||
wrapped_func_key2 = security_handler_factory.verify_api_key(
|
||||
wrapped_func_key2 = apikey_security_handler_factory._get_verify_func(
|
||||
apikey2_info, "header", "X-Auth-2"
|
||||
)
|
||||
schemes = {
|
||||
@@ -247,19 +259,21 @@ async def test_multiple_schemes():
|
||||
wrapped_func = security_handler_factory.verify_multiple_schemes(schemes)
|
||||
|
||||
# Single key does not succeed
|
||||
request = MagicMock()
|
||||
request.headers = {"X-Auth-1": "foobar"}
|
||||
request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth-1", b"foobar"]]})
|
||||
|
||||
assert await wrapped_func(request) is security_handler_factory.no_value
|
||||
assert await wrapped_func(request) is NO_VALUE
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {"X-Auth-2": "bar"}
|
||||
request = ASGIRequest(scope={"type": "http", "headers": [[b"x-auth-2", b"bar"]]})
|
||||
|
||||
assert await wrapped_func(request) is security_handler_factory.no_value
|
||||
assert await wrapped_func(request) is NO_VALUE
|
||||
|
||||
# Supplying both keys does succeed
|
||||
request = MagicMock()
|
||||
request.headers = {"X-Auth-1": "foobar", "X-Auth-2": "bar"}
|
||||
request = ASGIRequest(
|
||||
scope={
|
||||
"type": "http",
|
||||
"headers": [[b"x-auth-1", b"foobar"], [b"x-auth-2", b"bar"]],
|
||||
}
|
||||
)
|
||||
|
||||
expected_token_info = {
|
||||
"key1": {"sub": "foo"},
|
||||
@@ -273,7 +287,7 @@ async def test_verify_security_oauthproblem():
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
security_func = security_handler_factory.verify_security([])
|
||||
|
||||
request = MagicMock()
|
||||
request = MagicMock(spec_set=ASGIRequest)
|
||||
with pytest.raises(OAuthProblem) as exc_info:
|
||||
await security_func(request)
|
||||
|
||||
|
||||
@@ -12,7 +12,11 @@ from connexion.jsonifier import Jsonifier
|
||||
from connexion.middleware.security import SecurityOperation
|
||||
from connexion.operations import Swagger2Operation
|
||||
from connexion.resolver import Resolver
|
||||
from connexion.security import SecurityHandlerFactory
|
||||
from connexion.security import (
|
||||
ApiKeySecurityHandler,
|
||||
OAuthSecurityHandler,
|
||||
SecurityHandlerFactory,
|
||||
)
|
||||
|
||||
TEST_FOLDER = pathlib.Path(__file__).parent
|
||||
|
||||
@@ -418,11 +422,14 @@ def test_operation(api):
|
||||
|
||||
|
||||
def test_operation_remote_token_info():
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
class MockOAuthHandler(OAuthSecurityHandler):
|
||||
"""Mock."""
|
||||
|
||||
security_handler_factory = SecurityHandlerFactory({"oauth2": MockOAuthHandler})
|
||||
oauth_security_handler = security_handler_factory.security_handlers["oauth2"]
|
||||
verify_oauth = mock.MagicMock(return_value="verify_oauth_result")
|
||||
security_handler_factory.verify_oauth = verify_oauth
|
||||
security_handler_factory.get_token_info_remote = mock.MagicMock(
|
||||
oauth_security_handler._get_verify_func = verify_oauth
|
||||
oauth_security_handler.get_token_info_remote = mock.MagicMock(
|
||||
return_value="get_token_info_remote_result"
|
||||
)
|
||||
|
||||
@@ -434,9 +441,11 @@ def test_operation_remote_token_info():
|
||||
)
|
||||
|
||||
verify_oauth.assert_called_with(
|
||||
"get_token_info_remote_result", security_handler_factory.validate_scope, ["uid"]
|
||||
"get_token_info_remote_result",
|
||||
oauth_security_handler.validate_scope,
|
||||
["uid"],
|
||||
)
|
||||
security_handler_factory.get_token_info_remote.assert_called_with(
|
||||
oauth_security_handler.get_token_info_remote.assert_called_with(
|
||||
"https://oauth.example/token_info"
|
||||
)
|
||||
|
||||
@@ -491,10 +500,13 @@ def test_operation_composed_definition(api):
|
||||
|
||||
|
||||
def test_operation_local_security_oauth2():
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
class MockOAuthHandler(OAuthSecurityHandler):
|
||||
"""Mock."""
|
||||
|
||||
security_handler_factory = SecurityHandlerFactory({"oauth2": MockOAuthHandler})
|
||||
oauth_security_handler = security_handler_factory.security_handlers["oauth2"]
|
||||
verify_oauth = mock.MagicMock(return_value="verify_oauth_result")
|
||||
security_handler_factory.verify_oauth = verify_oauth
|
||||
oauth_security_handler._get_verify_func = verify_oauth
|
||||
|
||||
SecurityOperation(
|
||||
next_app=mock.Mock,
|
||||
@@ -504,15 +516,24 @@ def test_operation_local_security_oauth2():
|
||||
)
|
||||
|
||||
verify_oauth.assert_called_with(
|
||||
math.ceil, security_handler_factory.validate_scope, ["uid"]
|
||||
math.ceil, oauth_security_handler.validate_scope, ["uid"]
|
||||
)
|
||||
|
||||
verify_oauth.assert_called_with(
|
||||
math.ceil,
|
||||
security_handler_factory.security_handlers["oauth2"].validate_scope,
|
||||
["uid"],
|
||||
)
|
||||
|
||||
|
||||
def test_operation_local_security_duplicate_token_info():
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
class MockOAuthHandler(OAuthSecurityHandler):
|
||||
"""Mock."""
|
||||
|
||||
security_handler_factory = SecurityHandlerFactory({"oauth2": MockOAuthHandler})
|
||||
oauth_security_handler = security_handler_factory.security_handlers["oauth2"]
|
||||
verify_oauth = mock.MagicMock(return_value="verify_oauth_result")
|
||||
security_handler_factory.verify_oauth = verify_oauth
|
||||
oauth_security_handler._get_verify_func = verify_oauth
|
||||
|
||||
SecurityOperation(
|
||||
next_app=mock.Mock,
|
||||
@@ -522,7 +543,11 @@ def test_operation_local_security_duplicate_token_info():
|
||||
)
|
||||
|
||||
verify_oauth.call_args.assert_called_with(
|
||||
math.ceil, security_handler_factory.validate_scope
|
||||
math.ceil, oauth_security_handler.validate_scope
|
||||
)
|
||||
|
||||
verify_oauth.call_args.assert_called_with(
|
||||
math.ceil, security_handler_factory.security_handlers["oauth2"].validate_scope
|
||||
)
|
||||
|
||||
|
||||
@@ -565,9 +590,13 @@ def test_multiple_security_schemes_and():
|
||||
def return_api_key_name(func, in_, name):
|
||||
return name
|
||||
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
class MockApiKeyHandler(ApiKeySecurityHandler):
|
||||
"""Mock"""
|
||||
|
||||
security_handler_factory = SecurityHandlerFactory({"apiKey": MockApiKeyHandler})
|
||||
apikey_security_handler = security_handler_factory.security_handlers["apiKey"]
|
||||
verify_api_key = mock.MagicMock(side_effect=return_api_key_name)
|
||||
security_handler_factory.verify_api_key = verify_api_key
|
||||
apikey_security_handler._get_verify_func = verify_api_key
|
||||
verify_multiple = mock.MagicMock(return_value="verify_multiple_result")
|
||||
security_handler_factory.verify_multiple_schemes = verify_multiple
|
||||
|
||||
@@ -595,9 +624,6 @@ def test_multiple_oauth_in_and(caplog):
|
||||
caplog.set_level(logging.WARNING, logger="connexion.operations.secure")
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
|
||||
verify_oauth = mock.MagicMock(return_value="verify_oauth_result")
|
||||
security_handler_factory.verify_oauth = verify_oauth
|
||||
|
||||
security = [{"oauth_1": ["uid"], "oauth_2": ["uid"]}]
|
||||
|
||||
SecurityOperation(
|
||||
@@ -688,10 +714,14 @@ def test_get_path_parameter_types(api):
|
||||
|
||||
def test_oauth_scopes_in_or():
|
||||
"""Tests whether an OAuth security scheme with 2 different possible scopes is correctly handled."""
|
||||
security_handler_factory = SecurityHandlerFactory()
|
||||
|
||||
class MockOAuthFactory(OAuthSecurityHandler):
|
||||
"""Mock."""
|
||||
|
||||
security_handler_factory = SecurityHandlerFactory({"oauth2": MockOAuthFactory})
|
||||
oauth_security_handler = security_handler_factory.security_handlers["oauth2"]
|
||||
verify_oauth = mock.MagicMock(return_value="verify_oauth_result")
|
||||
security_handler_factory.verify_oauth = verify_oauth
|
||||
oauth_security_handler._get_verify_func = verify_oauth
|
||||
|
||||
security = [{"oauth": ["myscope"]}, {"oauth": ["myscope2"]}]
|
||||
|
||||
@@ -704,8 +734,8 @@ def test_oauth_scopes_in_or():
|
||||
|
||||
verify_oauth.assert_has_calls(
|
||||
[
|
||||
mock.call(math.ceil, security_handler_factory.validate_scope, ["myscope"]),
|
||||
mock.call(math.ceil, security_handler_factory.validate_scope, ["myscope2"]),
|
||||
mock.call(math.ceil, oauth_security_handler.validate_scope, ["myscope"]),
|
||||
mock.call(math.ceil, oauth_security_handler.validate_scope, ["myscope2"]),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user