Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Support stable identifiers for MSC3440: Threading #12151

Merged
merged 13 commits into from
Mar 10, 2022
Merged
1 change: 1 addition & 0 deletions changelog.d/12151.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Support the stable identifiers from [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440): threads.
4 changes: 3 additions & 1 deletion synapse/api/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ class RelationTypes:
ANNOTATION: Final = "m.annotation"
REPLACE: Final = "m.replace"
REFERENCE: Final = "m.reference"
THREAD: Final = "io.element.thread"
THREAD: Final = "m.thread"
# TODO Remove this in Synapse >= v1.57.0.
UNSTABLE_THREAD: Final = "io.element.thread"


class LimitBlockingTypes:
Expand Down
23 changes: 12 additions & 11 deletions synapse/api/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@
"org.matrix.labels": {"type": "array", "items": {"type": "string"}},
"org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
# MSC3440, filtering by event relations.
"related_by_senders": {"type": "array", "items": {"type": "string"}},
"io.element.relation_senders": {"type": "array", "items": {"type": "string"}},
"related_by_rel_types": {"type": "array", "items": {"type": "string"}},
"io.element.relation_types": {"type": "array", "items": {"type": "string"}},
},
}
Expand Down Expand Up @@ -318,19 +320,18 @@ def __init__(self, hs: "HomeServer", filter_json: JsonDict):
self.labels = filter_json.get("org.matrix.labels", None)
self.not_labels = filter_json.get("org.matrix.not_labels", [])

# Ideally these would be rejected at the endpoint if they were provided
# and not supported, but that would involve modifying the JSON schema
# based on the homeserver configuration.
self.related_by_senders = self.filter_json.get("related_by_senders", None)
self.related_by_rel_types = self.filter_json.get("related_by_rel_types", None)

# Fallback to the unstable prefix if the stable version is not given.
if hs.config.experimental.msc3440_enabled:
self.relation_senders = self.filter_json.get(
self.related_by_senders = self.related_by_senders or self.filter_json.get(
"io.element.relation_senders", None
)
self.relation_types = self.filter_json.get(
"io.element.relation_types", None
self.related_by_rel_types = (
self.related_by_rel_types
or self.filter_json.get("io.element.relation_types", None)
)
else:
self.relation_senders = None
self.relation_types = None

def filters_all_types(self) -> bool:
return "*" in self.not_types
Expand Down Expand Up @@ -461,7 +462,7 @@ async def _check_event_relations(
event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined]
event_ids_to_keep = set(
await self._store.events_have_relations(
event_ids, self.relation_senders, self.relation_types
event_ids, self.related_by_senders, self.related_by_rel_types
)
)

Expand All @@ -474,7 +475,7 @@ async def _check_event_relations(
async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]:
result = [event for event in events if self._check(event)]

if self.relation_senders or self.relation_types:
if self.related_by_senders or self.related_by_rel_types:
return await self._check_event_relations(result)

return result
Expand Down
4 changes: 3 additions & 1 deletion synapse/events/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,11 +515,13 @@ def _inject_bundled_aggregations(
thread.latest_event, serialized_latest_event, thread.latest_edit
)

serialized_aggregations[RelationTypes.THREAD] = {
thread_summary = {
"latest_event": serialized_latest_event,
"count": thread.count,
"current_user_participated": thread.current_user_participated,
}
serialized_aggregations[RelationTypes.THREAD] = thread_summary
serialized_aggregations[RelationTypes.UNSTABLE_THREAD] = thread_summary
clokep marked this conversation as resolved.
Show resolved Hide resolved

# Include the bundled aggregations in the event.
if serialized_aggregations:
Expand Down
5 changes: 4 additions & 1 deletion synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,10 @@ async def _validate_event_relation(self, event: EventBase) -> None:
raise SynapseError(400, "Can't send same reaction twice")

# Don't attempt to start a thread if the parent event is a relation.
elif relation_type == RelationTypes.THREAD:
elif (
relation_type == RelationTypes.THREAD
or relation_type == RelationTypes.UNSTABLE_THREAD
):
if await self.store.event_includes_relation(relates_to):
raise SynapseError(
400, "Cannot start threads from an event with a relation"
Expand Down
1 change: 1 addition & 0 deletions synapse/rest/client/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
"org.matrix.msc3030": self.config.experimental.msc3030_enabled,
# Adds support for thread relations, per MSC3440.
"org.matrix.msc3440": self.config.experimental.msc3440_enabled,
"org.matrix.msc3440.stable": True,
clokep marked this conversation as resolved.
Show resolved Hide resolved
},
},
)
Expand Down
5 changes: 4 additions & 1 deletion synapse/storage/databases/main/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -1811,7 +1811,10 @@ def _handle_event_relations(
if rel_type == RelationTypes.REPLACE:
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))

