Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Improve event caching code #10119

Merged
merged 16 commits into from
Aug 4, 2021
Merged
6 changes: 1 addition & 5 deletions synapse/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __get__(self, instance, owner=None):


class _EventInternalMetadata:
__slots__ = ["_dict", "stream_ordering", "outlier", "redacted_by"]
__slots__ = ["_dict", "stream_ordering", "outlier"]

def __init__(self, internal_metadata_dict: JsonDict):
# we have to copy the dict, because it turns out that the same dict is
Expand All @@ -111,10 +111,6 @@ def __init__(self, internal_metadata_dict: JsonDict):
# in the DAG)
self.outlier = False

# Whether this event has a valid redaction event pointing at it (i.e.
# whether it should be redacted before giving to clients).
self.redacted_by: Optional[str] = None

out_of_band_membership: bool = DictProperty("out_of_band_membership")
send_on_behalf_of: str = DictProperty("send_on_behalf_of")
recheck_redaction: bool = DictProperty("recheck_redaction")
Expand Down
93 changes: 32 additions & 61 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import logging
import threading
import weakref
from typing import (
Collection,
Container,
Expand Down Expand Up @@ -171,13 +170,12 @@ def __init__(self, database: DatabasePool, db_conn, hs):
max_size=hs.config.caches.event_cache_size,
)

# Map from event ID to a deferred that will result in an
# Dict[str, _EventCacheEntry].
self._current_event_fetches: Dict[str, ObservableDeferred] = {}

# We keep track of the events we have currently loaded in memory so that
# we can reuse them even if they've been evicted from the cache.
self._event_ref: Dict[str, EventBase] = weakref.WeakValueDictionary()
# Map from event ID to a deferred that will result in a map from event
# ID to cache entry. Note that the returned dict may not have the
# requested event in it if the event isn't in the DB.
self._current_event_fetches: Dict[
str, ObservableDeferred[Dict[str, _EventCacheEntry]]
] = {}

self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
Expand Down Expand Up @@ -524,34 +522,30 @@ async def _get_events_from_cache_or_db(
# events out of the DB multiple times.
already_fetching: Dict[str, defer.Deferred] = {}

# We also add entries to `self._current_event_fetches` for each event
# we're going to pull from the DB. We use a single deferred that
# resolves to all the events we pulled from the DB (this will result in
# this function returning more events than requested, but that can
# happen already due to `_get_events_from_db`).
fetching_deferred = ObservableDeferred(defer.Deferred())

for event_id in missing_events_ids:
deferred = self._current_event_fetches.get(event_id)
if deferred is not None:
# We're already pulling the event out of the DB, ad the deferred
# We're already pulling the event out of the DB. Add the deferred
# to the collection of deferreds to wait on.
already_fetching[event_id] = deferred.observe()
else:
# We're not already pulling the event from the DB, so add our
# deferred to the the map of events that are being fetched.
self._current_event_fetches[event_id] = fetching_deferred
fetching_deferred.observe().addBoth(
lambda _, event_id: self._current_event_fetches.pop(event_id, None),
event_id,
)

missing_events_ids.difference_update(already_fetching)

if missing_events_ids:
log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids))

# Add entries to `self._current_event_fetches` for each event we're
# going to pull from the DB. We use a single deferred that resolves
# to all the events we pulled from the DB (this will result in this
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
Dict[str, _EventCacheEntry]
] = ObservableDeferred(defer.Deferred())
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred

# Note that _get_events_from_db is also responsible for turning db rows
# into FrozenEvents (via _get_event_from_row), which involves seeing if
# the events have been redacted, and if so pulling the redaction event out
Expand All @@ -564,10 +558,16 @@ async def _get_events_from_cache_or_db(

event_entry_map.update(missing_events)
except Exception as e:
fetching_deferred.errback(e)
with PreserveLoggingContext():
fetching_deferred.errback(e)
raise e
finally:
# Ensure that we mark these events as no longer being fetched.
for event_id in missing_events_ids:
self._current_event_fetches.pop(event_id, None)
Comment on lines +565 to +567
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if doing this after fetching_deferred.errback could cause races. I can't really think how it could, though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the only thing that can happen is that another request waits on the deferred after its already been resolved, which shouldn't be an issue as awaiting on the deferred will just return immediately?


fetching_deferred.callback(missing_events)
with PreserveLoggingContext():
fetching_deferred.callback(missing_events)

if already_fetching:
# Wait for the other event requests to finish and add their results
Expand All @@ -593,13 +593,11 @@ async def _get_events_from_cache_or_db(

def _invalidate_get_event_cache(self, event_id):
self._get_event_cache.invalidate((event_id,))
self._event_ref.pop(event_id, None)
self._current_event_fetches.pop(event_id, None)

def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
) -> Dict[str, _EventCacheEntry]:
"""Fetch events from the caches
"""Fetch events from the caches, may return rejected events.
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

Args:
events: list of event_ids to fetch
Expand All @@ -608,34 +606,13 @@ def _get_events_from_cache(
event_map = {}

for event_id in events:
# First check if its in the event cache
ret = self._get_event_cache.get(
(event_id,), None, update_metrics=update_metrics
)
if ret:
event_map[event_id] = ret

# Otherwise check if we still have the event in memory.
event = self._event_ref.get(event_id)
if event:
redacted_event = None
if event.internal_metadata.redacted_by is not None:
# The event has been redacted, so we generate a redacted
# version.
redacted_event = prune_event(event)
redacted_event.unsigned[
"redacted_by"
] = event.internal_metadata.redacted_by

cache_entry = _EventCacheEntry(
event=event,
redacted_event=redacted_event,
)
event_map[event_id] = cache_entry
if not ret:
continue

# We add the entry back into the cache as we want to keep
# recently queried events in the cache.
self._get_event_cache.set((event_id,), cache_entry)
event_map[event_id] = ret

return event_map

Expand Down Expand Up @@ -765,7 +742,8 @@ def fire(evs, exc):
async def _get_events_from_db(
self, event_ids: Iterable[str]
) -> Dict[str, _EventCacheEntry]:
"""Fetch a bunch of events from the database.
"""Fetch a bunch of events from the database, may return rejected
events.
erikjohnston marked this conversation as resolved.
Show resolved Hide resolved

Returned events will be added to the cache for future lookups.

Expand Down Expand Up @@ -905,20 +883,13 @@ async def _get_events_from_db(
original_ev, redactions, event_map
)

if redacted_event:
original_ev.internal_metadata.redacted_by = redacted_event.unsigned[
"redacted_by"
]

cache_entry = _EventCacheEntry(
event=original_ev, redacted_event=redacted_event
)

self._get_event_cache.set((event_id,), cache_entry)
result_map[event_id] = cache_entry

self._event_ref[event_id] = original_ev

return result_map

async def _enqueue_events(self, events):
Expand Down
25 changes: 0 additions & 25 deletions tests/storage/databases/main/test_events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,31 +131,6 @@ def test_simple(self):
# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)

def test_event_ref(self):
"""Test that we reuse events that are still in memory but have fallen
out of the cache, rather than requesting them from the DB.
"""

# Reset the event cache
self.store._get_event_cache.clear()

with LoggingContext("test") as ctx:
# We keep hold of the event event though we never use it.
event = self.get_success(self.store.get_event(self.event_id)) # noqa: F841

# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)

# Reset the event cache
self.store._get_event_cache.clear()

with LoggingContext("test") as ctx:
self.get_success(self.store.get_event(self.event_id))

# Since the event is still in memory we shouldn't have fetched it
# from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 0)

def test_dedupe(self):
"""Test that if we request the same event multiple times we only pull it
out once.
Expand Down