Skip to content

Commit

Permalink
implemented change leaves in range
Browse files Browse the repository at this point in the history
  • Loading branch information
dloghin committed Apr 15, 2024
1 parent d17c3ec commit 94ff49a
Showing 1 changed file with 230 additions and 20 deletions.
250 changes: 230 additions & 20 deletions plonky2/src/hash/merkle_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use alloc::sync::Arc;
use alloc::vec::Vec;
use core::mem::MaybeUninit;
use core::slice;
use std::collections::HashSet;
#[cfg(feature = "cuda")]
use std::os::raw::c_void;
#[cfg(feature = "cuda")]
Expand All @@ -30,9 +31,9 @@ use crate::hash::hash_types::RichField;
#[cfg(feature = "cuda")]
use crate::hash::hash_types::NUM_HASH_OUT_ELTS;
use crate::hash::merkle_proofs::MerkleProof;
use crate::plonk::config::{GenericHashOut, Hasher};
#[cfg(feature = "cuda")]
use crate::plonk::config::HasherType;
use crate::plonk::config::{GenericHashOut, Hasher};
use crate::util::log2_strict;

#[cfg(feature = "cuda")]
Expand Down Expand Up @@ -942,6 +943,113 @@ impl<F: RichField, H: Hasher<F>> MerkleTree<F, H> {
}
}

pub fn change_leaves_in_range_and_update(
&mut self,
new_leaves: Vec<Vec<F>>,
start_index: usize,
end_index: usize,
) {
assert_eq!(new_leaves.len(), end_index - start_index);
assert_eq!(new_leaves[0].len(), self.leaf_size);

let tree_leaves_count = self.leaves.len() / self.leaf_size;
assert!(start_index < end_index);
assert!(end_index < tree_leaves_count);

let cap_height = log2_strict(self.cap.len());
let mut leaves = self.leaves.clone();

leaves[start_index * self.leaf_size..end_index * self.leaf_size]
.par_chunks_exact_mut(self.leaf_size)
.zip(new_leaves.clone())
.for_each(|(x, y)| {
for j in 0..self.leaf_size {
x[j] = y[j];
}
});

let digests_len = self.digests.len();
let cap_len = self.cap.0.len();
let digests_buf = capacity_up_to_mut(&mut self.digests, digests_len);
let cap_buf = capacity_up_to_mut(&mut self.cap.0, cap_len);
self.leaves = leaves;
if digests_buf.is_empty() {
cap_buf[start_index..end_index]
.par_iter_mut()
.zip(new_leaves)
.for_each(|(cap, leaf)| {
cap.write(H::hash_or_noop(leaf.as_slice()));
});
} else {
let subtree_leaves_len = tree_leaves_count >> cap_height;
let subtree_digests_len = digests_buf.len() >> cap_height;

let mut positions: Vec<usize> = (start_index..end_index)
.map(|idx| {
let subtree_idx = idx / subtree_leaves_len;
let subtree_offset = subtree_idx * subtree_digests_len;
let idx_in_subtree =
subtree_digests_len - subtree_leaves_len + idx % subtree_leaves_len;
subtree_offset + idx_in_subtree
})
.collect();

// TODO change to parallel loop
for i in 0..positions.len() {
digests_buf[positions[i]].write(H::hash_or_noop(new_leaves[i].as_slice()));
}

if subtree_digests_len > 2 {
let rounds = log2_strict(tree_leaves_count) - cap_height - 1;
for _ in 0..rounds {
let mut parent_indexes: HashSet<usize> = HashSet::new();
let parents: Vec<usize> = positions
.par_iter()
.map(|pos| {
let subtree_offset = pos / subtree_digests_len;
let idx_in_subtree = pos % subtree_digests_len;
let mut parent_idx = 0;
if idx_in_subtree > 1 {
parent_idx = idx_in_subtree / 2 - 1;
}
subtree_offset * subtree_digests_len + parent_idx
})
.collect();
for p in parents {
parent_indexes.insert(p);
}
positions = parent_indexes.into_iter().collect();

// TODO change to parallel loop
for i in 0..positions.len() {
let subtree_offset = positions[i] / subtree_digests_len;
let idx_in_subtree = positions[i] % subtree_digests_len;
let digest_idx = subtree_offset * subtree_digests_len + 2 * (idx_in_subtree + 1);
unsafe {
let left_digest = digests_buf[digest_idx].assume_init();
let right_digest = digests_buf[digest_idx + 1].assume_init();
digests_buf[positions[i]].write(H::two_to_one(left_digest, right_digest));
}
}
}
}

let mut cap_indexes: HashSet<usize> = HashSet::new();
for idx in start_index..end_index {
cap_indexes.insert(idx / subtree_leaves_len);
}

unsafe {
for idx in cap_indexes {
let digest_idx = idx * subtree_digests_len;
let left_digest = digests_buf[digest_idx].assume_init();
let right_digest = digests_buf[digest_idx + 1].assume_init();
cap_buf[idx].write(H::two_to_one(left_digest, right_digest));
}
}
}
}

/// Create a Merkle proof from a leaf index.
pub fn prove(&self, leaf_index: usize) -> MerkleProof<F, H> {
let cap_height = log2_strict(self.cap.len());
Expand Down Expand Up @@ -1057,7 +1165,13 @@ mod tests {
});
}