if rel_type == RelationTypes.THREAD:
if (
rel_type == RelationTypes.THREAD
or rel_type == RelationTypes.UNSTABLE_THREAD
):
txn.call_after(
self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
)
Expand Down
73 changes: 45 additions & 28 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ def _get_thread_summaries_txn(
AND parent.room_id = child.room_id
WHERE
%s
AND relation_type = ?
AND %s
ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC
"""
else:
Expand All @@ -514,16 +514,22 @@ def _get_thread_summaries_txn(
AND parent.room_id = child.room_id
WHERE
%s
AND relation_type = ?
AND %s
ORDER BY child.topological_ordering DESC, child.stream_ordering DESC
"""

clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", event_ids
)

args.append(RelationTypes.THREAD)
if self._msc3440_enabled:
relations_clause = "(relation_type = ? OR relation_type = ?)"
args.append(RelationTypes.UNSTABLE_THREAD)
else:
relations_clause = "relation_type = ?"
clokep marked this conversation as resolved.
Show resolved Hide resolved

txn.execute(sql % (clause,), args)
txn.execute(sql % (clause, relations_clause), args)
latest_event_ids = {}
for parent_event_id, child_event_id in txn:
# Only consider the latest threaded reply (by topological ordering).
Expand All @@ -543,7 +549,7 @@ def _get_thread_summaries_txn(
AND parent.room_id = child.room_id
WHERE
%s
AND relation_type = ?
AND %s
GROUP BY parent.event_id
"""

Expand All @@ -552,9 +558,15 @@ def _get_thread_summaries_txn(
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", latest_event_ids.keys()
)

args.append(RelationTypes.THREAD)
if self._msc3440_enabled:
relations_clause = "(relation_type = ? OR relation_type = ?)"
args.append(RelationTypes.UNSTABLE_THREAD)
else:
relations_clause = "relation_type = ?"

txn.execute(sql % (clause,), args)
txn.execute(sql % (clause, relations_clause), args)
counts = dict(cast(List[Tuple[str, int]], txn.fetchall()))

return counts, latest_event_ids
Expand Down Expand Up @@ -617,16 +629,24 @@ def _get_thread_summary_txn(txn: LoggingTransaction) -> Set[str]:
AND parent.room_id = child.room_id
WHERE
%s
AND relation_type = ?
AND %s
AND child.sender = ?
"""

clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", event_ids
)
args.extend((RelationTypes.THREAD, user_id))

txn.execute(sql % (clause,), args)
args.append(RelationTypes.THREAD)
if self._msc3440_enabled:
relations_clause = "(relation_type = ? OR relation_type = ?)"
args.append(RelationTypes.UNSTABLE_THREAD)
else:
relations_clause = "relation_type = ?"

args.append(user_id)

txn.execute(sql % (clause, relations_clause), args)
return {row[0] for row in txn.fetchall()}

