diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 586038078f86..f1a3f9d1df5a 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -51,12 +51,14 @@ def read_config(self, config, **kwargs): "client_auth_method", "client_secret_basic" ) self.oidc_scopes = oidc_config.get("scopes", ["openid"]) + self.oidc_uses_userinfo = oidc_config.get("uses_userinfo", False) self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint") self.oidc_token_endpoint = oidc_config.get("token_endpoint") self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint") self.oidc_jwks_uri = oidc_config.get("jwks_uri") self.oidc_subject_claim = oidc_config.get("subject_claim", "sub") self.oidc_skip_verification = oidc_config.get("skip_verification", False) + self.oidc_merge_with_existing_users = oidc_config.get("merge_with_existing_users", False) ump_config = oidc_config.get("user_mapping_provider", {}) ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) @@ -118,6 +120,11 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): # #scopes: ["openid"] + # always use userinfo endpoint. This is required for providers that don't include user + # information in the token response, e.g. Gitlab. + # + #uses_userinfo: false + # the oauth2 authorization endpoint. Required if provider discovery is disabled. # #authorization_endpoint: "https://accounts.example.com/oauth2/auth" @@ -140,6 +147,10 @@ def generate_config_section(self, config_dir_path, server_name, **kwargs): # #skip_verification: false + # if user already exists, add oidc token to that account instead of failing. Defaults to false. + # + #merge_with_existing_users: false + # An external module can be provided here as a custom solution to mapping # attributes returned from a OIDC provider onto a matrix user. diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 4ba8c7fda502..7d475ded2ce0 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -93,6 +93,7 @@ class OidcHandler: def __init__(self, hs: HomeServer): self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = hs.config.oidc_scopes # type: List[str] + self._uses_userinfo_config = hs.config.oidc_uses_userinfo # type: bool self._client_auth = ClientAuth( hs.config.oidc_client_id, hs.config.oidc_client_secret, @@ -112,6 +113,7 @@ def __init__(self, hs: HomeServer): hs.config.oidc_user_mapping_provider_config ) # type: OidcMappingProvider self._skip_verification = hs.config.oidc_skip_verification # type: bool + self._merge_with_existing_users = hs.config.oidc_merge_with_existing_users # type: bool self._http_client = hs.get_proxied_http_client() self._auth_handler = hs.get_auth_handler() @@ -224,8 +226,7 @@ def _uses_userinfo(self) -> bool: ``access_token`` with the ``userinfo_endpoint``. """ - # Maybe that should be user-configurable and not inferred? - return "openid" not in self._scopes + return self._uses_userinfo_config or "openid" not in self._scopes async def load_metadata(self) -> OpenIDProviderMetadata: """Load and validate the provider metadata. @@ -884,17 +885,20 @@ async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str: user_id = UserID(localpart, self._hostname) if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()): - # This mxid is taken - raise MappingException( - "mxid '{}' is already taken".format(user_id.to_string()) + if self._merge_with_existing_users: + registered_user_id = user_id.to_string() + else: + # This mxid is taken + raise MappingException( + "mxid '{}' is already taken".format(user_id.to_string()) + ) + else: + # It's the first time this user is logging in and the mapped mxid was + # not taken, register the user + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, default_display_name=attributes["display_name"], ) - # It's the first time this user is logging in and the mapped mxid was - # not taken, register the user - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=attributes["display_name"], - ) - await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id, )