Use resolver in security middleware (#1553)

* Use resolver in security middleware

* Initialize RoutingOperation with operation object
This commit is contained in:
Robbe Sneyders
2022-06-20 22:54:54 +02:00
committed by GitHub
parent b561ecfdaa
commit abbc8ff162
7 changed files with 82 additions and 28 deletions

View File

@@ -38,6 +38,7 @@ class AbstractSpecAPI(metaclass=AbstractAPIMeta):
self, self,
specification: t.Union[pathlib.Path, str, dict], specification: t.Union[pathlib.Path, str, dict],
base_path: t.Optional[str] = None, base_path: t.Optional[str] = None,
resolver: t.Optional[Resolver] = None,
arguments: t.Optional[dict] = None, arguments: t.Optional[dict] = None,
options: t.Optional[dict] = None, options: t.Optional[dict] = None,
*args, *args,
@@ -48,6 +49,9 @@ class AbstractSpecAPI(metaclass=AbstractAPIMeta):
:param specification: OpenAPI specification. Can be provided either as dict, or as path :param specification: OpenAPI specification. Can be provided either as dict, or as path
to file. to file.
:param base_path: Base path to host the API. :param base_path: Base path to host the API.
:param resolver: Callable that maps operationID to a function
:param resolver_error_handler: Callable that generates an Operation used for handling
ResolveErrors
:param arguments: Jinja arguments to resolve in specification. :param arguments: Jinja arguments to resolve in specification.
:param options: New style options dictionary. :param options: New style options dictionary.
""" """
@@ -70,6 +74,8 @@ class AbstractSpecAPI(metaclass=AbstractAPIMeta):
self._set_base_path(base_path) self._set_base_path(base_path)
self.resolver = resolver or Resolver()
def _set_base_path(self, base_path: t.Optional[str] = None) -> None: def _set_base_path(self, base_path: t.Optional[str] = None) -> None:
if base_path is not None: if base_path is not None:
# update spec to include user-provided base_path # update spec to include user-provided base_path
@@ -121,7 +127,6 @@ class AbstractRoutingAPI(AbstractSpecAPI):
def __init__( def __init__(
self, self,
*args, *args,
resolver: t.Optional[Resolver] = None,
resolver_error_handler: t.Optional[t.Callable] = None, resolver_error_handler: t.Optional[t.Callable] = None,
debug: bool = False, debug: bool = False,
pass_context_arg_name: t.Optional[str] = None, pass_context_arg_name: t.Optional[str] = None,
@@ -129,9 +134,6 @@ class AbstractRoutingAPI(AbstractSpecAPI):
) -> None: ) -> None:
"""Minimal interface of an API, with only functionality related to routing. """Minimal interface of an API, with only functionality related to routing.
:param resolver: Callable that maps operationID to a function
:param resolver_error_handler: Callable that generates an Operation used for handling
ResolveErrors
:param debug: Flag to run in debug mode :param debug: Flag to run in debug mode
:param pass_context_arg_name: If not None URL request handling functions with an argument :param pass_context_arg_name: If not None URL request handling functions with an argument
matching this name will be passed the framework's request context. matching this name will be passed the framework's request context.
@@ -140,8 +142,6 @@ class AbstractRoutingAPI(AbstractSpecAPI):
self.debug = debug self.debug = debug
self.resolver_error_handler = resolver_error_handler self.resolver_error_handler = resolver_error_handler
self.resolver = resolver or Resolver()
logger.debug('pass_context_arg_name: %s', pass_context_arg_name) logger.debug('pass_context_arg_name: %s', pass_context_arg_name)
self.pass_context_arg_name = pass_context_arg_name self.pass_context_arg_name = pass_context_arg_name

View File

@@ -8,6 +8,7 @@ from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.apis import AbstractRoutingAPI from connexion.apis import AbstractRoutingAPI
from connexion.exceptions import NotFoundProblem from connexion.exceptions import NotFoundProblem
from connexion.middleware import AppMiddleware from connexion.middleware import AppMiddleware
from connexion.operations import AbstractOperation
from connexion.resolver import Resolver from connexion.resolver import Resolver
ROUTING_CONTEXT = 'connexion_routing' ROUTING_CONTEXT = 'connexion_routing'
@@ -91,7 +92,7 @@ class RoutingAPI(AbstractRoutingAPI):
def add_operation(self, path: str, method: str) -> None: def add_operation(self, path: str, method: str) -> None:
operation_cls = self.specification.operation_cls operation_cls = self.specification.operation_cls
operation = operation_cls.from_spec(self.specification, self, path, method, self.resolver) operation = operation_cls.from_spec(self.specification, self, path, method, self.resolver)
routing_operation = RoutingOperation(operation.operation_id, next_app=self.next_app) routing_operation = RoutingOperation.from_operation(operation, next_app=self.next_app)
self._add_operation_internal(method, path, routing_operation) self._add_operation_internal(method, path, routing_operation)
def _add_operation_internal(self, method: str, path: str, operation: 'RoutingOperation') -> None: def _add_operation_internal(self, method: str, path: str, operation: 'RoutingOperation') -> None:
@@ -104,6 +105,10 @@ class RoutingOperation:
self.operation_id = operation_id self.operation_id = operation_id
self.next_app = next_app self.next_app = next_app
@classmethod
def from_operation(cls, operation: AbstractOperation, next_app: ASGIApp):
return cls(operation.operation_id, next_app)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Attach operation to scope and pass it to the next app""" """Attach operation to scope and pass it to the next app"""
original_scope = _scope.get() original_scope = _scope.get()

View File

@@ -6,12 +6,15 @@ from collections import defaultdict
from starlette.types import ASGIApp, Receive, Scope, Send from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.apis.abstract import AbstractSpecAPI from connexion.apis.abstract import AbstractSpecAPI
from connexion.exceptions import MissingMiddleware from connexion.exceptions import MissingMiddleware, ProblemException
from connexion.http_facts import METHODS from connexion.http_facts import METHODS
from connexion.lifecycle import MiddlewareRequest from connexion.lifecycle import MiddlewareRequest
from connexion.middleware import AppMiddleware from connexion.middleware import AppMiddleware
from connexion.middleware.routing import ROUTING_CONTEXT from connexion.middleware.routing import ROUTING_CONTEXT
from connexion.operations import AbstractOperation
from connexion.resolver import ResolverError
from connexion.security import SecurityHandlerFactory from connexion.security import SecurityHandlerFactory
from connexion.spec import Specification
logger = logging.getLogger("connexion.middleware.security") logger = logging.getLogger("connexion.middleware.security")
@@ -69,8 +72,6 @@ class SecurityAPI(AbstractSpecAPI):
): ):
super().__init__(specification, *args, **kwargs) super().__init__(specification, *args, **kwargs)
self.security_handler_factory = SecurityHandlerFactory('context') self.security_handler_factory = SecurityHandlerFactory('context')
self.app_security = self.specification.security
self.security_schemes = self.specification.security_definitions
if auth_all_paths: if auth_all_paths:
self.add_auth_on_not_found() self.add_auth_on_not_found()
@@ -81,30 +82,36 @@ class SecurityAPI(AbstractSpecAPI):
def add_auth_on_not_found(self): def add_auth_on_not_found(self):
"""Register a default SecurityOperation for routes that are not found.""" """Register a default SecurityOperation for routes that are not found."""
default_operation = self.make_operation() default_operation = self.make_operation(self.specification)
self.operations = defaultdict(lambda: default_operation) self.operations = defaultdict(lambda: default_operation)
def add_paths(self): def add_paths(self):
paths = self.specification.get('paths', {}) paths = self.specification.get('paths', {})
for path, methods in paths.items(): for path, methods in paths.items():
for method, operation in methods.items(): for method in methods:
if method not in METHODS: if method not in METHODS:
continue continue
operation_id = operation.get('operationId') try:
if operation_id: self.add_operation(path, method)
self.operations[operation_id] = self.make_operation(operation) except ResolverError:
# ResolverErrors are either raised or handled in routing middleware.
pass
def make_operation(self, operation_spec: dict = None): def add_operation(self, path: str, method: str) -> None:
security = self.app_security operation_cls = self.specification.operation_cls
if operation_spec: operation = operation_cls.from_spec(self.specification, self, path, method, self.resolver)
security = operation_spec.get('security', self.app_security) security_operation = self.make_operation(operation)
self._add_operation_internal(operation.operation_id, security_operation)
return SecurityOperation( def make_operation(self, operation: t.Union[AbstractOperation, Specification]):
self.security_handler_factory, return SecurityOperation.from_operation(
security=security, operation,
security_schemes=self.specification.security_definitions security_handler_factory=self.security_handler_factory,
) )
def _add_operation_internal(self, operation_id: str, operation: 'SecurityOperation'):
self.operations[operation_id] = operation
class SecurityOperation: class SecurityOperation:
@@ -119,6 +126,18 @@ class SecurityOperation:
self.security_schemes = security_schemes self.security_schemes = security_schemes
self.verification_fn = self._get_verification_fn() self.verification_fn = self._get_verification_fn()
@classmethod
def from_operation(
cls,
operation: AbstractOperation,
security_handler_factory: SecurityHandlerFactory
):
return cls(
security_handler_factory,
security=operation.security,
security_schemes=operation.security_schemes
)
def _get_verification_fn(self): def _get_verification_fn(self):
logger.debug('... Security: %s', self.security, extra=vars(self)) logger.debug('... Security: %s', self.security, extra=vars(self))
if not self.security: if not self.security:
@@ -234,5 +253,5 @@ class SecurityOperation:
await self.verification_fn(request) await self.verification_fn(request)
class MissingSecurityOperation(Exception): class MissingSecurityOperation(ProblemException):
pass pass

View File

@@ -43,6 +43,7 @@ class AbstractOperation(metaclass=abc.ABCMeta):
serious_business(stuff) serious_business(stuff)
""" """
def __init__(self, api, method, path, operation, resolver, def __init__(self, api, method, path, operation, resolver,
app_security=None, security_schemes=None,
validate_responses=False, strict_validation=False, validate_responses=False, strict_validation=False,
randomize_endpoint=None, validator_map=None, randomize_endpoint=None, validator_map=None,
pythonic_params=False, uri_parser_class=None, pythonic_params=False, uri_parser_class=None,
@@ -57,7 +58,6 @@ class AbstractOperation(metaclass=abc.ABCMeta):
:param operation: swagger operation object :param operation: swagger operation object
:type operation: dict :type operation: dict
:param resolver: Callable that maps operationID to a function :param resolver: Callable that maps operationID to a function
:param app_produces: list of content types the application can return by default
:param app_security: list of security rules the application uses by default :param app_security: list of security rules the application uses by default
:type app_security: list :type app_security: list
:param security_schemes: `Security Definitions Object :param security_schemes: `Security Definitions Object
@@ -85,6 +85,8 @@ class AbstractOperation(metaclass=abc.ABCMeta):
self._path = path self._path = path
self._operation = operation self._operation = operation
self._resolver = resolver self._resolver = resolver
self._security = operation.get('security', app_security)
self._security_schemes = security_schemes
self._validate_responses = validate_responses self._validate_responses = validate_responses
self._strict_validation = strict_validation self._strict_validation = strict_validation
self._pythonic_params = pythonic_params self._pythonic_params = pythonic_params
@@ -119,6 +121,14 @@ class AbstractOperation(metaclass=abc.ABCMeta):
""" """
return self._path return self._path
@property
def security(self):
return self._security
@property
def security_schemes(self):
return self._security_schemes
@property @property
def responses(self): def responses(self):
""" """

View File

@@ -22,6 +22,7 @@ class OpenAPIOperation(AbstractOperation):
""" """
def __init__(self, api, method, path, operation, resolver, path_parameters=None, def __init__(self, api, method, path, operation, resolver, path_parameters=None,
app_security=None, security_schemes=None,
components=None, validate_responses=False, strict_validation=False, components=None, validate_responses=False, strict_validation=False,
randomize_endpoint=None, validator_map=None, randomize_endpoint=None, validator_map=None,
pythonic_params=False, uri_parser_class=None, pass_context_arg_name=None): pythonic_params=False, uri_parser_class=None, pass_context_arg_name=None):
@@ -44,6 +45,11 @@ class OpenAPIOperation(AbstractOperation):
:param resolver: Callable that maps operationID to a function :param resolver: Callable that maps operationID to a function
:param path_parameters: Parameters defined in the path level :param path_parameters: Parameters defined in the path level
:type path_parameters: list :type path_parameters: list
:param app_security: list of security rules the application uses by default
:type app_security: list
:param security_schemes: `Security Definitions Object
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#security-definitions-object>`_
:type security_schemes: dict
:param components: `Components Object :param components: `Components Object
<https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.1.md#componentsObject>`_ <https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.1.md#componentsObject>`_
:type components: dict :type components: dict
@@ -76,6 +82,8 @@ class OpenAPIOperation(AbstractOperation):
path=path, path=path,
operation=operation, operation=operation,
resolver=resolver, resolver=resolver,
app_security=app_security,
security_schemes=security_schemes,
validate_responses=validate_responses, validate_responses=validate_responses,
strict_validation=strict_validation, strict_validation=strict_validation,
randomize_endpoint=randomize_endpoint, randomize_endpoint=randomize_endpoint,
@@ -116,6 +124,8 @@ class OpenAPIOperation(AbstractOperation):
spec.get_operation(path, method), spec.get_operation(path, method),
resolver=resolver, resolver=resolver,
path_parameters=spec.get_path_params(path), path_parameters=spec.get_path_params(path),
app_security=spec.security,
security_schemes=spec.security_schemes,
components=spec.components, components=spec.components,
*args, *args,
**kwargs **kwargs

View File

@@ -27,7 +27,8 @@ class Swagger2Operation(AbstractOperation):
""" """
def __init__(self, api, method, path, operation, resolver, app_produces, app_consumes, def __init__(self, api, method, path, operation, resolver, app_produces, app_consumes,
path_parameters=None, definitions=None, validate_responses=False, path_parameters=None, app_security=None, security_schemes=None,
definitions=None, validate_responses=False,
strict_validation=False, randomize_endpoint=None, validator_map=None, strict_validation=False, randomize_endpoint=None, validator_map=None,
pythonic_params=False, uri_parser_class=None, pass_context_arg_name=None): pythonic_params=False, uri_parser_class=None, pass_context_arg_name=None):
""" """
@@ -47,6 +48,11 @@ class Swagger2Operation(AbstractOperation):
:type app_consumes: list :type app_consumes: list
:param path_parameters: Parameters defined in the path level :param path_parameters: Parameters defined in the path level
:type path_parameters: list :type path_parameters: list
:param app_security: list of security rules the application uses by default
:type app_security: list
:param security_schemes: `Security Definitions Object
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#security-definitions-object>`_
:type security_schemes: dict
:param definitions: `Definitions Object :param definitions: `Definitions Object
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#definitionsObject>`_ <https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#definitionsObject>`_
:type definitions: dict :type definitions: dict
@@ -77,6 +83,8 @@ class Swagger2Operation(AbstractOperation):
path=path, path=path,
operation=operation, operation=operation,
resolver=resolver, resolver=resolver,
app_security=app_security,
security_schemes=security_schemes,
validate_responses=validate_responses, validate_responses=validate_responses,
strict_validation=strict_validation, strict_validation=strict_validation,
randomize_endpoint=randomize_endpoint, randomize_endpoint=randomize_endpoint,
@@ -112,6 +120,8 @@ class Swagger2Operation(AbstractOperation):
path_parameters=spec.get_path_params(path), path_parameters=spec.get_path_params(path),
app_produces=spec.produces, app_produces=spec.produces,
app_consumes=spec.consumes, app_consumes=spec.consumes,
app_security=spec.security,
security_schemes=spec.security_schemes,
definitions=spec.definitions, definitions=spec.definitions,
*args, *args,
**kwargs **kwargs

View File

@@ -239,7 +239,7 @@ class Swagger2Specification(Specification):
return self._spec['responses'] return self._spec['responses']
@property @property
def security_definitions(self): def security_schemes(self):
return self._spec.get('securityDefinitions', {}) return self._spec.get('securityDefinitions', {})
@property @property
@@ -268,7 +268,7 @@ class OpenAPISpecification(Specification):
spec.setdefault('components', {}) spec.setdefault('components', {})
@property @property
def security_definitions(self): def security_schemes(self):
return self._spec['components'].get('securitySchemes', {}) return self._spec['components'].get('securitySchemes', {})
@property @property