Skip to content

Commit

Permalink
Merge pull request #35 from okx/dev-dumi
Browse files Browse the repository at this point in the history
Merge latest PoseidonBN128 code
  • Loading branch information
dloghin authored Jul 1, 2024
2 parents d99fe3b + beb6e11 commit f757d1e
Show file tree
Hide file tree
Showing 16 changed files with 7,652 additions and 75 deletions.
54 changes: 45 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,57 @@
# description
this repo is a fork of https://github.com/0xPolygonZero/plonky2. several optimizations were implemented to boost the computation speed.
# Description

# optimizations
- precompute of fft twiddle factors
- cuda implementation of Goldilocks Field NTT (feature `cuda`)
This repo is a fork of https://github.com/0xPolygonZero/plonky2. To boost speed, several optimizations were implemented:

# Optimizations
- Precompute FFT twiddle factors.
- CUDA implementation of Goldilocks Field NTT (feature `cuda`).
- CUDA implementation of Poseidon (Goldilocks) and Poseidon (BN 128) (feature `cuda`).
- Fixed the AVX implementation for Poseidon (Goldilocks) (target CPU must support AVX2).
- CUDA implementation of Merkle Tree building (feature `cuda`).
- Change Merkle Tree structure from recursive to iterative (1-dimensional vector).

# Dependencies

# dependencies
```
git submodule update --init --recursive
```

# run examples
- cuda NTT
## Benchmarking Merkle Tree building with Poseison hash

Set the latest Rust nightly:
```
cargo run --release -p plonky2_field --features=cuda --example fft
rustup update
rustup override set nightly-x86_64-unknown-linux-gnu
```

CPU, no AVX: ``cargo bench merkle``

CPU with AVX2: ``RUSTFLAGS="-C target-feature=+avx2" cargo bench merkle``

CPU with AVX512: ``RUSTFLAGS="-C target-feature=+avx512dq" cargo bench merkle``

GPU (CUDA): ``cargo bench merkle --features=cuda``

### Results

The results in the table below represent the build time (in milliseconds) of a Merkle Tree with the indicated number of leaves (first row) using the hashing method indicated in the first column. The systems used for benchmarking are:

- first three columns: AMD Ryzen Threadripper PRO 5975WX 32-Cores (only AVX2) + NVIDIA RTX 4090 (feature `cuda`);

- last three columns: AMD Ryzen 9 7950X 16-Core (AVX2 and AVX512DQ).


| Number of MT Leaves | 2^13 | 2^14 | 2^15 | | 2^13 | 2^14 | 2^15 |
| --- | --- | --- | --- | --- | --- | --- | --- |
| Poseidon (no AVX) | 12.4 | 23.4 | 46.6 | | 12.8 | 25.2 | 50.3 |
| Poseidon (AVX) | 11.4 | 21.3 | 39.2 | | 10.3 | 20.3 | 40.2 |
| Poseidon (AVX512) | - | - | - | | 12.3 | 24.1 | 47.8 |
| Poseidon (GPU) | 8 | 14.3 | 26.5 | | - | - | - |
| Poseidon BN 128 (no AVX) | 111.9 | 223 | 446.3 | | 176.9 | 351 | 699.1 |
| Poseidon BN 128 (AVX) | 146.8 | 291.7 | 581.8 | | 220.1 | 433.5 | 858.8 |
| Poseidon BN 128 (AVX512) | - | - | - | | WIP | WIP | WIP |
| Poseidon BN 128 (GPU) | 37.5 | 57.6 | 92.9 | | - | - | - |

## Running

To see recursion performance, one can run this bench, which generates a chain of three recursion proofs:
Expand Down
2 changes: 2 additions & 0 deletions plonky2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ cuda = ["cryptography_cuda/cuda"]
no_cuda = ["cryptography_cuda/no_cuda"]
batch = []
cuda_timing = []
papi = []

[dependencies]
ahash = { workspace = true }
Expand All @@ -41,6 +42,7 @@ static_assertions = { workspace = true }
unroll = { workspace = true }
web-time = { version = "1.0.0", optional = true }
once_cell = { version = "1.18.0" }
papi-bindings = { version = "0.5.2" }

# Local dependencies
plonky2_field = { version = "0.2.0", path = "../field", default-features = false }
Expand Down
79 changes: 60 additions & 19 deletions plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,26 @@ use core::arch::x86_64::*;

use crate::hash::hash_types::RichField;

