Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix @cachedList on _have_seen_events_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
MadLittleMods committed Sep 22, 2022
1 parent 2162ab5 commit 0cdc7bf
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 19 deletions.
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _invalidate_caches_for_event(
# process triggering the invalidation is responsible for clearing any external
# cached objects.
self._invalidate_local_get_event_cache(event_id)
self.have_seen_event.invalidate(((room_id, event_id),))
self.have_seen_event.invalidate((room_id, event_id))

self.get_latest_event_ids_in_room.invalidate((room_id,))

Expand Down
40 changes: 22 additions & 18 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1474,32 +1474,38 @@ async def have_seen_events(
# the batches as big as possible.

results: Set[str] = set()
for chunk in batch_iter(event_ids, 500):
r = await self._have_seen_events_dict(
[(room_id, event_id) for event_id in chunk]
for event_ids_chunk in batch_iter(event_ids, 500):
events_seen_dict = await self._have_seen_events_dict(
room_id, event_ids_chunk
)
results.update(
eid for (eid, have_event) in events_seen_dict.items() if have_event
)
results.update(eid for ((_rid, eid), have_event) in r.items() if have_event)

return results

@cachedList(cached_method_name="have_seen_event", list_name="keys")
@cachedList(cached_method_name="have_seen_event", list_name="event_ids")
async def _have_seen_events_dict(
self, keys: Collection[Tuple[str, str]]
) -> Dict[Tuple[str, str], bool]:
self,
room_id: str,
event_ids: Collection[str],
) -> Dict[str, bool]:
"""Helper for have_seen_events
Returns:
a dict {(room_id, event_id)-> bool}
a dict {event_id -> bool}
"""
# if the event cache contains the event, obviously we've seen it.

cache_results = {
(rid, eid)
for (rid, eid) in keys
if await self._get_event_cache.contains((eid,))
event_id
for event_id in event_ids
if await self._get_event_cache.contains((event_id,))
}
results = dict.fromkeys(cache_results, True)
remaining = [k for k in keys if k not in cache_results]
remaining = [
event_id for event_id in event_ids if event_id not in cache_results
]
if not remaining:
return results

Expand All @@ -1511,23 +1517,21 @@ def have_seen_events_txn(txn: LoggingTransaction) -> None:

sql = "SELECT event_id FROM events AS e WHERE "
clause, args = make_in_list_sql_clause(
txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining]
txn.database_engine, "e.event_id", remaining
)
txn.execute(sql + clause, args)
found_events = {eid for eid, in txn}

# ... and then we can update the results for each key
results.update(
{(rid, eid): (eid in found_events) for (rid, eid) in remaining}
)
results.update({eid: (eid in found_events) for eid in remaining})

await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn)
return results

@cached(max_entries=100000, tree=True)
async def have_seen_event(self, room_id: str, event_id: str) -> bool:
res = await self._have_seen_events_dict(((room_id, event_id),))
return res[(room_id, event_id)]
res = await self._have_seen_events_dict(room_id, [event_id])
return res[event_id]

def _get_current_state_event_counts_txn(
self, txn: LoggingTransaction, room_id: str
Expand Down

0 comments on commit 0cdc7bf

Please sign in to comment.