Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sliding Sync: Add Account Data extension (MSC3959) #17477

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/17477.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add Account Data extension support to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.
138 changes: 138 additions & 0 deletions synapse/handlers/sliding_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from synapse.types import (
DeviceListUpdates,
JsonDict,
JsonMapping,
PersistedEventPosition,
Requester,
RoomStreamToken,
Expand Down Expand Up @@ -356,6 +357,7 @@ def __init__(self, hs: "HomeServer"):
self.event_sources = hs.get_event_sources()
self.relations_handler = hs.get_relations_handler()
self.device_handler = hs.get_device_handler()
self.push_rules_handler = hs.get_push_rules_handler()
self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync

async def wait_for_sync_for_user(
Expand Down Expand Up @@ -627,6 +629,7 @@ async def handle_room(room_id: str) -> None:

extensions = await self.get_extensions_response(
sync_config=sync_config,
lists=lists,
from_token=from_token,
to_token=to_token,
)
Expand Down Expand Up @@ -1785,13 +1788,15 @@ async def get_room_sync_data(
async def get_extensions_response(
self,
sync_config: SlidingSyncConfig,
lists: Dict[str, SlidingSyncResult.SlidingWindowList],
to_token: StreamToken,
from_token: Optional[StreamToken],
) -> SlidingSyncResult.Extensions:
"""Handle extension requests.

Args:
sync_config: Sync configuration
lists: Sliding window API. A map of list key to list results.
to_token: The point in the stream to sync up to.
from_token: The point in the stream to sync from.
"""
Expand All @@ -1816,9 +1821,20 @@ async def get_extensions_response(
from_token=from_token,
)

account_data_response = None
if sync_config.extensions.account_data is not None:
account_data_response = await self.get_account_data_extension_response(
sync_config=sync_config,
lists=lists,
account_data_request=sync_config.extensions.account_data,
to_token=to_token,
from_token=from_token,
)

return SlidingSyncResult.Extensions(
to_device=to_device_response,
e2ee=e2ee_response,
account_data=account_data_response,
)

async def get_to_device_extension_response(
Expand Down Expand Up @@ -1944,3 +1960,125 @@ async def get_e2ee_extension_response(
device_one_time_keys_count=device_one_time_keys_count,
device_unused_fallback_key_types=device_unused_fallback_key_types,
)

async def get_account_data_extension_response(
self,
sync_config: SlidingSyncConfig,
lists: Dict[str, SlidingSyncResult.SlidingWindowList],
account_data_request: SlidingSyncConfig.Extensions.AccountDataExtension,
to_token: StreamToken,
from_token: Optional[StreamToken],
) -> Optional[SlidingSyncResult.Extensions.AccountDataExtension]:
"""Handle Account Data extension (MSC3959)

Args:
sync_config: Sync configuration
lists: Sliding window API. A map of list key to list results.
account_data_request: The account_data extension from the request
to_token: The point in the stream to sync up to.
from_token: The point in the stream to sync from.
"""
user_id = sync_config.user.to_string()

# Skip if the extension is not enabled
if not account_data_request.enabled:
return None

global_account_data_map: Mapping[str, JsonMapping] = {}
if from_token is not None:
global_account_data_map = (
await self.store.get_updated_global_account_data_for_user(
user_id, from_token.account_data_key
)
)

have_push_rules_changed = await self.store.have_push_rules_changed_for_user(
user_id, from_token.push_rules_key
)
if have_push_rules_changed:
global_account_data_map = dict(global_account_data_map)
global_account_data_map[AccountDataTypes.PUSH_RULES] = (
await self.push_rules_handler.push_rules_for_user(sync_config.user)
)
else:
all_global_account_data = await self.store.get_global_account_data_for_user(
user_id
)

global_account_data_map = dict(all_global_account_data)
global_account_data_map[AccountDataTypes.PUSH_RULES] = (
await self.push_rules_handler.push_rules_for_user(sync_config.user)
)

# We only want to include account data for rooms that are already in the sliding
# sync response AND that were requested in the account data request.
relevant_room_ids: Set[str] = set()

# See what rooms from the room subscriptions we should get account data for
if (
account_data_request.rooms is not None
and sync_config.room_subscriptions is not None
):
actual_room_ids = sync_config.room_subscriptions.keys()

for room_id in account_data_request.rooms:
# A wildcard means we process all rooms from the room subscriptions
if room_id == "*":
relevant_room_ids.update(sync_config.room_subscriptions.keys())
break

if room_id in actual_room_ids:
relevant_room_ids.add(room_id)

# See what rooms from the sliding window lists we should get account data for
if account_data_request.lists is not None:
for list_key in account_data_request.lists:
# Just some typing because we share the variable name in multiple places
actual_list: Optional[SlidingSyncResult.SlidingWindowList] = None

# A wildcard means we process rooms from all lists
if list_key == "*":
for actual_list in lists.values():
# We only expect a single SYNC operation for any list
assert len(actual_list.ops) == 1
sync_op = actual_list.ops[0]
assert sync_op.op == OperationType.SYNC

relevant_room_ids.update(sync_op.room_ids)

break

actual_list = lists.get(list_key)
if actual_list is not None:
# We only expect a single SYNC operation for any list
assert len(actual_list.ops) == 1
sync_op = actual_list.ops[0]
assert sync_op.op == OperationType.SYNC

relevant_room_ids.update(sync_op.room_ids)

# Fetch room account data
account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]] = {}
if len(relevant_room_ids) > 0:
if from_token is not None:
account_data_by_room_map = (
await self.store.get_updated_room_account_data_for_user(
user_id, from_token.account_data_key
)
)
else:
account_data_by_room_map = (
await self.store.get_room_account_data_for_user(user_id)
)

# Filter down to the relevant rooms
account_data_by_room_map = {
room_id: account_data_map
for room_id, account_data_map in account_data_by_room_map.items()
if room_id in relevant_room_ids
}

return SlidingSyncResult.Extensions.AccountDataExtension(
global_account_data_map=global_account_data_map,
account_data_by_room_map=account_data_by_room_map,
)
19 changes: 18 additions & 1 deletion synapse/rest/client/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,6 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:

return 200, response_content

# TODO: Is there a better way to encode things?
async def encode_response(
self,
requester: Requester,
Expand Down Expand Up @@ -1115,6 +1114,24 @@ async def encode_extensions(
extensions.e2ee.device_list_updates.left
)

if extensions.account_data is not None:
serialized_extensions["account_data"] = {
# Same as the the top-level `account_data.events` field in Sync v2.
"global": [
{"type": account_data_type, "content": content}
for account_data_type, content in extensions.account_data.global_account_data_map.items()
],
# Same as the joined room's account_data field in Sync v2, e.g the path
# `rooms.join["!foo:bar"].account_data.events`.
"rooms": {
room_id: [
{"type": account_data_type, "content": content}
for account_data_type, content in event_map.items()
]
for room_id, event_map in extensions.account_data.account_data_by_room_map.items()
},
}

return serialized_extensions


Expand Down
22 changes: 21 additions & 1 deletion synapse/types/handlers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,31 @@ def __bool__(self) -> bool:
or self.device_unused_fallback_key_types
)

@attr.s(slots=True, frozen=True, auto_attribs=True)
class AccountDataExtension:
"""The Account Data extension (MSC3959)

Attributes:
global_account_data_map: Mapping from `type` to `content` of global account
data events.
account_data_by_room_map: Mapping from room_id to mapping of `type` to
`content` of room account data events.
"""

global_account_data_map: Mapping[str, JsonMapping]
account_data_by_room_map: Mapping[str, Mapping[str, JsonMapping]]

def __bool__(self) -> bool:
return bool(
self.global_account_data_map or self.account_data_by_room_map
)

to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None
account_data: Optional[AccountDataExtension] = None

def __bool__(self) -> bool:
return bool(self.to_device or self.e2ee)
return bool(self.to_device or self.e2ee or self.account_data)

next_pos: StreamToken
lists: Dict[str, SlidingWindowList]
Expand Down
18 changes: 18 additions & 0 deletions synapse/types/rest/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,26 @@ class E2eeExtension(RequestBodyModel):

enabled: Optional[StrictBool] = False

class AccountDataExtension(RequestBodyModel):
"""The Account Data extension (MSC3959)

Attributes:
enabled
lists: List of list keys (from the Sliding Window API) to apply this
extension to.
rooms: List of room IDs (from the Room Subscription API) to apply this
extension to.
"""

enabled: Optional[StrictBool] = False
# Process all lists defined in the Sliding Window API. (This is the default.)
lists: Optional[List[StrictStr]] = ["*"]
# Process all room subscriptions defined in the Room Subscription API. (This is the default.)
rooms: Optional[List[StrictStr]] = ["*"]

to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None
account_data: Optional[AccountDataExtension] = None

# mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884
if TYPE_CHECKING:
Expand Down
Loading
Loading