const MSB_: i64 = 0x8000000000000000u64 as i64;
const P_s_: i64 = 0x7FFFFFFF00000001u64 as i64;
const P_n_: i64 = 0xFFFFFFFF;
const MSB_1: i64 = 0x8000000000000000u64 as i64;
const P_S_1: i64 = 0x7FFFFFFF00000001u64 as i64;
const P_N_1: i64 = 0xFFFFFFFF;

#[inline(always)]
pub fn shift_avx(a: &__m256i) -> __m256i {
unsafe {
let MSB = _mm256_set_epi64x(MSB_, MSB_, MSB_, MSB_);
_mm256_xor_si256(*a, MSB)
let msb = _mm256_set_epi64x(MSB_1, MSB_1, MSB_1, MSB_1);
_mm256_xor_si256(*a, msb)
}
}

#[allow(dead_code)]
#[inline(always)]
pub fn toCanonical_avx_s(a_s: &__m256i) -> __m256i {
pub fn to_canonical_avx_s(a_s: &__m256i) -> __m256i {
unsafe {
let P_s = _mm256_set_epi64x(P_s_, P_s_, P_s_, P_s_);
let P_n = _mm256_set_epi64x(P_n_, P_n_, P_n_, P_n_);
let mask1_ = _mm256_cmpgt_epi64(P_s, *a_s);
let corr1_ = _mm256_andnot_si256(mask1_, P_n);
let p_s = _mm256_set_epi64x(P_S_1, P_S_1, P_S_1, P_S_1);
let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1);
let mask1_ = _mm256_cmpgt_epi64(p_s, *a_s);
let corr1_ = _mm256_andnot_si256(mask1_, p_n);
_mm256_add_epi64(*a_s, corr1_)
}
}
Expand All @@ -31,9 +31,9 @@ pub fn toCanonical_avx_s(a_s: &__m256i) -> __m256i {
pub fn add_avx_a_sc(a_sc: &__m256i, b: &__m256i) -> __m256i {
unsafe {
let c0_s = _mm256_add_epi64(*a_sc, *b);
let P_n = _mm256_set_epi64x(P_n_, P_n_, P_n_, P_n_);
let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1);
let mask_ = _mm256_cmpgt_epi64(*a_sc, c0_s);
let corr_ = _mm256_and_si256(mask_, P_n);
let corr_ = _mm256_and_si256(mask_, p_n);
let c_s = _mm256_add_epi64(c0_s, corr_);
shift_avx(&c_s)
}
Expand Down Expand Up @@ -69,14 +69,27 @@ pub fn sub_avx_s_b_small(a_s: &__m256i, b: &__m256i) -> __m256i {
#[inline(always)]
pub fn reduce_avx_128_64(c_h: &__m256i, c_l: &__m256i) -> __m256i {
unsafe {
let MSB = _mm256_set_epi64x(MSB_, MSB_, MSB_, MSB_);
let msb = _mm256_set_epi64x(MSB_1, MSB_1, MSB_1, MSB_1);
let c_hh = _mm256_srli_epi64(*c_h, 32);
let c_ls = _mm256_xor_si256(*c_l, MSB);
let c_ls = _mm256_xor_si256(*c_l, msb);
let c1_s = sub_avx_s_b_small(&c_ls, &c_hh);
let P_n = _mm256_set_epi64x(P_n_, P_n_, P_n_, P_n_);
let c2 = _mm256_mul_epu32(*c_h, P_n);
let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1);
let c2 = _mm256_mul_epu32(*c_h, p_n);
let c_s = add_avx_s_b_small(&c1_s, &c2);
_mm256_xor_si256(c_s, MSB)
_mm256_xor_si256(c_s, msb)
}
}

// Here we suppose c_h < 2^32
#[inline(always)]
pub fn reduce_avx_96_64(c_h: &__m256i, c_l: &__m256i) -> __m256i {
unsafe {
let msb = _mm256_set_epi64x(MSB_1, MSB_1, MSB_1, MSB_1);
let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1);
let c_ls = _mm256_xor_si256(*c_l, msb);
let c2 = _mm256_mul_epu32(*c_h, p_n);
let c_s = add_avx_s_b_small(&c_ls, &c2);
_mm256_xor_si256(c_s, msb)
}
}

Expand Down Expand Up @@ -128,8 +141,8 @@ pub fn mult_avx_128(a: &__m256i, b: &__m256i) -> (__m256i, __m256i) {
let c_ll = _mm256_mul_epu32(*a, *b);
let c_ll_h = _mm256_srli_epi64(c_ll, 32);
let r0 = _mm256_add_epi64(c_hl, c_ll_h);
let P_n = _mm256_set_epi64x(P_n_, P_n_, P_n_, P_n_);
let r0_l = _mm256_and_si256(r0, P_n);
let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1);
let r0_l = _mm256_and_si256(r0, p_n);
let r0_h = _mm256_srli_epi64(r0, 32);
let r1 = _mm256_add_epi64(c_lh, r0_l);
// let r1_l = _mm256_castps_si256(_mm256_moveldup_ps(_mm256_castsi256_ps(r1)));
Expand All @@ -148,6 +161,21 @@ pub fn mult_avx(a: &__m256i, b: &__m256i) -> __m256i {
reduce_avx_128_64(&c_h, &c_l)
}