participated_threads = await self.db_pool.runInteraction(
Expand Down Expand Up @@ -830,26 +850,23 @@ async def get_bundled_aggregations(
results.setdefault(event_id, BundledAggregations()).replace = edit

# Fetch thread summaries.
if self._msc3440_enabled:
summaries = await self._get_thread_summaries(seen_event_ids)
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._get_threads_participated(
summaries.keys(), user_id
)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)
summaries = await self._get_thread_summaries(seen_event_ids)
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._get_threads_participated(summaries.keys(), user_id)
for event_id, summary in summaries.items():
if summary:
thread_count, latest_thread_event, edit = summary
results.setdefault(
event_id, BundledAggregations()
).thread = _ThreadAggregation(
latest_event=latest_thread_event,
latest_edit=edit,
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
)

return results

Expand Down
18 changes: 10 additions & 8 deletions synapse/storage/databases/main/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,21 +325,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
args.extend(event_filter.labels)

# Filter on relation_senders / relation types from the joined tables.
if event_filter.relation_senders:
if event_filter.related_by_senders:
clauses.append(
"(%s)"
% " OR ".join(
"related_event.sender = ?" for _ in event_filter.relation_senders
"related_event.sender = ?" for _ in event_filter.related_by_senders
)
)
args.extend(event_filter.relation_senders)
args.extend(event_filter.related_by_senders)

if event_filter.relation_types:
if event_filter.related_by_rel_types:
clauses.append(
"(%s)"
% " OR ".join("relation_type = ?" for _ in event_filter.relation_types)
% " OR ".join(
"relation_type = ?" for _ in event_filter.related_by_rel_types
)
)
args.extend(event_filter.relation_types)
args.extend(event_filter.related_by_rel_types)

return " AND ".join(clauses), args

Expand Down Expand Up @@ -1203,15 +1205,15 @@ def _paginate_room_events_txn(
# If there is a filter on relation_senders and relation_types join to the
# relations table.
if event_filter and (
event_filter.relation_senders or event_filter.relation_types
event_filter.related_by_senders or event_filter.related_by_rel_types
):
# Filtering by relations could cause the same event to appear multiple
# times (since there's no limit on the number of relations to an event).
needs_distinct = True
join_clause += """
LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id)
"""
if event_filter.relation_senders:
if event_filter.related_by_senders:
join_clause += """
LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id)
"""
Expand Down
8 changes: 2 additions & 6 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,9 +547,7 @@ def test_aggregation_must_be_annotation(self) -> None:
)
self.assertEqual(400, channel.code, channel.json_body)

@unittest.override_config(
{"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}}
)
@unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
def test_bundled_aggregations(self) -> None:
"""
Test that annotations, references, and threads get correctly bundled.
Expand Down Expand Up @@ -597,6 +595,7 @@ def assert_bundle(event_json: JsonDict) -> None:
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
RelationTypes.THREAD,
RelationTypes.UNSTABLE_THREAD,
),
)

Expand Down Expand Up @@ -758,7 +757,6 @@ def test_aggregation_get_event_for_thread(self) -> None:
},
)

@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_ignore_invalid_room(self) -> None:
"""Test that we ignore invalid relations over federation."""
# Create another room and send a message in it.
Expand Down Expand Up @@ -1065,7 +1063,6 @@ def test_edit_reply(self) -> None:
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)

@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_edit_thread(self) -> None:
"""Test that editing a thread works."""

Expand Down Expand Up @@ -1383,7 +1380,6 @@ def test_redact_relation_annotation(self) -> None:
chunk = self._get_aggregations()
self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}])

@unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
def test_redact_relation_thread(self) -> None:
"""
Test that thread replies are properly handled after the thread reply redacted.
Expand Down
18 changes: 8 additions & 10 deletions tests/rest/client/test_rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2141,21 +2141,19 @@ def _filter_messages(self, filter: JsonDict) -> List[JsonDict]:

def test_filter_relation_senders(self) -> None:
# Messages which second user reacted to.
filter = {"io.element.relation_senders": [self.second_user_id]}
filter = {"related_by_senders": [self.second_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_1)

# Messages which third user reacted to.
filter = {"io.element.relation_senders": [self.third_user_id]}
filter = {"related_by_senders": [self.third_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_2)

# Messages which either user reacted to.
filter = {
"io.element.relation_senders": [self.second_user_id, self.third_user_id]
}
filter = {"related_by_senders": [self.second_user_id, self.third_user_id]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 2, chunk)
self.assertCountEqual(
Expand All @@ -2164,20 +2162,20 @@ def test_filter_relation_senders(self) -> None:

def test_filter_relation_type(self) -> None:
# Messages which have annotations.
filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]}
filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_1)

# Messages which have references.
filter = {"io.element.relation_types": [RelationTypes.REFERENCE]}
filter = {"related_by_rel_types": [RelationTypes.REFERENCE]}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
self.assertEqual(chunk[0]["event_id"], self.event_id_2)

# Messages which have either annotations or references.
filter = {
"io.element.relation_types": [
"related_by_rel_types": [
RelationTypes.ANNOTATION,
RelationTypes.REFERENCE,
]
Expand All @@ -2191,8 +2189,8 @@ def test_filter_relation_type(self) -> None:
def test_filter_relation_senders_and_type(self) -> None:
# Messages which second user reacted to.
filter = {
"io.element.relation_senders": [self.second_user_id],
"io.element.relation_types": [RelationTypes.ANNOTATION],
"related_by_senders": [self.second_user_id],
"related_by_rel_types": [RelationTypes.ANNOTATION],
}
chunk = self._filter_messages(filter)
self.assertEqual(len(chunk), 1, chunk)
Expand Down
Loading