Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert state resolution to async/await #7942

Merged
merged 11 commits into from
Jul 24, 2020
1 change: 1 addition & 0 deletions changelog.d/7942.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert state resolution to async/await.
12 changes: 8 additions & 4 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,10 @@ def check_user_in_room(
if current_state:
member = current_state.get((EventTypes.Member, user_id), None)
else:
member = yield self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
member = yield defer.ensureDeferred(
self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
)
membership = member.membership if member else None

Expand Down Expand Up @@ -665,8 +667,10 @@ def check_user_in_room_or_world_readable(
)
return member_event.membership, member_event.event_id
except AuthError:
visibility = yield self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
visibility = yield defer.ensureDeferred(
self.state.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
)
if (
visibility
Expand Down
4 changes: 2 additions & 2 deletions synapse/events/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ def build(self, prev_event_ids):
Deferred[FrozenEvent]
"""

state_ids = yield self._state.get_current_state_ids(
self.room_id, prev_event_ids
state_ids = yield defer.ensureDeferred(
self._state.get_current_state_ids(self.room_id, prev_event_ids)
)
auth_ids = yield self._auth.compute_auth_events(self, state_ids)

Expand Down
4 changes: 3 additions & 1 deletion synapse/federation/sender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,9 @@ def send_read_receipt(self, receipt: ReadReceipt):
room_id = receipt.room_id

# Work out which remote servers should be poked and poke them.
domains = yield self.state.get_current_hosts_in_room(room_id)
domains = yield defer.ensureDeferred(
self.state.get_current_hosts_in_room(room_id)
)
domains = [
d
for d in domains
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,8 +928,8 @@ async def _on_user_joined_room(self, room_id: str, user_id: str) -> None:
# TODO: Check that this is actually a new server joining the
# room.

user_ids = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, user_ids))
users = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, users))

states_d = await self.current_state_for_users(user_ids)

Expand Down
4 changes: 3 additions & 1 deletion synapse/push/bulk_push_rule_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ def get_rules(self, event, context):

push_rules_delta_state_cache_metric.inc_hits()
else:
current_state_ids = yield context.get_current_state_ids()
current_state_ids = yield defer.ensureDeferred(
context.get_current_state_ids()
)
push_rules_delta_state_cache_metric.inc_misses()

push_rules_state_size_counter.inc(len(current_state_ids))
Expand Down
95 changes: 42 additions & 53 deletions synapse/state/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,20 @@

import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Set
from typing import Awaitable, Dict, Iterable, List, Optional, Set

import attr
from frozendict import frozendict
from prometheus_client import Histogram

from twisted.internet import defer

from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
from synapse.types import StateMap
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
Expand Down Expand Up @@ -108,8 +107,7 @@ def __init__(self, hs):
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()

@defer.inlineCallbacks
def get_current_state(
async def get_current_state(
self, room_id, event_type=None, state_key="", latest_event_ids=None
):
""" Retrieves the current state for the room. This is done by
Expand All @@ -126,20 +124,20 @@ def get_current_state(
map from (type, state_key) to event
"""
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)

logger.debug("calling resolve_state_groups from get_current_state")
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state

if event_type:
event_id = state.get((event_type, state_key))
event = None
if event_id:
event = yield self.store.get_event(event_id, allow_none=True)
event = await self.store.get_event(event_id, allow_none=True)
return event

state_map = yield self.store.get_events(
state_map = await self.store.get_events(
list(state.values()), get_prev_content=False
)
state = {
Expand All @@ -148,8 +146,7 @@ def get_current_state(

return state

@defer.inlineCallbacks
def get_current_state_ids(self, room_id, latest_event_ids=None):
async def get_current_state_ids(self, room_id, latest_event_ids=None):
"""Get the current state, or the state at a set of events, for a room

Args:
Expand All @@ -164,41 +161,38 @@ def get_current_state_ids(self, room_id, latest_event_ids=None):
(event_type, state_key) -> event_id
"""
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)

logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state

return state

@defer.inlineCallbacks
def get_current_users_in_room(self, room_id, latest_event_ids=None):
async def get_current_users_in_room(
self, room_id: str, latest_event_ids: Optional[List[str]] = None
) -> Dict[str, ProfileInfo]:
"""
Get the users who are currently in a room.

Args:
room_id (str): The ID of the room.
latest_event_ids (List[str]|None): Precomputed list of latest
event IDs. Will be computed if None.
room_id: The ID of the room.
latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
Returns:
Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their
profileinfo.
Dictionary of user IDs to their profileinfo.
"""
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_users_in_room")
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
joined_users = await self.store.get_joined_users_from_state(room_id, entry)
return joined_users

@defer.inlineCallbacks
def get_current_hosts_in_room(self, room_id):
event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
return (yield self.get_hosts_in_room_at_events(room_id, event_ids))
async def get_current_hosts_in_room(self, room_id):
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)

@defer.inlineCallbacks
def get_hosts_in_room_at_events(self, room_id, event_ids):
async def get_hosts_in_room_at_events(self, room_id, event_ids):
"""Get the hosts that were in a room at the given event ids

Args:
Expand All @@ -208,12 +202,11 @@ def get_hosts_in_room_at_events(self, room_id, event_ids):
Returns:
Deferred[list[str]]: the hosts in the room at the given events
"""
entry = yield self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
entry = await self.resolve_state_groups_for_events(room_id, event_ids)
joined_hosts = await self.store.get_joined_hosts(room_id, entry)
return joined_hosts

@defer.inlineCallbacks
def compute_event_context(
async def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
):
"""Build an EventContext structure for the event.
Expand Down Expand Up @@ -278,7 +271,7 @@ def compute_event_context(
# otherwise, we'll need to resolve the state across the prev_events.
logger.debug("calling resolve_state_groups from compute_event_context")

entry = yield self.resolve_state_groups_for_events(
entry = await self.resolve_state_groups_for_events(
event.room_id, event.prev_event_ids()
)

Expand All @@ -295,7 +288,7 @@ def compute_event_context(
#

if not state_group_before_event:
state_group_before_event = yield self.state_store.store_state_group(
state_group_before_event = await self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event_prev_group,
Expand Down Expand Up @@ -335,7 +328,7 @@ def compute_event_context(
state_ids_after_event[key] = event.event_id
delta_ids = {key: event.event_id}

state_group_after_event = yield self.state_store.store_state_group(
state_group_after_event = await self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event,
Expand All @@ -353,8 +346,7 @@ def compute_event_context(
)

@measure_func()
@defer.inlineCallbacks
def resolve_state_groups_for_events(self, room_id, event_ids):
async def resolve_state_groups_for_events(self, room_id, event_ids):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.

Expand All @@ -373,7 +365,7 @@ def resolve_state_groups_for_events(self, room_id, event_ids):
# map from state group id to the state in that state group (where
# 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]]
state_groups_ids = yield self.state_store.get_state_groups_ids(
state_groups_ids = await self.state_store.get_state_groups_ids(
room_id, event_ids
)

Expand All @@ -382,7 +374,7 @@ def resolve_state_groups_for_events(self, room_id, event_ids):
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()

prev_group, delta_ids = yield self.state_store.get_state_group_delta(name)
prev_group, delta_ids = await self.state_store.get_state_group_delta(name)

return _StateCacheEntry(
state=state_list,
Expand All @@ -391,9 +383,9 @@ def resolve_state_groups_for_events(self, room_id, event_ids):
delta_ids=delta_ids,
)

room_version = yield self.store.get_room_version_id(room_id)
room_version = await self.store.get_room_version_id(room_id)

result = yield self._state_resolution_handler.resolve_state_groups(
result = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
state_groups_ids,
Expand All @@ -402,8 +394,7 @@ def resolve_state_groups_for_events(self, room_id, event_ids):
)
return result

@defer.inlineCallbacks
def resolve_events(self, room_version, state_sets, event):
async def resolve_events(self, room_version, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
Expand All @@ -414,7 +405,7 @@ def resolve_events(self, room_version, state_sets, event):
state_map = {ev.event_id: ev for st in state_sets for ev in st}

with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
new_state = await resolve_events_with_store(
self.clock,
event.room_id,
room_version,
Expand Down Expand Up @@ -451,9 +442,8 @@ def __init__(self, hs):
reset_expiry_on_get=True,
)

@defer.inlineCallbacks
@log_function
def resolve_state_groups(
async def resolve_state_groups(
self, room_id, room_version, state_groups_ids, event_map, state_res_store
):
"""Resolves conflicts between a set of state groups
Expand All @@ -479,13 +469,13 @@ def resolve_state_groups(
state_res_store (StateResolutionStore)

Returns:
Deferred[_StateCacheEntry]: resolved state
_StateCacheEntry: resolved state
"""
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())

group_names = frozenset(state_groups_ids.keys())

with (yield self.resolve_linearizer.queue(group_names)):
with (await self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache:
Expand Down Expand Up @@ -517,7 +507,7 @@ def resolve_state_groups(
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events_with_store(
new_state = await resolve_events_with_store(
self.clock,
room_id,
room_version,
Expand Down Expand Up @@ -598,7 +588,7 @@ def resolve_events_with_store(
state_sets: List[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
):
) -> Awaitable[StateMap[str]]:
"""
Args:
room_id: the room we are working in
Expand All @@ -619,8 +609,7 @@ def resolve_events_with_store(
state_res_store: a place to fetch events from

Returns:
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
a map from (type, state_key) to event_id.
"""
v = KNOWN_ROOM_VERSIONS[room_version]
if v.state_res == StateResolutionVersions.V1:
Expand Down
Loading