Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add LDE bench and salt support #37

Merged
merged 12 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ members = ["field", "maybe_rayon", "plonky2", "starky", "util", "gen", "u32", "e
resolver = "2"

[workspace.dependencies]
cryptography_cuda = { git = "ssh://[email protected]/okx/cryptography_cuda.git", rev = "2a7c42d29ee72d7c2c2da9378ae816384c43cdec" }
cryptography_cuda = { git = "ssh://[email protected]/okx/cryptography_cuda.git", rev = "547192b2ef42dc7519435059c86f88431b8de999" }
ahash = { version = "0.8.7", default-features = false, features = [
"compile-time-rng",
] } # NOTE: Be sure to keep this version the same as the dependency in `hashbrown`.
Expand Down
16 changes: 0 additions & 16 deletions field/src/fft.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
use alloc::vec::Vec;
use core::cmp::{max, min};

#[cfg(feature = "cuda")]
use cryptography_cuda::{ntt, types::NTTInputOutputOrder};
use plonky2_util::{log2_strict, reverse_index_bits_in_place};
use unroll::unroll_for_loops;

Expand Down Expand Up @@ -34,20 +32,6 @@ pub fn fft_root_table<F: Field>(n: usize) -> FftRootTable<F> {
root_table
}

#[allow(dead_code)]
#[cfg(feature = "cuda")]
fn fft_dispatch_gpu<F: Field>(
input: &mut [F],
zero_factor: Option<usize>,
root_table: Option<&FftRootTable<F>>,
) {
if F::CUDA_SUPPORT {
return ntt(0, input, NTTInputOutputOrder::NN);
} else {
return fft_dispatch_cpu(input, zero_factor, root_table);
}
}

fn fft_dispatch_cpu<F: Field>(
input: &mut [F],
zero_factor: Option<usize>,
Expand Down
6 changes: 5 additions & 1 deletion plonky2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ serde = { workspace = true, features = ["rc"] }
static_assertions = { workspace = true }
unroll = { workspace = true }
web-time = { version = "1.0.0", optional = true }
once_cell = { version = "1.18.0" }
once_cell = { version = "1.20.2" }
papi-bindings = { version = "0.5.2" }

# Local dependencies
Expand Down Expand Up @@ -80,6 +80,10 @@ harness = false
name = "ffts"
harness = false

[[bench]]
name = "lde"
harness = false

[[bench]]
name = "hashing"
harness = false
Expand Down
59 changes: 59 additions & 0 deletions plonky2/benches/lde.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
mod allocator;

use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
#[cfg(feature = "cuda")]
use cryptography_cuda::init_cuda_degree_rs;
use plonky2::field::extension::Extendable;
use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::field::polynomial::PolynomialCoeffs;
use plonky2::fri::oracle::PolynomialBatch;
use plonky2::hash::hash_types::RichField;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use plonky2::util::timing::TimingTree;
use tynm::type_name;

pub(crate) fn bench_batch_lde<
F: RichField + Extendable<D>,
C: GenericConfig<D, F = F>,
const D: usize,
>(
c: &mut Criterion,
) {
const RATE_BITS: usize = 3;

let mut group = c.benchmark_group(&format!("lde<{}>", type_name::<F>()));

#[cfg(feature = "cuda")]
init_cuda_degree_rs(16);

for size_log in [13, 14, 15] {
let orig_size = 1 << (size_log - RATE_BITS);
let lde_size = 1 << size_log;
let batch_size = 1 << 4;

group.bench_with_input(BenchmarkId::from_parameter(lde_size), &lde_size, |b, _| {
let polynomials: Vec<PolynomialCoeffs<F>> = (0..batch_size)
.into_iter()
.map(|_i| PolynomialCoeffs::new(F::rand_vec(orig_size)))
.collect();
let mut timing = TimingTree::new("lde", log::Level::Error);
b.iter(|| {
PolynomialBatch::<F, C, D>::from_coeffs(
polynomials.clone(),
RATE_BITS,
false,
1,
&mut timing,
None,
)
});
});
}
}

fn criterion_benchmark(c: &mut Criterion) {
bench_batch_lde::<GoldilocksField, PoseidonGoldilocksConfig, 2>(c);
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
56 changes: 44 additions & 12 deletions plonky2/src/fri/oracle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,22 @@ use crate::util::reducing::ReducingFactor;
use crate::util::timing::TimingTree;
use crate::util::{log2_strict, reverse_bits, reverse_index_bits_in_place, transpose};

#[cfg(all(feature = "cuda", any(test, doctest)))]
pub static GPU_INIT: once_cell::sync::Lazy<std::sync::Arc<std::sync::Mutex<u64>>> =
once_cell::sync::Lazy::new(|| std::sync::Arc::new(std::sync::Mutex::new(0)));

#[cfg(all(feature = "cuda", any(test, doctest)))]
fn init_gpu() {
use cryptography_cuda::init_cuda_rs;

let mut init = GPU_INIT.lock().unwrap();
if *init == 0 {
println!("Init GPU!");
init_cuda_rs();
*init = 1;
}
}

/// Four (~64 bit) field elements gives ~128 bit security.
pub const SALT_SIZE: usize = 4;

Expand Down Expand Up @@ -192,10 +208,17 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
timing: &mut TimingTree,
fft_root_table: Option<&FftRootTable<F>>,
) -> Self {
let pols = polynomials.len();
let degree = polynomials[0].len();
let log_n = log2_strict(degree);

if log_n + rate_bits > 1 && polynomials.len() > 0 {
#[cfg(any(test, doctest))]
init_gpu();

if log_n + rate_bits > 1
&& polynomials.len() > 0
&& pols * (1 << (log_n + rate_bits)) < (1 << 31)
{
let _num_gpus: usize = std::env::var("NUM_OF_GPUS")
.expect("NUM_OF_GPUS should be set")
.parse()
Expand Down Expand Up @@ -232,17 +255,17 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
}

#[cfg(feature = "cuda")]
pub fn from_coeffs_gpu(
fn from_coeffs_gpu(
polynomials: &[PolynomialCoeffs<F>],
rate_bits: usize,
_blinding: bool,
blinding: bool,
cap_height: usize,
timing: &mut TimingTree,
_fft_root_table: Option<&FftRootTable<F>>,
log_n: usize,
_degree: usize,
) -> MerkleTree<F, <C as GenericConfig<D>>::Hasher> {
// let salt_size = if blinding { SALT_SIZE } else { 0 };
let salt_size = if blinding { SALT_SIZE } else { 0 };
// println!("salt_size: {:?}", salt_size);
let output_domain_size = log_n + rate_bits;

Expand All @@ -255,8 +278,9 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
let total_num_of_fft = polynomials.len();
// println!("total_num_of_fft: {:?}", total_num_of_fft);

let num_of_cols = total_num_of_fft + salt_size; // if blinding, extend by salt_size
let total_num_input_elements = total_num_of_fft * (1 << log_n);
let total_num_output_elements = total_num_of_fft * (1 << output_domain_size);
let total_num_output_elements = num_of_cols * (1 << output_domain_size);

let mut gpu_input: Vec<F> = polynomials
.into_iter()
Expand All @@ -270,6 +294,7 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
cfg_lde.are_outputs_on_device = true;
cfg_lde.with_coset = true;
cfg_lde.is_multi_gpu = true;
cfg_lde.salt_size = salt_size as u32;

let mut device_output_data: HostOrDeviceSlice<'_, F> =
HostOrDeviceSlice::cuda_malloc(0 as i32, total_num_output_elements).unwrap();
Expand Down Expand Up @@ -302,7 +327,7 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
}

let mut cfg_trans = TransposeConfig::default();
cfg_trans.batches = total_num_of_fft as u32;
cfg_trans.batches = num_of_cols as u32;
cfg_trans.are_inputs_on_device = true;
cfg_trans.are_outputs_on_device = true;

Expand All @@ -327,10 +352,14 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
MerkleTree::new_from_gpu_leaves(
&device_transpose_data,
1 << output_domain_size,
total_num_of_fft,
num_of_cols,
cap_height
)
);

drop(device_transpose_data);
drop(device_output_data);

mt
}

Expand All @@ -340,6 +369,9 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
blinding: bool,
fft_root_table: Option<&FftRootTable<F>>,
) -> Vec<Vec<F>> {
#[cfg(all(feature = "cuda", any(test, doctest)))]
init_gpu();

let degree = polynomials[0].len();
#[cfg(all(feature = "cuda", feature = "batch"))]
let log_n = log2_strict(degree) + rate_bits;
Expand Down Expand Up @@ -443,11 +475,11 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
println!("collect data from gpu used: {:?}", start.elapsed());
r
})
// .chain(
// (0..salt_size)
// .into_par_iter()
// .map(|_| F::rand_vec(degree << rate_bits)),
// )
.chain(
(0..salt_size)
.into_par_iter()
.map(|_| F::rand_vec(degree << rate_bits)),
)
.collect();
println!("real lde elapsed: {:?}", start_lde.elapsed());
return ret;
Expand Down
2 changes: 1 addition & 1 deletion plonky2/src/gates/gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use core::ops::Range;
use std::sync::Arc;

use hashbrown::HashMap;
use serde::{ Serialize, Serializer};
use serde::{Serialize, Serializer};

use crate::field::batch_util::batch_multiply_inplace;
use crate::field::extension::{Extendable, FieldExtension};
Expand Down
6 changes: 1 addition & 5 deletions plonky2/src/gates/low_degree_interpolation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,7 @@ impl<F: RichField + Extendable<D>, const D: usize> Gate<F, D> for LowDegreeInter
fn id(&self) -> String {
format!("{self:?}<D={D}>")
}
fn serialize(
&self,
dst: &mut Vec<u8>,
_common_data: &CommonCircuitData<F, D>,
) -> IoResult<()> {
fn serialize(&self, dst: &mut Vec<u8>, _common_data: &CommonCircuitData<F, D>) -> IoResult<()> {
dst.write_usize(self.subgroup_bits)?;
Ok(())
}
Expand Down
56 changes: 45 additions & 11 deletions plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,26 @@ const MSB_: i64 = 0x8000000000000000u64 as i64;
const P8_: i64 = 0xFFFFFFFF00000001u64 as i64;
const P8_N_: i64 = 0xFFFFFFFF;

#[allow(non_snake_case)]
#[repr(align(64))]
struct FieldConstants {
MSB_V: [i64; 8],
P8_V: [i64; 8],
P8_N_V: [i64; 8],
}

const FC: FieldConstants = FieldConstants {
MSB_V: [MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_],
P8_V: [P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_],
P8_N_V: [P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_],
};

#[allow(dead_code)]
#[inline(always)]
pub fn shift_avx512(a: &__m512i) -> __m512i {
unsafe {
let msb = _mm512_set_epi64(MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_);
// let msb = _mm512_set_epi64(MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_);
let msb = _mm512_load_si512(FC.MSB_V.as_ptr().cast::<i32>());
_mm512_xor_si512(*a, msb)
}
}
Expand All @@ -20,27 +35,31 @@ pub fn shift_avx512(a: &__m512i) -> __m512i {
#[inline(always)]
pub fn to_canonical_avx512(a: &__m512i) -> __m512i {
unsafe {
let p8 = _mm512_set_epi64(P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_);
let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
// let p8 = _mm512_set_epi64(P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_);
// let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
let p8 = _mm512_load_si512(FC.P8_V.as_ptr().cast::<i32>());
let p8_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::<i32>());
let result_mask = _mm512_cmpge_epu64_mask(*a, p8);
_mm512_mask_add_epi64(*a, result_mask, *a, p8_n)
}
}

#[inline(always)]
pub fn add_avx512_b_c(a: &__m512i, b: &__m512i) -> __m512i {
pub fn add_avx512(a: &__m512i, b: &__m512i) -> __m512i {
unsafe {
let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
// let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
let p8_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::<i32>());
let c0 = _mm512_add_epi64(*a, *b);
let result_mask = _mm512_cmpgt_epu64_mask(*a, c0);
_mm512_mask_add_epi64(c0, result_mask, c0, p8_n)
}
}

#[inline(always)]
pub fn sub_avx512_b_c(a: &__m512i, b: &__m512i) -> __m512i {
pub fn sub_avx512(a: &__m512i, b: &__m512i) -> __m512i {
unsafe {
let p8 = _mm512_set_epi64(P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_);
// let p8 = _mm512_set_epi64(P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_);
let p8 = _mm512_load_si512(FC.P8_V.as_ptr().cast::<i32>());
let c0 = _mm512_sub_epi64(*a, *b);
let result_mask = _mm512_cmpgt_epu64_mask(*b, *a);
_mm512_mask_add_epi64(c0, result_mask, c0, p8)
Expand All @@ -50,11 +69,25 @@ pub fn sub_avx512_b_c(a: &__m512i, b: &__m512i) -> __m512i {
#[inline(always)]
pub fn reduce_avx512_128_64(c_h: &__m512i, c_l: &__m512i) -> __m512i {
unsafe {
let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
// let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
let p8_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::<i32>());
let c_hh = _mm512_srli_epi64(*c_h, 32);
let c1 = sub_avx512_b_c(c_l, &c_hh);
let c1 = sub_avx512(c_l, &c_hh);
let c2 = _mm512_mul_epu32(*c_h, p8_n);
add_avx512_b_c(&c1, &c2)
add_avx512(&c1, &c2)
}
}

// Here we suppose c_h < 2^32
#[inline(always)]
pub fn reduce_avx512_96_64(c_h: &__m512i, c_l: &__m512i) -> __m512i {
unsafe {
let msb = _mm512_load_si512(FC.MSB_V.as_ptr().cast::<i32>());
let p_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::<i32>());
let c_ls = _mm512_xor_si512(*c_l, msb);
let c2 = _mm512_mul_epu32(*c_h, p_n);
let c_s = add_avx512(&c_ls, &c2);
_mm512_xor_si512(c_s, msb)
}
}

Expand All @@ -69,7 +102,8 @@ pub fn mult_avx512_128(a: &__m512i, b: &__m512i) -> (__m512i, __m512i) {
let c_ll = _mm512_mul_epu32(*a, *b);
let c_ll_h = _mm512_srli_epi64(c_ll, 32);
let r0 = _mm512_add_epi64(c_hl, c_ll_h);
let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
// let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
let p8_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::<i32>());
let r0_l = _mm512_and_si512(r0, p8_n);
let r0_h = _mm512_srli_epi64(r0, 32);
let r1 = _mm512_add_epi64(c_lh, r0_l);
Expand Down
2 changes: 1 addition & 1 deletion plonky2/src/hash/arch/x86_64/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod goldilocks_avx512;
pub mod poseidon2_goldilocks_avx2;
#[cfg(target_feature = "avx2")]
pub mod poseidon_bn128_avx2;
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512dq")))]
#[cfg(target_feature = "avx2")]
pub mod poseidon_goldilocks_avx2;
#[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))]
pub mod poseidon_goldilocks_avx512;
Loading