Skip to content

Commit

Permalink
feat: 20-30% cost reduction in recursive ipa algorithm (#9420)
Browse files Browse the repository at this point in the history
eccvm_recursive_verifier_test measurements (size-512 eccvm recursive
verification)

Old: 876,214
New: 678,751

The relative performance delta should be much greater for large eccvm
instances as this PR removes an nlogn algorithm.

This PR resolves issue
[#857](AztecProtocol/barretenberg#857) and
issue [#1023](AztecProtocol/barretenberg#1023)
(single batch mul in IPA)

Re: [#1023](AztecProtocol/barretenberg#1023).
The code still performs 2 batch muls, but all additional * operator
calls have been combined into the batch muls.

It is not worth combining both batch muls, as it would require a
multiplication operation on a large number of scalar multipliers. In the
recursive setting the scalars are bigfield elements - the extra
bigfield::operator* cost is not worth combining both batch_mul calls.

Additional improvements:

removed unneccessary uses of `pow` operator in ipa - in the recursive
setting these were stdlib::bigfield::pow calls and very expensive

removed the number of distinct multiplication calls in
ipa::reduce_verify_internal

cycle_scalar::cycle_scalar(stdlib::bigfield) constructor now more
optimally constructs a cycle_scalar out of a bigfield element. New
method leverages the fact that `scalar.lo` and `scalar.hi` are
implicitly range-constrained to remove reundant bigfield constructor
calls and arithmetic calls, and the process of performing a scalar
multiplication applies a modular reduction to the imput, which makes the
explicit call to `validate_scalar_is_in_field` unneccessary

---------
Co-authored-by: lucasxia01 <[email protected]>
  • Loading branch information
zac-williamson and lucasxia01 authored Oct 29, 2024
1 parent 59810e0 commit a4bd3e1
Show file tree
Hide file tree
Showing 5 changed files with 179 additions and 80 deletions.
131 changes: 71 additions & 60 deletions barretenberg/cpp/src/barretenberg/commitment_schemes/ipa/ipa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,20 +369,29 @@ template <typename Curve_> class IPA {
// Construct vector s
std::vector<Fr> s_vec(poly_length, Fr::one());

// TODO(https://github.com/AztecProtocol/barretenberg/issues/857): This code is not efficient as its
// O(nlogn). This can be optimized to be linear by computing a tree of products. Its very readable, so we're
// leaving it unoptimized for now.
parallel_for_heuristic(
poly_length,
[&](size_t i) {
for (size_t j = (log_poly_degree - 1); j != static_cast<size_t>(-1); j--) {
auto bit = (i >> j) & 1;
bool b = static_cast<bool>(bit);
if (b) {
s_vec[i] *= round_challenges_inv[log_poly_degree - 1 - j];
}
}
}, thread_heuristics::FF_MULTIPLICATION_COST * log_poly_degree);
std::vector<Fr> s_vec_temporaries(poly_length / 2);

Fr* previous_round_s = &s_vec_temporaries[0];
Fr* current_round_s = &s_vec[0];
// if number of rounds is even we need to swap these so that s_vec always contains the result
if ((log_poly_degree & 1) == 0)
{
std::swap(previous_round_s, current_round_s);
}
previous_round_s[0] = Fr(1);
for (size_t i = 0; i < log_poly_degree; ++i)
{
const size_t round_size = 1 << (i + 1);
const Fr round_challenge = round_challenges_inv[i];
parallel_for_heuristic(
round_size / 2,
[&](size_t j) {
current_round_s[j * 2] = previous_round_s[j];
current_round_s[j * 2 + 1] = previous_round_s[j] * round_challenge;
}, thread_heuristics::FF_MULTIPLICATION_COST * 2);
std::swap(current_round_s, previous_round_s);
}


std::span<const Commitment> srs_elements = vk->get_monomial_points();
if (poly_length * 2 > srs_elements.size()) {
Expand Down Expand Up @@ -454,28 +463,20 @@ template <typename Curve_> class IPA {
const Fr generator_challenge = transcript->template get_challenge<Fr>("IPA:generator_challenge");
auto builder = generator_challenge.get_context();

Commitment aux_generator = Commitment::one(builder) * generator_challenge;

const auto log_poly_degree = numeric::get_msb(static_cast<uint32_t>(poly_length));

// Step 3.
// Compute C' = C + f(\beta) ⋅ U
GroupElement C_prime = opening_claim.commitment + aux_generator * opening_claim.opening_pair.evaluation;

auto pippenger_size = 2 * log_poly_degree;
std::vector<Fr> round_challenges(log_poly_degree);
std::vector<Fr> round_challenges_inv(log_poly_degree);
std::vector<Commitment> msm_elements(pippenger_size);
std::vector<Fr> msm_scalars(pippenger_size);

// Step 4.
// Step 3.
// Receive all L_i and R_i and prepare for MSM
for (size_t i = 0; i < log_poly_degree; i++) {
std::string index = std::to_string(log_poly_degree - i - 1);
auto element_L = transcript->template receive_from_prover<Commitment>("IPA:L_" + index);
auto element_R = transcript->template receive_from_prover<Commitment>("IPA:R_" + index);
round_challenges[i] = transcript->template get_challenge<Fr>("IPA:round_challenge_" + index);

round_challenges_inv[i] = round_challenges[i].invert();

msm_elements[2 * i] = element_L;
Expand All @@ -484,63 +485,73 @@ template <typename Curve_> class IPA {
msm_scalars[2 * i + 1] = round_challenges[i];
}

// Step 5.
// Compute C₀ = C' + ∑_{j ∈ [k]} u_j^{-1}L_j + ∑_{j ∈ [k]} u_jR_j
GroupElement LR_sums = GroupElement::batch_mul(msm_elements, msm_scalars);

GroupElement C_zero = C_prime + LR_sums;

// Step 6.
// Step 4.
// Compute b_zero where b_zero can be computed using the polynomial:
// g(X) = ∏_{i ∈ [k]} (1 + u_{i-1}^{-1}.X^{2^{i-1}}).
// b_zero = g(evaluation) = ∏_{i ∈ [k]} (1 + u_{i-1}^{-1}. (evaluation)^{2^{i-1}})

Fr b_zero = Fr(1);
Fr challenge = opening_claim.opening_pair.challenge;
for (size_t i = 0; i < log_poly_degree; i++) {
b_zero *= Fr(1) + (round_challenges_inv[log_poly_degree - 1 - i] *
opening_claim.opening_pair.challenge.pow(1 << i));
b_zero *= Fr(1) + (round_challenges_inv[log_poly_degree - 1 - i] * challenge);
if (i != log_poly_degree - 1)
{
challenge = challenge * challenge;
}
}

// Step 7.

// Step 5.
// Construct vector s
// We implement a linear-time algorithm to optimally compute this vector
// Note: currently requires an extra vector of size `poly_length / 2` to cache temporaries
// this might able to be optimized if we care enough, but the size of this poly shouldn't be large relative to the builder polynomial sizes
std::vector<Fr> s_vec_temporaries(poly_length / 2);
std::vector<Fr> s_vec(poly_length);

// TODO(https://github.com/AztecProtocol/barretenberg/issues/857): This code is not efficient as its
// O(nlogn). This can be optimized to be linear by computing a tree of products.
for (size_t i = 0; i < poly_length; i++) {
Fr s_vec_scalar = Fr(1);
for (size_t j = (log_poly_degree - 1); j != static_cast<size_t>(-1); j--) {
auto bit = (i >> j) & 1;
bool b = static_cast<bool>(bit);
if (b) {
s_vec_scalar *= round_challenges_inv[log_poly_degree - 1 - j];
}
Fr* previous_round_s = &s_vec_temporaries[0];
Fr* current_round_s = &s_vec[0];
// if number of rounds is even we need to swap these so that s_vec always contains the result
if ((log_poly_degree & 1) == 0)
{
std::swap(previous_round_s, current_round_s);
}
previous_round_s[0] = Fr(1);
for (size_t i = 0; i < log_poly_degree; ++i)
{
const size_t round_size = 1 << (i + 1);
const Fr round_challenge = round_challenges_inv[i];
for (size_t j = 0; j < round_size / 2; ++j)
{
current_round_s[j * 2] = previous_round_s[j];
current_round_s[j * 2 + 1] = previous_round_s[j] * round_challenge;
}
s_vec[i] = s_vec_scalar;
std::swap(current_round_s, previous_round_s);
}

auto srs_elements = vk->get_monomial_points();

// TODO(https://github.com/AztecProtocol/barretenberg/issues/1023): Unify the two batch_muls
// Step 6.
// Receive a₀ from the prover
const auto a_zero = transcript->template receive_from_prover<Fr>("IPA:a_0");

// Step 8.
// Step 7.
// Compute G₀
// Unlike the native verification function, the verifier commitment key only containts the SRS so we can apply
// batch_mul directly on it.
const std::vector<Commitment> srs_elements = vk->get_monomial_points();
Commitment G_zero = Commitment::batch_mul(srs_elements, s_vec);

// Step 9.
// Receive a₀ from the prover
auto a_zero = transcript->template receive_from_prover<Fr>("IPA:a_0");

// Step 10.
// Compute C_right
GroupElement right_hand_side = G_zero * a_zero + aux_generator * a_zero * b_zero;

// Step 11.
// Check if C_right == C₀
C_zero.assert_equal(right_hand_side);
return (C_zero.get_value() == right_hand_side.get_value());
// Step 8.
// Compute R = C' + ∑_{j ∈ [k]} u_j^{-1}L_j + ∑_{j ∈ [k]} u_jR_j - G₀ * a₀ - (f(\beta) + a₀ * b₀) ⋅ U
// This is a combination of several IPA relations into a large batch mul
// which should be equal to -C
msm_elements.emplace_back(-G_zero);
msm_elements.emplace_back(-Commitment::one(builder));
msm_scalars.emplace_back(a_zero);
msm_scalars.emplace_back(generator_challenge * a_zero.madd(b_zero, {opening_claim.opening_pair.evaluation}));
GroupElement ipa_relation = GroupElement::batch_mul(msm_elements, msm_scalars);
ipa_relation.assert_equal(-opening_claim.commitment);

return (ipa_relation.get_value() == -opening_claim.commitment.get_value());
}

public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ template <typename RecursiveFlavor> class ECCVMRecursiveTests : public ::testing
OuterBuilder outer_circuit;
RecursiveVerifier verifier{ &outer_circuit, verification_key };
verifier.verify_proof(proof);
info("Recursive Verifier: num gates = ", outer_circuit.num_gates);
info("Recursive Verifier: num gates = ", outer_circuit.get_estimated_num_finalized_gates());

// Check for a failure flag in the recursive verifier circuit
EXPECT_EQ(outer_circuit.failed(), false) << outer_circuit.err();
Expand Down Expand Up @@ -135,10 +135,10 @@ template <typename RecursiveFlavor> class ECCVMRecursiveTests : public ::testing
OuterBuilder outer_circuit;
RecursiveVerifier verifier{ &outer_circuit, verification_key };
verifier.verify_proof(proof);
info("Recursive Verifier: num gates = ", outer_circuit.num_gates);
info("Recursive Verifier: estimated num finalized gates = ", outer_circuit.get_estimated_num_finalized_gates());

// Check for a failure flag in the recursive verifier circuit
EXPECT_EQ(outer_circuit.failed(), true) << outer_circuit.err();
EXPECT_FALSE(CircuitChecker::check(outer_circuit));
}
};
using FlavorTypes = testing::Types<ECCVMRecursiveFlavor_<UltraCircuitBuilder>>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,17 +113,18 @@ TEST_F(GoblinRecursiveVerifierTests, ECCVMFailure)

// Tamper with the ECCVM proof
for (auto& val : proof.eccvm_proof) {
if (val > 0) { // tamper by finding the first non-zero value and incrementing it by 1
if (val > 0) { // tamper by finding the tenth non-zero value and incrementing it by 1
// tamper by finding the first non-zero value
// and incrementing it by 1
val += 1;
break;
}
}

Builder builder;
GoblinRecursiveVerifier verifier{ &builder, verifier_input };
verifier.verify(proof);

EXPECT_FALSE(CircuitChecker::check(builder));
EXPECT_DEBUG_DEATH(verifier.verify(proof), "(sumcheck_verified && batched_opening_verified)");
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -678,23 +678,110 @@ typename cycle_group<Builder>::cycle_scalar cycle_group<Builder>::cycle_scalar::
template <typename Builder> cycle_group<Builder>::cycle_scalar::cycle_scalar(BigScalarField& scalar)
{
auto* ctx = get_context() ? get_context() : scalar.get_context();
const uint256_t value((scalar.get_value() % uint512_t(ScalarField::modulus)).lo);
const uint256_t value_lo = value.slice(0, LO_BITS);
const uint256_t value_hi = value.slice(LO_BITS, HI_BITS);

if (scalar.is_constant()) {
const uint256_t value((scalar.get_value() % uint512_t(ScalarField::modulus)).lo);
const uint256_t value_lo = value.slice(0, LO_BITS);
const uint256_t value_hi = value.slice(LO_BITS, HI_BITS);

lo = value_lo;
hi = value_hi;
// N.B. to be able to call assert equal, these cannot be constants
} else {
lo = witness_t(ctx, value_lo);
hi = witness_t(ctx, value_hi);
field_t zero = field_t(0);
zero.convert_constant_to_fixed_witness(ctx);
BigScalarField lo_big(lo, zero);
BigScalarField hi_big(hi, zero);
BigScalarField res = lo_big + hi_big * BigScalarField((uint256_t(1) << LO_BITS));
scalar.assert_equal(res);
validate_scalar_is_in_field();
// To efficiently convert a bigfield into a cycle scalar,
// we are going to explicitly rely on the fact that `scalar.lo` and `scalar.hi`
// are implicitly range-constrained to be 128 bits when they are converted into 4-bit lookup window slices

// First check: can the scalar actually fit into LO_BITS + HI_BITS?
// If it can, we can tolerate the scalar being > ScalarField::modulus, because performing a scalar mul
// implicilty performs a modular reduction
// If not, call `self_reduce` to cut enougn modulus multiples until the above condition is met
if (scalar.get_maximum_value() >= (uint512_t(1) << (LO_BITS + HI_BITS))) {
scalar.self_reduce();
}

field_t limb0 = scalar.binary_basis_limbs[0].element;
field_t limb1 = scalar.binary_basis_limbs[1].element;
field_t limb2 = scalar.binary_basis_limbs[2].element;
field_t limb3 = scalar.binary_basis_limbs[3].element;

// The general plan is as follows:
// 1. ensure limb0 contains no more than BigScalarField::NUM_LIMB_BITS
// 2. define limb1_lo = limb1.slice(0, LO_BITS - BigScalarField::NUM_LIMB_BITS)
// 3. define limb1_hi = limb1.slice(LO_BITS - BigScalarField::NUM_LIMB_BITS, <whatever maximum bound of limb1
// is>)
// 4. construct *this.lo out of limb0 and limb1_lo
// 5. construct *this.hi out of limb1_hi, limb2 and limb3
// This is a lot of logic, but very cheap on constraints.
// For fresh bignums that have come out of a MUL operation,
// the only "expensive" part is a size (LO_BITS - BigScalarField::NUM_LIMB_BITS) range check

// to convert into a cycle_scalar, we need to convert 4*68 bit limbs into 2*128 bit limbs
// we also need to ensure that the number of bits in cycle_scalar is < LO_BITS + HI_BITS
// note: we do not need to validate that the scalar is within the field modulus
// because performing a scalar multiplication implicitly performs a modular reduction (ecc group is
// multiplicative modulo BigField::modulus)

uint256_t limb1_max = scalar.binary_basis_limbs[1].maximum_value;

// Ensure that limb0 only contains at most NUM_LIMB_BITS. If it exceeds this value, slice of the excess and add
// it into limb1
if (scalar.binary_basis_limbs[0].maximum_value > BigScalarField::DEFAULT_MAXIMUM_LIMB) {
const uint256_t limb = limb0.get_value();
const uint256_t lo_v = limb.slice(0, BigScalarField::NUM_LIMB_BITS);
const uint256_t hi_v = limb >> BigScalarField::NUM_LIMB_BITS;
field_t lo = field_t::from_witness(ctx, lo_v);
field_t hi = field_t::from_witness(ctx, hi_v);

uint256_t hi_max = (scalar.binary_basis_limbs[0].maximum_value >> BigScalarField::NUM_LIMB_BITS);
const uint64_t hi_bits = hi_max.get_msb() + 1;
lo.create_range_constraint(BigScalarField::NUM_LIMB_BITS);
hi.create_range_constraint(static_cast<size_t>(hi_bits));
limb0.assert_equal(lo + hi * BigScalarField::shift_1);

limb1 += hi;
limb1_max += hi_max;
limb0 = lo;
}

// sanity check that limb[1] is the limb that contributs both to *this.lo and *this.hi
ASSERT((BigScalarField::NUM_LIMB_BITS * 2 > LO_BITS) && (BigScalarField::NUM_LIMB_BITS < LO_BITS));

// limb1 is the tricky one as it contributs to both *this.lo and *this.hi
// By this point, we know that limb1 fits in the range `1 << BigScalarField::NUM_LIMB_BITS to (1 <<
// BigScalarField::NUM_LIMB_BITS) + limb1_max.get_maximum_value() we need to slice this limb into 2. The first
// is LO_BITS - BigScalarField::NUM_LIMB_BITS (which reprsents its contribution to *this.lo) and the second
// represents the limbs contribution to *this.hi Step 1: compute the max bit sizes of both slices
const size_t lo_bits_in_limb_1 = LO_BITS - BigScalarField::NUM_LIMB_BITS;
const size_t hi_bits_in_limb_1 = (static_cast<size_t>(limb1_max.get_msb()) + 1) - lo_bits_in_limb_1;

// Step 2: compute the witness values of both slices
const uint256_t limb_1 = limb1.get_value();
const uint256_t limb_1_hi_multiplicand = (uint256_t(1) << lo_bits_in_limb_1);
const uint256_t limb_1_hi_v = limb_1 >> lo_bits_in_limb_1;
const uint256_t limb_1_lo_v = limb_1 - (limb_1_hi_v << lo_bits_in_limb_1);

// Step 3: instantiate both slices as witnesses and validate their sum equals limb1
field_t limb_1_lo = field_t::from_witness(ctx, limb_1_lo_v);
field_t limb_1_hi = field_t::from_witness(ctx, limb_1_hi_v);
limb1.assert_equal(limb_1_hi * limb_1_hi_multiplicand + limb_1_lo);

// Step 4: apply range constraints to validate both slices represent the expected contributions to *this.lo and
// *this,hi
limb_1_lo.create_range_constraint(lo_bits_in_limb_1);
limb_1_hi.create_range_constraint(hi_bits_in_limb_1);

// construct *this.lo out of:
// a. `limb0` (the first NUM_LIMB_BITS bits of scalar)
// b. `limb_1_lo` (the first LO_BITS - NUM_LIMB_BITS) of limb1
lo = limb0 + (limb_1_lo * BigScalarField::shift_1);

const uint256_t limb_2_shift = uint256_t(1) << (BigScalarField::NUM_LIMB_BITS - lo_bits_in_limb_1);
const uint256_t limb_3_shift =
uint256_t(1) << ((BigScalarField::NUM_LIMB_BITS - lo_bits_in_limb_1) + BigScalarField::NUM_LIMB_BITS);

// construct *this.hi out of limb2, limb3 and the remaining term from limb1 not contributing to `lo`
hi = limb_1_hi.add_two(limb2 * limb_2_shift, limb3 * limb_3_shift);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -485,8 +485,8 @@ class ECCOpQueue {

// Decompose point coordinates (Fq) into hi-lo chunks (Fr)
const size_t CHUNK_SIZE = 2 * DEFAULT_NON_NATIVE_FIELD_LIMB_BITS;
auto x_256 = uint256_t(point.x);
auto y_256 = uint256_t(point.y);
uint256_t x_256(point.x);
uint256_t y_256(point.y);
ultra_op.return_is_infinity = point.is_point_at_infinity();
ultra_op.x_lo = Fr(x_256.slice(0, CHUNK_SIZE));
ultra_op.x_hi = Fr(x_256.slice(CHUNK_SIZE, CHUNK_SIZE * 2));
Expand Down

0 comments on commit a4bd3e1

Please sign in to comment.