Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Refactor get_user_devices_from_cache to avoid mutating cached values. #15040

Merged
merged 2 commits into from
Feb 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15040.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Avoid mutating a cached value in `get_user_devices_from_cache`.
11 changes: 7 additions & 4 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,19 +159,22 @@ async def query_devices(
# A map of destination -> user ID -> device IDs.
remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
if remote_queries:
query_list: List[Tuple[str, Optional[str]]] = []
user_ids = set()
user_and_device_ids: List[Tuple[str, str]] = []
for user_id, device_ids in remote_queries.items():
if device_ids:
query_list.extend(
user_and_device_ids.extend(
(user_id, device_id) for device_id in device_ids
)
else:
query_list.append((user_id, None))
user_ids.add(user_id)

(
user_ids_not_in_cache,
remote_results,
) = await self.store.get_user_devices_from_cache(query_list)
) = await self.store.get_user_devices_from_cache(
user_ids, user_and_device_ids
)

# Check that the homeserver still shares a room with all cached users.
# Note that this check may be slightly racy when a remote user leaves a
Expand Down
31 changes: 17 additions & 14 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,42 +745,45 @@ def _add_user_signature_change_txn(
@trace
@cancellable
async def get_user_devices_from_cache(
self, query_list: List[Tuple[str, Optional[str]]]
self, user_ids: Set[str], user_and_device_ids: List[Tuple[str, str]]
) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
"""Get the devices (and keys if any) for remote users from the cache.

Args:
query_list: List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user.
user_ids: users which should have all device IDs returned
user_and_device_ids: List of (user_id, device_ids)

Returns:
A tuple of (user_ids_not_in_cache, results_map), where
user_ids_not_in_cache is a set of user_ids and results_map is a
mapping of user_id -> device_id -> device_info.
"""
user_ids = {user_id for user_id, _ in query_list}
user_map = await self.get_device_list_last_stream_id_for_remotes(list(user_ids))
unique_user_ids = user_ids | {user_id for user_id, _ in user_and_device_ids}
user_map = await self.get_device_list_last_stream_id_for_remotes(
list(unique_user_ids)
)

# We go and check if any of the users need to have their device lists
# resynced. If they do then we remove them from the cached list.
users_needing_resync = await self.get_user_ids_requiring_device_list_resync(
user_ids
unique_user_ids
)
user_ids_in_cache = {
user_id for user_id, stream_id in user_map.items() if stream_id
} - users_needing_resync
user_ids_not_in_cache = user_ids - user_ids_in_cache
user_ids_not_in_cache = unique_user_ids - user_ids_in_cache

# First fetch all the users which all devices are to be returned.
results: Dict[str, Dict[str, JsonDict]] = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue

if device_id:
for user_id in user_ids:
if user_id in user_ids_in_cache:
results[user_id] = await self.get_cached_devices_for_user(user_id)
# Then fetch all device-specific requests, but skip users we've already
# fetched all devices for.
for user_id, device_id in user_and_device_ids:
if user_id in user_ids_in_cache and user_id not in user_ids:
device = await self._get_cached_user_device(user_id, device_id)
results.setdefault(user_id, {})[device_id] = device
else:
results[user_id] = await self.get_cached_devices_for_user(user_id)

set_tag("in_cache", str(results))
set_tag("not_in_cache", str(user_ids_not_in_cache))
Expand Down