Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Raise an exception when getting state at an outlier #12191

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/12191.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Avoid trying to calculate the state at outlier events.
16 changes: 12 additions & 4 deletions synapse/storage/databases/main/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Iterable, Optional, Set, Tuple
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple

from frozendict import frozendict

Expand Down Expand Up @@ -309,9 +309,13 @@ async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
num_args=1,
)
async def _get_state_group_for_events(
self, event_ids: Iterable[str]
self, event_ids: Collection[str]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, because we now iterate over event_ids twice.

) -> Dict[str, int]:
"""Returns mapping event_id -> state_group"""
"""Returns mapping event_id -> state_group.

Raises:
RuntimeError if the state is unknown at any of the given events
"""
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
Expand All @@ -321,7 +325,11 @@ async def _get_state_group_for_events(
desc="_get_state_group_for_events",
)

return {row["event_id"]: row["state_group"] for row in rows}
res = {row["event_id"]: row["state_group"] for row in rows}
for e in event_ids:
if e not in res:
raise RuntimeError("No state group for unknown or outlier event %s" % e)
return res

async def get_referenced_state_groups(
self, state_groups: Iterable[int]
Expand Down
20 changes: 20 additions & 0 deletions synapse/storage/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,10 @@ async def get_state_groups_ids(

Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id)

Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
if not event_ids:
return {}
Expand Down Expand Up @@ -659,6 +663,10 @@ async def get_state_for_events(

Returns:
A dict of (event_id) -> (type, state_key) -> [state_events]

Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)

Expand Down Expand Up @@ -696,6 +704,10 @@ async def get_state_ids_for_events(

Returns:
A dict from event_id -> (type, state_key) -> event_id

Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)

Expand Down Expand Up @@ -723,6 +735,10 @@ async def get_state_for_event(

Returns:
A dict from (type, state_key) -> state_event

Raises:
RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown)
"""
state_map = await self.get_state_for_events(
[event_id], state_filter or StateFilter.all()
Expand All @@ -741,6 +757,10 @@ async def get_state_ids_for_event(

Returns:
A dict from (type, state_key) -> state_event_id

Raises:
RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown)
"""
state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all()
Expand Down
72 changes: 54 additions & 18 deletions tests/handlers/test_federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.types import create_requester
from synapse.util import Clock
from synapse.util.stringutils import random_string

from tests import unittest
from tests.test_utils import event_injection

logger = logging.getLogger(__name__)

Expand All @@ -39,7 +39,7 @@ def generate_fake_event_id() -> str:
return "$fake_" + random_string(43)


class FederationTestCase(unittest.HomeserverTestCase):
class FederationTestCase(unittest.FederatingHomeserverTestCase):
servlets = [
admin.register_servlets,
login.register_servlets,
Expand Down Expand Up @@ -219,41 +219,77 @@ def test_backfill_with_many_backward_extremities(self) -> None:
# create the room
user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
requester = create_requester(user_id)

room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
room_version = self.get_success(self.store.get_room_version(room_id))

# we need a user on the remote server to be a member, so that we can send
# extremity-causing events.
self.get_success(
event_injection.inject_member_event(
self.hs, room_id, f"@user:{self.OTHER_SERVER_NAME}", "join"
)
)

ev1 = self.helper.send(room_id, "first message", tok=tok)
send_result = self.helper.send(room_id, "first message", tok=tok)
ev1 = self.get_success(
self.store.get_event(send_result["event_id"], allow_none=False)
)
current_state = self.get_success(
self.store.get_events_as_list(
(self.get_success(self.store.get_current_state_ids(room_id))).values()
)
)

# Create "many" backward extremities. The magic number we're trying to
# create more than is 5 which corresponds to the number of backward
# extremities we slice off in `_maybe_backfill_inner`
federation_event_handler = self.hs.get_federation_event_handler()
for _ in range(0, 8):
event_handler = self.hs.get_event_creation_handler()
event, context = self.get_success(
event_handler.create_event(
requester,
event = make_event_from_dict(
self.add_hashes_and_signatures(
{
"origin_server_ts": 1,
"type": "m.room.message",
"content": {
"msgtype": "m.text",
"body": "message connected to fake event",
},
"room_id": room_id,
"sender": user_id,
"sender": f"@user:{self.OTHER_SERVER_NAME}",
"prev_events": [
ev1.event_id,
# We're creating an backward extremity each time thanks
# to this fake event
generate_fake_event_id(),
],
# lazy: *everything* is an auth event
"auth_events": [ev.event_id for ev in current_state],
"depth": ev1.depth + 1,
},
prev_event_ids=[
ev1["event_id"],
# We're creating an backward extremity each time thanks
# to this fake event
generate_fake_event_id(),
],
)
room_version,
),
room_version,
)

# we poke this directly into _process_received_pdu, to avoid the
# federation handler wanting to backfill the fake event.
self.get_success(
event_handler.handle_new_client_event(requester, event, context)
federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME, event, state=current_state
)
)

# we should now have 8 backwards extremities.
backwards_extremities = self.get_success(
self.store.db_pool.simple_select_list(
"event_backward_extremities",
keyvalues={"room_id": room_id},
retcols=["event_id"],
)
)
self.assertEqual(len(backwards_extremities), 8)

current_depth = 1
limit = 100
with LoggingContext("receive_pdu"):
Expand Down