diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e2c8dc022ced..6cc45f9aaf3a 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1048,6 +1048,13 @@ async def _validate_event_relation(self, event: EventBase) -> None: if already_exists: raise SynapseError(400, "Can't send same reaction twice") + # If this relation is a thread, then ensure thread head is not part of + # a thread already. + elif relation_type == RelationTypes.THREAD: + already_thread = await self.store.get_event_thread(relates_to) + if already_thread: + raise SynapseError(400, "Can't fork threads") + @measure_func("handle_new_client_event") async def handle_new_client_event( self, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 40760fbd1b36..bb6e597188aa 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -372,6 +372,37 @@ def _get_if_user_has_annotated_event(txn): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) + async def get_event_thread(self, event_id: str) -> Optional[str]: + """Return an event's thread. + + Args: + event_id: The event being used as the start of a new thread. + + Returns: + The thread ID of the event. + """ + + sql = """ + SELECT relates_to_id FROM event_relations + WHERE + event_id = ? + AND relation_type = ? + LIMIT 1; + """ + + def _get_thread_id(txn) -> Optional[str]: + txn.execute( + sql, + ( + event_id, + RelationTypes.THREAD, + ), + ) + + return txn.fetchone() + + return await self.db_pool.runInteraction("get_thread_id", _get_thread_id) + class RelationsStore(RelationsWorkerStore): pass diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index aad7f997ac98..867d4aab24a8 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -119,6 +119,25 @@ def test_deny_double_react(self): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") self.assertEquals(400, channel.code, channel.json_body) + def test_deny_forked_thread(self): + """It is invalid to start a thread off a thread.""" + channel = self._send_relation( + RelationTypes.THREAD, + "m.room.message", + content={"msgtype": "m.text", "body": "foo"}, + parent_id=self.parent_id, + ) + self.assertEquals(200, channel.code, channel.json_body) + parent_id = channel.json_body["event_id"] + + channel = self._send_relation( + RelationTypes.THREAD, + "m.room.message", + content={"msgtype": "m.text", "body": "foo"}, + parent_id=parent_id, + ) + self.assertEquals(400, channel.code, channel.json_body) + def test_basic_paginate_relations(self): """Tests that calling pagination API correctly the latest relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")