fn verify_change_leaf_and_update_range(leaves_count: usize, leaf_size: usize, cap_height: usize, start_index: usize, end_index: usize) {
fn verify_change_leaf_and_update_range_one_by_one(
leaves_count: usize,
leaf_size: usize,
cap_height: usize,
start_index: usize,
end_index: usize,
) {
use plonky2_field::types::Field;

const D: usize = 2;
Expand All @@ -1070,7 +1184,9 @@ mod tests {
let mut leaves1_1d: Vec<F> = raw_leaves.into_iter().flatten().collect();
let leaves2_1d: Vec<F> = leaves1_1d.clone();

let mut tree2 = MerkleTree::<F, <C as GenericConfig<D>>::Hasher>::new_from_1d(leaves2_1d, leaf_size, cap_height);
let mut tree2 = MerkleTree::<F, <C as GenericConfig<D>>::Hasher>::new_from_1d(
leaves2_1d, leaf_size, cap_height,
);

// v1
let now = Instant::now();
Expand All @@ -1079,7 +1195,9 @@ mod tests {
leaves1_1d[i * leaf_size + j] = vals[i - start_index][j];
}
}
let tree1 = MerkleTree::<F, <C as GenericConfig<D>>::Hasher>::new_from_1d(leaves1_1d, leaf_size, cap_height);
let tree1 = MerkleTree::<F, <C as GenericConfig<D>>::Hasher>::new_from_1d(
leaves1_1d, leaf_size, cap_height,
);
println!("Time V1: {} ms", now.elapsed().as_millis());

// v2
Expand All @@ -1095,8 +1213,11 @@ mod tests {

// compare leaves
let t2leaves = tree2.get_leaves_1d();
tree1.get_leaves_1d().chunks_exact(leaf_size).enumerate().for_each(
|(i, x)| {
tree1
.get_leaves_1d()
.chunks_exact(leaf_size)
.enumerate()
.for_each(|(i, x)| {
let mut ok = true;
for j in 0..leaf_size {
if x[j] != t2leaves[i * leaf_size + j] {
Expand All @@ -1107,26 +1228,110 @@ mod tests {
if !ok {
println!("Leaves different at index {:?}", i);
}
assert!(ok);
});

// compare trees
tree1.digests.into_iter().enumerate().for_each(|(i, x)| {
let y = tree2.digests[i];
if x != y {
println!("Digests different at index {:?}", i);
}
assert_eq!(x, y);
});
tree1.cap.0.into_iter().enumerate().for_each(|(i, x)| {
let y = tree2.cap.0[i];
if x != y {
println!("Cap different at index {:?}", i);
}
assert_eq!(x, y);
});
}

fn verify_change_leaf_and_update_range(
leaves_count: usize,
leaf_size: usize,
cap_height: usize,
start_index: usize,
end_index: usize,
) {
// use plonky2_field::types::Field;

const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;

let raw_leaves: Vec<Vec<F>> = random_data::<F>(leaves_count, leaf_size);
let vals: Vec<Vec<F>> = random_data::<F>(end_index - start_index, leaf_size);

let mut leaves1_1d: Vec<F> = raw_leaves.into_iter().flatten().collect();
let leaves2_1d: Vec<F> = leaves1_1d.clone();

let mut tree2 = MerkleTree::<F, <C as GenericConfig<D>>::Hasher>::new_from_1d(
leaves2_1d, leaf_size, cap_height,
);

// compare trees
tree1.digests.into_iter().enumerate().for_each(
|(i,x)| {
let y = tree2.digests[i];
if x != y {
println!("Digests different at index {:?}", i);
}
// v1
let now = Instant::now();
for i in start_index..end_index {
for j in 0..leaf_size {
leaves1_1d[i * leaf_size + j] = vals[i - start_index][j];
}
}
let tree1 = MerkleTree::<F, <C as GenericConfig<D>>::Hasher>::new_from_1d(
leaves1_1d, leaf_size, cap_height,
);
tree1.cap.0.into_iter().enumerate().for_each(
|(i,x)| {
let y = tree2.cap.0[i];
if x != y {
println!("Cap different at index {:?}", i);
println!("Time V1: {} ms", now.elapsed().as_millis());

// v2
let now = Instant::now();
/*
for idx in start_index..end_index {
let mut leaf: Vec<F> = vec![F::from_canonical_u64(0); leaf_size];
for j in 0..leaf_size {
leaf[j] = vals[idx - start_index][j];
}
tree2.change_leaf_and_update(leaf, idx);
}
*/
tree2.change_leaves_in_range_and_update(vals, start_index, end_index);
println!("Time V2: {} ms", now.elapsed().as_millis());

// compare leaves
let t2leaves = tree2.get_leaves_1d();
tree1
.get_leaves_1d()
.chunks_exact(leaf_size)
.enumerate()
.for_each(|(i, x)| {
let mut ok = true;
for j in 0..leaf_size {
if x[j] != t2leaves[i * leaf_size + j] {
ok = false;
break;
}
}
if !ok {
println!("Leaves different at index {:?}", i);
}
assert!(ok);
});

// compare trees
tree1.digests.into_iter().enumerate().for_each(|(i, x)| {
let y = tree2.digests[i];
if x != y {
println!("Digests different at index {:?}", i);
}
);
assert_eq!(x, y);
});
tree1.cap.0.into_iter().enumerate().for_each(|(i, x)| {
let y = tree2.cap.0[i];
if x != y {
println!("Cap different at index {:?}", i);
}
assert_eq!(x, y);
});
}

#[test]
Expand Down Expand Up @@ -1177,7 +1382,12 @@ mod tests {

#[test]
fn test_change_leaf_and_update_range() -> Result<()> {
verify_change_leaf_and_update_range(1024, 68, 0, 32, 48);
for h in 0..11 {
println!("Run verify_change_leaf_and_update_range_one_by_one() for height {:?}", h);
verify_change_leaf_and_update_range_one_by_one(1024, 68, h, 32, 48);
println!("Run verify_change_leaf_and_update_range() for height {:?}", h);
verify_change_leaf_and_update_range(1024, 68, h, 32, 48);
}

Ok(())
}
Expand Down

0 comments on commit 94ff49a

Please sign in to comment.