Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add cancellation support to @cached and @cachedList decorators (#…
Browse files Browse the repository at this point in the history
…12183)

These decorators mostly support cancellation already. Add cancellation
tests and fix use of finished logging contexts by delaying cancellation,
as suggested by @erikjohnston.

Signed-off-by: Sean Quah <[email protected]>
  • Loading branch information
squahtx authored Mar 14, 2022
1 parent 605d161 commit 2fcf4b3
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 2 deletions.
1 change: 1 addition & 0 deletions changelog.d/12183.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add cancellation support to `@cached` and `@cachedList` decorators.
11 changes: 11 additions & 0 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import delay_cancellation
from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.lrucache import LruCache

Expand Down Expand Up @@ -350,6 +351,11 @@ def _wrapped(*args: Any, **kwargs: Any) -> Any:
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
ret = cache.set(cache_key, ret, callback=invalidate_callback)

# We started a new call to `self.orig`, so we must always wait for it to
# complete. Otherwise we might mark our current logging context as
# finished while `self.orig` is still using it in the background.
ret = delay_cancellation(ret)

return make_deferred_yieldable(ret)

wrapped = cast(_CachedFunction, _wrapped)
Expand Down Expand Up @@ -510,6 +516,11 @@ def errback_all(f: Failure) -> None:
d = defer.gatherResults(cached_defers, consumeErrors=True).addCallbacks(
lambda _: results, unwrapFirstError
)
if missing:
# We started a new call to `self.orig`, so we must always wait for it to
# complete. Otherwise we might mark our current logging context as
# finished while `self.orig` is still using it in the background.
d = delay_cancellation(d)
return make_deferred_yieldable(d)
else:
return defer.succeed(results)
Expand Down
147 changes: 145 additions & 2 deletions tests/util/caches/test_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from unittest import mock

from twisted.internet import defer, reactor
from twisted.internet.defer import Deferred
from twisted.internet.defer import CancelledError, Deferred

from synapse.api.errors import SynapseError
from synapse.logging.context import (
Expand All @@ -28,7 +28,7 @@
make_deferred_yieldable,
)
from synapse.util.caches import descriptors
from synapse.util.caches.descriptors import cached, lru_cache
from synapse.util.caches.descriptors import cached, cachedList, lru_cache

from tests import unittest
from tests.test_utils import get_awaitable_result
Expand Down Expand Up @@ -493,6 +493,74 @@ def func3(self, key, cache_context):
obj.invalidate()
top_invalidate.assert_called_once()

def test_cancel(self):
"""Test that cancelling a lookup does not cancel other lookups"""
complete_lookup: "Deferred[None]" = Deferred()

class Cls:
@cached()
async def fn(self, arg1):
await complete_lookup
return str(arg1)

obj = Cls()

d1 = obj.fn(123)
d2 = obj.fn(123)
self.assertFalse(d1.called)
self.assertFalse(d2.called)

# Cancel `d1`, which is the lookup that caused `fn` to run.
d1.cancel()

# `d2` should complete normally.
complete_lookup.callback(None)
self.failureResultOf(d1, CancelledError)
self.assertEqual(d2.result, "123")

def test_cancel_logcontexts(self):
"""Test that cancellation does not break logcontexts.
* The `CancelledError` must be raised with the correct logcontext.
* The inner lookup must not resume with a finished logcontext.
* The inner lookup must not restore a finished logcontext when done.
"""
complete_lookup: "Deferred[None]" = Deferred()

class Cls:
inner_context_was_finished = False

@cached()
async def fn(self, arg1):
await make_deferred_yieldable(complete_lookup)
self.inner_context_was_finished = current_context().finished
return str(arg1)

obj = Cls()

async def do_lookup():
with LoggingContext("c1") as c1:
try:
await obj.fn(123)
self.fail("No CancelledError thrown")
except CancelledError:
self.assertEqual(
current_context(),
c1,
"CancelledError was not raised with the correct logcontext",
)
# suppress the error and succeed

d = defer.ensureDeferred(do_lookup())
d.cancel()

complete_lookup.callback(None)
self.successResultOf(d)
self.assertFalse(
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
)
self.assertEqual(current_context(), SENTINEL_CONTEXT)


class CacheDecoratorTestCase(unittest.HomeserverTestCase):
"""More tests for @cached
Expand Down Expand Up @@ -865,3 +933,78 @@ async def list_fn(self, args1, arg2):
obj.fn.invalidate((10, 2))
invalidate0.assert_called_once()
invalidate1.assert_called_once()

def test_cancel(self):
"""Test that cancelling a lookup does not cancel other lookups"""
complete_lookup: "Deferred[None]" = Deferred()

class Cls:
@cached()
def fn(self, arg1):
pass

@cachedList(cached_method_name="fn", list_name="args")
async def list_fn(self, args):
await complete_lookup
return {arg: str(arg) for arg in args}

obj = Cls()

d1 = obj.list_fn([123, 456])
d2 = obj.list_fn([123, 456, 789])
self.assertFalse(d1.called)
self.assertFalse(d2.called)

d1.cancel()

# `d2` should complete normally.
complete_lookup.callback(None)
self.failureResultOf(d1, CancelledError)
self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"})

def test_cancel_logcontexts(self):
"""Test that cancellation does not break logcontexts.
* The `CancelledError` must be raised with the correct logcontext.
* The inner lookup must not resume with a finished logcontext.
* The inner lookup must not restore a finished logcontext when done.
"""
complete_lookup: "Deferred[None]" = Deferred()

class Cls:
inner_context_was_finished = False

@cached()
def fn(self, arg1):
pass

@cachedList(cached_method_name="fn", list_name="args")
async def list_fn(self, args):
await make_deferred_yieldable(complete_lookup)
self.inner_context_was_finished = current_context().finished
return {arg: str(arg) for arg in args}

obj = Cls()

async def do_lookup():
with LoggingContext("c1") as c1:
try:
await obj.list_fn([123])
self.fail("No CancelledError thrown")
except CancelledError:
self.assertEqual(
current_context(),
c1,
"CancelledError was not raised with the correct logcontext",
)
# suppress the error and succeed

d = defer.ensureDeferred(do_lookup())
d.cancel()

complete_lookup.callback(None)
self.successResultOf(d)
self.assertFalse(
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
)
self.assertEqual(current_context(), SENTINEL_CONTEXT)

0 comments on commit 2fcf4b3

Please sign in to comment.