Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improvements to bundling aggregations #11815

Merged
merged 12 commits into from
Jan 26, 2022
Merged
1 change: 1 addition & 0 deletions changelog.d/11815.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type safety of bundled aggregations code.
57 changes: 40 additions & 17 deletions synapse/events/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,17 @@
# limitations under the License.
import collections.abc
import re
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Union,
)

from frozendict import frozendict

Expand All @@ -26,6 +36,10 @@

from . import EventBase

if TYPE_CHECKING:
from synapse.storage.databases.main.relations import BundledAggregations


# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'.
Expand Down Expand Up @@ -376,7 +390,7 @@ def serialize_event(
event: Union[JsonDict, EventBase],
time_now: int,
*,
bundle_aggregations: Optional[Dict[str, JsonDict]] = None,
bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
**kwargs: Any,
) -> JsonDict:
"""Serializes a single event.
Expand Down Expand Up @@ -415,7 +429,7 @@ def _inject_bundled_aggregations(
self,
event: EventBase,
time_now: int,
aggregations: JsonDict,
aggregations: "BundledAggregations",
serialized_event: JsonDict,
) -> None:
"""Potentially injects bundled aggregations into the unsigned portion of the serialized event.
Expand All @@ -427,13 +441,18 @@ def _inject_bundled_aggregations(
serialized_event: The serialized event which may be modified.

"""
# Make a copy in-case the object is cached.
aggregations = aggregations.copy()
serialized_aggregations = {}

if aggregations.annotations:
serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations

if aggregations.references:
serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references

if RelationTypes.REPLACE in aggregations:
if aggregations.replace:
# If there is an edit replace the content, preserving existing
# relations.
edit = aggregations[RelationTypes.REPLACE]
edit = aggregations.replace

# Ensure we take copies of the edit content, otherwise we risk modifying
# the original event.
Expand All @@ -451,24 +470,28 @@ def _inject_bundled_aggregations(
else:
serialized_event["content"].pop("m.relates_to", None)

aggregations[RelationTypes.REPLACE] = {
serialized_aggregations[RelationTypes.REPLACE] = {
"event_id": edit.event_id,
"origin_server_ts": edit.origin_server_ts,
"sender": edit.sender,
}

# If this event is the start of a thread, include a summary of the replies.
if RelationTypes.THREAD in aggregations:
# Serialize the latest thread event.
latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"]

# Don't bundle aggregations as this could recurse forever.
aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event(
latest_thread_event, time_now, bundle_aggregations=None
)
if aggregations.thread:
serialized_aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever.
"latest_event": self.serialize_event(
aggregations.thread.latest_event, time_now, bundle_aggregations=None
),
"count": aggregations.thread.count,
"current_user_participated": aggregations.thread.current_user_participated,
}

# Include the bundled aggregations in the event.
serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations)
if serialized_aggregations:
serialized_event["unsigned"].setdefault("m.relations", {}).update(
serialized_aggregations
)

def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
Expand Down
59 changes: 33 additions & 26 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Tuple,
)

import attr
from typing_extensions import TypedDict

from synapse.api.constants import (
Expand Down Expand Up @@ -60,6 +61,7 @@
from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.federation import get_domains_from_state
from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import (
Expand Down Expand Up @@ -90,6 +92,17 @@
FIVE_MINUTES_IN_MS = 5 * 60 * 1000


@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventContext:
events_before: List[EventBase]
event: EventBase
events_after: List[EventBase]
state: List[EventBase]
aggregations: Dict[str, BundledAggregations]
start: str
end: str


class RoomCreationHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
Expand Down Expand Up @@ -1119,7 +1132,7 @@ async def get_event_context(
limit: int,
event_filter: Optional[Filter],
use_admin_priviledge: bool = False,
) -> Optional[JsonDict]:
) -> Optional[EventContext]:
"""Retrieves events, pagination tokens and state around a given event
in a room.

Expand Down Expand Up @@ -1179,23 +1192,15 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
# filter_evts can return a pruned event in case the user is allowed to see that
# there's something there but not see the content, so use the event that's in
# `filtered` rather than the event we retrieved from the datastore.
results["event"] = filtered[0]
event = filtered[0]
Comment on lines -1182 to +1195
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're still doing some mutation of results on lines +1184 to +1190 above. Not sure if that's intentional? It seems to conflict with the title of this commit!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we are...I initially had converted that too and then stopped. I think it got messy, let me give it a try again though!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've updated this to not mutate those results, it got a bit messy though. Let me know if you'd prefer I back it out. See 1777831.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be honest I don't mind too much either way. (I was mostly checking that we hadn't missed a bit of work that we'd intended to do, judging by the commit message).

I like that the change makes it a bit easier to see what you get out of get_events_around (i.e. a thing with specific fields rather than an arbitrary dictionary), but I don't really have a strong opinion here. Happy to defer to your judgement!


# Fetch the aggregations.
aggregations = await self.store.get_bundled_aggregations(
[results["event"]], user.to_string()
)
aggregations.update(
await self.store.get_bundled_aggregations(
results["events_before"], user.to_string()
)
)
aggregations.update(
await self.store.get_bundled_aggregations(
results["events_after"], user.to_string()
)
itertools.chain(
results["events_before"], (event,), results["events_after"]
),
user.to_string(),
)
results["aggregations"] = aggregations

if results["events_after"]:
last_event_id = results["events_after"][-1].event_id
Expand All @@ -1207,7 +1212,7 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
ev.sender
for ev in itertools.chain(
results["events_before"],
(results["event"],),
(event,),
results["events_after"],
)
)
Expand All @@ -1226,21 +1231,23 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
if event_filter:
state_events = await event_filter.filter(state_events)

