Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert _base, profile, and _receipts handlers to async/await #7860

Merged
merged 5 commits into from
Jul 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/7860.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert _base, profile, and _receipts handlers to async/await.
7 changes: 2 additions & 5 deletions synapse/handlers/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import logging

from twisted.internet import defer

import synapse.state
import synapse.storage
import synapse.types
Expand Down Expand Up @@ -66,8 +64,7 @@ def __init__(self, hs):

self.event_builder_factory = hs.get_event_builder_factory()

@defer.inlineCallbacks
def ratelimit(self, requester, update=True, is_admin_redaction=False):
async def ratelimit(self, requester, update=True, is_admin_redaction=False):
"""Ratelimits requests.

Args:
Expand Down Expand Up @@ -99,7 +96,7 @@ def ratelimit(self, requester, update=True, is_admin_redaction=False):
burst_count = self._rc_message.burst_count

# Check if there is a per user override in the DB.
override = yield self.store.get_ratelimit_for_user(user_id)
override = await self.store.get_ratelimit_for_user(user_id)
if override:
# If overridden with a null Hz then ratelimiting has been entirely
# disabled for the user
Expand Down
8 changes: 6 additions & 2 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,11 +488,15 @@ def create_event(

try:
if "displayname" not in content:
displayname = yield profile.get_displayname(target)
displayname = yield defer.ensureDeferred(
profile.get_displayname(target)
)
if displayname is not None:
content["displayname"] = displayname
if "avatar_url" not in content:
avatar_url = yield profile.get_avatar_url(target)
avatar_url = yield defer.ensureDeferred(
profile.get_avatar_url(target)
)
if avatar_url is not None:
content["avatar_url"] = avatar_url
except Exception as e:
Expand Down
63 changes: 27 additions & 36 deletions synapse/handlers/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import logging

from twisted.internet import defer

from synapse.api.errors import (
AuthError,
Codes,
Expand Down Expand Up @@ -54,16 +52,15 @@ def __init__(self, hs):

self.user_directory_handler = hs.get_user_directory_handler()

@defer.inlineCallbacks
def get_profile(self, user_id):
async def get_profile(self, user_id):
target_user = UserID.from_string(user_id)

if self.hs.is_mine(target_user):
try:
displayname = yield self.store.get_profile_displayname(
displayname = await self.store.get_profile_displayname(
target_user.localpart
)
avatar_url = yield self.store.get_profile_avatar_url(
avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart
)
except StoreError as e:
Expand All @@ -74,7 +71,7 @@ def get_profile(self, user_id):
return {"displayname": displayname, "avatar_url": avatar_url}
else:
try:
result = yield self.federation.make_query(
result = await self.federation.make_query(
destination=target_user.domain,
query_type="profile",
args={"user_id": user_id},
Expand All @@ -86,19 +83,18 @@ def get_profile(self, user_id):
except HttpResponseException as e:
raise e.to_synapse_error()

@defer.inlineCallbacks
def get_profile_from_cache(self, user_id):
async def get_profile_from_cache(self, user_id):
"""Get the profile information from our local cache. If the user is
ours then the profile information will always be corect. Otherwise,
it may be out of date/missing.
"""
target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user):
try:
displayname = yield self.store.get_profile_displayname(
displayname = await self.store.get_profile_displayname(
target_user.localpart
)
avatar_url = yield self.store.get_profile_avatar_url(
avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart
)
except StoreError as e:
Expand All @@ -108,14 +104,13 @@ def get_profile_from_cache(self, user_id):

return {"displayname": displayname, "avatar_url": avatar_url}
else:
profile = yield self.store.get_from_remote_profile_cache(user_id)
profile = await self.store.get_from_remote_profile_cache(user_id)
return profile or {}

@defer.inlineCallbacks
def get_displayname(self, target_user):
async def get_displayname(self, target_user):
if self.hs.is_mine(target_user):
try:
displayname = yield self.store.get_profile_displayname(
displayname = await self.store.get_profile_displayname(
target_user.localpart
)
except StoreError as e:
Expand All @@ -126,7 +121,7 @@ def get_displayname(self, target_user):
return displayname
else:
try:
result = yield self.federation.make_query(
result = await self.federation.make_query(
destination=target_user.domain,
query_type="profile",
args={"user_id": target_user.to_string(), "field": "displayname"},
Expand Down Expand Up @@ -189,11 +184,10 @@ async def set_displayname(

await self._update_join_states(requester, target_user)

@defer.inlineCallbacks
def get_avatar_url(self, target_user):
async def get_avatar_url(self, target_user):
if self.hs.is_mine(target_user):
try:
avatar_url = yield self.store.get_profile_avatar_url(
avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart
)
except StoreError as e:
Expand All @@ -203,7 +197,7 @@ def get_avatar_url(self, target_user):
return avatar_url
else:
try:
result = yield self.federation.make_query(
result = await self.federation.make_query(
destination=target_user.domain,
query_type="profile",
args={"user_id": target_user.to_string(), "field": "avatar_url"},
Expand Down Expand Up @@ -253,8 +247,7 @@ async def set_avatar_url(

await self._update_join_states(requester, target_user)

@defer.inlineCallbacks
def on_profile_query(self, args):
async def on_profile_query(self, args):
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver")
Expand All @@ -264,12 +257,12 @@ def on_profile_query(self, args):
response = {}
try:
if just_field is None or just_field == "displayname":
response["displayname"] = yield self.store.get_profile_displayname(
response["displayname"] = await self.store.get_profile_displayname(
user.localpart
)

if just_field is None or just_field == "avatar_url":
response["avatar_url"] = yield self.store.get_profile_avatar_url(
response["avatar_url"] = await self.store.get_profile_avatar_url(
user.localpart
)
except StoreError as e:
Expand Down Expand Up @@ -304,8 +297,7 @@ async def _update_join_states(self, requester, target_user):
"Failed to update join event for room %s - %s", room_id, str(e)
)

@defer.inlineCallbacks
def check_profile_query_allowed(self, target_user, requester=None):
async def check_profile_query_allowed(self, target_user, requester=None):
"""Checks whether a profile query is allowed. If the
'require_auth_for_profile_requests' config flag is set to True and a
'requester' is provided, the query is only allowed if the two users
Expand Down Expand Up @@ -337,8 +329,8 @@ def check_profile_query_allowed(self, target_user, requester=None):
return

