Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update cryptography_cuda reference, fix AVX2 issues #36

Merged
merged 21 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,17 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased
## [Unreleased]

## [0.2.3] - 2024-04-16

- Code refactoring ([#1558](https://github.com/0xPolygonZero/plonky2/pull/1558))
- Simplify types: remove option from CTL filters ([#1567](https://github.com/0xPolygonZero/plonky2/pull/1567))
- Add stdarch_x86_avx512 feature ([#1566](https://github.com/0xPolygonZero/plonky2/pull/1566))

## [0.2.2] - 2024-03-21

### Changed
- Fix CTLs with exactly two looking tables ([#1555](https://github.com/0xPolygonZero/plonky2/pull/1555))
- Make Starks without constraints provable ([#1552](https://github.com/0xPolygonZero/plonky2/pull/1552))

Expand Down
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ members = ["field", "maybe_rayon", "plonky2", "starky", "util", "gen", "u32", "e
resolver = "2"

[workspace.dependencies]
cryptography_cuda = { git = "ssh://[email protected]/okx/cryptography_cuda.git", rev = "73261c1420670cb371b359929df578edbb3b6a62" }
cryptography_cuda = { git = "ssh://[email protected]/okx/cryptography_cuda.git", rev = "f2ed17c3086b9ca538272974e42b47e4bf7970e2" }
ahash = { version = "0.8.7", default-features = false, features = [
"compile-time-rng",
] } # NOTE: Be sure to keep this version the same as the dependency in `hashbrown`.
Expand Down
9 changes: 2 additions & 7 deletions field/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
[package]
name = "plonky2_field"
description = "Finite field arithmetic"
version = "0.2.0"
authors = [
"Daniel Lubarov <[email protected]>",
"William Borgeaud <[email protected]>",
"Jacqueline Nabaglo <[email protected]>",
"Hamish Ivey-Law <[email protected]>",
]
version = "0.2.2"
authors = ["Daniel Lubarov <[email protected]>", "William Borgeaud <[email protected]>", "Jacqueline Nabaglo <[email protected]>", "Hamish Ivey-Law <[email protected]>"]
edition.workspace = true
license.workspace = true
homepage.workspace = true
Expand Down
6 changes: 6 additions & 0 deletions field/src/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ fn fft_dispatch_cpu<F: Field>(
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) {
/*
if root_table.is_some() {
return fft_classic(input, zero_factor.unwrap_or(0), root_table.unwrap());
} else {
Expand All @@ -68,6 +69,11 @@ fn fft_dispatch_cpu<F: Field>(

return fft_classic(input, zero_factor.unwrap_or(0), computed.as_ref());
};
*/
let computed_root_table = root_table.is_none().then(|| fft_root_table(input.len()));
let used_root_table = root_table.or(computed_root_table.as_ref()).unwrap();

fft_classic(input, zero_factor.unwrap_or(0), used_root_table);
}

#[inline]
Expand Down
1 change: 1 addition & 0 deletions field/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#![deny(rustdoc::broken_intra_doc_links)]
#![deny(missing_debug_implementations)]
#![feature(specialization)]
#![cfg_attr(target_arch = "x86_64", feature(stdarch_x86_avx512))]
#![cfg_attr(not(test), no_std)]
#![cfg(not(test))]
extern crate alloc;
Expand Down
10 changes: 3 additions & 7 deletions plonky2/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
[package]
name = "plonky2"
description = "Recursive SNARKs based on PLONK and FRI"
version = "0.2.0"
authors = [
"Daniel Lubarov <[email protected]>",
"William Borgeaud <[email protected]>",
"Nicholas Ward <[email protected]>",
]
version = "0.2.2"
authors = ["Daniel Lubarov <[email protected]>", "William Borgeaud <[email protected]>", "Nicholas Ward <[email protected]>"]
readme = "README.md"
edition.workspace = true
license.workspace = true
Expand Down Expand Up @@ -45,7 +41,7 @@ once_cell = { version = "1.18.0" }
papi-bindings = { version = "0.5.2" }

# Local dependencies
plonky2_field = { version = "0.2.0", path = "../field", default-features = false }
plonky2_field = { version = "0.2.2", path = "../field", default-features = false }
plonky2_maybe_rayon = { version = "0.2.0", path = "../maybe_rayon", default-features = false }
plonky2_util = { version = "0.2.0", path = "../util", default-features = false }
cryptography_cuda = { workspace = true, optional = true }
Expand Down
1 change: 0 additions & 1 deletion plonky2/src/gates/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use alloc::{
vec,
vec::Vec,
};
use core::usize;

use itertools::Itertools;
use keccak_hash::keccak;
Expand Down
1 change: 0 additions & 1 deletion plonky2/src/gates/lookup_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use alloc::{
vec,
vec::Vec,
};
use core::usize;
#[cfg(feature = "std")]
use std::sync::Arc;

Expand Down
1 change: 1 addition & 0 deletions plonky2/src/hash/arch/x86_64/poseidon2_goldilocks_avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ where
}
}

#[allow(dead_code)]
#[inline(always)]
pub fn matmul_internal_avx<F>(
state: &mut [F; SPONGE_WIDTH],
Expand Down
34 changes: 15 additions & 19 deletions plonky2/src/hash/arch/x86_64/poseidon_bn128_avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1095,13 +1095,18 @@ mod tests {
13281191951274694749u64 as i64,
13281191951274694749u64 as i64,
);
let exp: [u64; 4] = [
0xE0842DFEFB3AC8EEu64,
0xE0842DFEFB3AC8EEu64,
0xE0842DFEFB3AC8EEu64,
0xE0842DFEFB3AC8EEu64,
];

let r = _mm256_add_epi64(ct1, ct2);
let mut a: [u64; 4] = [0; 4];
_mm256_store_si256(a.as_mut_ptr().cast::<__m256i>(), r);
println!("{:?}", a);
let x = 2896914383306846353u64 + 13281191951274694749u64;
println!("{:?}", x);
let mut vr: [u64; 4] = [0; 4];
_mm256_storeu_si256(vr.as_mut_ptr().cast::<__m256i>(), r);
println!("{:X?}", vr);
assert_eq!(vr, exp);
}
Ok(())
}
Expand Down Expand Up @@ -1147,28 +1152,19 @@ mod tests {
#[test]
fn test_bn128_sub64() -> Result<()> {
unsafe {
let a = _mm256_set_epi64x(
4i64,
7i64,
0xFFFFFFFFFFFFFFFFu64 as i64,
4291643747455737684u64 as i64,
);
let b = _mm256_set_epi64x(7i64, 4i64, 0x0i64, 3486998266802970665u64 as i64);
let a = _mm256_set_epi64x(4i64, 7i64, 0xFFFFFFFFFFFFFFFFu64 as i64, 0x0u64 as i64);
let b = _mm256_set_epi64x(7i64, 4i64, 0x0i64, 0xFFFFFFFFFFFFFFFFu64 as i64);
let bin = _mm256_set_epi64x(0, 0, 0, 0);

let res = [
0xFFFFFFFFFFFFFFFFu64,
0xFFFFFFFFFFFFFFFFu64,
3u64,
0xFFFFFFFFFFFFFFFDu64,
];
let res = [0x1u64, 0xFFFFFFFFFFFFFFFFu64, 3u64, 0xFFFFFFFFFFFFFFFDu64];

let bout = [1u64, 0u64, 0u64, 1u64];

let mut v: [u64; 4] = [0; 4];
let (c1, c2) = sub64(&a, &b, &bin);
_mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), c1);
println!(" Res: {:?}", v);
println!("Res: {:X?}", v);
println!("Exp: {:X?}", res);
assert_eq!(v, res);
_mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), c2);
println!("Cout: {:X?}", v);
Expand Down
2 changes: 1 addition & 1 deletion plonky2/src/hash/poseidon2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use plonky2_field::goldilocks_field::GoldilocksField;
use super::arch::x86_64::goldilocks_avx2::sbox_avx;
#[cfg(target_feature = "avx2")]
use super::arch::x86_64::poseidon2_goldilocks_avx2::{
add_rc_avx, internal_layer_avx, matmul_internal_avx, permute_mut_avx,
add_rc_avx, internal_layer_avx, permute_mut_avx,
};
use super::hash_types::{HashOutTarget, NUM_HASH_OUT_ELTS};
use crate::field::extension::Extendable;
Expand Down
4 changes: 4 additions & 0 deletions plonky2/src/hash/poseidon_bn128_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4550,6 +4550,7 @@ pub struct PoseidonBN128NativePermutation<F> {
}

impl<F: RichField> PoseidonBN128NativePermutation<F> {
#[allow(dead_code)]
#[inline]
fn exp5state(self, state: &mut [ElementBN128; 5]) {
state[0].exp5();
Expand All @@ -4559,6 +4560,7 @@ impl<F: RichField> PoseidonBN128NativePermutation<F> {
state[4].exp5();
}

#[allow(dead_code)]
#[inline]
fn ark(self, state: &mut [ElementBN128; 5], c: [[u64; 4]; 100], it: usize) {
for i in 0..5 {
Expand All @@ -4567,6 +4569,7 @@ impl<F: RichField> PoseidonBN128NativePermutation<F> {
}
}

#[allow(dead_code)]
#[inline]
fn mix(self, state: &mut [ElementBN128; 5], m: [[[u64; 4]; 5]; 5]) {
let mut new_state: [ElementBN128; 5] = [ElementBN128::zero(); 5];
Expand All @@ -4584,6 +4587,7 @@ impl<F: RichField> PoseidonBN128NativePermutation<F> {
}
}

#[allow(dead_code)]
pub fn permute_fn(&self, input: [u64; 12]) -> [u64; 12] {
#[cfg(feature = "papi")]
let mut event_set = init_papi();
Expand Down
2 changes: 1 addition & 1 deletion plonky2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#![deny(rustdoc::broken_intra_doc_links)]
#![deny(missing_debug_implementations)]
#![cfg_attr(not(feature = "std"), no_std)]
// #![feature(stdarch_x86_avx512)]
#![feature(stdarch_x86_avx512)]

// #[cfg(not(feature = "std"))]
pub extern crate alloc;
Expand Down
4 changes: 2 additions & 2 deletions starky/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]
name = "starky"
description = "Implementation of STARKs"
version = "0.2.1"
version = "0.4.0"
authors = ["Daniel Lubarov <[email protected]>", "William Borgeaud <[email protected]>"]
readme = "README.md"
edition.workspace = true
Expand All @@ -26,7 +26,7 @@ log = { workspace = true }
num-bigint = { version = "0.4.3", default-features = false }

# Local dependencies
plonky2 = { version = "0.2.0", path = "../plonky2", default-features = false }
plonky2 = { version = "0.2.2", path = "../plonky2", default-features = false }
plonky2_maybe_rayon = { version = "0.2.0", path = "../maybe_rayon", default-features = false }
plonky2_util = { version = "0.2.0", path = "../util", default-features = false }

Expand Down
58 changes: 14 additions & 44 deletions starky/src/cross_table_lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,12 @@ pub type TableIdx = usize;
pub struct TableWithColumns<F: Field> {
table: TableIdx,
columns: Vec<Column<F>>,
filter: Option<Filter<F>>,
filter: Filter<F>,
}

impl<F: Field> TableWithColumns<F> {
/// Generates a new `TableWithColumns` given a `table` index, a linear combination of columns `columns` and a `filter`.
pub fn new(table: TableIdx, columns: Vec<Column<F>>, filter: Option<Filter<F>>) -> Self {
pub fn new(table: TableIdx, columns: Vec<Column<F>>, filter: Filter<F>) -> Self {
Self {
table,
columns,
Expand Down Expand Up @@ -163,7 +163,7 @@ pub struct CtlZData<'a, F: Field> {
pub(crate) columns: Vec<&'a [Column<F>]>,
/// Vector of filter columns for the current table.
/// Each filter evaluates to either 1 or 0.
pub(crate) filter: Vec<Option<Filter<F>>>,
pub(crate) filter: Vec<Filter<F>>,
}

impl<'a, F: Field> CtlZData<'a, F> {
Expand All @@ -173,7 +173,7 @@ impl<'a, F: Field> CtlZData<'a, F> {
z: PolynomialValues<F>,
challenge: GrandProductChallenge<F>,
columns: Vec<&'a [Column<F>]>,
filter: Vec<Option<Filter<F>>>,
filter: Vec<Filter<F>>,
) -> Self {
Self {
helper_columns,
Expand Down Expand Up @@ -404,7 +404,7 @@ fn ctl_helper_zs_cols<F: Field, const N: usize>(
.map(|(table, group)| {
let columns_filters = group
.map(|table| (&table.columns[..], &table.filter))
.collect::<Vec<(&[Column<F>], &Option<Filter<F>>)>>();
.collect::<Vec<(&[Column<F>], &Filter<F>)>>();
(
table,
partial_sums(
Expand Down Expand Up @@ -484,7 +484,7 @@ where
/// Column linear combinations of the `CrossTableLookup`s.
pub(crate) columns: Vec<&'a [Column<F>]>,
/// Filter that evaluates to either 1 or 0.
pub(crate) filter: Vec<Option<Filter<F>>>,
pub(crate) filter: Vec<Filter<F>>,
}

impl<'a, F: RichField + Extendable<D>, const D: usize>
Expand Down Expand Up @@ -682,16 +682,8 @@ pub(crate) fn eval_cross_table_lookup_checks<F, FE, P, S, const D: usize, const
let combin0 = challenges.combine(&evals[0]);
let combin1 = challenges.combine(&evals[1]);

let f0 = if let Some(filter0) = &filter[0] {
filter0.eval_filter(local_values, next_values)
} else {
P::ONES
};
let f1 = if let Some(filter1) = &filter[1] {
filter1.eval_filter(local_values, next_values)
} else {
P::ONES
};
let f0 = filter[0].eval_filter(local_values, next_values);
let f1 = filter[1].eval_filter(local_values, next_values);

consumer
.constraint_last_row(combin0 * combin1 * *local_z - f0 * combin1 - f1 * combin0);
Expand All @@ -700,11 +692,7 @@ pub(crate) fn eval_cross_table_lookup_checks<F, FE, P, S, const D: usize, const
);
} else {
let combin0 = challenges.combine(&evals[0]);
let f0 = if let Some(filter0) = &filter[0] {
filter0.eval_filter(local_values, next_values)
} else {
P::ONES
};
let f0 = filter[0].eval_filter(local_values, next_values);
consumer.constraint_last_row(combin0 * *local_z - f0);
consumer.constraint_transition(combin0 * (*local_z - *next_z) - f0);
}
Expand All @@ -726,7 +714,7 @@ pub struct CtlCheckVarsTarget<F: Field, const D: usize> {
/// Column linear combinations of the `CrossTableLookup`s.
pub(crate) columns: Vec<Vec<Column<F>>>,
/// Filter that evaluates to either 1 or 0.
pub(crate) filter: Vec<Option<Filter<F>>>,
pub(crate) filter: Vec<Filter<F>>,
}

impl<'a, F: Field, const D: usize> CtlCheckVarsTarget<F, D> {
Expand Down Expand Up @@ -856,8 +844,6 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit<
let local_values = vars.get_local_values();
let next_values = vars.get_next_values();

let one = builder.one_extension();

for lookup_vars in ctl_vars {
let CtlCheckVarsTarget {
helper_columns,
Expand Down Expand Up @@ -906,16 +892,8 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit<
let combin0 = challenges.combine_circuit(builder, &evals[0]);
let combin1 = challenges.combine_circuit(builder, &evals[1]);

let f0 = if let Some(filter0) = &filter[0] {
filter0.eval_filter_circuit(builder, local_values, next_values)
} else {
one
};
let f1 = if let Some(filter1) = &filter[1] {
filter1.eval_filter_circuit(builder, local_values, next_values)
} else {
one
};
let f0 = filter[0].eval_filter_circuit(builder, local_values, next_values);
let f1 = filter[1].eval_filter_circuit(builder, local_values, next_values);

let combined = builder.mul_sub_extension(combin1, *local_z, f1);
let combined = builder.mul_extension(combined, combin0);
Expand All @@ -928,11 +906,7 @@ pub(crate) fn eval_cross_table_lookup_checks_circuit<
consumer.constraint_last_row(builder, constr);
} else {
let combin0 = challenges.combine_circuit(builder, &evals[0]);
let f0 = if let Some(filter0) = &filter[0] {
filter0.eval_filter_circuit(builder, local_values, next_values)
} else {
one
};
let f0 = filter[0].eval_filter_circuit(builder, local_values, next_values);

let constr = builder.mul_sub_extension(combin0, *local_z, f0);
consumer.constraint_last_row(builder, constr);
Expand Down Expand Up @@ -1121,11 +1095,7 @@ pub mod debug_utils {
) {
let trace = &trace_poly_values[table.table];
for i in 0..trace[0].len() {
let filter = if let Some(combin) = &table.filter {
combin.eval_table(trace, i)
} else {
F::ONE
};
let filter = table.filter.eval_table(trace, i);
if filter.is_one() {
let row = table
.columns
Expand Down
Loading