Skip to content

Commit

Permalink
auth: Replace josepy with PyJWT
Browse files Browse the repository at this point in the history
  • Loading branch information
tonial committed Oct 19, 2024
1 parent 9bd27c7 commit 8554b50
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 105 deletions.
57 changes: 19 additions & 38 deletions mozilla_django_oidc/auth.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import base64
import hashlib
import json
import logging

import inspect
import jwt
import requests
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation
from django.urls import reverse
from django.utils.encoding import force_bytes, smart_bytes, smart_str
from django.utils.encoding import force_bytes, smart_str
from django.utils.module_loading import import_string
from josepy.b64 import b64decode
from josepy.jwk import JWK
from josepy.jws import JWS, Header
from requests.auth import HTTPBasicAuth
from requests.exceptions import HTTPError

Expand Down Expand Up @@ -127,10 +124,10 @@ def update_user(self, user, claims):

def _verify_jws(self, payload, key):
"""Verify the given JWS payload with the given key and return the payload"""
jws = JWS.from_compact(payload)
jws = jwt.get_unverified_header(payload)

try:
alg = jws.signature.combined.alg.name
alg = jws["alg"]
except KeyError:
msg = "No alg value found in header"
raise SuspiciousOperation(msg)
Expand All @@ -142,21 +139,18 @@ def _verify_jws(self, payload, key):
)
raise SuspiciousOperation(msg)

if isinstance(key, str):
# Use smart_bytes here since the key string comes from settings.
jwk = JWK.load(smart_bytes(key))
else:
# The key is a json returned from the IDP JWKS endpoint.
jwk = JWK.from_json(key)

if not jws.verify(jwk):
try:
return jwt.decode(payload, key, algorithms=alg, options={"verify_aud": False})
except jwt.DecodeError:
msg = "JWS token verification failed."
raise SuspiciousOperation(msg)

return jws.payload

def retrieve_matching_jwk(self, token):
"""Get the signing key by exploring the JWKS endpoint of the OP."""
"""Get the signing key by exploring the JWKS endpoint of the OP.
Don't use jwt.PyJWKClient()get_signing_key_from_jwt() because it doesn't check
the algorithm in case of multiple jwk with the same kid.
"""
response_jwks = requests.get(
self.OIDC_OP_JWKS_ENDPOINT,
verify=self.get_settings("OIDC_VERIFY_SSL", True),
Expand All @@ -167,32 +161,29 @@ def retrieve_matching_jwk(self, token):
jwks = response_jwks.json()

# Compute the current header from the given token to find a match
jws = JWS.from_compact(token)
json_header = jws.signature.protected
header = Header.json_loads(json_header)
jws = jwt.get_unverified_header(token)

key = None
for jwk in jwks["keys"]:
if import_from_settings("OIDC_VERIFY_KID", True) and jwk[
"kid"
] != smart_str(header.kid):
] != smart_str(jws["kid"]):
continue
if "alg" in jwk and jwk["alg"] != smart_str(header.alg):
if "alg" in jwk and jwk["alg"] != smart_str(jws["alg"]):
continue
key = jwk
if key is None:
raise SuspiciousOperation("Could not find a valid JWKS.")
return key
return jwt.PyJWK(key)

def get_payload_data(self, token, key):
"""Helper method to get the payload of the JWT token."""
if self.get_settings("OIDC_ALLOW_UNSECURED_JWT", False):
header, payload_data, signature = token.split(b".")
header = json.loads(smart_str(b64decode(header)))
header = jwt.get_unverified_header(token)

# If config allows unsecured JWTs check the header and return the decoded payload
if "alg" in header and header["alg"] == "none":
return b64decode(payload_data)
return jwt.decode(token, options={"verify_signature": False})

# By default fallback to verify JWT signatures
return self._verify_jws(token, key)
Expand All @@ -201,7 +192,6 @@ def verify_token(self, token, **kwargs):
"""Validate the token signature."""
nonce = kwargs.get("nonce")

token = force_bytes(token)
if self.OIDC_RP_SIGN_ALGO.startswith("RS") or self.OIDC_RP_SIGN_ALGO.startswith(
"ES"
):
Expand All @@ -212,16 +202,7 @@ def verify_token(self, token, **kwargs):
else:
key = self.OIDC_RP_CLIENT_SECRET

payload_data = self.get_payload_data(token, key)

# The 'token' will always be a byte string since it's
# the result of base64.urlsafe_b64decode().
# The payload is always the result of base64.urlsafe_b64decode().
# In Python 3 and 2, that's always a byte string.
# In Python3.6, the json.loads() function can accept a byte string
# as it will automagically decode it to a unicode string before
# deserializing https://bugs.python.org/issue17909
payload = json.loads(payload_data.decode("utf-8"))
payload = self.get_payload_data(token, key)
token_nonce = payload.get("nonce")

if self.get_settings("OIDC_USE_NONCE", True) and nonce != token_nonce:
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

install_requirements = [
"Django >= 3.2",
"josepy",
"pyjwt",
"requests",
"cryptography",
Expand Down
Loading

0 comments on commit 8554b50

Please sign in to comment.