From 163c68c0be3205d39fc01522c6553b2dca738c0b Mon Sep 17 00:00:00 2001 From: Sean Quah Date: Thu, 2 Feb 2023 23:42:54 +0000 Subject: [PATCH] convert list of servers into a set --- synapse/federation/federation_client.py | 21 +++++++++-------- synapse/handlers/federation.py | 20 ++++++++++++----- synapse/storage/controllers/state.py | 2 +- synapse/storage/databases/main/room.py | 30 ++++++++++++++----------- tests/handlers/test_federation.py | 2 +- tests/handlers/test_room_member.py | 2 +- 6 files changed, 47 insertions(+), 30 deletions(-) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 10c56ed7011b..95bd1444b72f 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -19,6 +19,7 @@ import logging from typing import ( TYPE_CHECKING, + AbstractSet, Awaitable, Callable, Collection, @@ -110,9 +111,9 @@ class SendJoinResult: # True if 'state' elides non-critical membership events partial_state: bool - # If 'partial_state' is set, a list of the servers in the room (otherwise empty). + # If 'partial_state' is set, a set of the servers in the room (otherwise empty). # Always contains the server we joined off. - servers_in_room: List[str] + servers_in_room: AbstractSet[str] class FederationClient(FederationBase): @@ -1153,30 +1154,32 @@ async def _execute(pdu: EventBase) -> None: % (auth_chain_create_events,) ) - servers_in_room = response.servers_in_room + servers_in_room = None + if response.servers_in_room is not None: + servers_in_room = set(response.servers_in_room) + if response.members_omitted: if not servers_in_room: raise InvalidResponseError( "members_omitted was set, but no servers were listed in the room" ) - if destination not in servers_in_room: - # `servers_in_room` is supposed to be a complete list. - # Fix things up if the remote homeserver is badly behaved. - servers_in_room = [destination] + servers_in_room - if not partial_state: raise InvalidResponseError( "members_omitted was set, but we asked for full state" ) + # `servers_in_room` is supposed to be a complete list. + # Fix things up in case the remote homeserver is badly behaved. + servers_in_room.add(destination) + return SendJoinResult( event=event, state=signed_state, auth_chain=signed_auth, origin=destination, partial_state=response.members_omitted, - servers_in_room=servers_in_room or [], + servers_in_room=servers_in_room or frozenset(), ) # MSC3083 defines additional error codes for room joins. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index dc1cbf5c3d13..7f64130e0aa1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -20,7 +20,17 @@ import logging from enum import Enum from http import HTTPStatus -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import ( + TYPE_CHECKING, + AbstractSet, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) import attr from prometheus_client import Histogram @@ -169,7 +179,7 @@ def __init__(self, hs: "HomeServer"): # A dictionary mapping room IDs to (initial destination, other destinations) # tuples. self._partial_state_syncs_maybe_needing_restart: Dict[ - str, Tuple[Optional[str], StrCollection] + str, Tuple[Optional[str], AbstractSet[str]] ] = {} # A lock guarding the partial state flag for rooms. # When the lock is held for a given room, no other concurrent code may @@ -1720,7 +1730,7 @@ async def _resume_partial_state_room_sync(self) -> None: def _start_partial_state_room_sync( self, initial_destination: Optional[str], - other_destinations: StrCollection, + other_destinations: AbstractSet[str], room_id: str, ) -> None: """Starts the background process to resync the state of a partial state room, @@ -1802,7 +1812,7 @@ async def _sync_partial_state_room_wrapper() -> None: async def _sync_partial_state_room( self, initial_destination: Optional[str], - other_destinations: StrCollection, + other_destinations: AbstractSet[str], room_id: str, ) -> None: """Background process to resync the state of a partial-state room @@ -1939,7 +1949,7 @@ async def _sync_partial_state_room( def _prioritise_destinations_for_partial_state_resync( initial_destination: Optional[str], - other_destinations: StrCollection, + other_destinations: AbstractSet[str], room_id: str, ) -> StrCollection: """Work out the order in which we should ask servers to resync events. diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 7d9fc7d3b746..52efd4a1719b 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -573,7 +573,7 @@ async def get_current_hosts_in_room_or_partial_state_approximation( room_id ) if hosts_at_join is None: - hosts_at_join = () + hosts_at_join = frozenset() hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 12ddd152d6b0..2e1d114f9920 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -18,6 +18,7 @@ from enum import Enum from typing import ( TYPE_CHECKING, + AbstractSet, Any, Awaitable, Collection, @@ -25,7 +26,6 @@ List, Mapping, Optional, - Sequence, Set, Tuple, Union, @@ -108,7 +108,7 @@ class RoomSortOrder(Enum): @attr.s(slots=True, frozen=True, auto_attribs=True) class PartialStateResyncInfo: joined_via: Optional[str] - servers_in_room: List[str] = attr.ib(factory=list) + servers_in_room: Set[str] = attr.ib(factory=set) class RoomWorkerStore(CacheInvalidationWorkerStore): @@ -1194,8 +1194,8 @@ def get_rooms_for_retention_period_in_range_txn( async def get_partial_state_servers_at_join( self, room_id: str - ) -> Optional[Sequence[str]]: - """Gets the list of servers in a partial state room at the time we joined it. + ) -> Optional[AbstractSet[str]]: + """Gets the set of servers in a partial state room at the time we joined it. Returns: The `servers_in_room` list from the `/send_join` response for partial state @@ -1211,12 +1211,16 @@ async def get_partial_state_servers_at_join( return servers_in_room @cached(iterable=True) - async def _get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]: - return await self.db_pool.simple_select_onecol( - "partial_state_rooms_servers", - keyvalues={"room_id": room_id}, - retcol="server_name", - desc="get_partial_state_servers_at_join", + async def _get_partial_state_servers_at_join( + self, room_id: str + ) -> AbstractSet[str]: + return frozenset( + await self.db_pool.simple_select_onecol( + "partial_state_rooms_servers", + keyvalues={"room_id": room_id}, + retcol="server_name", + desc="get_partial_state_servers_at_join", + ) ) async def get_partial_state_room_resync_info( @@ -1261,7 +1265,7 @@ async def get_partial_state_room_resync_info( # partial-joined between the two SELECTs, but this is unlikely to happen # in practice.) continue - entry.servers_in_room.append(server_name) + entry.servers_in_room.add(server_name) return room_servers @@ -1951,7 +1955,7 @@ async def upsert_room_on_join( async def store_partial_state_room( self, room_id: str, - servers: Collection[str], + servers: AbstractSet[str], device_lists_stream_id: int, joined_via: str, ) -> None: @@ -1986,7 +1990,7 @@ def _store_partial_state_room_txn( self, txn: LoggingTransaction, room_id: str, - servers: Collection[str], + servers: AbstractSet[str], device_lists_stream_id: int, joined_via: str, ) -> None: diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index c1558c40c370..57675fa407e4 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -656,7 +656,7 @@ def test_failed_partial_join_is_clean(self) -> None: EVENT_INVITATION_MEMBERSHIP, ], partial_state=True, - servers_in_room=["example.com"], + servers_in_room={"example.com"}, ) ) ) diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py index 6bbfd5dc843f..6a38893b688a 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py @@ -171,7 +171,7 @@ def test_remote_joins_contribute_to_rate_limit(self) -> None: state=[create_event], auth_chain=[create_event], partial_state=False, - servers_in_room=[], + servers_in_room=frozenset(), ) ) )