Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge latest PoseidonBN128 code #35

Merged
merged 7 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 45 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,57 @@
# description
this repo is a fork of https://github.com/0xPolygonZero/plonky2. several optimizations were implemented to boost the computation speed.
# Description

# optimizations
- precompute of fft twiddle factors
- cuda implementation of Goldilocks Field NTT (feature `cuda`)
This repo is a fork of https://github.com/0xPolygonZero/plonky2. To boost speed, several optimizations were implemented:

# Optimizations
- Precompute FFT twiddle factors.
- CUDA implementation of Goldilocks Field NTT (feature `cuda`).
- CUDA implementation of Poseidon (Goldilocks) and Poseidon (BN 128) (feature `cuda`).
- Fixed the AVX implementation for Poseidon (Goldilocks) (target CPU must support AVX2).
- CUDA implementation of Merkle Tree building (feature `cuda`).
- Change Merkle Tree structure from recursive to iterative (1-dimensional vector).

# Dependencies

# dependencies
```
git submodule update --init --recursive
```

# run examples
- cuda NTT
## Benchmarking Merkle Tree building with Poseison hash

Set the latest Rust nightly:
```
cargo run --release -p plonky2_field --features=cuda --example fft
rustup update
rustup override set nightly-x86_64-unknown-linux-gnu
```

CPU, no AVX: ``cargo bench merkle``

CPU with AVX2: ``RUSTFLAGS="-C target-feature=+avx2" cargo bench merkle``

CPU with AVX512: ``RUSTFLAGS="-C target-feature=+avx512dq" cargo bench merkle``

GPU (CUDA): ``cargo bench merkle --features=cuda``

### Results

The results in the table below represent the build time (in milliseconds) of a Merkle Tree with the indicated number of leaves (first row) using the hashing method indicated in the first column. The systems used for benchmarking are:

- first three columns: AMD Ryzen Threadripper PRO 5975WX 32-Cores (only AVX2) + NVIDIA RTX 4090 (feature `cuda`);

- last three columns: AMD Ryzen 9 7950X 16-Core (AVX2 and AVX512DQ).


| Number of MT Leaves | 2^13 | 2^14 | 2^15 | | 2^13 | 2^14 | 2^15 |
| --- | --- | --- | --- | --- | --- | --- | --- |
| Poseidon (no AVX) | 12.4 | 23.4 | 46.6 | | 12.8 | 25.2 | 50.3 |
| Poseidon (AVX) | 11.4 | 21.3 | 39.2 | | 10.3 | 20.3 | 40.2 |
| Poseidon (AVX512) | - | - | - | | 12.3 | 24.1 | 47.8 |
| Poseidon (GPU) | 8 | 14.3 | 26.5 | | - | - | - |
| Poseidon BN 128 (no AVX) | 111.9 | 223 | 446.3 | | 176.9 | 351 | 699.1 |
| Poseidon BN 128 (AVX) | 146.8 | 291.7 | 581.8 | | 220.1 | 433.5 | 858.8 |
| Poseidon BN 128 (AVX512) | - | - | - | | WIP | WIP | WIP |
| Poseidon BN 128 (GPU) | 37.5 | 57.6 | 92.9 | | - | - | - |

## Running

To see recursion performance, one can run this bench, which generates a chain of three recursion proofs:
Expand Down
2 changes: 2 additions & 0 deletions plonky2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ cuda = ["cryptography_cuda/cuda"]
no_cuda = ["cryptography_cuda/no_cuda"]
batch = []
cuda_timing = []
papi = []

[dependencies]
ahash = { workspace = true }
Expand All @@ -41,6 +42,7 @@ static_assertions = { workspace = true }
unroll = { workspace = true }
web-time = { version = "1.0.0", optional = true }
once_cell = { version = "1.18.0" }
papi-bindings = { version = "0.5.2" }

# Local dependencies
plonky2_field = { version = "0.2.0", path = "../field", default-features = false }
Expand Down
79 changes: 60 additions & 19 deletions plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,26 @@ use core::arch::x86_64::*;

use crate::hash::hash_types::RichField;

const MSB_: i64 = 0x8000000000000000u64 as i64;
const P_s_: i64 = 0x7FFFFFFF00000001u64 as i64;
const P_n_: i64 = 0xFFFFFFFF;
const MSB_1: i64 = 0x8000000000000000u64 as i64;
const P_S_1: i64 = 0x7FFFFFFF00000001u64 as i64;
const P_N_1: i64 = 0xFFFFFFFF;

