Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Return an immutable value from get_latest_event_ids_in_room. (#16326)
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep authored Sep 18, 2023
1 parent 63d28a8 commit 85bfd47
Show file tree
Hide file tree
Showing 12 changed files with 48 additions and 40 deletions.
1 change: 1 addition & 0 deletions changelog.d/16326.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
2 changes: 1 addition & 1 deletion synapse/events/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def is_state(self) -> bool:

async def build(
self,
prev_event_ids: StrCollection,
prev_event_ids: List[str],
auth_event_ids: Optional[List[str]],
depth: Optional[int] = None,
) -> EventBase:
Expand Down
8 changes: 3 additions & 5 deletions synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,12 +723,11 @@ async def _get_missing_events_for_pdu(
if not prevs - seen:
return

latest_list = await self._store.get_latest_event_ids_in_room(room_id)
latest_frozen = await self._store.get_latest_event_ids_in_room(room_id)

# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
latest = set(latest_list)
latest |= seen
latest = seen | latest_frozen

logger.info(
"Requesting missing events between %s and %s",
Expand Down Expand Up @@ -1976,8 +1975,7 @@ async def _check_for_soft_fail(
# partial and full state and may not be accurate.
return

extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids_list)
extrem_ids = await self._store.get_latest_event_ids_in_room(event.room_id)
prev_event_ids = set(event.prev_event_ids())

if extrem_ids == prev_event_ids:
Expand Down
9 changes: 4 additions & 5 deletions synapse/storage/controllers/persist_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections import deque
from typing import (
TYPE_CHECKING,
AbstractSet,
Any,
Awaitable,
Callable,
Expand Down Expand Up @@ -618,7 +619,7 @@ async def _persist_event_batch(
)

for room_id, ev_ctx_rm in events_by_room.items():
latest_event_ids = set(
latest_event_ids = (
await self.main_store.get_latest_event_ids_in_room(room_id)
)
new_latest_event_ids = await self._calculate_new_extremities(
Expand Down Expand Up @@ -740,7 +741,7 @@ async def _calculate_new_extremities(
self,
room_id: str,
event_contexts: List[Tuple[EventBase, EventContext]],
latest_event_ids: Collection[str],
latest_event_ids: AbstractSet[str],
) -> Set[str]:
"""Calculates the new forward extremities for a room given events to
persist.
Expand All @@ -758,8 +759,6 @@ async def _calculate_new_extremities(
and not event.internal_metadata.is_soft_failed()
]

latest_event_ids = set(latest_event_ids)

# start with the existing forward extremities
result = set(latest_event_ids)

Expand Down Expand Up @@ -798,7 +797,7 @@ async def _get_new_state_after_events(
self,
room_id: str,
events_context: List[Tuple[EventBase, EventContext]],
old_latest_event_ids: Set[str],
old_latest_event_ids: AbstractSet[str],
new_latest_event_ids: Set[str],
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]], Set[str]]:
"""Calculate the current state dict after adding some new events to
Expand Down
8 changes: 5 additions & 3 deletions synapse/storage/databases/main/event_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Expand Down Expand Up @@ -47,7 +48,7 @@
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.signatures import SignatureWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import JsonDict, StrCollection, StrSequence
from synapse.types import JsonDict, StrCollection
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
Expand Down Expand Up @@ -1179,13 +1180,14 @@ def _get_rooms_with_many_extremities_txn(txn: LoggingTransaction) -> List[str]:
)

@cached(max_entries=5000, iterable=True)
async def get_latest_event_ids_in_room(self, room_id: str) -> StrSequence:
return await self.db_pool.simple_select_onecol(
async def get_latest_event_ids_in_room(self, room_id: str) -> FrozenSet[str]:
event_ids = await self.db_pool.simple_select_onecol(
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
desc="get_latest_event_ids_in_room",
)
return frozenset(event_ids)

async def get_min_depth(self, room_id: str) -> Optional[int]:
"""For the given room, get the minimum depth we have seen for it."""
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ async def _persist_events_and_state_updates(

for room_id, latest_event_ids in new_forward_extremities.items():
self.store.get_latest_event_ids_in_room.prefill(
(room_id,), list(latest_event_ids)
(room_id,), frozenset(latest_event_ids)
)

async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -1858,7 +1858,7 @@ def _add_new_user(self, room_id: str, user_id: str) -> None:
)

event = self.get_success(
builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
)

self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event))
Expand Down
4 changes: 2 additions & 2 deletions tests/replication/storage/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def tearDown(self) -> None:
def test_get_latest_event_ids_in_room(self) -> None:
create = self.persist(type="m.room.create", key="", creator=USER_ID)
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
self.check("get_latest_event_ids_in_room", (ROOM_ID,), {create.event_id})

join = self.persist(
type="m.room.member",
Expand All @@ -99,7 +99,7 @@ def test_get_latest_event_ids_in_room(self) -> None:
prev_events=[(create.event_id, {})],
)
self.replicate()
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
self.check("get_latest_event_ids_in_room", (ROOM_ID,), {join.event_id})

def test_redactions(self) -> None:
self.persist(type="m.room.create", key="", creator=USER_ID)
Expand Down
10 changes: 5 additions & 5 deletions tests/replication/tcp/streams/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Optional, Sequence
from typing import Any, List, Optional

from twisted.test.proto_helpers import MemoryReactor

Expand Down Expand Up @@ -139,7 +139,7 @@ def test_update_function_huge_state_change(self) -> None:
)

# this is the point in the DAG where we make a fork
fork_point: Sequence[str] = self.get_success(
fork_point = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)

Expand Down Expand Up @@ -294,7 +294,7 @@ def test_update_function_state_row_limit(self) -> None:
)

# this is the point in the DAG where we make a fork
fork_point: Sequence[str] = self.get_success(
fork_point = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)

Expand All @@ -316,14 +316,14 @@ def test_update_function_state_row_limit(self) -> None:
self.test_handler.received_rdata_rows.clear()

# now roll back all that state by de-modding the users
prev_events = fork_point
prev_events = list(fork_point)
pl_events = []
for u in user_ids:
pls["users"][u] = 0
e = self.get_success(
inject_event(
self.hs,
prev_event_ids=list(prev_events),
prev_event_ids=prev_events,
type=EventTypes.PowerLevels,
state_key="",
sender=self.user_id,
Expand Down
2 changes: 1 addition & 1 deletion tests/replication/test_federation_sender_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def create_room_with_remote_server(

builder = factory.for_room_version(room_version, event_dict)
join_event = self.get_success(
builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None)
builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None)
)

self.get_success(federation.on_send_membership_event(remote_server, join_event))
Expand Down
14 changes: 7 additions & 7 deletions tests/storage/test_cleanup_extrems.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_soft_failed_extremities_handled_correctly(self) -> None:
self.store.get_latest_event_ids_in_room(self.room_id)
)

self.assertEqual(latest_event_ids, [event_id_4])
self.assertEqual(latest_event_ids, {event_id_4})

def test_basic_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of
Expand All @@ -147,15 +147,15 @@ def test_basic_cleanup(self) -> None:
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
self.assertEqual(latest_event_ids, {event_id_a, event_id_b})

# Run the background update and check it did the right thing
self.run_background_update()

latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
self.assertEqual(latest_event_ids, [event_id_b])
self.assertEqual(latest_event_ids, {event_id_b})

def test_chain_of_fail_cleanup(self) -> None:
"""Test that extremities are correctly calculated in the presence of
Expand Down Expand Up @@ -185,15 +185,15 @@ def test_chain_of_fail_cleanup(self) -> None:
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b})
self.assertEqual(latest_event_ids, {event_id_a, event_id_b})

# Run the background update and check it did the right thing
self.run_background_update()

latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
self.assertEqual(latest_event_ids, [event_id_b])
self.assertEqual(latest_event_ids, {event_id_b})

def test_forked_graph_cleanup(self) -> None:
r"""Test that extremities are correctly calculated in the presence of
Expand Down Expand Up @@ -240,15 +240,15 @@ def test_forked_graph_cleanup(self) -> None:
latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c})
self.assertEqual(latest_event_ids, {event_id_a, event_id_b, event_id_c})

