Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…tools-python into feat/batch-new-processor

* 'develop' of https://github.com/awslabs/aws-lambda-powertools-python:
  fix(parser): kinesis sequence number is str, not int (aws-powertools#907)
  feat(apigateway): add exception_handler support (aws-powertools#898)
  fix(event-sources): Pass authorizer data to APIGatewayEventAuthorizer (aws-powertools#897)
  chore(deps): bump fastjsonschema from 2.15.1 to 2.15.2 (aws-powertools#891)
  • Loading branch information
heitorlessa committed Dec 17, 2021
2 parents b0f170e + 99227ce commit 01eb5a7
Show file tree
Hide file tree
Showing 9 changed files with 175 additions and 26 deletions.
63 changes: 49 additions & 14 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
from enum import Enum
from functools import partial
from http import HTTPStatus
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union

from aws_lambda_powertools.event_handler import content_types
from aws_lambda_powertools.event_handler.exceptions import ServiceError
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
from aws_lambda_powertools.shared import constants
from aws_lambda_powertools.shared.functions import resolve_truthy_env_var_choice
from aws_lambda_powertools.shared.json_encoder import Encoder
Expand All @@ -27,7 +27,6 @@
_SAFE_URI = "-._~()'!*:@,;" # https://www.ietf.org/rfc/rfc3986.txt
# API GW/ALB decode non-safe URI chars; we must support them too
_UNSAFE_URI = "%<>\[\]{}|^" # noqa: W605

_NAMED_GROUP_BOUNDARY_PATTERN = fr"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)"


Expand Down Expand Up @@ -435,6 +434,7 @@ def __init__(
self._proxy_type = proxy_type
self._routes: List[Route] = []
self._route_keys: List[str] = []
self._exception_handlers: Dict[Type, Callable] = {}
self._cors = cors
self._cors_enabled: bool = cors is not None
self._cors_methods: Set[str] = {"OPTIONS"}
Expand Down Expand Up @@ -596,6 +596,10 @@ def _not_found(self, method: str) -> ResponseBuilder:
headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods))
return ResponseBuilder(Response(status_code=204, content_type=None, headers=headers, body=None))

handler = self._lookup_exception_handler(NotFoundError)
if handler:
return ResponseBuilder(handler(NotFoundError()))

return ResponseBuilder(
Response(
status_code=HTTPStatus.NOT_FOUND.value,
Expand All @@ -609,16 +613,11 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
"""Actually call the matching route with any provided keyword arguments."""
try:
return ResponseBuilder(self._to_response(route.func(**args)), route)
except ServiceError as e:
return ResponseBuilder(
Response(
status_code=e.status_code,
content_type=content_types.APPLICATION_JSON,
body=self._json_dump({"statusCode": e.status_code, "message": e.msg}),
),
route,
)
except Exception:
except Exception as exc:
response_builder = self._call_exception_handler(exc, route)
if response_builder:
return response_builder

if self._debug:
# If the user has turned on debug mode,
# we'll let the original exception propagate so
Expand All @@ -628,10 +627,46 @@ def _call_route(self, route: Route, args: Dict[str, str]) -> ResponseBuilder:
status_code=500,
content_type=content_types.TEXT_PLAIN,
body="".join(traceback.format_exc()),
)
),
route,
)

raise

def not_found(self, func: Callable):
return self.exception_handler(NotFoundError)(func)

def exception_handler(self, exc_class: Type[Exception]):
def register_exception_handler(func: Callable):
self._exception_handlers[exc_class] = func

return register_exception_handler

def _lookup_exception_handler(self, exp_type: Type) -> Optional[Callable]:
# Use "Method Resolution Order" to allow for matching against a base class
# of an exception
for cls in exp_type.__mro__:
if cls in self._exception_handlers:
return self._exception_handlers[cls]
return None

def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[ResponseBuilder]:
handler = self._lookup_exception_handler(type(exp))
if handler:
return ResponseBuilder(handler(exp), route)

if isinstance(exp, ServiceError):
return ResponseBuilder(
Response(
status_code=exp.status_code,
content_type=content_types.APPLICATION_JSON,
body=self._json_dump({"statusCode": exp.status_code, "message": exp.msg}),
),
route,
)

return None

def _to_response(self, result: Union[Dict, Response]) -> Response:
"""Convert the route's result to a Response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,22 @@
class APIGatewayEventAuthorizer(DictWrapper):
@property
def claims(self) -> Optional[Dict[str, Any]]:
return self["requestContext"]["authorizer"].get("claims")
return self.get("claims")

@property
def scopes(self) -> Optional[List[str]]:
return self["requestContext"]["authorizer"].get("scopes")
return self.get("scopes")

@property
def principal_id(self) -> Optional[str]:
"""The principal user identification associated with the token sent by the client and returned from an
API Gateway Lambda authorizer (formerly known as a custom authorizer)"""
return self.get("principalId")

@property
def integration_latency(self) -> Optional[int]:
"""The authorizer latency in ms."""
return self.get("integrationLatency")


class APIGatewayEventRequestContext(BaseRequestContext):
Expand Down Expand Up @@ -56,7 +67,7 @@ def route_key(self) -> Optional[str]:

@property
def authorizer(self) -> APIGatewayEventAuthorizer:
return APIGatewayEventAuthorizer(self._data)
return APIGatewayEventAuthorizer(self._data["requestContext"]["authorizer"])


class APIGatewayProxyEvent(BaseProxyEvent):
Expand Down
4 changes: 2 additions & 2 deletions aws_lambda_powertools/utilities/data_classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def __eq__(self, other: Any) -> bool:

return self._data == other._data

def get(self, key: str) -> Optional[Any]:
return self._data.get(key)
def get(self, key: str, default: Optional[Any] = None) -> Optional[Any]:
return self._data.get(key, default)

@property
def raw_event(self) -> Dict[str, Any]:
Expand Down
3 changes: 1 addition & 2 deletions aws_lambda_powertools/utilities/parser/models/kinesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import List, Union

from pydantic import BaseModel, validator
from pydantic.types import PositiveInt

from aws_lambda_powertools.utilities.parser.types import Literal, Model

Expand All @@ -14,7 +13,7 @@
class KinesisDataStreamRecordPayload(BaseModel):
kinesisSchemaVersion: str
partitionKey: str
sequenceNumber: PositiveInt
sequenceNumber: str
data: Union[bytes, Model] # base64 encoded str is parsed into bytes
approximateArrivalTimestamp: float

Expand Down
10 changes: 7 additions & 3 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions tests/events/apiGatewayProxyEventPrincipalId.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"resource": "/trip",
"path": "/trip",
"httpMethod": "POST",
"requestContext": {
"requestId": "34972478-2843-4ced-a657-253108738274",
"authorizer": {
"user_id": "fake_username",
"principalId": "fake",
"integrationLatency": 451
}
}
}
75 changes: 74 additions & 1 deletion tests/functional/event_handler/test_api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def patch_func():
def handler(event, context):
return app.resolve(event, context)

# Also check check the route configurations
# Also check the route configurations
routes = app._routes
assert len(routes) == 5
for route in routes:
Expand Down Expand Up @@ -1076,3 +1076,76 @@ def foo():

assert result["statusCode"] == 200
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON


def test_exception_handler():
# GIVEN a resolver with an exception handler defined for ValueError
app = ApiGatewayResolver()

@app.exception_handler(ValueError)
def handle_value_error(ex: ValueError):
print(f"request path is '{app.current_event.path}'")
return Response(
status_code=418,
content_type=content_types.TEXT_HTML,
body=str(ex),
)

@app.get("/my/path")
def get_lambda() -> Response:
raise ValueError("Foo!")

# WHEN calling the event handler
# AND a ValueError is raised
result = app(LOAD_GW_EVENT, {})

# THEN call the exception_handler
assert result["statusCode"] == 418
assert result["headers"]["Content-Type"] == content_types.TEXT_HTML
assert result["body"] == "Foo!"


def test_exception_handler_service_error():
# GIVEN
app = ApiGatewayResolver()

@app.exception_handler(ServiceError)
def service_error(ex: ServiceError):
print(ex.msg)
return Response(
status_code=ex.status_code,
content_type=content_types.APPLICATION_JSON,
body="CUSTOM ERROR FORMAT",
)

@app.get("/my/path")
def get_lambda() -> Response:
raise InternalServerError("Something sensitive")

# WHEN calling the event handler
# AND a ServiceError is raised
result = app(LOAD_GW_EVENT, {})

# THEN call the exception_handler
assert result["statusCode"] == 500
assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON
assert result["body"] == "CUSTOM ERROR FORMAT"


def test_exception_handler_not_found():
# GIVEN a resolver with an exception handler defined for a 404 not found
app = ApiGatewayResolver()

@app.not_found
def handle_not_found(exc: NotFoundError) -> Response:
assert isinstance(exc, NotFoundError)
return Response(status_code=404, content_type=content_types.TEXT_PLAIN, body="I am a teapot!")

# WHEN calling the event handler
# AND not route is found
result = app(LOAD_GW_EVENT, {})

# THEN call the exception_handler
assert result["statusCode"] == 404
assert result["headers"]["Content-Type"] == content_types.TEXT_PLAIN
assert result["body"] == "I am a teapot!"
2 changes: 1 addition & 1 deletion tests/functional/parser/test_kinesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def handle_kinesis_no_envelope(event: KinesisDataStreamModel, _: LambdaContext):
assert kinesis.approximateArrivalTimestamp == 1545084650.987
assert kinesis.kinesisSchemaVersion == "1.0"
assert kinesis.partitionKey == "1"
assert kinesis.sequenceNumber == 49590338271490256608559692538361571095921575989136588898
assert kinesis.sequenceNumber == "49590338271490256608559692538361571095921575989136588898"
assert kinesis.data == b"Hello, this is a test."


Expand Down
14 changes: 14 additions & 0 deletions tests/functional/test_data_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,20 @@ def test_api_gateway_proxy_event():
assert request_context.identity.client_cert.subject_dn == "www.example.com"


def test_api_gateway_proxy_event_with_principal_id():
event = APIGatewayProxyEvent(load_event("apiGatewayProxyEventPrincipalId.json"))

request_context = event.request_context
authorizer = request_context.authorizer
assert authorizer.claims is None
assert authorizer.scopes is None
assert authorizer["principalId"] == "fake"
assert authorizer.get("principalId") == "fake"
assert authorizer.principal_id == "fake"
assert authorizer.integration_latency == 451
assert authorizer.get("integrationStatus", "failed") == "failed"


def test_api_gateway_proxy_v2_event():
event = APIGatewayProxyEventV2(load_event("apiGatewayProxyV2Event.json"))

Expand Down

0 comments on commit 01eb5a7

Please sign in to comment.