From 6f925b729321a09144ab11a4d11808ca57f53683 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Fri, 27 Sep 2024 16:30:25 +0800 Subject: [PATCH 01/16] update crypto_cuda ref --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 4ee77074f2..68dfb213f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = ["field", "maybe_rayon", "plonky2", "starky", "util", "gen", "u32", "e resolver = "2" [workspace.dependencies] -cryptography_cuda = { git = "ssh://git@github.com/okx/cryptography_cuda.git", rev = "2a7c42d29ee72d7c2c2da9378ae816384c43cdec" } +cryptography_cuda = { git = "ssh://git@github.com/okx/cryptography_cuda.git", rev = "51e363b7074ad48edad09fad5d936c81b075f1a7" } ahash = { version = "0.8.7", default-features = false, features = [ "compile-time-rng", ] } # NOTE: Be sure to keep this version the same as the dependency in `hashbrown`. From 067d4a3bc2a0ceb36da39869988190a595ac4340 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Mon, 30 Sep 2024 16:11:47 +0800 Subject: [PATCH 02/16] add salt support in oracle.rs --- plonky2/src/fri/oracle.rs | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/plonky2/src/fri/oracle.rs b/plonky2/src/fri/oracle.rs index 443810fcde..1e8db6c994 100644 --- a/plonky2/src/fri/oracle.rs +++ b/plonky2/src/fri/oracle.rs @@ -192,10 +192,11 @@ impl, C: GenericConfig, const D: usize> timing: &mut TimingTree, fft_root_table: Option<&FftRootTable>, ) -> Self { + let pols = polynomials.len(); let degree = polynomials[0].len(); let log_n = log2_strict(degree); - if log_n + rate_bits > 1 && polynomials.len() > 0 { + if log_n + rate_bits > 1 && polynomials.len() > 0 && pols * (1 << (log_n + rate_bits)) < (1 << 31) { let _num_gpus: usize = std::env::var("NUM_OF_GPUS") .expect("NUM_OF_GPUS should be set") .parse() @@ -235,14 +236,14 @@ impl, C: GenericConfig, const D: usize> pub fn from_coeffs_gpu( polynomials: &[PolynomialCoeffs], rate_bits: usize, - _blinding: bool, + blinding: bool, cap_height: usize, timing: &mut TimingTree, _fft_root_table: Option<&FftRootTable>, log_n: usize, _degree: usize, ) -> MerkleTree>::Hasher> { - // let salt_size = if blinding { SALT_SIZE } else { 0 }; + let salt_size = if blinding { SALT_SIZE } else { 0 }; // println!("salt_size: {:?}", salt_size); let output_domain_size = log_n + rate_bits; @@ -255,8 +256,9 @@ impl, C: GenericConfig, const D: usize> let total_num_of_fft = polynomials.len(); // println!("total_num_of_fft: {:?}", total_num_of_fft); + let num_of_cols = total_num_of_fft + salt_size; // if blinding, extend by salt_size let total_num_input_elements = total_num_of_fft * (1 << log_n); - let total_num_output_elements = total_num_of_fft * (1 << output_domain_size); + let total_num_output_elements = num_of_cols * (1 << output_domain_size); let mut gpu_input: Vec = polynomials .into_iter() @@ -270,6 +272,7 @@ impl, C: GenericConfig, const D: usize> cfg_lde.are_outputs_on_device = true; cfg_lde.with_coset = true; cfg_lde.is_multi_gpu = true; + cfg_lde.salt_size = salt_size as u32; let mut device_output_data: HostOrDeviceSlice<'_, F> = HostOrDeviceSlice::cuda_malloc(0 as i32, total_num_output_elements).unwrap(); @@ -302,7 +305,7 @@ impl, C: GenericConfig, const D: usize> } let mut cfg_trans = TransposeConfig::default(); - cfg_trans.batches = total_num_of_fft as u32; + cfg_trans.batches = num_of_cols as u32; cfg_trans.are_inputs_on_device = true; cfg_trans.are_outputs_on_device = true; @@ -327,10 +330,14 @@ impl, C: GenericConfig, const D: usize> MerkleTree::new_from_gpu_leaves( &device_transpose_data, 1 << output_domain_size, - total_num_of_fft, + num_of_cols, cap_height ) ); + + drop(device_transpose_data); + drop(device_output_data); + mt } @@ -443,11 +450,11 @@ impl, C: GenericConfig, const D: usize> println!("collect data from gpu used: {:?}", start.elapsed()); r }) - // .chain( - // (0..salt_size) - // .into_par_iter() - // .map(|_| F::rand_vec(degree << rate_bits)), - // ) + .chain( + (0..salt_size) + .into_par_iter() + .map(|_| F::rand_vec(degree << rate_bits)), + ) .collect(); println!("real lde elapsed: {:?}", start_lde.elapsed()); return ret; From 56c42f6ace7b1ae05c6f45ad5797a46368eaec98 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Wed, 2 Oct 2024 15:46:31 +0800 Subject: [PATCH 03/16] add LDE benchmarks --- plonky2/Cargo.toml | 4 ++++ plonky2/benches/lde.rs | 46 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 50 insertions(+) create mode 100644 plonky2/benches/lde.rs diff --git a/plonky2/Cargo.toml b/plonky2/Cargo.toml index 94715189b3..791a64c922 100644 --- a/plonky2/Cargo.toml +++ b/plonky2/Cargo.toml @@ -80,6 +80,10 @@ harness = false name = "ffts" harness = false +[[bench]] +name = "lde" +harness = false + [[bench]] name = "hashing" harness = false diff --git a/plonky2/benches/lde.rs b/plonky2/benches/lde.rs new file mode 100644 index 0000000000..c1a702f931 --- /dev/null +++ b/plonky2/benches/lde.rs @@ -0,0 +1,46 @@ +mod allocator; + +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use plonky2::field::extension::Extendable; +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::field::polynomial::PolynomialCoeffs; +use plonky2::fri::oracle::PolynomialBatch; +use plonky2::hash::hash_types::RichField; +use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; +use plonky2::util::timing::TimingTree; +use tynm::type_name; +#[cfg(feature = "cuda")] +use cryptography_cuda::{init_cuda_degree_rs}; + +pub(crate) fn bench_batch_lde, C: GenericConfig, const D: usize>(c: &mut Criterion) +{ + const RATE_BITS: usize = 3; + + let mut group = c.benchmark_group(&format!("lde<{}>", type_name::())); + + #[cfg(feature = "cuda")] + init_cuda_degree_rs(16); + + for size_log in [13, 14, 15] { + let orig_size = 1 << (size_log - RATE_BITS); + let lde_size = 1 << size_log; + let batch_size = 1 << 4; + + group.bench_with_input(BenchmarkId::from_parameter(lde_size), &lde_size, |b, _| { + let polynomials: Vec> = (0..batch_size).into_iter().map(|_i| { + PolynomialCoeffs::new(F::rand_vec(orig_size)) + }).collect(); + let mut timing = TimingTree::new("lde", log::Level::Error); + b.iter(|| { + PolynomialBatch::::from_coeffs(polynomials.clone(), RATE_BITS, false, 1, &mut timing, None) + }); + }); + } +} + +fn criterion_benchmark(c: &mut Criterion) { + bench_batch_lde::(c); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); From 37d0b68ae955c30cf86ef7519a484985b818a525 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Wed, 2 Oct 2024 15:51:26 +0800 Subject: [PATCH 04/16] cargo fmt --- plonky2/benches/lde.rs | 29 +++++--- plonky2/src/fri/oracle.rs | 13 ++-- plonky2/src/gates/gate.rs | 2 +- plonky2/src/gates/low_degree_interpolation.rs | 6 +- plonky2/src/hash/poseidon_bn128.rs | 3 +- plonky2/src/hash/poseidon_goldilocks.rs | 66 +++++++++++++++---- plonky2/src/plonk/config.rs | 2 +- .../util/serialization/gate_serialization.rs | 26 ++++---- plonky2/src/util/serialization/mod.rs | 2 +- 9 files changed, 102 insertions(+), 47 deletions(-) diff --git a/plonky2/benches/lde.rs b/plonky2/benches/lde.rs index c1a702f931..465c60846c 100644 --- a/plonky2/benches/lde.rs +++ b/plonky2/benches/lde.rs @@ -1,6 +1,8 @@ mod allocator; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +#[cfg(feature = "cuda")] +use cryptography_cuda::init_cuda_degree_rs; use plonky2::field::extension::Extendable; use plonky2::field::goldilocks_field::GoldilocksField; use plonky2::field::polynomial::PolynomialCoeffs; @@ -9,11 +11,14 @@ use plonky2::hash::hash_types::RichField; use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use plonky2::util::timing::TimingTree; use tynm::type_name; -#[cfg(feature = "cuda")] -use cryptography_cuda::{init_cuda_degree_rs}; -pub(crate) fn bench_batch_lde, C: GenericConfig, const D: usize>(c: &mut Criterion) -{ +pub(crate) fn bench_batch_lde< + F: RichField + Extendable, + C: GenericConfig, + const D: usize, +>( + c: &mut Criterion, +) { const RATE_BITS: usize = 3; let mut group = c.benchmark_group(&format!("lde<{}>", type_name::())); @@ -27,12 +32,20 @@ pub(crate) fn bench_batch_lde, C: GenericConfig> = (0..batch_size).into_iter().map(|_i| { - PolynomialCoeffs::new(F::rand_vec(orig_size)) - }).collect(); + let polynomials: Vec> = (0..batch_size) + .into_iter() + .map(|_i| PolynomialCoeffs::new(F::rand_vec(orig_size))) + .collect(); let mut timing = TimingTree::new("lde", log::Level::Error); b.iter(|| { - PolynomialBatch::::from_coeffs(polynomials.clone(), RATE_BITS, false, 1, &mut timing, None) + PolynomialBatch::::from_coeffs( + polynomials.clone(), + RATE_BITS, + false, + 1, + &mut timing, + None, + ) }); }); } diff --git a/plonky2/src/fri/oracle.rs b/plonky2/src/fri/oracle.rs index 1e8db6c994..2fbfa9f33a 100644 --- a/plonky2/src/fri/oracle.rs +++ b/plonky2/src/fri/oracle.rs @@ -196,7 +196,10 @@ impl, C: GenericConfig, const D: usize> let degree = polynomials[0].len(); let log_n = log2_strict(degree); - if log_n + rate_bits > 1 && polynomials.len() > 0 && pols * (1 << (log_n + rate_bits)) < (1 << 31) { + if log_n + rate_bits > 1 + && polynomials.len() > 0 + && pols * (1 << (log_n + rate_bits)) < (1 << 31) + { let _num_gpus: usize = std::env::var("NUM_OF_GPUS") .expect("NUM_OF_GPUS should be set") .parse() @@ -256,7 +259,7 @@ impl, C: GenericConfig, const D: usize> let total_num_of_fft = polynomials.len(); // println!("total_num_of_fft: {:?}", total_num_of_fft); - let num_of_cols = total_num_of_fft + salt_size; // if blinding, extend by salt_size + let num_of_cols = total_num_of_fft + salt_size; // if blinding, extend by salt_size let total_num_input_elements = total_num_of_fft * (1 << log_n); let total_num_output_elements = num_of_cols * (1 << output_domain_size); @@ -451,9 +454,9 @@ impl, C: GenericConfig, const D: usize> r }) .chain( - (0..salt_size) - .into_par_iter() - .map(|_| F::rand_vec(degree << rate_bits)), + (0..salt_size) + .into_par_iter() + .map(|_| F::rand_vec(degree << rate_bits)), ) .collect(); println!("real lde elapsed: {:?}", start_lde.elapsed()); diff --git a/plonky2/src/gates/gate.rs b/plonky2/src/gates/gate.rs index 3e1f29742d..2d41aa7ffb 100644 --- a/plonky2/src/gates/gate.rs +++ b/plonky2/src/gates/gate.rs @@ -8,7 +8,7 @@ use core::ops::Range; use std::sync::Arc; use hashbrown::HashMap; -use serde::{ Serialize, Serializer}; +use serde::{Serialize, Serializer}; use crate::field::batch_util::batch_multiply_inplace; use crate::field::extension::{Extendable, FieldExtension}; diff --git a/plonky2/src/gates/low_degree_interpolation.rs b/plonky2/src/gates/low_degree_interpolation.rs index 9fe5421701..9407003d86 100644 --- a/plonky2/src/gates/low_degree_interpolation.rs +++ b/plonky2/src/gates/low_degree_interpolation.rs @@ -84,11 +84,7 @@ impl, const D: usize> Gate for LowDegreeInter fn id(&self) -> String { format!("{self:?}") } - fn serialize( - &self, - dst: &mut Vec, - _common_data: &CommonCircuitData, - ) -> IoResult<()> { + fn serialize(&self, dst: &mut Vec, _common_data: &CommonCircuitData) -> IoResult<()> { dst.write_usize(self.subgroup_bits)?; Ok(()) } diff --git a/plonky2/src/hash/poseidon_bn128.rs b/plonky2/src/hash/poseidon_bn128.rs index 60053b3960..fcb8bae00c 100644 --- a/plonky2/src/hash/poseidon_bn128.rs +++ b/plonky2/src/hash/poseidon_bn128.rs @@ -262,7 +262,8 @@ mod tests { use plonky2_field::types::Field; use super::PoseidonBN128Hash; - use crate::{hash::poseidon::PoseidonHash, plonk::config::{GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig}}; + use crate::hash::poseidon::PoseidonHash; + use crate::plonk::config::{GenericConfig, GenericHashOut, Hasher, PoseidonGoldilocksConfig}; #[test] fn test_poseidon_bn128_hash_no_pad() -> Result<()> { diff --git a/plonky2/src/hash/poseidon_goldilocks.rs b/plonky2/src/hash/poseidon_goldilocks.rs index bfd5c59365..164ec6d633 100644 --- a/plonky2/src/hash/poseidon_goldilocks.rs +++ b/plonky2/src/hash/poseidon_goldilocks.rs @@ -510,12 +510,27 @@ mod tests { F::from_canonical_u64(0), F::from_canonical_u64(0), F::from_canonical_u64(0), - F::from_canonical_u64(0) + F::from_canonical_u64(0), ]; let output = F::poseidon(input); - let expected_out: [u64;12] = [ - 7211848465497282123, 8334407123774112207, 4858661444170722461, 8419634888969461752, 8365439750915196882, 13994809114733475841, 8086590873907410085, 17222247664612180184, 2859807231239647069, 1588164466493087886, 10963846266850921292, 10092827555303260923 - ]; let expected_out = expected_out.iter().map(|x| F::from_canonical_u64(*x)).collect::>(); + let expected_out: [u64; 12] = [ + 7211848465497282123, + 8334407123774112207, + 4858661444170722461, + 8419634888969461752, + 8365439750915196882, + 13994809114733475841, + 8086590873907410085, + 17222247664612180184, + 2859807231239647069, + 1588164466493087886, + 10963846266850921292, + 10092827555303260923, + ]; + let expected_out = expected_out + .iter() + .map(|x| F::from_canonical_u64(*x)) + .collect::>(); assert_eq!(output.to_vec(), expected_out); let input: [F; 12] = [ @@ -530,17 +545,33 @@ mod tests { F::from_canonical_u64(0), F::from_canonical_u64(0), F::from_canonical_u64(0), - F::from_canonical_u64(0) + F::from_canonical_u64(0), ]; let output = F::poseidon(input); - let expected_out: [u64;12] = [11994017978598211037, 7557030840175886847, 2132360640983728466, 4344091215078417239, 5401009700429511129, 2034618959601429994, 11010409655003603569, 8592131210799925716, 8985230087572094046, 12365839308703522999, 6320659093029715449, 16143392566362192896]; - let expected_out = expected_out.iter().map(|x| F::from_canonical_u64(*x)).collect::>(); + let expected_out: [u64; 12] = [ + 11994017978598211037, + 7557030840175886847, + 2132360640983728466, + 4344091215078417239, + 5401009700429511129, + 2034618959601429994, + 11010409655003603569, + 8592131210799925716, + 8985230087572094046, + 12365839308703522999, + 6320659093029715449, + 16143392566362192896, + ]; + let expected_out = expected_out + .iter() + .map(|x| F::from_canonical_u64(*x)) + .collect::>(); assert_eq!(output.to_vec(), expected_out); } #[test] fn test_hash_no_pad_gl() { - let inputs: [u64; 32] =[ + let inputs: [u64; 32] = [ 9972144316416239374, 7195869958086994472, 12805395537960412263, @@ -572,13 +603,24 @@ mod tests { 2150999602305437005, 9103462636082953981, 16341057499572706412, - 842265247111451937 + 842265247111451937, ]; - let inputs = inputs.iter().map(|x| F::from_canonical_u64(*x)).collect::>(); + let inputs = inputs + .iter() + .map(|x| F::from_canonical_u64(*x)) + .collect::>(); let output = PoseidonHash::hash_no_pad(&inputs); - let expected_out: [u64;4] = [8197835875512527937, 7109417654116018994, 18237163116575285904, 17017896878738047012]; - let expected_out = expected_out.iter().map(|x| F::from_canonical_u64(*x)).collect::>(); + let expected_out: [u64; 4] = [ + 8197835875512527937, + 7109417654116018994, + 18237163116575285904, + 17017896878738047012, + ]; + let expected_out = expected_out + .iter() + .map(|x| F::from_canonical_u64(*x)) + .collect::>(); assert_eq!(output.elements.to_vec(), expected_out); } } diff --git a/plonky2/src/plonk/config.rs b/plonky2/src/plonk/config.rs index f33fbd4f19..4b5ddceada 100644 --- a/plonky2/src/plonk/config.rs +++ b/plonky2/src/plonk/config.rs @@ -11,7 +11,7 @@ use alloc::{vec, vec::Vec}; use core::fmt::Debug; use serde::de::DeserializeOwned; -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; use crate::field::extension::quadratic::QuadraticExtension; use crate::field::extension::{Extendable, FieldExtension}; diff --git a/plonky2/src/util/serialization/gate_serialization.rs b/plonky2/src/util/serialization/gate_serialization.rs index 76d1eeb8f2..5de2179879 100644 --- a/plonky2/src/util/serialization/gate_serialization.rs +++ b/plonky2/src/util/serialization/gate_serialization.rs @@ -100,10 +100,10 @@ pub mod default { use crate::gates::base_sum::BaseSumGate; use crate::gates::constant::ConstantGate; use crate::gates::coset_interpolation::CosetInterpolationGate; - use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; use crate::gates::exponentiation::ExponentiationGate; use crate::gates::lookup::LookupGate; use crate::gates::lookup_table::LookupTableGate; + use crate::gates::low_degree_interpolation::LowDegreeInterpolationGate; use crate::gates::multiplication_extension::MulExtensionGate; use crate::gates::noop::NoopGate; use crate::gates::poseidon::PoseidonGate; @@ -160,18 +160,18 @@ mod test { use crate::plonk::circuit_data::{CircuitConfig, CommonCircuitData, VerifierOnlyCircuitData}; use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; use crate::util::serialization::DefaultGateSerializer; - + #[test] fn test_gate_serialization() { const D: usize = 2; type C = PoseidonGoldilocksConfig; type F = >::F; - + let config = CircuitConfig::standard_recursion_config(); let mut builder = CircuitBuilder::::new(config); - + let targets: [Target; 4] = core::array::from_fn(|_| builder.add_virtual_target()); // (0..4).map(|).collect(); - + builder.hash_n_to_hash_no_pad::(targets.to_vec()); let mut pw = PartialWitness::new(); @@ -179,28 +179,28 @@ mod test { pw.set_target(targets[1], F::ONE); pw.set_target(targets[2], F::ZERO); pw.set_target(targets[3], F::ONE); - + let data = builder.build::(); - + let common: CommonCircuitData = data.common; let gate_serializer = DefaultGateSerializer; let common_data_bytes = common .to_bytes(&gate_serializer) - .map_err(|_| anyhow::Error::msg("CommonCircuitData serialization failed.")).unwrap(); - + .map_err(|_| anyhow::Error::msg("CommonCircuitData serialization failed.")) + .unwrap(); let recoverred_common_data = CommonCircuitData::::from_bytes(common_data_bytes, &gate_serializer) - .map_err(|_| anyhow::Error::msg("CommonCircuitData deserialization failed.")).unwrap(); + .map_err(|_| anyhow::Error::msg("CommonCircuitData deserialization failed.")) + .unwrap(); assert_eq!(common, recoverred_common_data); let vd = data.verifier_only; let vd_str = serde_json::to_string(&vd).unwrap(); - let vd_recoverred : VerifierOnlyCircuitData = serde_json::from_str(&vd_str).unwrap(); + let vd_recoverred: VerifierOnlyCircuitData = serde_json::from_str(&vd_str).unwrap(); assert_eq!(vd, vd_recoverred); - } -} \ No newline at end of file +} diff --git a/plonky2/src/util/serialization/mod.rs b/plonky2/src/util/serialization/mod.rs index 8c3658aa6c..4d5b8e7fa7 100644 --- a/plonky2/src/util/serialization/mod.rs +++ b/plonky2/src/util/serialization/mod.rs @@ -11,7 +11,7 @@ use core::fmt::{Debug, Display, Formatter}; use core::mem::size_of; use core::ops::Range; #[cfg(feature = "std")] -use std::{collections::BTreeMap}; +use std::collections::BTreeMap; pub use gate_serialization::default::DefaultGateSerializer; pub use gate_serialization::GateSerializer; From 57a6c34143368b0004cc9b1f1cf3beb1698f15c4 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Wed, 2 Oct 2024 15:54:11 +0800 Subject: [PATCH 05/16] cargo fmt rustc 1.83.0 --- starky/src/fibonacci_stark.rs | 3 ++- starky/src/permutation_stark.rs | 3 ++- starky/src/unconstrained_stark.rs | 3 ++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/starky/src/fibonacci_stark.rs b/starky/src/fibonacci_stark.rs index 7aa40b6ed9..0ebc30ef38 100644 --- a/starky/src/fibonacci_stark.rs +++ b/starky/src/fibonacci_stark.rs @@ -61,7 +61,8 @@ const FIBONACCI_COLUMNS: usize = 2; const FIBONACCI_PUBLIC_INPUTS: usize = 3; impl, const D: usize> Stark for FibonacciStark { - type EvaluationFrame = StarkFrame + type EvaluationFrame + = StarkFrame where FE: FieldExtension, P: PackedField; diff --git a/starky/src/permutation_stark.rs b/starky/src/permutation_stark.rs index 62290b658d..998665d1b2 100644 --- a/starky/src/permutation_stark.rs +++ b/starky/src/permutation_stark.rs @@ -55,7 +55,8 @@ const PERM_COLUMNS: usize = 3; const PERM_PUBLIC_INPUTS: usize = 1; impl, const D: usize> Stark for PermutationStark { - type EvaluationFrame = StarkFrame + type EvaluationFrame + = StarkFrame where FE: FieldExtension, P: PackedField; diff --git a/starky/src/unconstrained_stark.rs b/starky/src/unconstrained_stark.rs index 2f93c25556..a6bd7ea8cd 100644 --- a/starky/src/unconstrained_stark.rs +++ b/starky/src/unconstrained_stark.rs @@ -45,7 +45,8 @@ const COLUMNS: usize = 2; const PUBLIC_INPUTS: usize = 0; impl, const D: usize> Stark for UnconstrainedStark { - type EvaluationFrame = StarkFrame + type EvaluationFrame + = StarkFrame where FE: FieldExtension, P: PackedField; From 92a5413a258d8732f753ea48eafd013a550ff74c Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Mon, 14 Oct 2024 12:32:23 +0800 Subject: [PATCH 06/16] fix tests to run under --feature=cuda --- field/src/fft.rs | 16 --------- plonky2/Cargo.toml | 2 +- plonky2/src/fri/oracle.rs | 23 +++++++++++- plonky2/src/hash/merkle_tree.rs | 42 ++++++++++++++-------- plonky2/src/plonk/circuit_builder.rs | 53 +++++++++++++++++++++++++++- 5 files changed, 103 insertions(+), 33 deletions(-) diff --git a/field/src/fft.rs b/field/src/fft.rs index 09c9104422..700af4fb54 100644 --- a/field/src/fft.rs +++ b/field/src/fft.rs @@ -1,8 +1,6 @@ use alloc::vec::Vec; use core::cmp::{max, min}; -#[cfg(feature = "cuda")] -use cryptography_cuda::{ntt, types::NTTInputOutputOrder}; use plonky2_util::{log2_strict, reverse_index_bits_in_place}; use unroll::unroll_for_loops; @@ -34,20 +32,6 @@ pub fn fft_root_table(n: usize) -> FftRootTable { root_table } -#[allow(dead_code)] -#[cfg(feature = "cuda")] -fn fft_dispatch_gpu( - input: &mut [F], - zero_factor: Option, - root_table: Option<&FftRootTable>, -) { - if F::CUDA_SUPPORT { - return ntt(0, input, NTTInputOutputOrder::NN); - } else { - return fft_dispatch_cpu(input, zero_factor, root_table); - } -} - fn fft_dispatch_cpu( input: &mut [F], zero_factor: Option, diff --git a/plonky2/Cargo.toml b/plonky2/Cargo.toml index 791a64c922..7de17ec942 100644 --- a/plonky2/Cargo.toml +++ b/plonky2/Cargo.toml @@ -37,7 +37,7 @@ serde = { workspace = true, features = ["rc"] } static_assertions = { workspace = true } unroll = { workspace = true } web-time = { version = "1.0.0", optional = true } -once_cell = { version = "1.18.0" } +once_cell = { version = "1.20.2" } papi-bindings = { version = "0.5.2" } # Local dependencies diff --git a/plonky2/src/fri/oracle.rs b/plonky2/src/fri/oracle.rs index 2fbfa9f33a..99a82ec6b5 100644 --- a/plonky2/src/fri/oracle.rs +++ b/plonky2/src/fri/oracle.rs @@ -27,6 +27,21 @@ use crate::util::reducing::ReducingFactor; use crate::util::timing::TimingTree; use crate::util::{log2_strict, reverse_bits, reverse_index_bits_in_place, transpose}; +#[cfg(all(feature = "cuda", any(test, doctest)))] +pub static GPU_INIT: once_cell::sync::Lazy>> = once_cell::sync::Lazy::new(|| std::sync::Arc::new(std::sync::Mutex::new(0))); + +#[cfg(all(feature = "cuda", any(test, doctest)))] +fn init_gpu() { + use cryptography_cuda::init_cuda_rs; + + let mut init = GPU_INIT.lock().unwrap(); + if *init == 0 { + println!("Init GPU!"); + init_cuda_rs(); + *init = 1; + } +} + /// Four (~64 bit) field elements gives ~128 bit security. pub const SALT_SIZE: usize = 4; @@ -196,6 +211,9 @@ impl, C: GenericConfig, const D: usize> let degree = polynomials[0].len(); let log_n = log2_strict(degree); + #[cfg(any(test, doctest))] + init_gpu(); + if log_n + rate_bits > 1 && polynomials.len() > 0 && pols * (1 << (log_n + rate_bits)) < (1 << 31) @@ -236,7 +254,7 @@ impl, C: GenericConfig, const D: usize> } #[cfg(feature = "cuda")] - pub fn from_coeffs_gpu( + fn from_coeffs_gpu( polynomials: &[PolynomialCoeffs], rate_bits: usize, blinding: bool, @@ -350,6 +368,9 @@ impl, C: GenericConfig, const D: usize> blinding: bool, fft_root_table: Option<&FftRootTable>, ) -> Vec> { + #[cfg(all(feature = "cuda", any(test, doctest)))] + init_gpu(); + let degree = polynomials[0].len(); #[cfg(all(feature = "cuda", feature = "batch"))] let log_n = log2_strict(degree) + rate_bits; diff --git a/plonky2/src/hash/merkle_tree.rs b/plonky2/src/hash/merkle_tree.rs index 9ffe3850ac..fe78b94b8f 100644 --- a/plonky2/src/hash/merkle_tree.rs +++ b/plonky2/src/hash/merkle_tree.rs @@ -379,20 +379,24 @@ fn fill_digests_buf_gpu_ptr>( let stream1 = CudaStream::create().unwrap(); let stream2 = CudaStream::create().unwrap(); - gpu_digests_buf - .copy_to_host_ptr_async( - digests_buf.as_mut_ptr() as *mut core::ffi::c_void, - digests_size, - &stream1, - ) - .expect("copy digests"); - gpu_cap_buf - .copy_to_host_ptr_async( - cap_buf.as_mut_ptr() as *mut core::ffi::c_void, - caps_size, - &stream2, - ) - .expect("copy caps"); + if digests_buf.len() != 0 { + gpu_digests_buf + .copy_to_host_ptr_async( + digests_buf.as_mut_ptr() as *mut core::ffi::c_void, + digests_size, + &stream1, + ) + .expect("copy digests"); + } + if cap_buf.len() != 0 { + gpu_cap_buf + .copy_to_host_ptr_async( + cap_buf.as_mut_ptr() as *mut core::ffi::c_void, + caps_size, + &stream2, + ) + .expect("copy caps"); + } stream1.synchronize().expect("cuda sync"); stream2.synchronize().expect("cuda sync"); stream1.destroy().expect("cuda stream destroy"); @@ -545,6 +549,16 @@ impl> MerkleTree { leaf_len: usize, cap_height: usize, ) -> Self { + // special case + if leaf_len <= H::HASH_SIZE / 8 || H::HASHER_TYPE == HasherType::Keccak { + let mut host_leaves: Vec = vec![F::ZERO; leaves_len * leaf_len]; + leaves_gpu_ptr + .copy_to_host(host_leaves.as_mut_slice(), leaves_len * leaf_len) + .expect("copy to host error"); + return Self::new_from_1d(host_leaves, leaf_len, cap_height); + } + + // general case let log2_leaves_len = log2_strict(leaves_len); assert!( cap_height <= log2_leaves_len, diff --git a/plonky2/src/plonk/circuit_builder.rs b/plonky2/src/plonk/circuit_builder.rs index 2010a6c974..c8d339d6fa 100644 --- a/plonky2/src/plonk/circuit_builder.rs +++ b/plonky2/src/plonk/circuit_builder.rs @@ -90,7 +90,7 @@ pub struct LookupWire { /// /// # Usage /// -/// ```rust +/// ```ignore /// use plonky2::plonk::circuit_data::CircuitConfig; /// use plonky2::iop::witness::PartialWitness; /// use plonky2::plonk::circuit_builder::CircuitBuilder; @@ -1289,3 +1289,54 @@ impl, const D: usize> CircuitBuilder { circuit_data.verifier_data() } } + +#[cfg(test)] +mod tests { + use crate::field::types::Field; + use crate::iop::witness::PartialWitness; + use crate::plonk::circuit_builder::CircuitBuilder; + use crate::plonk::circuit_data::CircuitConfig; + use crate::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; + + // this is the code at line 93 + #[test] + fn test_builder() { + // Define parameters for this circuit + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let config = CircuitConfig::standard_recursion_config(); + let mut builder = CircuitBuilder::::new(config); + + // Build a circuit for the statement: "I know the 100th term + // of the Fibonacci sequence, starting from 0 and 1". + let initial_a = builder.constant(F::ZERO); + let initial_b = builder.constant(F::ONE); + let mut prev_target = initial_a; + let mut cur_target = initial_b; + for _ in 0..99 { + // Encode an addition of the two previous terms + let temp = builder.add(prev_target, cur_target); + // Shift the two previous terms with the new value + prev_target = cur_target; + cur_target = temp; + } + + // The only public input is the result (which is generated). + builder.register_public_input(cur_target); + + // Build the circuit + let circuit_data = builder.build::(); + + // Now compute the witness and generate a proof + let pw = PartialWitness::new(); + + // There are no public inputs to register, as the only one + // will be generated while proving the statement. + let proof = circuit_data.prove(pw).unwrap(); + + // Verify the proof + assert!(circuit_data.verify(proof).is_ok()); + } +} From 2b034f1cd92d347bd78b7da2edc0c2c02e0b7df9 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Mon, 14 Oct 2024 12:32:39 +0800 Subject: [PATCH 07/16] update crypto cuda ref --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 68dfb213f7..a11b080104 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = ["field", "maybe_rayon", "plonky2", "starky", "util", "gen", "u32", "e resolver = "2" [workspace.dependencies] -cryptography_cuda = { git = "ssh://git@github.com/okx/cryptography_cuda.git", rev = "51e363b7074ad48edad09fad5d936c81b075f1a7" } +cryptography_cuda = { git = "ssh://git@github.com/okx/cryptography_cuda.git", rev = "547192b2ef42dc7519435059c86f88431b8de999" } ahash = { version = "0.8.7", default-features = false, features = [ "compile-time-rng", ] } # NOTE: Be sure to keep this version the same as the dependency in `hashbrown`. From 49813c81d9196402e1354ad55510bd1e1df3476d Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Tue, 15 Oct 2024 11:32:48 +0800 Subject: [PATCH 08/16] poseidon avx512 implementation --- .../src/hash/arch/x86_64/goldilocks_avx512.rs | 56 +- plonky2/src/hash/arch/x86_64/mod.rs | 2 +- .../arch/x86_64/poseidon_goldilocks_avx2.rs | 8 +- .../arch/x86_64/poseidon_goldilocks_avx512.rs | 1491 ++++++++++++++++- plonky2/src/hash/merkle_tree.rs | 269 ++- plonky2/src/hash/poseidon.rs | 37 +- plonky2/src/hash/poseidon_goldilocks.rs | 7 +- 7 files changed, 1816 insertions(+), 54 deletions(-) diff --git a/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs index 6c37bec0b4..ce86ce67de 100644 --- a/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs @@ -7,11 +7,26 @@ const MSB_: i64 = 0x8000000000000000u64 as i64; const P8_: i64 = 0xFFFFFFFF00000001u64 as i64; const P8_N_: i64 = 0xFFFFFFFF; +#[allow(non_snake_case)] +#[repr(align(64))] +struct FieldConstants { + MSB_V: [i64; 8], + P8_V: [i64; 8], + P8_N_V: [i64; 8], +} + +const FC: FieldConstants = FieldConstants { + MSB_V: [MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_], + P8_V: [P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_], + P8_N_V: [P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_], +}; + #[allow(dead_code)] #[inline(always)] pub fn shift_avx512(a: &__m512i) -> __m512i { unsafe { - let msb = _mm512_set_epi64(MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_); + // let msb = _mm512_set_epi64(MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_); + let msb = _mm512_load_si512(FC.MSB_V.as_ptr().cast::()); _mm512_xor_si512(*a, msb) } } @@ -20,17 +35,20 @@ pub fn shift_avx512(a: &__m512i) -> __m512i { #[inline(always)] pub fn to_canonical_avx512(a: &__m512i) -> __m512i { unsafe { - let p8 = _mm512_set_epi64(P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_); - let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_); + // let p8 = _mm512_set_epi64(P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_); + // let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_); + let p8 = _mm512_load_si512(FC.P8_V.as_ptr().cast::()); + let p8_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::()); let result_mask = _mm512_cmpge_epu64_mask(*a, p8); _mm512_mask_add_epi64(*a, result_mask, *a, p8_n) } } #[inline(always)] -pub fn add_avx512_b_c(a: &__m512i, b: &__m512i) -> __m512i { +pub fn add_avx512(a: &__m512i, b: &__m512i) -> __m512i { unsafe { - let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_); + // let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_); + let p8_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::()); let c0 = _mm512_add_epi64(*a, *b); let result_mask = _mm512_cmpgt_epu64_mask(*a, c0); _mm512_mask_add_epi64(c0, result_mask, c0, p8_n) @@ -38,9 +56,10 @@ pub fn add_avx512_b_c(a: &__m512i, b: &__m512i) -> __m512i { } #[inline(always)] -pub fn sub_avx512_b_c(a: &__m512i, b: &__m512i) -> __m512i { +pub fn sub_avx512(a: &__m512i, b: &__m512i) -> __m512i { unsafe { - let p8 = _mm512_set_epi64(P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_); + // let p8 = _mm512_set_epi64(P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_); + let p8 = _mm512_load_si512(FC.P8_V.as_ptr().cast::()); let c0 = _mm512_sub_epi64(*a, *b); let result_mask = _mm512_cmpgt_epu64_mask(*b, *a); _mm512_mask_add_epi64(c0, result_mask, c0, p8) @@ -50,11 +69,25 @@ pub fn sub_avx512_b_c(a: &__m512i, b: &__m512i) -> __m512i { #[inline(always)] pub fn reduce_avx512_128_64(c_h: &__m512i, c_l: &__m512i) -> __m512i { unsafe { - let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_); + // let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_); + let p8_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::()); let c_hh = _mm512_srli_epi64(*c_h, 32); - let c1 = sub_avx512_b_c(c_l, &c_hh); + let c1 = sub_avx512(c_l, &c_hh); let c2 = _mm512_mul_epu32(*c_h, p8_n); - add_avx512_b_c(&c1, &c2) + add_avx512(&c1, &c2) + } +} + +// Here we suppose c_h < 2^32 +#[inline(always)] +pub fn reduce_avx512_96_64(c_h: &__m512i, c_l: &__m512i) -> __m512i { + unsafe { + let msb = _mm512_load_si512(FC.MSB_V.as_ptr().cast::()); + let p_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::()); + let c_ls = _mm512_xor_si512(*c_l, msb); + let c2 = _mm512_mul_epu32(*c_h, p_n); + let c_s = add_avx512(&c_ls, &c2); + _mm512_xor_si512(c_s, msb) } } @@ -69,7 +102,8 @@ pub fn mult_avx512_128(a: &__m512i, b: &__m512i) -> (__m512i, __m512i) { let c_ll = _mm512_mul_epu32(*a, *b); let c_ll_h = _mm512_srli_epi64(c_ll, 32); let r0 = _mm512_add_epi64(c_hl, c_ll_h); - let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_); + // let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_); + let p8_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::()); let r0_l = _mm512_and_si512(r0, p8_n); let r0_h = _mm512_srli_epi64(r0, 32); let r1 = _mm512_add_epi64(c_lh, r0_l); diff --git a/plonky2/src/hash/arch/x86_64/mod.rs b/plonky2/src/hash/arch/x86_64/mod.rs index 28b49ce53c..cb21e20f40 100644 --- a/plonky2/src/hash/arch/x86_64/mod.rs +++ b/plonky2/src/hash/arch/x86_64/mod.rs @@ -10,7 +10,7 @@ pub mod goldilocks_avx512; pub mod poseidon2_goldilocks_avx2; #[cfg(target_feature = "avx2")] pub mod poseidon_bn128_avx2; -#[cfg(all(target_feature = "avx2", not(target_feature = "avx512dq")))] +#[cfg(target_feature = "avx2")] pub mod poseidon_goldilocks_avx2; #[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] pub mod poseidon_goldilocks_avx512; diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs index 0564a73de8..db86c76b98 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs @@ -218,7 +218,7 @@ const FAST_PARTIAL_ROUND_INITIAL_MATRIX: [[u64; 12]; 12] = [ ], ]; -const FAST_PARTIAL_ROUND_W_HATS: [[u64; 12 - 1]; N_PARTIAL_ROUNDS] = [ +pub const FAST_PARTIAL_ROUND_W_HATS: [[u64; 12 - 1]; N_PARTIAL_ROUNDS] = [ [ 0x3d999c961b7c63b0, 0x814e82efcd172529, @@ -818,9 +818,9 @@ const FAST_PARTIAL_ROUND_VS: [[u64; 12]; N_PARTIAL_ROUNDS] = [ ], ]; -const MDS_FREQ_BLOCK_ONE: [i64; 3] = [16, 32, 16]; -const MDS_FREQ_BLOCK_TWO: [(i64, i64); 3] = [(2, -1), (-4, 1), (16, 1)]; -const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-1, -8, 2]; +pub(crate) const MDS_FREQ_BLOCK_ONE: [i64; 3] = [16, 32, 16]; +pub(crate) const MDS_FREQ_BLOCK_TWO: [(i64, i64); 3] = [(2, -1), (-4, 1), (16, 1)]; +pub(crate) const MDS_FREQ_BLOCK_THREE: [i64; 3] = [-1, -8, 2]; #[allow(dead_code)] #[inline(always)] diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs index 0ab9dc9146..be920c065c 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs @@ -2,11 +2,19 @@ use core::arch::x86_64::*; use unroll::unroll_for_loops; +use super::poseidon_goldilocks_avx2::{ + MDS_FREQ_BLOCK_ONE, MDS_FREQ_BLOCK_THREE, MDS_FREQ_BLOCK_TWO, +}; use crate::field::types::PrimeField64; use crate::hash::arch::x86_64::goldilocks_avx512::*; +use crate::hash::arch::x86_64::poseidon_goldilocks_avx2::FAST_PARTIAL_ROUND_W_HATS; +use crate::hash::hash_types::{HashOut, RichField}; use crate::hash::poseidon::{ - Poseidon, ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, N_ROUNDS, SPONGE_WIDTH, + add_u160_u128, reduce_u160, Poseidon, ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, + N_PARTIAL_ROUNDS, N_ROUNDS, SPONGE_RATE, SPONGE_WIDTH, }; +use crate::hash::poseidon_goldilocks::poseidon12_mds::block2; +use crate::plonk::config::GenericHashOut; #[allow(dead_code)] const MDS_MATRIX_CIRC: [u64; 12] = [17, 15, 41, 16, 2, 28, 13, 13, 39, 18, 34, 20]; @@ -29,6 +37,33 @@ const FAST_PARTIAL_FIRST_ROUND_CONSTANT: [u64; 12] = [ 0xc33448feadc78f0c, ]; +const FAST_PARTIAL_FIRST_ROUND_CONSTANT_AVX512: [u64; 24] = [ + 0x3cc3f892184df408, + 0xe993fd841e7e97f1, + 0xf2831d3575f0f3af, + 0xd2500e0a350994ca, + 0x3cc3f892184df408, + 0xe993fd841e7e97f1, + 0xf2831d3575f0f3af, + 0xd2500e0a350994ca, + 0xc5571f35d7288633, + 0x91d89c5184109a02, + 0xf37f925d04e5667b, + 0x2d6e448371955a69, + 0xc5571f35d7288633, + 0x91d89c5184109a02, + 0xf37f925d04e5667b, + 0x2d6e448371955a69, + 0x740ef19ce01398a1, + 0x694d24c0752fdf45, + 0x60936af96ee2f148, + 0xc33448feadc78f0c, + 0x740ef19ce01398a1, + 0x694d24c0752fdf45, + 0x60936af96ee2f148, + 0xc33448feadc78f0c, +]; + const FAST_PARTIAL_ROUND_CONSTANTS: [u64; N_PARTIAL_ROUNDS] = [ 0x74cb2e819ae421ab, 0xd2559d2370e7f663, @@ -54,6 +89,7 @@ const FAST_PARTIAL_ROUND_CONSTANTS: [u64; N_PARTIAL_ROUNDS] = [ 0x0, ]; +#[allow(unused)] const FAST_PARTIAL_ROUND_INITIAL_MATRIX: [[u64; 12]; 12] = [ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [ @@ -212,6 +248,884 @@ const FAST_PARTIAL_ROUND_INITIAL_MATRIX: [[u64; 12]; 12] = [ ], ]; +const FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512: [[u64; 24]; 12] = [ + [ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + ], + [ + 0, + 0x80772dc2645b280b, + 0xdc927721da922cf8, + 0xc1978156516879ad, + 0, + 0x80772dc2645b280b, + 0xdc927721da922cf8, + 0xc1978156516879ad, + 0x90e80c591f48b603, + 0x3a2432625475e3ae, + 0x00a2d4321cca94fe, + 0x77736f524010c932, + 0x90e80c591f48b603, + 0x3a2432625475e3ae, + 0x00a2d4321cca94fe, + 0x77736f524010c932, + 0x904d3f2804a36c54, + 0xbf9b39e28a16f354, + 0x3a1ded54a6cd058b, + 0x42392870da5737cf, + 0x904d3f2804a36c54, + 0xbf9b39e28a16f354, + 0x3a1ded54a6cd058b, + 0x42392870da5737cf, + ], + [ + 0, + 0xe796d293a47a64cb, + 0xb124c33152a2421a, + 0x0ee5dc0ce131268a, + 0, + 0xe796d293a47a64cb, + 0xb124c33152a2421a, + 0x0ee5dc0ce131268a, + 0xa9032a52f930fae6, + 0x7e33ca8c814280de, + 0xad11180f69a8c29e, + 0xc75ac6d5b5a10ff3, + 0xa9032a52f930fae6, + 0x7e33ca8c814280de, + 0xad11180f69a8c29e, + 0xc75ac6d5b5a10ff3, + 0xf0674a8dc5a387ec, + 0xb36d43120eaa5e2b, + 0x6f232aab4b533a25, + 0x3a1ded54a6cd058b, + 0xf0674a8dc5a387ec, + 0xb36d43120eaa5e2b, + 0x6f232aab4b533a25, + 0x3a1ded54a6cd058b, + ], + [ + 0, + 0xdcedab70f40718ba, + 0x14a4a64da0b2668f, + 0x4715b8e5ab34653b, + 0, + 0xdcedab70f40718ba, + 0x14a4a64da0b2668f, + 0x4715b8e5ab34653b, + 0x1e8916a99c93a88e, + 0xbba4b5d86b9a3b2c, + 0xe76649f9bd5d5c2e, + 0xaf8e2518a1ece54d, + 0x1e8916a99c93a88e, + 0xbba4b5d86b9a3b2c, + 0xe76649f9bd5d5c2e, + 0xaf8e2518a1ece54d, + 0xdcda1344cdca873f, + 0xcd080204256088e5, + 0xb36d43120eaa5e2b, + 0xbf9b39e28a16f354, + 0xdcda1344cdca873f, + 0xcd080204256088e5, + 0xb36d43120eaa5e2b, + 0xbf9b39e28a16f354, + ], + [ + 0, + 0xf4a437f2888ae909, + 0xc537d44dc2875403, + 0x7f68007619fd8ba9, + 0, + 0xf4a437f2888ae909, + 0xc537d44dc2875403, + 0x7f68007619fd8ba9, + 0xa4911db6a32612da, + 0x2f7e9aade3fdaec1, + 0xe7ffd578da4ea43d, + 0x43a608e7afa6b5c2, + 0xa4911db6a32612da, + 0x2f7e9aade3fdaec1, + 0xe7ffd578da4ea43d, + 0x43a608e7afa6b5c2, + 0xca46546aa99e1575, + 0xdcda1344cdca873f, + 0xf0674a8dc5a387ec, + 0x904d3f2804a36c54, + 0xca46546aa99e1575, + 0xdcda1344cdca873f, + 0xf0674a8dc5a387ec, + 0x904d3f2804a36c54, + ], + [ + 0, + 0xf97abba0dffb6c50, + 0x5e40f0c9bb82aab5, + 0x5996a80497e24a6b, + 0, + 0xf97abba0dffb6c50, + 0x5e40f0c9bb82aab5, + 0x5996a80497e24a6b, + 0x07084430a7307c9a, + 0xad2f570a5b8545aa, + 0xab7f81fef4274770, + 0xcb81f535cf98c9e9, + 0x07084430a7307c9a, + 0xad2f570a5b8545aa, + 0xab7f81fef4274770, + 0xcb81f535cf98c9e9, + 0x43a608e7afa6b5c2, + 0xaf8e2518a1ece54d, + 0xc75ac6d5b5a10ff3, + 0x77736f524010c932, + 0x43a608e7afa6b5c2, + 0xaf8e2518a1ece54d, + 0xc75ac6d5b5a10ff3, + 0x77736f524010c932, + ], + [ + 0, + 0x7f8e41e0b0a6cdff, + 0x4b1ba8d40afca97d, + 0x623708f28fca70e8, + 0, + 0x7f8e41e0b0a6cdff, + 0x4b1ba8d40afca97d, + 0x623708f28fca70e8, + 0xbf150dc4914d380f, + 0xc26a083554767106, + 0x753b8b1126665c22, + 0xab7f81fef4274770, + 0xbf150dc4914d380f, + 0xc26a083554767106, + 0x753b8b1126665c22, + 0xab7f81fef4274770, + 0xe7ffd578da4ea43d, + 0xe76649f9bd5d5c2e, + 0xad11180f69a8c29e, + 0x00a2d4321cca94fe, + 0xe7ffd578da4ea43d, + 0xe76649f9bd5d5c2e, + 0xad11180f69a8c29e, + 0x00a2d4321cca94fe, + ], + [ + 0, + 0x726af914971c1374, + 0x1d7f8a2cce1a9d00, + 0x18737784700c75cd, + 0, + 0x726af914971c1374, + 0x1d7f8a2cce1a9d00, + 0x18737784700c75cd, + 0x7fb45d605dd82838, + 0x862361aeab0f9b6e, + 0xc26a083554767106, + 0xad2f570a5b8545aa, + 0x7fb45d605dd82838, + 0x862361aeab0f9b6e, + 0xc26a083554767106, + 0xad2f570a5b8545aa, + 0x2f7e9aade3fdaec1, + 0xbba4b5d86b9a3b2c, + 0x7e33ca8c814280de, + 0x3a2432625475e3ae, + 0x2f7e9aade3fdaec1, + 0xbba4b5d86b9a3b2c, + 0x7e33ca8c814280de, + 0x3a2432625475e3ae, + ], + [ + 0, + 0x64dd936da878404d, + 0x4db9a2ead2bd7262, + 0xbe2e19f6d07f1a83, + 0, + 0x64dd936da878404d, + 0x4db9a2ead2bd7262, + 0xbe2e19f6d07f1a83, + 0x02290fe23c20351a, + 0x7fb45d605dd82838, + 0xbf150dc4914d380f, + 0x07084430a7307c9a, + 0x02290fe23c20351a, + 0x7fb45d605dd82838, + 0xbf150dc4914d380f, + 0x07084430a7307c9a, + 0xa4911db6a32612da, + 0x1e8916a99c93a88e, + 0xa9032a52f930fae6, + 0x90e80c591f48b603, + 0xa4911db6a32612da, + 0x1e8916a99c93a88e, + 0xa9032a52f930fae6, + 0x90e80c591f48b603, + ], + [ + 0, + 0x85418a9fef8a9890, + 0xd8a2eb7ef5e707ad, + 0xbfe85ababed2d882, + 0, + 0x85418a9fef8a9890, + 0xd8a2eb7ef5e707ad, + 0xbfe85ababed2d882, + 0xbe2e19f6d07f1a83, + 0x18737784700c75cd, + 0x623708f28fca70e8, + 0x5996a80497e24a6b, + 0xbe2e19f6d07f1a83, + 0x18737784700c75cd, + 0x623708f28fca70e8, + 0x5996a80497e24a6b, + 0x7f68007619fd8ba9, + 0x4715b8e5ab34653b, + 0x0ee5dc0ce131268a, + 0xc1978156516879ad, + 0x7f68007619fd8ba9, + 0x4715b8e5ab34653b, + 0x0ee5dc0ce131268a, + 0xc1978156516879ad, + ], + [ + 0, + 0x156048ee7a738154, + 0x91f7562377e81df5, + 0xd8a2eb7ef5e707ad, + 0, + 0x156048ee7a738154, + 0x91f7562377e81df5, + 0xd8a2eb7ef5e707ad, + 0x4db9a2ead2bd7262, + 0x1d7f8a2cce1a9d00, + 0x4b1ba8d40afca97d, + 0x5e40f0c9bb82aab5, + 0x4db9a2ead2bd7262, + 0x1d7f8a2cce1a9d00, + 0x4b1ba8d40afca97d, + 0x5e40f0c9bb82aab5, + 0xc537d44dc2875403, + 0x14a4a64da0b2668f, + 0xb124c33152a2421a, + 0xdc927721da922cf8, + 0xc537d44dc2875403, + 0x14a4a64da0b2668f, + 0xb124c33152a2421a, + 0xdc927721da922cf8, + ], + [ + 0, + 0xd841e8ef9dde8ba0, + 0x156048ee7a738154, + 0x85418a9fef8a9890, + 0, + 0xd841e8ef9dde8ba0, + 0x156048ee7a738154, + 0x85418a9fef8a9890, + 0x64dd936da878404d, + 0x726af914971c1374, + 0x7f8e41e0b0a6cdff, + 0xf97abba0dffb6c50, + 0x64dd936da878404d, + 0x726af914971c1374, + 0x7f8e41e0b0a6cdff, + 0xf97abba0dffb6c50, + 0xf4a437f2888ae909, + 0xdcedab70f40718ba, + 0xe796d293a47a64cb, + 0x80772dc2645b280b, + 0xf4a437f2888ae909, + 0xdcedab70f40718ba, + 0xe796d293a47a64cb, + 0x80772dc2645b280b, + ], +]; + +#[rustfmt::skip] +pub const ALL_ROUND_CONSTANTS_AVX512: [u64; 2 * SPONGE_WIDTH * N_ROUNDS] = [ + 0xb585f766f2144405, 0x7746a55f43921ad7, 0xb2fb0d31cee799b4, 0xf6760a4803427d7, 0xb585f766f2144405, 0x7746a55f43921ad7, 0xb2fb0d31cee799b4, 0xf6760a4803427d7, + 0xe10d666650f4e012, 0x8cae14cb07d09bf1, 0xd438539c95f63e9f, 0xef781c7ce35b4c3d, 0xe10d666650f4e012, 0x8cae14cb07d09bf1, 0xd438539c95f63e9f, 0xef781c7ce35b4c3d, + 0xcdc4a239b0c44426, 0x277fa208bf337bff, 0xe17653a29da578a1, 0xc54302f225db2c76, 0xcdc4a239b0c44426, 0x277fa208bf337bff, 0xe17653a29da578a1, 0xc54302f225db2c76, + 0x86287821f722c881, 0x59cd1a8a41c18e55, 0xc3b919ad495dc574, 0xa484c4c5ef6a0781, 0x86287821f722c881, 0x59cd1a8a41c18e55, 0xc3b919ad495dc574, 0xa484c4c5ef6a0781, + 0x308bbd23dc5416cc, 0x6e4a40c18f30c09c, 0x9a2eedb70d8f8cfa, 0xe360c6e0ae486f38, 0x308bbd23dc5416cc, 0x6e4a40c18f30c09c, 0x9a2eedb70d8f8cfa, 0xe360c6e0ae486f38, + 0xd5c7718fbfc647fb, 0xc35eae071903ff0b, 0x849c2656969c4be7, 0xc0572c8c08cbbbad, 0xd5c7718fbfc647fb, 0xc35eae071903ff0b, 0x849c2656969c4be7, 0xc0572c8c08cbbbad, + 0xe9fa634a21de0082, 0xf56f6d48959a600d, 0xf7d713e806391165, 0x8297132b32825daf, 0xe9fa634a21de0082, 0xf56f6d48959a600d, 0xf7d713e806391165, 0x8297132b32825daf, + 0xad6805e0e30b2c8a, 0xac51d9f5fcf8535e, 0x502ad7dc18c2ad87, 0x57a1550c110b3041, 0xad6805e0e30b2c8a, 0xac51d9f5fcf8535e, 0x502ad7dc18c2ad87, 0x57a1550c110b3041, + 0x66bbd30e6ce0e583, 0xda2abef589d644e, 0xf061274fdb150d61, 0x28b8ec3ae9c29633, 0x66bbd30e6ce0e583, 0xda2abef589d644e, 0xf061274fdb150d61, 0x28b8ec3ae9c29633, + 0x92a756e67e2b9413, 0x70e741ebfee96586, 0x19d5ee2af82ec1c, 0x6f6f2ed772466352, 0x92a756e67e2b9413, 0x70e741ebfee96586, 0x19d5ee2af82ec1c, 0x6f6f2ed772466352, + 0x7cf416cfe7e14ca1, 0x61df517b86a46439, 0x85dc499b11d77b75, 0x4b959b48b9c10733, 0x7cf416cfe7e14ca1, 0x61df517b86a46439, 0x85dc499b11d77b75, 0x4b959b48b9c10733, + 0xe8be3e5da8043e57, 0xf5c0bc1de6da8699, 0x40b12cbf09ef74bf, 0xa637093ecb2ad631, 0xe8be3e5da8043e57, 0xf5c0bc1de6da8699, 0x40b12cbf09ef74bf, 0xa637093ecb2ad631, + 0x3cc3f892184df408, 0x2e479dc157bf31bb, 0x6f49de07a6234346, 0x213ce7bede378d7b, 0x3cc3f892184df408, 0x2e479dc157bf31bb, 0x6f49de07a6234346, 0x213ce7bede378d7b, + 0x5b0431345d4dea83, 0xa2de45780344d6a1, 0x7103aaf94a7bf308, 0x5326fc0d97279301, 0x5b0431345d4dea83, 0xa2de45780344d6a1, 0x7103aaf94a7bf308, 0x5326fc0d97279301, + 0xa9ceb74fec024747, 0x27f8ec88bb21b1a3, 0xfceb4fda1ded0893, 0xfac6ff1346a41675, 0xa9ceb74fec024747, 0x27f8ec88bb21b1a3, 0xfceb4fda1ded0893, 0xfac6ff1346a41675, + 0x7131aa45268d7d8c, 0x9351036095630f9f, 0xad535b24afc26bfb, 0x4627f5c6993e44be, 0x7131aa45268d7d8c, 0x9351036095630f9f, 0xad535b24afc26bfb, 0x4627f5c6993e44be, + 0x645cf794b8f1cc58, 0x241c70ed0af61617, 0xacb8e076647905f1, 0x3737e9db4c4f474d, 0x645cf794b8f1cc58, 0x241c70ed0af61617, 0xacb8e076647905f1, 0x3737e9db4c4f474d, + 0xe7ea5e33e75fffb6, 0x90dee49fc9bfc23a, 0xd1b1edf76bc09c92, 0xb65481ba645c602, 0xe7ea5e33e75fffb6, 0x90dee49fc9bfc23a, 0xd1b1edf76bc09c92, 0xb65481ba645c602, + 0x99ad1aab0814283b, 0x438a7c91d416ca4d, 0xb60de3bcc5ea751c, 0xc99cab6aef6f58bc, 0x99ad1aab0814283b, 0x438a7c91d416ca4d, 0xb60de3bcc5ea751c, 0xc99cab6aef6f58bc, + 0x69a5ed92a72ee4ff, 0x5e7b329c1ed4ad71, 0x5fc0ac0800144885, 0x32db829239774eca, 0x69a5ed92a72ee4ff, 0x5e7b329c1ed4ad71, 0x5fc0ac0800144885, 0x32db829239774eca, + 0xade699c5830f310, 0x7cc5583b10415f21, 0x85df9ed2e166d64f, 0x6604df4fee32bcb1, 0xade699c5830f310, 0x7cc5583b10415f21, 0x85df9ed2e166d64f, 0x6604df4fee32bcb1, + 0xeb84f608da56ef48, 0xda608834c40e603d, 0x8f97fe408061f183, 0xa93f485c96f37b89, 0xeb84f608da56ef48, 0xda608834c40e603d, 0x8f97fe408061f183, 0xa93f485c96f37b89, + 0x6704e8ee8f18d563, 0xcee3e9ac1e072119, 0x510d0e65e2b470c1, 0xf6323f486b9038f0, 0x6704e8ee8f18d563, 0xcee3e9ac1e072119, 0x510d0e65e2b470c1, 0xf6323f486b9038f0, + 0xb508cdeffa5ceef, 0xf2417089e4fb3cbd, 0x60e75c2890d15730, 0xa6217d8bf660f29c, 0xb508cdeffa5ceef, 0xf2417089e4fb3cbd, 0x60e75c2890d15730, 0xa6217d8bf660f29c, + 0x7159cd30c3ac118e, 0x839b4e8fafead540, 0xd3f3e5e82920adc, 0x8f7d83bddee7bba8, 0x7159cd30c3ac118e, 0x839b4e8fafead540, 0xd3f3e5e82920adc, 0x8f7d83bddee7bba8, + 0x780f2243ea071d06, 0xeb915845f3de1634, 0xd19e120d26b6f386, 0x16ee53a7e5fecc6, 0x780f2243ea071d06, 0xeb915845f3de1634, 0xd19e120d26b6f386, 0x16ee53a7e5fecc6, + 0xcb5fd54e7933e477, 0xacb8417879fd449f, 0x9c22190be7f74732, 0x5d693c1ba3ba3621, 0xcb5fd54e7933e477, 0xacb8417879fd449f, 0x9c22190be7f74732, 0x5d693c1ba3ba3621, + 0xdcef0797c2b69ec7, 0x3d639263da827b13, 0xe273fd971bc8d0e7, 0x418f02702d227ed5, 0xdcef0797c2b69ec7, 0x3d639263da827b13, 0xe273fd971bc8d0e7, 0x418f02702d227ed5, + 0x8c25fda3b503038c, 0x2cbaed4daec8c07c, 0x5f58e6afcdd6ddc2, 0x284650ac5e1b0eba, 0x8c25fda3b503038c, 0x2cbaed4daec8c07c, 0x5f58e6afcdd6ddc2, 0x284650ac5e1b0eba, + 0x635b337ee819dab5, 0x9f9a036ed4f2d49f, 0xb93e260cae5c170e, 0xb0a7eae879ddb76d, 0x635b337ee819dab5, 0x9f9a036ed4f2d49f, 0xb93e260cae5c170e, 0xb0a7eae879ddb76d, + 0xd0762cbc8ca6570c, 0x34c6efb812b04bf5, 0x40bf0ab5fa14c112, 0xb6b570fc7c5740d3, 0xd0762cbc8ca6570c, 0x34c6efb812b04bf5, 0x40bf0ab5fa14c112, 0xb6b570fc7c5740d3, + 0x5a27b9002de33454, 0xb1a5b165b6d2b2d2, 0x8722e0ace9d1be22, 0x788ee3b37e5680fb, 0x5a27b9002de33454, 0xb1a5b165b6d2b2d2, 0x8722e0ace9d1be22, 0x788ee3b37e5680fb, + 0x14a726661551e284, 0x98b7672f9ef3b419, 0xbb93ae776bb30e3a, 0x28fd3b046380f850, 0x14a726661551e284, 0x98b7672f9ef3b419, 0xbb93ae776bb30e3a, 0x28fd3b046380f850, + 0x30a4680593258387, 0x337dc00c61bd9ce1, 0xd5eca244c7a4ff1d, 0x7762638264d279bd, 0x30a4680593258387, 0x337dc00c61bd9ce1, 0xd5eca244c7a4ff1d, 0x7762638264d279bd, + 0xc1e434bedeefd767, 0x299351a53b8ec22, 0xb2d456e4ad251b80, 0x3e9ed1fda49cea0b, 0xc1e434bedeefd767, 0x299351a53b8ec22, 0xb2d456e4ad251b80, 0x3e9ed1fda49cea0b, + 0x2972a92ba450bed8, 0x20216dd77be493de, 0xadffe8cf28449ec6, 0x1c4dbb1c4c27d243, 0x2972a92ba450bed8, 0x20216dd77be493de, 0xadffe8cf28449ec6, 0x1c4dbb1c4c27d243, + 0x15a16a8a8322d458, 0x388a128b7fd9a609, 0x2300e5d6baedf0fb, 0x2f63aa8647e15104, 0x15a16a8a8322d458, 0x388a128b7fd9a609, 0x2300e5d6baedf0fb, 0x2f63aa8647e15104, + 0xf1c36ce86ecec269, 0x27181125183970c9, 0xe584029370dca96d, 0x4d9bbc3e02f1cfb2, 0xf1c36ce86ecec269, 0x27181125183970c9, 0xe584029370dca96d, 0x4d9bbc3e02f1cfb2, + 0xea35bc29692af6f8, 0x18e21b4beabb4137, 0x1e3b9fc625b554f4, 0x25d64362697828fd, 0xea35bc29692af6f8, 0x18e21b4beabb4137, 0x1e3b9fc625b554f4, 0x25d64362697828fd, + 0x5a3f1bb1c53a9645, 0xdb7f023869fb8d38, 0xb462065911d4e1fc, 0x49c24ae4437d8030, 0x5a3f1bb1c53a9645, 0xdb7f023869fb8d38, 0xb462065911d4e1fc, 0x49c24ae4437d8030, + 0xd793862c112b0566, 0xaadd1106730d8feb, 0xc43b6e0e97b0d568, 0xe29024c18ee6fca2, 0xd793862c112b0566, 0xaadd1106730d8feb, 0xc43b6e0e97b0d568, 0xe29024c18ee6fca2, + 0x5e50c27535b88c66, 0x10383f20a4ff9a87, 0x38e8ee9d71a45af8, 0xdd5118375bf1a9b9, 0x5e50c27535b88c66, 0x10383f20a4ff9a87, 0x38e8ee9d71a45af8, 0xdd5118375bf1a9b9, + 0x775005982d74d7f7, 0x86ab99b4dde6c8b0, 0xb1204f603f51c080, 0xef61ac8470250ecf, 0x775005982d74d7f7, 0x86ab99b4dde6c8b0, 0xb1204f603f51c080, 0xef61ac8470250ecf, + 0x1bbcd90f132c603f, 0xcd1dabd964db557, 0x11a3ae5beb9d1ec9, 0xf755bfeea585d11d, 0x1bbcd90f132c603f, 0xcd1dabd964db557, 0x11a3ae5beb9d1ec9, 0xf755bfeea585d11d, + 0xa3b83250268ea4d7, 0x516306f4927c93af, 0xddb4ac49c9efa1da, 0x64bb6dec369d4418, 0xa3b83250268ea4d7, 0x516306f4927c93af, 0xddb4ac49c9efa1da, 0x64bb6dec369d4418, + 0xf9cc95c22b4c1fcc, 0x8d37f755f4ae9f6, 0xeec49b613478675b, 0xf143933aed25e0b0, 0xf9cc95c22b4c1fcc, 0x8d37f755f4ae9f6, 0xeec49b613478675b, 0xf143933aed25e0b0, + 0xe4c5dd8255dfc622, 0xe7ad7756f193198e, 0x92c2318b87fff9cb, 0x739c25f8fd73596d, 0xe4c5dd8255dfc622, 0xe7ad7756f193198e, 0x92c2318b87fff9cb, 0x739c25f8fd73596d, + 0x5636cac9f16dfed0, 0xdd8f909a938e0172, 0xc6401fe115063f5b, 0x8ad97b33f1ac1455, 0x5636cac9f16dfed0, 0xdd8f909a938e0172, 0xc6401fe115063f5b, 0x8ad97b33f1ac1455, + 0xc49366bb25e8513, 0x784d3d2f1698309, 0x530fb67ea1809a81, 0x410492299bb01f49, 0xc49366bb25e8513, 0x784d3d2f1698309, 0x530fb67ea1809a81, 0x410492299bb01f49, + 0x139542347424b9ac, 0x9cb0bd5ea1a1115e, 0x2e3f615c38f49a1, 0x985d4f4a9c5291ef, 0x139542347424b9ac, 0x9cb0bd5ea1a1115e, 0x2e3f615c38f49a1, 0x985d4f4a9c5291ef, + 0x775b9feafdcd26e7, 0x304265a6384f0f2d, 0x593664c39773012c, 0x4f0a2e5fb028f2ce, 0x775b9feafdcd26e7, 0x304265a6384f0f2d, 0x593664c39773012c, 0x4f0a2e5fb028f2ce, + 0xdd611f1000c17442, 0xd8185f9adfea4fd0, 0xef87139ca9a3ab1e, 0x3ba71336c34ee133, 0xdd611f1000c17442, 0xd8185f9adfea4fd0, 0xef87139ca9a3ab1e, 0x3ba71336c34ee133, + 0x7d3a455d56b70238, 0x660d32e130182684, 0x297a863f48cd1f43, 0x90e0a736a751ebb7, 0x7d3a455d56b70238, 0x660d32e130182684, 0x297a863f48cd1f43, 0x90e0a736a751ebb7, + 0x549f80ce550c4fd3, 0xf73b2922f38bd64, 0x16bf1f73fb7a9c3f, 0x6d1f5a59005bec17, 0x549f80ce550c4fd3, 0xf73b2922f38bd64, 0x16bf1f73fb7a9c3f, 0x6d1f5a59005bec17, + 0x2ff876fa5ef97c4, 0xc5cb72a2a51159b0, 0x8470f39d2d5c900e, 0x25abb3f1d39fcb76, 0x2ff876fa5ef97c4, 0xc5cb72a2a51159b0, 0x8470f39d2d5c900e, 0x25abb3f1d39fcb76, + 0x23eb8cc9b372442f, 0xd687ba55c64f6364, 0xda8d9e90fd8ff158, 0xe3cbdc7d2fe45ea7, 0x23eb8cc9b372442f, 0xd687ba55c64f6364, 0xda8d9e90fd8ff158, 0xe3cbdc7d2fe45ea7, + 0xb9a8c9b3aee52297, 0xc0d28a5c10960bd3, 0x45d7ac9b68f71a34, 0xeeb76e397069e804, 0xb9a8c9b3aee52297, 0xc0d28a5c10960bd3, 0x45d7ac9b68f71a34, 0xeeb76e397069e804, + 0x3d06c8bd1514e2d9, 0x9c9c98207cb10767, 0x65700b51aedfb5ef, 0x911f451539869408, 0x3d06c8bd1514e2d9, 0x9c9c98207cb10767, 0x65700b51aedfb5ef, 0x911f451539869408, + 0x7ae6849fbc3a0ec6, 0x3bb340eba06afe7e, 0xb46e9d8b682ea65e, 0x8dcf22f9a3b34356, 0x7ae6849fbc3a0ec6, 0x3bb340eba06afe7e, 0xb46e9d8b682ea65e, 0x8dcf22f9a3b34356, + 0x77bdaeda586257a7, 0xf19e400a5104d20d, 0xc368a348e46d950f, 0x9ef1cd60e679f284, 0x77bdaeda586257a7, 0xf19e400a5104d20d, 0xc368a348e46d950f, 0x9ef1cd60e679f284, + 0xe89cd854d5d01d33, 0x5cd377dc8bb882a2, 0xa7b0fb7883eee860, 0x7684403ec392950d, 0xe89cd854d5d01d33, 0x5cd377dc8bb882a2, 0xa7b0fb7883eee860, 0x7684403ec392950d, + 0x5fa3f06f4fed3b52, 0x8df57ac11bc04831, 0x2db01efa1e1e1897, 0x54846de4aadb9ca2, 0x5fa3f06f4fed3b52, 0x8df57ac11bc04831, 0x2db01efa1e1e1897, 0x54846de4aadb9ca2, + 0xba6745385893c784, 0x541d496344d2c75b, 0xe909678474e687fe, 0xdfe89923f6c9c2ff, 0xba6745385893c784, 0x541d496344d2c75b, 0xe909678474e687fe, 0xdfe89923f6c9c2ff, + 0xece5a71e0cfedc75, 0x5ff98fd5d51fe610, 0x83e8941918964615, 0x5922040b47f150c1, 0xece5a71e0cfedc75, 0x5ff98fd5d51fe610, 0x83e8941918964615, 0x5922040b47f150c1, + 0xf97d750e3dd94521, 0x5080d4c2b86f56d7, 0xa7de115b56c78d70, 0x6a9242ac87538194, 0xf97d750e3dd94521, 0x5080d4c2b86f56d7, 0xa7de115b56c78d70, 0x6a9242ac87538194, + 0xf7856ef7f9173e44, 0x2265fc92feb0dc09, 0x17dfc8e4f7ba8a57, 0x9001a64209f21db8, 0xf7856ef7f9173e44, 0x2265fc92feb0dc09, 0x17dfc8e4f7ba8a57, 0x9001a64209f21db8, + 0x90004c1371b893c5, 0xb932b7cf752e5545, 0xa0b1df81b6fe59fc, 0x8ef1dd26770af2c2, 0x90004c1371b893c5, 0xb932b7cf752e5545, 0xa0b1df81b6fe59fc, 0x8ef1dd26770af2c2, + 0x541a4f9cfbeed35, 0x9e61106178bfc530, 0xb3767e80935d8af2, 0x98d5782065af06, 0x541a4f9cfbeed35, 0x9e61106178bfc530, 0xb3767e80935d8af2, 0x98d5782065af06, + 0x31d191cd5c1466c7, 0x410fefafa319ac9d, 0xbdf8f242e316c4ab, 0x9e8cd55b57637ed0, 0x31d191cd5c1466c7, 0x410fefafa319ac9d, 0xbdf8f242e316c4ab, 0x9e8cd55b57637ed0, + 0xde122bebe9a39368, 0x4d001fd58f002526, 0xca6637000eb4a9f8, 0x2f2339d624f91f78, 0xde122bebe9a39368, 0x4d001fd58f002526, 0xca6637000eb4a9f8, 0x2f2339d624f91f78, + 0x6d1a7918c80df518, 0xdf9a4939342308e9, 0xebc2151ee6c8398c, 0x3cc2ba8a1116515, 0x6d1a7918c80df518, 0xdf9a4939342308e9, 0xebc2151ee6c8398c, 0x3cc2ba8a1116515, + 0xd341d037e840cf83, 0x387cb5d25af4afcc, 0xbba2515f22909e87, 0x7248fe7705f38e47, 0xd341d037e840cf83, 0x387cb5d25af4afcc, 0xbba2515f22909e87, 0x7248fe7705f38e47, + 0x4d61e56a525d225a, 0x262e963c8da05d3d, 0x59e89b094d220ec2, 0x55d5b52b78b9c5e, 0x4d61e56a525d225a, 0x262e963c8da05d3d, 0x59e89b094d220ec2, 0x55d5b52b78b9c5e, + 0x82b27eb33514ef99, 0xd30094ca96b7ce7b, 0xcf5cb381cd0a1535, 0xfeed4db6919e5a7c, 0x82b27eb33514ef99, 0xd30094ca96b7ce7b, 0xcf5cb381cd0a1535, 0xfeed4db6919e5a7c, + 0x41703f53753be59f, 0x5eeea940fcde8b6f, 0x4cd1f1b175100206, 0x4a20358574454ec0, 0x41703f53753be59f, 0x5eeea940fcde8b6f, 0x4cd1f1b175100206, 0x4a20358574454ec0, + 0x1478d361dbbf9fac, 0x6f02dc07d141875c, 0x296a202ed8e556a2, 0x2afd67999bf32ee5, 0x1478d361dbbf9fac, 0x6f02dc07d141875c, 0x296a202ed8e556a2, 0x2afd67999bf32ee5, + 0x7acfd96efa95491d, 0x6798ba0c0abb2c6d, 0x34c6f57b26c92122, 0x5736e1bad206b5de, 0x7acfd96efa95491d, 0x6798ba0c0abb2c6d, 0x34c6f57b26c92122, 0x5736e1bad206b5de, + 0x20057d2a0056521b, 0x3dea5bd5d0578bd7, 0x16e50d897d4634ac, 0x29bff3ecb9b7a6e3, 0x20057d2a0056521b, 0x3dea5bd5d0578bd7, 0x16e50d897d4634ac, 0x29bff3ecb9b7a6e3, + 0x475cd3205a3bdcde, 0x18a42105c31b7e88, 0x23e7414af663068, 0x15147108121967d7, 0x475cd3205a3bdcde, 0x18a42105c31b7e88, 0x23e7414af663068, 0x15147108121967d7, + 0xe4a3dff1d7d6fef9, 0x1a8d1a588085737, 0x11b4c74eda62beef, 0xe587cc0d69a73346, 0xe4a3dff1d7d6fef9, 0x1a8d1a588085737, 0x11b4c74eda62beef, 0xe587cc0d69a73346, + 0x1ff7327017aa2a6e, 0x594e29c42473d06b, 0xf6f31db1899b12d5, 0xc02ac5e47312d3ca, 0x1ff7327017aa2a6e, 0x594e29c42473d06b, 0xf6f31db1899b12d5, 0xc02ac5e47312d3ca, + 0xe70201e960cb78b8, 0x6f90ff3b6a65f108, 0x42747a7245e7fa84, 0xd1f507e43ab749b2, 0xe70201e960cb78b8, 0x6f90ff3b6a65f108, 0x42747a7245e7fa84, 0xd1f507e43ab749b2, + 0x1c86d265f15750cd, 0x3996ce73dd832c1c, 0x8e7fba02983224bd, 0xba0dec7103255dd4, 0x1c86d265f15750cd, 0x3996ce73dd832c1c, 0x8e7fba02983224bd, 0xba0dec7103255dd4, + 0x9e9cbd781628fc5b, 0xdae8645996edd6a5, 0xdebe0853b1a1d378, 0xa49229d24d014343, 0x9e9cbd781628fc5b, 0xdae8645996edd6a5, 0xdebe0853b1a1d378, 0xa49229d24d014343, + 0x7be5b9ffda905e1c, 0xa3c95eaec244aa30, 0x230bca8f4df0544, 0x4135c2bebfe148c6, 0x7be5b9ffda905e1c, 0xa3c95eaec244aa30, 0x230bca8f4df0544, 0x4135c2bebfe148c6, + 0x166fc0cc438a3c72, 0x3762b59a8ae83efa, 0xe8928a4c89114750, 0x2a440b51a4945ee5, 0x166fc0cc438a3c72, 0x3762b59a8ae83efa, 0xe8928a4c89114750, 0x2a440b51a4945ee5, + 0x80cefd2b7d99ff83, 0xbb9879c6e61fd62a, 0x6e7c8f1a84265034, 0x164bb2de1bbeddc8, 0x80cefd2b7d99ff83, 0xbb9879c6e61fd62a, 0x6e7c8f1a84265034, 0x164bb2de1bbeddc8, + 0xf3c12fe54d5c653b, 0x40b9e922ed9771e2, 0x551f5b0fbe7b1840, 0x25032aa7c4cb1811, 0xf3c12fe54d5c653b, 0x40b9e922ed9771e2, 0x551f5b0fbe7b1840, 0x25032aa7c4cb1811, + 0xaaed34074b164346, 0x8ffd96bbf9c9c81d, 0x70fc91eb5937085c, 0x7f795e2a5f915440, 0xaaed34074b164346, 0x8ffd96bbf9c9c81d, 0x70fc91eb5937085c, 0x7f795e2a5f915440, + 0x4543d9df5476d3cb, 0xf172d73e004fc90d, 0xdfd1c4febcc81238, 0xbc8dfb627fe558fc, 0x4543d9df5476d3cb, 0xf172d73e004fc90d, 0xdfd1c4febcc81238, 0xbc8dfb627fe558fc, +]; + +const FAST_PARTIAL_ROUND_VS_AVX512: [[u64; 24]; N_PARTIAL_ROUNDS] = [ + [ + 0x0, + 0x94877900674181c3, + 0xc6c67cc37a2a2bbd, + 0xd667c2055387940f, + 0x0, + 0x94877900674181c3, + 0xc6c67cc37a2a2bbd, + 0xd667c2055387940f, + 0xba63a63e94b5ff0, + 0x99460cc41b8f079f, + 0x7ff02375ed524bb3, + 0xea0870b47a8caf0e, + 0xba63a63e94b5ff0, + 0x99460cc41b8f079f, + 0x7ff02375ed524bb3, + 0xea0870b47a8caf0e, + 0xabcad82633b7bc9d, + 0x3b8d135261052241, + 0xfb4515f5e5b0d539, + 0x3ee8011c2b37f77c, + 0xabcad82633b7bc9d, + 0x3b8d135261052241, + 0xfb4515f5e5b0d539, + 0x3ee8011c2b37f77c, + ], + [ + 0x0, + 0xadef3740e71c726, + 0xa37bf67c6f986559, + 0xc6b16f7ed4fa1b00, + 0x0, + 0xadef3740e71c726, + 0xa37bf67c6f986559, + 0xc6b16f7ed4fa1b00, + 0x6a065da88d8bfc3c, + 0x4cabc0916844b46f, + 0x407faac0f02e78d1, + 0x7a786d9cf0852cf, + 0x6a065da88d8bfc3c, + 0x4cabc0916844b46f, + 0x407faac0f02e78d1, + 0x7a786d9cf0852cf, + 0x42433fb6949a629a, + 0x891682a147ce43b0, + 0x26cfd58e7b003b55, + 0x2bbf0ed7b657acb3, + 0x42433fb6949a629a, + 0x891682a147ce43b0, + 0x26cfd58e7b003b55, + 0x2bbf0ed7b657acb3, + ], + [ + 0x0, + 0x481ac7746b159c67, + 0xe367de32f108e278, + 0x73f260087ad28bec, + 0x0, + 0x481ac7746b159c67, + 0xe367de32f108e278, + 0x73f260087ad28bec, + 0x5cfc82216bc1bdca, + 0xcaccc870a2663a0e, + 0xdb69cd7b4298c45d, + 0x7bc9e0c57243e62d, + 0x5cfc82216bc1bdca, + 0xcaccc870a2663a0e, + 0xdb69cd7b4298c45d, + 0x7bc9e0c57243e62d, + 0x3cc51c5d368693ae, + 0x366b4e8cc068895b, + 0x2bd18715cdabbca4, + 0xa752061c4f33b8cf, + 0x3cc51c5d368693ae, + 0x366b4e8cc068895b, + 0x2bd18715cdabbca4, + 0xa752061c4f33b8cf, + ], + [ + 0x0, + 0xb22d2432b72d5098, + 0x9e18a487f44d2fe4, + 0x4b39e14ce22abd3c, + 0x0, + 0xb22d2432b72d5098, + 0x9e18a487f44d2fe4, + 0x4b39e14ce22abd3c, + 0x9e77fde2eb315e0d, + 0xca5e0385fe67014d, + 0xc2cb99bf1b6bddb, + 0x99ec1cd2a4460bfe, + 0x9e77fde2eb315e0d, + 0xca5e0385fe67014d, + 0xc2cb99bf1b6bddb, + 0x99ec1cd2a4460bfe, + 0x8577a815a2ff843f, + 0x7d80a6b4fd6518a5, + 0xeb6c67123eab62cb, + 0x8f7851650eca21a5, + 0x8577a815a2ff843f, + 0x7d80a6b4fd6518a5, + 0xeb6c67123eab62cb, + 0x8f7851650eca21a5, + ], + [ + 0x0, + 0x11ba9a1b81718c2a, + 0x9f7d798a3323410c, + 0xa821855c8c1cf5e5, + 0x0, + 0x11ba9a1b81718c2a, + 0x9f7d798a3323410c, + 0xa821855c8c1cf5e5, + 0x535e8d6fac0031b2, + 0x404e7c751b634320, + 0xa729353f6e55d354, + 0x4db97d92e58bb831, + 0x535e8d6fac0031b2, + 0x404e7c751b634320, + 0xa729353f6e55d354, + 0x4db97d92e58bb831, + 0xb53926c27897bf7d, + 0x965040d52fe115c5, + 0x9565fa41ebd31fd7, + 0xaae4438c877ea8f4, + 0xb53926c27897bf7d, + 0x965040d52fe115c5, + 0x9565fa41ebd31fd7, + 0xaae4438c877ea8f4, + ], + [ + 0x0, + 0x37f4e36af6073c6e, + 0x4edc0918210800e9, + 0xc44998e99eae4188, + 0x0, + 0x37f4e36af6073c6e, + 0x4edc0918210800e9, + 0xc44998e99eae4188, + 0x9f4310d05d068338, + 0x9ec7fe4350680f29, + 0xc5b2c1fdc0b50874, + 0xa01920c5ef8b2ebe, + 0x9f4310d05d068338, + 0x9ec7fe4350680f29, + 0xc5b2c1fdc0b50874, + 0xa01920c5ef8b2ebe, + 0x59fa6f8bd91d58ba, + 0x8bfc9eb89b515a82, + 0xbe86a7a2555ae775, + 0xcbb8bbaa3810babf, + 0x59fa6f8bd91d58ba, + 0x8bfc9eb89b515a82, + 0xbe86a7a2555ae775, + 0xcbb8bbaa3810babf, + ], + [ + 0x0, + 0x577f9a9e7ee3f9c2, + 0x88c522b949ace7b1, + 0x82f07007c8b72106, + 0x0, + 0x577f9a9e7ee3f9c2, + 0x88c522b949ace7b1, + 0x82f07007c8b72106, + 0x8283d37c6675b50e, + 0x98b074d9bbac1123, + 0x75c56fb7758317c1, + 0xfed24e206052bc72, + 0x8283d37c6675b50e, + 0x98b074d9bbac1123, + 0x75c56fb7758317c1, + 0xfed24e206052bc72, + 0x26d7c3d1bc07dae5, + 0xf88c5e441e28dbb4, + 0x4fe27f9f96615270, + 0x514d4ba49c2b14fe, + 0x26d7c3d1bc07dae5, + 0xf88c5e441e28dbb4, + 0x4fe27f9f96615270, + 0x514d4ba49c2b14fe, + ], + [ + 0x0, + 0xf02a3ac068ee110b, + 0xa3630dafb8ae2d7, + 0xce0dc874eaf9b55c, + 0x0, + 0xf02a3ac068ee110b, + 0xa3630dafb8ae2d7, + 0xce0dc874eaf9b55c, + 0x9a95f6cff5b55c7e, + 0x626d76abfed00c7b, + 0xa0c1cf1251c204ad, + 0xdaebd3006321052c, + 0x9a95f6cff5b55c7e, + 0x626d76abfed00c7b, + 0xa0c1cf1251c204ad, + 0xdaebd3006321052c, + 0x3d4bd48b625a8065, + 0x7f1e584e071f6ed2, + 0x720574f0501caed3, + 0xe3260ba93d23540a, + 0x3d4bd48b625a8065, + 0x7f1e584e071f6ed2, + 0x720574f0501caed3, + 0xe3260ba93d23540a, + ], + [ + 0x0, + 0xab1cbd41d8c1e335, + 0x9322ed4c0bc2df01, + 0x51c3c0983d4284e5, + 0x0, + 0xab1cbd41d8c1e335, + 0x9322ed4c0bc2df01, + 0x51c3c0983d4284e5, + 0x94178e291145c231, + 0xfd0f1a973d6b2085, + 0xd427ad96e2b39719, + 0x8a52437fecaac06b, + 0x94178e291145c231, + 0xfd0f1a973d6b2085, + 0xd427ad96e2b39719, + 0x8a52437fecaac06b, + 0xdc20ee4b8c4c9a80, + 0xa2c98e9549da2100, + 0x1603fe12613db5b6, + 0xe174929433c5505, + 0xdc20ee4b8c4c9a80, + 0xa2c98e9549da2100, + 0x1603fe12613db5b6, + 0xe174929433c5505, + ], + [ + 0x0, + 0x3d4eab2b8ef5f796, + 0xcfff421583896e22, + 0x4143cb32d39ac3d9, + 0x0, + 0x3d4eab2b8ef5f796, + 0xcfff421583896e22, + 0x4143cb32d39ac3d9, + 0x22365051b78a5b65, + 0x6f7fd010d027c9b6, + 0xd9dd36fba77522ab, + 0xa44cf1cb33e37165, + 0x22365051b78a5b65, + 0x6f7fd010d027c9b6, + 0xd9dd36fba77522ab, + 0xa44cf1cb33e37165, + 0x3fc83d3038c86417, + 0xc4588d418e88d270, + 0xce1320f10ab80fe2, + 0xdb5eadbbec18de5d, + 0x3fc83d3038c86417, + 0xc4588d418e88d270, + 0xce1320f10ab80fe2, + 0xdb5eadbbec18de5d, + ], + [ + 0x0, + 0x1183dfce7c454afd, + 0x21cea4aa3d3ed949, + 0xfce6f70303f2304, + 0x0, + 0x1183dfce7c454afd, + 0x21cea4aa3d3ed949, + 0xfce6f70303f2304, + 0x19557d34b55551be, + 0x4c56f689afc5bbc9, + 0xa1e920844334f944, + 0xbad66d423d2ec861, + 0x19557d34b55551be, + 0x4c56f689afc5bbc9, + 0xa1e920844334f944, + 0xbad66d423d2ec861, + 0xf318c785dc9e0479, + 0x99e2032e765ddd81, + 0x400ccc9906d66f45, + 0xe1197454db2e0dd9, + 0xf318c785dc9e0479, + 0x99e2032e765ddd81, + 0x400ccc9906d66f45, + 0xe1197454db2e0dd9, + ], + [ + 0x0, + 0x84d1ecc4d53d2ff1, + 0xd8af8b9ceb4e11b6, + 0x335856bb527b52f4, + 0x0, + 0x84d1ecc4d53d2ff1, + 0xd8af8b9ceb4e11b6, + 0x335856bb527b52f4, + 0xc756f17fb59be595, + 0xc0654e4ea5553a78, + 0x9e9a46b61f2ea942, + 0x14fc8b5b3b809127, + 0xc756f17fb59be595, + 0xc0654e4ea5553a78, + 0x9e9a46b61f2ea942, + 0x14fc8b5b3b809127, + 0xd7009f0f103be413, + 0x3e0ee7b7a9fb4601, + 0xa74e888922085ed7, + 0xe80a7cde3d4ac526, + 0xd7009f0f103be413, + 0x3e0ee7b7a9fb4601, + 0xa74e888922085ed7, + 0xe80a7cde3d4ac526, + ], + [ + 0x0, + 0x238aa6daa612186d, + 0x9137a5c630bad4b4, + 0xc7db3817870c5eda, + 0x0, + 0x238aa6daa612186d, + 0x9137a5c630bad4b4, + 0xc7db3817870c5eda, + 0x217e4f04e5718dc9, + 0xcae814e2817bd99d, + 0xe3292e7ab770a8ba, + 0x7bb36ef70b6b9482, + 0x217e4f04e5718dc9, + 0xcae814e2817bd99d, + 0xe3292e7ab770a8ba, + 0x7bb36ef70b6b9482, + 0x3c7835fb85bca2d3, + 0xfe2cdf8ee3c25e86, + 0x61b3915ad7274b20, + 0xeab75ca7c918e4ef, + 0x3c7835fb85bca2d3, + 0xfe2cdf8ee3c25e86, + 0x61b3915ad7274b20, + 0xeab75ca7c918e4ef, + ], + [ + 0x0, + 0xd6e15ffc055e154e, + 0xec67881f381a32bf, + 0xfbb1196092bf409c, + 0x0, + 0xd6e15ffc055e154e, + 0xec67881f381a32bf, + 0xfbb1196092bf409c, + 0xdc9d2e07830ba226, + 0x698ef3245ff7988, + 0x194fae2974f8b576, + 0x7a5d9bea6ca4910e, + 0xdc9d2e07830ba226, + 0x698ef3245ff7988, + 0x194fae2974f8b576, + 0x7a5d9bea6ca4910e, + 0x7aebfea95ccdd1c9, + 0xf9bd38a67d5f0e86, + 0xfa65539de65492d8, + 0xf0dfcbe7653ff787, + 0x7aebfea95ccdd1c9, + 0xf9bd38a67d5f0e86, + 0xfa65539de65492d8, + 0xf0dfcbe7653ff787, + ], + [ + 0x0, + 0xbd87ad390420258, + 0xad8617bca9e33c8, + 0xc00ad377a1e2666, + 0x0, + 0xbd87ad390420258, + 0xad8617bca9e33c8, + 0xc00ad377a1e2666, + 0xac6fc58b3f0518f, + 0xc0cc8a892cc4173, + 0xc210accb117bc21, + 0xb73630dbb46ca18, + 0xac6fc58b3f0518f, + 0xc0cc8a892cc4173, + 0xc210accb117bc21, + 0xb73630dbb46ca18, + 0xc8be4920cbd4a54, + 0xbfe877a21be1690, + 0xae790559b0ded81, + 0xbf50db2f8d6ce31, + 0xc8be4920cbd4a54, + 0xbfe877a21be1690, + 0xae790559b0ded81, + 0xbf50db2f8d6ce31, + ], + [ + 0x0, + 0xcf29427ff7c58, + 0xbd9b3cf49eec8, + 0xd1dc8aa81fb26, + 0x0, + 0xcf29427ff7c58, + 0xbd9b3cf49eec8, + 0xd1dc8aa81fb26, + 0xbc792d5c394ef, + 0xd2ae0b2266453, + 0xd413f12c496c1, + 0xc84128cfed618, + 0xbc792d5c394ef, + 0xd2ae0b2266453, + 0xd413f12c496c1, + 0xc84128cfed618, + 0xdb5ebd48fc0d4, + 0xd1b77326dcb90, + 0xbeb0ccc145421, + 0xd10e5b22b11d1, + 0xdb5ebd48fc0d4, + 0xd1b77326dcb90, + 0xbeb0ccc145421, + 0xd10e5b22b11d1, + ], + [ + 0x0, + 0xe24c99adad8, + 0xcf389ed4bc8, + 0xe580cbf6966, + 0x0, + 0xe24c99adad8, + 0xcf389ed4bc8, + 0xe580cbf6966, + 0xcde5fd7e04f, + 0xe63628041b3, + 0xe7e81a87361, + 0xdabe78f6d98, + 0xcde5fd7e04f, + 0xe63628041b3, + 0xe7e81a87361, + 0xdabe78f6d98, + 0xefb14cac554, + 0xe5574743b10, + 0xd05709f42c1, + 0xe4690c96af1, + 0xefb14cac554, + 0xe5574743b10, + 0xd05709f42c1, + 0xe4690c96af1, + ], + [ + 0x0, + 0xf7157bc98, + 0xe3006d948, + 0xfa65811e6, + 0x0, + 0xf7157bc98, + 0xe3006d948, + 0xfa65811e6, + 0xe0d127e2f, + 0xfc18bfe53, + 0xfd002d901, + 0xeed6461d8, + 0xe0d127e2f, + 0xfc18bfe53, + 0xfd002d901, + 0xeed6461d8, + 0x1068562754, + 0xfa0236f50, + 0xe3af13ee1, + 0xfa460f6d1, + 0x1068562754, + 0xfa0236f50, + 0xe3af13ee1, + 0xfa460f6d1, + ], + [ + 0x0, 0x11131738, 0xf56d588, 0x11050f86, 0x0, 0x11131738, 0xf56d588, 0x11050f86, 0xf848f4f, + 0x111527d3, 0x114369a1, 0x106f2f38, 0xf848f4f, 0x111527d3, 0x114369a1, 0x106f2f38, + 0x11e2ca94, 0x110a29f0, 0xfa9f5c1, 0x10f625d1, 0x11e2ca94, 0x110a29f0, 0xfa9f5c1, + 0x10f625d1, + ], + [ + 0x0, 0x11f718, 0x10b6c8, 0x134a96, 0x0, 0x11f718, 0x10b6c8, 0x134a96, 0x10cf7f, 0x124d03, + 0x13f8a1, 0x117c58, 0x10cf7f, 0x124d03, 0x13f8a1, 0x117c58, 0x132c94, 0x134fc0, 0x10a091, + 0x128961, 0x132c94, 0x134fc0, 0x10a091, 0x128961, + ], + [ + 0x0, 0x1300, 0x1750, 0x114e, 0x0, 0x1300, 0x1750, 0x114e, 0x131f, 0x167b, 0x1371, 0x1230, + 0x131f, 0x167b, 0x1371, 0x1230, 0x182c, 0x1368, 0xf31, 0x15c9, 0x182c, 0x1368, 0xf31, + 0x15c9, + ], + [ + 0x0, 0x14, 0x22, 0x12, 0x0, 0x14, 0x22, 0x12, 0x27, 0xd, 0xd, 0x1c, 0x27, 0xd, 0xd, 0x1c, + 0x2, 0x10, 0x29, 0xf, 0x2, 0x10, 0x29, 0xf, + ], +]; + +#[allow(unused)] #[inline(always)] #[unroll_for_loops] fn mds_partial_layer_init_avx(state: &mut [F; SPONGE_WIDTH]) @@ -247,8 +1161,8 @@ where ); let m0 = mult_avx512(&sr512, &t0); let m1 = mult_avx512(&sr512, &t1); - r0 = add_avx512_b_c(&r0, &m0); - r1 = add_avx512_b_c(&r1, &m1); + r0 = add_avx512(&r0, &m0); + r1 = add_avx512(&r1, &m1); } _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), r0); _mm512_storeu_si512((state[4..12]).as_mut_ptr().cast::(), r1); @@ -256,6 +1170,7 @@ where } } +#[allow(unused)] #[inline(always)] #[unroll_for_loops] fn partial_first_constant_layer_avx(state: &mut [F; SPONGE_WIDTH]) @@ -275,8 +1190,8 @@ where ); let mut s0 = _mm512_loadu_si512((state[0..8]).as_ptr().cast::()); let mut s1 = _mm512_loadu_si512((state[4..12]).as_ptr().cast::()); - s0 = add_avx512_b_c(&s0, &c0); - s1 = add_avx512_b_c(&s1, &c1); + s0 = add_avx512(&s0, &c0); + s1 = add_avx512(&s1, &c1); _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), s0); _mm512_storeu_si512((state[4..12]).as_mut_ptr().cast::(), s1); } @@ -294,7 +1209,383 @@ where x3 * x4 } -pub fn poseidon_avx512(input: &[F; SPONGE_WIDTH]) -> [F; SPONGE_WIDTH] +#[inline(always)] +unsafe fn fft2_real_avx512(x0: &__m512i, x1: &__m512i) -> (__m512i, __m512i) { + let y0 = _mm512_add_epi64(*x0, *x1); + let y1 = _mm512_sub_epi64(*x0, *x1); + (y0, y1) +} + +#[inline(always)] +unsafe fn fft4_real_avx512( + x0: &__m512i, + x1: &__m512i, + x2: &__m512i, + x3: &__m512i, +) -> (__m512i, __m512i, __m512i, __m512i) { + let zeros = _mm512_set_epi64(0, 0, 0, 0, 0, 0, 0, 0); + let (z0, z2) = fft2_real_avx512(x0, x2); + let (z1, z3) = fft2_real_avx512(x1, x3); + let y0 = _mm512_add_epi64(z0, z1); + let y2 = _mm512_sub_epi64(z0, z1); + let y3 = _mm512_sub_epi64(zeros, z3); + (y0, z2, y3, y2) +} + +#[inline(always)] +unsafe fn ifft2_real_unreduced_avx512(y0: &__m512i, y1: &__m512i) -> (__m512i, __m512i) { + let x0 = _mm512_add_epi64(*y0, *y1); + let x1 = _mm512_sub_epi64(*y0, *y1); + (x0, x1) +} + +#[inline(always)] +unsafe fn ifft4_real_unreduced_avx512( + y: (__m512i, (__m512i, __m512i), __m512i), +) -> (__m512i, __m512i, __m512i, __m512i) { + let zeros = _mm512_set_epi64(0, 0, 0, 0, 0, 0, 0, 0); + let z0 = _mm512_add_epi64(y.0, y.2); + let z1 = _mm512_sub_epi64(y.0, y.2); + let z2 = y.1 .0; + let z3 = _mm512_sub_epi64(zeros, y.1 .1); + let (x0, x2) = ifft2_real_unreduced_avx512(&z0, &z2); + let (x1, x3) = ifft2_real_unreduced_avx512(&z1, &z3); + (x0, x1, x2, x3) +} + +#[inline] +pub unsafe fn add64_no_carry_avx512(a: &__m512i, b: &__m512i) -> (__m512i, __m512i) { + /* + * a and b are signed 4 x i64. Suppose a and b represent only one i64, then: + * - (test 1): if a < 2^63 and b < 2^63 (this means a >= 0 and b >= 0) => sum does not overflow => cout = 0 + * - if a >= 2^63 and b >= 2^63 => sum overflows so sum = a + b and cout = 1 + * - (test 2): if (a < 2^63 and b >= 2^63) or (a >= 2^63 and b < 2^63) + * - (test 3): if a + b < 2^64 (this means a + b is negative in signed representation) => no overflow so cout = 0 + * - (test 3): if a + b >= 2^64 (this means a + b becomes positive in signed representation, that is, a + b >= 0) => there is overflow so cout = 1 + */ + let ones = _mm512_set_epi64(1, 1, 1, 1, 1, 1, 1, 1); + let zeros = _mm512_set_epi64(0, 0, 0, 0, 0, 0, 0, 0); + let r = _mm512_add_epi64(*a, *b); + let ma = _mm512_cmpgt_epi64_mask(zeros, *a); + let mb = _mm512_cmpgt_epi64_mask(zeros, *b); + let mc = _mm512_cmpgt_epi64_mask(zeros, r); + let m = (ma & mb) | (!mc & ((!ma & mb) | (ma & !mb))); + let co = _mm512_mask_blend_epi64(m, zeros, ones); + (r, co) +} + +#[inline] +pub unsafe fn mul64_no_overflow_avx512(a: &__m512i, b: &__m512i) -> __m512i { + let r = _mm512_mul_epu32(*a, *b); + let ah = _mm512_srli_epi64(*a, 32); + let bh = _mm512_srli_epi64(*b, 32); + let r1 = _mm512_mul_epu32(*a, bh); + let r1 = _mm512_slli_epi64(r1, 32); + let r = _mm512_add_epi64(r, r1); + let r1 = _mm512_mul_epu32(ah, *b); + let r1 = _mm512_slli_epi64(r1, 32); + let r = _mm512_add_epi64(r, r1); + r +} + +#[inline(always)] +unsafe fn block1_avx512(x: &__m512i, y: [i64; 3]) -> __m512i { + let x0 = _mm512_permutex_epi64(*x, 0x0); + let x1 = _mm512_permutex_epi64(*x, 0x55); + let x2 = _mm512_permutex_epi64(*x, 0xAA); + + let f0 = _mm512_set_epi64(0, y[2], y[1], y[0], 0, y[2], y[1], y[0]); + let f1 = _mm512_set_epi64(0, y[1], y[0], y[2], 0, y[1], y[0], y[2]); + let f2 = _mm512_set_epi64(0, y[0], y[2], y[1], 0, y[0], y[2], y[1]); + + let t0 = mul64_no_overflow_avx512(&x0, &f0); + let t1 = mul64_no_overflow_avx512(&x1, &f1); + let t2 = mul64_no_overflow_avx512(&x2, &f2); + + let t0 = _mm512_add_epi64(t0, t1); + _mm512_add_epi64(t0, t2) +} + +#[inline(always)] +unsafe fn block2_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> (__m512i, __m512i) { + let mut vxr: [i64; 8] = [0; 8]; + let mut vxi: [i64; 8] = [0; 8]; + _mm512_storeu_si512(vxr.as_mut_ptr().cast::(), *xr); + _mm512_storeu_si512(vxi.as_mut_ptr().cast::(), *xi); + let x1: [(i64, i64); 3] = [(vxr[0], vxi[0]), (vxr[1], vxi[1]), (vxr[2], vxi[2])]; + let x2: [(i64, i64); 3] = [(vxr[4], vxi[4]), (vxr[5], vxi[5]), (vxr[6], vxi[6])]; + let b1 = block2(x1, y); + let b2 = block2(x2, y); + vxr = [b1[0].0, b1[1].0, b1[2].0, 0, b2[0].0, b2[1].0, b2[2].0, 0]; + vxi = [b1[0].1, b1[1].1, b1[2].1, 0, b2[0].1, b2[1].1, b2[2].1, 0]; + let rr = _mm512_loadu_si512(vxr.as_ptr().cast::()); + let ri = _mm512_loadu_si512(vxi.as_ptr().cast::()); + (rr, ri) +} + +#[inline(always)] +unsafe fn block3_avx512(x: &__m512i, y: [i64; 3]) -> __m512i { + let x0 = _mm512_permutex_epi64(*x, 0x0); + let x1 = _mm512_permutex_epi64(*x, 0x55); + let x2 = _mm512_permutex_epi64(*x, 0xAA); + + let f0 = _mm512_set_epi64(0, y[2], y[1], y[0], 0, y[2], y[1], y[0]); + let f1 = _mm512_set_epi64(0, y[1], y[0], -y[2], 0, y[1], y[0], -y[2]); + let f2 = _mm512_set_epi64(0, y[0], -y[2], -y[1], 0, y[0], -y[2], -y[1]); + + let t0 = mul64_no_overflow_avx512(&x0, &f0); + let t1 = mul64_no_overflow_avx512(&x1, &f1); + let t2 = mul64_no_overflow_avx512(&x2, &f2); + + let t0 = _mm512_add_epi64(t0, t1); + _mm512_add_epi64(t0, t2) +} + +#[inline] +unsafe fn mds_multiply_freq_avx512(s0: &mut __m512i, s1: &mut __m512i, s2: &mut __m512i) { + /* + // Alternative code using store and set. + let mut s: [i64; 12] = [0; 12]; + _mm256_storeu_si256(s[0..4].as_mut_ptr().cast::<__m256i>(), *s0); + _mm256_storeu_si256(s[4..8].as_mut_ptr().cast::<__m256i>(), *s1); + _mm256_storeu_si256(s[8..12].as_mut_ptr().cast::<__m256i>(), *s2); + let f0 = _mm256_set_epi64x(0, s[2], s[1], s[0]); + let f1 = _mm256_set_epi64x(0, s[5], s[4], s[3]); + let f2 = _mm256_set_epi64x(0, s[8], s[7], s[6]); + let f3 = _mm256_set_epi64x(0, s[11], s[10], s[9]); + */ + + // Alternative code using permute and blend (it is faster). + let f0 = *s0; + let f11 = _mm512_permutex_epi64(*s0, 0x3); + let f12 = _mm512_permutex_epi64(*s1, 0x10); + let f1 = _mm512_mask_blend_epi64(0x66, f11, f12); + let f21 = _mm512_permutex_epi64(*s1, 0xE); + let f22 = _mm512_permutex_epi64(*s2, 0x0); + let f2 = _mm512_mask_blend_epi64(0x44, f21, f22); + let f3 = _mm512_permutex_epi64(*s2, 0x39); + + let (u0, u1, u2, u3) = fft4_real_avx512(&f0, &f1, &f2, &f3); + + // let [v0, v4, v8] = block1_avx([u[0], u[1], u[2]], MDS_FREQ_BLOCK_ONE); + // [u[0], u[1], u[2]] are all in u0 + let f0 = block1_avx512(&u0, MDS_FREQ_BLOCK_ONE); + + // let [v1, v5, v9] = block2([(u[0], v[0]), (u[1], v[1]), (u[2], v[2])], MDS_FREQ_BLOCK_TWO); + let (f1, f2) = block2_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); + + // let [v2, v6, v10] = block3_avx([u[0], u[1], u[2]], MDS_FREQ_BLOCK_ONE); + // [u[0], u[1], u[2]] are all in u3 + let f3 = block3_avx512(&u3, MDS_FREQ_BLOCK_THREE); + + let (r0, r3, r6, r9) = ifft4_real_unreduced_avx512((f0, (f1, f2), f3)); + let t = _mm512_permutex_epi64(r3, 0x0); + *s0 = _mm512_mask_blend_epi64(0x88, r0, t); + let t1 = _mm512_permutex_epi64(r3, 0x9); + let t2 = _mm512_permutex_epi64(r6, 0x40); + *s1 = _mm512_mask_blend_epi64(0xCC, t1, t2); + let t1 = _mm512_permutex_epi64(r6, 0x2); + let t2 = _mm512_permutex_epi64(r9, 0x90); + *s2 = _mm512_mask_blend_epi64(0xEE, t1, t2); +} + +#[inline(always)] +#[unroll_for_loops] +unsafe fn mds_layer_avx512(s0: &mut __m512i, s1: &mut __m512i, s2: &mut __m512i) { + let mask = _mm512_set_epi64( + 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, + 0xFFFFFFFF, + ); + let mut sl0 = _mm512_and_si512(*s0, mask); + let mut sl1 = _mm512_and_si512(*s1, mask); + let mut sl2 = _mm512_and_si512(*s2, mask); + let mut sh0 = _mm512_srli_epi64(*s0, 32); + let mut sh1 = _mm512_srli_epi64(*s1, 32); + let mut sh2 = _mm512_srli_epi64(*s2, 32); + + mds_multiply_freq_avx512(&mut sl0, &mut sl1, &mut sl2); + mds_multiply_freq_avx512(&mut sh0, &mut sh1, &mut sh2); + + let shl0 = _mm512_slli_epi64(sh0, 32); + let shl1 = _mm512_slli_epi64(sh1, 32); + let shl2 = _mm512_slli_epi64(sh2, 32); + let shh0 = _mm512_srli_epi64(sh0, 32); + let shh1 = _mm512_srli_epi64(sh1, 32); + let shh2 = _mm512_srli_epi64(sh2, 32); + + let (rl0, c0) = add64_no_carry_avx512(&sl0, &shl0); + let (rh0, _) = add64_no_carry_avx512(&shh0, &c0); + let r0 = reduce_avx512_96_64(&rh0, &rl0); + + let (rl1, c1) = add64_no_carry_avx512(&sl1, &shl1); + let (rh1, _) = add64_no_carry_avx512(&shh1, &c1); + *s1 = reduce_avx512_96_64(&rh1, &rl1); + + let (rl2, c2) = add64_no_carry_avx512(&sl2, &shl2); + let (rh2, _) = add64_no_carry_avx512(&shh2, &c2); + *s2 = reduce_avx512_96_64(&rh2, &rl2); + + let rl = _mm512_slli_epi64(*s0, 3); // * 8 (low part) + let rh = _mm512_srli_epi64(*s0, 61); // * 8 (high part, only 3 bits) + let rx = reduce_avx512_96_64(&rh, &rl); + let rx = add_avx512(&r0, &rx); + *s0 = _mm512_mask_blend_epi64(0x11, r0, rx); +} + +#[unroll_for_loops] +unsafe fn mds_partial_layer_init_avx512(s0: &mut __m512i, s1: &mut __m512i, s2: &mut __m512i) +where + F: PrimeField64, +{ + let mut result = [F::ZERO; 2 * SPONGE_WIDTH]; + let res0 = *s0; + + let mut r0 = _mm512_loadu_si512((&mut result[0..8]).as_mut_ptr().cast::()); + let mut r1 = _mm512_loadu_si512((&mut result[0..8]).as_mut_ptr().cast::()); + let mut r2 = _mm512_loadu_si512((&mut result[0..8]).as_mut_ptr().cast::()); + for r in 1..12 { + let sr = match r { + 1 => _mm512_permutex_epi64(*s0, 0x55), + 2 => _mm512_permutex_epi64(*s0, 0xAA), + 3 => _mm512_permutex_epi64(*s0, 0xFF), + 4 => _mm512_permutex_epi64(*s1, 0x0), + 5 => _mm512_permutex_epi64(*s1, 0x55), + 6 => _mm512_permutex_epi64(*s1, 0xAA), + 7 => _mm512_permutex_epi64(*s1, 0xFF), + 8 => _mm512_permutex_epi64(*s2, 0x0), + 9 => _mm512_permutex_epi64(*s2, 0x55), + 10 => _mm512_permutex_epi64(*s2, 0xAA), + 11 => _mm512_permutex_epi64(*s2, 0xFF), + _ => _mm512_permutex_epi64(*s0, 0x55), + }; + let t0 = _mm512_loadu_si512( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][0..8]) + .as_ptr() + .cast::(), + ); + let t1 = _mm512_loadu_si512( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][8..16]) + .as_ptr() + .cast::(), + ); + let t2 = _mm512_loadu_si512( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][16..24]) + .as_ptr() + .cast::(), + ); + let m0 = mult_avx512(&sr, &t0); + let m1 = mult_avx512(&sr, &t1); + let m2 = mult_avx512(&sr, &t2); + r0 = add_avx512(&r0, &m0); + r1 = add_avx512(&r1, &m1); + r2 = add_avx512(&r2, &m2); + } + *s0 = _mm512_mask_blend_epi64(0x11, r0, res0); + *s1 = r1; + *s2 = r2; +} + +#[inline(always)] +#[unroll_for_loops] +unsafe fn mds_partial_layer_fast_avx512( + s0: &mut __m512i, + s1: &mut __m512i, + s2: &mut __m512i, + state: &mut [F; 2 * SPONGE_WIDTH], + r: usize, +) where + F: PrimeField64, +{ + let mut d_sum1 = (0u128, 0u32); // u160 accumulator + let mut d_sum2 = (0u128, 0u32); // u160 accumulator + for i in 1..4 { + let t = FAST_PARTIAL_ROUND_W_HATS[r][i - 1] as u128; + let si1 = state[i].to_noncanonical_u64() as u128; + let si2 = state[i + 4].to_noncanonical_u64() as u128; + d_sum1 = add_u160_u128(d_sum1, si1 * t); + d_sum2 = add_u160_u128(d_sum2, si2 * t); + } + for i in 4..8 { + let t = FAST_PARTIAL_ROUND_W_HATS[r][i - 1] as u128; + let si1 = state[i + 4].to_noncanonical_u64() as u128; + let si2 = state[i + 8].to_noncanonical_u64() as u128; + d_sum1 = add_u160_u128(d_sum1, si1 * t); + d_sum2 = add_u160_u128(d_sum2, si2 * t); + } + for i in 8..12 { + let t = FAST_PARTIAL_ROUND_W_HATS[r][i - 1] as u128; + let si1 = state[i + 8].to_noncanonical_u64() as u128; + let si2 = state[i + 12].to_noncanonical_u64() as u128; + d_sum1 = add_u160_u128(d_sum1, si1 * t); + d_sum2 = add_u160_u128(d_sum2, si2 * t); + } + // 1st + let x0_1 = state[0].to_noncanonical_u64() as u128; + let mds0to0_1 = (MDS_MATRIX_CIRC[0] + MDS_MATRIX_DIAG[0]) as u128; + d_sum1 = add_u160_u128(d_sum1, x0_1 * mds0to0_1); + let d1 = reduce_u160::(d_sum1); + // 2nd + let x0_2 = state[4].to_noncanonical_u64() as u128; + let mds0to0_2 = (MDS_MATRIX_CIRC[0] + MDS_MATRIX_DIAG[0]) as u128; + d_sum2 = add_u160_u128(d_sum2, x0_2 * mds0to0_2); + let d2 = reduce_u160::(d_sum2); + + // result = [d] concat [state[0] * v + state[shift up by 1]] + let ss0 = _mm512_set_epi64( + state[4].to_noncanonical_u64() as i64, + state[4].to_noncanonical_u64() as i64, + state[4].to_noncanonical_u64() as i64, + state[4].to_noncanonical_u64() as i64, + state[0].to_noncanonical_u64() as i64, + state[0].to_noncanonical_u64() as i64, + state[0].to_noncanonical_u64() as i64, + state[0].to_noncanonical_u64() as i64, + ); + let rc0 = _mm512_loadu_si512( + (&FAST_PARTIAL_ROUND_VS_AVX512[r][0..8]) + .as_ptr() + .cast::(), + ); + let rc1 = _mm512_loadu_si512( + (&FAST_PARTIAL_ROUND_VS_AVX512[r][8..16]) + .as_ptr() + .cast::(), + ); + let rc2 = _mm512_loadu_si512( + (&FAST_PARTIAL_ROUND_VS_AVX512[r][16..24]) + .as_ptr() + .cast::(), + ); + let (mh, ml) = mult_avx512_128(&ss0, &rc0); + let m = reduce_avx512_128_64(&mh, &ml); + let r0 = add_avx512(s0, &m); + let d0 = _mm512_set_epi64( + 0, + 0, + 0, + d2.to_canonical_u64() as i64, + 0, + 0, + 0, + d1.to_canonical_u64() as i64, + ); + *s0 = _mm512_mask_blend_epi64(0x11, r0, d0); + + let (mh, ml) = mult_avx512_128(&ss0, &rc1); + let m = reduce_avx512_128_64(&mh, &ml); + *s1 = add_avx512(s1, &m); + + let (mh, ml) = mult_avx512_128(&ss0, &rc2); + let m = reduce_avx512_128_64(&mh, &ml); + *s2 = add_avx512(s2, &m); + + _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), *s0); + _mm512_storeu_si512((state[8..16]).as_mut_ptr().cast::(), *s1); + _mm512_storeu_si512((state[16..24]).as_mut_ptr().cast::(), *s2); +} + +#[allow(unused)] +pub fn poseidon_avx512_single(input: &[F; SPONGE_WIDTH]) -> [F; SPONGE_WIDTH] where F: PrimeField64 + Poseidon, { @@ -313,8 +1604,8 @@ where .unwrap(); let rc0 = _mm512_loadu_si512((&rc[0..8]).as_ptr().cast::()); let rc1 = _mm512_loadu_si512((&rc[4..12]).as_ptr().cast::()); - let ss0 = add_avx512_b_c(&s0, &rc0); - let ss1 = add_avx512_b_c(&s1, &rc1); + let ss0 = add_avx512(&s0, &rc0); + let ss1 = add_avx512(&s1, &rc1); let r0 = sbox_avx512_one(&ss0); let r1 = sbox_avx512_one(&ss1); @@ -346,8 +1637,8 @@ where .unwrap(); let rc0 = _mm512_loadu_si512((&rc[0..8]).as_ptr().cast::()); let rc1 = _mm512_loadu_si512((&rc[4..12]).as_ptr().cast::()); - let ss0 = add_avx512_b_c(&s0, &rc0); - let ss1 = add_avx512_b_c(&s1, &rc1); + let ss0 = add_avx512(&s0, &rc0); + let ss1 = add_avx512(&s1, &rc1); let r0 = sbox_avx512_one(&ss0); let r1 = sbox_avx512_one(&ss1); @@ -364,3 +1655,183 @@ where }; *state } + +pub fn poseidon_avx512_double(input: &[F; 2 * SPONGE_WIDTH]) -> [F; 2 * SPONGE_WIDTH] +where + F: PrimeField64 + Poseidon, +{ + let mut state: [F; 24] = input.clone(); + state[0..4].copy_from_slice(&input[0..4]); + state[4..8].copy_from_slice(&input[12..16]); + state[8..12].copy_from_slice(&input[4..8]); + state[12..16].copy_from_slice(&input[16..20]); + state[16..20].copy_from_slice(&input[8..12]); + state[20..24].copy_from_slice(&input[20..24]); + + let mut round_ctr = 0; + + unsafe { + // load state + let mut s0 = _mm512_loadu_si512((&state[0..8]).as_ptr().cast::()); + let mut s1 = _mm512_loadu_si512((&state[8..16]).as_ptr().cast::()); + let mut s2 = _mm512_loadu_si512((&state[16..24]).as_ptr().cast::()); + + for _ in 0..HALF_N_FULL_ROUNDS { + let rc: &[u64; 24] = &ALL_ROUND_CONSTANTS_AVX512[2 * SPONGE_WIDTH * round_ctr..] + [..2 * SPONGE_WIDTH] + .try_into() + .unwrap(); + let rc0 = _mm512_loadu_si512((&rc[0..8]).as_ptr().cast::()); + let rc1 = _mm512_loadu_si512((&rc[8..16]).as_ptr().cast::()); + let rc2 = _mm512_loadu_si512((&rc[16..24]).as_ptr().cast::()); + let ss0 = add_avx512(&s0, &rc0); + let ss1 = add_avx512(&s1, &rc1); + let ss2 = add_avx512(&s2, &rc2); + s0 = sbox_avx512_one(&ss0); + s1 = sbox_avx512_one(&ss1); + s2 = sbox_avx512_one(&ss2); + mds_layer_avx512(&mut s0, &mut s1, &mut s2); + round_ctr += 1; + } + + // this does partial_first_constant_layer_avx(&mut state); + let c0 = _mm512_loadu_si512( + (&FAST_PARTIAL_FIRST_ROUND_CONSTANT_AVX512[0..8]) + .as_ptr() + .cast::(), + ); + let c1 = _mm512_loadu_si512( + (&FAST_PARTIAL_FIRST_ROUND_CONSTANT_AVX512[8..16]) + .as_ptr() + .cast::(), + ); + let c2 = _mm512_loadu_si512( + (&FAST_PARTIAL_FIRST_ROUND_CONSTANT_AVX512[16..24]) + .as_ptr() + .cast::(), + ); + s0 = add_avx512(&s0, &c0); + s1 = add_avx512(&s1, &c1); + s2 = add_avx512(&s2, &c2); + + mds_partial_layer_init_avx512::(&mut s0, &mut s1, &mut s2); + + _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), s0); + _mm512_storeu_si512((state[8..16]).as_mut_ptr().cast::(), s1); + _mm512_storeu_si512((state[16..24]).as_mut_ptr().cast::(), s2); + + for i in 0..N_PARTIAL_ROUNDS { + state[0] = sbox_monomial(state[0]); + state[0] = state[0].add_canonical_u64(FAST_PARTIAL_ROUND_CONSTANTS[i]); + state[4] = sbox_monomial(state[4]); + state[4] = state[4].add_canonical_u64(FAST_PARTIAL_ROUND_CONSTANTS[i]); + mds_partial_layer_fast_avx512(&mut s0, &mut s1, &mut s2, &mut state, i); + } + round_ctr += N_PARTIAL_ROUNDS; + + // here state is already loaded in s0, s1, s2 + // Self::full_rounds(&mut state, &mut round_ctr); + for _ in 0..HALF_N_FULL_ROUNDS { + let rc: &[u64; 24] = &ALL_ROUND_CONSTANTS_AVX512[2 * SPONGE_WIDTH * round_ctr..] + [..2 * SPONGE_WIDTH] + .try_into() + .unwrap(); + let rc0 = _mm512_loadu_si512((&rc[0..8]).as_ptr().cast::()); + let rc1 = _mm512_loadu_si512((&rc[8..16]).as_ptr().cast::()); + let rc2 = _mm512_loadu_si512((&rc[16..24]).as_ptr().cast::()); + let ss0 = add_avx512(&s0, &rc0); + let ss1 = add_avx512(&s1, &rc1); + let ss2 = add_avx512(&s2, &rc2); + s0 = sbox_avx512_one(&ss0); + s1 = sbox_avx512_one(&ss1); + s2 = sbox_avx512_one(&ss2); + mds_layer_avx512(&mut s0, &mut s1, &mut s2); + round_ctr += 1; + } + + // store state + _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), s0); + _mm512_storeu_si512((state[8..16]).as_mut_ptr().cast::(), s1); + _mm512_storeu_si512((state[16..24]).as_mut_ptr().cast::(), s2); + + debug_assert_eq!(round_ctr, N_ROUNDS); + }; + + let mut new_state: [F; 24] = state.clone(); + new_state[0..4].copy_from_slice(&state[0..4]); + new_state[4..8].copy_from_slice(&state[8..12]); + new_state[8..12].copy_from_slice(&state[16..20]); + new_state[12..16].copy_from_slice(&state[4..8]); + new_state[16..20].copy_from_slice(&state[12..16]); + new_state[20..24].copy_from_slice(&state[20..24]); + new_state +} + +pub fn hash_leaf_avx512(inputs: &[F], leaf_size: usize) -> Vec> +where + F: RichField, +{ + if leaf_size <= 4 { + let mut inputs_bytes1 = vec![0u8; 32]; + let mut inputs_bytes2 = vec![0u8; 32]; + for i in 0..inputs.len() { + inputs_bytes1[i * 8..(i + 1) * 8] + .copy_from_slice(&inputs[i].to_canonical_u64().to_le_bytes()); + inputs_bytes2[i * 8..(i + 1) * 8] + .copy_from_slice(&inputs[i + leaf_size].to_canonical_u64().to_le_bytes()); + } + return vec![ + HashOut::from_bytes(&inputs_bytes1), + HashOut::from_bytes(&inputs_bytes2), + ]; + } + + let mut state: [F; 24] = [F::ZERO; 24]; + + // Absorb all input chunks. + let mut idx1 = 0; + let mut idx2 = leaf_size; + let loops = if leaf_size % SPONGE_RATE == 0 { + leaf_size / SPONGE_RATE + } else { + leaf_size / SPONGE_RATE + 1 + }; + for _ in 0..loops { + let end1 = if idx1 + SPONGE_RATE > leaf_size { + leaf_size + } else { + idx1 + SPONGE_RATE + }; + let end2 = if idx2 + SPONGE_RATE > inputs.len() { + inputs.len() + } else { + idx2 + SPONGE_RATE + }; + let end = end1 - idx1; + state[0..end].copy_from_slice(&inputs[idx1..end1]); + state[12..12 + end].copy_from_slice(&inputs[idx2..end2]); + state = poseidon_avx512_double(&state); + idx1 += SPONGE_RATE; + idx2 += SPONGE_RATE; + } + + // Squeeze until we have the desired number of outputs. + let output1 = vec![state[0], state[1], state[2], state[3]]; + let output2 = vec![state[12], state[13], state[14], state[15]]; + vec![HashOut::from_vec(output1), HashOut::from_vec(output2)] +} + +pub fn hash_two_avx512(h1: &Vec, h2: &Vec, h3: &Vec, h4: &Vec) -> Vec> +where + F: RichField, +{ + let mut state: [F; 24] = [F::ZERO; 24]; + state[0..4].copy_from_slice(&h1); + state[4..8].copy_from_slice(&h2); + state[12..16].copy_from_slice(&h3); + state[16..20].copy_from_slice(&h4); + state = poseidon_avx512_double(&state); + let output1 = vec![state[0], state[1], state[2], state[3]]; + let output2 = vec![state[12], state[13], state[14], state[15]]; + vec![HashOut::from_vec(output1), HashOut::from_vec(output2)] +} diff --git a/plonky2/src/hash/merkle_tree.rs b/plonky2/src/hash/merkle_tree.rs index fe78b94b8f..9bb93dec5f 100644 --- a/plonky2/src/hash/merkle_tree.rs +++ b/plonky2/src/hash/merkle_tree.rs @@ -23,6 +23,8 @@ use once_cell::sync::Lazy; use plonky2_maybe_rayon::*; use serde::{Deserialize, Serialize}; +#[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] +use crate::hash::arch::x86_64::poseidon_goldilocks_avx512::{hash_leaf_avx512, hash_two_avx512}; use crate::hash::hash_types::RichField; #[cfg(feature = "cuda")] use crate::hash::hash_types::NUM_HASH_OUT_ELTS; @@ -257,6 +259,140 @@ fn fill_digests_buf>( */ } +#[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] +fn fill_subtree_poseidon_avx512>( + digests_buf: &mut [MaybeUninit], + leaves: &[F], + leaf_size: usize, +) -> H::Hash { + let leaves_count = leaves.len() / leaf_size; + + // if one leaf => return its hash + if leaves_count == 1 { + let hash = H::hash_or_noop(leaves); + digests_buf[0].write(hash); + return hash; + } + // if two leaves => return their concat hash + if leaves_count == 2 { + let h = hash_leaf_avx512(leaves, leaf_size); + let hash_left = H::Hash::from_bytes(&h[0].to_bytes()); + let hash_right = H::Hash::from_bytes(&h[1].to_bytes()); + digests_buf[0].write(hash_left); + digests_buf[1].write(hash_right); + return H::two_to_one(hash_left, hash_right); + } + + assert_eq!(leaves_count, digests_buf.len() / 2 + 1); + + // leaves first - we can do all in parallel + let (_, digests_leaves) = digests_buf.split_at_mut(digests_buf.len() - leaves_count); + digests_leaves + .par_chunks_mut(2) + .into_par_iter() + .enumerate() + .for_each(|(chunk_idx, digests)| { + let (_, r) = leaves.split_at(2 * chunk_idx * leaf_size); + let (leaves2, _) = r.split_at(2 * leaf_size); + let h = hash_leaf_avx512(leaves2, leaf_size); + let h1 = H::Hash::from_bytes(&h[0].to_bytes()); + let h2 = H::Hash::from_bytes(&h[1].to_bytes()); + digests[0].write(h1); + digests[1].write(h2); + }); + + // internal nodes - we can do in parallel per level + let mut last_index = digests_buf.len() - leaves_count; + + for level_log in range(1, log2_strict(leaves_count)).rev() { + let level_size = 1 << level_log; + let (_, digests_slice) = digests_buf.split_at_mut(last_index - level_size); + let (digests_slice, next_digests) = digests_slice.split_at_mut(level_size); + + digests_slice + .par_chunks_mut(2) + .into_par_iter() + .enumerate() + .for_each(|(chunk_idx, digests)| { + let idx = last_index - level_size + 2 * chunk_idx; + let left_idx1 = 2 * (idx + 1) - last_index; + let right_idx1 = left_idx1 + 1; + let left_idx2 = right_idx1 + 1; + let right_idx2 = left_idx2 + 1; + + unsafe { + let left_digest1 = next_digests[left_idx1].assume_init().to_vec(); + let right_digest1 = next_digests[right_idx1].assume_init().to_vec(); + let left_digest2 = next_digests[left_idx2].assume_init().to_vec(); + let right_digest2 = next_digests[right_idx2].assume_init().to_vec(); + + let h = hash_two_avx512( + &left_digest1, + &right_digest1, + &left_digest2, + &right_digest2, + ); + let h1 = H::Hash::from_bytes(&h[0].to_bytes()); + let h2 = H::Hash::from_bytes(&h[1].to_bytes()); + digests[0].write(h1); + digests[1].write(h2); + } + }); + last_index -= level_size; + } + + // return cap hash + let hash: >::Hash; + unsafe { + let left_digest = digests_buf[0].assume_init(); + let right_digest = digests_buf[1].assume_init(); + hash = H::two_to_one(left_digest, right_digest); + } + hash +} + +#[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] +fn fill_digests_buf_poseidon_avx515>( + digests_buf: &mut [MaybeUninit], + cap_buf: &mut [MaybeUninit], + leaves: &Vec, + leaf_size: usize, + cap_height: usize, +) { + let leaves_count = leaves.len() / leaf_size; + if digests_buf.is_empty() { + debug_assert_eq!(cap_buf.len(), leaves_count); + cap_buf + .par_chunks_mut(2) + .into_par_iter() + .enumerate() + .for_each(|(leaf_idx, cap_buf)| { + let (_, r) = leaves.split_at(2 * leaf_idx * leaf_size); + let (lv, _) = r.split_at(2 * leaf_size); + let h = hash_leaf_avx512(lv, leaf_size); + cap_buf[0].write(H::Hash::from_bytes(&h[0].to_bytes())); + cap_buf[1].write(H::Hash::from_bytes(&h[1].to_bytes())); + }); + return; + } + + let subtree_digests_len = digests_buf.len() >> cap_height; + let subtree_leaves_len = leaves_count >> cap_height; + let digests_chunks = digests_buf.par_chunks_exact_mut(subtree_digests_len); + let leaves_chunks = leaves.par_chunks_exact(subtree_leaves_len * leaf_size); + assert_eq!(digests_chunks.len(), cap_buf.len()); + assert_eq!(digests_chunks.len(), leaves_chunks.len()); + digests_chunks.zip(cap_buf).zip(leaves_chunks).for_each( + |((subtree_digests, subtree_cap), subtree_leaves)| { + subtree_cap.write(fill_subtree_poseidon_avx512::( + subtree_digests, + subtree_leaves, + leaf_size, + )); + }, + ); +} + #[cfg(feature = "cuda")] fn fill_digests_buf_gpu>( digests_buf: &mut [MaybeUninit], @@ -455,7 +591,10 @@ fn fill_digests_buf_meta>( } } -#[cfg(not(feature = "cuda"))] +#[cfg(all( + not(feature = "cuda"), + not(all(target_feature = "avx2", target_feature = "avx512dq")) +))] fn fill_digests_buf_meta>( digests_buf: &mut [MaybeUninit], cap_buf: &mut [MaybeUninit], @@ -466,6 +605,29 @@ fn fill_digests_buf_meta>( fill_digests_buf::(digests_buf, cap_buf, leaves, leaf_size, cap_height); } +#[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] +fn fill_digests_buf_meta>( + digests_buf: &mut [MaybeUninit], + cap_buf: &mut [MaybeUninit], + leaves: &Vec, + leaf_size: usize, + cap_height: usize, +) { + use crate::plonk::config::HasherType; + + if leaf_size <= H::HASH_SIZE / 8 || H::HASHER_TYPE != HasherType::Poseidon { + fill_digests_buf::(digests_buf, cap_buf, leaves, leaf_size, cap_height); + } else { + fill_digests_buf_poseidon_avx515::( + digests_buf, + cap_buf, + leaves, + leaf_size, + cap_height, + ); + } +} + impl> MerkleTree { pub fn new_from_1d(leaves_1d: Vec, leaf_size: usize, cap_height: usize) -> Self { let leaves_len = leaves_1d.len() / leaf_size; @@ -478,10 +640,10 @@ impl> MerkleTree { ); let num_digests = 2 * (leaves_len - (1 << cap_height)); - let mut digests = Vec::with_capacity(num_digests); + let mut digests: Vec<>::Hash> = Vec::with_capacity(num_digests); let len_cap = 1 << cap_height; - let mut cap = Vec::with_capacity(len_cap); + let mut cap: Vec<>::Hash> = Vec::with_capacity(len_cap); let digests_buf = capacity_up_to_mut(&mut digests, num_digests); let cap_buf = capacity_up_to_mut(&mut cap, len_cap); @@ -864,10 +1026,15 @@ mod tests { GenericConfig, KeccakGoldilocksConfig, Poseidon2GoldilocksConfig, PoseidonGoldilocksConfig, }; - fn random_data(n: usize, k: usize) -> Vec> { + fn random_data_2d(n: usize, k: usize) -> Vec> { (0..n).map(|_| F::rand_vec(k)).collect() } + #[allow(unused)] + fn random_data_1d(n: usize, k: usize) -> Vec { + F::rand_vec(k * n) + } + fn verify_all_leaves< F: RichField + Extendable, C: GenericConfig, @@ -891,12 +1058,12 @@ mod tests { let n = 1 << log_n; let k = 7; - let mut leaves = random_data::(n, k); + let mut leaves = random_data_2d::(n, k); let mut mt1 = MerkleTree::>::Hasher>::new_from_2d(leaves.clone(), cap_h); - let tmp = random_data::(1, k); + let tmp = random_data_2d::(1, k); leaves[0] = tmp[0].clone(); let mt2 = MerkleTree::>::Hasher>::new_from_2d(leaves, cap_h); @@ -946,8 +1113,8 @@ mod tests { type C = PoseidonGoldilocksConfig; type F = >::F; - let raw_leaves: Vec> = random_data::(leaves_count, leaf_size); - let vals: Vec> = random_data::(end_index - start_index, leaf_size); + let raw_leaves: Vec> = random_data_2d::(leaves_count, leaf_size); + let vals: Vec> = random_data_2d::(end_index - start_index, leaf_size); let mut leaves1_1d: Vec = raw_leaves.into_iter().flatten().collect(); let leaves2_1d: Vec = leaves1_1d.clone(); @@ -1029,8 +1196,8 @@ mod tests { type C = PoseidonGoldilocksConfig; type F = >::F; - let raw_leaves: Vec> = random_data::(leaves_count, leaf_size); - let vals: Vec> = random_data::(end_index - start_index, leaf_size); + let raw_leaves: Vec> = random_data_2d::(leaves_count, leaf_size); + let vals: Vec> = random_data_2d::(end_index - start_index, leaf_size); let mut leaves1_1d: Vec = raw_leaves.into_iter().flatten().collect(); let leaves2_1d: Vec = leaves1_1d.clone(); @@ -1112,7 +1279,7 @@ mod tests { let log_n = 8; let cap_height = log_n + 1; // Should panic if `cap_height > len_n`. - let leaves = random_data::(1 << log_n, 7); + let leaves = random_data_2d::(1 << log_n, 7); let _ = MerkleTree::>::Hasher>::new_from_2d(leaves, cap_height); } @@ -1124,7 +1291,7 @@ mod tests { let log_n = 8; let n = 1 << log_n; - let leaves = random_data::(n, 7); + let leaves = random_data_2d::(n, 7); verify_all_leaves::(leaves, log_n)?; @@ -1178,7 +1345,7 @@ mod tests { let log_n = 12; let n = 1 << log_n; - let leaves = random_data::(n, 7); + let leaves = random_data_2d::(n, 7); verify_all_leaves::(leaves, 1)?; @@ -1194,7 +1361,7 @@ mod tests { let log_n = 14; let n = 1 << log_n; - let leaves = random_data::(n, 7); + let leaves = random_data_2d::(n, 7); let leaves_1d: Vec = leaves.into_iter().flatten().collect(); let mut gpu_data: HostOrDeviceSlice<'_, F> = @@ -1216,7 +1383,7 @@ mod tests { let log_n = 12; let n = 1 << log_n; - let leaves = random_data::(n, 7); + let leaves = random_data_2d::(n, 7); verify_all_leaves::(leaves, 1)?; @@ -1231,7 +1398,7 @@ mod tests { let log_n = 12; let n = 1 << log_n; - let leaves = random_data::(n, 7); + let leaves = random_data_2d::(n, 7); verify_all_leaves::(leaves, 1)?; @@ -1246,10 +1413,78 @@ mod tests { let log_n = 12; let n = 1 << log_n; - let leaves = random_data::(n, 7); + let leaves = random_data_2d::(n, 7); verify_all_leaves::(leaves, 1)?; Ok(()) } + + #[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] + fn check_consistency>( + leaves: &Vec, + leaves_len: usize, + leaf_size: usize, + cap_height: usize, + ) { + println!("Check for height: {:?}", cap_height); + // no AVX + let num_digests = 2 * (leaves_len - (1 << cap_height)); + let mut digests: Vec<>::Hash> = Vec::with_capacity(num_digests); + let len_cap = 1 << cap_height; + let mut cap: Vec<>::Hash> = Vec::with_capacity(len_cap); + let digests_buf = capacity_up_to_mut(&mut digests, num_digests); + let cap_buf = capacity_up_to_mut(&mut cap, len_cap); + fill_digests_buf::(digests_buf, cap_buf, &leaves.clone(), leaf_size, cap_height); + unsafe { + digests.set_len(num_digests); + cap.set_len(len_cap); + } + // AVX512 + let mut digests_avx512: Vec<>::Hash> = Vec::with_capacity(num_digests); + let mut cap_avx512: Vec<>::Hash> = Vec::with_capacity(len_cap); + let digests_buf_avx512 = capacity_up_to_mut(&mut digests_avx512, num_digests); + let cap_buf_avx512 = capacity_up_to_mut(&mut cap_avx512, len_cap); + fill_digests_buf_poseidon_avx515::( + digests_buf_avx512, + cap_buf_avx512, + &leaves, + leaf_size, + cap_height, + ); + unsafe { + digests_avx512.set_len(num_digests); + cap_avx512.set_len(len_cap); + } + + digests + .into_iter() + .zip(digests_avx512) + .for_each(|(d1, d2)| { + assert_eq!(d1, d2); + }); + cap.into_iter().zip(cap_avx512).for_each(|(d1, d2)| { + assert_eq!(d1, d2); + }); + } + + #[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] + #[test] + fn test_merkle_trees_poseidon_g64_avx512_consistency() -> Result<()> { + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + type H = >::Hasher; + + let leaf_size = 16; + let log_n = 10; + let leaves_len = 1 << log_n; + let leaves: Vec = random_data_1d::(leaves_len, leaf_size); + + for cap_height in 0..log_n { + check_consistency::(&leaves, leaves_len, leaf_size, cap_height); + } + + Ok(()) + } } diff --git a/plonky2/src/hash/poseidon.rs b/plonky2/src/hash/poseidon.rs index b1cfb20d63..65884d8286 100644 --- a/plonky2/src/hash/poseidon.rs +++ b/plonky2/src/hash/poseidon.rs @@ -8,10 +8,8 @@ use core::fmt::Debug; use plonky2_field::packed::PackedField; use unroll::unroll_for_loops; -#[cfg(all(target_feature = "avx2", not(target_feature = "avx512dq")))] +#[cfg(target_feature = "avx2")] use super::arch::x86_64::poseidon_goldilocks_avx2::poseidon_avx; -#[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] -use super::arch::x86_64::poseidon_goldilocks_avx512::poseidon_avx512; use super::hash_types::HashOutTarget; use crate::field::extension::{Extendable, FieldExtension}; use crate::field::types::{Field, PrimeField64}; @@ -783,17 +781,11 @@ pub trait Poseidon: PrimeField64 { } #[inline] - #[cfg(all(target_feature = "avx2", not(target_feature = "avx512dq")))] + #[cfg(all(target_feature = "avx2"))] fn poseidon(input: [Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH] { poseidon_avx(&input) } - #[inline] - #[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] - fn poseidon(input: [Self; SPONGE_WIDTH]) -> [Self; SPONGE_WIDTH] { - poseidon_avx512(&input) - } - // For testing only, to ensure that various tricks are correct. #[inline] fn partial_rounds_naive(state: &mut [Self; SPONGE_WIDTH], round_ctr: &mut usize) { @@ -988,4 +980,29 @@ pub(crate) mod test_helpers { assert_eq!(output[i], output_naive[i]); } } + + #[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] + pub(crate) fn check_test_vectors_avx512( + test_vectors: Vec<([u64; SPONGE_WIDTH], [u64; SPONGE_WIDTH])>, + ) where + F: Poseidon, + { + use crate::hash::arch::x86_64::poseidon_goldilocks_avx512::poseidon_avx512_double; + + println!("Checking test vectors with AVX512 Poseidon implementation..."); + + for (input_, expected_output_) in test_vectors.into_iter() { + let mut input = [F::ZERO; 2 * SPONGE_WIDTH]; + for i in 0..SPONGE_WIDTH { + input[i] = F::from_canonical_u64(input_[i]); + input[i + SPONGE_WIDTH] = F::from_canonical_u64(input_[i]); + } + let output = poseidon_avx512_double::(&input); + for i in 0..SPONGE_WIDTH { + let ex_output = F::from_canonical_u64(expected_output_[i]); + assert_eq!(output[i], ex_output); + assert_eq!(output[i + SPONGE_WIDTH], ex_output); + } + } + } } diff --git a/plonky2/src/hash/poseidon_goldilocks.rs b/plonky2/src/hash/poseidon_goldilocks.rs index 164ec6d633..f78c3f4a24 100644 --- a/plonky2/src/hash/poseidon_goldilocks.rs +++ b/plonky2/src/hash/poseidon_goldilocks.rs @@ -449,6 +449,8 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField as F; use crate::field::types::{Field, PrimeField64}; + #[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] + use crate::hash::poseidon::test_helpers::check_test_vectors_avx512; use crate::hash::poseidon::test_helpers::{check_consistency, check_test_vectors}; use crate::hash::poseidon::{Poseidon, PoseidonHash}; use crate::plonk::config::Hasher; @@ -488,7 +490,10 @@ mod tests { 0xfcc781b0ce382bf2, 0x934c69ff3ed14ba5, 0x504688a5996e8f13, 0x401f3f2ed524a2ba, ]), ]; - check_test_vectors::(test_vectors12); + check_test_vectors::(test_vectors12.clone()); + + #[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] + check_test_vectors_avx512::(test_vectors12); } #[test] From c0d2b9283b7744cb536a5028c1d193eb096338c4 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Tue, 15 Oct 2024 11:33:19 +0800 Subject: [PATCH 09/16] cargo fmt --- plonky2/src/fri/oracle.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/plonky2/src/fri/oracle.rs b/plonky2/src/fri/oracle.rs index 99a82ec6b5..ca945da163 100644 --- a/plonky2/src/fri/oracle.rs +++ b/plonky2/src/fri/oracle.rs @@ -28,7 +28,8 @@ use crate::util::timing::TimingTree; use crate::util::{log2_strict, reverse_bits, reverse_index_bits_in_place, transpose}; #[cfg(all(feature = "cuda", any(test, doctest)))] -pub static GPU_INIT: once_cell::sync::Lazy>> = once_cell::sync::Lazy::new(|| std::sync::Arc::new(std::sync::Mutex::new(0))); +pub static GPU_INIT: once_cell::sync::Lazy>> = + once_cell::sync::Lazy::new(|| std::sync::Arc::new(std::sync::Mutex::new(0))); #[cfg(all(feature = "cuda", any(test, doctest)))] fn init_gpu() { From 0a6626eaffda2073908a64b3e0add89e66b85c9f Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Tue, 15 Oct 2024 13:51:47 +0800 Subject: [PATCH 10/16] minor perf optimizations for avx512 --- .../arch/x86_64/poseidon_goldilocks_avx512.rs | 147 +++++++++++++----- plonky2/src/hash/hash_types.rs | 15 ++ plonky2/src/hash/merkle_tree.rs | 24 +-- plonky2/src/plonk/config.rs | 1 + 4 files changed, 137 insertions(+), 50 deletions(-) diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs index be920c065c..04e835e37d 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs @@ -8,13 +8,12 @@ use super::poseidon_goldilocks_avx2::{ use crate::field::types::PrimeField64; use crate::hash::arch::x86_64::goldilocks_avx512::*; use crate::hash::arch::x86_64::poseidon_goldilocks_avx2::FAST_PARTIAL_ROUND_W_HATS; -use crate::hash::hash_types::{HashOut, RichField}; +use crate::hash::hash_types::RichField; use crate::hash::poseidon::{ add_u160_u128, reduce_u160, Poseidon, ALL_ROUND_CONSTANTS, HALF_N_FULL_ROUNDS, N_PARTIAL_ROUNDS, N_ROUNDS, SPONGE_RATE, SPONGE_WIDTH, }; use crate::hash::poseidon_goldilocks::poseidon12_mds::block2; -use crate::plonk::config::GenericHashOut; #[allow(dead_code)] const MDS_MATRIX_CIRC: [u64; 12] = [17, 15, 41, 16, 2, 28, 13, 13, 39, 18, 34, 20]; @@ -1223,7 +1222,7 @@ unsafe fn fft4_real_avx512( x2: &__m512i, x3: &__m512i, ) -> (__m512i, __m512i, __m512i, __m512i) { - let zeros = _mm512_set_epi64(0, 0, 0, 0, 0, 0, 0, 0); + let zeros = _mm512_xor_si512(*x0, *x0); // faster 0 let (z0, z2) = fft2_real_avx512(x0, x2); let (z1, z3) = fft2_real_avx512(x1, x3); let y0 = _mm512_add_epi64(z0, z1); @@ -1243,7 +1242,7 @@ unsafe fn ifft2_real_unreduced_avx512(y0: &__m512i, y1: &__m512i) -> (__m512i, _ unsafe fn ifft4_real_unreduced_avx512( y: (__m512i, (__m512i, __m512i), __m512i), ) -> (__m512i, __m512i, __m512i, __m512i) { - let zeros = _mm512_set_epi64(0, 0, 0, 0, 0, 0, 0, 0); + let zeros = _mm512_xor_si512(y.0, y.0); // faster 0 let z0 = _mm512_add_epi64(y.0, y.2); let z1 = _mm512_sub_epi64(y.0, y.2); let z2 = y.1 .0; @@ -1264,7 +1263,7 @@ pub unsafe fn add64_no_carry_avx512(a: &__m512i, b: &__m512i) -> (__m512i, __m51 * - (test 3): if a + b >= 2^64 (this means a + b becomes positive in signed representation, that is, a + b >= 0) => there is overflow so cout = 1 */ let ones = _mm512_set_epi64(1, 1, 1, 1, 1, 1, 1, 1); - let zeros = _mm512_set_epi64(0, 0, 0, 0, 0, 0, 0, 0); + let zeros = _mm512_xor_si512(*a, *a); // faster 0 let r = _mm512_add_epi64(*a, *b); let ma = _mm512_cmpgt_epi64_mask(zeros, *a); let mb = _mm512_cmpgt_epi64_mask(zeros, *b); @@ -1276,16 +1275,7 @@ pub unsafe fn add64_no_carry_avx512(a: &__m512i, b: &__m512i) -> (__m512i, __m51 #[inline] pub unsafe fn mul64_no_overflow_avx512(a: &__m512i, b: &__m512i) -> __m512i { - let r = _mm512_mul_epu32(*a, *b); - let ah = _mm512_srli_epi64(*a, 32); - let bh = _mm512_srli_epi64(*b, 32); - let r1 = _mm512_mul_epu32(*a, bh); - let r1 = _mm512_slli_epi64(r1, 32); - let r = _mm512_add_epi64(r, r1); - let r1 = _mm512_mul_epu32(ah, *b); - let r1 = _mm512_slli_epi64(r1, 32); - let r = _mm512_add_epi64(r, r1); - r + _mm512_mullo_epi64(*a, *b) } #[inline(always)] @@ -1323,6 +1313,95 @@ unsafe fn block2_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> (__m5 (rr, ri) } +#[allow(dead_code)] +#[inline(always)] +unsafe fn block2_full_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> (__m512i, __m512i) { + let yr = _mm512_set_epi64(0, y[2].0, y[1].0, y[0].0, 0, y[2].0, y[1].0, y[0].0); + let yi = _mm512_set_epi64(0, y[2].1, y[1].1, y[0].1, 0, y[2].1, y[1].1, y[0].1); + let ys = _mm512_add_epi64(yr, yi); + let xs = _mm512_add_epi64(*xr, *xi); + + // z0 + // z0r = dif2[0] + prod[1] - sum[1] + prod[2] - sum[2] + // z0i = prod[0] - sum[0] + dif1[1] + dif1[2] + let yy = _mm512_permutex_epi64(yr, 0x18); + let mr_z0 = mul64_no_overflow_avx512(xr, &yy); + let yy = _mm512_permutex_epi64(yi, 0x18); + let mi_z0 = mul64_no_overflow_avx512(xi, &yy); + let sum = _mm512_add_epi64(mr_z0, mi_z0); + let dif1 = _mm512_sub_epi64(mi_z0, mr_z0); + let dif2 = _mm512_sub_epi64(mr_z0, mi_z0); + let yy = _mm512_permutex_epi64(ys, 0x18); + let prod = mul64_no_overflow_avx512(&xs, &yy); + let dif3 = _mm512_sub_epi64(prod, sum); + let dif3perm1 = _mm512_permutex_epi64(dif3, 0x1); + let dif3perm2 = _mm512_permutex_epi64(dif3, 0x2); + let z0r = _mm512_add_epi64(dif2, dif3perm1); + let z0r = _mm512_add_epi64(z0r, dif3perm2); + let dif1perm1 = _mm512_permutex_epi64(dif1, 0x1); + let dif1perm2 = _mm512_permutex_epi64(dif1, 0x2); + let z0i = _mm512_add_epi64(dif3, dif1perm1); + let z0i = _mm512_add_epi64(z0i, dif1perm2); + let mask = _mm512_set_epi64(0, 0, 0, 0xFFFFFFFFFFFFFFFFu64 as i64, 0, 0, 0, 0xFFFFFFFFFFFFFFFFu64 as i64); + let z0r = _mm512_and_si512(z0r, mask); + let z0i = _mm512_and_si512(z0i, mask); + + // z1 + // z1r = dif2[0] + dif2[1] + prod[2] - sum[2]; + // z1i = prod[0] - sum[0] + prod[1] - sum[1] + dif1[2]; + let yy = _mm512_permutex_epi64(yr, 0x21); + let mr_z1 = mul64_no_overflow_avx512(xr, &yy); + let yy = _mm512_permutex_epi64(yi, 0x21); + let mi_z1 = mul64_no_overflow_avx512(xi, &yy); + let sum = _mm512_add_epi64(mr_z1, mi_z1); + let dif1 = _mm512_sub_epi64(mi_z1, mr_z1); + let dif2 = _mm512_sub_epi64(mr_z1, mi_z1); + let yy = _mm512_permutex_epi64(ys, 0x21); + let prod = mul64_no_overflow_avx512(&xs, &yy); + let dif3 = _mm512_sub_epi64(prod, sum); + let dif2perm = _mm512_permutex_epi64(dif2, 0x0); + let dif3perm = _mm512_permutex_epi64(dif3, 0x8); + let z1r = _mm512_add_epi64(dif2, dif2perm); + let z1r = _mm512_add_epi64(z1r, dif3perm); + let dif3perm = _mm512_permutex_epi64(dif3, 0x0); + let dif1perm = _mm512_permutex_epi64(dif1, 0x8); + let z1i = _mm512_add_epi64(dif3, dif3perm); + let z1i = _mm512_add_epi64(z1i, dif1perm); + let mask = _mm512_set_epi64(0, 0, 0xFFFFFFFFFFFFFFFFu64 as i64, 0, 0, 0, 0xFFFFFFFFFFFFFFFFu64 as i64, 0); + let z1r = _mm512_and_si512(z1r, mask); + let z1i = _mm512_and_si512(z1i, mask); + + // z2 + // z2r = dif2[0] + dif2[1] + dif2[2]; + // z2i = prod[0] - sum[0] + prod[1] - sum[1] + prod[2] - sum[2] + let yy = _mm512_permutex_epi64(yr, 0x6); + let mr_z2 = mul64_no_overflow_avx512(xr, &yy); + let yy = _mm512_permutex_epi64(yi, 0x6); + let mi_z2 = mul64_no_overflow_avx512(xi, &yy); + let sum = _mm512_add_epi64(mr_z2, mi_z2); + let dif2 = _mm512_sub_epi64(mr_z2, mi_z2); + let yy = _mm512_permutex_epi64(ys, 0x6); + let prod = mul64_no_overflow_avx512(&xs, &yy); + let dif3 = _mm512_sub_epi64(prod, sum); + let dif2perm1 = _mm512_permutex_epi64(dif2, 0x0); + let dif2perm2 = _mm512_permutex_epi64(dif2, 0x10); + let z2r = _mm512_add_epi64(dif2, dif2perm1); + let z2r = _mm512_add_epi64(z2r, dif2perm2); + let dif3perm1 = _mm512_permutex_epi64(dif3, 0x0); + let dif3perm2 = _mm512_permutex_epi64(dif3, 0x10); + let z2i = _mm512_add_epi64(dif3, dif3perm1); + let z2i = _mm512_add_epi64(z2i, dif3perm2); + let mask = _mm512_set_epi64(0, 0xFFFFFFFFFFFFFFFFu64 as i64, 0, 0, 0, 0xFFFFFFFFFFFFFFFFu64 as i64, 0, 0); + let z2r = _mm512_and_si512(z2r, mask); + let z2i = _mm512_and_si512(z2i, mask); + + let zr = _mm512_or_si512(z0r, z1r); + let zr = _mm512_or_si512(zr, z2r); + let zi = _mm512_or_si512(z0i, z1i); + let zi = _mm512_or_si512(zi, z2i); + (zr, zi) +} + #[inline(always)] unsafe fn block3_avx512(x: &__m512i, y: [i64; 3]) -> __m512i { let x0 = _mm512_permutex_epi64(*x, 0x0); @@ -1372,7 +1451,8 @@ unsafe fn mds_multiply_freq_avx512(s0: &mut __m512i, s1: &mut __m512i, s2: &mut let f0 = block1_avx512(&u0, MDS_FREQ_BLOCK_ONE); // let [v1, v5, v9] = block2([(u[0], v[0]), (u[1], v[1]), (u[2], v[2])], MDS_FREQ_BLOCK_TWO); - let (f1, f2) = block2_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); + // let (f1, f2) = block2_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); + let (f1, f2) = block2_full_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); // let [v2, v6, v10] = block3_avx([u[0], u[1], u[2]], MDS_FREQ_BLOCK_ONE); // [u[0], u[1], u[2]] are all in u3 @@ -1767,28 +1847,23 @@ where new_state } -pub fn hash_leaf_avx512(inputs: &[F], leaf_size: usize) -> Vec> +pub fn hash_leaf_avx512(inputs: &[F], leaf_size: usize) -> (Vec, Vec) where F: RichField, { + // special case if leaf_size <= 4 { - let mut inputs_bytes1 = vec![0u8; 32]; - let mut inputs_bytes2 = vec![0u8; 32]; - for i in 0..inputs.len() { - inputs_bytes1[i * 8..(i + 1) * 8] - .copy_from_slice(&inputs[i].to_canonical_u64().to_le_bytes()); - inputs_bytes2[i * 8..(i + 1) * 8] - .copy_from_slice(&inputs[i + leaf_size].to_canonical_u64().to_le_bytes()); - } - return vec![ - HashOut::from_bytes(&inputs_bytes1), - HashOut::from_bytes(&inputs_bytes2), - ]; + let mut h1 = vec![F::ZERO; 4]; + let mut h2 = vec![F::ZERO; 4]; + h1.copy_from_slice(&inputs[0..leaf_size]); + h2.copy_from_slice(&inputs[leaf_size..2 * leaf_size]); + return (h1, h2); } + // general case let mut state: [F; 24] = [F::ZERO; 24]; - // Absorb all input chunks. + // absorb all input chunks of size SPONGE_RATE let mut idx1 = 0; let mut idx2 = leaf_size; let loops = if leaf_size % SPONGE_RATE == 0 { @@ -1815,13 +1890,11 @@ where idx2 += SPONGE_RATE; } - // Squeeze until we have the desired number of outputs. - let output1 = vec![state[0], state[1], state[2], state[3]]; - let output2 = vec![state[12], state[13], state[14], state[15]]; - vec![HashOut::from_vec(output1), HashOut::from_vec(output2)] + // return 2 hashes of 4 elements each + (vec![state[0], state[1], state[2], state[3]], vec![state[12], state[13], state[14], state[15]]) } -pub fn hash_two_avx512(h1: &Vec, h2: &Vec, h3: &Vec, h4: &Vec) -> Vec> +pub fn hash_two_avx512(h1: &Vec, h2: &Vec, h3: &Vec, h4: &Vec) -> (Vec, Vec) where F: RichField, { @@ -1831,7 +1904,5 @@ where state[12..16].copy_from_slice(&h3); state[16..20].copy_from_slice(&h4); state = poseidon_avx512_double(&state); - let output1 = vec![state[0], state[1], state[2], state[3]]; - let output2 = vec![state[12], state[13], state[14], state[15]]; - vec![HashOut::from_vec(output1), HashOut::from_vec(output2)] + (vec![state[0], state[1], state[2], state[3]], vec![state[12], state[13], state[14], state[15]]) } diff --git a/plonky2/src/hash/hash_types.rs b/plonky2/src/hash/hash_types.rs index 992f36197e..bab65c8ae2 100644 --- a/plonky2/src/hash/hash_types.rs +++ b/plonky2/src/hash/hash_types.rs @@ -104,6 +104,12 @@ impl GenericHashOut for HashOut { fn to_vec(&self) -> Vec { self.elements.to_vec() } + + fn from_vec(vec: &[F]) -> Self { + HashOut { + elements: vec.try_into().unwrap(), + } + } } impl Default for HashOut { @@ -190,6 +196,15 @@ impl GenericHashOut for BytesHash { }) .collect() } + + fn from_vec(vec: &[F]) -> Self { + let mut bytes = [0; N]; + for (i, &x) in vec.iter().enumerate() { + let arr = x.to_canonical_u64().to_le_bytes(); + bytes[i * 8..(i + 1) * 8].copy_from_slice(&arr); + } + Self(bytes) + } } impl Serialize for BytesHash { diff --git a/plonky2/src/hash/merkle_tree.rs b/plonky2/src/hash/merkle_tree.rs index 9bb93dec5f..d1ca38fb01 100644 --- a/plonky2/src/hash/merkle_tree.rs +++ b/plonky2/src/hash/merkle_tree.rs @@ -275,9 +275,9 @@ fn fill_subtree_poseidon_avx512>( } // if two leaves => return their concat hash if leaves_count == 2 { - let h = hash_leaf_avx512(leaves, leaf_size); - let hash_left = H::Hash::from_bytes(&h[0].to_bytes()); - let hash_right = H::Hash::from_bytes(&h[1].to_bytes()); + let (h1, h2) = hash_leaf_avx512(leaves, leaf_size); + let hash_left = H::Hash::from_vec(&h1); + let hash_right = H::Hash::from_vec(&h2); digests_buf[0].write(hash_left); digests_buf[1].write(hash_right); return H::two_to_one(hash_left, hash_right); @@ -294,9 +294,9 @@ fn fill_subtree_poseidon_avx512>( .for_each(|(chunk_idx, digests)| { let (_, r) = leaves.split_at(2 * chunk_idx * leaf_size); let (leaves2, _) = r.split_at(2 * leaf_size); - let h = hash_leaf_avx512(leaves2, leaf_size); - let h1 = H::Hash::from_bytes(&h[0].to_bytes()); - let h2 = H::Hash::from_bytes(&h[1].to_bytes()); + let (h1, h2) = hash_leaf_avx512(leaves2, leaf_size); + let h1 = H::Hash::from_vec(&h1); + let h2 = H::Hash::from_vec(&h2); digests[0].write(h1); digests[1].write(h2); }); @@ -326,14 +326,14 @@ fn fill_subtree_poseidon_avx512>( let left_digest2 = next_digests[left_idx2].assume_init().to_vec(); let right_digest2 = next_digests[right_idx2].assume_init().to_vec(); - let h = hash_two_avx512( + let (h1, h2) = hash_two_avx512( &left_digest1, &right_digest1, &left_digest2, &right_digest2, ); - let h1 = H::Hash::from_bytes(&h[0].to_bytes()); - let h2 = H::Hash::from_bytes(&h[1].to_bytes()); + let h1 = H::Hash::from_vec(&h1); + let h2 = H::Hash::from_vec(&h2); digests[0].write(h1); digests[1].write(h2); } @@ -369,9 +369,9 @@ fn fill_digests_buf_poseidon_avx515>( .for_each(|(leaf_idx, cap_buf)| { let (_, r) = leaves.split_at(2 * leaf_idx * leaf_size); let (lv, _) = r.split_at(2 * leaf_size); - let h = hash_leaf_avx512(lv, leaf_size); - cap_buf[0].write(H::Hash::from_bytes(&h[0].to_bytes())); - cap_buf[1].write(H::Hash::from_bytes(&h[1].to_bytes())); + let (h1, h2) = hash_leaf_avx512(lv, leaf_size); + cap_buf[0].write(H::Hash::from_vec(&h1)); + cap_buf[1].write(H::Hash::from_vec(&h2)); }); return; } diff --git a/plonky2/src/plonk/config.rs b/plonky2/src/plonk/config.rs index 4b5ddceada..731ab4c44f 100644 --- a/plonky2/src/plonk/config.rs +++ b/plonky2/src/plonk/config.rs @@ -39,6 +39,7 @@ pub trait GenericHashOut: fn from_bytes(bytes: &[u8]) -> Self; fn to_vec(&self) -> Vec; + fn from_vec(vec: &[F]) -> Self; } /// Trait for hash functions. From 3acd1f0565541f7be3f6c955b3dc1cbf576b4381 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Tue, 15 Oct 2024 13:55:33 +0800 Subject: [PATCH 11/16] fix warning --- plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs index 04e835e37d..4a7f8f9ad9 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs @@ -1296,6 +1296,7 @@ unsafe fn block1_avx512(x: &__m512i, y: [i64; 3]) -> __m512i { _mm512_add_epi64(t0, t2) } +#[allow(unused)] #[inline(always)] unsafe fn block2_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> (__m512i, __m512i) { let mut vxr: [i64; 8] = [0; 8]; From 3e2cc3f7b408785fd148cd0ff6cb9db42cb5c640 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Tue, 15 Oct 2024 13:56:59 +0800 Subject: [PATCH 12/16] fix cargo fmt --- .../arch/x86_64/poseidon_goldilocks_avx512.rs | 43 ++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs index 4a7f8f9ad9..9dc7c70898 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs @@ -1343,7 +1343,16 @@ unsafe fn block2_full_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> let dif1perm2 = _mm512_permutex_epi64(dif1, 0x2); let z0i = _mm512_add_epi64(dif3, dif1perm1); let z0i = _mm512_add_epi64(z0i, dif1perm2); - let mask = _mm512_set_epi64(0, 0, 0, 0xFFFFFFFFFFFFFFFFu64 as i64, 0, 0, 0, 0xFFFFFFFFFFFFFFFFu64 as i64); + let mask = _mm512_set_epi64( + 0, + 0, + 0, + 0xFFFFFFFFFFFFFFFFu64 as i64, + 0, + 0, + 0, + 0xFFFFFFFFFFFFFFFFu64 as i64, + ); let z0r = _mm512_and_si512(z0r, mask); let z0i = _mm512_and_si512(z0i, mask); @@ -1368,7 +1377,16 @@ unsafe fn block2_full_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> let dif1perm = _mm512_permutex_epi64(dif1, 0x8); let z1i = _mm512_add_epi64(dif3, dif3perm); let z1i = _mm512_add_epi64(z1i, dif1perm); - let mask = _mm512_set_epi64(0, 0, 0xFFFFFFFFFFFFFFFFu64 as i64, 0, 0, 0, 0xFFFFFFFFFFFFFFFFu64 as i64, 0); + let mask = _mm512_set_epi64( + 0, + 0, + 0xFFFFFFFFFFFFFFFFu64 as i64, + 0, + 0, + 0, + 0xFFFFFFFFFFFFFFFFu64 as i64, + 0, + ); let z1r = _mm512_and_si512(z1r, mask); let z1i = _mm512_and_si512(z1i, mask); @@ -1392,7 +1410,16 @@ unsafe fn block2_full_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> let dif3perm2 = _mm512_permutex_epi64(dif3, 0x10); let z2i = _mm512_add_epi64(dif3, dif3perm1); let z2i = _mm512_add_epi64(z2i, dif3perm2); - let mask = _mm512_set_epi64(0, 0xFFFFFFFFFFFFFFFFu64 as i64, 0, 0, 0, 0xFFFFFFFFFFFFFFFFu64 as i64, 0, 0); + let mask = _mm512_set_epi64( + 0, + 0xFFFFFFFFFFFFFFFFu64 as i64, + 0, + 0, + 0, + 0xFFFFFFFFFFFFFFFFu64 as i64, + 0, + 0, + ); let z2r = _mm512_and_si512(z2r, mask); let z2i = _mm512_and_si512(z2i, mask); @@ -1892,7 +1919,10 @@ where } // return 2 hashes of 4 elements each - (vec![state[0], state[1], state[2], state[3]], vec![state[12], state[13], state[14], state[15]]) + ( + vec![state[0], state[1], state[2], state[3]], + vec![state[12], state[13], state[14], state[15]], + ) } pub fn hash_two_avx512(h1: &Vec, h2: &Vec, h3: &Vec, h4: &Vec) -> (Vec, Vec) @@ -1905,5 +1935,8 @@ where state[12..16].copy_from_slice(&h3); state[16..20].copy_from_slice(&h4); state = poseidon_avx512_double(&state); - (vec![state[0], state[1], state[2], state[3]], vec![state[12], state[13], state[14], state[15]]) + ( + vec![state[0], state[1], state[2], state[3]], + vec![state[12], state[13], state[14], state[15]], + ) } From 4c640df4ca2cbb2f881b384c8f74bc740e342b29 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Thu, 17 Oct 2024 15:05:01 +0800 Subject: [PATCH 13/16] fix avx512 poseidon issue --- .../src/hash/arch/x86_64/goldilocks_avx2.rs | 13 +- .../src/hash/arch/x86_64/goldilocks_avx512.rs | 31 ++- .../hash/arch/x86_64/poseidon_bn128_avx2.rs | 21 -- .../arch/x86_64/poseidon_goldilocks_avx2.rs | 13 +- .../arch/x86_64/poseidon_goldilocks_avx512.rs | 38 +++- plonky2/src/hash/poseidon_goldilocks.rs | 196 ++++++++++++++---- 6 files changed, 235 insertions(+), 77 deletions(-) diff --git a/plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs b/plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs index 7be01a9cb8..9df0f2be12 100644 --- a/plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs +++ b/plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs @@ -41,9 +41,16 @@ pub fn add_avx_a_sc(a_sc: &__m256i, b: &__m256i) -> __m256i { #[inline(always)] pub fn add_avx(a: &__m256i, b: &__m256i) -> __m256i { - let a_sc = shift_avx(a); - // let a_sc = toCanonical_avx_s(&a_s); - add_avx_a_sc(&a_sc, b) + unsafe { + let msb = _mm256_set_epi64x(MSB_1, MSB_1, MSB_1, MSB_1); + let a_sc = _mm256_xor_si256(*a, msb); + let c0_s = _mm256_add_epi64(a_sc, *b); + let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1); + let mask_ = _mm256_cmpgt_epi64(a_sc, c0_s); + let corr_ = _mm256_and_si256(mask_, p_n); + let c_s = _mm256_add_epi64(c0_s, corr_); + _mm256_xor_si256(c_s, msb) + } } #[inline(always)] diff --git a/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs index ce86ce67de..adb3ca702b 100644 --- a/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs @@ -46,15 +46,37 @@ pub fn to_canonical_avx512(a: &__m512i) -> __m512i { #[inline(always)] pub fn add_avx512(a: &__m512i, b: &__m512i) -> __m512i { + /* unsafe { // let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_); - let p8_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::()); + let p8_n = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); let c0 = _mm512_add_epi64(*a, *b); let result_mask = _mm512_cmpgt_epu64_mask(*a, c0); _mm512_mask_add_epi64(c0, result_mask, c0, p8_n) } + */ + unsafe { + let msb = _mm512_load_epi64(FC.MSB_V.as_ptr().cast::()); + let a_sc = _mm512_xor_si512(*a, msb); + let c0_s = _mm512_add_epi64(a_sc, *b); + let p_n = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); + let mask_ = _mm512_cmpgt_epi64_mask(a_sc, c0_s); + let c_s = _mm512_mask_add_epi64(c0_s, mask_, c0_s, p_n); + _mm512_xor_si512(c_s, msb) + } } +#[inline(always)] +pub fn add_avx512_s_b_small(a_s: &__m512i, b_small: &__m512i) -> __m512i { + unsafe { + let corr = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); + let c0_s = _mm512_add_epi64(*a_s, *b_small); + let mask_ = _mm512_cmpgt_epi64_mask(*a_s, c0_s); + _mm512_mask_add_epi64(c0_s, mask_, c0_s, corr) + } +} + + #[inline(always)] pub fn sub_avx512(a: &__m512i, b: &__m512i) -> __m512i { unsafe { @@ -82,11 +104,12 @@ pub fn reduce_avx512_128_64(c_h: &__m512i, c_l: &__m512i) -> __m512i { #[inline(always)] pub fn reduce_avx512_96_64(c_h: &__m512i, c_l: &__m512i) -> __m512i { unsafe { - let msb = _mm512_load_si512(FC.MSB_V.as_ptr().cast::()); - let p_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::()); + let msb = _mm512_load_epi64(FC.MSB_V.as_ptr().cast::()); + let p_n = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); let c_ls = _mm512_xor_si512(*c_l, msb); let c2 = _mm512_mul_epu32(*c_h, p_n); - let c_s = add_avx512(&c_ls, &c2); + let c_s = add_avx512_s_b_small(&c_ls, &c2); + // let c_s = add_avx512(&c_ls, &c2); _mm512_xor_si512(c_s, msb) } } diff --git a/plonky2/src/hash/arch/x86_64/poseidon_bn128_avx2.rs b/plonky2/src/hash/arch/x86_64/poseidon_bn128_avx2.rs index 04ce1262bf..2f7039f147 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_bn128_avx2.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_bn128_avx2.rs @@ -84,19 +84,6 @@ unsafe fn sub64(a: &__m256i, b: &__m256i, bin: &__m256i) -> (__m256i, __m256i) { let zeros = _mm256_set_epi64x(0, 0, 0, 0); let (r1, b1) = sub64_no_borrow(a, b); - // TODO - delete - /* - let mut v = [0i64; 4]; - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), *a); - println!("a: {:?}", v); - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), *b); - println!("b: {:?}", v); - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), r1); - println!("r: {:?}", v); - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), b1); - println!("b: {:?}", v); - */ - let m1 = _mm256_cmpeq_epi64(*bin, ones); let m2 = _mm256_cmpeq_epi64(r1, zeros); let m = _mm256_and_si256(m1, m2); @@ -104,14 +91,6 @@ unsafe fn sub64(a: &__m256i, b: &__m256i, bin: &__m256i) -> (__m256i, __m256i) { let r = _mm256_sub_epi64(r1, *bin); let bo = _mm256_or_si256(bo, b1); - // TODO - delete - /* - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), r); - println!("r: {:?}", v); - _mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), bo); - println!("b: {:?}", v); - */ - (r, bo) } diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs index db86c76b98..d10fe828ce 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs @@ -1164,7 +1164,7 @@ unsafe fn mds_layer_avx(s0: &mut __m256i, s1: &mut __m256i, s2: &mut __m256i) { let (rl0, c0) = add64_no_carry(&sl0, &shl0); let (rh0, _) = add64_no_carry(&shh0, &c0); let r0 = reduce_avx_128_64(&rh0, &rl0); - + let (rl1, c1) = add64_no_carry(&sl1, &shl1); let (rh1, _) = add64_no_carry(&shh1, &c1); *s1 = reduce_avx_128_64(&rh1, &rl1); @@ -1393,7 +1393,7 @@ where F: PrimeField64 + Poseidon, { let mut state = &mut input.clone(); - let mut round_ctr = 0; + let mut round_ctr = 0; unsafe { // load state @@ -1410,12 +1410,13 @@ where let rc2 = _mm256_loadu_si256((&rc[8..12]).as_ptr().cast::<__m256i>()); let ss0 = add_avx(&s0, &rc0); let ss1 = add_avx(&s1, &rc1); - let ss2 = add_avx(&s2, &rc2); + let ss2 = add_avx(&s2, &rc2); + (s0, s1, s2) = sbox_avx_m256i(&ss0, &ss1, &ss2); mds_layer_avx(&mut s0, &mut s1, &mut s2); - round_ctr += 1; + round_ctr += 1; } - + // this does partial_first_constant_layer_avx(&mut state); let c0 = _mm256_loadu_si256( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT[0..4]) @@ -1441,7 +1442,7 @@ where _mm256_storeu_si256((state[0..4]).as_mut_ptr().cast::<__m256i>(), s0); _mm256_storeu_si256((state[4..8]).as_mut_ptr().cast::<__m256i>(), s1); _mm256_storeu_si256((state[8..12]).as_mut_ptr().cast::<__m256i>(), s2); - + for i in 0..N_PARTIAL_ROUNDS { state[0] = sbox_monomial(state[0]); state[0] = state[0].add_canonical_u64(FAST_PARTIAL_ROUND_CONSTANTS[i]); diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs index 9dc7c70898..86bc98d60d 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs @@ -1271,11 +1271,36 @@ pub unsafe fn add64_no_carry_avx512(a: &__m512i, b: &__m512i) -> (__m512i, __m51 let m = (ma & mb) | (!mc & ((!ma & mb) | (ma & !mb))); let co = _mm512_mask_blend_epi64(m, zeros, ones); (r, co) + /* + let mut va: [u64; 8] = [0; 8]; + let mut vb: [u64; 8] = [0; 8]; + let mut vr: [u64; 8] = [0; 8]; + let mut vc: [u64; 8] = [0; 8]; + _mm512_storeu_epi64(va.as_mut_ptr().cast::(), *a); + _mm512_storeu_epi64(vb.as_mut_ptr().cast::(), *b); + for i in 0..8 { + vr[i] = va[i].wrapping_add(vb[i]); + vc[i] = if vr[i] < va[i] { 1 } else { 0 }; + } + let r = _mm512_loadu_epi64(vr.as_ptr().cast::()); + let c = _mm512_loadu_epi64(vc.as_ptr().cast::()); + (r, c) + */ } #[inline] pub unsafe fn mul64_no_overflow_avx512(a: &__m512i, b: &__m512i) -> __m512i { - _mm512_mullo_epi64(*a, *b) + // _mm512_mullo_epi64(*a, *b) + let r = _mm512_mul_epu32(*a, *b); + let ah = _mm512_srli_epi64(*a, 32); + let bh = _mm512_srli_epi64(*b, 32); + let r1 = _mm512_mul_epu32(*a, bh); + let r1 = _mm512_slli_epi64(r1, 32); + let r = _mm512_add_epi64(r, r1); + let r1 = _mm512_mul_epu32(ah, *b); + let r1 = _mm512_slli_epi64(r1, 32); + let r = _mm512_add_epi64(r, r1); + r } #[inline(always)] @@ -1479,8 +1504,8 @@ unsafe fn mds_multiply_freq_avx512(s0: &mut __m512i, s1: &mut __m512i, s2: &mut let f0 = block1_avx512(&u0, MDS_FREQ_BLOCK_ONE); // let [v1, v5, v9] = block2([(u[0], v[0]), (u[1], v[1]), (u[2], v[2])], MDS_FREQ_BLOCK_TWO); - // let (f1, f2) = block2_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); - let (f1, f2) = block2_full_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); + let (f1, f2) = block2_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); + // let (f1, f2) = block2_full_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); // let [v2, v6, v10] = block3_avx([u[0], u[1], u[2]], MDS_FREQ_BLOCK_ONE); // [u[0], u[1], u[2]] are all in u3 @@ -1795,6 +1820,7 @@ where let ss0 = add_avx512(&s0, &rc0); let ss1 = add_avx512(&s1, &rc1); let ss2 = add_avx512(&s2, &rc2); + s0 = sbox_avx512_one(&ss0); s1 = sbox_avx512_one(&ss1); s2 = sbox_avx512_one(&ss2); @@ -1900,13 +1926,13 @@ where leaf_size / SPONGE_RATE + 1 }; for _ in 0..loops { - let end1 = if idx1 + SPONGE_RATE > leaf_size { + let end1 = if idx1 + SPONGE_RATE >= leaf_size { leaf_size } else { idx1 + SPONGE_RATE }; - let end2 = if idx2 + SPONGE_RATE > inputs.len() { - inputs.len() + let end2 = if idx2 + SPONGE_RATE >= 2 * leaf_size { + 2 * leaf_size } else { idx2 + SPONGE_RATE }; diff --git a/plonky2/src/hash/poseidon_goldilocks.rs b/plonky2/src/hash/poseidon_goldilocks.rs index f78c3f4a24..fa3752629d 100644 --- a/plonky2/src/hash/poseidon_goldilocks.rs +++ b/plonky2/src/hash/poseidon_goldilocks.rs @@ -450,6 +450,8 @@ mod tests { use crate::field::goldilocks_field::GoldilocksField as F; use crate::field::types::{Field, PrimeField64}; #[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] + use crate::hash::arch::x86_64::poseidon_goldilocks_avx512::hash_leaf_avx512; + #[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] use crate::hash::poseidon::test_helpers::check_test_vectors_avx512; use crate::hash::poseidon::test_helpers::{check_consistency, check_test_vectors}; use crate::hash::poseidon::{Poseidon, PoseidonHash}; @@ -488,6 +490,12 @@ mod tests { [0xa89280105650c4ec, 0xab542d53860d12ed, 0x5704148e9ccab94f, 0xd3a826d4b62da9f5, 0x8a7a6ca87892574f, 0xc7017e1cad1a674e, 0x1f06668922318e34, 0xa3b203bc8102676f, 0xfcc781b0ce382bf2, 0x934c69ff3ed14ba5, 0x504688a5996e8f13, 0x401f3f2ed524a2ba, ]), + ([0xf2cc0ce426e7eddd, 0x91ad40f14cfdcb78, 0xc516c642346aabc, 0xa79a0411d96de0, + 0xf256c881b6167069, 0x5c767aa6354a647b, 0x79a821313415b9dc, 0xf083bc2f276b99e1, + 0x9aa0ac0171df5ac7, 0xc3c705daf69d66e0, 0x3b0468abe66c5ed, 0xdcf835c4d4cffd73, ], + [0x96d91d333e5e038d, 0x114395c7cfb7e18f, 0x19b1ea99556391ff, 0xd53855a776b4582a, + 0x378d8ea4ffbb7545, 0x168319892eff226a, 0x5f09f06508283bd, 0xb92d599c947cc2f1, + 0xf078fc732200e4d4, 0xcaf95e4285f3099d, 0x8532be1f10f23cd0, 0xc3260991186909ff, ]) ]; check_test_vectors::(test_vectors12.clone()); @@ -576,40 +584,144 @@ mod tests { #[test] fn test_hash_no_pad_gl() { - let inputs: [u64; 32] = [ - 9972144316416239374, - 7195869958086994472, - 12805395537960412263, - 6755149769410714396, - 16592921959755212957, - 1370750654791741308, - 11186995120529280354, - 288690570896506034, - 2896720011649362435, - 13870686984275550055, - 12288026009924247278, - 15608864109019511973, - 15690944173815210604, - 17535150735055770942, - 4265223756233917229, - 17236464151311603291, - 15180455466814482598, - 12377438429067983442, - 11274960245127600167, - 5684300978461808754, - 1918159483831849502, - 15340265949423289730, - 181633163915570313, - 12684059848091546996, - 10060377187090493210, - 13523019938818230572, - 16846214147461656883, - 13560222746484567233, - 2150999602305437005, - 9103462636082953981, - 16341057499572706412, - 842265247111451937, + let inputs = [ + 0xb8f463d7cb4f24f6, + 0xe94ad9aba668af65, + 0x4a31c8cee787786a, + 0x7f8ed7050aeadcf9, + 0x516c34f52a5c8b14, + 0x542c22306722b175, + 0x6feba1eb9030ecb9, + 0xe103d491fa784080, + 0x31d9a62ea39f4ec9, + 0xbf0ccc95d9b4c697, + 0x5a9d230167523b2e, + 0x7ff277e12091d2f2, + 0xf2af521b9537abf3, + 0xe39e815313da5c12, + 0xe5feaa1e4f46b87b, + 0x76b772a9e6eda11c, + 0x9005e1c8fbf27eed, + 0x78ea9242b53108ac, + 0x5561d33040b6affb, + 0x61ded48ffee1f243, + 0xebbe0c4034afb9e5, + 0x7973d462ab14d331, + 0x76a23e459a0849b, + 0x9fa93d23d8b84515, + 0x1e19bba2ce8042dd, + 0xb1159302625b71a3, + 0x792e2e4171fd7e83, + 0xc9088b032be7eff0, + 0x6540b29fbec19cb2, + 0x8c4f849dd68f4cdc, + 0xb91969b7cfcd1ec8, + 0x4d450eff6a3b0c7c, + 0xcace16a8345de56e, + 0xe5bac07b93e1f0e2, + 0x35088bde4f1bd3a9, + 0x2e0bd8e257386e40, + 0xed67fe1bd44680f0, + 0x887a32a6049105f, + 0x3ae86d4d60b87a67, + 0x665a656a217edacf, + 0x2eb451b933acbd2d, + 0x63876760e9570fb4, + 0x2b11da28eb95d7d6, + 0x138ea36659579c0a, + 0x457f674d92cfcd72, + 0xba4b8ffc7287142d, + 0x2b9bd3cd64e65cb6, + 0x2780e8b0e66848e8, + 0xe18303c5010835a4, + 0x6c4e379aba35e21e, + 0xf9c3f2f33320d9cd, + 0x82429ba2d6263c9a, + 0x11e81115fa995e88, + 0x75a7fb5681cd15e4, + 0xa54b2a0b6d57e340, + 0x884b3d9cc9b7f720, + 0xdac1b985f5b0ff19, + 0x5938c0405a01dbd4, + 0x13fb2d9399c3ef2e, + 0xeaed82d3706dccec, + 0xf8d853012e56f7fb, + 0xa4c639bbaf484525, + 0xe3b35501c21797ba, + 0x1a645013fcb5e3a0, + 0xf2eb2337ba169178, + 0xcc94fd9269c7d33, + 0x82a9aaa398b13f1, + 0xe9b5ecbe6576234, + 0x252287d7ed9ec792, + 0x30629bee322f17cc, + 0x9ae26078f44e8afb, + 0xabdc35ac8f527136, + 0x4b2a3be4ef4c231f, + 0x23074d5363eeba58, + 0x75cfe940f6967c16, + 0xfb185a23f6225406, + 0xda8a21bd2ba64cc3, + 0xd623bde11eb8c989, + 0x76201928e4523ba3, + 0x1c20cb194495b643, + 0x3e70ce2fddc52451, + 0x86c698ca61fdae8e, + 0x9855dd30ad0c1309, + 0x271541a781755737, + 0x209b4ccf7db16277, + 0xff27cae2771d1d8c, + 0xd7795488a7bfe6ee, + 0x9cf1875ec535778e, + 0x9fad94c126427390, + 0x199b482c029f3d9d, + 0x92ae2055bb3f6d6, + 0x29d6100b44167374, + 0x88e8c8ffdefe0f33, + 0xa3d8d929ea748a62, + 0xd5dbe1a3d99e113d, + 0x438639f8f0e3ff25, + 0xf2cc0ce426e7eddd, + 0x91ad40f14cfdcb78, + 0xc516c642346aabc, + 0xa79a0411d96de0, + 0xf256c881b6167069, + 0x5c767aa6354a647b, + 0x79a821313415b9dc, + 0xf083bc2f276b99e1, + 0x9d47fc86eb2de7c2, + 0x3370a8711a678a03, + 0x1572c8a8bf872b26, + 0xdbb7de1fc45360a1, + 0x5f87c0fe24bafdd4, + 0x2f6a5784207d118a, + 0x640c588afcf0cc14, + 0xe609f3cbb7cb015, + 0x8e4907544019be80, + 0xde2f553ac4ab68c3, + 0x29cd0d2800262365, + 0x3bf736a6fbc14ce2, + 0xab059c3c3cba4912, + 0xe609e14997bd2f5c, + 0x694189d934ff1f8d, + 0x54570348f45e3a9, + 0x90ef5b98b0a08a34, + 0x1b09b93749616de8, + 0x89be3144389d48c1, + 0xdaa7e268d0fd82d8, + 0xc46956b67fa89c61, + 0xec88a7133e4fefc, + 0xe41596ca682069f4, + 0x297f55e46472431b, + 0x33ada14fd813218d, + 0x22c57ca5e77249ad, + 0x4e2f2c7cc99f2d47, + 0x78d11ba2efc7556f, + 0xdfc98976b6e3ad0d, + 0x59d88f72bf5ad1d8, + 0x19ca05690b8e1ad9, ]; + let inputs = inputs .iter() .map(|x| F::from_canonical_u64(*x)) @@ -617,15 +729,25 @@ mod tests { let output = PoseidonHash::hash_no_pad(&inputs); let expected_out: [u64; 4] = [ - 8197835875512527937, - 7109417654116018994, - 18237163116575285904, - 17017896878738047012, + 0xc19dccf6ec4f3df3, + 0x1bf0d65af6925451, + 0xee9dbf2c8dcad9a2, + 0xae46323715f528a1, ]; let expected_out = expected_out .iter() .map(|x| F::from_canonical_u64(*x)) .collect::>(); assert_eq!(output.elements.to_vec(), expected_out); + + #[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))] + { + let mut dleaf: Vec = vec![F::from_canonical_u64(0); 2 * inputs.len()]; + dleaf[0..inputs.len()].copy_from_slice(&inputs); + dleaf[inputs.len()..2 * inputs.len()].copy_from_slice(&inputs); + let (h1, h2) = hash_leaf_avx512(dleaf.as_slice(), inputs.len()); + assert_eq!(h1, expected_out); + assert_eq!(h2, expected_out); + } } } From 7b892617d7f6fb6b9ec84a9b0b6247d4ea35f33d Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Thu, 17 Oct 2024 18:08:30 +0800 Subject: [PATCH 14/16] fix mul issue in avx512 --- .../src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs index 86bc98d60d..4a20c49202 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs @@ -1290,7 +1290,8 @@ pub unsafe fn add64_no_carry_avx512(a: &__m512i, b: &__m512i) -> (__m512i, __m51 #[inline] pub unsafe fn mul64_no_overflow_avx512(a: &__m512i, b: &__m512i) -> __m512i { - // _mm512_mullo_epi64(*a, *b) + /* + // long version let r = _mm512_mul_epu32(*a, *b); let ah = _mm512_srli_epi64(*a, 32); let bh = _mm512_srli_epi64(*b, 32); @@ -1301,6 +1302,8 @@ pub unsafe fn mul64_no_overflow_avx512(a: &__m512i, b: &__m512i) -> __m512i { let r1 = _mm512_slli_epi64(r1, 32); let r = _mm512_add_epi64(r, r1); r + */ + _mm512_mullo_epi64(*a, *b) } #[inline(always)] @@ -1504,8 +1507,8 @@ unsafe fn mds_multiply_freq_avx512(s0: &mut __m512i, s1: &mut __m512i, s2: &mut let f0 = block1_avx512(&u0, MDS_FREQ_BLOCK_ONE); // let [v1, v5, v9] = block2([(u[0], v[0]), (u[1], v[1]), (u[2], v[2])], MDS_FREQ_BLOCK_TWO); - let (f1, f2) = block2_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); - // let (f1, f2) = block2_full_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); + // let (f1, f2) = block2_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); + let (f1, f2) = block2_full_avx512(&u1, &u2, MDS_FREQ_BLOCK_TWO); // let [v2, v6, v10] = block3_avx([u[0], u[1], u[2]], MDS_FREQ_BLOCK_ONE); // [u[0], u[1], u[2]] are all in u3 From 8cc4ca276fd18e51c696d6df2db20875994286c9 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Fri, 18 Oct 2024 10:25:54 +0800 Subject: [PATCH 15/16] optimize avx512 code --- .../src/hash/arch/x86_64/goldilocks_avx512.rs | 13 +- .../arch/x86_64/poseidon_goldilocks_avx512.rs | 274 ++++++++---------- 2 files changed, 122 insertions(+), 165 deletions(-) diff --git a/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs index adb3ca702b..e67818e102 100644 --- a/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs @@ -6,19 +6,22 @@ use crate::hash::hash_types::RichField; const MSB_: i64 = 0x8000000000000000u64 as i64; const P8_: i64 = 0xFFFFFFFF00000001u64 as i64; const P8_N_: i64 = 0xFFFFFFFF; +const ONE_: i64 = 1; #[allow(non_snake_case)] #[repr(align(64))] -struct FieldConstants { - MSB_V: [i64; 8], - P8_V: [i64; 8], - P8_N_V: [i64; 8], +pub(crate) struct FieldConstants { + pub(crate) MSB_V: [i64; 8], + pub(crate) P8_V: [i64; 8], + pub(crate) P8_N_V: [i64; 8], + pub(crate) ONE_V: [i64; 8], } -const FC: FieldConstants = FieldConstants { +pub(crate) const FC: FieldConstants = FieldConstants { MSB_V: [MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_], P8_V: [P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_], P8_N_V: [P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_], + ONE_V: [ONE_, ONE_, ONE_, ONE_, ONE_, ONE_, ONE_, ONE_], }; #[allow(dead_code)] diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs index 4a20c49202..282c19a44e 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx512.rs @@ -1134,8 +1134,8 @@ where let mut result = [F::ZERO; SPONGE_WIDTH]; let res0 = state[0]; unsafe { - let mut r0 = _mm512_loadu_si512((&mut result[0..8]).as_mut_ptr().cast::()); - let mut r1 = _mm512_loadu_si512((&mut result[4..12]).as_mut_ptr().cast::()); + let mut r0 = _mm512_loadu_epi64((&mut result[0..8]).as_mut_ptr().cast::()); + let mut r1 = _mm512_loadu_epi64((&mut result[4..12]).as_mut_ptr().cast::()); for r in 1..12 { let sr512 = _mm512_set_epi64( @@ -1148,23 +1148,23 @@ where state[r].to_canonical_u64() as i64, state[r].to_canonical_u64() as i64, ); - let t0 = _mm512_loadu_si512( + let t0 = _mm512_loadu_epi64( (&FAST_PARTIAL_ROUND_INITIAL_MATRIX[r][0..8]) .as_ptr() - .cast::(), + .cast::(), ); - let t1 = _mm512_loadu_si512( + let t1 = _mm512_loadu_epi64( (&FAST_PARTIAL_ROUND_INITIAL_MATRIX[r][4..12]) .as_ptr() - .cast::(), + .cast::(), ); let m0 = mult_avx512(&sr512, &t0); let m1 = mult_avx512(&sr512, &t1); r0 = add_avx512(&r0, &m0); r1 = add_avx512(&r1, &m1); } - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), r0); - _mm512_storeu_si512((state[4..12]).as_mut_ptr().cast::(), r1); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), r0); + _mm512_storeu_epi64((state[4..12]).as_mut_ptr().cast::(), r1); state[0] = res0; } } @@ -1177,22 +1177,22 @@ where F: PrimeField64, { unsafe { - let c0 = _mm512_loadu_si512( + let c0 = _mm512_loadu_epi64( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT[0..8]) .as_ptr() - .cast::(), + .cast::(), ); - let c1 = _mm512_loadu_si512( + let c1 = _mm512_loadu_epi64( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT[4..12]) .as_ptr() - .cast::(), + .cast::(), ); - let mut s0 = _mm512_loadu_si512((state[0..8]).as_ptr().cast::()); - let mut s1 = _mm512_loadu_si512((state[4..12]).as_ptr().cast::()); + let mut s0 = _mm512_loadu_epi64((state[0..8]).as_ptr().cast::()); + let mut s1 = _mm512_loadu_epi64((state[4..12]).as_ptr().cast::()); s0 = add_avx512(&s0, &c0); s1 = add_avx512(&s1, &c1); - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), s0); - _mm512_storeu_si512((state[4..12]).as_mut_ptr().cast::(), s1); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), s0); + _mm512_storeu_epi64((state[4..12]).as_mut_ptr().cast::(), s1); } } @@ -1262,30 +1262,16 @@ pub unsafe fn add64_no_carry_avx512(a: &__m512i, b: &__m512i) -> (__m512i, __m51 * - (test 3): if a + b < 2^64 (this means a + b is negative in signed representation) => no overflow so cout = 0 * - (test 3): if a + b >= 2^64 (this means a + b becomes positive in signed representation, that is, a + b >= 0) => there is overflow so cout = 1 */ - let ones = _mm512_set_epi64(1, 1, 1, 1, 1, 1, 1, 1); + let ones = _mm512_load_epi64(FC.ONE_V.as_ptr().cast::()); let zeros = _mm512_xor_si512(*a, *a); // faster 0 let r = _mm512_add_epi64(*a, *b); let ma = _mm512_cmpgt_epi64_mask(zeros, *a); let mb = _mm512_cmpgt_epi64_mask(zeros, *b); let mc = _mm512_cmpgt_epi64_mask(zeros, r); - let m = (ma & mb) | (!mc & ((!ma & mb) | (ma & !mb))); + // let m = (ma & mb) | (!mc & ((!ma & mb) | (ma & !mb))); + let m = (ma & mb) | (!mc & (ma ^ mb)); let co = _mm512_mask_blend_epi64(m, zeros, ones); (r, co) - /* - let mut va: [u64; 8] = [0; 8]; - let mut vb: [u64; 8] = [0; 8]; - let mut vr: [u64; 8] = [0; 8]; - let mut vc: [u64; 8] = [0; 8]; - _mm512_storeu_epi64(va.as_mut_ptr().cast::(), *a); - _mm512_storeu_epi64(vb.as_mut_ptr().cast::(), *b); - for i in 0..8 { - vr[i] = va[i].wrapping_add(vb[i]); - vc[i] = if vr[i] < va[i] { 1 } else { 0 }; - } - let r = _mm512_loadu_epi64(vr.as_ptr().cast::()); - let c = _mm512_loadu_epi64(vc.as_ptr().cast::()); - (r, c) - */ } #[inline] @@ -1303,6 +1289,7 @@ pub unsafe fn mul64_no_overflow_avx512(a: &__m512i, b: &__m512i) -> __m512i { let r = _mm512_add_epi64(r, r1); r */ + // short version _mm512_mullo_epi64(*a, *b) } @@ -1329,16 +1316,16 @@ unsafe fn block1_avx512(x: &__m512i, y: [i64; 3]) -> __m512i { unsafe fn block2_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> (__m512i, __m512i) { let mut vxr: [i64; 8] = [0; 8]; let mut vxi: [i64; 8] = [0; 8]; - _mm512_storeu_si512(vxr.as_mut_ptr().cast::(), *xr); - _mm512_storeu_si512(vxi.as_mut_ptr().cast::(), *xi); + _mm512_storeu_epi64(vxr.as_mut_ptr().cast::(), *xr); + _mm512_storeu_epi64(vxi.as_mut_ptr().cast::(), *xi); let x1: [(i64, i64); 3] = [(vxr[0], vxi[0]), (vxr[1], vxi[1]), (vxr[2], vxi[2])]; let x2: [(i64, i64); 3] = [(vxr[4], vxi[4]), (vxr[5], vxi[5]), (vxr[6], vxi[6])]; let b1 = block2(x1, y); let b2 = block2(x2, y); vxr = [b1[0].0, b1[1].0, b1[2].0, 0, b2[0].0, b2[1].0, b2[2].0, 0]; vxi = [b1[0].1, b1[1].1, b1[2].1, 0, b2[0].1, b2[1].1, b2[2].1, 0]; - let rr = _mm512_loadu_si512(vxr.as_ptr().cast::()); - let ri = _mm512_loadu_si512(vxi.as_ptr().cast::()); + let rr = _mm512_loadu_epi64(vxr.as_ptr().cast::()); + let ri = _mm512_loadu_epi64(vxi.as_ptr().cast::()); (rr, ri) } @@ -1371,18 +1358,9 @@ unsafe fn block2_full_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> let dif1perm2 = _mm512_permutex_epi64(dif1, 0x2); let z0i = _mm512_add_epi64(dif3, dif1perm1); let z0i = _mm512_add_epi64(z0i, dif1perm2); - let mask = _mm512_set_epi64( - 0, - 0, - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - 0, - 0, - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - ); - let z0r = _mm512_and_si512(z0r, mask); - let z0i = _mm512_and_si512(z0i, mask); + let zeros = _mm512_xor_si512(z0r, z0r); + let z0r = _mm512_mask_blend_epi64(0x11, zeros, z0r); + let z0i = _mm512_mask_blend_epi64(0x11, zeros, z0i); // z1 // z1r = dif2[0] + dif2[1] + prod[2] - sum[2]; @@ -1405,18 +1383,8 @@ unsafe fn block2_full_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> let dif1perm = _mm512_permutex_epi64(dif1, 0x8); let z1i = _mm512_add_epi64(dif3, dif3perm); let z1i = _mm512_add_epi64(z1i, dif1perm); - let mask = _mm512_set_epi64( - 0, - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - 0, - 0, - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - 0, - ); - let z1r = _mm512_and_si512(z1r, mask); - let z1i = _mm512_and_si512(z1i, mask); + let z1r = _mm512_mask_blend_epi64(0x22, zeros, z1r); + let z1i = _mm512_mask_blend_epi64(0x22, zeros, z1i); // z2 // z2r = dif2[0] + dif2[1] + dif2[2]; @@ -1438,18 +1406,8 @@ unsafe fn block2_full_avx512(xr: &__m512i, xi: &__m512i, y: [(i64, i64); 3]) -> let dif3perm2 = _mm512_permutex_epi64(dif3, 0x10); let z2i = _mm512_add_epi64(dif3, dif3perm1); let z2i = _mm512_add_epi64(z2i, dif3perm2); - let mask = _mm512_set_epi64( - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - 0, - 0, - 0, - 0xFFFFFFFFFFFFFFFFu64 as i64, - 0, - 0, - ); - let z2r = _mm512_and_si512(z2r, mask); - let z2i = _mm512_and_si512(z2i, mask); + let z2r = _mm512_mask_blend_epi64(0x44, zeros, z2r); + let z2i = _mm512_mask_blend_epi64(0x44, zeros, z2i); let zr = _mm512_or_si512(z0r, z1r); let zr = _mm512_or_si512(zr, z2r); @@ -1528,10 +1486,7 @@ unsafe fn mds_multiply_freq_avx512(s0: &mut __m512i, s1: &mut __m512i, s2: &mut #[inline(always)] #[unroll_for_loops] unsafe fn mds_layer_avx512(s0: &mut __m512i, s1: &mut __m512i, s2: &mut __m512i) { - let mask = _mm512_set_epi64( - 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, - 0xFFFFFFFF, - ); + let mask = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); let mut sl0 = _mm512_and_si512(*s0, mask); let mut sl1 = _mm512_and_si512(*s1, mask); let mut sl2 = _mm512_and_si512(*s2, mask); @@ -1573,48 +1528,48 @@ unsafe fn mds_partial_layer_init_avx512(s0: &mut __m512i, s1: &mut __m512i, s where F: PrimeField64, { - let mut result = [F::ZERO; 2 * SPONGE_WIDTH]; let res0 = *s0; - - let mut r0 = _mm512_loadu_si512((&mut result[0..8]).as_mut_ptr().cast::()); - let mut r1 = _mm512_loadu_si512((&mut result[0..8]).as_mut_ptr().cast::()); - let mut r2 = _mm512_loadu_si512((&mut result[0..8]).as_mut_ptr().cast::()); + let mut r0 = _mm512_xor_epi64(res0, res0); + let mut r1 = r0; + let mut r2 = r0; for r in 1..12 { - let sr = match r { - 1 => _mm512_permutex_epi64(*s0, 0x55), - 2 => _mm512_permutex_epi64(*s0, 0xAA), - 3 => _mm512_permutex_epi64(*s0, 0xFF), - 4 => _mm512_permutex_epi64(*s1, 0x0), - 5 => _mm512_permutex_epi64(*s1, 0x55), - 6 => _mm512_permutex_epi64(*s1, 0xAA), - 7 => _mm512_permutex_epi64(*s1, 0xFF), - 8 => _mm512_permutex_epi64(*s2, 0x0), - 9 => _mm512_permutex_epi64(*s2, 0x55), - 10 => _mm512_permutex_epi64(*s2, 0xAA), - 11 => _mm512_permutex_epi64(*s2, 0xFF), - _ => _mm512_permutex_epi64(*s0, 0x55), - }; - let t0 = _mm512_loadu_si512( - (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][0..8]) - .as_ptr() - .cast::(), - ); - let t1 = _mm512_loadu_si512( - (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][8..16]) - .as_ptr() - .cast::(), - ); - let t2 = _mm512_loadu_si512( - (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][16..24]) - .as_ptr() - .cast::(), - ); - let m0 = mult_avx512(&sr, &t0); - let m1 = mult_avx512(&sr, &t1); - let m2 = mult_avx512(&sr, &t2); - r0 = add_avx512(&r0, &m0); - r1 = add_avx512(&r1, &m1); - r2 = add_avx512(&r2, &m2); + if r < 12 { + let sr = match r { + 1 => _mm512_permutex_epi64(*s0, 0x55), + 2 => _mm512_permutex_epi64(*s0, 0xAA), + 3 => _mm512_permutex_epi64(*s0, 0xFF), + 4 => _mm512_permutex_epi64(*s1, 0x0), + 5 => _mm512_permutex_epi64(*s1, 0x55), + 6 => _mm512_permutex_epi64(*s1, 0xAA), + 7 => _mm512_permutex_epi64(*s1, 0xFF), + 8 => _mm512_permutex_epi64(*s2, 0x0), + 9 => _mm512_permutex_epi64(*s2, 0x55), + 10 => _mm512_permutex_epi64(*s2, 0xAA), + 11 => _mm512_permutex_epi64(*s2, 0xFF), + _ => _mm512_permutex_epi64(*s0, 0x55), + }; + let t0 = _mm512_loadu_epi64( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][0..8]) + .as_ptr() + .cast::(), + ); + let t1 = _mm512_loadu_epi64( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][8..16]) + .as_ptr() + .cast::(), + ); + let t2 = _mm512_loadu_epi64( + (&FAST_PARTIAL_ROUND_INITIAL_MATRIX_AVX512[r][16..24]) + .as_ptr() + .cast::(), + ); + let m0 = mult_avx512(&sr, &t0); + let m1 = mult_avx512(&sr, &t1); + let m2 = mult_avx512(&sr, &t2); + r0 = add_avx512(&r0, &m0); + r1 = add_avx512(&r1, &m1); + r2 = add_avx512(&r2, &m2); + } } *s0 = _mm512_mask_blend_epi64(0x11, r0, res0); *s1 = r1; @@ -1677,20 +1632,20 @@ unsafe fn mds_partial_layer_fast_avx512( state[0].to_noncanonical_u64() as i64, state[0].to_noncanonical_u64() as i64, ); - let rc0 = _mm512_loadu_si512( + let rc0 = _mm512_loadu_epi64( (&FAST_PARTIAL_ROUND_VS_AVX512[r][0..8]) .as_ptr() - .cast::(), + .cast::(), ); - let rc1 = _mm512_loadu_si512( + let rc1 = _mm512_loadu_epi64( (&FAST_PARTIAL_ROUND_VS_AVX512[r][8..16]) .as_ptr() - .cast::(), + .cast::(), ); - let rc2 = _mm512_loadu_si512( + let rc2 = _mm512_loadu_epi64( (&FAST_PARTIAL_ROUND_VS_AVX512[r][16..24]) .as_ptr() - .cast::(), + .cast::(), ); let (mh, ml) = mult_avx512_128(&ss0, &rc0); let m = reduce_avx512_128_64(&mh, &ml); @@ -1715,9 +1670,9 @@ unsafe fn mds_partial_layer_fast_avx512( let m = reduce_avx512_128_64(&mh, &ml); *s2 = add_avx512(s2, &m); - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), *s0); - _mm512_storeu_si512((state[8..16]).as_mut_ptr().cast::(), *s1); - _mm512_storeu_si512((state[16..24]).as_mut_ptr().cast::(), *s2); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), *s0); + _mm512_storeu_epi64((state[8..16]).as_mut_ptr().cast::(), *s1); + _mm512_storeu_epi64((state[16..24]).as_mut_ptr().cast::(), *s2); } #[allow(unused)] @@ -1732,22 +1687,22 @@ where // Self::full_rounds(&mut state, &mut round_ctr); for _ in 0..HALF_N_FULL_ROUNDS { // load state - let s0 = _mm512_loadu_si512((&state[0..8]).as_ptr().cast::()); - let s1 = _mm512_loadu_si512((&state[4..12]).as_ptr().cast::()); + let s0 = _mm512_loadu_epi64((&state[0..8]).as_ptr().cast::()); + let s1 = _mm512_loadu_epi64((&state[4..12]).as_ptr().cast::()); let rc: &[u64; 12] = &ALL_ROUND_CONSTANTS[SPONGE_WIDTH * round_ctr..][..SPONGE_WIDTH] .try_into() .unwrap(); - let rc0 = _mm512_loadu_si512((&rc[0..8]).as_ptr().cast::()); - let rc1 = _mm512_loadu_si512((&rc[4..12]).as_ptr().cast::()); + let rc0 = _mm512_loadu_epi64((&rc[0..8]).as_ptr().cast::()); + let rc1 = _mm512_loadu_epi64((&rc[4..12]).as_ptr().cast::()); let ss0 = add_avx512(&s0, &rc0); let ss1 = add_avx512(&s1, &rc1); let r0 = sbox_avx512_one(&ss0); let r1 = sbox_avx512_one(&ss1); // store state - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), r0); - _mm512_storeu_si512((state[4..12]).as_mut_ptr().cast::(), r1); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), r0); + _mm512_storeu_epi64((state[4..12]).as_mut_ptr().cast::(), r1); *state = ::mds_layer(&state); round_ctr += 1; @@ -1765,22 +1720,22 @@ where // Self::full_rounds(&mut state, &mut round_ctr); for _ in 0..HALF_N_FULL_ROUNDS { // load state - let s0 = _mm512_loadu_si512((&state[0..8]).as_ptr().cast::()); - let s1 = _mm512_loadu_si512((&state[4..12]).as_ptr().cast::()); + let s0 = _mm512_loadu_epi64((&state[0..8]).as_ptr().cast::()); + let s1 = _mm512_loadu_epi64((&state[4..12]).as_ptr().cast::()); let rc: &[u64; 12] = &ALL_ROUND_CONSTANTS[SPONGE_WIDTH * round_ctr..][..SPONGE_WIDTH] .try_into() .unwrap(); - let rc0 = _mm512_loadu_si512((&rc[0..8]).as_ptr().cast::()); - let rc1 = _mm512_loadu_si512((&rc[4..12]).as_ptr().cast::()); + let rc0 = _mm512_loadu_epi64((&rc[0..8]).as_ptr().cast::()); + let rc1 = _mm512_loadu_epi64((&rc[4..12]).as_ptr().cast::()); let ss0 = add_avx512(&s0, &rc0); let ss1 = add_avx512(&s1, &rc1); let r0 = sbox_avx512_one(&ss0); let r1 = sbox_avx512_one(&ss1); // store state - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), r0); - _mm512_storeu_si512((state[4..12]).as_mut_ptr().cast::(), r1); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), r0); + _mm512_storeu_epi64((state[4..12]).as_mut_ptr().cast::(), r1); *state = ::mds_layer(&state); // mds_layer_avx::(&mut s0, &mut s1, &mut s2); @@ -1808,22 +1763,21 @@ where unsafe { // load state - let mut s0 = _mm512_loadu_si512((&state[0..8]).as_ptr().cast::()); - let mut s1 = _mm512_loadu_si512((&state[8..16]).as_ptr().cast::()); - let mut s2 = _mm512_loadu_si512((&state[16..24]).as_ptr().cast::()); + let mut s0 = _mm512_loadu_epi64((&state[0..8]).as_ptr().cast::()); + let mut s1 = _mm512_loadu_epi64((&state[8..16]).as_ptr().cast::()); + let mut s2 = _mm512_loadu_epi64((&state[16..24]).as_ptr().cast::()); for _ in 0..HALF_N_FULL_ROUNDS { let rc: &[u64; 24] = &ALL_ROUND_CONSTANTS_AVX512[2 * SPONGE_WIDTH * round_ctr..] [..2 * SPONGE_WIDTH] .try_into() .unwrap(); - let rc0 = _mm512_loadu_si512((&rc[0..8]).as_ptr().cast::()); - let rc1 = _mm512_loadu_si512((&rc[8..16]).as_ptr().cast::()); - let rc2 = _mm512_loadu_si512((&rc[16..24]).as_ptr().cast::()); + let rc0 = _mm512_loadu_epi64((&rc[0..8]).as_ptr().cast::()); + let rc1 = _mm512_loadu_epi64((&rc[8..16]).as_ptr().cast::()); + let rc2 = _mm512_loadu_epi64((&rc[16..24]).as_ptr().cast::()); let ss0 = add_avx512(&s0, &rc0); let ss1 = add_avx512(&s1, &rc1); let ss2 = add_avx512(&s2, &rc2); - s0 = sbox_avx512_one(&ss0); s1 = sbox_avx512_one(&ss1); s2 = sbox_avx512_one(&ss2); @@ -1832,20 +1786,20 @@ where } // this does partial_first_constant_layer_avx(&mut state); - let c0 = _mm512_loadu_si512( + let c0 = _mm512_loadu_epi64( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT_AVX512[0..8]) .as_ptr() - .cast::(), + .cast::(), ); - let c1 = _mm512_loadu_si512( + let c1 = _mm512_loadu_epi64( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT_AVX512[8..16]) .as_ptr() - .cast::(), + .cast::(), ); - let c2 = _mm512_loadu_si512( + let c2 = _mm512_loadu_epi64( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT_AVX512[16..24]) .as_ptr() - .cast::(), + .cast::(), ); s0 = add_avx512(&s0, &c0); s1 = add_avx512(&s1, &c1); @@ -1853,9 +1807,9 @@ where mds_partial_layer_init_avx512::(&mut s0, &mut s1, &mut s2); - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), s0); - _mm512_storeu_si512((state[8..16]).as_mut_ptr().cast::(), s1); - _mm512_storeu_si512((state[16..24]).as_mut_ptr().cast::(), s2); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), s0); + _mm512_storeu_epi64((state[8..16]).as_mut_ptr().cast::(), s1); + _mm512_storeu_epi64((state[16..24]).as_mut_ptr().cast::(), s2); for i in 0..N_PARTIAL_ROUNDS { state[0] = sbox_monomial(state[0]); @@ -1873,9 +1827,9 @@ where [..2 * SPONGE_WIDTH] .try_into() .unwrap(); - let rc0 = _mm512_loadu_si512((&rc[0..8]).as_ptr().cast::()); - let rc1 = _mm512_loadu_si512((&rc[8..16]).as_ptr().cast::()); - let rc2 = _mm512_loadu_si512((&rc[16..24]).as_ptr().cast::()); + let rc0 = _mm512_loadu_epi64((&rc[0..8]).as_ptr().cast::()); + let rc1 = _mm512_loadu_epi64((&rc[8..16]).as_ptr().cast::()); + let rc2 = _mm512_loadu_epi64((&rc[16..24]).as_ptr().cast::()); let ss0 = add_avx512(&s0, &rc0); let ss1 = add_avx512(&s1, &rc1); let ss2 = add_avx512(&s2, &rc2); @@ -1887,9 +1841,9 @@ where } // store state - _mm512_storeu_si512((state[0..8]).as_mut_ptr().cast::(), s0); - _mm512_storeu_si512((state[8..16]).as_mut_ptr().cast::(), s1); - _mm512_storeu_si512((state[16..24]).as_mut_ptr().cast::(), s2); + _mm512_storeu_epi64((state[0..8]).as_mut_ptr().cast::(), s0); + _mm512_storeu_epi64((state[8..16]).as_mut_ptr().cast::(), s1); + _mm512_storeu_epi64((state[16..24]).as_mut_ptr().cast::(), s2); debug_assert_eq!(round_ctr, N_ROUNDS); }; From 6b4bce20bce2aebcc9a1d143fc1fc3d3aa425f06 Mon Sep 17 00:00:00 2001 From: Dumi Loghin Date: Fri, 18 Oct 2024 10:26:19 +0800 Subject: [PATCH 16/16] cargo fmt --- plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs | 3 +-- .../src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs | 12 ++++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs b/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs index e67818e102..dd305d5c8e 100644 --- a/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs +++ b/plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs @@ -59,7 +59,7 @@ pub fn add_avx512(a: &__m512i, b: &__m512i) -> __m512i { } */ unsafe { - let msb = _mm512_load_epi64(FC.MSB_V.as_ptr().cast::()); + let msb = _mm512_load_epi64(FC.MSB_V.as_ptr().cast::()); let a_sc = _mm512_xor_si512(*a, msb); let c0_s = _mm512_add_epi64(a_sc, *b); let p_n = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::()); @@ -79,7 +79,6 @@ pub fn add_avx512_s_b_small(a_s: &__m512i, b_small: &__m512i) -> __m512i { } } - #[inline(always)] pub fn sub_avx512(a: &__m512i, b: &__m512i) -> __m512i { unsafe { diff --git a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs index d10fe828ce..301e043446 100644 --- a/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs +++ b/plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs @@ -1164,7 +1164,7 @@ unsafe fn mds_layer_avx(s0: &mut __m256i, s1: &mut __m256i, s2: &mut __m256i) { let (rl0, c0) = add64_no_carry(&sl0, &shl0); let (rh0, _) = add64_no_carry(&shh0, &c0); let r0 = reduce_avx_128_64(&rh0, &rl0); - + let (rl1, c1) = add64_no_carry(&sl1, &shl1); let (rh1, _) = add64_no_carry(&shh1, &c1); *s1 = reduce_avx_128_64(&rh1, &rl1); @@ -1393,7 +1393,7 @@ where F: PrimeField64 + Poseidon, { let mut state = &mut input.clone(); - let mut round_ctr = 0; + let mut round_ctr = 0; unsafe { // load state @@ -1410,13 +1410,13 @@ where let rc2 = _mm256_loadu_si256((&rc[8..12]).as_ptr().cast::<__m256i>()); let ss0 = add_avx(&s0, &rc0); let ss1 = add_avx(&s1, &rc1); - let ss2 = add_avx(&s2, &rc2); + let ss2 = add_avx(&s2, &rc2); (s0, s1, s2) = sbox_avx_m256i(&ss0, &ss1, &ss2); mds_layer_avx(&mut s0, &mut s1, &mut s2); - round_ctr += 1; + round_ctr += 1; } - + // this does partial_first_constant_layer_avx(&mut state); let c0 = _mm256_loadu_si256( (&FAST_PARTIAL_FIRST_ROUND_CONSTANT[0..4]) @@ -1442,7 +1442,7 @@ where _mm256_storeu_si256((state[0..4]).as_mut_ptr().cast::<__m256i>(), s0); _mm256_storeu_si256((state[4..8]).as_mut_ptr().cast::<__m256i>(), s1); _mm256_storeu_si256((state[8..12]).as_mut_ptr().cast::<__m256i>(), s2); - + for i in 0..N_PARTIAL_ROUNDS { state[0] = sbox_monomial(state[0]); state[0] = state[0].add_canonical_u64(FAST_PARTIAL_ROUND_CONSTANTS[i]);