Skip to content

Commit

Permalink
fix: Remove PublicValued bound from EventLens
Browse files Browse the repository at this point in the history
This refactor moves the `PublicValues` struct into the corresponding
`WithEvents` impls for each chip that needs access to it explicitly.
  • Loading branch information
wwared committed Sep 9, 2024
1 parent ee510ca commit 0174a5f
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 81 deletions.
23 changes: 3 additions & 20 deletions core/src/air/machine.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
use std::marker::PhantomData;

use p3_air::BaseAir;
use p3_field::{AbstractField, Field};
use p3_field::Field;
use p3_matrix::dense::RowMajorMatrix;
pub use sphinx_derive::MachineAir;

use crate::{
runtime::Program,
stark::{MachineRecord, PublicValued},
};

use super::{PublicValues, Word};
use crate::{runtime::Program, stark::MachineRecord};

/// A description of the events related to this AIR.
pub trait WithEvents<'a>: Sized {
Expand All @@ -24,9 +19,7 @@ pub trait WithEvents<'a>: Sized {
/// Chip, as specified by its `WithEvents` trait implementation.
///
/// The name is inspired by (but not conformant to) functional optics ( https://doi.org/10.1145/1232420.1232424 )
///
/// TODO: Figure out if the PublicValued bound should generalize.
pub trait EventLens<T: for<'b> WithEvents<'b>>: PublicValued {
pub trait EventLens<T: for<'b> WithEvents<'b>> {
fn events(&self) -> <T as WithEvents<'_>>::Events;
}

Expand Down Expand Up @@ -75,16 +68,6 @@ where
}
}

impl<'a, T, R, F> PublicValued for Proj<'a, T, R, F>
where
T: for<'b> WithEvents<'b>,
R: EventLens<T>,
{
fn public_values<FF: AbstractField>(&self) -> PublicValues<Word<FF>, FF> {
self.record.public_values()
}
}

//////////////// end of shenanigans destined for the derive macros. ////////////////

