Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Cross-signing [4/4] -- federation edition #5727

Merged
merged 18 commits into from
Nov 1, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions synapse/federation/sender/per_destination_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def _get_device_update_edus(self, limit):
last_device_list = self._last_device_list_stream_id

# Retrieve list of new device updates to send to the destination
now_stream_id, results = yield self._store.get_devices_by_remote(
now_stream_id, results = yield self._store.get_device_updates_by_remote(
self._destination, last_device_list, limit=limit
)
edus = [
Expand All @@ -372,7 +372,7 @@ def _get_device_update_edus(self, limit):
for (edu_type, content) in results
]

assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"

return (edus, now_stream_id)

Expand Down
24 changes: 12 additions & 12 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,7 +1072,7 @@ class SignatureListItem:


class SigningKeyEduUpdater(object):
"Handles incoming signing key updates from federation and updates the DB"
"""Handles incoming signing key updates from federation and updates the DB"""

def __init__(self, hs, e2e_keys_handler):
self.store = hs.get_datastore()
Expand Down Expand Up @@ -1111,7 +1111,6 @@ def incoming_signing_key_update(self, origin, edu_content):
self_signing_key = edu_content.pop("self_signing_key", None)

if get_domain_from_id(user_id) != origin:
# TODO: Raise?
logger.warning("Got signing key update edu for %r from %r", user_id, origin)
return

Expand All @@ -1122,7 +1121,7 @@ def incoming_signing_key_update(self, origin, edu_content):
return

self._pending_updates.setdefault(user_id, []).append(
(master_key, self_signing_key, edu_content)
(master_key, self_signing_key)
)

yield self._handle_signing_key_updates(user_id)
Expand All @@ -1147,22 +1146,23 @@ def _handle_signing_key_updates(self, user_id):

logger.info("pending updates: %r", pending_updates)

for master_key, self_signing_key, edu_content in pending_updates:
for master_key, self_signing_key in pending_updates:
if master_key:
yield self.store.set_e2e_cross_signing_key(
user_id, "master", master_key
)
device_id = get_verify_key_from_cross_signing_key(master_key)[
1
].version
device_ids.append(device_id)
_, verify_key = get_verify_key_from_cross_signing_key(master_key)
# verify_key is a VerifyKey from signedjson, which uses
# .version to denote the portion of the key ID after the
# algorithm and colon, which is the device ID
device_ids.append(verify_key.version)
if self_signing_key:
yield self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key
)
device_id = get_verify_key_from_cross_signing_key(self_signing_key)[
1
].version
device_ids.append(device_id)
_, verify_key = get_verify_key_from_cross_signing_key(
self_signing_key
)
device_ids.append(verify_key.version)

yield device_handler.notify_device_update(user_id, device_ids)
87 changes: 40 additions & 47 deletions synapse/storage/data_stores/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,13 @@ def get_devices_by_user(self, user_id):

@trace
@defer.inlineCallbacks
def get_devices_by_remote(self, destination, from_stream_id, limit):
"""Get stream of updates to send to remote servers
def get_device_updates_by_remote(self, destination, from_stream_id, limit):
"""Get a stream of device updates to send to the given remote server.