results["state"] = await filter_evts(state_events)

# We use a dummy token here as we only care about the room portion of
# the token, which we replace.
token = StreamToken.START

results["start"] = await token.copy_and_replace(
"room_key", results["start"]
).to_string(self.store)

results["end"] = await token.copy_and_replace(
"room_key", results["end"]
).to_string(self.store)

return results
return EventContext(
events_before=results["events_before"],
event=event,
events_after=results["events_after"],
state=await filter_evts(state_events),
aggregations=aggregations,
start=await token.copy_and_replace("room_key", results["start"]).to_string(
self.store
),
end=await token.copy_and_replace("room_key", results["end"]).to_string(
self.store
),
)
DMRobertson marked this conversation as resolved.
Show resolved Hide resolved


class TimestampLookupHandler:
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
Expand Down Expand Up @@ -100,7 +101,7 @@ class TimelineBatch:
limited: bool
# A mapping of event ID to the bundled aggregations for the above events.
# This is only calculated if limited is true.
bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None
bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None

def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
Expand Down
39 changes: 24 additions & 15 deletions synapse/rest/admin/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,7 +729,7 @@ async def on_GET(
else:
event_filter = None

results = await self.room_context_handler.get_event_context(
event_context = await self.room_context_handler.get_event_context(
requester,
room_id,
event_id,
Expand All @@ -738,25 +738,34 @@ async def on_GET(
use_admin_priviledge=True,
)

if not results:
if not event_context:
raise SynapseError(
HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND
)

time_now = self.clock.time_msec()
aggregations = results.pop("aggregations", None)
results["events_before"] = self._event_serializer.serialize_events(
results["events_before"], time_now, bundle_aggregations=aggregations
)
results["event"] = self._event_serializer.serialize_event(
results["event"], time_now, bundle_aggregations=aggregations
)
results["events_after"] = self._event_serializer.serialize_events(
results["events_after"], time_now, bundle_aggregations=aggregations
)
results["state"] = self._event_serializer.serialize_events(
results["state"], time_now
)
results = {
"events_before": self._event_serializer.serialize_events(
event_context.events_before,
time_now,
bundle_aggregations=event_context.aggregations,
),
"event": self._event_serializer.serialize_event(
event_context.event,
time_now,
bundle_aggregations=event_context.aggregations,
),
"events_after": self._event_serializer.serialize_events(
event_context.events_after,
time_now,
bundle_aggregations=event_context.aggregations,
),
"state": self._event_serializer.serialize_events(
event_context.state, time_now
),
"start": event_context.start,
"end": event_context.end,
}

return HTTPStatus.OK, results

Expand Down
39 changes: 24 additions & 15 deletions synapse/rest/client/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,27 +706,36 @@ async def on_GET(
else:
event_filter = None

results = await self.room_context_handler.get_event_context(
event_context = await self.room_context_handler.get_event_context(
requester, room_id, event_id, limit, event_filter
)

if not results:
if not event_context:
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)

time_now = self.clock.time_msec()
aggregations = results.pop("aggregations", None)
results["events_before"] = self._event_serializer.serialize_events(
results["events_before"], time_now, bundle_aggregations=aggregations
)
results["event"] = self._event_serializer.serialize_event(
results["event"], time_now, bundle_aggregations=aggregations
)
results["events_after"] = self._event_serializer.serialize_events(
results["events_after"], time_now, bundle_aggregations=aggregations
)
results["state"] = self._event_serializer.serialize_events(
results["state"], time_now
)
results = {
"events_before": self._event_serializer.serialize_events(
event_context.events_before,
time_now,
bundle_aggregations=event_context.aggregations,
),
"event": self._event_serializer.serialize_event(
event_context.event,
time_now,
bundle_aggregations=event_context.aggregations,
),
"events_after": self._event_serializer.serialize_events(
event_context.events_after,
time_now,
bundle_aggregations=event_context.aggregations,
),
"state": self._event_serializer.serialize_events(
event_context.state, time_now
),
"start": event_context.start,
"end": event_context.end,
}

return 200, results

Expand Down
3 changes: 2 additions & 1 deletion synapse/rest/client/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder

Expand Down Expand Up @@ -526,7 +527,7 @@ async def encode_room(

def serialize(
events: Iterable[EventBase],
aggregations: Optional[Dict[str, Dict[str, Any]]] = None,
aggregations: Optional[Dict[str, BundledAggregations]] = None,
) -> List[JsonDict]:
return self._event_serializer.serialize_events(
events,
Expand Down
Loading