Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add type annotations to trace decorator. (#13328)
Browse files Browse the repository at this point in the history
Functions that are decorated with `trace` are now properly typed
and the type hints for them are fixed.
  • Loading branch information
clokep authored Jul 19, 2022
1 parent 47822fd commit a6895dd
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 55 deletions.
1 change: 1 addition & 0 deletions changelog.d/13328.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `trace` decorator.
2 changes: 1 addition & 1 deletion synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ async def query_user_devices(
)

async def claim_client_keys(
self, destination: str, content: JsonDict, timeout: int
self, destination: str, content: JsonDict, timeout: Optional[int]
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.
Expand Down
2 changes: 1 addition & 1 deletion synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ async def query_user_devices(
)

async def claim_client_keys(
self, destination: str, query_content: JsonDict, timeout: int
self, destination: str, query_content: JsonDict, timeout: Optional[int]
) -> JsonDict:
"""Claim one-time keys for a list of devices hosted on a remote server.
Expand Down
16 changes: 9 additions & 7 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple

import attr
from canonicaljson import encode_canonical_json
Expand Down Expand Up @@ -92,7 +92,11 @@ def __init__(self, hs: "HomeServer"):

@trace
async def query_devices(
self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
self,
query_body: JsonDict,
timeout: int,
from_user_id: str,
from_device_id: Optional[str],
) -> JsonDict:
"""Handle a device key query from a client
Expand Down Expand Up @@ -120,9 +124,7 @@ async def query_devices(
the number of in-flight queries at a time.
"""
async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
device_keys_query: Dict[str, Iterable[str]] = query_body.get(
"device_keys", {}
)
device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})

# separate users by domain.
# make a map from domain to user_id to device_ids
Expand Down Expand Up @@ -392,7 +394,7 @@ async def get_cross_signing_keys_from_cache(

@trace
async def query_local_devices(
self, query: Dict[str, Optional[List[str]]]
self, query: Mapping[str, Optional[List[str]]]
) -> Dict[str, Dict[str, dict]]:
"""Get E2E device keys for local users
Expand Down Expand Up @@ -461,7 +463,7 @@ async def on_federation_query_client_keys(

@trace
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
) -> JsonDict:
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
Expand Down
50 changes: 28 additions & 22 deletions synapse/logging/opentracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,13 @@ def interesting_function(*args, **kwargs):
return something_usual_and_useful
Operation names can be explicitly set for a function by passing the
operation name to ``trace``
Operation names can be explicitly set for a function by using ``trace_with_opname``:
.. code-block:: python
from synapse.logging.opentracing import trace
from synapse.logging.opentracing import trace_with_opname
@trace(opname="a_better_operation_name")
@trace_with_opname("a_better_operation_name")
def interesting_badly_named_function(*args, **kwargs):
# Does all kinds of cool and expected things
return something_usual_and_useful
Expand Down Expand Up @@ -798,33 +797,31 @@ def extract_text_map(carrier: Dict[str, str]) -> Optional["opentracing.SpanConte
# Tracing decorators


def trace(func=None, opname: Optional[str] = None):
def trace_with_opname(opname: str) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
Decorator to trace a function.
Sets the operation name to that of the function's or that given
as operation_name. See the module's doc string for usage
examples.
Decorator to trace a function with a custom opname.
See the module's doc string for usage examples.
"""

def decorator(func):
def decorator(func: Callable[P, R]) -> Callable[P, R]:
if opentracing is None:
return func # type: ignore[unreachable]

_opname = opname if opname else func.__name__

if inspect.iscoroutinefunction(func):

@wraps(func)
async def _trace_inner(*args, **kwargs):
with start_active_span(_opname):
return await func(*args, **kwargs)
async def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
with start_active_span(opname):
return await func(*args, **kwargs) # type: ignore[misc]

else:
# The other case here handles both sync functions and those
# decorated with inlineDeferred.
@wraps(func)
def _trace_inner(*args, **kwargs):
scope = start_active_span(_opname)
def _trace_inner(*args: P.args, **kwargs: P.kwargs) -> R:
scope = start_active_span(opname)
scope.__enter__()

try:
Expand Down Expand Up @@ -858,12 +855,21 @@ def err_back(result: R) -> R:
scope.__exit__(type(e), None, e.__traceback__)
raise

return _trace_inner
return _trace_inner # type: ignore[return-value]

if func:
return decorator(func)
else:
return decorator
return decorator


def trace(func: Callable[P, R]) -> Callable[P, R]:
"""
Decorator to trace a function.
Sets the operation name to that of the function's name.
See the module's doc string for usage examples.
"""

return trace_with_opname(func.__name__)(func)


def tag_args(func: Callable[P, R]) -> Callable[P, R]:
Expand Down
4 changes: 2 additions & 2 deletions synapse/replication/http/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from synapse.http.server import HttpServer, is_method_cancellable
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing
from synapse.logging.opentracing import trace
from synapse.logging.opentracing import trace_with_opname
from synapse.types import JsonDict
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
Expand Down Expand Up @@ -196,7 +196,7 @@ def make_client(cls, hs: "HomeServer") -> Callable:
"ascii"
)

@trace(opname="outgoing_replication_request")
@trace_with_opname("outgoing_replication_request")
async def send_request(*, instance_name: str = "master", **kwargs: Any) -> Any:
with outgoing_gauge.track_inprogress():
if instance_name == local_instance_name:
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
from synapse.types import JsonDict, StreamToken

from ._base import client_patterns, interactive_auth_handler
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(self, hs: "HomeServer"):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()

@trace(opname="upload_keys")
@trace_with_opname("upload_keys")
async def on_POST(
self, request: SynapseRequest, device_id: Optional[str]
) -> Tuple[int, JsonDict]:
Expand Down
13 changes: 8 additions & 5 deletions synapse/rest/client/room_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple, cast

from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
Expand Down Expand Up @@ -127,7 +127,7 @@ async def on_PUT(
requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
version = parse_string(request, "version")
version = parse_string(request, "version", required=True)

if session_id:
body = {"sessions": {session_id: body}}
Expand Down Expand Up @@ -196,8 +196,11 @@ async def on_GET(
user_id = requester.user.to_string()
version = parse_string(request, "version", required=True)

room_keys = await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id
room_keys = cast(
JsonDict,
await self.e2e_room_keys_handler.get_room_keys(
user_id, version, room_id, session_id
),
)

# Convert room_keys to the right format to return.
Expand Down Expand Up @@ -240,7 +243,7 @@ async def on_DELETE(

requester = await self.auth.get_user_by_req(request, allow_guest=False)
user_id = requester.user.to_string()
version = parse_string(request, "version")
version = parse_string(request, "version", required=True)

ret = await self.e2e_room_keys_handler.delete_room_keys(
user_id, version, room_id, session_id
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/sendtodevice.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from synapse.http.server import HttpServer
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import set_tag, trace
from synapse.logging.opentracing import set_tag, trace_with_opname
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict

Expand All @@ -43,7 +43,7 @@ def __init__(self, hs: "HomeServer"):
self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()

@trace(opname="sendToDevice")
@trace_with_opname("sendToDevice")
def on_PUT(
self, request: SynapseRequest, message_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
Expand Down
12 changes: 6 additions & 6 deletions synapse/rest/client/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace
from synapse.logging.opentracing import trace_with_opname
from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder

Expand Down Expand Up @@ -210,7 +210,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
logger.debug("Event formatting complete")
return 200, response_content

@trace(opname="sync.encode_response")
@trace_with_opname("sync.encode_response")
async def encode_response(
self,
time_now: int,
Expand Down Expand Up @@ -315,7 +315,7 @@ def encode_presence(events: List[UserPresenceState], time_now: int) -> JsonDict:
]
}

@trace(opname="sync.encode_joined")
@trace_with_opname("sync.encode_joined")
async def encode_joined(
self,
rooms: List[JoinedSyncResult],
Expand All @@ -340,7 +340,7 @@ async def encode_joined(

return joined

@trace(opname="sync.encode_invited")
@trace_with_opname("sync.encode_invited")
async def encode_invited(
self,
rooms: List[InvitedSyncResult],
Expand Down Expand Up @@ -371,7 +371,7 @@ async def encode_invited(

return invited

@trace(opname="sync.encode_knocked")
@trace_with_opname("sync.encode_knocked")
async def encode_knocked(
self,
rooms: List[KnockedSyncResult],
Expand Down Expand Up @@ -420,7 +420,7 @@ async def encode_knocked(

return knocked

@trace(opname="sync.encode_archived")
@trace_with_opname("sync.encode_archived")
async def encode_archived(
self,
rooms: List[ArchivedSyncResult],
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def get_device_stream_token(self) -> int:

@trace
async def get_user_devices_from_cache(
self, query_list: List[Tuple[str, str]]
self, query_list: List[Tuple[str, Optional[str]]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.
Expand Down
Loading

0 comments on commit a6895dd

Please sign in to comment.