Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
convert list of servers into a set
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Quah committed Feb 2, 2023
1 parent f65d213 commit 163c68c
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 30 deletions.
21 changes: 12 additions & 9 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
from typing import (
TYPE_CHECKING,
AbstractSet,
Awaitable,
Callable,
Collection,
Expand Down Expand Up @@ -110,9 +111,9 @@ class SendJoinResult:
# True if 'state' elides non-critical membership events
partial_state: bool

# If 'partial_state' is set, a list of the servers in the room (otherwise empty).
# If 'partial_state' is set, a set of the servers in the room (otherwise empty).
# Always contains the server we joined off.
servers_in_room: List[str]
servers_in_room: AbstractSet[str]


class FederationClient(FederationBase):
Expand Down Expand Up @@ -1153,30 +1154,32 @@ async def _execute(pdu: EventBase) -> None:
% (auth_chain_create_events,)
)

servers_in_room = response.servers_in_room
servers_in_room = None
if response.servers_in_room is not None:
servers_in_room = set(response.servers_in_room)

if response.members_omitted:
if not servers_in_room:
raise InvalidResponseError(
"members_omitted was set, but no servers were listed in the room"
)

if destination not in servers_in_room:
# `servers_in_room` is supposed to be a complete list.
# Fix things up if the remote homeserver is badly behaved.
servers_in_room = [destination] + servers_in_room

if not partial_state:
raise InvalidResponseError(
"members_omitted was set, but we asked for full state"
)

# `servers_in_room` is supposed to be a complete list.
# Fix things up in case the remote homeserver is badly behaved.
servers_in_room.add(destination)

return SendJoinResult(
event=event,
state=signed_state,
auth_chain=signed_auth,
origin=destination,
partial_state=response.members_omitted,
servers_in_room=servers_in_room or [],
servers_in_room=servers_in_room or frozenset(),
)

# MSC3083 defines additional error codes for room joins.
Expand Down
20 changes: 15 additions & 5 deletions synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,17 @@
import logging
from enum import Enum
from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import (
TYPE_CHECKING,
AbstractSet,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)

import attr
from prometheus_client import Histogram
Expand Down Expand Up @@ -169,7 +179,7 @@ def __init__(self, hs: "HomeServer"):
# A dictionary mapping room IDs to (initial destination, other destinations)
# tuples.
self._partial_state_syncs_maybe_needing_restart: Dict[
str, Tuple[Optional[str], StrCollection]
str, Tuple[Optional[str], AbstractSet[str]]
] = {}
# A lock guarding the partial state flag for rooms.
# When the lock is held for a given room, no other concurrent code may
Expand Down Expand Up @@ -1720,7 +1730,7 @@ async def _resume_partial_state_room_sync(self) -> None:
def _start_partial_state_room_sync(
self,
initial_destination: Optional[str],
other_destinations: StrCollection,
other_destinations: AbstractSet[str],
room_id: str,
) -> None:
"""Starts the background process to resync the state of a partial state room,
Expand Down Expand Up @@ -1802,7 +1812,7 @@ async def _sync_partial_state_room_wrapper() -> None:
async def _sync_partial_state_room(
self,
initial_destination: Optional[str],
other_destinations: StrCollection,
other_destinations: AbstractSet[str],
room_id: str,
) -> None:
"""Background process to resync the state of a partial-state room
Expand Down Expand Up @@ -1939,7 +1949,7 @@ async def _sync_partial_state_room(

def _prioritise_destinations_for_partial_state_resync(
initial_destination: Optional[str],
other_destinations: StrCollection,
other_destinations: AbstractSet[str],
room_id: str,
) -> StrCollection:
"""Work out the order in which we should ask servers to resync events.
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/controllers/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ async def get_current_hosts_in_room_or_partial_state_approximation(
room_id
)
if hosts_at_join is None:
hosts_at_join = ()
hosts_at_join = frozenset()

hosts_from_state = await self.stores.main.get_current_hosts_in_room(room_id)

Expand Down
30 changes: 17 additions & 13 deletions synapse/storage/databases/main/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
from enum import Enum
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Awaitable,
Collection,
Dict,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
Expand Down Expand Up @@ -108,7 +108,7 @@ class RoomSortOrder(Enum):
@attr.s(slots=True, frozen=True, auto_attribs=True)
class PartialStateResyncInfo:
joined_via: Optional[str]
servers_in_room: List[str] = attr.ib(factory=list)
servers_in_room: Set[str] = attr.ib(factory=set)


class RoomWorkerStore(CacheInvalidationWorkerStore):
Expand Down Expand Up @@ -1194,8 +1194,8 @@ def get_rooms_for_retention_period_in_range_txn(

async def get_partial_state_servers_at_join(
self, room_id: str
) -> Optional[Sequence[str]]:
"""Gets the list of servers in a partial state room at the time we joined it.
) -> Optional[AbstractSet[str]]:
"""Gets the set of servers in a partial state room at the time we joined it.
Returns:
The `servers_in_room` list from the `/send_join` response for partial state
Expand All @@ -1211,12 +1211,16 @@ async def get_partial_state_servers_at_join(
return servers_in_room

@cached(iterable=True)
async def _get_partial_state_servers_at_join(self, room_id: str) -> Sequence[str]:
return await self.db_pool.simple_select_onecol(
"partial_state_rooms_servers",
keyvalues={"room_id": room_id},
retcol="server_name",
desc="get_partial_state_servers_at_join",
async def _get_partial_state_servers_at_join(
self, room_id: str
) -> AbstractSet[str]:
return frozenset(
await self.db_pool.simple_select_onecol(
"partial_state_rooms_servers",
keyvalues={"room_id": room_id},
retcol="server_name",
desc="get_partial_state_servers_at_join",
)
)

async def get_partial_state_room_resync_info(
Expand Down Expand Up @@ -1261,7 +1265,7 @@ async def get_partial_state_room_resync_info(
# partial-joined between the two SELECTs, but this is unlikely to happen
# in practice.)
continue
entry.servers_in_room.append(server_name)
entry.servers_in_room.add(server_name)

return room_servers

Expand Down Expand Up @@ -1951,7 +1955,7 @@ async def upsert_room_on_join(
async def store_partial_state_room(
self,
room_id: str,
servers: Collection[str],
servers: AbstractSet[str],
device_lists_stream_id: int,
joined_via: str,
) -> None:
Expand Down Expand Up @@ -1986,7 +1990,7 @@ def _store_partial_state_room_txn(
self,
txn: LoggingTransaction,
room_id: str,
servers: Collection[str],
servers: AbstractSet[str],
device_lists_stream_id: int,
joined_via: str,
) -> None:
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def test_failed_partial_join_is_clean(self) -> None:
EVENT_INVITATION_MEMBERSHIP,
],
partial_state=True,
servers_in_room=["example.com"],
servers_in_room={"example.com"},
)
)
)
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_room_member.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def test_remote_joins_contribute_to_rate_limit(self) -> None:
state=[create_event],
auth_chain=[create_event],
partial_state=False,
servers_in_room=[],
servers_in_room=frozenset(),
)
)
)
Expand Down

0 comments on commit 163c68c

Please sign in to comment.