#[inline(always)]
pub fn shift_avx(a: &__m256i) -> __m256i {
unsafe {
let MSB = _mm256_set_epi64x(MSB_, MSB_, MSB_, MSB_);
_mm256_xor_si256(*a, MSB)
let msb = _mm256_set_epi64x(MSB_1, MSB_1, MSB_1, MSB_1);
_mm256_xor_si256(*a, msb)
}
}

#[allow(dead_code)]
#[inline(always)]
pub fn toCanonical_avx_s(a_s: &__m256i) -> __m256i {
pub fn to_canonical_avx_s(a_s: &__m256i) -> __m256i {
unsafe {
let P_s = _mm256_set_epi64x(P_s_, P_s_, P_s_, P_s_);
let P_n = _mm256_set_epi64x(P_n_, P_n_, P_n_, P_n_);
let mask1_ = _mm256_cmpgt_epi64(P_s, *a_s);
let corr1_ = _mm256_andnot_si256(mask1_, P_n);
let p_s = _mm256_set_epi64x(P_S_1, P_S_1, P_S_1, P_S_1);
let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1);
let mask1_ = _mm256_cmpgt_epi64(p_s, *a_s);
let corr1_ = _mm256_andnot_si256(mask1_, p_n);
_mm256_add_epi64(*a_s, corr1_)
}
}
Expand All @@ -31,9 +31,9 @@ pub fn toCanonical_avx_s(a_s: &__m256i) -> __m256i {
pub fn add_avx_a_sc(a_sc: &__m256i, b: &__m256i) -> __m256i {
unsafe {
let c0_s = _mm256_add_epi64(*a_sc, *b);
let P_n = _mm256_set_epi64x(P_n_, P_n_, P_n_, P_n_);
let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1);
let mask_ = _mm256_cmpgt_epi64(*a_sc, c0_s);
let corr_ = _mm256_and_si256(mask_, P_n);
let corr_ = _mm256_and_si256(mask_, p_n);
let c_s = _mm256_add_epi64(c0_s, corr_);
shift_avx(&c_s)
}
Expand Down Expand Up @@ -69,14 +69,27 @@ pub fn sub_avx_s_b_small(a_s: &__m256i, b: &__m256i) -> __m256i {
#[inline(always)]
pub fn reduce_avx_128_64(c_h: &__m256i, c_l: &__m256i) -> __m256i {
unsafe {
let MSB = _mm256_set_epi64x(MSB_, MSB_, MSB_, MSB_);
let msb = _mm256_set_epi64x(MSB_1, MSB_1, MSB_1, MSB_1);
let c_hh = _mm256_srli_epi64(*c_h, 32);
let c_ls = _mm256_xor_si256(*c_l, MSB);
let c_ls = _mm256_xor_si256(*c_l, msb);
let c1_s = sub_avx_s_b_small(&c_ls, &c_hh);
let P_n = _mm256_set_epi64x(P_n_, P_n_, P_n_, P_n_);
let c2 = _mm256_mul_epu32(*c_h, P_n);
let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1);
let c2 = _mm256_mul_epu32(*c_h, p_n);
let c_s = add_avx_s_b_small(&c1_s, &c2);
_mm256_xor_si256(c_s, MSB)
_mm256_xor_si256(c_s, msb)
}
}

// Here we suppose c_h < 2^32
#[inline(always)]
pub fn reduce_avx_96_64(c_h: &__m256i, c_l: &__m256i) -> __m256i {
unsafe {
let msb = _mm256_set_epi64x(MSB_1, MSB_1, MSB_1, MSB_1);
let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1);
let c_ls = _mm256_xor_si256(*c_l, msb);
let c2 = _mm256_mul_epu32(*c_h, p_n);
let c_s = add_avx_s_b_small(&c_ls, &c2);
_mm256_xor_si256(c_s, msb)
}
}

Expand Down Expand Up @@ -128,8 +141,8 @@ pub fn mult_avx_128(a: &__m256i, b: &__m256i) -> (__m256i, __m256i) {
let c_ll = _mm256_mul_epu32(*a, *b);
let c_ll_h = _mm256_srli_epi64(c_ll, 32);
let r0 = _mm256_add_epi64(c_hl, c_ll_h);
let P_n = _mm256_set_epi64x(P_n_, P_n_, P_n_, P_n_);
let r0_l = _mm256_and_si256(r0, P_n);
let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1);
let r0_l = _mm256_and_si256(r0, p_n);
let r0_h = _mm256_srli_epi64(r0, 32);
let r1 = _mm256_add_epi64(c_lh, r0_l);
// let r1_l = _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(r1)));
Expand All @@ -148,6 +161,21 @@ pub fn mult_avx(a: &__m256i, b: &__m256i) -> __m256i {
reduce_avx_128_64(&c_h, &c_l)
}

