From aec69d2481e9ea1d8ea1c0ffce1706a65a7896a8 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 13 May 2022 12:35:31 +0100 Subject: [PATCH] Another batch of type annotations (#12726) --- changelog.d/12726.misc | 1 + mypy.ini | 21 ++++++++++++ synapse/handlers/e2e_keys.py | 29 ++++++---------- synapse/http/connectproxyclient.py | 39 +++++++++++++--------- synapse/http/proxyagent.py | 2 +- synapse/logging/_remote.py | 20 ++++++----- synapse/logging/formatter.py | 14 +++++--- synapse/logging/handlers.py | 4 +-- synapse/logging/scopecontextmanager.py | 28 ++++++++++++---- synapse/storage/background_updates.py | 19 ++++++++--- synapse/types.py | 46 ++++++++++++++++---------- 11 files changed, 144 insertions(+), 79 deletions(-) create mode 100644 changelog.d/12726.misc diff --git a/changelog.d/12726.misc b/changelog.d/12726.misc new file mode 100644 index 000000000000..b07e1b52ee7c --- /dev/null +++ b/changelog.d/12726.misc @@ -0,0 +1 @@ +Add type annotations to increase the number of modules passing `disallow-untyped-defs`. \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index 8478dd9e510b..9ae7ad211c54 100644 --- a/mypy.ini +++ b/mypy.ini @@ -128,15 +128,30 @@ disallow_untyped_defs = True [mypy-synapse.http.federation.*] disallow_untyped_defs = True +[mypy-synapse.http.connectproxyclient] +disallow_untyped_defs = True + +[mypy-synapse.http.proxyagent] +disallow_untyped_defs = True + [mypy-synapse.http.request_metrics] disallow_untyped_defs = True [mypy-synapse.http.server] disallow_untyped_defs = True +[mypy-synapse.logging._remote] +disallow_untyped_defs = True + [mypy-synapse.logging.context] disallow_untyped_defs = True +[mypy-synapse.logging.formatter] +disallow_untyped_defs = True + +[mypy-synapse.logging.handlers] +disallow_untyped_defs = True + [mypy-synapse.metrics.*] disallow_untyped_defs = True @@ -166,6 +181,9 @@ disallow_untyped_defs = True [mypy-synapse.state.*] disallow_untyped_defs = True +[mypy-synapse.storage.databases.background_updates] +disallow_untyped_defs = True + [mypy-synapse.storage.databases.main.account_data] disallow_untyped_defs = True @@ -232,6 +250,9 @@ disallow_untyped_defs = True [mypy-synapse.streams.*] disallow_untyped_defs = True +[mypy-synapse.types] +disallow_untyped_defs = True + [mypy-synapse.util.*] disallow_untyped_defs = True diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index d6714228ef41..e6c2cfb8c8e7 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -15,7 +15,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -1105,22 +1105,19 @@ async def _get_e2e_cross_signing_verify_key( # can request over federation raise NotFoundError("No %s key found for %s" % (key_type, user_id)) - ( - key, - key_id, - verify_key, - ) = await self._retrieve_cross_signing_keys_for_remote_user(user, key_type) - - if key is None: + cross_signing_keys = await self._retrieve_cross_signing_keys_for_remote_user( + user, key_type + ) + if cross_signing_keys is None: raise NotFoundError("No %s key found for %s" % (key_type, user_id)) - return key, key_id, verify_key + return cross_signing_keys async def _retrieve_cross_signing_keys_for_remote_user( self, user: UserID, desired_key_type: str, - ) -> Tuple[Optional[dict], Optional[str], Optional[VerifyKey]]: + ) -> Optional[Tuple[Dict[str, Any], str, VerifyKey]]: """Queries cross-signing keys for a remote user and saves them to the database Only the key specified by `key_type` will be returned, while all retrieved keys @@ -1146,12 +1143,10 @@ async def _retrieve_cross_signing_keys_for_remote_user( type(e), e, ) - return None, None, None + return None # Process each of the retrieved cross-signing keys - desired_key = None - desired_key_id = None - desired_verify_key = None + desired_key_data = None retrieved_device_ids = [] for key_type in ["master", "self_signing"]: key_content = remote_result.get(key_type + "_key") @@ -1196,9 +1191,7 @@ async def _retrieve_cross_signing_keys_for_remote_user( # If this is the desired key type, save it and its ID/VerifyKey if key_type == desired_key_type: - desired_key = key_content - desired_verify_key = verify_key - desired_key_id = key_id + desired_key_data = key_content, key_id, verify_key # At the same time, store this key in the db for subsequent queries await self.store.set_e2e_cross_signing_key( @@ -1212,7 +1205,7 @@ async def _retrieve_cross_signing_keys_for_remote_user( user.to_string(), retrieved_device_ids ) - return desired_key, desired_key_id, desired_verify_key + return desired_key_data def _check_cross_signing_key( diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py index 203e995bb77d..23a60af17184 100644 --- a/synapse/http/connectproxyclient.py +++ b/synapse/http/connectproxyclient.py @@ -14,15 +14,22 @@ import base64 import logging -from typing import Optional +from typing import Optional, Union import attr from zope.interface import implementer from twisted.internet import defer, protocol from twisted.internet.error import ConnectError -from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint +from twisted.internet.interfaces import ( + IAddress, + IConnector, + IProtocol, + IReactorCore, + IStreamClientEndpoint, +) from twisted.internet.protocol import ClientFactory, Protocol, connectionDone +from twisted.python.failure import Failure from twisted.web import http logger = logging.getLogger(__name__) @@ -81,14 +88,14 @@ def __init__( self._port = port self._proxy_creds = proxy_creds - def __repr__(self): + def __repr__(self) -> str: return "" % (self._proxy_endpoint,) # Mypy encounters a false positive here: it complains that ClientFactory # is incompatible with IProtocolFactory. But ClientFactory inherits from # Factory, which implements IProtocolFactory. So I think this is a bug # in mypy-zope. - def connect(self, protocolFactory: ClientFactory): # type: ignore[override] + def connect(self, protocolFactory: ClientFactory) -> "defer.Deferred[IProtocol]": # type: ignore[override] f = HTTPProxiedClientFactory( self._host, self._port, protocolFactory, self._proxy_creds ) @@ -125,10 +132,10 @@ def __init__( self.proxy_creds = proxy_creds self.on_connection: "defer.Deferred[None]" = defer.Deferred() - def startedConnecting(self, connector): + def startedConnecting(self, connector: IConnector) -> None: return self.wrapped_factory.startedConnecting(connector) - def buildProtocol(self, addr): + def buildProtocol(self, addr: IAddress) -> "HTTPConnectProtocol": wrapped_protocol = self.wrapped_factory.buildProtocol(addr) if wrapped_protocol is None: raise TypeError("buildProtocol produced None instead of a Protocol") @@ -141,13 +148,13 @@ def buildProtocol(self, addr): self.proxy_creds, ) - def clientConnectionFailed(self, connector, reason): + def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None: logger.debug("Connection to proxy failed: %s", reason) if not self.on_connection.called: self.on_connection.errback(reason) return self.wrapped_factory.clientConnectionFailed(connector, reason) - def clientConnectionLost(self, connector, reason): + def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None: logger.debug("Connection to proxy lost: %s", reason) if not self.on_connection.called: self.on_connection.errback(reason) @@ -191,10 +198,10 @@ def __init__( ) self.http_setup_client.on_connected.addCallback(self.proxyConnected) - def connectionMade(self): + def connectionMade(self) -> None: self.http_setup_client.makeConnection(self.transport) - def connectionLost(self, reason=connectionDone): + def connectionLost(self, reason: Failure = connectionDone) -> None: if self.wrapped_protocol.connected: self.wrapped_protocol.connectionLost(reason) @@ -203,7 +210,7 @@ def connectionLost(self, reason=connectionDone): if not self.connected_deferred.called: self.connected_deferred.errback(reason) - def proxyConnected(self, _): + def proxyConnected(self, _: Union[None, "defer.Deferred[None]"]) -> None: self.wrapped_protocol.makeConnection(self.transport) self.connected_deferred.callback(self.wrapped_protocol) @@ -213,7 +220,7 @@ def proxyConnected(self, _): if buf: self.wrapped_protocol.dataReceived(buf) - def dataReceived(self, data: bytes): + def dataReceived(self, data: bytes) -> None: # if we've set up the HTTP protocol, we can send the data there if self.wrapped_protocol.connected: return self.wrapped_protocol.dataReceived(data) @@ -243,7 +250,7 @@ def __init__( self.proxy_creds = proxy_creds self.on_connected: "defer.Deferred[None]" = defer.Deferred() - def connectionMade(self): + def connectionMade(self) -> None: logger.debug("Connected to proxy, sending CONNECT") self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port)) @@ -257,14 +264,14 @@ def connectionMade(self): self.endHeaders() - def handleStatus(self, version: bytes, status: bytes, message: bytes): + def handleStatus(self, version: bytes, status: bytes, message: bytes) -> None: logger.debug("Got Status: %s %s %s", status, message, version) if status != b"200": raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}") - def handleEndHeaders(self): + def handleEndHeaders(self) -> None: logger.debug("End Headers") self.on_connected.callback(None) - def handleResponse(self, body): + def handleResponse(self, body: bytes) -> None: pass diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py index a16dde23807f..b2a50c910507 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py @@ -245,7 +245,7 @@ def http_proxy_endpoint( proxy: Optional[bytes], reactor: IReactorCore, tls_options_factory: Optional[IPolicyForHTTPS], - **kwargs, + **kwargs: object, ) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]: """Parses an http proxy setting and returns an endpoint for the proxy diff --git a/synapse/logging/_remote.py b/synapse/logging/_remote.py index 475756f1db64..5a61b21eaf7e 100644 --- a/synapse/logging/_remote.py +++ b/synapse/logging/_remote.py @@ -31,7 +31,11 @@ TCP4ClientEndpoint, TCP6ClientEndpoint, ) -from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint +from twisted.internet.interfaces import ( + IPushProducer, + IReactorTCP, + IStreamClientEndpoint, +) from twisted.internet.protocol import Factory, Protocol from twisted.internet.tcp import Connection from twisted.python.failure import Failure @@ -59,14 +63,14 @@ class LogProducer: _buffer: Deque[logging.LogRecord] _paused: bool = attr.ib(default=False, init=False) - def pauseProducing(self): + def pauseProducing(self) -> None: self._paused = True - def stopProducing(self): + def stopProducing(self) -> None: self._paused = True self._buffer = deque() - def resumeProducing(self): + def resumeProducing(self) -> None: # If we're already producing, nothing to do. self._paused = False @@ -102,8 +106,8 @@ def __init__( host: str, port: int, maximum_buffer: int = 1000, - level=logging.NOTSET, - _reactor=None, + level: int = logging.NOTSET, + _reactor: Optional[IReactorTCP] = None, ): super().__init__(level=level) self.host = host @@ -118,7 +122,7 @@ def __init__( if _reactor is None: from twisted.internet import reactor - _reactor = reactor + _reactor = reactor # type: ignore[assignment] try: ip = ip_address(self.host) @@ -139,7 +143,7 @@ def __init__( self._stopping = False self._connect() - def close(self): + def close(self) -> None: self._stopping = True self._service.stopService() diff --git a/synapse/logging/formatter.py b/synapse/logging/formatter.py index c0f12ecd15b8..c88b8ae5450f 100644 --- a/synapse/logging/formatter.py +++ b/synapse/logging/formatter.py @@ -16,6 +16,8 @@ import logging import traceback from io import StringIO +from types import TracebackType +from typing import Optional, Tuple, Type class LogFormatter(logging.Formatter): @@ -28,10 +30,14 @@ class LogFormatter(logging.Formatter): where it was caught are logged). """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def formatException(self, ei): + def formatException( + self, + ei: Tuple[ + Optional[Type[BaseException]], + Optional[BaseException], + Optional[TracebackType], + ], + ) -> str: sio = StringIO() (typ, val, tb) = ei diff --git a/synapse/logging/handlers.py b/synapse/logging/handlers.py index 478b5274942b..dec2a2c3dd1a 100644 --- a/synapse/logging/handlers.py +++ b/synapse/logging/handlers.py @@ -49,7 +49,7 @@ def __init__( ) self._flushing_thread.start() - def on_reactor_running(): + def on_reactor_running() -> None: self._reactor_started = True reactor_to_use: IReactorCore @@ -74,7 +74,7 @@ def shouldFlush(self, record: LogRecord) -> bool: else: return True - def _flush_periodically(self): + def _flush_periodically(self) -> None: """ Whilst this handler is active, flush the handler periodically. """ diff --git a/synapse/logging/scopecontextmanager.py b/synapse/logging/scopecontextmanager.py index d57e7c5324f8..a26a1a58e7d6 100644 --- a/synapse/logging/scopecontextmanager.py +++ b/synapse/logging/scopecontextmanager.py @@ -13,6 +13,8 @@ # limitations under the License.import logging import logging +from types import TracebackType +from typing import Optional, Type from opentracing import Scope, ScopeManager @@ -107,19 +109,26 @@ class _LogContextScope(Scope): and - if enter_logcontext was set - the logcontext is finished too. """ - def __init__(self, manager, span, logcontext, enter_logcontext, finish_on_close): + def __init__( + self, + manager: LogContextScopeManager, + span, + logcontext, + enter_logcontext: bool, + finish_on_close: bool, + ): """ Args: - manager (LogContextScopeManager): + manager: the manager that is responsible for this scope. span (Span): the opentracing span which this scope represents the local lifetime for. logcontext (LogContext): the logcontext to which this scope is attached. - enter_logcontext (Boolean): + enter_logcontext: if True the logcontext will be exited when the scope is finished - finish_on_close (Boolean): + finish_on_close: if True finish the span when the scope is closed """ super().__init__(manager, span) @@ -127,16 +136,21 @@ def __init__(self, manager, span, logcontext, enter_logcontext, finish_on_close) self._finish_on_close = finish_on_close self._enter_logcontext = enter_logcontext - def __exit__(self, exc_type, value, traceback): + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: if exc_type == twisted.internet.defer._DefGen_Return: # filter out defer.returnValue() calls exc_type = value = traceback = None super().__exit__(exc_type, value, traceback) - def __str__(self): + def __str__(self) -> str: return f"Scope<{self.span}>" - def close(self): + def close(self) -> None: active_scope = self.manager.active if active_scope is not self: logger.error( diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 08c6eabc6d1a..c2bbbb574e75 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -12,20 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from types import TracebackType from typing import ( TYPE_CHECKING, + Any, AsyncContextManager, Awaitable, Callable, Dict, Iterable, + List, Optional, + Type, ) import attr from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.types import Connection +from synapse.storage.types import Connection, Cursor from synapse.types import JsonDict from synapse.util import Clock, json_encoder @@ -74,7 +78,12 @@ async def __aenter__(self) -> int: return self._update_duration_ms - async def __aexit__(self, *exc) -> None: + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: pass @@ -352,7 +361,7 @@ async def do_next_background_update(self, sleep: bool = True) -> bool: True if we have finished running all the background updates, otherwise False """ - def get_background_updates_txn(txn): + def get_background_updates_txn(txn: Cursor) -> List[Dict[str, Any]]: txn.execute( """ SELECT update_name, depends_on FROM background_updates @@ -469,7 +478,7 @@ def register_background_update_handler( self, update_name: str, update_handler: Callable[[JsonDict, int], Awaitable[int]], - ): + ) -> None: """Register a handler for doing a background update. The handler should take two arguments: @@ -603,7 +612,7 @@ def create_index_sqlite(conn: Connection) -> None: else: runner = create_index_sqlite - async def updater(progress, batch_size): + async def updater(progress: JsonDict, batch_size: int) -> int: if runner is not None: logger.info("Adding index %s to %s", index_name, table) await self.db_pool.runWithConnection(runner) diff --git a/synapse/types.py b/synapse/types.py index 9ac688b23b28..325332a6e00f 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -24,6 +24,7 @@ Mapping, Match, MutableMapping, + NoReturn, Optional, Set, Tuple, @@ -35,6 +36,7 @@ import attr from frozendict import frozendict from signedjson.key import decode_verify_key_bytes +from signedjson.types import VerifyKey from typing_extensions import TypedDict from unpaddedbase64 import decode_base64 from zope.interface import Interface @@ -55,6 +57,7 @@ if TYPE_CHECKING: from synapse.appservice.api import ApplicationService from synapse.storage.databases.main import DataStore, PurgeEventsStore + from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore # Define a state map type from type/state_key to T (usually an event ID or # event) @@ -114,7 +117,7 @@ class Requester: app_service: Optional["ApplicationService"] authenticated_entity: str - def serialize(self): + def serialize(self) -> Dict[str, Any]: """Converts self to a type that can be serialized as JSON, and then deserialized by `deserialize` @@ -132,7 +135,9 @@ def serialize(self): } @staticmethod - def deserialize(store, input): + def deserialize( + store: "ApplicationServiceWorkerStore", input: Dict[str, Any] + ) -> "Requester": """Converts a dict that was produced by `serialize` back into a Requester. @@ -236,10 +241,10 @@ class DomainSpecificString(metaclass=abc.ABCMeta): domain: str # Because this is a frozen class, it is deeply immutable. - def __copy__(self): + def __copy__(self: DS) -> DS: return self - def __deepcopy__(self, memo): + def __deepcopy__(self: DS, memo: Dict[str, object]) -> DS: return self @classmethod @@ -729,12 +734,14 @@ async def to_string(self, store: "DataStore") -> str: ) @property - def room_stream_id(self): + def room_stream_id(self) -> int: return self.room_key.stream - def copy_and_advance(self, key, new_value) -> "StreamToken": + def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken": """Advance the given key in the token to a new value if and only if the new value is after the old value. + + :raises TypeError: if `key` is not the one of the keys tracked by a StreamToken. """ if key == "room_key": new_token = self.copy_and_replace( @@ -751,7 +758,7 @@ def copy_and_advance(self, key, new_value) -> "StreamToken": else: return self - def copy_and_replace(self, key, new_value) -> "StreamToken": + def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken": return attr.evolve(self, **{key: new_value}) @@ -793,14 +800,14 @@ class ThirdPartyInstanceID: # Deny iteration because it will bite you if you try to create a singleton # set by: # users = set(user) - def __iter__(self): + def __iter__(self) -> NoReturn: raise ValueError("Attempted to iterate a %s" % (type(self).__name__,)) # Because this class is a frozen class, it is deeply immutable. - def __copy__(self): + def __copy__(self) -> "ThirdPartyInstanceID": return self - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: Dict[str, object]) -> "ThirdPartyInstanceID": return self @classmethod @@ -852,25 +859,28 @@ def __bool__(self) -> bool: return bool(self.changed or self.left) -def get_verify_key_from_cross_signing_key(key_info): +def get_verify_key_from_cross_signing_key( + key_info: Mapping[str, Any] +) -> Tuple[str, VerifyKey]: """Get the key ID and signedjson verify key from a cross-signing key dict Args: - key_info (dict): a cross-signing key dict, which must have a "keys" + key_info: a cross-signing key dict, which must have a "keys" property that has exactly one item in it Returns: - (str, VerifyKey): the key ID and verify key for the cross-signing key + the key ID and verify key for the cross-signing key """ - # make sure that exactly one key is provided + # make sure that a `keys` field is provided if "keys" not in key_info: raise ValueError("Invalid key") keys = key_info["keys"] - if len(keys) != 1: - raise ValueError("Invalid key") - # and return that one key - for key_id, key_data in keys.items(): + # and that it contains exactly one key + if len(keys) == 1: + key_id, key_data = next(iter(keys.items())) return key_id, decode_verify_key_bytes(key_id, decode_base64(key_data)) + else: + raise ValueError("Invalid key") @attr.s(auto_attribs=True, frozen=True, slots=True)