# Run the background update and check it did the right thing
self.run_background_update()

latest_event_ids = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)
self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c})
self.assertEqual(latest_event_ids, {event_id_b, event_id_c})


class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
Expand Down
26 changes: 17 additions & 9 deletions tests/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,15 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = self.hs.get_datastores().main

# Figure out what the most recent event is
most_recent = self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
)[0]
most_recent = next(
iter(
self.get_success(
self.hs.get_datastores().main.get_latest_event_ids_in_room(
self.room_id
)
)
)
)

join_event = make_event_from_dict(
{
Expand Down Expand Up @@ -100,8 +106,8 @@ async def _check_sigs_and_hash_for_pulled_events_and_fetch(

# Make sure we actually joined the room
self.assertEqual(
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0],
"$join:test.serv",
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)),
{"$join:test.serv"},
)

def test_cant_hide_direct_ancestors(self) -> None:
Expand All @@ -127,9 +133,11 @@ async def post_json(
self.http_client.post_json = post_json

# Figure out what the most recent event is
most_recent = self.get_success(
self.store.get_latest_event_ids_in_room(self.room_id)
)[0]
most_recent = next(
iter(
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
)
)

# Now lie about an event
lying_event = make_event_from_dict(
Expand Down Expand Up @@ -165,7 +173,7 @@ async def post_json(

# Make sure the invalid event isn't there
extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
self.assertEqual(extrem[0], "$join:test.serv")
self.assertEqual(extrem, {"$join:test.serv"})

def test_retry_device_list_resync(self) -> None:
"""Tests that device lists are marked as stale if they couldn't be synced, and
Expand Down

0 comments on commit 85bfd47

Please sign in to comment.