diff --git a/changelog.d/12013.misc b/changelog.d/12013.misc new file mode 100644 index 000000000000..c0fca8dccbae --- /dev/null +++ b/changelog.d/12013.misc @@ -0,0 +1 @@ +Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server. diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index 896168c05c0a..fab6da3c087f 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -47,6 +47,11 @@ async def _check_sigs_and_hash( ) -> EventBase: """Checks that event is correctly signed by the sending server. + Also checks the content hash, and redacts the event if there is a mismatch. + + Also runs the event through the spam checker; if it fails, redacts the event + and flags it as soft-failed. + Args: room_version: The room version of the PDU pdu: the event to be checked @@ -55,7 +60,10 @@ async def _check_sigs_and_hash( * the original event if the checks pass * a redacted version of the event (if the signature matched but the hash did not) - * throws a SynapseError if the signature check failed.""" + + Raises: + SynapseError if the signature check failed. + """ try: await _check_sigs_on_pdu(self.keyring, room_version, pdu) except SynapseError as e: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 74f17aa4daa3..aadc147e6385 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -413,26 +413,90 @@ async def get_room_state_ids( return state_event_ids, auth_event_ids + async def get_room_state( + self, + destination: str, + room_id: str, + event_id: str, + room_version: RoomVersion, + ) -> Tuple[List[EventBase], List[EventBase]]: + """Calls the /state endpoint to fetch the state at a particular point + in the room. + + Any invalid events (those with incorrect or unverifiable signatures or hashes) + are filtered out from the response, and any duplicate events are removed. + + (Size limits and other event-format checks are *not* performed.) + + Note that the result is not ordered, so callers must be careful to process + the events in an order that handles dependencies. + + Returns: + a tuple of (state events, auth events) + """ + result = await self.transport_layer.get_room_state( + room_version, + destination, + room_id, + event_id, + ) + state_events = result.state + auth_events = result.auth_events + + # we may as well filter out any duplicates from the response, to save + # processing them multiple times. (In particular, events may be present in + # `auth_events` as well as `state`, which is redundant). + # + # We don't rely on the sort order of the events, so we can just stick them + # in a dict. + state_event_map = {event.event_id: event for event in state_events} + auth_event_map = { + event.event_id: event + for event in auth_events + if event.event_id not in state_event_map + } + + logger.info( + "Processing from /state: %d state events, %d auth events", + len(state_event_map), + len(auth_event_map), + ) + + valid_auth_events = await self._check_sigs_and_hash_and_fetch( + destination, auth_event_map.values(), room_version + ) + + valid_state_events = await self._check_sigs_and_hash_and_fetch( + destination, state_event_map.values(), room_version + ) + + return valid_state_events, valid_auth_events + async def _check_sigs_and_hash_and_fetch( self, origin: str, pdus: Collection[EventBase], room_version: RoomVersion, ) -> List[EventBase]: - """Takes a list of PDUs and checks the signatures and hashes of each - one. If a PDU fails its signature check then we check if we have it in - the database and if not then request if from the originating server of - that PDU. + """Checks the signatures and hashes of a list of events. + + If a PDU fails its signature check then we check if we have it in + the database, and if not then request it from the sender's server (if that + is different from `origin`). If that still fails, the event is omitted from + the returned list. If a PDU fails its content hash check then it is redacted. - The given list of PDUs are not modified, instead the function returns + Also runs each event through the spam checker; if it fails, redacts the event + and flags it as soft-failed. + + The given list of PDUs are not modified; instead the function returns a new list. Args: - origin - pdu - room_version + origin: The server that sent us these events + pdus: The events to be checked + room_version: the version of the room these events are in Returns: A list of PDUs that have valid signatures and hashes. @@ -463,11 +527,16 @@ async def _check_sigs_and_hash_and_fetch_one( origin: str, room_version: RoomVersion, ) -> Optional[EventBase]: - """Takes a PDU and checks its signatures and hashes. If the PDU fails - its signature check then we check if we have it in the database and if - not then request if from the originating server of that PDU. + """Takes a PDU and checks its signatures and hashes. + + If the PDU fails its signature check then we check if we have it in the + database; if not, we then request it from sender's server (if that is not the + same as `origin`). If that still fails, we return None. + + If the PDU fails its content hash check, it is redacted. - If then PDU fails its content hash check then it is redacted. + Also runs the event through the spam checker; if it fails, redacts the event + and flags it as soft-failed. Args: origin diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 8782586cd6b4..696e06d6b6af 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -64,13 +64,12 @@ def __init__(self, hs): async def get_room_state_ids( self, destination: str, room_id: str, event_id: str ) -> JsonDict: - """Requests all state for a given room from the given server at the - given event. Returns the state's event_id's + """Requests the IDs of all state for a given room at the given event. Args: destination: The host name of the remote homeserver we want to get the state from. - context: The name of the context we want the state of + room_id: the room we want the state of event_id: The event we want the context at. Returns: @@ -86,6 +85,29 @@ async def get_room_state_ids( try_trailing_slash_on_400=True, ) + async def get_room_state( + self, room_version: RoomVersion, destination: str, room_id: str, event_id: str + ) -> "StateRequestResponse": + """Requests the full state for a given room at the given event. + + Args: + room_version: the version of the room (required to build the event objects) + destination: The host name of the remote homeserver we want + to get the state from. + room_id: the room we want the state of + event_id: The event we want the context at. + + Returns: + Results in a dict received from the remote homeserver. + """ + path = _create_v1_path("/state/%s", room_id) + return await self.client.get_json( + destination, + path=path, + args={"event_id": event_id}, + parser=_StateParser(room_version), + ) + async def get_event( self, destination: str, event_id: str, timeout: Optional[int] = None ) -> JsonDict: @@ -1272,6 +1294,14 @@ class SendJoinResponse: event: Optional[EventBase] = None +@attr.s(slots=True, auto_attribs=True) +class StateRequestResponse: + """The parsed response of a `/state` request.""" + + auth_events: List[EventBase] + state: List[EventBase] + + @ijson.coroutine def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]: """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs @@ -1355,3 +1385,37 @@ def finish(self) -> SendJoinResponse: self._response.event_dict, self._room_version ) return self._response + + +class _StateParser(ByteParser[StateRequestResponse]): + """A parser for the response to `/state` requests. + + Args: + room_version: The version of the room. + """ + + CONTENT_TYPE = "application/json" + + def __init__(self, room_version: RoomVersion): + self._response = StateRequestResponse([], []) + self._room_version = room_version + self._coros = [ + ijson.items_coro( + _event_list_parser(room_version, self._response.state), + "pdus.item", + use_float=True, + ), + ijson.items_coro( + _event_list_parser(room_version, self._response.auth_events), + "auth_chain.item", + use_float=True, + ), + ] + + def write(self, data: bytes) -> int: + for c in self._coros: + c.send(data) + return len(data) + + def finish(self) -> StateRequestResponse: + return self._response diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index c5f8fcbb2a7d..e7656fbb9f7c 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -958,6 +958,7 @@ async def post_json( ) return body + @overload async def get_json( self, destination: str, @@ -967,7 +968,38 @@ async def get_json( timeout: Optional[int] = None, ignore_backoff: bool = False, try_trailing_slash_on_400: bool = False, + parser: Literal[None] = None, + max_response_size: Optional[int] = None, ) -> Union[JsonDict, list]: + ... + + @overload + async def get_json( + self, + destination: str, + path: str, + args: Optional[QueryArgs] = ..., + retry_on_dns_fail: bool = ..., + timeout: Optional[int] = ..., + ignore_backoff: bool = ..., + try_trailing_slash_on_400: bool = ..., + parser: ByteParser[T] = ..., + max_response_size: Optional[int] = ..., + ) -> T: + ... + + async def get_json( + self, + destination: str, + path: str, + args: Optional[QueryArgs] = None, + retry_on_dns_fail: bool = True, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Optional[ByteParser] = None, + max_response_size: Optional[int] = None, + ): """GETs some json from the given host homeserver and path Args: @@ -992,6 +1024,13 @@ async def get_json( try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED response we should try appending a trailing slash to the end of the request. Workaround for #3622 in Synapse <= v0.99.3. + + parser: The parser to use to decode the response. Defaults to + parsing as JSON. + + max_response_size: The maximum size to read from the response. If None, + uses the default. + Returns: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. @@ -1026,8 +1065,17 @@ async def get_json( else: _sec_timeout = self.default_timeout + if parser is None: + parser = JsonParser() + body = await _handle_response( - self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() + self.reactor, + _sec_timeout, + request, + response, + start_ms, + parser=parser, + max_response_size=max_response_size, ) return body diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py new file mode 100644 index 000000000000..ec8864dafe37 --- /dev/null +++ b/tests/federation/test_federation_client.py @@ -0,0 +1,149 @@ +# Copyright 2022 Matrix.org Federation C.I.C +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest import mock + +import twisted.web.client +from twisted.internet import defer +from twisted.internet.protocol import Protocol +from twisted.python.failure import Failure +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.room_versions import RoomVersions +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests.unittest import FederatingHomeserverTestCase + + +class FederationClientTest(FederatingHomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + super().prepare(reactor, clock, homeserver) + + # mock out the Agent used by the federation client, which is easier than + # catching the HTTPS connection and do the TLS stuff. + self._mock_agent = mock.create_autospec(twisted.web.client.Agent, spec_set=True) + homeserver.get_federation_http_client().agent = self._mock_agent + + def test_get_room_state(self): + creator = f"@creator:{self.OTHER_SERVER_NAME}" + test_room_id = "!room_id" + + # mock up some events to use in the response. + # In real life, these would have things in `prev_events` and `auth_events`, but that's + # a bit annoying to mock up, and the code under test doesn't care, so we don't bother. + create_event_dict = self.add_hashes_and_signatures( + { + "room_id": test_room_id, + "type": "m.room.create", + "state_key": "", + "sender": creator, + "content": {"creator": creator}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 500, + } + ) + member_event_dict = self.add_hashes_and_signatures( + { + "room_id": test_room_id, + "type": "m.room.member", + "sender": creator, + "state_key": creator, + "content": {"membership": "join"}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 600, + } + ) + pl_event_dict = self.add_hashes_and_signatures( + { + "room_id": test_room_id, + "type": "m.room.power_levels", + "sender": creator, + "state_key": "", + "content": {}, + "prev_events": [], + "auth_events": [], + "origin_server_ts": 700, + } + ) + + # mock up the response, and have the agent return it + self._mock_agent.request.return_value = defer.succeed( + _mock_response( + { + "pdus": [ + create_event_dict, + member_event_dict, + pl_event_dict, + ], + "auth_chain": [ + create_event_dict, + member_event_dict, + ], + } + ) + ) + + # now fire off the request + state_resp, auth_resp = self.get_success( + self.hs.get_federation_client().get_room_state( + "yet_another_server", + test_room_id, + "event_id", + RoomVersions.V9, + ) + ) + + # check the right call got made to the agent + self._mock_agent.request.assert_called_once_with( + b"GET", + b"matrix://yet_another_server/_matrix/federation/v1/state/%21room_id?event_id=event_id", + headers=mock.ANY, + bodyProducer=None, + ) + + # ... and that the response is correct. + + # the auth_resp should be empty because all the events are also in state + self.assertEqual(auth_resp, []) + + # all of the events should be returned in state_resp, though not necessarily + # in the same order. We just check the type on the assumption that if the type + # is right, so is the rest of the event. + self.assertCountEqual( + [e.type for e in state_resp], + ["m.room.create", "m.room.member", "m.room.power_levels"], + ) + + +def _mock_response(resp: JsonDict): + body = json.dumps(resp).encode("utf-8") + + def deliver_body(p: Protocol): + p.dataReceived(body) + p.connectionLost(Failure(twisted.web.client.ResponseDone())) + + response = mock.Mock( + code=200, + phrase=b"OK", + headers=twisted.web.client.Headers({"content-Type": ["application/json"]}), + length=len(body), + deliverBody=deliver_body, + ) + mock.seal(response) + return response diff --git a/tests/unittest.py b/tests/unittest.py index a71892cb9dbe..7983c1e8b860 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -51,7 +51,10 @@ from synapse import events from synapse.api.constants import EventTypes, Membership +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.config.homeserver import HomeServerConfig +from synapse.config.server import DEFAULT_ROOM_VERSION +from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.federation.transport.server import TransportLayerServer from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite @@ -839,6 +842,24 @@ def make_signed_federation_request( client_ip=client_ip, ) + def add_hashes_and_signatures( + self, + event_dict: JsonDict, + room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION], + ) -> JsonDict: + """Adds hashes and signatures to the given event dict + + Returns: + The modified event dict, for convenience + """ + add_hashes_and_signatures( + room_version, + event_dict, + signature_name=self.OTHER_SERVER_NAME, + signing_key=self.OTHER_SERVER_SIGNATURE_KEY, + ) + return event_dict + def _auth_header_for_request( origin: str,