mirror of
https://github.com/LukeHagar/connexion.git
synced 2025-12-06 12:27:45 +00:00
Use resolver in security middleware (#1553)
* Use resolver in security middleware * Initialize RoutingOperation with operation object
This commit is contained in:
@@ -38,6 +38,7 @@ class AbstractSpecAPI(metaclass=AbstractAPIMeta):
|
|||||||
self,
|
self,
|
||||||
specification: t.Union[pathlib.Path, str, dict],
|
specification: t.Union[pathlib.Path, str, dict],
|
||||||
base_path: t.Optional[str] = None,
|
base_path: t.Optional[str] = None,
|
||||||
|
resolver: t.Optional[Resolver] = None,
|
||||||
arguments: t.Optional[dict] = None,
|
arguments: t.Optional[dict] = None,
|
||||||
options: t.Optional[dict] = None,
|
options: t.Optional[dict] = None,
|
||||||
*args,
|
*args,
|
||||||
@@ -48,6 +49,9 @@ class AbstractSpecAPI(metaclass=AbstractAPIMeta):
|
|||||||
:param specification: OpenAPI specification. Can be provided either as dict, or as path
|
:param specification: OpenAPI specification. Can be provided either as dict, or as path
|
||||||
to file.
|
to file.
|
||||||
:param base_path: Base path to host the API.
|
:param base_path: Base path to host the API.
|
||||||
|
:param resolver: Callable that maps operationID to a function
|
||||||
|
:param resolver_error_handler: Callable that generates an Operation used for handling
|
||||||
|
ResolveErrors
|
||||||
:param arguments: Jinja arguments to resolve in specification.
|
:param arguments: Jinja arguments to resolve in specification.
|
||||||
:param options: New style options dictionary.
|
:param options: New style options dictionary.
|
||||||
"""
|
"""
|
||||||
@@ -70,6 +74,8 @@ class AbstractSpecAPI(metaclass=AbstractAPIMeta):
|
|||||||
|
|
||||||
self._set_base_path(base_path)
|
self._set_base_path(base_path)
|
||||||
|
|
||||||
|
self.resolver = resolver or Resolver()
|
||||||
|
|
||||||
def _set_base_path(self, base_path: t.Optional[str] = None) -> None:
|
def _set_base_path(self, base_path: t.Optional[str] = None) -> None:
|
||||||
if base_path is not None:
|
if base_path is not None:
|
||||||
# update spec to include user-provided base_path
|
# update spec to include user-provided base_path
|
||||||
@@ -121,7 +127,6 @@ class AbstractRoutingAPI(AbstractSpecAPI):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
resolver: t.Optional[Resolver] = None,
|
|
||||||
resolver_error_handler: t.Optional[t.Callable] = None,
|
resolver_error_handler: t.Optional[t.Callable] = None,
|
||||||
debug: bool = False,
|
debug: bool = False,
|
||||||
pass_context_arg_name: t.Optional[str] = None,
|
pass_context_arg_name: t.Optional[str] = None,
|
||||||
@@ -129,9 +134,6 @@ class AbstractRoutingAPI(AbstractSpecAPI):
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Minimal interface of an API, with only functionality related to routing.
|
"""Minimal interface of an API, with only functionality related to routing.
|
||||||
|
|
||||||
:param resolver: Callable that maps operationID to a function
|
|
||||||
:param resolver_error_handler: Callable that generates an Operation used for handling
|
|
||||||
ResolveErrors
|
|
||||||
:param debug: Flag to run in debug mode
|
:param debug: Flag to run in debug mode
|
||||||
:param pass_context_arg_name: If not None URL request handling functions with an argument
|
:param pass_context_arg_name: If not None URL request handling functions with an argument
|
||||||
matching this name will be passed the framework's request context.
|
matching this name will be passed the framework's request context.
|
||||||
@@ -140,8 +142,6 @@ class AbstractRoutingAPI(AbstractSpecAPI):
|
|||||||
self.debug = debug
|
self.debug = debug
|
||||||
self.resolver_error_handler = resolver_error_handler
|
self.resolver_error_handler = resolver_error_handler
|
||||||
|
|
||||||
self.resolver = resolver or Resolver()
|
|
||||||
|
|
||||||
logger.debug('pass_context_arg_name: %s', pass_context_arg_name)
|
logger.debug('pass_context_arg_name: %s', pass_context_arg_name)
|
||||||
self.pass_context_arg_name = pass_context_arg_name
|
self.pass_context_arg_name = pass_context_arg_name
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from starlette.types import ASGIApp, Receive, Scope, Send
|
|||||||
from connexion.apis import AbstractRoutingAPI
|
from connexion.apis import AbstractRoutingAPI
|
||||||
from connexion.exceptions import NotFoundProblem
|
from connexion.exceptions import NotFoundProblem
|
||||||
from connexion.middleware import AppMiddleware
|
from connexion.middleware import AppMiddleware
|
||||||
|
from connexion.operations import AbstractOperation
|
||||||
from connexion.resolver import Resolver
|
from connexion.resolver import Resolver
|
||||||
|
|
||||||
ROUTING_CONTEXT = 'connexion_routing'
|
ROUTING_CONTEXT = 'connexion_routing'
|
||||||
@@ -91,7 +92,7 @@ class RoutingAPI(AbstractRoutingAPI):
|
|||||||
def add_operation(self, path: str, method: str) -> None:
|
def add_operation(self, path: str, method: str) -> None:
|
||||||
operation_cls = self.specification.operation_cls
|
operation_cls = self.specification.operation_cls
|
||||||
operation = operation_cls.from_spec(self.specification, self, path, method, self.resolver)
|
operation = operation_cls.from_spec(self.specification, self, path, method, self.resolver)
|
||||||
routing_operation = RoutingOperation(operation.operation_id, next_app=self.next_app)
|
routing_operation = RoutingOperation.from_operation(operation, next_app=self.next_app)
|
||||||
self._add_operation_internal(method, path, routing_operation)
|
self._add_operation_internal(method, path, routing_operation)
|
||||||
|
|
||||||
def _add_operation_internal(self, method: str, path: str, operation: 'RoutingOperation') -> None:
|
def _add_operation_internal(self, method: str, path: str, operation: 'RoutingOperation') -> None:
|
||||||
@@ -104,6 +105,10 @@ class RoutingOperation:
|
|||||||
self.operation_id = operation_id
|
self.operation_id = operation_id
|
||||||
self.next_app = next_app
|
self.next_app = next_app
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_operation(cls, operation: AbstractOperation, next_app: ASGIApp):
|
||||||
|
return cls(operation.operation_id, next_app)
|
||||||
|
|
||||||
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
|
||||||
"""Attach operation to scope and pass it to the next app"""
|
"""Attach operation to scope and pass it to the next app"""
|
||||||
original_scope = _scope.get()
|
original_scope = _scope.get()
|
||||||
|
|||||||
@@ -6,12 +6,15 @@ from collections import defaultdict
|
|||||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||||
|
|
||||||
from connexion.apis.abstract import AbstractSpecAPI
|
from connexion.apis.abstract import AbstractSpecAPI
|
||||||
from connexion.exceptions import MissingMiddleware
|
from connexion.exceptions import MissingMiddleware, ProblemException
|
||||||
from connexion.http_facts import METHODS
|
from connexion.http_facts import METHODS
|
||||||
from connexion.lifecycle import MiddlewareRequest
|
from connexion.lifecycle import MiddlewareRequest
|
||||||
from connexion.middleware import AppMiddleware
|
from connexion.middleware import AppMiddleware
|
||||||
from connexion.middleware.routing import ROUTING_CONTEXT
|
from connexion.middleware.routing import ROUTING_CONTEXT
|
||||||
|
from connexion.operations import AbstractOperation
|
||||||
|
from connexion.resolver import ResolverError
|
||||||
from connexion.security import SecurityHandlerFactory
|
from connexion.security import SecurityHandlerFactory
|
||||||
|
from connexion.spec import Specification
|
||||||
|
|
||||||
logger = logging.getLogger("connexion.middleware.security")
|
logger = logging.getLogger("connexion.middleware.security")
|
||||||
|
|
||||||
@@ -69,8 +72,6 @@ class SecurityAPI(AbstractSpecAPI):
|
|||||||
):
|
):
|
||||||
super().__init__(specification, *args, **kwargs)
|
super().__init__(specification, *args, **kwargs)
|
||||||
self.security_handler_factory = SecurityHandlerFactory('context')
|
self.security_handler_factory = SecurityHandlerFactory('context')
|
||||||
self.app_security = self.specification.security
|
|
||||||
self.security_schemes = self.specification.security_definitions
|
|
||||||
|
|
||||||
if auth_all_paths:
|
if auth_all_paths:
|
||||||
self.add_auth_on_not_found()
|
self.add_auth_on_not_found()
|
||||||
@@ -81,30 +82,36 @@ class SecurityAPI(AbstractSpecAPI):
|
|||||||
|
|
||||||
def add_auth_on_not_found(self):
|
def add_auth_on_not_found(self):
|
||||||
"""Register a default SecurityOperation for routes that are not found."""
|
"""Register a default SecurityOperation for routes that are not found."""
|
||||||
default_operation = self.make_operation()
|
default_operation = self.make_operation(self.specification)
|
||||||
self.operations = defaultdict(lambda: default_operation)
|
self.operations = defaultdict(lambda: default_operation)
|
||||||
|
|
||||||
def add_paths(self):
|
def add_paths(self):
|
||||||
paths = self.specification.get('paths', {})
|
paths = self.specification.get('paths', {})
|
||||||
for path, methods in paths.items():
|
for path, methods in paths.items():
|
||||||
for method, operation in methods.items():
|
for method in methods:
|
||||||
if method not in METHODS:
|
if method not in METHODS:
|
||||||
continue
|
continue
|
||||||
operation_id = operation.get('operationId')
|
try:
|
||||||
if operation_id:
|
self.add_operation(path, method)
|
||||||
self.operations[operation_id] = self.make_operation(operation)
|
except ResolverError:
|
||||||
|
# ResolverErrors are either raised or handled in routing middleware.
|
||||||
|
pass
|
||||||
|
|
||||||
def make_operation(self, operation_spec: dict = None):
|
def add_operation(self, path: str, method: str) -> None:
|
||||||
security = self.app_security
|
operation_cls = self.specification.operation_cls
|
||||||
if operation_spec:
|
operation = operation_cls.from_spec(self.specification, self, path, method, self.resolver)
|
||||||
security = operation_spec.get('security', self.app_security)
|
security_operation = self.make_operation(operation)
|
||||||
|
self._add_operation_internal(operation.operation_id, security_operation)
|
||||||
|
|
||||||
return SecurityOperation(
|
def make_operation(self, operation: t.Union[AbstractOperation, Specification]):
|
||||||
self.security_handler_factory,
|
return SecurityOperation.from_operation(
|
||||||
security=security,
|
operation,
|
||||||
security_schemes=self.specification.security_definitions
|
security_handler_factory=self.security_handler_factory,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _add_operation_internal(self, operation_id: str, operation: 'SecurityOperation'):
|
||||||
|
self.operations[operation_id] = operation
|
||||||
|
|
||||||
|
|
||||||
class SecurityOperation:
|
class SecurityOperation:
|
||||||
|
|
||||||
@@ -119,6 +126,18 @@ class SecurityOperation:
|
|||||||
self.security_schemes = security_schemes
|
self.security_schemes = security_schemes
|
||||||
self.verification_fn = self._get_verification_fn()
|
self.verification_fn = self._get_verification_fn()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_operation(
|
||||||
|
cls,
|
||||||
|
operation: AbstractOperation,
|
||||||
|
security_handler_factory: SecurityHandlerFactory
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
security_handler_factory,
|
||||||
|
security=operation.security,
|
||||||
|
security_schemes=operation.security_schemes
|
||||||
|
)
|
||||||
|
|
||||||
def _get_verification_fn(self):
|
def _get_verification_fn(self):
|
||||||
logger.debug('... Security: %s', self.security, extra=vars(self))
|
logger.debug('... Security: %s', self.security, extra=vars(self))
|
||||||
if not self.security:
|
if not self.security:
|
||||||
@@ -234,5 +253,5 @@ class SecurityOperation:
|
|||||||
await self.verification_fn(request)
|
await self.verification_fn(request)
|
||||||
|
|
||||||
|
|
||||||
class MissingSecurityOperation(Exception):
|
class MissingSecurityOperation(ProblemException):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ class AbstractOperation(metaclass=abc.ABCMeta):
|
|||||||
serious_business(stuff)
|
serious_business(stuff)
|
||||||
"""
|
"""
|
||||||
def __init__(self, api, method, path, operation, resolver,
|
def __init__(self, api, method, path, operation, resolver,
|
||||||
|
app_security=None, security_schemes=None,
|
||||||
validate_responses=False, strict_validation=False,
|
validate_responses=False, strict_validation=False,
|
||||||
randomize_endpoint=None, validator_map=None,
|
randomize_endpoint=None, validator_map=None,
|
||||||
pythonic_params=False, uri_parser_class=None,
|
pythonic_params=False, uri_parser_class=None,
|
||||||
@@ -57,7 +58,6 @@ class AbstractOperation(metaclass=abc.ABCMeta):
|
|||||||
:param operation: swagger operation object
|
:param operation: swagger operation object
|
||||||
:type operation: dict
|
:type operation: dict
|
||||||
:param resolver: Callable that maps operationID to a function
|
:param resolver: Callable that maps operationID to a function
|
||||||
:param app_produces: list of content types the application can return by default
|
|
||||||
:param app_security: list of security rules the application uses by default
|
:param app_security: list of security rules the application uses by default
|
||||||
:type app_security: list
|
:type app_security: list
|
||||||
:param security_schemes: `Security Definitions Object
|
:param security_schemes: `Security Definitions Object
|
||||||
@@ -85,6 +85,8 @@ class AbstractOperation(metaclass=abc.ABCMeta):
|
|||||||
self._path = path
|
self._path = path
|
||||||
self._operation = operation
|
self._operation = operation
|
||||||
self._resolver = resolver
|
self._resolver = resolver
|
||||||
|
self._security = operation.get('security', app_security)
|
||||||
|
self._security_schemes = security_schemes
|
||||||
self._validate_responses = validate_responses
|
self._validate_responses = validate_responses
|
||||||
self._strict_validation = strict_validation
|
self._strict_validation = strict_validation
|
||||||
self._pythonic_params = pythonic_params
|
self._pythonic_params = pythonic_params
|
||||||
@@ -119,6 +121,14 @@ class AbstractOperation(metaclass=abc.ABCMeta):
|
|||||||
"""
|
"""
|
||||||
return self._path
|
return self._path
|
||||||
|
|
||||||
|
@property
|
||||||
|
def security(self):
|
||||||
|
return self._security
|
||||||
|
|
||||||
|
@property
|
||||||
|
def security_schemes(self):
|
||||||
|
return self._security_schemes
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def responses(self):
|
def responses(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ class OpenAPIOperation(AbstractOperation):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, api, method, path, operation, resolver, path_parameters=None,
|
def __init__(self, api, method, path, operation, resolver, path_parameters=None,
|
||||||
|
app_security=None, security_schemes=None,
|
||||||
components=None, validate_responses=False, strict_validation=False,
|
components=None, validate_responses=False, strict_validation=False,
|
||||||
randomize_endpoint=None, validator_map=None,
|
randomize_endpoint=None, validator_map=None,
|
||||||
pythonic_params=False, uri_parser_class=None, pass_context_arg_name=None):
|
pythonic_params=False, uri_parser_class=None, pass_context_arg_name=None):
|
||||||
@@ -44,6 +45,11 @@ class OpenAPIOperation(AbstractOperation):
|
|||||||
:param resolver: Callable that maps operationID to a function
|
:param resolver: Callable that maps operationID to a function
|
||||||
:param path_parameters: Parameters defined in the path level
|
:param path_parameters: Parameters defined in the path level
|
||||||
:type path_parameters: list
|
:type path_parameters: list
|
||||||
|
:param app_security: list of security rules the application uses by default
|
||||||
|
:type app_security: list
|
||||||
|
:param security_schemes: `Security Definitions Object
|
||||||
|
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#security-definitions-object>`_
|
||||||
|
:type security_schemes: dict
|
||||||
:param components: `Components Object
|
:param components: `Components Object
|
||||||
<https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.1.md#componentsObject>`_
|
<https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.1.md#componentsObject>`_
|
||||||
:type components: dict
|
:type components: dict
|
||||||
@@ -76,6 +82,8 @@ class OpenAPIOperation(AbstractOperation):
|
|||||||
path=path,
|
path=path,
|
||||||
operation=operation,
|
operation=operation,
|
||||||
resolver=resolver,
|
resolver=resolver,
|
||||||
|
app_security=app_security,
|
||||||
|
security_schemes=security_schemes,
|
||||||
validate_responses=validate_responses,
|
validate_responses=validate_responses,
|
||||||
strict_validation=strict_validation,
|
strict_validation=strict_validation,
|
||||||
randomize_endpoint=randomize_endpoint,
|
randomize_endpoint=randomize_endpoint,
|
||||||
@@ -116,6 +124,8 @@ class OpenAPIOperation(AbstractOperation):
|
|||||||
spec.get_operation(path, method),
|
spec.get_operation(path, method),
|
||||||
resolver=resolver,
|
resolver=resolver,
|
||||||
path_parameters=spec.get_path_params(path),
|
path_parameters=spec.get_path_params(path),
|
||||||
|
app_security=spec.security,
|
||||||
|
security_schemes=spec.security_schemes,
|
||||||
components=spec.components,
|
components=spec.components,
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs
|
||||||
|
|||||||
@@ -27,7 +27,8 @@ class Swagger2Operation(AbstractOperation):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, api, method, path, operation, resolver, app_produces, app_consumes,
|
def __init__(self, api, method, path, operation, resolver, app_produces, app_consumes,
|
||||||
path_parameters=None, definitions=None, validate_responses=False,
|
path_parameters=None, app_security=None, security_schemes=None,
|
||||||
|
definitions=None, validate_responses=False,
|
||||||
strict_validation=False, randomize_endpoint=None, validator_map=None,
|
strict_validation=False, randomize_endpoint=None, validator_map=None,
|
||||||
pythonic_params=False, uri_parser_class=None, pass_context_arg_name=None):
|
pythonic_params=False, uri_parser_class=None, pass_context_arg_name=None):
|
||||||
"""
|
"""
|
||||||
@@ -47,6 +48,11 @@ class Swagger2Operation(AbstractOperation):
|
|||||||
:type app_consumes: list
|
:type app_consumes: list
|
||||||
:param path_parameters: Parameters defined in the path level
|
:param path_parameters: Parameters defined in the path level
|
||||||
:type path_parameters: list
|
:type path_parameters: list
|
||||||
|
:param app_security: list of security rules the application uses by default
|
||||||
|
:type app_security: list
|
||||||
|
:param security_schemes: `Security Definitions Object
|
||||||
|
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#security-definitions-object>`_
|
||||||
|
:type security_schemes: dict
|
||||||
:param definitions: `Definitions Object
|
:param definitions: `Definitions Object
|
||||||
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#definitionsObject>`_
|
<https://github.com/swagger-api/swagger-spec/blob/master/versions/2.0.md#definitionsObject>`_
|
||||||
:type definitions: dict
|
:type definitions: dict
|
||||||
@@ -77,6 +83,8 @@ class Swagger2Operation(AbstractOperation):
|
|||||||
path=path,
|
path=path,
|
||||||
operation=operation,
|
operation=operation,
|
||||||
resolver=resolver,
|
resolver=resolver,
|
||||||
|
app_security=app_security,
|
||||||
|
security_schemes=security_schemes,
|
||||||
validate_responses=validate_responses,
|
validate_responses=validate_responses,
|
||||||
strict_validation=strict_validation,
|
strict_validation=strict_validation,
|
||||||
randomize_endpoint=randomize_endpoint,
|
randomize_endpoint=randomize_endpoint,
|
||||||
@@ -112,6 +120,8 @@ class Swagger2Operation(AbstractOperation):
|
|||||||
path_parameters=spec.get_path_params(path),
|
path_parameters=spec.get_path_params(path),
|
||||||
app_produces=spec.produces,
|
app_produces=spec.produces,
|
||||||
app_consumes=spec.consumes,
|
app_consumes=spec.consumes,
|
||||||
|
app_security=spec.security,
|
||||||
|
security_schemes=spec.security_schemes,
|
||||||
definitions=spec.definitions,
|
definitions=spec.definitions,
|
||||||
*args,
|
*args,
|
||||||
**kwargs
|
**kwargs
|
||||||
|
|||||||
@@ -239,7 +239,7 @@ class Swagger2Specification(Specification):
|
|||||||
return self._spec['responses']
|
return self._spec['responses']
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def security_definitions(self):
|
def security_schemes(self):
|
||||||
return self._spec.get('securityDefinitions', {})
|
return self._spec.get('securityDefinitions', {})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -268,7 +268,7 @@ class OpenAPISpecification(Specification):
|
|||||||
spec.setdefault('components', {})
|
spec.setdefault('components', {})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def security_definitions(self):
|
def security_schemes(self):
|
||||||
return self._spec['components'].get('securitySchemes', {})
|
return self._spec['components'].get('securitySchemes', {})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
Reference in New Issue
Block a user