Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Convert simple_select_one and simple_select_one_onecol to async #8162

Merged
merged 5 commits into from
Aug 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8162.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
36 changes: 30 additions & 6 deletions synapse/storage/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
Tuple,
TypeVar,
Union,
overload,
)

from prometheus_client import Histogram
from typing_extensions import Literal

from twisted.enterprise import adbapi
from twisted.internet import defer
Expand Down Expand Up @@ -1020,14 +1022,36 @@ def simple_upsert_many_txn_native_upsert(

return txn.execute_batch(sql, args)

def simple_select_one(
@overload
async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one",
) -> Dict[str, Any]:
...

@overload
async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one",
) -> Optional[Dict[str, Any]]:
...

async def simple_select_one(
self,
table: str,
keyvalues: Dict[str, Any],
retcols: Iterable[str],
allow_none: bool = False,
desc: str = "simple_select_one",
) -> defer.Deferred:
) -> Optional[Dict[str, Any]]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning multiple columns from it.

Expand All @@ -1038,18 +1062,18 @@ def simple_select_one(
allow_none: If true, return None instead of failing if the SELECT
statement returns no rows
"""
return self.runInteraction(
return await self.runInteraction(
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)

def simple_select_one_onecol(
async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: bool = False,
desc: str = "simple_select_one_onecol",
) -> defer.Deferred:
) -> Optional[Any]:
"""Executes a SELECT query on the named table, which is expected to
return a single row, returning a single column from it.

Expand All @@ -1061,7 +1085,7 @@ def simple_select_one_onecol(
statement returns no rows
desc: description of the transaction, for logging and metrics
"""
return self.runInteraction(
return await self.runInteraction(
desc,
self.simple_select_one_onecol_txn,
table,
Expand Down
14 changes: 8 additions & 6 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, Iterable, List, Optional, Set, Tuple
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple

from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
Expand Down Expand Up @@ -47,19 +47,19 @@


class DeviceWorkerStore(SQLBaseStore):
def get_device(self, user_id: str, device_id: str):
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.

Args:
user_id: The ID of the user which owns the device
device_id: The ID of the device to retrieve
Returns:
defer.Deferred for a dict containing the device information
A dict containing the device information
Raises:
StoreError: if the device is not found
"""
return self.db_pool.simple_select_one(
return await self.db_pool.simple_select_one(
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id, "hidden": False},
retcols=("user_id", "device_id", "display_name"),
Expand Down Expand Up @@ -656,11 +656,13 @@ def _get_all_device_list_changes_for_remotes(txn):
)

@cached(max_entries=10000)
def get_device_list_last_stream_id_for_remote(self, user_id: str):
async def get_device_list_last_stream_id_for_remote(
self, user_id: str
) -> Optional[Any]:
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
Expand Down
4 changes: 2 additions & 2 deletions synapse/storage/databases/main/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ async def get_association_from_room_alias(

return RoomAliasMapping(room_id, room_alias.to_string(), servers)

def get_room_alias_creator(self, room_alias):
return self.db_pool.simple_select_one_onecol(
async def get_room_alias_creator(self, room_alias: str) -> str:
return await self.db_pool.simple_select_one_onecol(
table="room_aliases",
keyvalues={"room_alias": room_alias},
retcol="creator",
Expand Down
8 changes: 4 additions & 4 deletions synapse/storage/databases/main/e2e_room_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,15 @@ def _get_e2e_room_keys_multi_txn(txn, user_id, version, room_keys):

return ret

def count_e2e_room_keys(self, user_id, version):
async def count_e2e_room_keys(self, user_id: str, version: str) -> int:
"""Get the number of keys in a backup version.

Args:
user_id (str): the user whose backup we're querying
version (str): the version ID of the backup we're querying about
user_id: the user whose backup we're querying
version: the version ID of the backup we're querying about
"""

return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="e2e_room_keys",
keyvalues={"user_id": user_id, "version": version},
retcol="COUNT(*)",
Expand Down
10 changes: 5 additions & 5 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,19 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):

super().process_replication_rows(stream_name, instance_name, token, rows)

def get_received_ts(self, event_id):
async def get_received_ts(self, event_id: str) -> Optional[int]:
"""Get received_ts (when it was persisted) for the event.

Raises an exception for unknown events.

Args:
event_id (str)
event_id: The event ID to query.

Returns:
Deferred[int|None]: Timestamp in milliseconds, or None for events
that were persisted before received_ts was implemented.
Timestamp in milliseconds, or None for events that were persisted
before received_ts was implemented.
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="events",
keyvalues={"event_id": event_id},
retcol="received_ts",
Expand Down
18 changes: 11 additions & 7 deletions synapse/storage/databases/main/group_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple

from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
Expand All @@ -28,8 +28,8 @@


class GroupServerWorkerStore(SQLBaseStore):
def get_group(self, group_id):
return self.db_pool.simple_select_one(
async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="groups",
keyvalues={"group_id": group_id},
retcols=(
Expand Down Expand Up @@ -351,19 +351,23 @@ async def is_user_in_group(self, user_id: str, group_id: str) -> bool:
)
return bool(result)

def is_user_admin_in_group(self, group_id, user_id):
return self.db_pool.simple_select_one_onecol(
async def is_user_admin_in_group(
self, group_id: str, user_id: str
) -> Optional[bool]:
return await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="is_admin",
allow_none=True,
desc="is_user_admin_in_group",
)

def is_user_invited_to_local_group(self, group_id, user_id):
async def is_user_invited_to_local_group(
self, group_id: str, user_id: str
) -> Optional[bool]:
"""Has the group server invited a user?
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="group_invites",
keyvalues={"group_id": group_id, "user_id": user_id},
retcol="user_id",
Expand Down
13 changes: 9 additions & 4 deletions synapse/storage/databases/main/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional

from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool

Expand All @@ -37,12 +39,13 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(MediaRepositoryStore, self).__init__(database, db_conn, hs)

def get_local_media(self, media_id):
async def get_local_media(self, media_id: str) -> Optional[Dict[str, Any]]:
"""Get the metadata for a local piece of media

Returns:
None if the media_id doesn't exist.
"""
return self.db_pool.simple_select_one(
return await self.db_pool.simple_select_one(
"local_media_repository",
{"media_id": media_id},
(
Expand Down Expand Up @@ -191,8 +194,10 @@ def store_local_thumbnail(
desc="store_local_thumbnail",
)

def get_cached_remote_media(self, origin, media_id):
return self.db_pool.simple_select_one(
async def get_cached_remote_media(
self, origin, media_id: str
) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
"remote_media_cache",
{"media_origin": origin, "media_id": media_id},
(
Expand Down
15 changes: 8 additions & 7 deletions synapse/storage/databases/main/monthly_active_users.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,18 @@ async def get_registered_reserved_users(self) -> List[str]:
return users

@cached(num_args=1)
def user_last_seen_monthly_active(self, user_id):
async def user_last_seen_monthly_active(self, user_id: str) -> int:
"""
Checks if a given user is part of the monthly active user group
Arguments:
user_id (str): user to add/update
Return:
Deferred[int] : timestamp since last seen, None if never seen
Checks if a given user is part of the monthly active user group

Arguments:
user_id: user to add/update

Return:
Timestamp since last seen, None if never seen
"""

return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
table="monthly_active_users",
keyvalues={"user_id": user_id},
retcol="timestamp",
Expand Down
17 changes: 10 additions & 7 deletions synapse/storage/databases/main/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional

from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.databases.main.roommember import ProfileInfo


class ProfileWorkerStore(SQLBaseStore):
async def get_profileinfo(self, user_localpart):
async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
try:
profile = await self.db_pool.simple_select_one(
table="profiles",
Expand All @@ -38,24 +39,26 @@ async def get_profileinfo(self, user_localpart):
avatar_url=profile["avatar_url"], display_name=profile["displayname"]
)

def get_profile_displayname(self, user_localpart):
return self.db_pool.simple_select_one_onecol(
async def get_profile_displayname(self, user_localpart: str) -> str:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="displayname",
desc="get_profile_displayname",
)

def get_profile_avatar_url(self, user_localpart):
return self.db_pool.simple_select_one_onecol(
async def get_profile_avatar_url(self, user_localpart: str) -> str:
return await self.db_pool.simple_select_one_onecol(
table="profiles",
keyvalues={"user_id": user_localpart},
retcol="avatar_url",
desc="get_profile_avatar_url",
)

def get_from_remote_profile_cache(self, user_id):
return self.db_pool.simple_select_one(
async def get_from_remote_profile_cache(
self, user_id: str
) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),
Expand Down
6 changes: 4 additions & 2 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@ def get_receipts_for_room(self, room_id, receipt_type):
)

@cached(num_args=3)
def get_last_receipt_event_id_for_user(self, user_id, room_id, receipt_type):
return self.db_pool.simple_select_one_onecol(
async def get_last_receipt_event_id_for_user(
self, user_id: str, room_id: str, receipt_type: str
) -> Optional[str]:
return await self.db_pool.simple_select_one_onecol(
table="receipts_linearized",
keyvalues={
"room_id": room_id,
Expand Down
10 changes: 5 additions & 5 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import logging
import re
from typing import Awaitable, Dict, List, Optional
from typing import Any, Awaitable, Dict, List, Optional

from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
Expand Down Expand Up @@ -46,8 +46,8 @@ def __init__(self, database: DatabasePool, db_conn, hs):
)

@cached()
def get_user_by_id(self, user_id):
return self.db_pool.simple_select_one(
async def get_user_by_id(self, user_id: str) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="users",
keyvalues={"name": user_id},
retcols=[
Expand Down Expand Up @@ -1259,12 +1259,12 @@ def del_user_pending_deactivation(self, user_id):
desc="del_user_pending_deactivation",
)

def get_user_pending_deactivation(self):
async def get_user_pending_deactivation(self) -> Optional[str]:
"""
Gets one user from the table of users waiting to be parted from all the rooms
they're in.
"""
return self.db_pool.simple_select_one_onecol(
return await self.db_pool.simple_select_one_onecol(
"users_pending_deactivation",
keyvalues={},
retcol="user_id",
Expand Down
Loading