Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Cache empty responses from /user/devices (#11587)
Browse files Browse the repository at this point in the history
If we've never made a request to a remote homeserver, we should cache the response---even if the response is "this user has no devices".
  • Loading branch information
David Robertson authored Jan 5, 2022
1 parent 0fb3dd0 commit 88a78c6
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 5 deletions.
1 change: 1 addition & 0 deletions changelog.d/11587.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a long-standing bug where Synapse wouldn't cache a response indicating that a remote user has no devices.
10 changes: 9 additions & 1 deletion synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,8 +948,16 @@ async def user_device_resync(
devices = []
ignore_devices = True
else:
prev_stream_id = await self.store.get_device_list_last_stream_id_for_remote(
user_id
)
cached_devices = await self.store.get_cached_devices_for_user(user_id)
if cached_devices == {d["device_id"]: d for d in devices}:

# To ensure that a user with no devices is cached, we skip the resync only
# if we have a stream_id from previously writing a cache entry.
if prev_stream_id is not None and cached_devices == {
d["device_id"]: d for d in devices
}:
logging.info(
"Skipping device list resync for %s, as our cache matches already",
user_id,
Expand Down
8 changes: 6 additions & 2 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,7 @@ def _get_all_device_list_changes_for_remotes(txn):
@cached(max_entries=10000)
async def get_device_list_last_stream_id_for_remote(
self, user_id: str
) -> Optional[Any]:
) -> Optional[str]:
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
Expand All @@ -729,7 +729,9 @@ async def get_device_list_last_stream_id_for_remote(
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
)
async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
async def get_device_list_last_stream_id_for_remotes(
self, user_ids: Iterable[str]
) -> Dict[str, Optional[str]]:
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
Expand Down Expand Up @@ -1316,6 +1318,7 @@ def _update_remote_device_list_cache_entry_txn(
content: JsonDict,
stream_id: str,
) -> None:
"""Delete, update or insert a cache entry for this (user, device) pair."""
if content.get("deleted"):
self.db_pool.simple_delete_txn(
txn,
Expand Down Expand Up @@ -1375,6 +1378,7 @@ async def update_remote_device_list_cache(
def _update_remote_device_list_cache_txn(
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
) -> None:
"""Replace the list of cached devices for this user with the given list."""
self.db_pool.simple_delete_txn(
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
)
Expand Down
96 changes: 96 additions & 0 deletions tests/handlers/test_e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable
from unittest import mock

from parameterized import parameterized
from signedjson import key as key, sign as sign

from twisted.internet import defer
Expand All @@ -23,6 +25,7 @@
from synapse.api.errors import Codes, SynapseError

from tests import unittest
from tests.test_utils import make_awaitable


class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
Expand Down Expand Up @@ -765,6 +768,8 @@ def test_query_devices_remote_sync(self):
remote_user_id = "@test:other"
local_user_id = "@test:test"

# Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock(
return_value=defer.succeed({"some_room_id"})
)
Expand Down Expand Up @@ -831,3 +836,94 @@ def test_query_devices_remote_sync(self):
}
},
)

@parameterized.expand(
[
# The remote homeserver's response indicates that this user has 0/1/2 devices.
([],),
(["device_1"],),
(["device_1", "device_2"],),
]
)
def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
"""Test that requests for all of a remote user's devices are cached.
We do this by asserting that only one call over federation was made, and that
the two queries to the local homeserver produce the same response.
"""
local_user_id = "@test:test"
remote_user_id = "@test:other"
request_body = {"device_keys": {remote_user_id: []}}

response_devices = [
{
"device_id": device_id,
"keys": {
"algorithms": ["dummy"],
"device_id": device_id,
"keys": {f"dummy:{device_id}": "dummy"},
"signatures": {device_id: {f"dummy:{device_id}": "dummy"}},
"unsigned": {},
"user_id": "@test:other",
},
}
for device_id in device_ids
]

response_body = {
"devices": response_devices,
"user_id": remote_user_id,
"stream_id": 12345, # an integer, according to the spec
}

e2e_handler = self.hs.get_e2e_keys_handler()

# Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early.
mock_get_rooms = mock.patch.object(
self.store,
"get_rooms_for_user",
new_callable=mock.MagicMock,
return_value=make_awaitable(["some_room_id"]),
)
mock_request = mock.patch.object(
self.hs.get_federation_client(),
"query_user_devices",
new_callable=mock.MagicMock,
return_value=make_awaitable(response_body),
)

with mock_get_rooms, mock_request as mocked_federation_request:
# Make the first query and sanity check it succeeds.
response_1 = self.get_success(
e2e_handler.query_devices(
request_body,
timeout=10,
from_user_id=local_user_id,
from_device_id="some_device_id",
)
)
self.assertEqual(response_1["failures"], {})

# We should have made a federation request to do so.
mocked_federation_request.assert_called_once()

# Reset the mock so we can prove we don't make a second federation request.
mocked_federation_request.reset_mock()

# Repeat the query.
response_2 = self.get_success(
e2e_handler.query_devices(
request_body,
timeout=10,
from_user_id=local_user_id,
from_device_id="some_device_id",
)
)
self.assertEqual(response_2["failures"], {})

# We should not have made a second federation request.
mocked_federation_request.assert_not_called()

# The two requests to the local homeserver should be identical.
self.assertEqual(response_1, response_2)
4 changes: 2 additions & 2 deletions tests/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import warnings
from asyncio import Future
from binascii import unhexlify
from typing import Any, Awaitable, Callable, TypeVar
from typing import Awaitable, Callable, TypeVar
from unittest.mock import Mock

import attr
Expand All @@ -46,7 +46,7 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
raise Exception("awaitable has not yet completed")


def make_awaitable(result: Any) -> Awaitable[Any]:
def make_awaitable(result: TV) -> Awaitable[TV]:
"""
Makes an awaitable, suitable for mocking an `async` function.
This uses Futures as they can be awaited multiple times so can be returned
Expand Down

0 comments on commit 88a78c6

Please sign in to comment.