From 9bd40d1fd56e44ca997f24d95e32166328978003 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Thu, 7 May 2020 12:02:00 +0200 Subject: [PATCH] OIDC login: make the user attribute mapping async Also passes the token as parameter of the mapping provider Signed-off-by: Quentin Gliech --- synapse/handlers/oidc_handler.py | 18 +++++++++++++----- tests/handlers/test_oidc.py | 4 ++-- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 5ad73100aceb..6a08fe2d1d5e 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -686,7 +686,7 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None: # Call the mapper to register/login the user try: - user_id = await self._map_userinfo_to_user(userinfo) + user_id = await self._map_userinfo_to_user(userinfo, token) except MappingException as e: logger.exception("Could not map user") self._render_error(request, "mapping_error", str(e)) @@ -724,7 +724,7 @@ def _verify_expiry(self, caveat: str) -> bool: now = self._clock.time_msec() return now < expiry - async def _map_userinfo_to_user(self, userinfo: UserInfo) -> str: + async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str: """Maps a UserInfo object to a mxid. UserInfo should have a claim that uniquely identifies users. This claim @@ -738,6 +738,7 @@ async def _map_userinfo_to_user(self, userinfo: UserInfo) -> str: Args: userinfo: an object representing the user + token: a dict with the tokens obtained from the provider Raises: MappingException: if there was an error while mapping some properties @@ -767,7 +768,9 @@ async def _map_userinfo_to_user(self, userinfo: UserInfo) -> str: return registered_user_id try: - attributes = self._user_mapping_provider.map_user_attributes(userinfo) + attributes = await self._user_mapping_provider.map_user_attributes( + userinfo, token + ) except Exception as e: raise MappingException( "Could not extract user attributes from OIDC response: " + str(e) @@ -845,11 +848,14 @@ def get_remote_user_id(self, userinfo: UserInfo) -> str: """ raise NotImplementedError() - def map_user_attributes(self, userinfo: UserInfo) -> UserAttribute: + async def map_user_attributes( + self, userinfo: UserInfo, token: Token + ) -> UserAttribute: """Map a ``UserInfo`` objects into user attributes. Args: userinfo: An object representing the user given by the OIDC provider + token: A dict with the tokens returned by the provider Returns: A dict containing the ``localpart`` and (optionally) the ``display_name`` @@ -919,7 +925,9 @@ def parse_config(config: dict) -> JinjaOidcMappingConfig: def get_remote_user_id(self, userinfo: UserInfo) -> str: return userinfo[self._config["subject_claim"]] - def map_user_attributes(self, userinfo: UserInfo) -> UserAttribute: + async def map_user_attributes( + self, userinfo: UserInfo, token: Token + ) -> UserAttribute: localpart = self._config["localpart_template"].render(user=userinfo).strip() display_name = None # type: Optional[str] diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 65141aeaf780..6b005f82406c 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -410,7 +410,7 @@ def test_callback(self): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo) + self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) self.handler._fetch_userinfo.assert_not_called() self.handler._render_error.assert_not_called() @@ -442,7 +442,7 @@ def test_callback(self): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo) + self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) self.handler._fetch_userinfo.assert_called_once_with(token) self.handler._render_error.assert_not_called()