Skip to content

Commit

Permalink
feat: avm support for public input columns (#5700)
Browse files Browse the repository at this point in the history
Adds support for public input columns as outlined in the following
hackmd: https://hackmd.io/8kkJo4RkRTG6mpwL8fOf3w?both
  • Loading branch information
Maddiaa0 authored May 10, 2024
1 parent 98d32f1 commit 8cf9168
Show file tree
Hide file tree
Showing 15 changed files with 1,042 additions and 0 deletions.
3 changes: 3 additions & 0 deletions barretenberg/cpp/pil/spike/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
## Spike machine

A spike machine for testing new PIL functionality
8 changes: 8 additions & 0 deletions barretenberg/cpp/pil/spike/spike.pil
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

namespace Spike(16);

pol constant first = [1] + [0]*;
pol commit x;
pol public kernel_inputs;

x - first = 0;
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

#define Spike_DECLARE_VIEWS(index) \
using Accumulator = typename std::tuple_element<index, ContainerOverSubrelations>::type; \
using View = typename Accumulator::View; \
[[maybe_unused]] auto Spike_first = View(new_term.Spike_first); \
[[maybe_unused]] auto Spike_kernel_inputs = View(new_term.Spike_kernel_inputs); \
[[maybe_unused]] auto Spike_x = View(new_term.Spike_x);
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@

#pragma once
#include "../../relation_parameters.hpp"
#include "../../relation_types.hpp"
#include "./declare_views.hpp"

namespace bb::Spike_vm {

template <typename FF> struct SpikeRow {
FF Spike_first{};
FF Spike_x{};
};

inline std::string get_relation_label_spike(int index)
{
switch (index) {}
return std::to_string(index);
}

template <typename FF_> class spikeImpl {
public:
using FF = FF_;

static constexpr std::array<size_t, 1> SUBRELATION_PARTIAL_LENGTHS{
2,
};

template <typename ContainerOverSubrelations, typename AllEntities>
void static accumulate(ContainerOverSubrelations& evals,
const AllEntities& new_term,
[[maybe_unused]] const RelationParameters<FF>&,
[[maybe_unused]] const FF& scaling_factor)
{

// Contribution 0
{
Spike_DECLARE_VIEWS(0);

auto tmp = (Spike_x - Spike_first);
tmp *= scaling_factor;
std::get<0>(evals) += tmp;
}
}
};

template <typename FF> using spike = Relation<spikeImpl<FF>>;

} // namespace bb::Spike_vm
8 changes: 8 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/generated/avm_flavor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2026,6 +2026,14 @@ class AvmFlavor {
*/
template <size_t LENGTH> using ProverUnivariates = AllEntities<bb::Univariate<FF, LENGTH>>;

/**
* @brief A container for univariates used during Protogalaxy folding and sumcheck with some of the computation
* optmistically ignored
* @details During folding and sumcheck, the prover evaluates the relations on these univariates.
*/
template <size_t LENGTH, size_t SKIP_COUNT>
using OptimisedProverUnivariates = AllEntities<bb::Univariate<FF, LENGTH, 0, SKIP_COUNT>>;

/**
* @brief A container for univariates produced during the hot loop in sumcheck.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
#include "./avm_verifier.hpp"
#include "barretenberg/commitment_schemes/zeromorph/zeromorph.hpp"
#include "barretenberg/numeric/bitop/get_msb.hpp"
#include "barretenberg/polynomials/polynomial.hpp"
#include "barretenberg/transcript/transcript.hpp"

namespace bb {

AvmVerifier::AvmVerifier(std::shared_ptr<Flavor::VerificationKey> verifier_key)
: key(verifier_key)
{}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@


// AUTOGENERATED FILE
#pragma once

#include "barretenberg/common/constexpr_utils.hpp"
#include "barretenberg/common/throw_or_abort.hpp"
#include "barretenberg/ecc/curves/bn254/fr.hpp"
#include "barretenberg/honk/proof_system/logderivative_library.hpp"
#include "barretenberg/relations/generic_lookup/generic_lookup_relation.hpp"
#include "barretenberg/relations/generic_permutation/generic_permutation_relation.hpp"
#include "barretenberg/stdlib_circuit_builders/circuit_builder_base.hpp"

#include "barretenberg/relations/generated/spike/spike.hpp"
#include "barretenberg/vm/generated/spike_flavor.hpp"

namespace bb {

template <typename FF> struct SpikeFullRow {
FF Spike_first{};
FF Spike_kernel_inputs{};
FF Spike_x{};
};

class SpikeCircuitBuilder {
public:
using Flavor = bb::SpikeFlavor;
using FF = Flavor::FF;
using Row = SpikeFullRow<FF>;

// TODO: template
using Polynomial = Flavor::Polynomial;
using ProverPolynomials = Flavor::ProverPolynomials;

static constexpr size_t num_fixed_columns = 3;
static constexpr size_t num_polys = 3;
std::vector<Row> rows;

void set_trace(std::vector<Row>&& trace) { rows = std::move(trace); }

ProverPolynomials compute_polynomials()
{
const auto num_rows = get_circuit_subgroup_size();
ProverPolynomials polys;

// Allocate mem for each column
for (auto& poly : polys.get_all()) {
poly = Polynomial(num_rows);
}

for (size_t i = 0; i < rows.size(); i++) {
polys.Spike_first[i] = rows[i].Spike_first;
polys.Spike_kernel_inputs[i] = rows[i].Spike_kernel_inputs;
polys.Spike_x[i] = rows[i].Spike_x;
}

return polys;
}

[[maybe_unused]] bool check_circuit()
{

auto polys = compute_polynomials();
const size_t num_rows = polys.get_polynomial_size();

const auto evaluate_relation = [&]<typename Relation>(const std::string& relation_name,
std::string (*debug_label)(int)) {
typename Relation::SumcheckArrayOfValuesOverSubrelations result;
for (auto& r : result) {
r = 0;
}
constexpr size_t NUM_SUBRELATIONS = result.size();

for (size_t i = 0; i < num_rows; ++i) {
Relation::accumulate(result, polys.get_row(i), {}, 1);

bool x = true;
for (size_t j = 0; j < NUM_SUBRELATIONS; ++j) {
if (result[j] != 0) {
std::string row_name = debug_label(static_cast<int>(j));
throw_or_abort(
format("Relation ", relation_name, ", subrelation index ", row_name, " failed at row ", i));
x = false;
}
}
if (!x) {
return false;
}
}
return true;
};

if (!evaluate_relation.template operator()<Spike_vm::spike<FF>>("spike", Spike_vm::get_relation_label_spike)) {
return false;
}

return true;
}

[[nodiscard]] size_t get_num_gates() const { return rows.size(); }

[[nodiscard]] size_t get_circuit_subgroup_size() const
{
const size_t num_rows = get_num_gates();
const auto num_rows_log2 = static_cast<size_t>(numeric::get_msb64(num_rows));
size_t num_rows_pow2 = 1UL << (num_rows_log2 + (1UL << num_rows_log2 == num_rows ? 0 : 1));
return num_rows_pow2;
}
};
} // namespace bb
86 changes: 86 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/generated/spike_composer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@


#include "./spike_composer.hpp"
#include "barretenberg/plonk_honk_shared/composer/composer_lib.hpp"
#include "barretenberg/plonk_honk_shared/composer/permutation_lib.hpp"
#include "barretenberg/vm/generated/spike_circuit_builder.hpp"
#include "barretenberg/vm/generated/spike_verifier.hpp"

namespace bb {

using Flavor = SpikeFlavor;
void SpikeComposer::compute_witness(CircuitConstructor& circuit)
{
if (computed_witness) {
return;
}

auto polynomials = circuit.compute_polynomials();

for (auto [key_poly, prover_poly] : zip_view(proving_key->get_all(), polynomials.get_unshifted())) {
ASSERT(flavor_get_label(*proving_key, key_poly) == flavor_get_label(polynomials, prover_poly));
key_poly = prover_poly;
}

computed_witness = true;
}

SpikeProver SpikeComposer::create_prover(CircuitConstructor& circuit_constructor)
{
compute_proving_key(circuit_constructor);
compute_witness(circuit_constructor);
compute_commitment_key(circuit_constructor.get_circuit_subgroup_size());

SpikeProver output_state(proving_key, proving_key->commitment_key);

return output_state;
}

SpikeVerifier SpikeComposer::create_verifier(CircuitConstructor& circuit_constructor)
{
auto verification_key = compute_verification_key(circuit_constructor);

SpikeVerifier output_state(verification_key);

auto pcs_verification_key = std::make_unique<VerifierCommitmentKey>();

output_state.pcs_verification_key = std::move(pcs_verification_key);

return output_state;
}

std::shared_ptr<Flavor::ProvingKey> SpikeComposer::compute_proving_key(CircuitConstructor& circuit_constructor)
{
if (proving_key) {
return proving_key;
}

// Initialize proving_key
{
const size_t subgroup_size = circuit_constructor.get_circuit_subgroup_size();
proving_key = std::make_shared<Flavor::ProvingKey>(subgroup_size, 0);
}

proving_key->contains_recursive_proof = false;

return proving_key;
}

std::shared_ptr<Flavor::VerificationKey> SpikeComposer::compute_verification_key(
CircuitConstructor& circuit_constructor)
{
if (verification_key) {
return verification_key;
}

if (!proving_key) {
compute_proving_key(circuit_constructor);
}

verification_key =
std::make_shared<Flavor::VerificationKey>(proving_key->circuit_size, proving_key->num_public_inputs);

return verification_key;
}

} // namespace bb
69 changes: 69 additions & 0 deletions barretenberg/cpp/src/barretenberg/vm/generated/spike_composer.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@


#pragma once

#include "barretenberg/plonk_honk_shared/composer/composer_lib.hpp"
#include "barretenberg/srs/global_crs.hpp"
#include "barretenberg/vm/generated/spike_circuit_builder.hpp"
#include "barretenberg/vm/generated/spike_prover.hpp"
#include "barretenberg/vm/generated/spike_verifier.hpp"

namespace bb {
class SpikeComposer {
public:
using Flavor = SpikeFlavor;
using CircuitConstructor = SpikeCircuitBuilder;
using ProvingKey = Flavor::ProvingKey;
using VerificationKey = Flavor::VerificationKey;
using PCS = Flavor::PCS;
using CommitmentKey = Flavor::CommitmentKey;
using VerifierCommitmentKey = Flavor::VerifierCommitmentKey;

// TODO: which of these will we really need
static constexpr std::string_view NAME_STRING = "Spike";
static constexpr size_t NUM_RESERVED_GATES = 0;
static constexpr size_t NUM_WIRES = Flavor::NUM_WIRES;

std::shared_ptr<ProvingKey> proving_key;
std::shared_ptr<VerificationKey> verification_key;

// The crs_factory holds the path to the srs and exposes methods to extract the srs elements
std::shared_ptr<bb::srs::factories::CrsFactory<Flavor::Curve>> crs_factory_;

// The commitment key is passed to the prover but also used herein to compute the verfication key commitments
std::shared_ptr<CommitmentKey> commitment_key;

std::vector<uint32_t> recursive_proof_public_input_indices;
bool contains_recursive_proof = false;
bool computed_witness = false;

SpikeComposer() { crs_factory_ = bb::srs::get_bn254_crs_factory(); }

SpikeComposer(std::shared_ptr<ProvingKey> p_key, std::shared_ptr<VerificationKey> v_key)
: proving_key(std::move(p_key))
, verification_key(std::move(v_key))
{}

SpikeComposer(SpikeComposer&& other) noexcept = default;
SpikeComposer(SpikeComposer const& other) noexcept = default;
SpikeComposer& operator=(SpikeComposer&& other) noexcept = default;
SpikeComposer& operator=(SpikeComposer const& other) noexcept = default;
~SpikeComposer() = default;

std::shared_ptr<ProvingKey> compute_proving_key(CircuitConstructor& circuit_constructor);
std::shared_ptr<VerificationKey> compute_verification_key(CircuitConstructor& circuit_constructor);

void compute_witness(CircuitConstructor& circuit_constructor);

SpikeProver create_prover(CircuitConstructor& circuit_constructor);
SpikeVerifier create_verifier(CircuitConstructor& circuit_constructor);

void add_table_column_selector_poly_to_proving_key(bb::polynomial& small, const std::string& tag);

void compute_commitment_key(size_t circuit_size)
{
proving_key->commitment_key = std::make_shared<CommitmentKey>(circuit_size);
};
};

} // namespace bb
Loading

0 comments on commit 8cf9168

Please sign in to comment.