/// An AIR that is part of a multi table AIR arithmetization.
Expand Down
12 changes: 8 additions & 4 deletions core/src/bytes/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ use crate::{
pub const NUM_ROWS: usize = 1 << 16;

impl<'a, F: Field> WithEvents<'a> for ByteChip<F> {
// the byte lookups
type Events = &'a HashMap<u32, HashMap<ByteLookupEvent, usize>>;
type Events = (
// the byte lookups
&'a HashMap<u32, HashMap<ByteLookupEvent, usize>>,
// the public values
PublicValues<Word<F>, F>,
);
}

impl<F: PrimeField32> MachineAir<F> for ByteChip<F> {
Expand Down Expand Up @@ -57,10 +61,10 @@ impl<F: PrimeField32> MachineAir<F> for ByteChip<F> {
NUM_BYTE_MULT_COLS,
);

let pv: PublicValues<Word<F>, F> = input.public_values();
let (events, pv) = input.events();

let shard = pv.execution_shard.as_canonical_u32();
for (lookup, mult) in input.events().get(&shard).unwrap_or(&HashMap::new()).iter() {
for (lookup, mult) in events.get(&shard).unwrap_or(&HashMap::new()).iter() {
let row = if lookup.opcode != ByteOpcode::U16Range {
((lookup.b << 8) + lookup.c) as usize
} else {
Expand Down
41 changes: 22 additions & 19 deletions core/src/memory/global.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use core::borrow::{Borrow, BorrowMut};
use core::mem::size_of;
use std::array;
use std::marker::PhantomData;

use p3_air::{Air, AirBuilder};
use p3_air::{AirBuilderWithPublicValues, BaseAir};
Expand All @@ -24,33 +25,39 @@ pub enum MemoryChipType {
}

/// A memory chip that can initialize or finalize values in memory.
pub struct MemoryChip {
pub struct MemoryChip<F> {
pub kind: MemoryChipType,
_marker: PhantomData<F>,
}

impl MemoryChip {
impl<F> MemoryChip<F> {
/// Creates a new memory chip with a certain type.
pub const fn new(kind: MemoryChipType) -> Self {
Self { kind }
pub fn new(kind: MemoryChipType) -> Self {
Self {
kind,
_marker: PhantomData,
}
}
}

impl<F> BaseAir<F> for MemoryChip {
impl<F: Send + Sync> BaseAir<F> for MemoryChip<F> {
fn width(&self) -> usize {
NUM_MEMORY_INIT_COLS
}
}

impl<'a> WithEvents<'a> for MemoryChip {
impl<'a, F: 'a> WithEvents<'a> for MemoryChip<F> {
type Events = (
// initialize events
&'a [MemoryInitializeFinalizeEvent],
// finalize events
&'a [MemoryInitializeFinalizeEvent],
// the public values
PublicValues<Word<F>, F>,
);
}

impl<F: PrimeField32> MachineAir<F> for MemoryChip {
impl<F: PrimeField32> MachineAir<F> for MemoryChip<F> {
type Record = ExecutionRecord;

type Program = Program;
Expand All @@ -67,22 +74,18 @@ impl<F: PrimeField32> MachineAir<F> for MemoryChip {
input: &EL,
_output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
let (mem_init_events, mem_final_events) = input.events();
let (mem_init_events, mem_final_events, pv) = input.events();

let mut memory_events = match self.kind {
MemoryChipType::Initialize => mem_init_events.to_vec(),
MemoryChipType::Finalize => mem_final_events.to_vec(),
};

let previous_addr_bits = match self.kind {
MemoryChipType::Initialize => input
.public_values::<F>()
.previous_init_addr_bits
.map(|f| f.as_canonical_u32()),
MemoryChipType::Finalize => input
.public_values::<F>()
.previous_finalize_addr_bits
.map(|f| f.as_canonical_u32()),
MemoryChipType::Initialize => pv.previous_init_addr_bits.map(|f| f.as_canonical_u32()),
MemoryChipType::Finalize => {
pv.previous_finalize_addr_bits.map(|f| f.as_canonical_u32())
}
};

memory_events.sort_by_key(|event| event.addr);
Expand Down Expand Up @@ -196,7 +199,7 @@ pub struct MemoryInitCols<T> {

pub(crate) const NUM_MEMORY_INIT_COLS: usize = size_of::<MemoryInitCols<u8>>();

impl<AB> Air<AB> for MemoryChip
impl<AB> Air<AB> for MemoryChip<AB::F>
where
AB: AirBuilderWithPublicValues + BaseAirBuilder,
{
Expand Down Expand Up @@ -427,13 +430,13 @@ mod tests {
runtime.run().unwrap();
let shard = runtime.record.clone();

let chip: MemoryChip = MemoryChip::new(MemoryChipType::Initialize);
let chip: MemoryChip<BabyBear> = MemoryChip::new(MemoryChipType::Initialize);

let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
println!("{:?}", trace.values);

let chip: MemoryChip = MemoryChip::new(MemoryChipType::Finalize);
let chip: MemoryChip<BabyBear> = MemoryChip::new(MemoryChipType::Finalize);
let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
println!("{:?}", trace.values);
Expand Down
26 changes: 14 additions & 12 deletions core/src/memory/program.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use p3_field::{AbstractField, PrimeField32};
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use std::collections::BTreeMap;
use std::marker::PhantomData;

use sphinx_derive::AlignedBorrow;

Expand Down Expand Up @@ -46,19 +47,19 @@ pub struct MemoryProgramMultCols<T> {
/// receives each row in the first shard. This prevents any of these addresses from being
/// overwritten through the normal MemoryInit.
#[derive(Default)]
pub struct MemoryProgramChip;
pub struct MemoryProgramChip<F>(PhantomData<F>);

impl MemoryProgramChip {
pub const fn new() -> Self {
Self {}
impl<F> MemoryProgramChip<F> {
pub fn new() -> Self {
Self(PhantomData)
}
}

impl<'a> WithEvents<'a> for MemoryProgramChip {
type Events = &'a BTreeMap<u32, u32>;
impl<'a, F: 'a> WithEvents<'a> for MemoryProgramChip<F> {
type Events = (&'a BTreeMap<u32, u32>, PublicValues<Word<F>, F>);
}

impl<F: PrimeField32> MachineAir<F> for MemoryProgramChip {
impl<F: PrimeField32> MachineAir<F> for MemoryProgramChip<F> {
type Record = ExecutionRecord;

type Program = Program;
Expand Down Expand Up @@ -112,9 +113,10 @@ impl<F: PrimeField32> MachineAir<F> for MemoryProgramChip {
input: &EL,
_output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
let program_memory_addrs = input.events().keys().copied().collect::<Vec<_>>();
let (events, pv) = input.events();
let program_memory_addrs = events.keys().copied().collect::<Vec<_>>();

let mult = if input.public_values::<F>().shard == F::one() {
let mult = if pv.shard == F::one() {
F::one()
} else {
F::zero()
Expand All @@ -128,7 +130,7 @@ impl<F: PrimeField32> MachineAir<F> for MemoryProgramChip {
let cols: &mut MemoryProgramMultCols<F> = row.as_mut_slice().borrow_mut();
cols.multiplicity = mult;
cols.is_first_shard
.populate(input.public_values::<F>().shard.as_canonical_u32() - 1);
.populate(pv.shard.as_canonical_u32() - 1);
row
})
.collect::<Vec<_>>();
Expand All @@ -150,13 +152,13 @@ impl<F: PrimeField32> MachineAir<F> for MemoryProgramChip {
}
}

impl<F> BaseAir<F> for MemoryProgramChip {
impl<F: Send + Sync> BaseAir<F> for MemoryProgramChip<F> {
fn width(&self) -> usize {
NUM_MEMORY_PROGRAM_MULT_COLS
}
}

impl<AB> Air<AB> for MemoryProgramChip
impl<AB> Air<AB> for MemoryProgramChip<AB::F>
where
AB: BaseAirBuilder + PairBuilder + AirBuilderWithPublicValues,
{
Expand Down
25 changes: 14 additions & 11 deletions core/src/program/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ use core::{
mem::size_of,
};
use hashbrown::HashMap;
use std::marker::PhantomData;

use p3_air::{Air, BaseAir, PairBuilder};
use p3_field::PrimeField;
use p3_matrix::{dense::RowMajorMatrix, Matrix};
use sphinx_derive::AlignedBorrow;

use crate::{
air::{EventLens, MachineAir, ProgramAirBuilder, WithEvents},
air::{EventLens, MachineAir, ProgramAirBuilder, PublicValues, WithEvents, Word},
cpu::{
columns::{InstructionCols, OpcodeSelectorCols},
CpuEvent,
Expand Down Expand Up @@ -44,24 +45,26 @@ pub struct ProgramMultiplicityCols<T> {

/// A chip that implements addition for the opcodes ADD and ADDI.
#[derive(Default)]
pub struct ProgramChip;
pub struct ProgramChip<F>(PhantomData<F>);

impl ProgramChip {
pub const fn new() -> Self {
Self {}
impl<F> ProgramChip<F> {
pub fn new() -> Self {
Self(PhantomData)
}
}

impl<'a> WithEvents<'a> for ProgramChip {
impl<'a, F: 'a> WithEvents<'a> for ProgramChip<F> {
type Events = (
// CPU events
&'a [CpuEvent],
// the Program
&'a Program,
// the public values
PublicValues<Word<F>, F>,
);
}

impl<F: PrimeField> MachineAir<F> for ProgramChip {
impl<F: PrimeField> MachineAir<F> for ProgramChip<F> {
type Record = ExecutionRecord;

type Program = Program;
Expand Down Expand Up @@ -120,7 +123,7 @@ impl<F: PrimeField> MachineAir<F> for ProgramChip {
) -> RowMajorMatrix<F> {
// Generate the trace rows for each event.

let (cpu_events, program) = input.events();
let (cpu_events, program, pv) = input.events();
// Collect the number of times each instruction is called from the cpu events.
// Store it as a map of PC -> count.
let mut instruction_counts = HashMap::new();
Expand All @@ -141,7 +144,7 @@ impl<F: PrimeField> MachineAir<F> for ProgramChip {
let pc = program.pc_base + (i as u32 * 4);
let mut row = [F::zero(); NUM_PROGRAM_MULT_COLS];
let cols: &mut ProgramMultiplicityCols<F> = row.as_mut_slice().borrow_mut();
cols.shard = input.public_values().execution_shard;
cols.shard = pv.execution_shard;
cols.multiplicity =
F::from_canonical_usize(*instruction_counts.get(&pc).unwrap_or(&0));
row
Expand All @@ -165,13 +168,13 @@ impl<F: PrimeField> MachineAir<F> for ProgramChip {
}
}

impl<F> BaseAir<F> for ProgramChip {
impl<F: Send + Sync> BaseAir<F> for ProgramChip<F> {
fn width(&self) -> usize {
NUM_PROGRAM_MULT_COLS
}
}

impl<AB> Air<AB> for ProgramChip
impl<AB> Air<AB> for ProgramChip<AB::F>
where
AB: ProgramAirBuilder + PairBuilder,
{
Expand Down
34 changes: 24 additions & 10 deletions core/src/runtime/record.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,10 @@ impl EventLens<ShiftRightChip> for ExecutionRecord {

impl<F: Field> EventLens<ByteChip<F>> for ExecutionRecord {
fn events(&self) -> <ByteChip<F> as crate::air::WithEvents<'_>>::Events {
&self.byte_lookups
(
&self.byte_lookups,
<ExecutionRecord as PublicValued>::public_values(self),
)
}
}

Expand All @@ -193,21 +196,32 @@ impl EventLens<CpuChip> for ExecutionRecord {
}
}

impl EventLens<MemoryChip> for ExecutionRecord {
fn events(&self) -> <MemoryChip as crate::air::WithEvents<'_>>::Events {
(&self.memory_initialize_events, &self.memory_finalize_events)
impl<F: Field> EventLens<MemoryChip<F>> for ExecutionRecord {
fn events(&self) -> <MemoryChip<F> as crate::air::WithEvents<'_>>::Events {
(
&self.memory_initialize_events,
&self.memory_finalize_events,
<ExecutionRecord as PublicValued>::public_values(self),
)
}
}

impl EventLens<MemoryProgramChip> for ExecutionRecord {
fn events(&self) -> <MemoryProgramChip as crate::air::WithEvents<'_>>::Events {
&self.program.memory_image
impl<F: Field> EventLens<MemoryProgramChip<F>> for ExecutionRecord {
fn events(&self) -> <MemoryProgramChip<F> as crate::air::WithEvents<'_>>::Events {
(
&self.program.memory_image,
<ExecutionRecord as PublicValued>::public_values(self),
)
}
}

impl EventLens<ProgramChip> for ExecutionRecord {
fn events(&self) -> <ProgramChip as crate::air::WithEvents<'_>>::Events {
(&self.cpu_events, &self.program)
impl<F: Field> EventLens<ProgramChip<F>> for ExecutionRecord {
fn events(&self) -> <ProgramChip<F> as crate::air::WithEvents<'_>>::Events {
(
&self.cpu_events,
&self.program,
<ExecutionRecord as PublicValued>::public_values(self),
)
}
}

Expand Down
Loading

0 comments on commit 0174a5f

Please sign in to comment.