Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge pull request #2075 from matrix-org/erikj/cache_speed
Browse files Browse the repository at this point in the history
Speed up cached function access
  • Loading branch information
erikjohnston authored Mar 31, 2017
2 parents 350333a + 4d17add commit 9cee0ce
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 16 deletions.
7 changes: 2 additions & 5 deletions synapse/push/push_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@
from synapse.push.presentable_names import (
calculate_room_name, name_from_member_event
)
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred


@defer.inlineCallbacks
def get_badge_count(store, user_id):
invites, joins = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(store.get_invited_rooms_for_user)(user_id),
preserve_fn(store.get_rooms_for_user)(user_id),
], consumeErrors=True))
invites = yield store.get_invited_rooms_for_user(user_id)
joins = yield store.get_rooms_for_user(user_id)

my_receipts_by_room = yield store.get_receipts_for_user(
user_id, "m.read",
Expand Down
7 changes: 6 additions & 1 deletion synapse/util/async.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def errback(f):
deferred.addCallbacks(callback, errback)

def observe(self):
"""Observe the underlying deferred.
Can return either a deferred if the underlying deferred is still pending
(or has failed), or the actual value. Callers may need to use maybeDeferred.
"""
if not self._result:
d = defer.Deferred()

Expand All @@ -101,7 +106,7 @@ def remove(r):
return d
else:
success, res = self._result
return defer.succeed(res) if success else defer.fail(res)
return res if success else defer.fail(res)

def observers(self):
return self._observers
Expand Down
42 changes: 35 additions & 7 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,20 @@ def __init__(self, orig, num_args, inlineCallbacks, cache_context=False):
)

self.num_args = num_args

# list of the names of the args used as the cache key
self.arg_names = all_args[1:num_args + 1]

# self.arg_defaults is a map of arg name to its default value for each
# argument that has a default value
if arg_spec.defaults:
self.arg_defaults = dict(zip(
all_args[-len(arg_spec.defaults):],
arg_spec.defaults
))
else:
self.arg_defaults = {}

if "cache_context" in self.arg_names:
raise Exception(
"cache_context arg cannot be included among the cache keys"
Expand Down Expand Up @@ -289,18 +301,31 @@ def __get__(self, obj, objtype=None):
iterable=self.iterable,
)

def get_cache_key(args, kwargs):
"""Given some args/kwargs return a generator that resolves into
the cache_key.
We loop through each arg name, looking up if its in the `kwargs`,
otherwise using the next argument in `args`. If there are no more
args then we try looking the arg name up in the defaults
"""
pos = 0
for nm in self.arg_names:
if nm in kwargs:
yield kwargs[nm]
elif pos < len(args):
yield args[pos]
pos += 1
else:
yield self.arg_defaults[nm]

@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)

# Add temp cache_context so inspect.getcallargs doesn't explode
if self.add_cache_context:
kwargs["cache_context"] = None

arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
cache_key = tuple(arg_dict[arg_nm] for arg_nm in self.arg_names)
cache_key = tuple(get_cache_key(args, kwargs))

# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
Expand Down Expand Up @@ -341,7 +366,10 @@ def onErr(f):
cache.set(cache_key, result_d, callback=invalidate_callback)
observer = result_d.observe()

return logcontext.make_deferred_yieldable(observer)
if isinstance(observer, defer.Deferred):
return logcontext.make_deferred_yieldable(observer)
else:
return observer

wrapped.invalidate = cache.invalidate
wrapped.invalidate_all = cache.invalidate_all
Expand Down
3 changes: 2 additions & 1 deletion synapse/visibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def filter_events_for_clients(store, user_tuples, events, event_id_to_state):
events ([synapse.events.EventBase]): list of events to filter
"""
forgotten = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(store.who_forgot_in_room)(
defer.maybeDeferred(
preserve_fn(store.who_forgot_in_room),
room_id,
)
for room_id in frozenset(e.room_id for e in events)
Expand Down
2 changes: 1 addition & 1 deletion tests/storage/test__base.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def func(self, key):

a.func.prefill(("foo",), ObservableDeferred(d))

self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(a.func("foo"), d.result)
self.assertEquals(callcount[0], 0)

@defer.inlineCallbacks
Expand Down
38 changes: 38 additions & 0 deletions tests/util/caches/test_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,41 @@ def do_lookup():
logcontext.LoggingContext.sentinel)

return d1

@defer.inlineCallbacks
def test_cache_default_args(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()

@descriptors.cached()
def fn(self, arg1, arg2=2, arg3=3):
return self.mock(arg1, arg2, arg3)

obj = Cls()

obj.mock.return_value = 'fish'
r = yield obj.fn(1, 2, 3)
self.assertEqual(r, 'fish')
obj.mock.assert_called_once_with(1, 2, 3)
obj.mock.reset_mock()

# a call with same params shouldn't call the mock again
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
obj.mock.assert_not_called()
obj.mock.reset_mock()

# a call with different params should call the mock again
obj.mock.return_value = 'chips'
r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_called_once_with(2, 3, 3)
obj.mock.reset_mock()

# the two values should now be cached
r = yield obj.fn(1, 2)
self.assertEqual(r, 'fish')
r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_not_called()
4 changes: 3 additions & 1 deletion tests/util/test_snapshot_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def test_get_set(self):
# before the cache expires returns a resolved deferred.
get_result_at_11 = self.cache.get(11, "key")
self.assertIsNotNone(get_result_at_11)
self.assertTrue(get_result_at_11.called)
if isinstance(get_result_at_11, Deferred):
# The cache may return the actual result rather than a deferred
self.assertTrue(get_result_at_11.called)

# Check that getting the key after the deferred has resolved
# after the cache expires returns None
Expand Down

0 comments on commit 9cee0ce

Please sign in to comment.