Skip to content

Commit

Permalink
improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-pujol committed Sep 23, 2022
1 parent b4cb377 commit 52ad753
Show file tree
Hide file tree
Showing 16 changed files with 164 additions and 40 deletions.
8 changes: 4 additions & 4 deletions jwskate/jwa/encryption/aescbchmac.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ def mac(
Returns:
the resulting MAC.
"""
if aad is None:
if aad is None: # pragma: no branch
aad = b""
al = BinaPy.from_int(len(aad) * 8, length=8, byteorder="big", signed=False)
hasher = hmac.HMAC(self.mac_key, self.hash_alg)
if not isinstance(ciphertext, bytes):
if not isinstance(ciphertext, bytes): # pragma: no branch
ciphertext = bytes(ciphertext)

for param in (aad, iv, ciphertext, al):
Expand All @@ -87,7 +87,7 @@ def encrypt(
Returns:
a tuple (encrypted_data, authentication_tag)
"""
if not isinstance(plaintext, bytes):
if not isinstance(plaintext, bytes): # pragma: no branch
plaintext = bytes(plaintext)

cipher = ciphers.Cipher(algorithms.AES(self.aes_key), modes.CBC(iv)).encryptor()
Expand Down Expand Up @@ -116,7 +116,7 @@ def decrypt(
Returns:
the decrypted data
"""
if not isinstance(ciphertext, bytes):
if not isinstance(ciphertext, bytes): # pragma: no branch
ciphertext = bytes(ciphertext)

mac = self.mac(ciphertext, iv=iv, aad=aad)
Expand Down
4 changes: 2 additions & 2 deletions jwskate/jwa/encryption/aesgcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def encrypt(
"""
if len(iv) * 8 != self.iv_size:
raise ValueError(f"Invalid IV size, must be {self.iv_size} bits")
if not isinstance(plaintext, bytes):
if not isinstance(plaintext, bytes): # pragma: no branch
plaintext = bytes(plaintext)
ciphertext_with_tag = BinaPy(aead.AESGCM(self.key).encrypt(iv, plaintext, aad))
ciphertext, tag = ciphertext_with_tag.cut_at(-self.tag_size)
Expand Down Expand Up @@ -60,7 +60,7 @@ def decrypt(
Raises:
ValueError: if the IV size is not appropriate
"""
if not isinstance(ciphertext, bytes):
if not isinstance(ciphertext, bytes): # pragma: no branch
ciphertext = bytes(ciphertext)

if len(iv) * 8 != self.iv_size:
Expand Down
2 changes: 1 addition & 1 deletion jwskate/jwa/key_mgmt/pbes2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class BasePbes2(BaseKeyManagementAlg):
def __init__(self, password: Union[SupportsBytes, bytes, str]):
if isinstance(password, str):
password = password.encode("utf-8")
if not isinstance(password, bytes):
if not isinstance(password, bytes): # pragma: no branch
password = bytes(password)
self.password = password

Expand Down
13 changes: 11 additions & 2 deletions jwskate/jwa/signature/ec.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,17 @@ class BaseECSignatureAlg(
public_key_class = ec.EllipticCurvePublicKey
private_key_class = ec.EllipticCurvePrivateKey

@classmethod
def check_key(
cls, key: Union[ec.EllipticCurvePrivateKey, ec.EllipticCurvePublicKey]
) -> None: # noqa: D102
if key.curve.name != cls.curve.cryptography_curve.name:
raise ValueError(
f"This key is on curve {key.curve.name}. An EC key on curve {cls.curve.name} is expected."
)

def sign(self, data: Union[bytes, SupportsBytes]) -> BinaPy: # noqa: D102
if not isinstance(data, bytes):
if not isinstance(data, bytes): # pragma: no branch
data = bytes(data)

with self.private_key_required() as key:
Expand All @@ -35,7 +44,7 @@ def sign(self, data: Union[bytes, SupportsBytes]) -> BinaPy: # noqa: D102
def verify(
self, data: Union[bytes, SupportsBytes], signature: bytes
) -> bool: # noqa: D102
if not isinstance(data, bytes):
if not isinstance(data, bytes): # pragma: no branch
data = bytes(data)

with self.public_key_required() as key:
Expand Down
4 changes: 2 additions & 2 deletions jwskate/jwa/signature/eddsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class EdDsa(
description = __doc__

def sign(self, data: Union[bytes, SupportsBytes]) -> BinaPy: # noqa: D102
if not isinstance(data, bytes):
if not isinstance(data, bytes): # pragma: no branch
data = bytes(data)

with self.private_key_required() as key:
Expand All @@ -34,7 +34,7 @@ def sign(self, data: Union[bytes, SupportsBytes]) -> BinaPy: # noqa: D102
def verify(
self, data: Union[bytes, SupportsBytes], signature: bytes
) -> bool: # noqa: D102
if not isinstance(data, bytes):
if not isinstance(data, bytes): # pragma: no branch
data = bytes(data)

with self.public_key_required() as key:
Expand Down
4 changes: 2 additions & 2 deletions jwskate/jwa/signature/hmac.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class BaseHMACSigAlg(BaseSymmetricAlg, BaseSignatureAlg):
min_key_size: int

def sign(self, data: Union[bytes, SupportsBytes]) -> BinaPy: # noqa: D102
if not isinstance(data, bytes):
if not isinstance(data, bytes): # pragma: no branch
data = bytes(data)

if self.read_only:
Expand All @@ -29,7 +29,7 @@ def sign(self, data: Union[bytes, SupportsBytes]) -> BinaPy: # noqa: D102
def verify(
self, data: Union[bytes, SupportsBytes], signature: bytes
) -> bool: # noqa: D102
if not isinstance(data, bytes):
if not isinstance(data, bytes): # pragma: no branch
data = bytes(data)

candidate_signature = self.sign(data)
Expand Down
4 changes: 2 additions & 2 deletions jwskate/jwa/signature/rsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def sign(self, data: Union[bytes, SupportsBytes]) -> BinaPy:
if self.read_only:
raise NotImplementedError

if not isinstance(data, bytes):
if not isinstance(data, bytes): # pragma: no branch
data = bytes(data)

with self.private_key_required() as key:
Expand All @@ -54,7 +54,7 @@ def verify(self, data: Union[bytes, SupportsBytes], signature: bytes) -> bool:
Returns:
`True` if the signature is valid, `False` otherwise
"""
if not isinstance(data, bytes):
if not isinstance(data, bytes): # pragma: no branch
data = bytes(data)

with self.public_key_required() as key:
Expand Down
4 changes: 2 additions & 2 deletions jwskate/jwk/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,15 +295,15 @@ def alg(self) -> Optional[str]:
the key alg
"""
alg = self.get("alg")
if alg is not None and not isinstance(alg, str):
if alg is not None and not isinstance(alg, str): # pragma: no branch
raise TypeError(f"Invalid alg type {type(alg)}", alg)
return alg

@property
def kid(self) -> Optional[str]:
"""Return the JWK key ID (kid), if present."""
kid = self.get("kid")
if kid is not None and not isinstance(kid, str):
if kid is not None and not isinstance(kid, str): # pragma: no branch
raise TypeError(f"invalid kid type {type(kid)}", kid)
return kid

Expand Down
17 changes: 8 additions & 9 deletions jwskate/jwk/okp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from jwskate.jwa import X448, X25519, Ed448, Ed25519, EdDsa, OKPCurve

from .. import EcdhEs, EcdhEs_A128KW, EcdhEs_A192KW, EcdhEs_A256KW
from .alg import UnsupportedAlg
from .base import Jwk, JwkParameter


Expand Down Expand Up @@ -307,22 +308,20 @@ def generate(
Returns:
the resulting OKPJwk
"""
if crv is None and alg is None:
raise ValueError(
"You must supply at least a Curve identifier (crv) or an Algorithm identifier (alg) "
"in order to generate an OKP JWK."
)
curve: Optional[OKPCurve] = None
if crv:
curve = cls.get_curve(crv)
elif alg:
if alg in cls.SIGNATURE_ALGORITHMS:
curve = Ed25519
elif alg in cls.KEY_MANAGEMENT_ALGORITHMS:
curve = X25519

if curve is None:
raise UnsupportedOKPCurve(crv)
else:
raise UnsupportedAlg(alg)
else:
raise ValueError(
"You must supply at least a Curve identifier (crv) or an Algorithm identifier (alg) "
"in order to generate an OKP JWK."
)

x, d = curve.generate()
return cls.private(crv=curve.name, x=x, d=d, alg=alg, **params)
2 changes: 1 addition & 1 deletion jwskate/jwk/rsa.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def prime_factors(self) -> Tuple[int, int]:
p, q = rsa.rsa_recover_prime_factors(
self.modulus, self.exponent, self.private_exponent
)
return p, q
return (p, q) if p < q else (q, p)
return (
BinaPy(self.p).decode_from("b64u").to_int(),
BinaPy(self.q).decode_from("b64u").to_int(),
Expand Down
8 changes: 4 additions & 4 deletions jwskate/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def alg(self) -> str:
AttributeError: if the `alg` header value is not a string
"""
alg = self.get_header("alg")
if alg is None or not isinstance(alg, str):
if alg is None or not isinstance(alg, str): # pragma: no branch
raise AttributeError("This token doesn't have a valid 'alg' header")
return alg

Expand All @@ -86,7 +86,7 @@ def kid(self) -> str:
AttributeError: if the `kid` header value is not a string
"""
kid = self.get_header("kid")
if kid is None or not isinstance(kid, str):
if kid is None or not isinstance(kid, str): # pragma: no branch
raise AttributeError("This token doesn't have a valid 'kid' header")
return kid

Expand All @@ -100,7 +100,7 @@ def typ(self) -> str:
AttributeError: if the `typ` header value is not a string
"""
typ = self.get_header("typ")
if typ is None or not isinstance(typ, str):
if typ is None or not isinstance(typ, str): # pragma: no branch
raise AttributeError("This token doesn't have a valid 'typ' header")
return typ

Expand All @@ -114,7 +114,7 @@ def cty(self) -> str:
AttributeError: if the `typ` header value is not a string
"""
cty = self.get_header("cty")
if cty is None or not isinstance(cty, str):
if cty is None or not isinstance(cty, str): # pragma: no branch
raise AttributeError("This token doesn't have a valid 'cty' header")
return cty

Expand Down
4 changes: 2 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 9 additions & 3 deletions tests/test_jwa.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Tests for the jwkskate.jwa submodule."""

import pytest
from binapy import BinaPy
from cryptography.hazmat.primitives.asymmetric import ec

from jwskate import Jwk
from jwskate.jwa import Aes128CbcHmacSha256, Aes192CbcHmacSha384, EcdhEs
from jwskate import ES256, Aes128CbcHmacSha256, Aes192CbcHmacSha384, EcdhEs, Jwk


def test_aes_128_hmac_sha256() -> None:
Expand Down Expand Up @@ -191,3 +191,9 @@ def test_ecdhes() -> None:
key_size=128,
)
assert BinaPy(bob_cek).to("b64u") == b"VqqN6vgjbSBcIijNcacQGg"


def test_ec_signature_invalid_size() -> None:
es256 = ES256(ec.generate_private_key(ec.SECP256R1()).public_key())
with pytest.raises(ValueError):
es256.verify(b"foo", b"bar")
38 changes: 34 additions & 4 deletions tests/test_jwk/test_okp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,45 @@
import pytest
from cryptography.hazmat.primitives.asymmetric import ed448, ed25519, x448, x25519

from jwskate import Jwk, JwsCompact, OKPJwk, UnsupportedOKPCurve
from jwskate import Jwk, JwsCompact, OKPJwk, UnsupportedAlg, UnsupportedOKPCurve


@pytest.mark.parametrize("curve", ["Ed25519", "Ed448", "X25519", "X448"])
def test_jwk_okp_generate(curve: str) -> None:
jwk = OKPJwk.generate(crv=curve, kid="myokpkey")
@pytest.mark.parametrize("crv", ["Ed25519", "Ed448", "X25519", "X448"])
def test_jwk_okp_generate_with_crv(crv: str) -> None:
jwk = OKPJwk.generate(crv=crv, kid="myokpkey")
assert jwk.kty == "OKP"
assert jwk.crv == crv
assert jwk.kid == "myokpkey"
assert "x" in jwk
assert "d" in jwk

assert jwk.supported_encryption_algorithms() == []


@pytest.mark.parametrize(
"alg", ["ECDH-ES", "ECDH-ES+A128KW", "ECDH-ES+A192KW", "ECDH-ES+A256KW"]
)
def test_jwk_okp_generate_with_alg(alg: str) -> None:
jwk = OKPJwk.generate(alg=alg, kid="myokpkey")
assert jwk.kty == "OKP"
assert jwk.crv == "X25519"
assert jwk.kid == "myokpkey"
assert "x" in jwk
assert "d" in jwk

assert jwk.supported_encryption_algorithms() == []


def test_generate_no_crv_no_alg() -> None:
with pytest.raises(ValueError):
OKPJwk.generate()


def test_generate_unsuppored_alg() -> None:
with pytest.raises(UnsupportedAlg):
OKPJwk.generate(alg="foo")


def test_okp_ed25519_sign() -> None:
jwk = Jwk(
{
Expand Down Expand Up @@ -96,3 +121,8 @@ def test_pem_key(crv: str) -> None:

with pytest.raises(ValueError):
assert Jwk.from_pem_key(public_pem, password) == public_jwk


def test_from_cryptography_key_unknown_type() -> None:
with pytest.raises(TypeError):
OKPJwk.from_cryptography_key("this is not a cryptography key")
Loading

0 comments on commit 52ad753

Please sign in to comment.