Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Convert retun type of get_events_around to attrs.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Jan 25, 2022
1 parent 0d3b625 commit 1777831
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 47 deletions.
32 changes: 15 additions & 17 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,40 +1180,38 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
results = await self.store.get_events_around(
room_id, event_id, before_limit, after_limit, event_filter
)
events_before = results.events_before
events_after = results.events_after

if event_filter:
results["events_before"] = await event_filter.filter(
results["events_before"]
)
results["events_after"] = await event_filter.filter(results["events_after"])
events_before = await event_filter.filter(events_before)
events_after = await event_filter.filter(events_after)

results["events_before"] = await filter_evts(results["events_before"])
results["events_after"] = await filter_evts(results["events_after"])
events_before = await filter_evts(events_before)
events_after = await filter_evts(events_after)
# filter_evts can return a pruned event in case the user is allowed to see that
# there's something there but not see the content, so use the event that's in
# `filtered` rather than the event we retrieved from the datastore.
event = filtered[0]

# Fetch the aggregations.
aggregations = await self.store.get_bundled_aggregations(
itertools.chain(
results["events_before"], (event,), results["events_after"]
),
itertools.chain(events_before, (event,), events_after),
user.to_string(),
)

if results["events_after"]:
last_event_id = results["events_after"][-1].event_id
if events_after:
last_event_id = events_after[-1].event_id
else:
last_event_id = event_id

if event_filter and event_filter.lazy_load_members:
state_filter = StateFilter.from_lazy_load_member_list(
ev.sender
for ev in itertools.chain(
results["events_before"],
events_before,
(event,),
results["events_after"],
events_after,
)
)
else:
Expand All @@ -1236,15 +1234,15 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
token = StreamToken.START

return EventContext(
events_before=results["events_before"],
events_before=events_before,
event=event,
events_after=results["events_after"],
events_after=events_after,
state=await filter_evts(state_events),
aggregations=aggregations,
start=await token.copy_and_replace("room_key", results["start"]).to_string(
start=await token.copy_and_replace("room_key", results.start).to_string(
self.store
),
end=await token.copy_and_replace("room_key", results["end"]).to_string(
end=await token.copy_and_replace("room_key", results.end).to_string(
self.store
),
)
Expand Down
45 changes: 23 additions & 22 deletions synapse/handlers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,36 +361,37 @@ async def search(

logger.info(
"Context for search returned %d and %d events",
len(res["events_before"]),
len(res["events_after"]),
len(res.events_before),
len(res.events_after),
)

res["events_before"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_before"]
events_before = await filter_events_for_client(
self.storage, user.to_string(), res.events_before
)

res["events_after"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_after"]
events_after = await filter_events_for_client(
self.storage, user.to_string(), res.events_after
)

res["start"] = await now_token.copy_and_replace(
"room_key", res["start"]
).to_string(self.store)

res["end"] = await now_token.copy_and_replace(
"room_key", res["end"]
).to_string(self.store)
context = {
"events_before": events_before,
"events_after": events_after,
"start": await now_token.copy_and_replace(
"room_key", res.start
).to_string(self.store),
"end": await now_token.copy_and_replace(
"room_key", res.end
).to_string(self.store),
}

if include_profile:
senders = {
ev.sender
for ev in itertools.chain(
res["events_before"], [event], res["events_after"]
)
for ev in itertools.chain(events_before, [event], events_after)
}

if res["events_after"]:
last_event_id = res["events_after"][-1].event_id
if events_after:
last_event_id = events_after[-1].event_id
else:
last_event_id = event.event_id

Expand All @@ -402,7 +403,7 @@ async def search(
last_event_id, state_filter
)

res["profile_info"] = {
context["profile_info"] = {
s.state_key: {
"displayname": s.content.get("displayname", None),
"avatar_url": s.content.get("avatar_url", None),
Expand All @@ -411,7 +412,7 @@ async def search(
if s.type == EventTypes.Member and s.state_key in senders
}

contexts[event.event_id] = res
contexts[event.event_id] = context
else:
contexts = {}

Expand All @@ -421,10 +422,10 @@ async def search(

for context in contexts.values():
context["events_before"] = self._event_serializer.serialize_events(
context["events_before"], time_now
context["events_before"], time_now # type: ignore[arg-type]
)
context["events_after"] = self._event_serializer.serialize_events(
context["events_after"], time_now
context["events_after"], time_now # type: ignore[arg-type]
)

state_results = {}
Expand Down
2 changes: 1 addition & 1 deletion synapse/push/mailer.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ async def _get_notif_vars(
}

the_events = await filter_events_for_client(
self.storage, user_id, results["events_before"]
self.storage, user_id, results.events_before
)
the_events.append(notif_event)

Expand Down
22 changes: 15 additions & 7 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ class _EventDictReturn:
stream_ordering: int


@attr.s(slots=True, frozen=True, auto_attribs=True)
class _EventsAround:
events_before: List[EventBase]
events_after: List[EventBase]
start: RoomStreamToken
end: RoomStreamToken


def generate_pagination_where_clause(
direction: str,
column_names: Tuple[str, str],
Expand Down Expand Up @@ -846,7 +854,7 @@ async def get_events_around(
before_limit: int,
after_limit: int,
event_filter: Optional[Filter] = None,
) -> dict:
) -> _EventsAround:
"""Retrieve events and pagination tokens around a given event in a
room.
"""
Expand All @@ -869,12 +877,12 @@ async def get_events_around(
list(results["after"]["event_ids"]), get_prev_content=True
)

return {
"events_before": events_before,
"events_after": events_after,
"start": results["before"]["token"],
"end": results["after"]["token"],
}
return _EventsAround(
events_before=events_before,
events_after=events_after,
start=results["before"]["token"],
end=results["after"]["token"],
)

def _get_events_around_txn(
self,
Expand Down

0 comments on commit 1777831

Please sign in to comment.