Skip to content

Commit

Permalink
Merge pull request #38 from okx/dev-dumi
Browse files Browse the repository at this point in the history
avx512 fix
  • Loading branch information
dloghin authored Oct 18, 2024
2 parents 376d690 + 6b4bce2 commit 8a00c2b
Show file tree
Hide file tree
Showing 6 changed files with 334 additions and 217 deletions.
13 changes: 10 additions & 3 deletions plonky2/src/hash/arch/x86_64/goldilocks_avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,16 @@ pub fn add_avx_a_sc(a_sc: &__m256i, b: &__m256i) -> __m256i {

#[inline(always)]
pub fn add_avx(a: &__m256i, b: &__m256i) -> __m256i {
let a_sc = shift_avx(a);
// let a_sc = toCanonical_avx_s(&a_s);
add_avx_a_sc(&a_sc, b)
unsafe {
let msb = _mm256_set_epi64x(MSB_1, MSB_1, MSB_1, MSB_1);
let a_sc = _mm256_xor_si256(*a, msb);
let c0_s = _mm256_add_epi64(a_sc, *b);
let p_n = _mm256_set_epi64x(P_N_1, P_N_1, P_N_1, P_N_1);
let mask_ = _mm256_cmpgt_epi64(a_sc, c0_s);
let corr_ = _mm256_and_si256(mask_, p_n);
let c_s = _mm256_add_epi64(c0_s, corr_);
_mm256_xor_si256(c_s, msb)
}
}

#[inline(always)]
Expand Down
43 changes: 34 additions & 9 deletions plonky2/src/hash/arch/x86_64/goldilocks_avx512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,22 @@ use crate::hash::hash_types::RichField;
const MSB_: i64 = 0x8000000000000000u64 as i64;
const P8_: i64 = 0xFFFFFFFF00000001u64 as i64;
const P8_N_: i64 = 0xFFFFFFFF;
const ONE_: i64 = 1;

#[allow(non_snake_case)]
#[repr(align(64))]
struct FieldConstants {
MSB_V: [i64; 8],
P8_V: [i64; 8],
P8_N_V: [i64; 8],
pub(crate) struct FieldConstants {
pub(crate) MSB_V: [i64; 8],
pub(crate) P8_V: [i64; 8],
pub(crate) P8_N_V: [i64; 8],
pub(crate) ONE_V: [i64; 8],
}

const FC: FieldConstants = FieldConstants {
pub(crate) const FC: FieldConstants = FieldConstants {
MSB_V: [MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_, MSB_],
P8_V: [P8_, P8_, P8_, P8_, P8_, P8_, P8_, P8_],
P8_N_V: [P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_],
ONE_V: [ONE_, ONE_, ONE_, ONE_, ONE_, ONE_, ONE_, ONE_],
};

#[allow(dead_code)]
Expand Down Expand Up @@ -46,13 +49,34 @@ pub fn to_canonical_avx512(a: &__m512i) -> __m512i {

#[inline(always)]
pub fn add_avx512(a: &__m512i, b: &__m512i) -> __m512i {
/*
unsafe {
// let p8_n = _mm512_set_epi64(P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_, P8_N_);
let p8_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::<i32>());
let p8_n = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::<i64>());
let c0 = _mm512_add_epi64(*a, *b);
let result_mask = _mm512_cmpgt_epu64_mask(*a, c0);
_mm512_mask_add_epi64(c0, result_mask, c0, p8_n)
}
*/
unsafe {
let msb = _mm512_load_epi64(FC.MSB_V.as_ptr().cast::<i64>());
let a_sc = _mm512_xor_si512(*a, msb);
let c0_s = _mm512_add_epi64(a_sc, *b);
let p_n = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::<i64>());
let mask_ = _mm512_cmpgt_epi64_mask(a_sc, c0_s);
let c_s = _mm512_mask_add_epi64(c0_s, mask_, c0_s, p_n);
_mm512_xor_si512(c_s, msb)
}
}

#[inline(always)]
pub fn add_avx512_s_b_small(a_s: &__m512i, b_small: &__m512i) -> __m512i {
unsafe {
let corr = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::<i64>());
let c0_s = _mm512_add_epi64(*a_s, *b_small);
let mask_ = _mm512_cmpgt_epi64_mask(*a_s, c0_s);
_mm512_mask_add_epi64(c0_s, mask_, c0_s, corr)
}
}

#[inline(always)]
Expand Down Expand Up @@ -82,11 +106,12 @@ pub fn reduce_avx512_128_64(c_h: &__m512i, c_l: &__m512i) -> __m512i {
#[inline(always)]
pub fn reduce_avx512_96_64(c_h: &__m512i, c_l: &__m512i) -> __m512i {
unsafe {
let msb = _mm512_load_si512(FC.MSB_V.as_ptr().cast::<i32>());
let p_n = _mm512_load_si512(FC.P8_N_V.as_ptr().cast::<i32>());
let msb = _mm512_load_epi64(FC.MSB_V.as_ptr().cast::<i64>());
let p_n = _mm512_load_epi64(FC.P8_N_V.as_ptr().cast::<i64>());
let c_ls = _mm512_xor_si512(*c_l, msb);
let c2 = _mm512_mul_epu32(*c_h, p_n);
let c_s = add_avx512(&c_ls, &c2);
let c_s = add_avx512_s_b_small(&c_ls, &c2);
// let c_s = add_avx512(&c_ls, &c2);
_mm512_xor_si512(c_s, msb)
}
}
Expand Down
21 changes: 0 additions & 21 deletions plonky2/src/hash/arch/x86_64/poseidon_bn128_avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,34 +84,13 @@ unsafe fn sub64(a: &__m256i, b: &__m256i, bin: &__m256i) -> (__m256i, __m256i) {
let zeros = _mm256_set_epi64x(0, 0, 0, 0);
let (r1, b1) = sub64_no_borrow(a, b);

// TODO - delete
/*
let mut v = [0i64; 4];
_mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), *a);
println!("a: {:?}", v);
_mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), *b);
println!("b: {:?}", v);
_mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), r1);
println!("r: {:?}", v);
_mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), b1);
println!("b: {:?}", v);
*/

let m1 = _mm256_cmpeq_epi64(*bin, ones);
let m2 = _mm256_cmpeq_epi64(r1, zeros);
let m = _mm256_and_si256(m1, m2);
let bo = _mm256_and_si256(m, ones);
let r = _mm256_sub_epi64(r1, *bin);
let bo = _mm256_or_si256(bo, b1);

// TODO - delete
/*
_mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), r);
println!("r: {:?}", v);
_mm256_storeu_si256(v.as_mut_ptr().cast::<__m256i>(), bo);
println!("b: {:?}", v);
*/

(r, bo)
}

Expand Down
1 change: 1 addition & 0 deletions plonky2/src/hash/arch/x86_64/poseidon_goldilocks_avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1411,6 +1411,7 @@ where
let ss0 = add_avx(&s0, &rc0);
let ss1 = add_avx(&s1, &rc1);
let ss2 = add_avx(&s2, &rc2);

(s0, s1, s2) = sbox_avx_m256i(&ss0, &ss1, &ss2);
mds_layer_avx(&mut s0, &mut s1, &mut s2);
round_ctr += 1;
Expand Down
Loading

0 comments on commit 8a00c2b

Please sign in to comment.