Args:
destination (str): The host the device updates are intended for
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
limit (int): Maximum number of device updates to return
Returns:
Deferred[tuple[int, list[tuple[string,dict]]]]:
current stream id (ie, the stream id of the last update included in the
Expand All @@ -119,8 +123,8 @@ def get_devices_by_remote(self, destination, from_stream_id, limit):
# stream_id; the rationale being that such a large device list update
# is likely an error.
updates = yield self.runInteraction(
"get_devices_by_remote",
self._get_devices_by_remote_txn,
"get_device_updates_by_remote",
self._get_device_updates_by_remote_txn,
destination,
from_stream_id,
now_stream_id,
Expand All @@ -131,7 +135,8 @@ def get_devices_by_remote(self, destination, from_stream_id, limit):
if not updates:
return now_stream_id, []

# get the cross-signing keys of the users the list
# get the cross-signing keys of the users in the list, so that we can
# determine which of the device changes were cross-signing keys
users = set(r[0] for r in updates)
master_key_by_user = {}
self_signing_key_by_user = {}
Expand All @@ -141,9 +146,12 @@ def get_devices_by_remote(self, destination, from_stream_id, limit):
key_id, verify_key = get_verify_key_from_cross_signing_key(
cross_signing_key
)
# verify_key is a VerifyKey from signedjson, which uses
# .version to denote the portion of the key ID after the
# algorithm and colon, which is the device ID
master_key_by_user[user] = {
"key_info": cross_signing_key,
"pubkey": verify_key.version,
"device_id": verify_key.version,
}

cross_signing_key = yield self.get_e2e_cross_signing_key(
Expand All @@ -155,7 +163,7 @@ def get_devices_by_remote(self, destination, from_stream_id, limit):
)
self_signing_key_by_user[user] = {
"key_info": cross_signing_key,
"pubkey": verify_key.version,
"device_id": verify_key.version,
}

# if we have exceeded the limit, we need to exclude any results with the
Expand All @@ -182,73 +190,58 @@ def get_devices_by_remote(self, destination, from_stream_id, limit):
# context which created the Edu.

query_map = {}
for update in updates:
if stream_id_cutoff is not None and update[2] >= stream_id_cutoff:
cross_signing_keys_by_user = {}
for user_id, device_id, update_stream_id, update_context in updates:
if stream_id_cutoff is not None and update_stream_id >= stream_id_cutoff:
# Stop processing updates
break

# skip over cross-signing keys
if (
update[0] in master_key_by_user
and update[1] == master_key_by_user[update[0]]["pubkey"]
) or (
update[0] in master_key_by_user
and update[1] == self_signing_key_by_user[update[0]]["pubkey"]
user_id in master_key_by_user
and device_id == master_key_by_user[user_id]["device_id"]
):
continue

key = (update[0], update[1])

update_context = update[3]
update_stream_id = update[2]

previous_update_stream_id, _ = query_map.get(key, (0, None))

if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)

# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
# steam_id.

# figure out which cross-signing keys were changed by intersecting the
# update list with the master/self-signing key by user maps
cross_signing_keys_by_user = {}
for user_id, device_id, stream, _opentracing_context in updates:
if device_id == master_key_by_user.get(user_id, {}).get("pubkey", None):
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["master_key"] = master_key_by_user[user_id]["key_info"]
uhoreg marked this conversation as resolved.
Show resolved Hide resolved
elif device_id == self_signing_key_by_user.get(user_id, {}).get(
"pubkey", None
elif (
user_id in master_key_by_user
uhoreg marked this conversation as resolved.
Show resolved Hide resolved
and device_id == self_signing_key_by_user[user_id]["device_id"]
):
result = cross_signing_keys_by_user.setdefault(user_id, {})
result["self_signing_key"] = self_signing_key_by_user[user_id][
"key_info"
]
else:
key = (user_id, device_id)

cross_signing_results = []
previous_update_stream_id, _ = query_map.get(key, (0, None))

# add the updated cross-signing keys to the results list
for user_id, result in iteritems(cross_signing_keys_by_user):
result["user_id"] = user_id
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
cross_signing_results.append(("org.matrix.signing_key_update", result))
if update_stream_id > previous_update_stream_id:
query_map[key] = (update_stream_id, update_context)

# If we didn't find any updates with a stream_id lower than the cutoff, it
# means that there are more than limit updates all of which have the same
# steam_id.

# That should only happen if a client is spamming the server with new
# devices, in which case E2E isn't going to work well anyway. We'll just
# skip that stream_id and return an empty list, and continue with the next
# stream_id next time.
if not query_map and not cross_signing_results:
if not query_map and not cross_signing_keys_by_user:
return stream_id_cutoff, []

results = yield self._get_device_update_edus_by_remote(
destination, from_stream_id, query_map
)
results.extend(cross_signing_results)

# add the updated cross-signing keys to the results list
for user_id, result in iteritems(cross_signing_keys_by_user):
result["user_id"] = user_id
# FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
results.append(("org.matrix.signing_key_update", result))

return now_stream_id, results

def _get_devices_by_remote_txn(
def _get_device_updates_by_remote_txn(
self, txn, destination, from_stream_id, now_stream_id, limit
):
"""Return device update information for a given remote destination
Expand Down
4 changes: 2 additions & 2 deletions tests/handlers/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def make_homeserver(self, reactor, clock):
"get_received_txn_response",
"set_received_txn_response",
"get_destination_retry_timings",
"get_devices_by_remote",
"get_device_updates_by_remote",
# Bits that user_directory needs
"get_user_directory_stream_pos",
"get_current_state_deltas",
Expand Down Expand Up @@ -109,7 +109,7 @@ def prepare(self, reactor, clock, hs):
retry_timings_res
)

self.datastore.get_devices_by_remote.return_value = (0, [])
self.datastore.get_device_updates_by_remote.return_value = (0, [])

def get_received_txn_response(*args):
return defer.succeed(None)
Expand Down
12 changes: 6 additions & 6 deletions tests/storage/test_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_get_devices_by_user(self):
)

@defer.inlineCallbacks
def test_get_devices_by_remote(self):
def test_get_device_updates_by_remote(self):
device_ids = ["device_id1", "device_id2"]

# Add two device updates with a single stream_id
Expand All @@ -81,15 +81,15 @@ def test_get_devices_by_remote(self):
)

# Get all device updates ever meant for this remote
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"somehost", -1, limit=100
)

# Check original device_ids are contained within these updates
self._check_devices_in_updates(device_ids, device_updates)

@defer.inlineCallbacks
def test_get_devices_by_remote_limited(self):
def test_get_device_updates_by_remote_limited(self):
# Test breaking the update limit in 1, 101, and 1 device_id segments

# first add one device
Expand All @@ -115,20 +115,20 @@ def test_get_devices_by_remote_limited(self):
#

# first we should get a single update
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", -1, limit=100
)
self._check_devices_in_updates(device_ids1, device_updates)

# Then we should get an empty list back as the 101 devices broke the limit
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100
)
self.assertEqual(len(device_updates), 0)

# The 101 devices should've been cleared, so we should now just get one device
# update
now_stream_id, device_updates = yield self.store.get_devices_by_remote(
now_stream_id, device_updates = yield self.store.get_device_updates_by_remote(
"someotherhost", now_stream_id, limit=100
)
self._check_devices_in_updates(device_ids3, device_updates)
Expand Down