Skip to content

Commit

Permalink
add salt support in oracle.rs
Browse files Browse the repository at this point in the history
  • Loading branch information
dloghin committed Sep 30, 2024
1 parent 6f925b7 commit 067d4a3
Showing 1 changed file with 18 additions and 11 deletions.
29 changes: 18 additions & 11 deletions plonky2/src/fri/oracle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,11 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
timing: &mut TimingTree,
fft_root_table: Option<&FftRootTable<F>>,
) -> Self {
let pols = polynomials.len();
let degree = polynomials[0].len();
let log_n = log2_strict(degree);

if log_n + rate_bits > 1 && polynomials.len() > 0 {
if log_n + rate_bits > 1 && polynomials.len() > 0 && pols * (1 << (log_n + rate_bits)) < (1 << 31) {
let _num_gpus: usize = std::env::var("NUM_OF_GPUS")
.expect("NUM_OF_GPUS should be set")
.parse()
Expand Down Expand Up @@ -235,14 +236,14 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
pub fn from_coeffs_gpu(
polynomials: &[PolynomialCoeffs<F>],
rate_bits: usize,
_blinding: bool,
blinding: bool,
cap_height: usize,
timing: &mut TimingTree,
_fft_root_table: Option<&FftRootTable<F>>,
log_n: usize,
_degree: usize,
) -> MerkleTree<F, <C as GenericConfig<D>>::Hasher> {
// let salt_size = if blinding { SALT_SIZE } else { 0 };
let salt_size = if blinding { SALT_SIZE } else { 0 };
// println!("salt_size: {:?}", salt_size);
let output_domain_size = log_n + rate_bits;

Expand All @@ -255,8 +256,9 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
let total_num_of_fft = polynomials.len();
// println!("total_num_of_fft: {:?}", total_num_of_fft);

let num_of_cols = total_num_of_fft + salt_size; // if blinding, extend by salt_size
let total_num_input_elements = total_num_of_fft * (1 << log_n);
let total_num_output_elements = total_num_of_fft * (1 << output_domain_size);
let total_num_output_elements = num_of_cols * (1 << output_domain_size);

let mut gpu_input: Vec<F> = polynomials
.into_iter()
Expand All @@ -270,6 +272,7 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
cfg_lde.are_outputs_on_device = true;
cfg_lde.with_coset = true;
cfg_lde.is_multi_gpu = true;
cfg_lde.salt_size = salt_size as u32;

let mut device_output_data: HostOrDeviceSlice<'_, F> =
HostOrDeviceSlice::cuda_malloc(0 as i32, total_num_output_elements).unwrap();
Expand Down Expand Up @@ -302,7 +305,7 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
}

let mut cfg_trans = TransposeConfig::default();
cfg_trans.batches = total_num_of_fft as u32;
cfg_trans.batches = num_of_cols as u32;
cfg_trans.are_inputs_on_device = true;
cfg_trans.are_outputs_on_device = true;

Expand All @@ -327,10 +330,14 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
MerkleTree::new_from_gpu_leaves(
&device_transpose_data,
1 << output_domain_size,
total_num_of_fft,
num_of_cols,
cap_height
)
);

drop(device_transpose_data);
drop(device_output_data);

mt
}

Expand Down Expand Up @@ -443,11 +450,11 @@ impl<F: RichField + Extendable<D>, C: GenericConfig<D, F = F>, const D: usize>
println!("collect data from gpu used: {:?}", start.elapsed());
r
})
// .chain(
// (0..salt_size)
// .into_par_iter()
// .map(|_| F::rand_vec(degree << rate_bits)),
// )
.chain(
(0..salt_size)
.into_par_iter()
.map(|_| F::rand_vec(degree << rate_bits)),
)
.collect();
println!("real lde elapsed: {:?}", start_lde.elapsed());
return ret;
Expand Down

0 comments on commit 067d4a3

Please sign in to comment.