// Multiply two 64bit numbers with the assumption that the product does not averflow.
#[inline]
pub unsafe fn mul64_no_overflow(a: &__m256i, b: &__m256i) -> __m256i {
let r = _mm256_mul_epu32(*a, *b);
let ah = _mm256_srli_epi64(*a, 32);
let bh = _mm256_srli_epi64(*b, 32);
let r1 = _mm256_mul_epu32(*a, bh);
let r1 = _mm256_slli_epi64(r1, 32);
let r = _mm256_add_epi64(r, r1);
let r1 = _mm256_mul_epu32(ah, *b);
let r1 = _mm256_slli_epi64(r1, 32);
let r = _mm256_add_epi64(r, r1);
r
}

/*
#[inline(always)]
pub fn mult_avx_v2(a: &__m256i, b: &__m256i) -> __m256i {
Expand Down Expand Up @@ -275,3 +303,16 @@ pub fn sbox_avx_m256i(s0: &__m256i, s1: &__m256i, s2: &__m256i) -> (__m256i, __m

(r0, r1, r2)
}

#[allow(dead_code)]
#[inline(always)]
pub fn sbox_avx_one(s0: &__m256i) -> __m256i {
// x^2
let p10 = sqr_avx(s0);
// x^3
let p30 = mult_avx(&p10, s0);
// x^4 = (x^2)^2
let p40 = sqr_avx(&p10);
// x^7
mult_avx(&p40, &p30)
}
150 changes: 150 additions & 0 deletions plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
// use core::arch::asm;
use core::arch::x86_64::*;

use crate::hash::hash_types::RichField;

const MSB_: i64 = 0x8000000000000000u64 as i64;
const P8_: i64 = 0xFFFFFFFF00000001u64 as i64;
const P8_n_: i64 = 0xFFFFFFFF;

#[allow(dead_code)]
#[inline(always)]
pub fn shift_avx512(a: &__m512i) -> __m512i {
unsafe {
let MSB = _mm512_set_epi64(MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_);
_mm512_xor_si512(*a, MSB)
}
}

#[allow(dead_code)]
#[inline(always)]
pub fn toCanonical_avx512(a: &__m512i) -> __m512i {
unsafe {
let P8 = _mm512_set_epi64(P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_);
let P8_n = _mm512_set_epi64(P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_);
let result_mask = _mm512_cmpge_epu64_mask(*a, P8);
_mm512_mask_add_epi64(*a, result_mask, *a, P8_n)
}
}

#[inline(always)]
pub fn add_avx512_b_c(a: &__m512i, b: &__m512i) -> __m512i {
unsafe {
let P8_n = _mm512_set_epi64(P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_);
let c0 = _mm512_add_epi64(*a, *b);
let result_mask = _mm512_cmpgt_epu64_mask(*a, c0);
_mm512_mask_add_epi64(c0, result_mask, c0, P8_n)
}
}

#[inline(always)]
pub fn sub_avx512_b_c(a: &__m512i, b: &__m512i) -> __m512i {
unsafe {
let P8 = _mm512_set_epi64(P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_);
let c0 = _mm512_sub_epi64(*a, *b);
let result_mask = _mm512_cmpgt_epu64_mask(*b, *a);
_mm512_mask_add_epi64(c0, result_mask, c0, P8)
}
}

#[inline(always)]
pub fn reduce_avx512_128_64(c_h: &__m512i, c_l: &__m512i) -> __m512i {
unsafe {
let P8_n = _mm512_set_epi64(P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_);
let c_hh = _mm512_srli_epi64(*c_h, 32);
let c1 = sub_avx512_b_c(c_l, &c_hh);
let c2 = _mm512_mul_epu32(*c_h, P8_n);
add_avx512_b_c(&c1, &c2)
}
}

#[inline(always)]
pub fn mult_avx512_128(a: &__m512i, b: &__m512i) -> (__m512i, __m512i) {
unsafe {
let a_h = _mm512_srli_epi64(*a, 32);
let b_h = _mm512_srli_epi64(*b, 32);
let c_hh = _mm512_mul_epu32(a_h, b_h);
let c_hl = _mm512_mul_epu32(a_h, *b);
let c_lh = _mm512_mul_epu32(*a, b_h);
let c_ll = _mm512_mul_epu32(*a, *b);
let c_ll_h = _mm512_srli_epi64(c_ll, 32);
let r0 = _mm512_add_epi64(c_hl, c_ll_h);
let P8_n = _mm512_set_epi64(P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_, P8_n_);
let r0_l = _mm512_and_si512(r0, P8_n);
let r0_h = _mm512_srli_epi64(r0, 32);
let r1 = _mm512_add_epi64(c_lh, r0_l);
let r1_l = _mm512_slli_epi64(r1, 32);
let mask = 0xAAAAu16;
let c_l = _mm512_mask_blend_epi32(mask, c_ll, r1_l);
let r2 = _mm512_add_epi64(c_hh, r0_h);
let r1_h = _mm512_srli_epi64(r1, 32);
let c_h = _mm512_add_epi64(r2, r1_h);
(c_h, c_l)
}
}

#[inline(always)]
pub fn mult_avx512(a: &__m512i, b: &__m512i) -> __m512i {
let (c_h, c_l) = mult_avx512_128(a, b);
reduce_avx512_128_64(&c_h, &c_l)
}

#[inline(always)]
pub fn sqr_avx512_128(a: &__m512i) -> (__m512i, __m512i) {
unsafe {
let a_h = _mm512_srli_epi64(*a, 32);
let c_ll = _mm512_mul_epu32(*a, *a);
let c_lh = _mm512_mul_epu32(*a, a_h);
let c_hh = _mm512_mul_epu32(a_h, a_h);
let c_ll_hi = _mm512_srli_epi64(c_ll, 33);
let t0 = _mm512_add_epi64(c_lh, c_ll_hi);
let t0_hi = _mm512_srli_epi64(t0, 31);
let res_hi = _mm512_add_epi64(c_hh, t0_hi);
let c_lh_lo = _mm512_slli_epi64(c_lh, 33);
let res_lo = _mm512_add_epi64(c_ll, c_lh_lo);
(res_hi, res_lo)
}
}

#[inline(always)]
pub fn sqr_avx512(a: &__m512i) -> __m512i {
let (c_h, c_l) = sqr_avx512_128(a);
reduce_avx512_128_64(&c_h, &c_l)
}

#[inline(always)]
pub fn sbox_avx512<F>(state: &mut [F; 16])
where
F: RichField,
{
unsafe {
let s0 = _mm512_loadu_si512((&state[0..8]).as_ptr().cast::<i32>());
let s1 = _mm512_loadu_si512((&state[8..16]).as_ptr().cast::<i32>());
// x^2
let p10 = sqr_avx512(&s0);
let p11 = sqr_avx512(&s1);
// x^3
let p20 = mult_avx512(&p10, &s0);
let p21 = mult_avx512(&p11, &s1);
// x^4 = (x^2)^2
let s0 = sqr_avx512(&p10);
let s1 = sqr_avx512(&p11);
// x^7
let p10 = mult_avx512(&s0, &p20);
let p11 = mult_avx512(&s1, &p21);
_mm512_storeu_si512((&mut state[0..8]).as_mut_ptr().cast::<i32>(), p10);
_mm512_storeu_si512((&mut state[8..16]).as_mut_ptr().cast::<i32>(), p11);
}
}

#[inline(always)]
pub fn sbox_avx512_one(s0: &__m512i) -> __m512i {
// x^2
let p10 = sqr_avx512(s0);
// x^3
let p30 = mult_avx512(&p10, s0);
// x^4 = (x^2)^2
let p40 = sqr_avx512(&p10);
// x^7
mult_avx512(&p40, &p30)
}
6 changes: 6 additions & 0 deletions plonky2/src/hash/arch/x86_64/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
// #[cfg(all(target_feature = "avx2", target_feature = "bmi2"))]
#[cfg(target_feature = "avx2")]
pub mod goldilocks_avx2;
#[cfg(target_feature = "avx512dq")]
pub mod goldilocks_avx512;
#[cfg(target_feature = "avx2")]
pub mod poseidon2_goldilocks_avx2;
#[cfg(target_feature = "avx2")]
pub mod poseidon_bn128_avx2;
#[cfg(all(target_feature = "avx2", not(target_feature = "avx512dq")))]
pub mod poseidon_goldilocks_avx2;
#[cfg(all(target_feature = "avx2", target_feature = "avx512dq"))]
pub mod poseidon_goldilocks_avx512;
Loading

0 comments on commit f757d1e

Please sign in to comment.