Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Faster joins: Support for calling /federation/v1/state #12013

Merged
merged 12 commits into from
Feb 22, 2022
1 change: 1 addition & 0 deletions changelog.d/12013.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server.
69 changes: 69 additions & 0 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import copy
import itertools
import logging
from functools import partial
from typing import (
TYPE_CHECKING,
Awaitable,
Expand Down Expand Up @@ -413,6 +414,74 @@ async def get_room_state_ids(

return state_event_ids, auth_event_ids

async def get_room_state(
self,
destination: str,
room_id: str,
event_id: str,
room_version: RoomVersion,
) -> Tuple[List[EventBase], List[EventBase]]:
"""Calls the /state endpoint to fetch the state at a particular point
in the room.

Returns:
a tuple of (state events, auth events)
"""
result = await self.transport_layer.get_room_state(
room_version,
destination,
room_id,
event_id,
)
state_events = result.state
auth_events = result.auth_events

# we may as filter out any duplicates from the response, which avoids some
# potential failure modes later.
richvdh marked this conversation as resolved.
Show resolved Hide resolved
#
# We don't rely on the sort order of either of them, which makes it easier
richvdh marked this conversation as resolved.
Show resolved Hide resolved
state_event_map = {event.event_id: event for event in state_events}
auth_event_map = {
event.event_id: event
for event in auth_events
if event.event_id not in state_event_map
}

logger.info(
"Processing from /state: %d state events, %d auth events",
len(state_event_map),
len(auth_event_map),
)

valid_state_events: List[EventBase] = []
valid_auth_events: List[EventBase] = []

async def _append_valid_events_to_list(
richvdh marked this conversation as resolved.
Show resolved Hide resolved
pdu: EventBase,
target_list: List[EventBase],
) -> None:
valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
pdu=pdu,
origin=destination,
room_version=room_version,
)

if valid_pdu:
target_list.append(valid_pdu)

await concurrently_execute(
partial(_append_valid_events_to_list, target_list=valid_state_events),
state_events,
10000,
richvdh marked this conversation as resolved.
Show resolved Hide resolved
)
await concurrently_execute(
partial(_append_valid_events_to_list, target_list=valid_auth_events),
state_events,
10000,
)

return valid_state_events, valid_auth_events

async def _check_sigs_and_hash_and_fetch(
self,
origin: str,
Expand Down
70 changes: 67 additions & 3 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,12 @@ def __init__(self, hs):
async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
) -> JsonDict:
"""Requests all state for a given room from the given server at the
given event. Returns the state's event_id's
"""Requests the IDs of all state for a given room at the given event.

Args:
destination: The host name of the remote homeserver we want
to get the state from.
context: The name of the context we want the state of
room_id: the room we want the state of
event_id: The event we want the context at.

Returns:
Expand All @@ -86,6 +85,29 @@ async def get_room_state_ids(
try_trailing_slash_on_400=True,
)

async def get_room_state(
self, room_version: RoomVersion, destination: str, room_id: str, event_id: str
) -> "StateRequestResponse":
"""Requests the full state for a given room at the given event.

Args:
room_version: the version of the room (required to build the event objects)
destination: The host name of the remote homeserver we want
to get the state from.
room_id: the room we want the state of
event_id: The event we want the context at.

Returns:
Results in a dict received from the remote homeserver.
"""
path = _create_v1_path("/state/%s", room_id)
return await self.client.get_json(
destination,
path=path,
args={"event_id": event_id},
parser=_StateParser(room_version),
)

async def get_event(
self, destination: str, event_id: str, timeout: Optional[int] = None
) -> JsonDict:
Expand Down Expand Up @@ -1272,6 +1294,14 @@ class SendJoinResponse:
event: Optional[EventBase] = None


@attr.s(slots=True, auto_attribs=True)
class StateRequestResponse:
"""The parsed response of a `/state` request."""

auth_events: List[EventBase]
state: List[EventBase]


@ijson.coroutine
def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
"""Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
Expand Down Expand Up @@ -1355,3 +1385,37 @@ def finish(self) -> SendJoinResponse:
self._response.event_dict, self._room_version
)
return self._response


class _StateParser(ByteParser[StateRequestResponse]):
"""A parser for the response to `/state` requests.

Args:
room_version: The version of the room.
"""

CONTENT_TYPE = "application/json"

def __init__(self, room_version: RoomVersion):
self._response = StateRequestResponse([], [])
self._room_version = room_version
self._coros = [
ijson.items_coro(
_event_list_parser(room_version, self._response.state),
"pdus.item",
use_float=True,
),
ijson.items_coro(
_event_list_parser(room_version, self._response.auth_events),
"auth_chain.item",
use_float=True,
),
]

def write(self, data: bytes) -> int:
for c in self._coros:
c.send(data)
return len(data)

def finish(self) -> StateRequestResponse:
return self._response
50 changes: 49 additions & 1 deletion synapse/http/matrixfederationclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,7 @@ async def post_json(
)
return body

@overload
async def get_json(
self,
destination: str,
Expand All @@ -967,7 +968,38 @@ async def get_json(
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
max_response_size: Optional[int] = None,
) -> Union[JsonDict, list]:
...

@overload
async def get_json(
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser[T]] = None,
max_response_size: Optional[int] = None,
) -> T:
...
richvdh marked this conversation as resolved.
Show resolved Hide resolved

async def get_json(
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser] = None,
max_response_size: Optional[int] = None,
):
"""GETs some json from the given host homeserver and path

Args:
Expand All @@ -992,6 +1024,13 @@ async def get_json(
try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.

parser: The parser to use to decode the response. Defaults to
parsing as JSON.

max_response_size: The maximum size to read from the response. If None,
uses the default.

Returns:
Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Expand Down Expand Up @@ -1026,8 +1065,17 @@ async def get_json(
else:
_sec_timeout = self.default_timeout

if parser is None:
parser = JsonParser()

body = await _handle_response(
self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
self.reactor,
_sec_timeout,
request,
response,
start_ms,
parser=parser,
max_response_size=max_response_size,
)

return body
Expand Down