diff --git a/core/src/air/machine.rs b/core/src/air/machine.rs index f47d14a17..3b4b54a74 100644 --- a/core/src/air/machine.rs +++ b/core/src/air/machine.rs @@ -1,16 +1,11 @@ use std::marker::PhantomData; use p3_air::BaseAir; -use p3_field::{AbstractField, Field}; +use p3_field::Field; use p3_matrix::dense::RowMajorMatrix; pub use sphinx_derive::MachineAir; -use crate::{ - runtime::Program, - stark::{MachineRecord, PublicValued}, -}; - -use super::{PublicValues, Word}; +use crate::{runtime::Program, stark::MachineRecord}; /// A description of the events related to this AIR. pub trait WithEvents<'a>: Sized { @@ -24,9 +19,7 @@ pub trait WithEvents<'a>: Sized { /// Chip, as specified by its `WithEvents` trait implementation. /// /// The name is inspired by (but not conformant to) functional optics ( https://doi.org/10.1145/1232420.1232424 ) -/// -/// TODO: Figure out if the PublicValued bound should generalize. -pub trait EventLens WithEvents<'b>>: PublicValued { +pub trait EventLens WithEvents<'b>> { fn events(&self) -> >::Events; } @@ -75,16 +68,6 @@ where } } -impl<'a, T, R, F> PublicValued for Proj<'a, T, R, F> -where - T: for<'b> WithEvents<'b>, - R: EventLens, -{ - fn public_values(&self) -> PublicValues, FF> { - self.record.public_values() - } -} - //////////////// end of shenanigans destined for the derive macros. //////////////// /// An AIR that is part of a multi table AIR arithmetization. diff --git a/core/src/bytes/trace.rs b/core/src/bytes/trace.rs index 570ab76f4..d612fff57 100644 --- a/core/src/bytes/trace.rs +++ b/core/src/bytes/trace.rs @@ -17,8 +17,12 @@ use crate::{ pub const NUM_ROWS: usize = 1 << 16; impl<'a, F: Field> WithEvents<'a> for ByteChip { - // the byte lookups - type Events = &'a HashMap>; + type Events = ( + // the byte lookups + &'a HashMap>, + // the public values + PublicValues, F>, + ); } impl MachineAir for ByteChip { @@ -57,10 +61,10 @@ impl MachineAir for ByteChip { NUM_BYTE_MULT_COLS, ); - let pv: PublicValues, F> = input.public_values(); + let (events, pv) = input.events(); let shard = pv.execution_shard.as_canonical_u32(); - for (lookup, mult) in input.events().get(&shard).unwrap_or(&HashMap::new()).iter() { + for (lookup, mult) in events.get(&shard).unwrap_or(&HashMap::new()).iter() { let row = if lookup.opcode != ByteOpcode::U16Range { ((lookup.b << 8) + lookup.c) as usize } else { diff --git a/core/src/memory/global.rs b/core/src/memory/global.rs index ce913d55b..91d8b01fc 100644 --- a/core/src/memory/global.rs +++ b/core/src/memory/global.rs @@ -1,6 +1,7 @@ use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; use std::array; +use std::marker::PhantomData; use p3_air::{Air, AirBuilder}; use p3_air::{AirBuilderWithPublicValues, BaseAir}; @@ -24,33 +25,39 @@ pub enum MemoryChipType { } /// A memory chip that can initialize or finalize values in memory. -pub struct MemoryChip { +pub struct MemoryChip { pub kind: MemoryChipType, + _marker: PhantomData, } -impl MemoryChip { +impl MemoryChip { /// Creates a new memory chip with a certain type. - pub const fn new(kind: MemoryChipType) -> Self { - Self { kind } + pub fn new(kind: MemoryChipType) -> Self { + Self { + kind, + _marker: PhantomData, + } } } -impl BaseAir for MemoryChip { +impl BaseAir for MemoryChip { fn width(&self) -> usize { NUM_MEMORY_INIT_COLS } } -impl<'a> WithEvents<'a> for MemoryChip { +impl<'a, F: 'a> WithEvents<'a> for MemoryChip { type Events = ( // initialize events &'a [MemoryInitializeFinalizeEvent], // finalize events &'a [MemoryInitializeFinalizeEvent], + // the public values + PublicValues, F>, ); } -impl MachineAir for MemoryChip { +impl MachineAir for MemoryChip { type Record = ExecutionRecord; type Program = Program; @@ -67,7 +74,7 @@ impl MachineAir for MemoryChip { input: &EL, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let (mem_init_events, mem_final_events) = input.events(); + let (mem_init_events, mem_final_events, pv) = input.events(); let mut memory_events = match self.kind { MemoryChipType::Initialize => mem_init_events.to_vec(), @@ -75,14 +82,10 @@ impl MachineAir for MemoryChip { }; let previous_addr_bits = match self.kind { - MemoryChipType::Initialize => input - .public_values::() - .previous_init_addr_bits - .map(|f| f.as_canonical_u32()), - MemoryChipType::Finalize => input - .public_values::() - .previous_finalize_addr_bits - .map(|f| f.as_canonical_u32()), + MemoryChipType::Initialize => pv.previous_init_addr_bits.map(|f| f.as_canonical_u32()), + MemoryChipType::Finalize => { + pv.previous_finalize_addr_bits.map(|f| f.as_canonical_u32()) + } }; memory_events.sort_by_key(|event| event.addr); @@ -196,7 +199,7 @@ pub struct MemoryInitCols { pub(crate) const NUM_MEMORY_INIT_COLS: usize = size_of::>(); -impl Air for MemoryChip +impl Air for MemoryChip where AB: AirBuilderWithPublicValues + BaseAirBuilder, { @@ -427,13 +430,13 @@ mod tests { runtime.run().unwrap(); let shard = runtime.record.clone(); - let chip: MemoryChip = MemoryChip::new(MemoryChipType::Initialize); + let chip: MemoryChip = MemoryChip::new(MemoryChipType::Initialize); let trace: RowMajorMatrix = chip.generate_trace(&shard, &mut ExecutionRecord::default()); println!("{:?}", trace.values); - let chip: MemoryChip = MemoryChip::new(MemoryChipType::Finalize); + let chip: MemoryChip = MemoryChip::new(MemoryChipType::Finalize); let trace: RowMajorMatrix = chip.generate_trace(&shard, &mut ExecutionRecord::default()); println!("{:?}", trace.values); diff --git a/core/src/memory/program.rs b/core/src/memory/program.rs index 791cde799..ff706e1f4 100644 --- a/core/src/memory/program.rs +++ b/core/src/memory/program.rs @@ -5,6 +5,7 @@ use p3_field::{AbstractField, PrimeField32}; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use std::collections::BTreeMap; +use std::marker::PhantomData; use sphinx_derive::AlignedBorrow; @@ -46,19 +47,19 @@ pub struct MemoryProgramMultCols { /// receives each row in the first shard. This prevents any of these addresses from being /// overwritten through the normal MemoryInit. #[derive(Default)] -pub struct MemoryProgramChip; +pub struct MemoryProgramChip(PhantomData); -impl MemoryProgramChip { - pub const fn new() -> Self { - Self {} +impl MemoryProgramChip { + pub fn new() -> Self { + Self(PhantomData) } } -impl<'a> WithEvents<'a> for MemoryProgramChip { - type Events = &'a BTreeMap; +impl<'a, F: 'a> WithEvents<'a> for MemoryProgramChip { + type Events = (&'a BTreeMap, PublicValues, F>); } -impl MachineAir for MemoryProgramChip { +impl MachineAir for MemoryProgramChip { type Record = ExecutionRecord; type Program = Program; @@ -112,9 +113,10 @@ impl MachineAir for MemoryProgramChip { input: &EL, _output: &mut ExecutionRecord, ) -> RowMajorMatrix { - let program_memory_addrs = input.events().keys().copied().collect::>(); + let (events, pv) = input.events(); + let program_memory_addrs = events.keys().copied().collect::>(); - let mult = if input.public_values::().shard == F::one() { + let mult = if pv.shard == F::one() { F::one() } else { F::zero() @@ -128,7 +130,7 @@ impl MachineAir for MemoryProgramChip { let cols: &mut MemoryProgramMultCols = row.as_mut_slice().borrow_mut(); cols.multiplicity = mult; cols.is_first_shard - .populate(input.public_values::().shard.as_canonical_u32() - 1); + .populate(pv.shard.as_canonical_u32() - 1); row }) .collect::>(); @@ -150,13 +152,13 @@ impl MachineAir for MemoryProgramChip { } } -impl BaseAir for MemoryProgramChip { +impl BaseAir for MemoryProgramChip { fn width(&self) -> usize { NUM_MEMORY_PROGRAM_MULT_COLS } } -impl Air for MemoryProgramChip +impl Air for MemoryProgramChip where AB: BaseAirBuilder + PairBuilder + AirBuilderWithPublicValues, { diff --git a/core/src/program/mod.rs b/core/src/program/mod.rs index 1c8f74bf4..532373dcd 100644 --- a/core/src/program/mod.rs +++ b/core/src/program/mod.rs @@ -3,6 +3,7 @@ use core::{ mem::size_of, }; use hashbrown::HashMap; +use std::marker::PhantomData; use p3_air::{Air, BaseAir, PairBuilder}; use p3_field::PrimeField; @@ -10,7 +11,7 @@ use p3_matrix::{dense::RowMajorMatrix, Matrix}; use sphinx_derive::AlignedBorrow; use crate::{ - air::{EventLens, MachineAir, ProgramAirBuilder, WithEvents}, + air::{EventLens, MachineAir, ProgramAirBuilder, PublicValues, WithEvents, Word}, cpu::{ columns::{InstructionCols, OpcodeSelectorCols}, CpuEvent, @@ -44,24 +45,26 @@ pub struct ProgramMultiplicityCols { /// A chip that implements addition for the opcodes ADD and ADDI. #[derive(Default)] -pub struct ProgramChip; +pub struct ProgramChip(PhantomData); -impl ProgramChip { - pub const fn new() -> Self { - Self {} +impl ProgramChip { + pub fn new() -> Self { + Self(PhantomData) } } -impl<'a> WithEvents<'a> for ProgramChip { +impl<'a, F: 'a> WithEvents<'a> for ProgramChip { type Events = ( // CPU events &'a [CpuEvent], // the Program &'a Program, + // the public values + PublicValues, F>, ); } -impl MachineAir for ProgramChip { +impl MachineAir for ProgramChip { type Record = ExecutionRecord; type Program = Program; @@ -120,7 +123,7 @@ impl MachineAir for ProgramChip { ) -> RowMajorMatrix { // Generate the trace rows for each event. - let (cpu_events, program) = input.events(); + let (cpu_events, program, pv) = input.events(); // Collect the number of times each instruction is called from the cpu events. // Store it as a map of PC -> count. let mut instruction_counts = HashMap::new(); @@ -141,7 +144,7 @@ impl MachineAir for ProgramChip { let pc = program.pc_base + (i as u32 * 4); let mut row = [F::zero(); NUM_PROGRAM_MULT_COLS]; let cols: &mut ProgramMultiplicityCols = row.as_mut_slice().borrow_mut(); - cols.shard = input.public_values().execution_shard; + cols.shard = pv.execution_shard; cols.multiplicity = F::from_canonical_usize(*instruction_counts.get(&pc).unwrap_or(&0)); row @@ -165,13 +168,13 @@ impl MachineAir for ProgramChip { } } -impl BaseAir for ProgramChip { +impl BaseAir for ProgramChip { fn width(&self) -> usize { NUM_PROGRAM_MULT_COLS } } -impl Air for ProgramChip +impl Air for ProgramChip where AB: ProgramAirBuilder + PairBuilder, { diff --git a/core/src/runtime/record.rs b/core/src/runtime/record.rs index 4bf3b54c0..977e4ce65 100644 --- a/core/src/runtime/record.rs +++ b/core/src/runtime/record.rs @@ -183,7 +183,10 @@ impl EventLens for ExecutionRecord { impl EventLens> for ExecutionRecord { fn events(&self) -> as crate::air::WithEvents<'_>>::Events { - &self.byte_lookups + ( + &self.byte_lookups, + ::public_values(self), + ) } } @@ -193,21 +196,32 @@ impl EventLens for ExecutionRecord { } } -impl EventLens for ExecutionRecord { - fn events(&self) -> >::Events { - (&self.memory_initialize_events, &self.memory_finalize_events) +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + ( + &self.memory_initialize_events, + &self.memory_finalize_events, + ::public_values(self), + ) } } -impl EventLens for ExecutionRecord { - fn events(&self) -> >::Events { - &self.program.memory_image +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + ( + &self.program.memory_image, + ::public_values(self), + ) } } -impl EventLens for ExecutionRecord { - fn events(&self) -> >::Events { - (&self.cpu_events, &self.program) +impl EventLens> for ExecutionRecord { + fn events(&self) -> as crate::air::WithEvents<'_>>::Events { + ( + &self.cpu_events, + &self.program, + ::public_values(self), + ) } } diff --git a/core/src/stark/air.rs b/core/src/stark/air.rs index 7a2f08c46..765c5196a 100644 --- a/core/src/stark/air.rs +++ b/core/src/stark/air.rs @@ -51,7 +51,7 @@ pub(crate) mod riscv_chips { #[record_type = "crate::runtime::ExecutionRecord"] pub enum RiscvAir { /// An AIR that contains a preprocessed program table and a lookup for the instructions. - Program(ProgramChip), + Program(ProgramChip), /// An AIR for the RISC-V CPU. Each row represents a cpu cycle. Cpu(CpuChip), /// An AIR for the RISC-V Add and SUB instruction. @@ -71,11 +71,11 @@ pub enum RiscvAir { /// A lookup table for byte operations. ByteLookup(ByteChip), /// A table for initializing the memory state. - MemoryInit(MemoryChip), + MemoryInit(MemoryChip), /// A table for finalizing the memory state. - MemoryFinal(MemoryChip), + MemoryFinal(MemoryChip), /// A table for initializing the program memory. - ProgramMemory(MemoryProgramChip), + ProgramMemory(MemoryProgramChip), /// A precompile for sha256 extend. Sha256Extend(ShaExtendChip), /// A precompile for sha256 compress. @@ -125,7 +125,7 @@ impl RiscvAir { let mut chips = vec![]; let cpu = CpuChip; chips.push(RiscvAir::Cpu(cpu)); - let program = ProgramChip; + let program = ProgramChip::new(); chips.push(RiscvAir::Program(program)); let sha_extend = ShaExtendChip; chips.push(RiscvAir::Sha256Extend(sha_extend));