Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Improvements to bundling aggregations. (#11815)
Browse files Browse the repository at this point in the history
This is some odds and ends found during the review of #11791
and while continuing to work in this code:

* Return attrs classes instead of dictionaries from some methods
  to improve type safety.
* Call `get_bundled_aggregations` fewer times.
* Adds a missing assertion in the tests.
* Do not return empty bundled aggregations for an event (preferring
  to not include the bundle at all, as the docstring states).
  • Loading branch information
clokep authored Jan 26, 2022
1 parent d8df8e6 commit 2897fb6
Show file tree
Hide file tree
Showing 12 changed files with 212 additions and 139 deletions.
1 change: 1 addition & 0 deletions changelog.d/11815.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type safety of bundled aggregations code.
57 changes: 40 additions & 17 deletions synapse/events/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,17 @@
# limitations under the License.
import collections.abc
import re
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Union,
)

from frozendict import frozendict

Expand All @@ -26,6 +36,10 @@

from . import EventBase

if TYPE_CHECKING:
from synapse.storage.databases.main.relations import BundledAggregations


# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'.
Expand Down Expand Up @@ -376,7 +390,7 @@ def serialize_event(
event: Union[JsonDict, EventBase],
time_now: int,
*,
bundle_aggregations: Optional[Dict[str, JsonDict]] = None,
bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
**kwargs: Any,
) -> JsonDict:
"""Serializes a single event.
Expand Down Expand Up @@ -415,7 +429,7 @@ def _inject_bundled_aggregations(
self,
event: EventBase,
time_now: int,
aggregations: JsonDict,
aggregations: "BundledAggregations",
serialized_event: JsonDict,
) -> None:
"""Potentially injects bundled aggregations into the unsigned portion of the serialized event.
Expand All @@ -427,13 +441,18 @@ def _inject_bundled_aggregations(
serialized_event: The serialized event which may be modified.
"""
# Make a copy in-case the object is cached.
aggregations = aggregations.copy()
serialized_aggregations = {}

if aggregations.annotations:
serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations

if aggregations.references:
serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references

if RelationTypes.REPLACE in aggregations:
if aggregations.replace:
# If there is an edit replace the content, preserving existing
# relations.
edit = aggregations[RelationTypes.REPLACE]
edit = aggregations.replace

# Ensure we take copies of the edit content, otherwise we risk modifying
# the original event.
Expand All @@ -451,24 +470,28 @@ def _inject_bundled_aggregations(
else:
serialized_event["content"].pop("m.relates_to", None)

aggregations[RelationTypes.REPLACE] = {
serialized_aggregations[RelationTypes.REPLACE] = {
"event_id": edit.event_id,
"origin_server_ts": edit.origin_server_ts,
"sender": edit.sender,
}

# If this event is the start of a thread, include a summary of the replies.
if RelationTypes.THREAD in aggregations:
# Serialize the latest thread event.
latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"]

# Don't bundle aggregations as this could recurse forever.
aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event(
latest_thread_event, time_now, bundle_aggregations=None
)
if aggregations.thread:
serialized_aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever.
"latest_event": self.serialize_event(
aggregations.thread.latest_event, time_now, bundle_aggregations=None
),
"count": aggregations.thread.count,
"current_user_participated": aggregations.thread.current_user_participated,
}

# Include the bundled aggregations in the event.
serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations)
if serialized_aggregations:
serialized_event["unsigned"].setdefault("m.relations", {}).update(
serialized_aggregations
)

def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
Expand Down
77 changes: 41 additions & 36 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Tuple,
)

import attr
from typing_extensions import TypedDict

from synapse.api.constants import (
Expand Down Expand Up @@ -60,6 +61,7 @@
from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.federation import get_domains_from_state
from synapse.rest.admin._base import assert_user_is_admin
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import (
Expand Down Expand Up @@ -90,6 +92,17 @@
FIVE_MINUTES_IN_MS = 5 * 60 * 1000


@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventContext:
events_before: List[EventBase]
event: EventBase
events_after: List[EventBase]
state: List[EventBase]
aggregations: Dict[str, BundledAggregations]
start: str
end: str


class RoomCreationHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
Expand Down Expand Up @@ -1119,7 +1132,7 @@ async def get_event_context(
limit: int,
event_filter: Optional[Filter],
use_admin_priviledge: bool = False,
) -> Optional[JsonDict]:
) -> Optional[EventContext]:
"""Retrieves events, pagination tokens and state around a given event
in a room.
Expand Down Expand Up @@ -1167,48 +1180,38 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
results = await self.store.get_events_around(
room_id, event_id, before_limit, after_limit, event_filter
)
events_before = results.events_before
events_after = results.events_after

if event_filter:
results["events_before"] = await event_filter.filter(
results["events_before"]
)
results["events_after"] = await event_filter.filter(results["events_after"])
events_before = await event_filter.filter(events_before)
events_after = await event_filter.filter(events_after)

results["events_before"] = await filter_evts(results["events_before"])
results["events_after"] = await filter_evts(results["events_after"])
events_before = await filter_evts(events_before)
events_after = await filter_evts(events_after)
# filter_evts can return a pruned event in case the user is allowed to see that
# there's something there but not see the content, so use the event that's in
# `filtered` rather than the event we retrieved from the datastore.
results["event"] = filtered[0]
event = filtered[0]

# Fetch the aggregations.
aggregations = await self.store.get_bundled_aggregations(
[results["event"]], user.to_string()
itertools.chain(events_before, (event,), events_after),
user.to_string(),
)
aggregations.update(
await self.store.get_bundled_aggregations(
results["events_before"], user.to_string()
)
)
aggregations.update(
await self.store.get_bundled_aggregations(
results["events_after"], user.to_string()
)
)
results["aggregations"] = aggregations

if results["events_after"]:
last_event_id = results["events_after"][-1].event_id
if events_after:
last_event_id = events_after[-1].event_id
else:
last_event_id = event_id

if event_filter and event_filter.lazy_load_members:
state_filter = StateFilter.from_lazy_load_member_list(
ev.sender
for ev in itertools.chain(
results["events_before"],
(results["event"],),
results["events_after"],
events_before,
(event,),
events_after,
)
)
else:
Expand All @@ -1226,21 +1229,23 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
if event_filter:
state_events = await event_filter.filter(state_events)

results["state"] = await filter_evts(state_events)

# We use a dummy token here as we only care about the room portion of
# the token, which we replace.
token = StreamToken.START

results["start"] = await token.copy_and_replace(
"room_key", results["start"]
).to_string(self.store)

results["end"] = await token.copy_and_replace(
"room_key", results["end"]
).to_string(self.store)

return results
return EventContext(
events_before=events_before,
event=event,
events_after=events_after,
state=await filter_evts(state_events),
aggregations=aggregations,
start=await token.copy_and_replace("room_key", results.start).to_string(
self.store
),
end=await token.copy_and_replace("room_key", results.end).to_string(
self.store
),
)


class TimestampLookupHandler:
Expand Down
45 changes: 23 additions & 22 deletions synapse/handlers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,36 +361,37 @@ async def search(

logger.info(
"Context for search returned %d and %d events",
len(res["events_before"]),
len(res["events_after"]),
len(res.events_before),
len(res.events_after),
)

res["events_before"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_before"]
events_before = await filter_events_for_client(
self.storage, user.to_string(), res.events_before
)

res["events_after"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_after"]
events_after = await filter_events_for_client(
self.storage, user.to_string(), res.events_after
)

res["start"] = await now_token.copy_and_replace(
"room_key", res["start"]
).to_string(self.store)

res["end"] = await now_token.copy_and_replace(
"room_key", res["end"]
).to_string(self.store)
context = {
"events_before": events_before,
"events_after": events_after,
"start": await now_token.copy_and_replace(
"room_key", res.start
).to_string(self.store),
"end": await now_token.copy_and_replace(
"room_key", res.end
).to_string(self.store),
}

if include_profile:
senders = {
ev.sender
for ev in itertools.chain(
res["events_before"], [event], res["events_after"]
)
for ev in itertools.chain(events_before, [event], events_after)
}

if res["events_after"]:
last_event_id = res["events_after"][-1].event_id
if events_after:
last_event_id = events_after[-1].event_id
else:
last_event_id = event.event_id

Expand All @@ -402,7 +403,7 @@ async def search(
last_event_id, state_filter
)

res["profile_info"] = {
context["profile_info"] = {
s.state_key: {
"displayname": s.content.get("displayname", None),
"avatar_url": s.content.get("avatar_url", None),
Expand All @@ -411,7 +412,7 @@ async def search(
if s.type == EventTypes.Member and s.state_key in senders
}

contexts[event.event_id] = res
contexts[event.event_id] = context
else:
contexts = {}

Expand All @@ -421,10 +422,10 @@ async def search(

for context in contexts.values():
context["events_before"] = self._event_serializer.serialize_events(
context["events_before"], time_now
context["events_before"], time_now # type: ignore[arg-type]
)
context["events_after"] = self._event_serializer.serialize_events(
context["events_after"], time_now
context["events_after"], time_now # type: ignore[arg-type]
)

state_results = {}
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
Expand Down Expand Up @@ -100,7 +101,7 @@ class TimelineBatch:
limited: bool
# A mapping of event ID to the bundled aggregations for the above events.
# This is only calculated if limited is true.
bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None
bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None

def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
Expand Down
2 changes: 1 addition & 1 deletion synapse/push/mailer.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ async def _get_notif_vars(
}

the_events = await filter_events_for_client(
self.storage, user_id, results["events_before"]
self.storage, user_id, results.events_before
)
the_events.append(notif_event)

Expand Down
Loading

0 comments on commit 2897fb6

Please sign in to comment.