Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Another batch of type annotations #12726

Merged
merged 11 commits into from
May 13, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ disallow_untyped_defs = True
[mypy-synapse.http.federation.*]
disallow_untyped_defs = True

[mypy-synapse.http.proxyagent]
disallow_untyped_defs = True

[mypy-synapse.http.request_metrics]
disallow_untyped_defs = True

Expand Down Expand Up @@ -166,6 +169,9 @@ disallow_untyped_defs = True
[mypy-synapse.state.*]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.background_updates]
disallow_untyped_defs = True

[mypy-synapse.storage.databases.main.account_data]
disallow_untyped_defs = True

Expand Down Expand Up @@ -232,6 +238,9 @@ disallow_untyped_defs = True
[mypy-synapse.streams.*]
disallow_untyped_defs = True

[mypy-synapse.types]
disallow_untyped_defs = True

[mypy-synapse.util.*]
disallow_untyped_defs = True

Expand Down
29 changes: 11 additions & 18 deletions synapse/handlers/e2e_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple

import attr
from canonicaljson import encode_canonical_json
Expand Down Expand Up @@ -1105,22 +1105,19 @@ async def _get_e2e_cross_signing_verify_key(
# can request over federation
raise NotFoundError("No %s key found for %s" % (key_type, user_id))

(
key,
key_id,
verify_key,
) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type)

if key is None:
cross_signing_keys = await self._retrieve_cross_signing_keys_for_remote_user(
user, key_type
)
if cross_signing_keys is None:
raise NotFoundError("No %s key found for %s" % (key_type, user_id))

return key, key_id, verify_key
return cross_signing_keys

async def _retrieve_cross_signing_keys_for_remote_user(
self,
user: UserID,
desired_key_type: str,
) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]:
) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]:
"""Queries cross-signing keys for a remote user and saves them to the database

Only the key specified by `key_type` will be returned, while all retrieved keys
Expand All @@ -1146,12 +1143,10 @@ async def _retrieve_cross_signing_keys_for_remote_user(
type(e),
e,
)
return None, None, None
return None

# Process each of the retrieved cross-signing keys
desired_key = None
desired_key_id = None
desired_verify_key = None
desired_key_data = None
retrieved_device_ids = []
for key_type in ["master", "self_signing"]:
key_content = remote_result.get(key_type + "_key")
Expand Down Expand Up @@ -1196,9 +1191,7 @@ async def _retrieve_cross_signing_keys_for_remote_user(

# If this is the desired key type, save it and its ID/VerifyKey
if key_type == desired_key_type:
desired_key = key_content
desired_verify_key = verify_key
desired_key_id = key_id
desired_key_data = key_content, key_id, verify_key

# At the same time, store this key in the db for subsequent queries
await self.store.set_e2e_cross_signing_key(
Expand All @@ -1212,7 +1205,7 @@ async def _retrieve_cross_signing_keys_for_remote_user(
user.to_string(), retrieved_device_ids
)

return desired_key, desired_key_id, desired_verify_key
return desired_key_data


def _check_cross_signing_key(
Expand Down
2 changes: 1 addition & 1 deletion synapse/http/proxyagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def http_proxy_endpoint(
proxy: Optional[bytes],
reactor: IReactorCore,
tls_options_factory: Optional[IPolicyForHTTPS],
**kwargs,
**kwargs: object,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bit of a cop-out.

) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
"""Parses an http proxy setting and returns an endpoint for the proxy

Expand Down
28 changes: 21 additions & 7 deletions synapse/logging/scopecontextmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.import logging

import logging
from types import TracebackType
from typing import Optional, Type

from opentracing import Scope, ScopeManager

Expand Down Expand Up @@ -107,36 +109,48 @@ class _LogContextScope(Scope):
and - if enter_logcontext was set - the logcontext is finished too.
"""

def __init__(self, manager, span, logcontext, enter_logcontext, finish_on_close):
def __init__(
self,
manager: LogContextScopeManager,
span,
logcontext,
enter_logcontext: bool,
finish_on_close: bool,
):
"""
Args:
manager (LogContextScopeManager):
manager:
the manager that is responsible for this scope.
span (Span):
the opentracing span which this scope represents the local
lifetime for.
logcontext (LogContext):
the logcontext to which this scope is attached.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wasn't sure where these types came from, so this is just a drive-by.

enter_logcontext (Boolean):
enter_logcontext:
if True the logcontext will be exited when the scope is finished
finish_on_close (Boolean):
finish_on_close:
if True finish the span when the scope is closed
"""
super().__init__(manager, span)
self.logcontext = logcontext
self._finish_on_close = finish_on_close
self._enter_logcontext = enter_logcontext

def __exit__(self, exc_type, value, traceback):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if exc_type == twisted.internet.defer._DefGen_Return:
# filter out defer.returnValue() calls
exc_type = value = traceback = None
super().__exit__(exc_type, value, traceback)

def __str__(self):
def __str__(self) -> str:
return f"Scope<{self.span}>"

def close(self):
def close(self) -> None:
active_scope = self.manager.active
if active_scope is not self:
logger.error(
Expand Down
19 changes: 14 additions & 5 deletions synapse/storage/background_updates.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Type,
)

import attr

from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.types import Connection
from synapse.storage.types import Connection, Cursor
from synapse.types import JsonDict
from synapse.util import Clock, json_encoder

Expand Down Expand Up @@ -74,7 +78,12 @@ async def __aenter__(self) -> int:

return self._update_duration_ms

async def __aexit__(self, *exc) -> None:
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> None:
pass


Expand Down Expand Up @@ -352,7 +361,7 @@ async def do_next_background_update(self, sleep: bool = True) -> bool:
True if we have finished running all the background updates, otherwise False
"""

