Use resolver in security middleware (#1553)

* Use resolver in security middleware

* Initialize RoutingOperation with operation object
This commit is contained in:
Robbe Sneyders
2022-06-20 22:54:54 +02:00
committed by GitHub
parent b561ecfdaa
commit abbc8ff162
7 changed files with 82 additions and 28 deletions

View File

@@ -38,6 +38,7 @@ class AbstractSpecAPI(metaclass=AbstractAPIMeta):
self,
specification: t.Union[pathlib.Path, str, dict],
base_path: t.Optional[str] = None,
resolver: t.Optional[Resolver] = None,
arguments: t.Optional[dict] = None,
options: t.Optional[dict] = None,
*args,
@@ -48,6 +49,9 @@ class AbstractSpecAPI(metaclass=AbstractAPIMeta):
:param specification: OpenAPI specification. Can be provided either as dict, or as path
to file.
:param base_path: Base path to host the API.
:param resolver: Callable that maps operationID to a function
:param resolver_error_handler: Callable that generates an Operation used for handling
ResolveErrors
:param arguments: Jinja arguments to resolve in specification.
:param options: New style options dictionary.
"""
@@ -70,6 +74,8 @@ class AbstractSpecAPI(metaclass=AbstractAPIMeta):
self._set_base_path(base_path)
self.resolver = resolver or Resolver()
def _set_base_path(self, base_path: t.Optional[str] = None) -> None:
if base_path is not None:
# update spec to include user-provided base_path
@@ -121,7 +127,6 @@ class AbstractRoutingAPI(AbstractSpecAPI):
def __init__(
self,
*args,
resolver: t.Optional[Resolver] = None,
resolver_error_handler: t.Optional[t.Callable] = None,
debug: bool = False,
pass_context_arg_name: t.Optional[str] = None,
@@ -129,9 +134,6 @@ class AbstractRoutingAPI(AbstractSpecAPI):
) -> None:
"""Minimal interface of an API, with only functionality related to routing.
:param resolver: Callable that maps operationID to a function
:param resolver_error_handler: Callable that generates an Operation used for handling
ResolveErrors
:param debug: Flag to run in debug mode
:param pass_context_arg_name: If not None URL request handling functions with an argument
matching this name will be passed the framework's request context.
@@ -140,8 +142,6 @@ class AbstractRoutingAPI(AbstractSpecAPI):
self.debug = debug
self.resolver_error_handler = resolver_error_handler
self.resolver = resolver or Resolver()
logger.debug('pass_context_arg_name: %s', pass_context_arg_name)
self.pass_context_arg_name = pass_context_arg_name

View File

@@ -8,6 +8,7 @@ from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.apis import AbstractRoutingAPI
from connexion.exceptions import NotFoundProblem
from connexion.middleware import AppMiddleware
from connexion.operations import AbstractOperation
from connexion.resolver import Resolver
ROUTING_CONTEXT = 'connexion_routing'
@@ -91,7 +92,7 @@ class RoutingAPI(AbstractRoutingAPI):
def add_operation(self, path: str, method: str) -> None:
operation_cls = self.specification.operation_cls
operation = operation_cls.from_spec(self.specification, self, path, method, self.resolver)
routing_operation = RoutingOperation(operation.operation_id, next_app=self.next_app)
routing_operation = RoutingOperation.from_operation(operation, next_app=self.next_app)
self._add_operation_internal(method, path, routing_operation)
def _add_operation_internal(self, method: str, path: str, operation: 'RoutingOperation') -> None:
@@ -104,6 +105,10 @@ class RoutingOperation:
self.operation_id = operation_id
self.next_app = next_app
@classmethod
def from_operation(cls, operation: AbstractOperation, next_app: ASGIApp):
return cls(operation.operation_id, next_app)
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"""Attach operation to scope and pass it to the next app"""
original_scope = _scope.get()

View File

@@ -6,12 +6,15 @@ from collections import defaultdict
from starlette.types import ASGIApp, Receive, Scope, Send
from connexion.apis.abstract import AbstractSpecAPI
from connexion.exceptions import MissingMiddleware
from connexion.exceptions import MissingMiddleware, ProblemException
from connexion.http_facts import METHODS
from connexion.lifecycle import MiddlewareRequest
from connexion.middleware import AppMiddleware
from connexion.middleware.routing import ROUTING_CONTEXT
from connexion.operations import AbstractOperation
from connexion.resolver import ResolverError
from connexion.security import SecurityHandlerFactory
from connexion.spec import Specification
logger = logging.getLogger("connexion.middleware.security")
@@ -69,8 +72,6 @@ class SecurityAPI(AbstractSpecAPI):
):
super().__init__(specification, *args, **kwargs)
self.security_handler_factory = SecurityHandlerFactory('context')
self.app_security = self.specification.security
self.security_schemes = self.specification.security_definitions
if auth_all_paths:
self.add_auth_on_not_found()
@@ -81,30 +82,36 @@ class SecurityAPI(AbstractSpecAPI):
def add_auth_on_not_found(self):
"""Register a default SecurityOperation for routes that are not found."""
default_operation = self.make_operation()
default_operation = self.make_operation(self.specification)
self.operations = defaultdict(lambda: default_operation)
def add_paths(self):
paths = self.specification.get('paths', {})
for path, methods in paths.items():
for method, operation in methods.items():
for method in methods:
if method not in METHODS:
continue
operation_id = operation.get('operationId')
if operation_id:
self.operations[operation_id] = self.make_operation(operation)
try:
self.add_operation(path, method)
except ResolverError:
# ResolverErrors are either raised or handled in routing middleware.
pass
def make_operation(self, operation_spec: dict = None):
security = self.app_security
if operation_spec:
security = operation_spec.get('security', self.app_security)
def add_operation(self, path: str, method: str) -> None:
operation_cls = self.specification.operation_cls
operation = operation_cls.from_spec(self.specification, self, path, method, self.resolver)
security_operation = self.make_operation(operation)
self._add_operation_internal(operation.operation_id, security_operation)
return SecurityOperation(
self.security_handler_factory,
security=security,
security_schemes=self.specification.security_definitions
def make_operation(self, operation: t.Union[AbstractOperation, Specification]):
return SecurityOperation.from_operation(
operation,
security_handler_factory=self.security_handler_factory,
)
def _add_operation_internal(self, operation_id: str, operation: 'SecurityOperation'):
self.operations[operation_id] = operation
class SecurityOperation:
@@ -119,6 +126,18 @@ class SecurityOperation:
self.security_schemes = security_schemes
self.verification_fn = self._get_verification_fn()
@classmethod
def from_operation(
cls,
operation: AbstractOperation,
security_handler_factory: SecurityHandlerFactory
):
return cls(
security_handler_factory,
security=operation.security,
security_schemes=operation.security_schemes
)
def _get_verification_fn(self):
logger.debug('... Security: %s', self.security, extra=vars(self))
if not self.security:
@@ -234,5 +253,5 @@ class SecurityOperation:
await self.verification_fn(request)
class MissingSecurityOperation(Exception):
class MissingSecurityOperation(ProblemException):
pass

View File

@@ -43,6 +43,7 @@ class AbstractOperation(metaclass=abc.ABCMeta):
serious_business(stuff)
"""
def __init__(self, api, method, path, operation, resolver,
app_security=None, security_schemes=None,
validate_responses=False, strict_validation=False,
randomize_endpoint=None, validator_map=None,
pythonic_params=False, uri_parser_class=None,
@@ -57,7 +58,6 @@ class AbstractOperation(metaclass=abc.ABCMeta):
:param operation: swagger operation object
:type operation: dict
:param resolver: Callable that maps operationID to a function
:param app_produces: list of content types the application can return by default
:param app_security: list of security rules the application uses by default
:type app_security: list
:param security_schemes: `Security Definitions Object
@@ -85,6 +85,8 @@ class AbstractOperation(metaclass=abc.ABCMeta):
self._path = path
self._operation = operation
self._resolver = resolver
self._security = operation.get('security', app_security)
self._security_schemes = security_schemes
self._validate_responses = validate_responses
self._strict_validation = strict_validation
self._pythonic_params = pythonic_params
@@ -119,6 +121,14 @@ class AbstractOperation(metaclass=abc.ABCMeta):
"""
return self._path
@property
def security(self):
return self._security
@property
def security_schemes(self):
return self._security_schemes
@property
def responses(self):
"""

View File

@@ -22,6 +22,7 @@ class OpenAPIOperation(AbstractOperation):
"""
def __init__(self, api, method, path, operation, resolver, path_parameters=None,
app_security=None, security_schemes=None,
components=None, validate_responses=False, strict_validation=False,
randomize_endpoint=None, validator_map=None,
pythonic_params=False, uri_parser_class=None, pass_context_arg_name=None):
@@ -44,6 +45,11 @@ class OpenAPIOperation(AbstractOperation):
:param resolver: Callable that maps operationID to a function
:param path_parameters: Parameters defined in the path level
:type path_parameters: list
:param app_security: list of security rules the application uses by default
:type app_security: list
:param security_schemes: `Security Definitions Object
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#security-definitions-object>`_
:type security_schemes: dict
:param components: `Components Object
<https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.1.md#componentsObject>`_
:type components: dict
@@ -76,6 +82,8 @@ class OpenAPIOperation(AbstractOperation):
path=path,
operation=operation,
resolver=resolver,
app_security=app_security,
security_schemes=security_schemes,
validate_responses=validate_responses,
strict_validation=strict_validation,
randomize_endpoint=randomize_endpoint,
@@ -116,6 +124,8 @@ class OpenAPIOperation(AbstractOperation):
spec.get_operation(path, method),
resolver=resolver,
path_parameters=spec.get_path_params(path),
app_security=spec.security,
security_schemes=spec.security_schemes,
components=spec.components,
*args,
**kwargs

View File

@@ -27,7 +27,8 @@ class Swagger2Operation(AbstractOperation):
"""
def __init__(self, api, method, path, operation, resolver, app_produces, app_consumes,
path_parameters=None, definitions=None, validate_responses=False,
path_parameters=None, app_security=None, security_schemes=None,
definitions=None, validate_responses=False,
strict_validation=False, randomize_endpoint=None, validator_map=None,
pythonic_params=False, uri_parser_class=None, pass_context_arg_name=None):
"""
@@ -47,6 +48,11 @@ class Swagger2Operation(AbstractOperation):
:type app_consumes: list
:param path_parameters: Parameters defined in the path level
:type path_parameters: list
:param app_security: list of security rules the application uses by default
:type app_security: list
:param security_schemes: `Security Definitions Object
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#security-definitions-object>`_
:type security_schemes: dict
:param definitions: `Definitions Object
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#definitionsObject>`_
:type definitions: dict
@@ -77,6 +83,8 @@ class Swagger2Operation(AbstractOperation):
path=path,
operation=operation,
resolver=resolver,
app_security=app_security,
security_schemes=security_schemes,
validate_responses=validate_responses,
strict_validation=strict_validation,
randomize_endpoint=randomize_endpoint,
@@ -112,6 +120,8 @@ class Swagger2Operation(AbstractOperation):
path_parameters=spec.get_path_params(path),
app_produces=spec.produces,
app_consumes=spec.consumes,
app_security=spec.security,
security_schemes=spec.security_schemes,
definitions=spec.definitions,
*args,
**kwargs

View File

@@ -239,7 +239,7 @@ class Swagger2Specification(Specification):
return self._spec['responses']
@property
def security_definitions(self):
def security_schemes(self):
return self._spec.get('securityDefinitions', {})
@property
@@ -268,7 +268,7 @@ class OpenAPISpecification(Specification):
spec.setdefault('components', {})
@property
def security_definitions(self):
def security_schemes(self):
return self._spec['components'].get('securitySchemes', {})
@property