Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct EcPairingCheck & EcSubgroupCheck return types #717

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions pyteal/ast/ec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@ def __init__(self, id: int, name: str, min_version: int) -> None:


class EcOperation(Expr):
def __init__(self, op: Op, curve: EllipticCurve, args: list[Expr]) -> None:
def __init__(
self, op: Op, curve: EllipticCurve, args: list[Expr], return_type: TealType
) -> None:
super().__init__()
self.op = op
assert curve in EllipticCurve
self.curve = curve
for arg in args:
require_type(arg, TealType.bytes)
self.args = args
self.return_type = return_type

def __teal__(self, options: "CompileOptions"):
verifyProgramVersion(
Expand All @@ -52,7 +55,7 @@ def __str__(self):
return f"(EcOperation {self.op} {self.curve} {self.args})"

def type_of(self):
return TealType.bytes
return self.return_type

def has_return(self):
return False
Expand All @@ -70,7 +73,7 @@ def EcAdd(curve: EllipticCurve, a: Expr, b: Expr) -> Expr:
An expression which evaluates to the sum of the two points on the given
curve.
"""
return EcOperation(Op.ec_add, curve, [a, b])
return EcOperation(Op.ec_add, curve, [a, b], TealType.bytes)


def EcScalarMul(curve: EllipticCurve, point: Expr, scalar: Expr) -> Expr:
Expand All @@ -86,7 +89,7 @@ def EcScalarMul(curve: EllipticCurve, point: Expr, scalar: Expr) -> Expr:
An expression which evaluates to the product of the point and scalar on
the given curve.
"""
return EcOperation(Op.ec_scalar_mul, curve, [point, scalar])
return EcOperation(Op.ec_scalar_mul, curve, [point, scalar], TealType.bytes)


def EcPairingCheck(curve: EllipticCurve, a: Expr, b: Expr) -> Expr:
Expand All @@ -102,7 +105,7 @@ def EcPairingCheck(curve: EllipticCurve, a: Expr, b: Expr) -> Expr:
point in `a` with its respective point in `b` is equal to the identity
element of the target group. Otherwise, evaluates to 0.
"""
return EcOperation(Op.ec_pairing_check, curve, [a, b])
return EcOperation(Op.ec_pairing_check, curve, [a, b], TealType.uint64)


def EcMultiScalarMul(curve: EllipticCurve, a: Expr, b: Expr) -> Expr:
Expand All @@ -117,7 +120,7 @@ def EcMultiScalarMul(curve: EllipticCurve, a: Expr, b: Expr) -> Expr:
Returns:
An expression that evaluates to curve point :code:`b_0a_0 + b_1a_1 + b_2a_2 + ... + b_Na_N`.
"""
return EcOperation(Op.ec_multi_scalar_mul, curve, [a, b])
return EcOperation(Op.ec_multi_scalar_mul, curve, [a, b], TealType.bytes)


def EcSubgroupCheck(curve: EllipticCurve, a: Expr) -> Expr:
Expand All @@ -132,7 +135,7 @@ def EcSubgroupCheck(curve: EllipticCurve, a: Expr) -> Expr:
subgroup of the curve (including the point at infinity) else 0. Program
fails if the point is not in the curve at all.
"""
return EcOperation(Op.ec_subgroup_check, curve, [a])
return EcOperation(Op.ec_subgroup_check, curve, [a], TealType.uint64)


def EcMapTo(curve: EllipticCurve, a: Expr) -> Expr:
Expand All @@ -145,4 +148,4 @@ def EcMapTo(curve: EllipticCurve, a: Expr) -> Expr:
Returns:
An expression that evaluates to the mapped point.
"""
return EcOperation(Op.ec_map_to, curve, [a])
return EcOperation(Op.ec_map_to, curve, [a], TealType.bytes)
17 changes: 9 additions & 8 deletions pyteal/ast/ec_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,24 @@
| Callable[[pt.EllipticCurve, pt.Expr, pt.Expr], pt.Expr],
pt.Op,
int,
pt.TealType,
]
] = [
(pt.EcAdd, pt.Op.ec_add, 2),
(pt.EcScalarMul, pt.Op.ec_scalar_mul, 2),
(pt.EcPairingCheck, pt.Op.ec_pairing_check, 2),
(pt.EcMultiScalarMul, pt.Op.ec_multi_scalar_mul, 2),
(pt.EcSubgroupCheck, pt.Op.ec_subgroup_check, 1),
(pt.EcMapTo, pt.Op.ec_map_to, 1),
(pt.EcAdd, pt.Op.ec_add, 2, pt.TealType.bytes),
(pt.EcScalarMul, pt.Op.ec_scalar_mul, 2, pt.TealType.bytes),
(pt.EcPairingCheck, pt.Op.ec_pairing_check, 2, pt.TealType.uint64),
(pt.EcMultiScalarMul, pt.Op.ec_multi_scalar_mul, 2, pt.TealType.bytes),
(pt.EcSubgroupCheck, pt.Op.ec_subgroup_check, 1, pt.TealType.uint64),
(pt.EcMapTo, pt.Op.ec_map_to, 1, pt.TealType.bytes),
]


def test_EcOperation():
for operation, expected_op, num_args in OPERATIONS:
for operation, expected_op, num_args, expected_return_type in OPERATIONS:
for curve in pt.EllipticCurve:
args = [pt.Bytes(f"arg{i}") for i in range(num_args)]
expr = operation(curve, *args)
assert expr.type_of() == pt.TealType.bytes
assert expr.type_of() == expected_return_type

expected = pt.TealSimpleBlock(
[pt.TealOp(arg, pt.Op.byte, f'"arg{i}"') for i, arg in enumerate(args)]
Expand Down