def get_background_updates_txn(txn):
def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]:
txn.execute(
"""
SELECT update_name, depends_on FROM background_updates
Expand Down Expand Up @@ -469,7 +478,7 @@ def register_background_update_handler(
self,
update_name: str,
update_handler: Callable[[JsonDict, int], Awaitable[int]],
):
) -> None:
"""Register a handler for doing a background update.

The handler should take two arguments:
Expand Down Expand Up @@ -603,7 +612,7 @@ def create_index_sqlite(conn: Connection) -> None:
else:
runner = create_index_sqlite

async def updater(progress, batch_size):
async def updater(progress: JsonDict, batch_size: int) -> int:
if runner is not None:
logger.info("Adding index %s to %s", index_name, table)
await self.db_pool.runWithConnection(runner)
Expand Down
46 changes: 28 additions & 18 deletions synapse/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Mapping,
Match,
MutableMapping,
NoReturn,
Optional,
Set,
Tuple,
Expand All @@ -35,6 +36,7 @@
import attr
from frozendict import frozendict
from signedjson.key import decode_verify_key_bytes
from signedjson.types import VerifyKey
from typing_extensions import TypedDict
from unpaddedbase64 import decode_base64
from zope.interface import Interface
Expand All @@ -55,6 +57,7 @@
if TYPE_CHECKING:
from synapse.appservice.api import ApplicationService
from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore

# Define a state map type from type/state_key to T (usually an event ID or
# event)
Expand Down Expand Up @@ -114,7 +117,7 @@ class Requester:
app_service: Optional["ApplicationService"]
authenticated_entity: str

def serialize(self):
def serialize(self) -> Dict[str, Any]:
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`

Expand All @@ -132,7 +135,9 @@ def serialize(self):
}

@staticmethod
def deserialize(store, input):
def deserialize(
store: "ApplicationServiceWorkerStore", input: Dict[str, Any]
) -> "Requester":
"""Converts a dict that was produced by `serialize` back into a
Requester.

Expand Down Expand Up @@ -236,10 +241,10 @@ class DomainSpecificString(metaclass=abc.ABCMeta):
domain: str

# Because this is a frozen class, it is deeply immutable.
def __copy__(self):
def __copy__(self: DS) -> DS:
return self

def __deepcopy__(self, memo):
def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS:
return self

@classmethod
Expand Down Expand Up @@ -729,12 +734,14 @@ async def to_string(self, store: "DataStore") -> str:
)

@property
def room_stream_id(self):
def room_stream_id(self) -> int:
return self.room_key.stream

def copy_and_advance(self, key, new_value) -> "StreamToken":
def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken":
"""Advance the given key in the token to a new value if and only if the
new value is after the old value.

:raises TypeError: if `key` is not the one of the keys tracked by a StreamToken.
"""
if key == "room_key":
new_token = self.copy_and_replace(
Expand All @@ -751,7 +758,7 @@ def copy_and_advance(self, key, new_value) -> "StreamToken":
else:
return self

def copy_and_replace(self, key, new_value) -> "StreamToken":
def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken":
return attr.evolve(self, **{key: new_value})


Expand Down Expand Up @@ -793,14 +800,14 @@ class ThirdPartyInstanceID:
# Deny iteration because it will bite you if you try to create a singleton
# set by:
# users = set(user)
def __iter__(self):
def __iter__(self) -> NoReturn:
raise ValueError("Attempted to iterate a %s" % (type(self).__name__,))

# Because this class is a frozen class, it is deeply immutable.
def __copy__(self):
def __copy__(self) -> "ThirdPartyInstanceID":
return self

def __deepcopy__(self, memo):
def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID":
return self

@classmethod
Expand Down Expand Up @@ -852,25 +859,28 @@ def __bool__(self) -> bool:
return bool(self.changed or self.left)


def get_verify_key_from_cross_signing_key(key_info):
def get_verify_key_from_cross_signing_key(
key_info: Dict[str, Any]
) -> Tuple[str, VerifyKey]:
"""Get the key ID and signedjson verify key from a cross-signing key dict

Args:
key_info (dict): a cross-signing key dict, which must have a "keys"
key_info: a cross-signing key dict, which must have a "keys"
property that has exactly one item in it

Returns:
(str, VerifyKey): the key ID and verify key for the cross-signing key
the key ID and verify key for the cross-signing key
"""
# make sure that exactly one key is provided
# make sure that a `keys` field is provided
if "keys" not in key_info:
raise ValueError("Invalid key")
keys = key_info["keys"]
if len(keys) != 1:
raise ValueError("Invalid key")
# and return that one key
for key_id, key_data in keys.items():
# and that it contains exactly one key
if len(keys) == 1:
key_id, key_data = keys.popitem()
return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data))
else:
raise ValueError("Invalid key")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is meant to be equivalent. Mypy doesn't like the previous version: it can't see that the loop on -872 will always run exactly once.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This appears to have broken a test!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh. If keys is empty, then before this change we would return None. Now we explicitly error.

Either the comments were incorrect, or else the implementation was buggy.

Copy link
Contributor Author

@DMRobertson DMRobertson May 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hasn't changed since its initial implementation! See #5769

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every call site that I can see expects this function to return a 2-tuple.



@attr.s(auto_attribs=True, frozen=True, slots=True)
Expand Down