diff --git a/changelog.d/13328.misc b/changelog.d/13328.misc new file mode 100644 index 000000000000..d15fb5fc3704 --- /dev/null +++ b/changelog.d/13328.misc @@ -0,0 +1 @@ +Add type hints to `trace` decorator. diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 66e6305562f5..7c450ecad008 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -217,7 +217,7 @@ async def query_user_devices( ) async def claim_client_keys( - self, destination: str, content: JsonDict, timeout: int + self, destination: str, content: JsonDict, timeout: Optional[int] ) -> JsonDict: """Claims one-time keys for a device hosted on a remote server. diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 9e84bd677e98..32074b8ca690 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -619,7 +619,7 @@ async def query_user_devices( ) async def claim_client_keys( - self, destination: str, query_content: JsonDict, timeout: int + self, destination: str, query_content: JsonDict, timeout: Optional[int] ) -> JsonDict: """Claim one-time keys for a list of devices hosted on a remote server. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 52bb5c9c55ab..84c28c480e89 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -92,7 +92,11 @@ def __init__(self, hs: "HomeServer"): @trace async def query_devices( - self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str + self, + query_body: JsonDict, + timeout: int, + from_user_id: str, + from_device_id: Optional[str], ) -> JsonDict: """Handle a device key query from a client @@ -120,9 +124,7 @@ async def query_devices( the number of in-flight queries at a time. """ async with self._query_devices_linearizer.queue((from_user_id, from_device_id)): - device_keys_query: Dict[str, Iterable[str]] = query_body.get( - "device_keys", {} - ) + device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {}) # separate users by domain. # make a map from domain to user_id to device_ids @@ -392,7 +394,7 @@ async def get_cross_signing_keys_from_cache( @trace async def query_local_devices( - self, query: Dict[str, Optional[List[str]]] + self, query: Mapping[str, Optional[List[str]]] ) -> Dict[str, Dict[str, dict]]: """Get E2E device keys for local users @@ -461,7 +463,7 @@ async def on_federation_query_client_keys( @trace async def claim_one_time_keys( - self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int + self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int] ) -> JsonDict: local_query: List[Tuple[str, str, str]] = [] remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {} diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 50c57940f94e..17e729f0c74b 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -84,14 +84,13 @@ def interesting_function(*args, **kwargs): return something_usual_and_useful -Operation names can be explicitly set for a function by passing the -operation name to ``trace`` +Operation names can be explicitly set for a function by using ``trace_with_opname``: .. code-block:: python - from synapse.logging.opentracing import trace + from synapse.logging.opentracing import trace_with_opname - @trace(opname="a_better_operation_name") + @trace_with_opname("a_better_operation_name") def interesting_badly_named_function(*args, **kwargs): # Does all kinds of cool and expected things return something_usual_and_useful @@ -798,33 +797,31 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte # Tracing decorators -def trace(func=None, opname: Optional[str] = None): +def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]: """ - Decorator to trace a function. - Sets the operation name to that of the function's or that given - as operation_name. See the module's doc string for usage - examples. + Decorator to trace a function with a custom opname. + + See the module's doc string for usage examples. + """ - def decorator(func): + def decorator(func: Callable[P, R]) -> Callable[P, R]: if opentracing is None: return func # type: ignore[unreachable] - _opname = opname if opname else func.__name__ - if inspect.iscoroutinefunction(func): @wraps(func) - async def _trace_inner(*args, **kwargs): - with start_active_span(_opname): - return await func(*args, **kwargs) + async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R: + with start_active_span(opname): + return await func(*args, **kwargs) # type: ignore[misc] else: # The other case here handles both sync functions and those # decorated with inlineDeferred. @wraps(func) - def _trace_inner(*args, **kwargs): - scope = start_active_span(_opname) + def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R: + scope = start_active_span(opname) scope.__enter__() try: @@ -858,12 +855,21 @@ def err_back(result: R) -> R: scope.__exit__(type(e), None, e.__traceback__) raise - return _trace_inner + return _trace_inner # type: ignore[return-value] - if func: - return decorator(func) - else: - return decorator + return decorator + + +def trace(func: Callable[P, R]) -> Callable[P, R]: + """ + Decorator to trace a function. + + Sets the operation name to that of the function's name. + + See the module's doc string for usage examples. + """ + + return trace_with_opname(func.__name__)(func) def tag_args(func: Callable[P, R]) -> Callable[P, R]: diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index a4ae4040c353..561ad5bf045c 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -29,7 +29,7 @@ from synapse.http.server import HttpServer, is_method_cancellable from synapse.http.site import SynapseRequest from synapse.logging import opentracing -from synapse.logging.opentracing import trace +from synapse.logging.opentracing import trace_with_opname from synapse.types import JsonDict from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import random_string @@ -196,7 +196,7 @@ def make_client(cls, hs: "HomeServer") -> Callable: "ascii" ) - @trace(opname="outgoing_replication_request") + @trace_with_opname("outgoing_replication_request") async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any: with outgoing_gauge.track_inprogress(): if instance_name == local_instance_name: diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index ce806e3c11be..eb1b85721fce 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -26,7 +26,7 @@ parse_string, ) from synapse.http.site import SynapseRequest -from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname from synapse.types import JsonDict, StreamToken from ._base import client_patterns, interactive_auth_handler @@ -71,7 +71,7 @@ def __init__(self, hs: "HomeServer"): self.e2e_keys_handler = hs.get_e2e_keys_handler() self.device_handler = hs.get_device_handler() - @trace(opname="upload_keys") + @trace_with_opname("upload_keys") async def on_POST( self, request: SynapseRequest, device_id: Optional[str] ) -> Tuple[int, JsonDict]: diff --git a/synapse/rest/client/room_keys.py b/synapse/rest/client/room_keys.py index 37e39570f6ac..f7081f638e51 100644 --- a/synapse/rest/client/room_keys.py +++ b/synapse/rest/client/room_keys.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple, cast from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import HttpServer @@ -127,7 +127,7 @@ async def on_PUT( requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() body = parse_json_object_from_request(request) - version = parse_string(request, "version") + version = parse_string(request, "version", required=True) if session_id: body = {"sessions": {session_id: body}} @@ -196,8 +196,11 @@ async def on_GET( user_id = requester.user.to_string() version = parse_string(request, "version", required=True) - room_keys = await self.e2e_room_keys_handler.get_room_keys( - user_id, version, room_id, session_id + room_keys = cast( + JsonDict, + await self.e2e_room_keys_handler.get_room_keys( + user_id, version, room_id, session_id + ), ) # Convert room_keys to the right format to return. @@ -240,7 +243,7 @@ async def on_DELETE( requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - version = parse_string(request, "version") + version = parse_string(request, "version", required=True) ret = await self.e2e_room_keys_handler.delete_room_keys( user_id, version, room_id, session_id diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py index 3322c8ef48da..1a8e9a96d4c7 100644 --- a/synapse/rest/client/sendtodevice.py +++ b/synapse/rest/client/sendtodevice.py @@ -19,7 +19,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request from synapse.http.site import SynapseRequest -from synapse.logging.opentracing import set_tag, trace +from synapse.logging.opentracing import set_tag, trace_with_opname from synapse.rest.client.transactions import HttpTransactionCache from synapse.types import JsonDict @@ -43,7 +43,7 @@ def __init__(self, hs: "HomeServer"): self.txns = HttpTransactionCache(hs) self.device_message_handler = hs.get_device_message_handler() - @trace(opname="sendToDevice") + @trace_with_opname("sendToDevice") def on_PUT( self, request: SynapseRequest, message_type: str, txn_id: str ) -> Awaitable[Tuple[int, JsonDict]]: diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 8bbf35148d67..c2989765cee0 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -37,7 +37,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest -from synapse.logging.opentracing import trace +from synapse.logging.opentracing import trace_with_opname from synapse.types import JsonDict, StreamToken from synapse.util import json_decoder @@ -210,7 +210,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: logger.debug("Event formatting complete") return 200, response_content - @trace(opname="sync.encode_response") + @trace_with_opname("sync.encode_response") async def encode_response( self, time_now: int, @@ -315,7 +315,7 @@ def encode_presence(events: List[UserPresenceState], time_now: int) -> JsonDict: ] } - @trace(opname="sync.encode_joined") + @trace_with_opname("sync.encode_joined") async def encode_joined( self, rooms: List[JoinedSyncResult], @@ -340,7 +340,7 @@ async def encode_joined( return joined - @trace(opname="sync.encode_invited") + @trace_with_opname("sync.encode_invited") async def encode_invited( self, rooms: List[InvitedSyncResult], @@ -371,7 +371,7 @@ async def encode_invited( return invited - @trace(opname="sync.encode_knocked") + @trace_with_opname("sync.encode_knocked") async def encode_knocked( self, rooms: List[KnockedSyncResult], @@ -420,7 +420,7 @@ async def encode_knocked( return knocked - @trace(opname="sync.encode_archived") + @trace_with_opname("sync.encode_archived") async def encode_archived( self, rooms: List[ArchivedSyncResult], diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index adde5d097810..7a6ed332aa61 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -669,7 +669,7 @@ def get_device_stream_token(self) -> int: @trace async def get_user_devices_from_cache( - self, query_list: List[Tuple[str, str]] + self, query_list: List[Tuple[str, Optional[str]]] ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]: """Get the devices (and keys if any) for remote users from the cache. diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 9b293475c871..60f622ad71bd 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -22,11 +22,14 @@ List, Optional, Tuple, + Union, cast, + overload, ) import attr from canonicaljson import encode_canonical_json +from typing_extensions import Literal from synapse.api.constants import DeviceKeyAlgorithms from synapse.appservice import ( @@ -113,7 +116,7 @@ async def get_e2e_device_keys_for_federation_query( user_devices = devices[user_id] results = [] for device_id, device in user_devices.items(): - result = {"device_id": device_id} + result: JsonDict = {"device_id": device_id} keys = device.keys if keys: @@ -156,6 +159,9 @@ async def get_e2e_device_keys_for_cs_api( rv[user_id] = {} for device_id, device_info in device_keys.items(): r = device_info.keys + if r is None: + continue + r["unsigned"] = {} display_name = device_info.display_name if display_name is not None: @@ -164,13 +170,42 @@ async def get_e2e_device_keys_for_cs_api( return rv + @overload + async def get_e2e_device_keys_and_signatures( + self, + query_list: Collection[Tuple[str, Optional[str]]], + include_all_devices: Literal[False] = False, + ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]: + ... + + @overload + async def get_e2e_device_keys_and_signatures( + self, + query_list: Collection[Tuple[str, Optional[str]]], + include_all_devices: bool = False, + include_deleted_devices: Literal[False] = False, + ) -> Dict[str, Dict[str, DeviceKeyLookupResult]]: + ... + + @overload + async def get_e2e_device_keys_and_signatures( + self, + query_list: Collection[Tuple[str, Optional[str]]], + include_all_devices: Literal[True], + include_deleted_devices: Literal[True], + ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + ... + @trace async def get_e2e_device_keys_and_signatures( self, - query_list: List[Tuple[str, Optional[str]]], + query_list: Collection[Tuple[str, Optional[str]]], include_all_devices: bool = False, include_deleted_devices: bool = False, - ) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]: + ) -> Union[ + Dict[str, Dict[str, DeviceKeyLookupResult]], + Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]], + ]: """Fetch a list of device keys Any cross-signatures made on the keys by the owner of the device are also @@ -1044,7 +1079,7 @@ def _claim_e2e_one_time_key_returning( _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple db_autocommit = False - row = await self.db_pool.runInteraction( + claim_row = await self.db_pool.runInteraction( "claim_e2e_one_time_keys", _claim_e2e_one_time_key, user_id, @@ -1052,11 +1087,11 @@ def _claim_e2e_one_time_key_returning( algorithm, db_autocommit=db_autocommit, ) - if row: + if claim_row: device_results = results.setdefault(user_id, {}).setdefault( device_id, {} ) - device_results[row[0]] = row[1] + device_results[claim_row[0]] = claim_row[1] continue # No one-time key available, so see if there's a fallback