// Multiply two 64bit numbers with the assumption that the product does not averflow.
#[inline]
pub unsafe fn mul64_no_overflow(a: &__m256i, b: &__m256i) -> __m256i {
let r = _mm256_mul_epu32(*a, *b);
let ah = _mm256_srli_epi64(*a, 32);
let bh = _mm256_srli_epi64(*b, 32);
let r1 = _mm256_mul_epu32(*a, bh);
let r1 = _mm256_slli_epi64(r1, 32);
let r = _mm256_add_epi64(r, r1);
let r1 = _mm256_mul_epu32(ah, *b);
let r1 = _mm256_slli_epi64(r1, 32);
let r = _mm256_add_epi64(r, r1);
r
}

/*
#[inline(always)]
pub fn mult_avx_v2(a: &__m256i, b: &__m256i) -> __m256i {
Expand Down Expand Up @@ -275,3 +303,16 @@ pub fn sbox_avx_m256i(s0: &__m256i, s1: &__m256i, s2: &__m256i) -> (__m256i, __m

(r0, r1, r2)
}

#[allow(dead_code)]
#[inline(always)]
pub fn sbox_avx_one(s0: &__m256i) -> __m256i {
// x^2
let p10 = sqr_avx(s0);
// x^3
let p30 = mult_avx(&p10, s0);
// x^4 = (x^2)^2
let p40 = sqr_avx(&p10);
// x^7
mult_avx(&p40, &p30)
}
150 changes: 150 additions & 0 deletions plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// use core::arch::asm;
use core::arch::x86_64::*;

use crate::hash::hash_types::RichField;

const MSB_: i64 = 0x8000000000000000u64 as i64;
const P8_: i64 = 0xFFFFFFFF00000001u64 as i64;
const P8_n_: i64 = 0xFFFFFFFF;

#[allow(dead_code)]
#[inline(always)]
pub fn shift_avx512(a: &__m512i) -> __m512i {
unsafe {
let MSB = _mm512_set_epi64(MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_);
_mm512_xor_si512(*a, MSB)
}
}

#[allow(dead_code)]
#[inline(always)]
pub fn toCanonical_avx512(a: &__m512i) -> __m512i {
unsafe {
let P8 = _mm512_set_epi64(P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_);
let P8_n = _mm512_set_epi64(P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_);
let result_mask = _mm512_cmpge_epu64_mask(*a, P8);
_mm512_mask_add_epi64(*a, result_mask, *a, P8_n)
}
}

#[inline(always)]
pub fn add_avx512_b_c(a: &__m512i, b: &__m512i) -> __m512i {
unsafe {
let P8_n = _mm512_set_epi64(P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_);
let c0 = _mm512_add_epi64(*a, *b);
let result_mask = _mm512_cmpgt_epu64_mask(*a, c0);
_mm512_mask_add_epi64(c0, result_mask, c0, P8_n)
}
}

#[inline(always)]
pub fn sub_avx512_b_c(a: &__m512i, b: &__m512i) -> __m512i {
unsafe {
let P8 = _mm512_set_epi64(P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_);
let c0 = _mm512_sub_epi64(*a, *b);
let result_mask = _mm512_cmpgt_epu64_mask(*b, *a);
_mm512_mask_add_epi64(c0, result_mask, c0, P8)
}
}

#[inline(always)]
pub fn reduce_avx512_128_64(c_h: &__m512i, c_l: &__m512i) -> __m512i {
unsafe {
let P8_n = _mm512_set_epi64(P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_);
let c_hh = _mm512_srli_epi64(*c_h, 32);
let c1 = sub_avx512_b_c(c_l, &c_hh);
let c2 = _mm512_mul_epu32(*c_h, P8_n);
add_avx512_b_c(&c1, &c2)
}
}

#[inline(always)]
pub fn mult_avx512_128(a: &__m512i, b: &__m512i) -> (__m512i, __m512i) {
unsafe {
let a_h = _mm512_srli_epi64(*a, 32);
let b_h = _mm512_srli_epi64(*b, 32);
let c_hh = _mm512_mul_epu32(a_h, b_h);
let c_hl = _mm512_mul_epu32(a_h, *b);
let c_lh = _mm512_mul_epu32(*a, b_h);
let c_ll = _mm512_mul_epu32(*a, *b);
let c_ll_h = _mm512_srli_epi64(c_ll, 32);
let r0 = _mm512_add_epi64(c_hl, c_ll_h);
let P8_n = _mm512_set_epi64(P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_);
let r0_l = _mm512_and_si512(r0, P8_n);
let r0_h = _mm512_srli_epi64(r0, 32);
let r1 = _mm512_add_epi64(c_lh, r0_l);
let r1_l = _mm512_slli_epi64(r1, 32);
let mask = 0xAAAAu16;
let c_l = _mm512_mask_blend_epi32(mask, c_ll, r1_l);
let r2 = _mm512_add_epi64(c_hh, r0_h);
let r1_h = _mm512_srli_epi64(r1, 32);
let c_h = _mm512_add_epi64(r2, r1_h);
(c_h, c_l)
}
}

#[inline(always)]
pub fn mult_avx512(a: &__m512i, b: &__m512i) -> __m512i {
let (c_h, c_l) = mult_avx512_128(a, b);
reduce_avx512_128_64(&c_h, &c_l)
}

#[inline(always)]
pub fn sqr_avx512_128(a: &__m512i) -> (__m512i, __m512i) {
unsafe {
let a_h = _mm512_srli_epi64(*a, 32);
let c_ll = _mm512_mul_epu32(*a, *a);
let c_lh = _mm512_mul_epu32(*a, a_h);
let c_hh = _mm512_mul_epu32(a_h, a_h);
let c_ll_hi = _mm512_srli_epi64(c_ll, 33);
let t0 = _mm512_add_epi64(c_lh, c_ll_hi);
let t0_hi = _mm512_srli_epi64(t0, 31);
let res_hi = _mm512_add_epi64(c_hh, t0_hi);
let c_lh_lo = _mm512_slli_epi64(c_lh, 33);
let res_lo = _mm512_add_epi64(c_ll, c_lh_lo);
(res_hi, res_lo)
}
}

#[inline(always)]
pub fn sqr_avx512(a: &__m512i) -> __m512i {
let (c_h, c_l) = sqr_avx512_128(a);
reduce_avx512_128_64(&c_h, &c_l)
}

#[inline(always)]
pub fn sbox_avx512<F>(state: &mut [F; 16])
where
F: RichField,
{
unsafe {
let s0 = _mm512_loadu_si512((&state[0..8]).as_ptr().cast::<i32>());
let s1 = _mm512_loadu_si512((&state[8..16]).as_ptr().cast::<i32>());
// x^2
let p10 = sqr_avx512(&s0);
let p11 = sqr_avx512(&s1);
// x^3
let p20 = mult_avx512(&p10, &s0);
let p21 = mult_avx512(&p11, &s1);
// x^4 = (x^2)^2
let s0 = sqr_avx512(&p10);
let s1 = sqr_avx512(&p11);
// x^7
let p10 = mult_avx512(&s0, &p20);
let p11 = mult_avx512(&s1, &p21);
_mm512_storeu_si512((&mut state[0..8]).as_mut_ptr().cast::<i32>(), p10);
_mm512_storeu_si512((&mut state[8..16]).as_mut_ptr().cast::<i32>(), p11);
}
}

#[inline(always)]
pub fn sbox_avx512_one(s0: &__m512i) -> __m512i {
// x^2
let p10 = sqr_avx512(s0);
// x^3
let p30 = mult_avx512(&p10, s0);
// x^4 = (x^2)^2
let p40 = sqr_avx512(&p10);
// x^7
mult_avx512(&p40, &p30)
}
6 changes: 6 additions & 0 deletions plonky2/src/hash/arch/x86_64/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
// #[cfg(all(target_feature = "avx2", target_feature = "bmi2"))]
#[cfg(target_feature = "avx2")]
pub mod goldilocks_avx2;
#[cfg(target_feature = "avx512dq")]
pub mod goldilocks_avx512;
#[cfg(target_feature = "avx2")]
pub mod poseidon2_goldilocks_avx2;
#[cfg(target_feature = "avx2")]
pub mod poseidon_bn128_avx2;
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512dq")))]
pub mod poseidon_goldilocks_avx2;
#[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))]
pub mod poseidon_goldilocks_avx512;
Loading