try:
requester_rooms = yield self.store.get_rooms_for_user(requester.to_string())
target_user_rooms = yield self.store.get_rooms_for_user(
requester_rooms = await self.store.get_rooms_for_user(requester.to_string())
target_user_rooms = await self.store.get_rooms_for_user(
target_user.to_string()
)

Expand Down Expand Up @@ -371,25 +363,24 @@ def _start_update_remote_profile_cache(self):
"Update remote profile", self._update_remote_profile_cache
)

@defer.inlineCallbacks
def _update_remote_profile_cache(self):
async def _update_remote_profile_cache(self):
"""Called periodically to check profiles of remote users we haven't
checked in a while.
"""
entries = yield self.store.get_remote_profile_cache_entries_that_expire(
entries = await self.store.get_remote_profile_cache_entries_that_expire(
last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
)

for user_id, displayname, avatar_url in entries:
is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
is_subscribed = await self.store.is_subscribed_remote_profile_for_user(
user_id
)
if not is_subscribed:
yield self.store.maybe_delete_remote_profile_cache(user_id)
await self.store.maybe_delete_remote_profile_cache(user_id)
continue

try:
profile = yield self.federation.make_query(
profile = await self.federation.make_query(
destination=get_domain_from_id(user_id),
query_type="profile",
args={"user_id": user_id},
Expand All @@ -398,7 +389,7 @@ def _update_remote_profile_cache(self):
except Exception:
logger.exception("Failed to get avatar_url")

yield self.store.update_remote_profile_cache(
await self.store.update_remote_profile_cache(
user_id, displayname, avatar_url
)
continue
Expand All @@ -407,4 +398,4 @@ def _update_remote_profile_cache(self):
new_avatar = profile.get("avatar_url")

# We always hit update to update the last_check timestamp
yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
await self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
16 changes: 6 additions & 10 deletions synapse/handlers/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
# limitations under the License.
import logging

from twisted.internet import defer

from synapse.handlers._base import BaseHandler
from synapse.types import ReadReceipt, get_domain_from_id
from synapse.util.async_helpers import maybe_awaitable
Expand Down Expand Up @@ -129,15 +127,14 @@ class ReceiptEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()

@defer.inlineCallbacks
def get_new_events(self, from_key, room_ids, **kwargs):
async def get_new_events(self, from_key, room_ids, **kwargs):
from_key = int(from_key)
to_key = yield self.get_current_key()
to_key = self.get_current_key()

if from_key == to_key:
return [], to_key

events = yield self.store.get_linearized_receipts_for_rooms(
events = await self.store.get_linearized_receipts_for_rooms(
room_ids, from_key=from_key, to_key=to_key
)

Expand All @@ -146,17 +143,16 @@ def get_new_events(self, from_key, room_ids, **kwargs):
def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id()

@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
async def get_pagination_rows(self, user, config, key):
to_key = int(config.from_key)

if config.to_key:
from_key = int(config.to_key)
else:
from_key = None

room_ids = yield self.store.get_rooms_for_user(user.to_string())
events = yield self.store.get_linearized_receipts_for_rooms(
room_ids = await self.store.get_rooms_for_user(user.to_string())
events = await self.store.get_linearized_receipts_for_rooms(
room_ids, from_key=from_key, to_key=to_key
)

Expand Down
17 changes: 11 additions & 6 deletions tests/handlers/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def register_query_handler(query_type, handler):
def test_get_my_name(self):
yield self.store.set_profile_displayname(self.frank.localpart, "Frank")

displayname = yield self.handler.get_displayname(self.frank)
displayname = yield defer.ensureDeferred(
self.handler.get_displayname(self.frank)
)

self.assertEquals("Frank", displayname)

Expand Down Expand Up @@ -140,7 +142,9 @@ def test_get_other_name(self):
{"displayname": "Alice"}
)

displayname = yield self.handler.get_displayname(self.alice)
displayname = yield defer.ensureDeferred(
self.handler.get_displayname(self.alice)
)

self.assertEquals(displayname, "Alice")
self.mock_federation.make_query.assert_called_with(
Expand All @@ -155,8 +159,10 @@ def test_incoming_fed_query(self):
yield self.store.create_profile("caroline")
yield self.store.set_profile_displayname("caroline", "Caroline")

response = yield self.query_handlers["profile"](
{"user_id": "@caroline:test", "field": "displayname"}
response = yield defer.ensureDeferred(
self.query_handlers["profile"](
{"user_id": "@caroline:test", "field": "displayname"}
)
)

self.assertEquals({"displayname": "Caroline"}, response)
Expand All @@ -166,8 +172,7 @@ def test_get_my_avatar(self):
yield self.store.set_profile_avatar_url(
self.frank.localpart, "http://my.server/me.png"
)

avatar_url = yield self.handler.get_avatar_url(self.frank)
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))

self.assertEquals("http://my.server/me.png", avatar_url)

Expand Down