Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace josepy with PyJWT #543

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 20 additions & 38 deletions mozilla_django_oidc/auth.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import base64
import hashlib
import json
import logging

import inspect
import jwt
import requests
from django.contrib.auth import get_user_model
from django.contrib.auth.backends import ModelBackend
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation
from django.urls import reverse
from django.utils.encoding import force_bytes, smart_bytes, smart_str
from django.utils.encoding import force_bytes, smart_str
from django.utils.module_loading import import_string
from josepy.b64 import b64decode
from josepy.jwk import JWK
from josepy.jws import JWS, Header
from requests.auth import HTTPBasicAuth
from requests.exceptions import HTTPError

Expand Down Expand Up @@ -127,10 +124,10 @@ def update_user(self, user, claims):

def _verify_jws(self, payload, key):
"""Verify the given JWS payload with the given key and return the payload"""
jws = JWS.from_compact(payload)
jws = jwt.get_unverified_header(payload)

try:
alg = jws.signature.combined.alg.name
alg = jws["alg"]
except KeyError:
msg = "No alg value found in header"
raise SuspiciousOperation(msg)
Expand All @@ -142,21 +139,19 @@ def _verify_jws(self, payload, key):
)
raise SuspiciousOperation(msg)

if isinstance(key, str):
# Use smart_bytes here since the key string comes from settings.
jwk = JWK.load(smart_bytes(key))
else:
# The key is a json returned from the IDP JWKS endpoint.
jwk = JWK.from_json(key)

if not jws.verify(jwk):
try:
# Maybe add a settings to enforce audiance validation
return jwt.decode(payload, key, algorithms=alg, options={"verify_aud": False})
except jwt.DecodeError:
msg = "JWS token verification failed."
raise SuspiciousOperation(msg)

return jws.payload

def retrieve_matching_jwk(self, token):
"""Get the signing key by exploring the JWKS endpoint of the OP."""
"""Get the signing key by exploring the JWKS endpoint of the OP.

Don't use jwt.PyJWKClient()get_signing_key_from_jwt() because it doesn't check
the algorithm in case of multiple jwk with the same kid.
"""
response_jwks = requests.get(
self.OIDC_OP_JWKS_ENDPOINT,
verify=self.get_settings("OIDC_VERIFY_SSL", True),
Expand All @@ -167,32 +162,29 @@ def retrieve_matching_jwk(self, token):
jwks = response_jwks.json()

# Compute the current header from the given token to find a match
jws = JWS.from_compact(token)
json_header = jws.signature.protected
header = Header.json_loads(json_header)
jws = jwt.get_unverified_header(token)

key = None
for jwk in jwks["keys"]:
if import_from_settings("OIDC_VERIFY_KID", True) and jwk[
"kid"
] != smart_str(header.kid):
] != smart_str(jws["kid"]):
continue
if "alg" in jwk and jwk["alg"] != smart_str(header.alg):
if "alg" in jwk and jwk["alg"] != smart_str(jws["alg"]):
continue
key = jwk
if key is None:
raise SuspiciousOperation("Could not find a valid JWKS.")
return key
return jwt.PyJWK(key)

def get_payload_data(self, token, key):
"""Helper method to get the payload of the JWT token."""
if self.get_settings("OIDC_ALLOW_UNSECURED_JWT", False):
header, payload_data, signature = token.split(b".")
header = json.loads(smart_str(b64decode(header)))
header = jwt.get_unverified_header(token)

# If config allows unsecured JWTs check the header and return the decoded payload
if "alg" in header and header["alg"] == "none":
return b64decode(payload_data)
return jwt.decode(token, options={"verify_signature": False})

# By default fallback to verify JWT signatures
return self._verify_jws(token, key)
Expand All @@ -201,7 +193,6 @@ def verify_token(self, token, **kwargs):
"""Validate the token signature."""
nonce = kwargs.get("nonce")

token = force_bytes(token)
if self.OIDC_RP_SIGN_ALGO.startswith("RS") or self.OIDC_RP_SIGN_ALGO.startswith(
"ES"
):
Expand All @@ -212,16 +203,7 @@ def verify_token(self, token, **kwargs):
else:
key = self.OIDC_RP_CLIENT_SECRET

payload_data = self.get_payload_data(token, key)

# The 'token' will always be a byte string since it's
# the result of base64.urlsafe_b64decode().
# The payload is always the result of base64.urlsafe_b64decode().
# In Python 3 and 2, that's always a byte string.
# In Python3.6, the json.loads() function can accept a byte string
# as it will automagically decode it to a unicode string before
# deserializing https://bugs.python.org/issue17909
payload = json.loads(payload_data.decode("utf-8"))
payload = self.get_payload_data(token, key)
token_nonce = payload.get("nonce")

if self.get_settings("OIDC_USE_NONCE", True) and nonce != token_nonce:
Expand Down
27 changes: 13 additions & 14 deletions mozilla_django_oidc/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import logging
import time
import warnings
from base64 import urlsafe_b64decode, urlsafe_b64encode
from hashlib import sha256
from urllib.request import parse_http_list, parse_keqv_list

# Make it obvious that these aren't the usual base64 functions
import josepy.b64
from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.utils.encoding import force_bytes

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -57,16 +57,12 @@ def is_authenticated(user):

def base64_url_encode(bytes_like_obj):
"""Return a URL-Safe, base64 encoded version of bytes_like_obj

Implements base64urlencode as described in
https://datatracker.ietf.org/doc/html/rfc7636#appendix-A
This function is not used by the OpenID client; it's just for testing PKCE related functions.
"""

s = josepy.b64.b64encode(bytes_like_obj).decode("ascii") # base64 encode
# the josepy base64 encoder (strips '='s padding) automatically
s = s.replace("+", "-") # 62nd char of encoding
s = s.replace("/", "_") # 63rd char of encoding

s = urlsafe_b64encode(force_bytes(bytes_like_obj)).decode('utf-8')
s = s.rstrip("=")
return s


Expand All @@ -78,11 +74,14 @@ def base64_url_decode(string_like_obj):
"""
s = string_like_obj

s = s.replace("_", "/") # 63rd char of encoding
s = s.replace("-", "+") # 62nd char of encoding
b = josepy.b64.b64decode(s) # josepy base64 encoder (decodes without '='s padding)

return b
size = len(s) % 4
if size == 2:
s += '=='
elif size == 3:
s += '='
elif size != 0:
raise ValueError('Invalid base64 string')
return urlsafe_b64decode(s.encode('utf-8'))


def generate_code_challenge(code_verifier, method):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

install_requirements = [
"Django >= 3.2",
"josepy",
"pyjwt",
"requests",
"cryptography",
]
Expand Down
Loading