feat: implements concurrent Smt::compute_mutations
(#365)
This commit is contained in:
parent
d569c71666
commit
1b77fa8039
8 changed files with 771 additions and 461 deletions
|
@ -11,6 +11,8 @@
|
||||||
## 0.13.1 (2024-12-26)
|
## 0.13.1 (2024-12-26)
|
||||||
|
|
||||||
- Generate reverse mutations set on applying of mutations set, implemented serialization of `MutationsSet` (#355).
|
- Generate reverse mutations set on applying of mutations set, implemented serialization of `MutationsSet` (#355).
|
||||||
|
- Added parallel implementation of `Smt::compute_mutations` with better performance (#365).
|
||||||
|
- Implemented parallel leaf hashing in `Smt::process_sorted_pairs_to_leaves` (#365).
|
||||||
|
|
||||||
## 0.13.0 (2024-11-24)
|
## 0.13.0 (2024-11-24)
|
||||||
|
|
||||||
|
|
|
@ -45,7 +45,7 @@ name = "store"
|
||||||
harness = false
|
harness = false
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
concurrent = ["dep:rayon"]
|
concurrent = ["dep:rayon", "hashbrown?/rayon"]
|
||||||
default = ["std", "concurrent"]
|
default = ["std", "concurrent"]
|
||||||
executable = ["dep:clap", "dep:rand-utils", "std"]
|
executable = ["dep:clap", "dep:rand-utils", "std"]
|
||||||
smt_hashmaps = ["dep:hashbrown"]
|
smt_hashmaps = ["dep:hashbrown"]
|
||||||
|
|
77
src/main.rs
77
src/main.rs
|
@ -13,8 +13,14 @@ use rand_utils::rand_value;
|
||||||
#[clap(name = "Benchmark", about = "SMT benchmark", version, rename_all = "kebab-case")]
|
#[clap(name = "Benchmark", about = "SMT benchmark", version, rename_all = "kebab-case")]
|
||||||
pub struct BenchmarkCmd {
|
pub struct BenchmarkCmd {
|
||||||
/// Size of the tree
|
/// Size of the tree
|
||||||
#[clap(short = 's', long = "size")]
|
#[clap(short = 's', long = "size", default_value = "1000000")]
|
||||||
size: usize,
|
size: usize,
|
||||||
|
/// Number of insertions
|
||||||
|
#[clap(short = 'i', long = "insertions", default_value = "1000")]
|
||||||
|
insertions: usize,
|
||||||
|
/// Number of updates
|
||||||
|
#[clap(short = 'u', long = "updates", default_value = "1000")]
|
||||||
|
updates: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
@ -25,7 +31,10 @@ fn main() {
|
||||||
pub fn benchmark_smt() {
|
pub fn benchmark_smt() {
|
||||||
let args = BenchmarkCmd::parse();
|
let args = BenchmarkCmd::parse();
|
||||||
let tree_size = args.size;
|
let tree_size = args.size;
|
||||||
|
let insertions = args.insertions;
|
||||||
|
let updates = args.updates;
|
||||||
|
|
||||||
|
assert!(updates <= tree_size, "Cannot update more than `size`");
|
||||||
// prepare the `leaves` vector for tree creation
|
// prepare the `leaves` vector for tree creation
|
||||||
let mut entries = Vec::new();
|
let mut entries = Vec::new();
|
||||||
for i in 0..tree_size {
|
for i in 0..tree_size {
|
||||||
|
@ -35,9 +44,9 @@ pub fn benchmark_smt() {
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut tree = construction(entries.clone(), tree_size).unwrap();
|
let mut tree = construction(entries.clone(), tree_size).unwrap();
|
||||||
insertion(&mut tree).unwrap();
|
insertion(&mut tree.clone(), insertions).unwrap();
|
||||||
batched_insertion(&mut tree).unwrap();
|
batched_insertion(&mut tree.clone(), insertions).unwrap();
|
||||||
batched_update(&mut tree, entries).unwrap();
|
batched_update(&mut tree.clone(), entries, updates).unwrap();
|
||||||
proof_generation(&mut tree).unwrap();
|
proof_generation(&mut tree).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -47,23 +56,20 @@ pub fn construction(entries: Vec<(RpoDigest, Word)>, size: usize) -> Result<Smt,
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let tree = Smt::with_entries(entries)?;
|
let tree = Smt::with_entries(entries)?;
|
||||||
let elapsed = now.elapsed().as_secs_f32();
|
let elapsed = now.elapsed().as_secs_f32();
|
||||||
|
println!("Constructed an SMT with {size} key-value pairs in {elapsed:.1} seconds");
|
||||||
println!("Constructed a SMT with {size} key-value pairs in {elapsed:.1} seconds");
|
|
||||||
println!("Number of leaf nodes: {}\n", tree.leaves().count());
|
println!("Number of leaf nodes: {}\n", tree.leaves().count());
|
||||||
|
|
||||||
Ok(tree)
|
Ok(tree)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Runs the insertion benchmark for the [`Smt`].
|
/// Runs the insertion benchmark for the [`Smt`].
|
||||||
pub fn insertion(tree: &mut Smt) -> Result<(), MerkleError> {
|
pub fn insertion(tree: &mut Smt, insertions: usize) -> Result<(), MerkleError> {
|
||||||
const NUM_INSERTIONS: usize = 1_000;
|
|
||||||
|
|
||||||
println!("Running an insertion benchmark:");
|
println!("Running an insertion benchmark:");
|
||||||
|
|
||||||
let size = tree.num_leaves();
|
let size = tree.num_leaves();
|
||||||
let mut insertion_times = Vec::new();
|
let mut insertion_times = Vec::new();
|
||||||
|
|
||||||
for i in 0..NUM_INSERTIONS {
|
for i in 0..insertions {
|
||||||
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
|
let test_key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
|
||||||
let test_value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
|
let test_value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
|
||||||
|
|
||||||
|
@ -74,22 +80,20 @@ pub fn insertion(tree: &mut Smt) -> Result<(), MerkleError> {
|
||||||
}
|
}
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"An average insertion time measured by {NUM_INSERTIONS} inserts into an SMT with {size} leaves is {:.0} μs\n",
|
"The average insertion time measured by {insertions} inserts into an SMT with {size} leaves is {:.0} μs\n",
|
||||||
// calculate the average
|
// calculate the average
|
||||||
insertion_times.iter().sum::<u128>() as f64 / (NUM_INSERTIONS as f64),
|
insertion_times.iter().sum::<u128>() as f64 / (insertions as f64),
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> {
|
pub fn batched_insertion(tree: &mut Smt, insertions: usize) -> Result<(), MerkleError> {
|
||||||
const NUM_INSERTIONS: usize = 1_000;
|
|
||||||
|
|
||||||
println!("Running a batched insertion benchmark:");
|
println!("Running a batched insertion benchmark:");
|
||||||
|
|
||||||
let size = tree.num_leaves();
|
let size = tree.num_leaves();
|
||||||
|
|
||||||
let new_pairs: Vec<(RpoDigest, Word)> = (0..NUM_INSERTIONS)
|
let new_pairs: Vec<(RpoDigest, Word)> = (0..insertions)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
let key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
|
let key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
|
||||||
let value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
|
let value = [ONE, ONE, ONE, Felt::new((size + i) as u64)];
|
||||||
|
@ -101,24 +105,24 @@ pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> {
|
||||||
let mutations = tree.compute_mutations(new_pairs);
|
let mutations = tree.compute_mutations(new_pairs);
|
||||||
let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
|
let compute_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"The average insert-batch computation time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
|
||||||
|
compute_elapsed,
|
||||||
|
compute_elapsed * 1000_f64 / insertions as f64, // time in μs
|
||||||
|
);
|
||||||
|
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
tree.apply_mutations(mutations)?;
|
tree.apply_mutations(mutations)?;
|
||||||
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
|
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"An average insert-batch computation time measured by a {NUM_INSERTIONS}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
|
"The average insert-batch application time measured by a {insertions}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
|
||||||
compute_elapsed,
|
|
||||||
compute_elapsed * 1000_f64 / NUM_INSERTIONS as f64, // time in μs
|
|
||||||
);
|
|
||||||
|
|
||||||
println!(
|
|
||||||
"An average insert-batch application time measured by a {NUM_INSERTIONS}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
|
|
||||||
apply_elapsed,
|
apply_elapsed,
|
||||||
apply_elapsed * 1000_f64 / NUM_INSERTIONS as f64, // time in μs
|
apply_elapsed * 1000_f64 / insertions as f64, // time in μs
|
||||||
);
|
);
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"An average batch insertion time measured by a 1k-batch into an SMT with {size} leaves totals to {:.1} ms",
|
"The average batch insertion time measured by a {insertions}-batch into an SMT with {size} leaves totals to {:.1} ms",
|
||||||
(compute_elapsed + apply_elapsed),
|
(compute_elapsed + apply_elapsed),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -127,8 +131,11 @@ pub fn batched_insertion(tree: &mut Smt) -> Result<(), MerkleError> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result<(), MerkleError> {
|
pub fn batched_update(
|
||||||
const NUM_UPDATES: usize = 1_000;
|
tree: &mut Smt,
|
||||||
|
entries: Vec<(RpoDigest, Word)>,
|
||||||
|
updates: usize,
|
||||||
|
) -> Result<(), MerkleError> {
|
||||||
const REMOVAL_PROBABILITY: f64 = 0.2;
|
const REMOVAL_PROBABILITY: f64 = 0.2;
|
||||||
|
|
||||||
println!("Running a batched update benchmark:");
|
println!("Running a batched update benchmark:");
|
||||||
|
@ -139,7 +146,7 @@ pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result
|
||||||
let new_pairs =
|
let new_pairs =
|
||||||
entries
|
entries
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.choose_multiple(&mut rng, NUM_UPDATES)
|
.choose_multiple(&mut rng, updates)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(key, _)| {
|
.map(|(key, _)| {
|
||||||
let value = if rng.gen_bool(REMOVAL_PROBABILITY) {
|
let value = if rng.gen_bool(REMOVAL_PROBABILITY) {
|
||||||
|
@ -151,7 +158,7 @@ pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result
|
||||||
(key, value)
|
(key, value)
|
||||||
});
|
});
|
||||||
|
|
||||||
assert_eq!(new_pairs.len(), NUM_UPDATES);
|
assert_eq!(new_pairs.len(), updates);
|
||||||
|
|
||||||
let now = Instant::now();
|
let now = Instant::now();
|
||||||
let mutations = tree.compute_mutations(new_pairs);
|
let mutations = tree.compute_mutations(new_pairs);
|
||||||
|
@ -162,19 +169,19 @@ pub fn batched_update(tree: &mut Smt, entries: Vec<(RpoDigest, Word)>) -> Result
|
||||||
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
|
let apply_elapsed = now.elapsed().as_secs_f64() * 1000_f64; // time in ms
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"An average update-batch computation time measured by a {NUM_UPDATES}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
|
"The average update-batch computation time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
|
||||||
compute_elapsed,
|
compute_elapsed,
|
||||||
compute_elapsed * 1000_f64 / NUM_UPDATES as f64, // time in μs
|
compute_elapsed * 1000_f64 / updates as f64, // time in μs
|
||||||
);
|
);
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"An average update-batch application time measured by a {NUM_UPDATES}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
|
"The average update-batch application time measured by a {updates}-batch into an SMT with {size} leaves over {:.1} ms is {:.0} μs",
|
||||||
apply_elapsed,
|
apply_elapsed,
|
||||||
apply_elapsed * 1000_f64 / NUM_UPDATES as f64, // time in μs
|
apply_elapsed * 1000_f64 / updates as f64, // time in μs
|
||||||
);
|
);
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"An average batch update time measured by a 1k-batch into an SMT with {size} leaves totals to {:.1} ms",
|
"The average batch update time measured by a {updates}-batch into an SMT with {size} leaves totals to {:.1} ms",
|
||||||
(compute_elapsed + apply_elapsed),
|
(compute_elapsed + apply_elapsed),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -203,7 +210,7 @@ pub fn proof_generation(tree: &mut Smt) -> Result<(), MerkleError> {
|
||||||
}
|
}
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
"An average proving time measured by {NUM_PROOFS} value proofs in an SMT with {size} leaves in {:.0} μs",
|
"The average proving time measured by {NUM_PROOFS} value proofs in an SMT with {size} leaves in {:.0} μs",
|
||||||
// calculate the average
|
// calculate the average
|
||||||
insertion_times.iter().sum::<u128>() as f64 / (NUM_PROOFS as f64),
|
insertion_times.iter().sum::<u128>() as f64 / (NUM_PROOFS as f64),
|
||||||
);
|
);
|
||||||
|
|
|
@ -22,10 +22,10 @@ pub use path::{MerklePath, RootPath, ValuePath};
|
||||||
|
|
||||||
mod smt;
|
mod smt;
|
||||||
#[cfg(feature = "internal")]
|
#[cfg(feature = "internal")]
|
||||||
pub use smt::build_subtree_for_bench;
|
pub use smt::{build_subtree_for_bench, SubtreeLeaf};
|
||||||
pub use smt::{
|
pub use smt::{
|
||||||
InnerNode, LeafIndex, MutationSet, NodeMutation, SimpleSmt, Smt, SmtLeaf, SmtLeafError,
|
InnerNode, LeafIndex, MutationSet, NodeMutation, SimpleSmt, Smt, SmtLeaf, SmtLeafError,
|
||||||
SmtProof, SmtProofError, SubtreeLeaf, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
|
SmtProof, SmtProofError, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
|
||||||
};
|
};
|
||||||
|
|
||||||
mod mmr;
|
mod mmr;
|
||||||
|
|
580
src/merkle/smt/full/concurrent/mod.rs
Normal file
580
src/merkle/smt/full/concurrent/mod.rs
Normal file
|
@ -0,0 +1,580 @@
|
||||||
|
use alloc::{collections::BTreeSet, vec::Vec};
|
||||||
|
use core::mem;
|
||||||
|
|
||||||
|
use num::Integer;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
EmptySubtreeRoots, InnerNode, InnerNodes, LeafIndex, Leaves, MerkleError, MutationSet,
|
||||||
|
NodeIndex, RpoDigest, Smt, SmtLeaf, SparseMerkleTree, Word, SMT_DEPTH,
|
||||||
|
};
|
||||||
|
use crate::merkle::smt::{NodeMutation, NodeMutations, UnorderedMap};
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests;
|
||||||
|
|
||||||
|
type MutatedSubtreeLeaves = Vec<Vec<SubtreeLeaf>>;
|
||||||
|
|
||||||
|
impl Smt {
|
||||||
|
/// Parallel implementation of [`Smt::with_entries()`].
|
||||||
|
///
|
||||||
|
/// This method constructs a new sparse Merkle tree concurrently by processing subtrees in
|
||||||
|
/// parallel, working from the bottom up. The process works as follows:
|
||||||
|
///
|
||||||
|
/// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf
|
||||||
|
/// indices. Each subtree covers a range of 256 (2^8) possible leaf positions.
|
||||||
|
///
|
||||||
|
/// 2. The subtrees are then processed in parallel:
|
||||||
|
/// - For each subtree, compute the inner nodes from depth D down to depth D-8.
|
||||||
|
/// - Each subtree computation yields a new subtree root and its associated inner nodes.
|
||||||
|
///
|
||||||
|
/// 3. These subtree roots are recursively merged to become the "leaves" for the next iteration,
|
||||||
|
/// which processes the next 8 levels up. This continues until the final root of the tree is
|
||||||
|
/// computed at depth 0.
|
||||||
|
pub(crate) fn with_entries_concurrent(
|
||||||
|
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
|
||||||
|
) -> Result<Self, MerkleError> {
|
||||||
|
let mut seen_keys = BTreeSet::new();
|
||||||
|
let entries: Vec<_> = entries
|
||||||
|
.into_iter()
|
||||||
|
.map(|(key, value)| {
|
||||||
|
if seen_keys.insert(key) {
|
||||||
|
Ok((key, value))
|
||||||
|
} else {
|
||||||
|
Err(MerkleError::DuplicateValuesForIndex(
|
||||||
|
LeafIndex::<SMT_DEPTH>::from(key).value(),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Result<_, _>>()?;
|
||||||
|
if entries.is_empty() {
|
||||||
|
return Ok(Self::default());
|
||||||
|
}
|
||||||
|
let (inner_nodes, leaves) = Self::build_subtrees(entries);
|
||||||
|
let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash();
|
||||||
|
<Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parallel implementation of [`Smt::compute_mutations()`].
|
||||||
|
///
|
||||||
|
/// This method computes mutations by recursively processing subtrees in parallel, working from
|
||||||
|
/// the bottom up. The process works as follows:
|
||||||
|
///
|
||||||
|
/// 1. First, the input key-value pairs are sorted and grouped into subtrees based on their leaf
|
||||||
|
/// indices. Each subtree covers a range of 256 (2^8) possible leaf positions.
|
||||||
|
///
|
||||||
|
/// 2. The subtrees containing modifications are then processed in parallel:
|
||||||
|
/// - For each modified subtree, compute node mutations from depth D up to depth D-8
|
||||||
|
/// - Each subtree computation yields a new root at depth D-8 and its associated mutations
|
||||||
|
///
|
||||||
|
/// 3. These subtree roots become the "leaves" for the next iteration, which processes the next
|
||||||
|
/// 8 levels up. This continues until reaching the tree's root at depth 0.
|
||||||
|
pub(crate) fn compute_mutations_concurrent(
|
||||||
|
&self,
|
||||||
|
kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
|
||||||
|
) -> MutationSet<SMT_DEPTH, RpoDigest, Word>
|
||||||
|
where
|
||||||
|
Self: Sized + Sync,
|
||||||
|
{
|
||||||
|
use rayon::prelude::*;
|
||||||
|
|
||||||
|
// Collect and sort key-value pairs by their corresponding leaf index
|
||||||
|
let mut sorted_kv_pairs: Vec<_> = kv_pairs.into_iter().collect();
|
||||||
|
sorted_kv_pairs.par_sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value());
|
||||||
|
|
||||||
|
// Convert sorted pairs into mutated leaves and capture any new pairs
|
||||||
|
let (mut subtree_leaves, new_pairs) =
|
||||||
|
self.sorted_pairs_to_mutated_subtree_leaves(sorted_kv_pairs);
|
||||||
|
let mut node_mutations = NodeMutations::default();
|
||||||
|
|
||||||
|
// Process each depth level in reverse, stepping by the subtree depth
|
||||||
|
for depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
||||||
|
// Parallel processing of each subtree to generate mutations and roots
|
||||||
|
let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves
|
||||||
|
.into_par_iter()
|
||||||
|
.map(|subtree| {
|
||||||
|
debug_assert!(subtree.is_sorted() && !subtree.is_empty());
|
||||||
|
self.build_subtree_mutations(subtree, SMT_DEPTH, depth)
|
||||||
|
})
|
||||||
|
.unzip();
|
||||||
|
|
||||||
|
// Prepare leaves for the next depth level
|
||||||
|
subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
|
||||||
|
|
||||||
|
// Aggregate all node mutations
|
||||||
|
node_mutations.extend(mutations_per_subtree.into_iter().flatten());
|
||||||
|
|
||||||
|
debug_assert!(!subtree_leaves.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finalize the mutation set with updated roots and mutations
|
||||||
|
MutationSet {
|
||||||
|
old_root: self.root(),
|
||||||
|
new_root: subtree_leaves[0][0].hash,
|
||||||
|
node_mutations,
|
||||||
|
new_pairs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Performs the initial transforms for constructing a [`SparseMerkleTree`] by composing
|
||||||
|
/// subtrees. In other words, this function takes the key-value inputs to the tree, and produces
|
||||||
|
/// the inputs to feed into [`build_subtree()`].
|
||||||
|
///
|
||||||
|
/// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If
|
||||||
|
/// `pairs` is not correctly sorted, the returned computations will be incorrect.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
/// With debug assertions on, this function panics if it detects that `pairs` is not correctly
|
||||||
|
/// sorted. Without debug assertions, the returned computations will be incorrect.
|
||||||
|
fn sorted_pairs_to_leaves(pairs: Vec<(RpoDigest, Word)>) -> PairComputations<u64, SmtLeaf> {
|
||||||
|
Self::process_sorted_pairs_to_leaves(pairs, Self::pairs_to_leaf)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes leaves from a set of key-value pairs and current leaf values.
|
||||||
|
/// Derived from `sorted_pairs_to_leaves`
|
||||||
|
fn sorted_pairs_to_mutated_subtree_leaves(
|
||||||
|
&self,
|
||||||
|
pairs: Vec<(RpoDigest, Word)>,
|
||||||
|
) -> (MutatedSubtreeLeaves, UnorderedMap<RpoDigest, Word>) {
|
||||||
|
// Map to track new key-value pairs for mutated leaves
|
||||||
|
let mut new_pairs = UnorderedMap::new();
|
||||||
|
|
||||||
|
let accumulator = Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| {
|
||||||
|
let mut leaf = self.get_leaf(&leaf_pairs[0].0);
|
||||||
|
|
||||||
|
for (key, value) in leaf_pairs {
|
||||||
|
// Check if the value has changed
|
||||||
|
let old_value =
|
||||||
|
new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));
|
||||||
|
|
||||||
|
// Skip if the value hasn't changed
|
||||||
|
if value == old_value {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, update the leaf and track the new key-value pair
|
||||||
|
leaf = self.construct_prospective_leaf(leaf, &key, &value);
|
||||||
|
new_pairs.insert(key, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
leaf
|
||||||
|
});
|
||||||
|
(accumulator.leaves, new_pairs)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes the node mutations and the root of a subtree
|
||||||
|
fn build_subtree_mutations(
|
||||||
|
&self,
|
||||||
|
mut leaves: Vec<SubtreeLeaf>,
|
||||||
|
tree_depth: u8,
|
||||||
|
bottom_depth: u8,
|
||||||
|
) -> (NodeMutations, SubtreeLeaf)
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
debug_assert!(bottom_depth <= tree_depth);
|
||||||
|
debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH));
|
||||||
|
debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32));
|
||||||
|
|
||||||
|
let subtree_root_depth = bottom_depth - SUBTREE_DEPTH;
|
||||||
|
let mut node_mutations: NodeMutations = Default::default();
|
||||||
|
let mut next_leaves: Vec<SubtreeLeaf> = Vec::with_capacity(leaves.len() / 2);
|
||||||
|
|
||||||
|
for current_depth in (subtree_root_depth..bottom_depth).rev() {
|
||||||
|
debug_assert!(current_depth <= bottom_depth);
|
||||||
|
|
||||||
|
let next_depth = current_depth + 1;
|
||||||
|
let mut iter = leaves.drain(..).peekable();
|
||||||
|
|
||||||
|
while let Some(first_leaf) = iter.next() {
|
||||||
|
// This constructs a valid index because next_depth will never exceed the depth of
|
||||||
|
// the tree.
|
||||||
|
let parent_index = NodeIndex::new_unchecked(next_depth, first_leaf.col).parent();
|
||||||
|
let parent_node = self.get_inner_node(parent_index);
|
||||||
|
let combined_node = Self::fetch_sibling_pair(&mut iter, first_leaf, parent_node);
|
||||||
|
let combined_hash = combined_node.hash();
|
||||||
|
|
||||||
|
let &empty_hash = EmptySubtreeRoots::entry(tree_depth, current_depth);
|
||||||
|
|
||||||
|
// Add the parent node even if it is empty for proper upward updates
|
||||||
|
next_leaves.push(SubtreeLeaf {
|
||||||
|
col: parent_index.value(),
|
||||||
|
hash: combined_hash,
|
||||||
|
});
|
||||||
|
|
||||||
|
node_mutations.insert(
|
||||||
|
parent_index,
|
||||||
|
if combined_hash != empty_hash {
|
||||||
|
NodeMutation::Addition(combined_node)
|
||||||
|
} else {
|
||||||
|
NodeMutation::Removal
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
drop(iter);
|
||||||
|
leaves = mem::take(&mut next_leaves);
|
||||||
|
}
|
||||||
|
|
||||||
|
debug_assert_eq!(leaves.len(), 1);
|
||||||
|
let root_leaf = leaves.pop().unwrap();
|
||||||
|
(node_mutations, root_leaf)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Constructs an `InnerNode` representing the sibling pair of which `first_leaf` is a part:
|
||||||
|
/// - If `first_leaf` is a right child, the left child is copied from the `parent_node`.
|
||||||
|
/// - If `first_leaf` is a left child, the right child is taken from `iter` if it was also
|
||||||
|
/// mutated or copied from the `parent_node`.
|
||||||
|
///
|
||||||
|
/// Returns the `InnerNode` containing the hashes of the sibling pair.
|
||||||
|
fn fetch_sibling_pair(
|
||||||
|
iter: &mut core::iter::Peekable<alloc::vec::Drain<SubtreeLeaf>>,
|
||||||
|
first_leaf: SubtreeLeaf,
|
||||||
|
parent_node: InnerNode,
|
||||||
|
) -> InnerNode {
|
||||||
|
let is_right_node = first_leaf.col.is_odd();
|
||||||
|
|
||||||
|
if is_right_node {
|
||||||
|
let left_leaf = SubtreeLeaf {
|
||||||
|
col: first_leaf.col - 1,
|
||||||
|
hash: parent_node.left,
|
||||||
|
};
|
||||||
|
InnerNode {
|
||||||
|
left: left_leaf.hash,
|
||||||
|
right: first_leaf.hash,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let right_col = first_leaf.col + 1;
|
||||||
|
let right_leaf = match iter.peek().copied() {
|
||||||
|
Some(SubtreeLeaf { col, .. }) if col == right_col => iter.next().unwrap(),
|
||||||
|
_ => SubtreeLeaf { col: right_col, hash: parent_node.right },
|
||||||
|
};
|
||||||
|
InnerNode {
|
||||||
|
left: first_leaf.hash,
|
||||||
|
right: right_leaf.hash,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Processes sorted key-value pairs to compute leaves for a subtree.
|
||||||
|
///
|
||||||
|
/// This function groups key-value pairs by their corresponding column index and processes each
|
||||||
|
/// group to construct leaves. The actual construction of the leaf is delegated to the
|
||||||
|
/// `process_leaf` callback, allowing flexibility for different use cases (e.g., creating
|
||||||
|
/// new leaves or mutating existing ones).
|
||||||
|
///
|
||||||
|
/// # Parameters
|
||||||
|
/// - `pairs`: A vector of sorted key-value pairs. The pairs *must* be sorted by leaf index
|
||||||
|
/// column (not simply by key). If the input is not sorted correctly, the function will
|
||||||
|
/// produce incorrect results and may panic in debug mode.
|
||||||
|
/// - `process_leaf`: A callback function used to process each group of key-value pairs
|
||||||
|
/// corresponding to the same column index. The callback takes a vector of key-value pairs for
|
||||||
|
/// a single column and returns the constructed leaf for that column.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// A `PairComputations<u64, Self::Leaf>` containing:
|
||||||
|
/// - `nodes`: A mapping of column indices to the constructed leaves.
|
||||||
|
/// - `leaves`: A collection of `SubtreeLeaf` structures representing the processed leaves. Each
|
||||||
|
/// `SubtreeLeaf` includes the column index and the hash of the corresponding leaf.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
/// This function will panic in debug mode if the input `pairs` are not sorted by column index.
|
||||||
|
fn process_sorted_pairs_to_leaves<F>(
|
||||||
|
pairs: Vec<(RpoDigest, Word)>,
|
||||||
|
mut process_leaf: F,
|
||||||
|
) -> PairComputations<u64, SmtLeaf>
|
||||||
|
where
|
||||||
|
F: FnMut(Vec<(RpoDigest, Word)>) -> SmtLeaf,
|
||||||
|
{
|
||||||
|
use rayon::prelude::*;
|
||||||
|
debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value()));
|
||||||
|
|
||||||
|
let mut accumulator: PairComputations<u64, SmtLeaf> = Default::default();
|
||||||
|
|
||||||
|
// As we iterate, we'll keep track of the kv-pairs we've seen so far that correspond to a
|
||||||
|
// single leaf. When we see a pair that's in a different leaf, we'll swap these pairs
|
||||||
|
// out and store them in our accumulated leaves.
|
||||||
|
let mut current_leaf_buffer: Vec<(RpoDigest, Word)> = Default::default();
|
||||||
|
|
||||||
|
let mut iter = pairs.into_iter().peekable();
|
||||||
|
while let Some((key, value)) = iter.next() {
|
||||||
|
let col = Self::key_to_leaf_index(&key).index.value();
|
||||||
|
let peeked_col = iter.peek().map(|(key, _v)| {
|
||||||
|
let index = Self::key_to_leaf_index(key);
|
||||||
|
let next_col = index.index.value();
|
||||||
|
// We panic if `pairs` is not sorted by column.
|
||||||
|
debug_assert!(next_col >= col);
|
||||||
|
next_col
|
||||||
|
});
|
||||||
|
current_leaf_buffer.push((key, value));
|
||||||
|
|
||||||
|
// If the next pair is the same column as this one, then we're done after adding this
|
||||||
|
// pair to the buffer.
|
||||||
|
if peeked_col == Some(col) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, the next pair is a different column, or there is no next pair. Either way
|
||||||
|
// it's time to swap out our buffer.
|
||||||
|
let leaf_pairs = mem::take(&mut current_leaf_buffer);
|
||||||
|
let leaf = process_leaf(leaf_pairs);
|
||||||
|
|
||||||
|
accumulator.nodes.insert(col, leaf);
|
||||||
|
|
||||||
|
debug_assert!(current_leaf_buffer.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute the leaves from the nodes concurrently
|
||||||
|
let mut accumulated_leaves: Vec<SubtreeLeaf> = accumulator
|
||||||
|
.nodes
|
||||||
|
.clone()
|
||||||
|
.into_par_iter()
|
||||||
|
.map(|(col, leaf)| SubtreeLeaf { col, hash: Self::hash_leaf(&leaf) })
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Sort the leaves by column
|
||||||
|
accumulated_leaves.par_sort_by_key(|leaf| leaf.col);
|
||||||
|
|
||||||
|
// TODO: determine is there is any notable performance difference between computing
|
||||||
|
// subtree boundaries after the fact as an iterator adapter (like this), versus computing
|
||||||
|
// subtree boundaries as we go. Either way this function is only used at the beginning of a
|
||||||
|
// parallel construction, so it should not be a critical path.
|
||||||
|
accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect();
|
||||||
|
accumulator
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs.
|
||||||
|
///
|
||||||
|
/// `entries` need not be sorted. This function will sort them.
|
||||||
|
fn build_subtrees(mut entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) {
|
||||||
|
entries.sort_by_key(|item| {
|
||||||
|
let index = Self::key_to_leaf_index(&item.0);
|
||||||
|
index.value()
|
||||||
|
});
|
||||||
|
Self::build_subtrees_from_sorted_entries(entries)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs.
|
||||||
|
///
|
||||||
|
/// This function is mostly an implementation detail of
|
||||||
|
/// [`Smt::with_entries_concurrent()`].
|
||||||
|
fn build_subtrees_from_sorted_entries(entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) {
|
||||||
|
use rayon::prelude::*;
|
||||||
|
|
||||||
|
let mut accumulated_nodes: InnerNodes = Default::default();
|
||||||
|
|
||||||
|
let PairComputations {
|
||||||
|
leaves: mut leaf_subtrees,
|
||||||
|
nodes: initial_leaves,
|
||||||
|
} = Self::sorted_pairs_to_leaves(entries);
|
||||||
|
|
||||||
|
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
||||||
|
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) =
|
||||||
|
leaf_subtrees
|
||||||
|
.into_par_iter()
|
||||||
|
.map(|subtree| {
|
||||||
|
debug_assert!(subtree.is_sorted());
|
||||||
|
debug_assert!(!subtree.is_empty());
|
||||||
|
let (nodes, subtree_root) =
|
||||||
|
build_subtree(subtree, SMT_DEPTH, current_depth);
|
||||||
|
(nodes, subtree_root)
|
||||||
|
})
|
||||||
|
.unzip();
|
||||||
|
|
||||||
|
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
|
||||||
|
accumulated_nodes.extend(nodes.into_iter().flatten());
|
||||||
|
|
||||||
|
debug_assert!(!leaf_subtrees.is_empty());
|
||||||
|
}
|
||||||
|
(accumulated_nodes, initial_leaves)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SUBTREES
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
/// A subtree is of depth 8.
|
||||||
|
const SUBTREE_DEPTH: u8 = 8;
|
||||||
|
|
||||||
|
/// A depth-8 subtree contains 256 "columns" that can possibly be occupied.
|
||||||
|
const COLS_PER_SUBTREE: u64 = u64::pow(2, SUBTREE_DEPTH as u32);
|
||||||
|
|
||||||
|
/// Helper struct for organizing the data we care about when computing Merkle subtrees.
|
||||||
|
///
|
||||||
|
/// Note that these represet "conceptual" leaves of some subtree, not necessarily
|
||||||
|
/// the leaf type for the sparse Merkle tree.
|
||||||
|
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
|
||||||
|
pub struct SubtreeLeaf {
|
||||||
|
/// The 'value' field of [`NodeIndex`]. When computing a subtree, the depth is already known.
|
||||||
|
pub col: u64,
|
||||||
|
/// The hash of the node this `SubtreeLeaf` represents.
|
||||||
|
pub hash: RpoDigest,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Helper struct to organize the return value of [`Smt::sorted_pairs_to_leaves()`].
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct PairComputations<K, L> {
|
||||||
|
/// Literal leaves to be added to the sparse Merkle tree's internal mapping.
|
||||||
|
pub nodes: UnorderedMap<K, L>,
|
||||||
|
/// "Conceptual" leaves that will be used for computations.
|
||||||
|
pub leaves: Vec<Vec<SubtreeLeaf>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Derive requires `L` to impl Default, even though we don't actually need that.
|
||||||
|
impl<K, L> Default for PairComputations<K, L> {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
nodes: Default::default(),
|
||||||
|
leaves: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub(crate) struct SubtreeLeavesIter<'s> {
|
||||||
|
leaves: core::iter::Peekable<alloc::vec::Drain<'s, SubtreeLeaf>>,
|
||||||
|
}
|
||||||
|
impl<'s> SubtreeLeavesIter<'s> {
|
||||||
|
fn from_leaves(leaves: &'s mut Vec<SubtreeLeaf>) -> Self {
|
||||||
|
// TODO: determine if there is any notable performance difference between taking a Vec,
|
||||||
|
// which many need flattening first, vs storing a `Box<dyn Iterator<Item = SubtreeLeaf>>`.
|
||||||
|
// The latter may have self-referential properties that are impossible to express in purely
|
||||||
|
// safe Rust Rust.
|
||||||
|
Self { leaves: leaves.drain(..).peekable() }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl Iterator for SubtreeLeavesIter<'_> {
|
||||||
|
type Item = Vec<SubtreeLeaf>;
|
||||||
|
|
||||||
|
/// Each `next()` collects an entire subtree.
|
||||||
|
fn next(&mut self) -> Option<Vec<SubtreeLeaf>> {
|
||||||
|
let mut subtree: Vec<SubtreeLeaf> = Default::default();
|
||||||
|
|
||||||
|
let mut last_subtree_col = 0;
|
||||||
|
|
||||||
|
while let Some(leaf) = self.leaves.peek() {
|
||||||
|
last_subtree_col = u64::max(1, last_subtree_col);
|
||||||
|
let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE);
|
||||||
|
let next_subtree_col = if is_exact_multiple {
|
||||||
|
u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE)
|
||||||
|
} else {
|
||||||
|
last_subtree_col.next_multiple_of(COLS_PER_SUBTREE)
|
||||||
|
};
|
||||||
|
|
||||||
|
last_subtree_col = leaf.col;
|
||||||
|
if leaf.col < next_subtree_col {
|
||||||
|
subtree.push(self.leaves.next().unwrap());
|
||||||
|
} else if subtree.is_empty() {
|
||||||
|
continue;
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if subtree.is_empty() {
|
||||||
|
debug_assert!(self.leaves.peek().is_none());
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(subtree)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HELPER FUNCTIONS
|
||||||
|
// ================================================================================================
|
||||||
|
|
||||||
|
/// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and
|
||||||
|
/// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and
|
||||||
|
/// `leaves` must not contain more than one depth-8 subtree's worth of leaves.
|
||||||
|
///
|
||||||
|
/// This function will then calculate the inner nodes above each leaf for 8 layers, as well as
|
||||||
|
/// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into
|
||||||
|
/// itself.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
/// With debug assertions on, this function panics under invalid inputs: if `leaves` contains
|
||||||
|
/// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to
|
||||||
|
/// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified
|
||||||
|
/// maximum depth (`DEPTH`), or if `leaves` is not sorted.
|
||||||
|
#[cfg(feature = "concurrent")]
|
||||||
|
pub(crate) fn build_subtree(
|
||||||
|
mut leaves: Vec<SubtreeLeaf>,
|
||||||
|
tree_depth: u8,
|
||||||
|
bottom_depth: u8,
|
||||||
|
) -> (UnorderedMap<NodeIndex, InnerNode>, SubtreeLeaf) {
|
||||||
|
debug_assert!(bottom_depth <= tree_depth);
|
||||||
|
debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH));
|
||||||
|
debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32));
|
||||||
|
let subtree_root = bottom_depth - SUBTREE_DEPTH;
|
||||||
|
let mut inner_nodes: UnorderedMap<NodeIndex, InnerNode> = Default::default();
|
||||||
|
let mut next_leaves: Vec<SubtreeLeaf> = Vec::with_capacity(leaves.len() / 2);
|
||||||
|
for next_depth in (subtree_root..bottom_depth).rev() {
|
||||||
|
debug_assert!(next_depth <= bottom_depth);
|
||||||
|
// `next_depth` is the stuff we're making.
|
||||||
|
// `current_depth` is the stuff we have.
|
||||||
|
let current_depth = next_depth + 1;
|
||||||
|
let mut iter = leaves.drain(..).peekable();
|
||||||
|
while let Some(first) = iter.next() {
|
||||||
|
// On non-continuous iterations, including the first iteration, `first_column` may
|
||||||
|
// be a left or right node. On subsequent continuous iterations, we will always call
|
||||||
|
// `iter.next()` twice.
|
||||||
|
// On non-continuous iterations (including the very first iteration), this column
|
||||||
|
// could be either on the left or the right. If the next iteration is not
|
||||||
|
// discontinuous with our right node, then the next iteration's
|
||||||
|
let is_right = first.col.is_odd();
|
||||||
|
let (left, right) = if is_right {
|
||||||
|
// Discontinuous iteration: we have no left node, so it must be empty.
|
||||||
|
let left = SubtreeLeaf {
|
||||||
|
col: first.col - 1,
|
||||||
|
hash: *EmptySubtreeRoots::entry(tree_depth, current_depth),
|
||||||
|
};
|
||||||
|
let right = first;
|
||||||
|
(left, right)
|
||||||
|
} else {
|
||||||
|
let left = first;
|
||||||
|
let right_col = first.col + 1;
|
||||||
|
let right = match iter.peek().copied() {
|
||||||
|
Some(SubtreeLeaf { col, .. }) if col == right_col => {
|
||||||
|
// Our inputs must be sorted.
|
||||||
|
debug_assert!(left.col <= col);
|
||||||
|
// The next leaf in the iterator is our sibling. Use it and consume it!
|
||||||
|
iter.next().unwrap()
|
||||||
|
},
|
||||||
|
// Otherwise, the leaves don't contain our sibling, so our sibling must be
|
||||||
|
// empty.
|
||||||
|
_ => SubtreeLeaf {
|
||||||
|
col: right_col,
|
||||||
|
hash: *EmptySubtreeRoots::entry(tree_depth, current_depth),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
(left, right)
|
||||||
|
};
|
||||||
|
let index = NodeIndex::new_unchecked(current_depth, left.col).parent();
|
||||||
|
let node = InnerNode { left: left.hash, right: right.hash };
|
||||||
|
let hash = node.hash();
|
||||||
|
let &equivalent_empty_hash = EmptySubtreeRoots::entry(tree_depth, next_depth);
|
||||||
|
// If this hash is empty, then it doesn't become a new inner node, nor does it count
|
||||||
|
// as a leaf for the next depth.
|
||||||
|
if hash != equivalent_empty_hash {
|
||||||
|
inner_nodes.insert(index, node);
|
||||||
|
next_leaves.push(SubtreeLeaf { col: index.value(), hash });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Stop borrowing `leaves`, so we can swap it.
|
||||||
|
// The iterator is empty at this point anyway.
|
||||||
|
drop(iter);
|
||||||
|
// After each depth, consider the stuff we just made the new "leaves", and empty the
|
||||||
|
// other collection.
|
||||||
|
mem::swap(&mut leaves, &mut next_leaves);
|
||||||
|
}
|
||||||
|
debug_assert_eq!(leaves.len(), 1);
|
||||||
|
let root = leaves.pop().unwrap();
|
||||||
|
(inner_nodes, root)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "internal")]
|
||||||
|
pub fn build_subtree_for_bench(
|
||||||
|
leaves: Vec<SubtreeLeaf>,
|
||||||
|
tree_depth: u8,
|
||||||
|
bottom_depth: u8,
|
||||||
|
) -> (UnorderedMap<NodeIndex, InnerNode>, SubtreeLeaf) {
|
||||||
|
build_subtree(leaves, tree_depth, bottom_depth)
|
||||||
|
}
|
|
@ -1,14 +1,16 @@
|
||||||
use alloc::{collections::BTreeMap, vec::Vec};
|
use alloc::{
|
||||||
|
collections::{BTreeMap, BTreeSet},
|
||||||
|
vec::Vec,
|
||||||
|
};
|
||||||
|
|
||||||
|
use rand::{prelude::IteratorRandom, thread_rng, Rng};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
build_subtree, InnerNode, LeafIndex, NodeIndex, PairComputations, SmtLeaf, SparseMerkleTree,
|
build_subtree, InnerNode, LeafIndex, NodeIndex, NodeMutations, PairComputations, RpoDigest,
|
||||||
SubtreeLeaf, SubtreeLeavesIter, COLS_PER_SUBTREE, SUBTREE_DEPTH,
|
Smt, SmtLeaf, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, UnorderedMap, COLS_PER_SUBTREE,
|
||||||
};
|
SMT_DEPTH, SUBTREE_DEPTH,
|
||||||
use crate::{
|
|
||||||
hash::rpo::RpoDigest,
|
|
||||||
merkle::{Smt, SMT_DEPTH},
|
|
||||||
Felt, Word, ONE,
|
|
||||||
};
|
};
|
||||||
|
use crate::{merkle::smt::Felt, Word, EMPTY_WORD, ONE};
|
||||||
|
|
||||||
fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf {
|
fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf {
|
||||||
SubtreeLeaf {
|
SubtreeLeaf {
|
||||||
|
@ -32,9 +34,7 @@ fn test_sorted_pairs_to_leaves() {
|
||||||
// Subtree 2. Another normal leaf.
|
// Subtree 2. Another normal leaf.
|
||||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(1024)]), [ONE; 4]),
|
(RpoDigest::new([ONE, ONE, ONE, Felt::new(1024)]), [ONE; 4]),
|
||||||
];
|
];
|
||||||
|
|
||||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||||
|
|
||||||
let control_leaves: Vec<SmtLeaf> = {
|
let control_leaves: Vec<SmtLeaf> = {
|
||||||
let mut entries_iter = entries.iter().cloned();
|
let mut entries_iter = entries.iter().cloned();
|
||||||
let mut next_entry = || entries_iter.next().unwrap();
|
let mut next_entry = || entries_iter.next().unwrap();
|
||||||
|
@ -52,11 +52,9 @@ fn test_sorted_pairs_to_leaves() {
|
||||||
assert_eq!(entries_iter.next(), None);
|
assert_eq!(entries_iter.next(), None);
|
||||||
control_leaves
|
control_leaves
|
||||||
};
|
};
|
||||||
|
|
||||||
let control_subtree_leaves: Vec<Vec<SubtreeLeaf>> = {
|
let control_subtree_leaves: Vec<Vec<SubtreeLeaf>> = {
|
||||||
let mut control_leaves_iter = control_leaves.iter();
|
let mut control_leaves_iter = control_leaves.iter();
|
||||||
let mut next_leaf = || control_leaves_iter.next().unwrap();
|
let mut next_leaf = || control_leaves_iter.next().unwrap();
|
||||||
|
|
||||||
let control_subtree_leaves: Vec<Vec<SubtreeLeaf>> = [
|
let control_subtree_leaves: Vec<Vec<SubtreeLeaf>> = [
|
||||||
// Subtree 0.
|
// Subtree 0.
|
||||||
vec![next_leaf(), next_leaf(), next_leaf()],
|
vec![next_leaf(), next_leaf(), next_leaf()],
|
||||||
|
@ -70,22 +68,18 @@ fn test_sorted_pairs_to_leaves() {
|
||||||
assert_eq!(control_leaves_iter.next(), None);
|
assert_eq!(control_leaves_iter.next(), None);
|
||||||
control_subtree_leaves
|
control_subtree_leaves
|
||||||
};
|
};
|
||||||
|
|
||||||
let subtrees: PairComputations<u64, SmtLeaf> = Smt::sorted_pairs_to_leaves(entries);
|
let subtrees: PairComputations<u64, SmtLeaf> = Smt::sorted_pairs_to_leaves(entries);
|
||||||
// This will check that the hashes, columns, and subtree assignments all match.
|
// This will check that the hashes, columns, and subtree assignments all match.
|
||||||
assert_eq!(subtrees.leaves, control_subtree_leaves);
|
assert_eq!(subtrees.leaves, control_subtree_leaves);
|
||||||
|
|
||||||
// Flattening and re-separating out the leaves into subtrees should have the same result.
|
// Flattening and re-separating out the leaves into subtrees should have the same result.
|
||||||
let mut all_leaves: Vec<SubtreeLeaf> = subtrees.leaves.clone().into_iter().flatten().collect();
|
let mut all_leaves: Vec<SubtreeLeaf> = subtrees.leaves.clone().into_iter().flatten().collect();
|
||||||
let re_grouped: Vec<Vec<_>> = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect();
|
let re_grouped: Vec<Vec<_>> = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect();
|
||||||
assert_eq!(subtrees.leaves, re_grouped);
|
assert_eq!(subtrees.leaves, re_grouped);
|
||||||
|
|
||||||
// Then finally we might as well check the computed leaf nodes too.
|
// Then finally we might as well check the computed leaf nodes too.
|
||||||
let control_leaves: BTreeMap<u64, SmtLeaf> = control
|
let control_leaves: BTreeMap<u64, SmtLeaf> = control
|
||||||
.leaves()
|
.leaves()
|
||||||
.map(|(index, value)| (index.index.value(), value.clone()))
|
.map(|(index, value)| (index.index.value(), value.clone()))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
for (column, test_leaf) in subtrees.nodes {
|
for (column, test_leaf) in subtrees.nodes {
|
||||||
if test_leaf.is_empty() {
|
if test_leaf.is_empty() {
|
||||||
continue;
|
continue;
|
||||||
|
@ -96,7 +90,6 @@ fn test_sorted_pairs_to_leaves() {
|
||||||
assert_eq!(control_leaf, &test_leaf);
|
assert_eq!(control_leaf, &test_leaf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper for the below tests.
|
// Helper for the below tests.
|
||||||
fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> {
|
fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> {
|
||||||
(0..pair_count)
|
(0..pair_count)
|
||||||
|
@ -108,23 +101,41 @@ fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> {
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
fn generate_updates(entries: Vec<(RpoDigest, Word)>, updates: usize) -> Vec<(RpoDigest, Word)> {
|
||||||
|
const REMOVAL_PROBABILITY: f64 = 0.2;
|
||||||
|
let mut rng = thread_rng();
|
||||||
|
// Assertion to ensure input keys are unique
|
||||||
|
assert!(
|
||||||
|
entries.iter().map(|(key, _)| key).collect::<BTreeSet<_>>().len() == entries.len(),
|
||||||
|
"Input entries contain duplicate keys!"
|
||||||
|
);
|
||||||
|
let mut sorted_entries: Vec<(RpoDigest, Word)> = entries
|
||||||
|
.into_iter()
|
||||||
|
.choose_multiple(&mut rng, updates)
|
||||||
|
.into_iter()
|
||||||
|
.map(|(key, _)| {
|
||||||
|
let value = if rng.gen_bool(REMOVAL_PROBABILITY) {
|
||||||
|
EMPTY_WORD
|
||||||
|
} else {
|
||||||
|
[ONE, ONE, ONE, Felt::new(rng.gen())]
|
||||||
|
};
|
||||||
|
(key, value)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
sorted_entries.sort_by_key(|(key, _)| Smt::key_to_leaf_index(key).value());
|
||||||
|
sorted_entries
|
||||||
|
}
|
||||||
#[test]
|
#[test]
|
||||||
fn test_single_subtree() {
|
fn test_single_subtree() {
|
||||||
// A single subtree's worth of leaves.
|
// A single subtree's worth of leaves.
|
||||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE;
|
const PAIR_COUNT: u64 = COLS_PER_SUBTREE;
|
||||||
|
|
||||||
let entries = generate_entries(PAIR_COUNT);
|
let entries = generate_entries(PAIR_COUNT);
|
||||||
|
|
||||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||||
|
|
||||||
// `entries` should already be sorted by nature of how we constructed it.
|
// `entries` should already be sorted by nature of how we constructed it.
|
||||||
let leaves = Smt::sorted_pairs_to_leaves(entries).leaves;
|
let leaves = Smt::sorted_pairs_to_leaves(entries).leaves;
|
||||||
let leaves = leaves.into_iter().next().unwrap();
|
let leaves = leaves.into_iter().next().unwrap();
|
||||||
|
|
||||||
let (first_subtree, subtree_root) = build_subtree(leaves, SMT_DEPTH, SMT_DEPTH);
|
let (first_subtree, subtree_root) = build_subtree(leaves, SMT_DEPTH, SMT_DEPTH);
|
||||||
assert!(!first_subtree.is_empty());
|
assert!(!first_subtree.is_empty());
|
||||||
|
|
||||||
// The inner nodes computed from that subtree should match the nodes in our control tree.
|
// The inner nodes computed from that subtree should match the nodes in our control tree.
|
||||||
for (index, node) in first_subtree.into_iter() {
|
for (index, node) in first_subtree.into_iter() {
|
||||||
let control = control.get_inner_node(index);
|
let control = control.get_inner_node(index);
|
||||||
|
@ -133,7 +144,6 @@ fn test_single_subtree() {
|
||||||
"subtree-computed node at index {index:?} does not match control",
|
"subtree-computed node at index {index:?} does not match control",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// The root returned should also match the equivalent node in the control tree.
|
// The root returned should also match the equivalent node in the control tree.
|
||||||
let control_root_index =
|
let control_root_index =
|
||||||
NodeIndex::new(SMT_DEPTH - SUBTREE_DEPTH, subtree_root.col).expect("Valid root index");
|
NodeIndex::new(SMT_DEPTH - SUBTREE_DEPTH, subtree_root.col).expect("Valid root index");
|
||||||
|
@ -144,7 +154,6 @@ fn test_single_subtree() {
|
||||||
"Subtree-computed root at index {control_root_index:?} does not match control"
|
"Subtree-computed root at index {control_root_index:?} does not match control"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that not just can we compute a subtree correctly, but we can feed the results of one
|
// Test that not just can we compute a subtree correctly, but we can feed the results of one
|
||||||
// subtree into computing another. In other words, test that `build_subtree()` is correctly
|
// subtree into computing another. In other words, test that `build_subtree()` is correctly
|
||||||
// composable.
|
// composable.
|
||||||
|
@ -152,30 +161,22 @@ fn test_single_subtree() {
|
||||||
fn test_two_subtrees() {
|
fn test_two_subtrees() {
|
||||||
// Two subtrees' worth of leaves.
|
// Two subtrees' worth of leaves.
|
||||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2;
|
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2;
|
||||||
|
|
||||||
let entries = generate_entries(PAIR_COUNT);
|
let entries = generate_entries(PAIR_COUNT);
|
||||||
|
|
||||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||||
|
|
||||||
let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries);
|
let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries);
|
||||||
// With two subtrees' worth of leaves, we should have exactly two subtrees.
|
// With two subtrees' worth of leaves, we should have exactly two subtrees.
|
||||||
let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap();
|
let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap();
|
||||||
assert_eq!(first.len() as u64, PAIR_COUNT / 2);
|
assert_eq!(first.len() as u64, PAIR_COUNT / 2);
|
||||||
assert_eq!(first.len(), second.len());
|
assert_eq!(first.len(), second.len());
|
||||||
|
|
||||||
let mut current_depth = SMT_DEPTH;
|
let mut current_depth = SMT_DEPTH;
|
||||||
let mut next_leaves: Vec<SubtreeLeaf> = Default::default();
|
let mut next_leaves: Vec<SubtreeLeaf> = Default::default();
|
||||||
|
|
||||||
let (first_nodes, first_root) = build_subtree(first, SMT_DEPTH, current_depth);
|
let (first_nodes, first_root) = build_subtree(first, SMT_DEPTH, current_depth);
|
||||||
next_leaves.push(first_root);
|
next_leaves.push(first_root);
|
||||||
|
|
||||||
let (second_nodes, second_root) = build_subtree(second, SMT_DEPTH, current_depth);
|
let (second_nodes, second_root) = build_subtree(second, SMT_DEPTH, current_depth);
|
||||||
next_leaves.push(second_root);
|
next_leaves.push(second_root);
|
||||||
|
|
||||||
// All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle.
|
// All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle.
|
||||||
let total_computed = first_nodes.len() + second_nodes.len() + next_leaves.len();
|
let total_computed = first_nodes.len() + second_nodes.len() + next_leaves.len();
|
||||||
assert_eq!(total_computed as u64, PAIR_COUNT);
|
assert_eq!(total_computed as u64, PAIR_COUNT);
|
||||||
|
|
||||||
// Verify the computed nodes of both subtrees.
|
// Verify the computed nodes of both subtrees.
|
||||||
let computed_nodes = first_nodes.clone().into_iter().chain(second_nodes);
|
let computed_nodes = first_nodes.clone().into_iter().chain(second_nodes);
|
||||||
for (index, test_node) in computed_nodes {
|
for (index, test_node) in computed_nodes {
|
||||||
|
@ -185,13 +186,10 @@ fn test_two_subtrees() {
|
||||||
"subtree-computed node at index {index:?} does not match control",
|
"subtree-computed node at index {index:?} does not match control",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
current_depth -= SUBTREE_DEPTH;
|
current_depth -= SUBTREE_DEPTH;
|
||||||
|
|
||||||
let (nodes, root_leaf) = build_subtree(next_leaves, SMT_DEPTH, current_depth);
|
let (nodes, root_leaf) = build_subtree(next_leaves, SMT_DEPTH, current_depth);
|
||||||
assert_eq!(nodes.len(), SUBTREE_DEPTH as usize);
|
assert_eq!(nodes.len(), SUBTREE_DEPTH as usize);
|
||||||
assert_eq!(root_leaf.col, 0);
|
assert_eq!(root_leaf.col, 0);
|
||||||
|
|
||||||
for (index, test_node) in nodes {
|
for (index, test_node) in nodes {
|
||||||
let control_node = control.get_inner_node(index);
|
let control_node = control.get_inner_node(index);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -199,30 +197,23 @@ fn test_two_subtrees() {
|
||||||
"subtree-computed node at index {index:?} does not match control",
|
"subtree-computed node at index {index:?} does not match control",
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let index = NodeIndex::new(current_depth - SUBTREE_DEPTH, root_leaf.col).unwrap();
|
let index = NodeIndex::new(current_depth - SUBTREE_DEPTH, root_leaf.col).unwrap();
|
||||||
let control_root = control.get_inner_node(index).hash();
|
let control_root = control.get_inner_node(index).hash();
|
||||||
assert_eq!(control_root, root_leaf.hash, "Root mismatch");
|
assert_eq!(control_root, root_leaf.hash, "Root mismatch");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_singlethreaded_subtrees() {
|
fn test_singlethreaded_subtrees() {
|
||||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
|
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
|
||||||
|
|
||||||
let entries = generate_entries(PAIR_COUNT);
|
let entries = generate_entries(PAIR_COUNT);
|
||||||
|
|
||||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||||
|
|
||||||
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
|
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
|
||||||
|
|
||||||
let PairComputations {
|
let PairComputations {
|
||||||
leaves: mut leaf_subtrees,
|
leaves: mut leaf_subtrees,
|
||||||
nodes: test_leaves,
|
nodes: test_leaves,
|
||||||
} = Smt::sorted_pairs_to_leaves(entries);
|
} = Smt::sorted_pairs_to_leaves(entries);
|
||||||
|
|
||||||
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
||||||
// There's no flat_map_unzip(), so this is the best we can do.
|
// There's no flat_map_unzip(), so this is the best we can do.
|
||||||
let (nodes, mut subtree_roots): (Vec<BTreeMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
|
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(i, subtree)| {
|
.map(|(i, subtree)| {
|
||||||
|
@ -235,10 +226,8 @@ fn test_singlethreaded_subtrees() {
|
||||||
!subtree.is_empty(),
|
!subtree.is_empty(),
|
||||||
"subtree {i} at bottom-depth {current_depth} is empty!",
|
"subtree {i} at bottom-depth {current_depth} is empty!",
|
||||||
);
|
);
|
||||||
|
|
||||||
// Do actual things.
|
// Do actual things.
|
||||||
let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth);
|
let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth);
|
||||||
|
|
||||||
// Post-assertions.
|
// Post-assertions.
|
||||||
for (&index, test_node) in nodes.iter() {
|
for (&index, test_node) in nodes.iter() {
|
||||||
let control_node = control.get_inner_node(index);
|
let control_node = control.get_inner_node(index);
|
||||||
|
@ -248,19 +237,14 @@ fn test_singlethreaded_subtrees() {
|
||||||
current_depth, i, index,
|
current_depth, i, index,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
(nodes, subtree_root)
|
(nodes, subtree_root)
|
||||||
})
|
})
|
||||||
.unzip();
|
.unzip();
|
||||||
|
|
||||||
// Update state between each depth iteration.
|
// Update state between each depth iteration.
|
||||||
|
|
||||||
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
|
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
|
||||||
accumulated_nodes.extend(nodes.into_iter().flatten());
|
accumulated_nodes.extend(nodes.into_iter().flatten());
|
||||||
|
|
||||||
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
|
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure the true leaves match, first checking length and then checking each individual
|
// Make sure the true leaves match, first checking length and then checking each individual
|
||||||
// leaf.
|
// leaf.
|
||||||
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
|
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
|
||||||
|
@ -272,7 +256,6 @@ fn test_singlethreaded_subtrees() {
|
||||||
let &control_leaf = control_leaves.get(&index).unwrap();
|
let &control_leaf = control_leaves.get(&index).unwrap();
|
||||||
assert_eq!(test_leaf, control_leaf, "test leaf at column {col} does not match control");
|
assert_eq!(test_leaf, control_leaf, "test leaf at column {col} does not match control");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure the inner nodes match, checking length first and then each individual leaf.
|
// Make sure the inner nodes match, checking length first and then each individual leaf.
|
||||||
let control_nodes_len = control.inner_nodes().count();
|
let control_nodes_len = control.inner_nodes().count();
|
||||||
let test_nodes_len = accumulated_nodes.len();
|
let test_nodes_len = accumulated_nodes.len();
|
||||||
|
@ -281,20 +264,16 @@ fn test_singlethreaded_subtrees() {
|
||||||
let control_node = control.get_inner_node(index);
|
let control_node = control.get_inner_node(index);
|
||||||
assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
|
assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
// After the last iteration of the above for loop, we should have the new root node actually
|
// After the last iteration of the above for loop, we should have the new root node actually
|
||||||
// in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from
|
// in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from
|
||||||
// `build_subtree()`. So let's check both!
|
// `build_subtree()`. So let's check both!
|
||||||
|
|
||||||
let control_root = control.get_inner_node(NodeIndex::root());
|
let control_root = control.get_inner_node(NodeIndex::root());
|
||||||
|
|
||||||
// That for loop should have left us with only one leaf subtree...
|
// That for loop should have left us with only one leaf subtree...
|
||||||
let [leaf_subtree]: [Vec<_>; 1] = leaf_subtrees.try_into().unwrap();
|
let [leaf_subtree]: [Vec<_>; 1] = leaf_subtrees.try_into().unwrap();
|
||||||
// which itself contains only one 'leaf'...
|
// which itself contains only one 'leaf'...
|
||||||
let [root_leaf]: [SubtreeLeaf; 1] = leaf_subtree.try_into().unwrap();
|
let [root_leaf]: [SubtreeLeaf; 1] = leaf_subtree.try_into().unwrap();
|
||||||
// which matches the expected root.
|
// which matches the expected root.
|
||||||
assert_eq!(control.root(), root_leaf.hash);
|
assert_eq!(control.root(), root_leaf.hash);
|
||||||
|
|
||||||
// Likewise `accumulated_nodes` should contain a node at the root index...
|
// Likewise `accumulated_nodes` should contain a node at the root index...
|
||||||
assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
|
assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
|
||||||
// and it should match our actual root.
|
// and it should match our actual root.
|
||||||
|
@ -303,28 +282,20 @@ fn test_singlethreaded_subtrees() {
|
||||||
// And of course the root we got from each place should match.
|
// And of course the root we got from each place should match.
|
||||||
assert_eq!(control.root(), root_leaf.hash);
|
assert_eq!(control.root(), root_leaf.hash);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The parallel version of `test_singlethreaded_subtree()`.
|
/// The parallel version of `test_singlethreaded_subtree()`.
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "concurrent")]
|
|
||||||
fn test_multithreaded_subtrees() {
|
fn test_multithreaded_subtrees() {
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
|
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
|
||||||
|
|
||||||
let entries = generate_entries(PAIR_COUNT);
|
let entries = generate_entries(PAIR_COUNT);
|
||||||
|
|
||||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||||
|
|
||||||
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
|
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
|
||||||
|
|
||||||
let PairComputations {
|
let PairComputations {
|
||||||
leaves: mut leaf_subtrees,
|
leaves: mut leaf_subtrees,
|
||||||
nodes: test_leaves,
|
nodes: test_leaves,
|
||||||
} = Smt::sorted_pairs_to_leaves(entries);
|
} = Smt::sorted_pairs_to_leaves(entries);
|
||||||
|
|
||||||
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
||||||
let (nodes, mut subtree_roots): (Vec<BTreeMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
|
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
|
||||||
.into_par_iter()
|
.into_par_iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(i, subtree)| {
|
.map(|(i, subtree)| {
|
||||||
|
@ -337,9 +308,7 @@ fn test_multithreaded_subtrees() {
|
||||||
!subtree.is_empty(),
|
!subtree.is_empty(),
|
||||||
"subtree {i} at bottom-depth {current_depth} is empty!",
|
"subtree {i} at bottom-depth {current_depth} is empty!",
|
||||||
);
|
);
|
||||||
|
|
||||||
let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth);
|
let (nodes, subtree_root) = build_subtree(subtree, SMT_DEPTH, current_depth);
|
||||||
|
|
||||||
// Post-assertions.
|
// Post-assertions.
|
||||||
for (&index, test_node) in nodes.iter() {
|
for (&index, test_node) in nodes.iter() {
|
||||||
let control_node = control.get_inner_node(index);
|
let control_node = control.get_inner_node(index);
|
||||||
|
@ -349,17 +318,13 @@ fn test_multithreaded_subtrees() {
|
||||||
current_depth, i, index,
|
current_depth, i, index,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
(nodes, subtree_root)
|
(nodes, subtree_root)
|
||||||
})
|
})
|
||||||
.unzip();
|
.unzip();
|
||||||
|
|
||||||
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
|
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
|
||||||
accumulated_nodes.extend(nodes.into_iter().flatten());
|
accumulated_nodes.extend(nodes.into_iter().flatten());
|
||||||
|
|
||||||
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
|
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure the true leaves match, checking length first and then each individual leaf.
|
// Make sure the true leaves match, checking length first and then each individual leaf.
|
||||||
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
|
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
|
||||||
let control_leaves_len = control_leaves.len();
|
let control_leaves_len = control_leaves.len();
|
||||||
|
@ -370,7 +335,6 @@ fn test_multithreaded_subtrees() {
|
||||||
let &control_leaf = control_leaves.get(&index).unwrap();
|
let &control_leaf = control_leaves.get(&index).unwrap();
|
||||||
assert_eq!(test_leaf, control_leaf);
|
assert_eq!(test_leaf, control_leaf);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure the inner nodes match, checking length first and then each individual leaf.
|
// Make sure the inner nodes match, checking length first and then each individual leaf.
|
||||||
let control_nodes_len = control.inner_nodes().count();
|
let control_nodes_len = control.inner_nodes().count();
|
||||||
let test_nodes_len = accumulated_nodes.len();
|
let test_nodes_len = accumulated_nodes.len();
|
||||||
|
@ -379,20 +343,16 @@ fn test_multithreaded_subtrees() {
|
||||||
let control_node = control.get_inner_node(index);
|
let control_node = control.get_inner_node(index);
|
||||||
assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
|
assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
|
||||||
}
|
}
|
||||||
|
|
||||||
// After the last iteration of the above for loop, we should have the new root node actually
|
// After the last iteration of the above for loop, we should have the new root node actually
|
||||||
// in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from
|
// in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from
|
||||||
// `build_subtree()`. So let's check both!
|
// `build_subtree()`. So let's check both!
|
||||||
|
|
||||||
let control_root = control.get_inner_node(NodeIndex::root());
|
let control_root = control.get_inner_node(NodeIndex::root());
|
||||||
|
|
||||||
// That for loop should have left us with only one leaf subtree...
|
// That for loop should have left us with only one leaf subtree...
|
||||||
let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap();
|
let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap();
|
||||||
// which itself contains only one 'leaf'...
|
// which itself contains only one 'leaf'...
|
||||||
let [root_leaf]: [_; 1] = leaf_subtree.try_into().unwrap();
|
let [root_leaf]: [_; 1] = leaf_subtree.try_into().unwrap();
|
||||||
// which matches the expected root.
|
// which matches the expected root.
|
||||||
assert_eq!(control.root(), root_leaf.hash);
|
assert_eq!(control.root(), root_leaf.hash);
|
||||||
|
|
||||||
// Likewise `accumulated_nodes` should contain a node at the root index...
|
// Likewise `accumulated_nodes` should contain a node at the root index...
|
||||||
assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
|
assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
|
||||||
// and it should match our actual root.
|
// and it should match our actual root.
|
||||||
|
@ -401,17 +361,86 @@ fn test_multithreaded_subtrees() {
|
||||||
// And of course the root we got from each place should match.
|
// And of course the root we got from each place should match.
|
||||||
assert_eq!(control.root(), root_leaf.hash);
|
assert_eq!(control.root(), root_leaf.hash);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "concurrent")]
|
fn test_with_entries_concurrent() {
|
||||||
fn test_with_entries_parallel() {
|
|
||||||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
|
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
|
||||||
|
|
||||||
let entries = generate_entries(PAIR_COUNT);
|
let entries = generate_entries(PAIR_COUNT);
|
||||||
|
|
||||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||||
|
|
||||||
let smt = Smt::with_entries(entries.clone()).unwrap();
|
let smt = Smt::with_entries(entries.clone()).unwrap();
|
||||||
assert_eq!(smt.root(), control.root());
|
assert_eq!(smt.root(), control.root());
|
||||||
assert_eq!(smt, control);
|
assert_eq!(smt, control);
|
||||||
}
|
}
|
||||||
|
/// Concurrent mutations
|
||||||
|
#[test]
|
||||||
|
fn test_singlethreaded_subtree_mutations() {
|
||||||
|
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
|
||||||
|
let entries = generate_entries(PAIR_COUNT);
|
||||||
|
let updates = generate_updates(entries.clone(), 1000);
|
||||||
|
let tree = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||||
|
let control = tree.compute_mutations_sequential(updates.clone());
|
||||||
|
let mut node_mutations = NodeMutations::default();
|
||||||
|
let (mut subtree_leaves, new_pairs) = tree.sorted_pairs_to_mutated_subtree_leaves(updates);
|
||||||
|
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
||||||
|
// There's no flat_map_unzip(), so this is the best we can do.
|
||||||
|
let (mutations_per_subtree, mut subtree_roots): (Vec<_>, Vec<_>) = subtree_leaves
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, subtree)| {
|
||||||
|
// Pre-assertions.
|
||||||
|
assert!(
|
||||||
|
subtree.is_sorted(),
|
||||||
|
"subtree {i} at bottom-depth {current_depth} is not sorted",
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
!subtree.is_empty(),
|
||||||
|
"subtree {i} at bottom-depth {current_depth} is empty!",
|
||||||
|
);
|
||||||
|
// Calculate the mutations for this subtree.
|
||||||
|
let (mutations_per_subtree, subtree_root) =
|
||||||
|
tree.build_subtree_mutations(subtree, SMT_DEPTH, current_depth);
|
||||||
|
// Check that the mutations match the control tree.
|
||||||
|
for (&index, mutation) in mutations_per_subtree.iter() {
|
||||||
|
let control_mutation = control.node_mutations().get(&index).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
control_mutation, mutation,
|
||||||
|
"depth {} subtree {}: mutation does not match control at index {:?}",
|
||||||
|
current_depth, i, index,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
(mutations_per_subtree, subtree_root)
|
||||||
|
})
|
||||||
|
.unzip();
|
||||||
|
subtree_leaves = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
|
||||||
|
node_mutations.extend(mutations_per_subtree.into_iter().flatten());
|
||||||
|
assert!(!subtree_leaves.is_empty(), "on depth {current_depth}");
|
||||||
|
}
|
||||||
|
let [subtree]: [Vec<_>; 1] = subtree_leaves.try_into().unwrap();
|
||||||
|
let [root_leaf]: [SubtreeLeaf; 1] = subtree.try_into().unwrap();
|
||||||
|
// Check that the new root matches the control.
|
||||||
|
assert_eq!(control.new_root, root_leaf.hash);
|
||||||
|
// Check that the node mutations match the control.
|
||||||
|
assert_eq!(control.node_mutations().len(), node_mutations.len());
|
||||||
|
for (&index, mutation) in control.node_mutations().iter() {
|
||||||
|
let test_mutation = node_mutations.get(&index).unwrap();
|
||||||
|
assert_eq!(test_mutation, mutation);
|
||||||
|
}
|
||||||
|
// Check that the new pairs match the control
|
||||||
|
assert_eq!(control.new_pairs.len(), new_pairs.len());
|
||||||
|
for (&key, &value) in control.new_pairs.iter() {
|
||||||
|
let test_value = new_pairs.get(&key).unwrap();
|
||||||
|
assert_eq!(test_value, &value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[test]
|
||||||
|
fn test_compute_mutations_parallel() {
|
||||||
|
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
|
||||||
|
let entries = generate_entries(PAIR_COUNT);
|
||||||
|
let tree = Smt::with_entries(entries.clone()).unwrap();
|
||||||
|
let updates = generate_updates(entries, 1000);
|
||||||
|
let control = tree.compute_mutations_sequential(updates.clone());
|
||||||
|
let mutations = tree.compute_mutations(updates);
|
||||||
|
assert_eq!(mutations.root(), control.root());
|
||||||
|
assert_eq!(mutations.old_root(), control.old_root());
|
||||||
|
assert_eq!(mutations.node_mutations(), control.node_mutations());
|
||||||
|
assert_eq!(mutations.new_pairs(), control.new_pairs());
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
use alloc::{collections::BTreeSet, string::ToString, vec::Vec};
|
use alloc::{string::ToString, vec::Vec};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex, MerkleError,
|
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex, MerkleError,
|
||||||
|
@ -15,6 +15,12 @@ mod proof;
|
||||||
pub use proof::SmtProof;
|
pub use proof::SmtProof;
|
||||||
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||||
|
|
||||||
|
// Concurrent implementation
|
||||||
|
#[cfg(feature = "concurrent")]
|
||||||
|
mod concurrent;
|
||||||
|
#[cfg(feature = "internal")]
|
||||||
|
pub use concurrent::{build_subtree_for_bench, SubtreeLeaf};
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests;
|
mod tests;
|
||||||
|
|
||||||
|
@ -81,23 +87,7 @@ impl Smt {
|
||||||
) -> Result<Self, MerkleError> {
|
) -> Result<Self, MerkleError> {
|
||||||
#[cfg(feature = "concurrent")]
|
#[cfg(feature = "concurrent")]
|
||||||
{
|
{
|
||||||
let mut seen_keys = BTreeSet::new();
|
Self::with_entries_concurrent(entries)
|
||||||
let entries: Vec<_> = entries
|
|
||||||
.into_iter()
|
|
||||||
.map(|(key, value)| {
|
|
||||||
if seen_keys.insert(key) {
|
|
||||||
Ok((key, value))
|
|
||||||
} else {
|
|
||||||
Err(MerkleError::DuplicateValuesForIndex(
|
|
||||||
LeafIndex::<SMT_DEPTH>::from(key).value(),
|
|
||||||
))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.collect::<Result<_, _>>()?;
|
|
||||||
if entries.is_empty() {
|
|
||||||
return Ok(Self::default());
|
|
||||||
}
|
|
||||||
<Self as SparseMerkleTree<SMT_DEPTH>>::with_entries_par(entries)
|
|
||||||
}
|
}
|
||||||
#[cfg(not(feature = "concurrent"))]
|
#[cfg(not(feature = "concurrent"))]
|
||||||
{
|
{
|
||||||
|
@ -112,9 +102,12 @@ impl Smt {
|
||||||
///
|
///
|
||||||
/// # Errors
|
/// # Errors
|
||||||
/// Returns an error if the provided entries contain multiple values for the same key.
|
/// Returns an error if the provided entries contain multiple values for the same key.
|
||||||
pub fn with_entries_sequential(
|
#[cfg(any(not(feature = "concurrent"), test))]
|
||||||
|
fn with_entries_sequential(
|
||||||
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
|
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
|
||||||
) -> Result<Self, MerkleError> {
|
) -> Result<Self, MerkleError> {
|
||||||
|
use alloc::collections::BTreeSet;
|
||||||
|
|
||||||
// create an empty tree
|
// create an empty tree
|
||||||
let mut tree = Self::new();
|
let mut tree = Self::new();
|
||||||
|
|
||||||
|
@ -252,7 +245,14 @@ impl Smt {
|
||||||
&self,
|
&self,
|
||||||
kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
|
kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
|
||||||
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
|
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
|
||||||
<Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations(self, kv_pairs)
|
#[cfg(feature = "concurrent")]
|
||||||
|
{
|
||||||
|
self.compute_mutations_concurrent(kv_pairs)
|
||||||
|
}
|
||||||
|
#[cfg(not(feature = "concurrent"))]
|
||||||
|
{
|
||||||
|
<Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations(self, kv_pairs)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree.
|
/// Applies the prospective mutations computed with [`Smt::compute_mutations()`] to this tree.
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
use alloc::{collections::BTreeMap, vec::Vec};
|
use alloc::vec::Vec;
|
||||||
use core::{hash::Hash, mem};
|
use core::hash::Hash;
|
||||||
|
|
||||||
use num::Integer;
|
|
||||||
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||||
|
|
||||||
use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex};
|
use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex};
|
||||||
|
@ -11,6 +10,8 @@ use crate::{
|
||||||
};
|
};
|
||||||
|
|
||||||
mod full;
|
mod full;
|
||||||
|
#[cfg(feature = "internal")]
|
||||||
|
pub use full::{build_subtree_for_bench, SubtreeLeaf};
|
||||||
pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH};
|
pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH};
|
||||||
|
|
||||||
mod simple;
|
mod simple;
|
||||||
|
@ -75,17 +76,6 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
|
||||||
// PROVIDED METHODS
|
// PROVIDED METHODS
|
||||||
// ---------------------------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------------------------
|
||||||
|
|
||||||
/// Creates a new sparse Merkle tree from an existing set of key-value pairs, in parallel.
|
|
||||||
#[cfg(feature = "concurrent")]
|
|
||||||
fn with_entries_par(entries: Vec<(Self::Key, Self::Value)>) -> Result<Self, MerkleError>
|
|
||||||
where
|
|
||||||
Self: Sized,
|
|
||||||
{
|
|
||||||
let (inner_nodes, leaves) = Self::build_subtrees(entries);
|
|
||||||
let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash();
|
|
||||||
Self::from_raw_parts(inner_nodes, leaves, root)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
|
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
|
||||||
/// path to the leaf, as well as the leaf itself.
|
/// path to the leaf, as well as the leaf itself.
|
||||||
fn open(&self, key: &Self::Key) -> Self::Opening {
|
fn open(&self, key: &Self::Key) -> Self::Opening {
|
||||||
|
@ -178,6 +168,15 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
|
||||||
fn compute_mutations(
|
fn compute_mutations(
|
||||||
&self,
|
&self,
|
||||||
kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
|
kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
|
||||||
|
) -> MutationSet<DEPTH, Self::Key, Self::Value> {
|
||||||
|
self.compute_mutations_sequential(kv_pairs)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sequential version of [`SparseMerkleTree::compute_mutations()`].
|
||||||
|
/// This is the default implementation.
|
||||||
|
fn compute_mutations_sequential(
|
||||||
|
&self,
|
||||||
|
kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
|
||||||
) -> MutationSet<DEPTH, Self::Key, Self::Value> {
|
) -> MutationSet<DEPTH, Self::Key, Self::Value> {
|
||||||
use NodeMutation::*;
|
use NodeMutation::*;
|
||||||
|
|
||||||
|
@ -457,118 +456,6 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
|
||||||
///
|
///
|
||||||
/// The length `path` is guaranteed to be equal to `DEPTH`
|
/// The length `path` is guaranteed to be equal to `DEPTH`
|
||||||
fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening;
|
fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening;
|
||||||
|
|
||||||
/// Performs the initial transforms for constructing a [`SparseMerkleTree`] by composing
|
|
||||||
/// subtrees. In other words, this function takes the key-value inputs to the tree, and produces
|
|
||||||
/// the inputs to feed into [`build_subtree()`].
|
|
||||||
///
|
|
||||||
/// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If
|
|
||||||
/// `pairs` is not correctly sorted, the returned computations will be incorrect.
|
|
||||||
///
|
|
||||||
/// # Panics
|
|
||||||
/// With debug assertions on, this function panics if it detects that `pairs` is not correctly
|
|
||||||
/// sorted. Without debug assertions, the returned computations will be incorrect.
|
|
||||||
fn sorted_pairs_to_leaves(
|
|
||||||
pairs: Vec<(Self::Key, Self::Value)>,
|
|
||||||
) -> PairComputations<u64, Self::Leaf> {
|
|
||||||
debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value()));
|
|
||||||
|
|
||||||
let mut accumulator: PairComputations<u64, Self::Leaf> = Default::default();
|
|
||||||
let mut accumulated_leaves: Vec<SubtreeLeaf> = Vec::with_capacity(pairs.len() / 2);
|
|
||||||
|
|
||||||
// As we iterate, we'll keep track of the kv-pairs we've seen so far that correspond to a
|
|
||||||
// single leaf. When we see a pair that's in a different leaf, we'll swap these pairs
|
|
||||||
// out and store them in our accumulated leaves.
|
|
||||||
let mut current_leaf_buffer: Vec<(Self::Key, Self::Value)> = Default::default();
|
|
||||||
|
|
||||||
let mut iter = pairs.into_iter().peekable();
|
|
||||||
while let Some((key, value)) = iter.next() {
|
|
||||||
let col = Self::key_to_leaf_index(&key).index.value();
|
|
||||||
let peeked_col = iter.peek().map(|(key, _v)| {
|
|
||||||
let index = Self::key_to_leaf_index(key);
|
|
||||||
let next_col = index.index.value();
|
|
||||||
// We panic if `pairs` is not sorted by column.
|
|
||||||
debug_assert!(next_col >= col);
|
|
||||||
next_col
|
|
||||||
});
|
|
||||||
current_leaf_buffer.push((key, value));
|
|
||||||
|
|
||||||
// If the next pair is the same column as this one, then we're done after adding this
|
|
||||||
// pair to the buffer.
|
|
||||||
if peeked_col == Some(col) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Otherwise, the next pair is a different column, or there is no next pair. Either way
|
|
||||||
// it's time to swap out our buffer.
|
|
||||||
let leaf_pairs = mem::take(&mut current_leaf_buffer);
|
|
||||||
let leaf = Self::pairs_to_leaf(leaf_pairs);
|
|
||||||
let hash = Self::hash_leaf(&leaf);
|
|
||||||
|
|
||||||
accumulator.nodes.insert(col, leaf);
|
|
||||||
accumulated_leaves.push(SubtreeLeaf { col, hash });
|
|
||||||
|
|
||||||
debug_assert!(current_leaf_buffer.is_empty());
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: determine is there is any notable performance difference between computing
|
|
||||||
// subtree boundaries after the fact as an iterator adapter (like this), versus computing
|
|
||||||
// subtree boundaries as we go. Either way this function is only used at the beginning of a
|
|
||||||
// parallel construction, so it should not be a critical path.
|
|
||||||
accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect();
|
|
||||||
accumulator
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs.
|
|
||||||
///
|
|
||||||
/// `entries` need not be sorted. This function will sort them.
|
|
||||||
#[cfg(feature = "concurrent")]
|
|
||||||
fn build_subtrees(
|
|
||||||
mut entries: Vec<(Self::Key, Self::Value)>,
|
|
||||||
) -> (InnerNodes, Leaves<Self::Leaf>) {
|
|
||||||
entries.sort_by_key(|item| {
|
|
||||||
let index = Self::key_to_leaf_index(&item.0);
|
|
||||||
index.value()
|
|
||||||
});
|
|
||||||
Self::build_subtrees_from_sorted_entries(entries)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs.
|
|
||||||
///
|
|
||||||
/// This function is mostly an implementation detail of
|
|
||||||
/// [`SparseMerkleTree::with_entries_par()`].
|
|
||||||
#[cfg(feature = "concurrent")]
|
|
||||||
fn build_subtrees_from_sorted_entries(
|
|
||||||
entries: Vec<(Self::Key, Self::Value)>,
|
|
||||||
) -> (InnerNodes, Leaves<Self::Leaf>) {
|
|
||||||
use rayon::prelude::*;
|
|
||||||
|
|
||||||
let mut accumulated_nodes: InnerNodes = Default::default();
|
|
||||||
|
|
||||||
let PairComputations {
|
|
||||||
leaves: mut leaf_subtrees,
|
|
||||||
nodes: initial_leaves,
|
|
||||||
} = Self::sorted_pairs_to_leaves(entries);
|
|
||||||
|
|
||||||
for current_depth in (SUBTREE_DEPTH..=DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
|
||||||
let (nodes, mut subtree_roots): (Vec<BTreeMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
|
|
||||||
.into_par_iter()
|
|
||||||
.map(|subtree| {
|
|
||||||
debug_assert!(subtree.is_sorted());
|
|
||||||
debug_assert!(!subtree.is_empty());
|
|
||||||
|
|
||||||
let (nodes, subtree_root) = build_subtree(subtree, DEPTH, current_depth);
|
|
||||||
(nodes, subtree_root)
|
|
||||||
})
|
|
||||||
.unzip();
|
|
||||||
|
|
||||||
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut subtree_roots).collect();
|
|
||||||
accumulated_nodes.extend(nodes.into_iter().flatten());
|
|
||||||
|
|
||||||
debug_assert!(!leaf_subtrees.is_empty());
|
|
||||||
}
|
|
||||||
(accumulated_nodes, initial_leaves)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// INNER NODE
|
// INNER NODE
|
||||||
|
@ -820,198 +707,3 @@ impl<const DEPTH: u8, K: Deserializable + Ord + Eq + Hash, V: Deserializable> De
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SUBTREES
|
|
||||||
// ================================================================================================
|
|
||||||
|
|
||||||
/// A subtree is of depth 8.
|
|
||||||
const SUBTREE_DEPTH: u8 = 8;
|
|
||||||
|
|
||||||
/// A depth-8 subtree contains 256 "columns" that can possibly be occupied.
|
|
||||||
const COLS_PER_SUBTREE: u64 = u64::pow(2, SUBTREE_DEPTH as u32);
|
|
||||||
|
|
||||||
/// Helper struct for organizing the data we care about when computing Merkle subtrees.
|
|
||||||
///
|
|
||||||
/// Note that these represet "conceptual" leaves of some subtree, not necessarily
|
|
||||||
/// the leaf type for the sparse Merkle tree.
|
|
||||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
|
|
||||||
pub struct SubtreeLeaf {
|
|
||||||
/// The 'value' field of [`NodeIndex`]. When computing a subtree, the depth is already known.
|
|
||||||
pub col: u64,
|
|
||||||
/// The hash of the node this `SubtreeLeaf` represents.
|
|
||||||
pub hash: RpoDigest,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`].
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub(crate) struct PairComputations<K, L> {
|
|
||||||
/// Literal leaves to be added to the sparse Merkle tree's internal mapping.
|
|
||||||
pub nodes: UnorderedMap<K, L>,
|
|
||||||
/// "Conceptual" leaves that will be used for computations.
|
|
||||||
pub leaves: Vec<Vec<SubtreeLeaf>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Derive requires `L` to impl Default, even though we don't actually need that.
|
|
||||||
impl<K, L> Default for PairComputations<K, L> {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
nodes: Default::default(),
|
|
||||||
leaves: Default::default(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug)]
|
|
||||||
struct SubtreeLeavesIter<'s> {
|
|
||||||
leaves: core::iter::Peekable<alloc::vec::Drain<'s, SubtreeLeaf>>,
|
|
||||||
}
|
|
||||||
impl<'s> SubtreeLeavesIter<'s> {
|
|
||||||
fn from_leaves(leaves: &'s mut Vec<SubtreeLeaf>) -> Self {
|
|
||||||
// TODO: determine if there is any notable performance difference between taking a Vec,
|
|
||||||
// which many need flattening first, vs storing a `Box<dyn Iterator<Item = SubtreeLeaf>>`.
|
|
||||||
// The latter may have self-referential properties that are impossible to express in purely
|
|
||||||
// safe Rust Rust.
|
|
||||||
Self { leaves: leaves.drain(..).peekable() }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
impl Iterator for SubtreeLeavesIter<'_> {
|
|
||||||
type Item = Vec<SubtreeLeaf>;
|
|
||||||
|
|
||||||
/// Each `next()` collects an entire subtree.
|
|
||||||
fn next(&mut self) -> Option<Vec<SubtreeLeaf>> {
|
|
||||||
let mut subtree: Vec<SubtreeLeaf> = Default::default();
|
|
||||||
|
|
||||||
let mut last_subtree_col = 0;
|
|
||||||
|
|
||||||
while let Some(leaf) = self.leaves.peek() {
|
|
||||||
last_subtree_col = u64::max(1, last_subtree_col);
|
|
||||||
let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE);
|
|
||||||
let next_subtree_col = if is_exact_multiple {
|
|
||||||
u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE)
|
|
||||||
} else {
|
|
||||||
last_subtree_col.next_multiple_of(COLS_PER_SUBTREE)
|
|
||||||
};
|
|
||||||
|
|
||||||
last_subtree_col = leaf.col;
|
|
||||||
if leaf.col < next_subtree_col {
|
|
||||||
subtree.push(self.leaves.next().unwrap());
|
|
||||||
} else if subtree.is_empty() {
|
|
||||||
continue;
|
|
||||||
} else {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if subtree.is_empty() {
|
|
||||||
debug_assert!(self.leaves.peek().is_none());
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
Some(subtree)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// HELPER FUNCTIONS
|
|
||||||
// ================================================================================================
|
|
||||||
|
|
||||||
/// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and
|
|
||||||
/// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and
|
|
||||||
/// `leaves` must not contain more than one depth-8 subtree's worth of leaves.
|
|
||||||
///
|
|
||||||
/// This function will then calculate the inner nodes above each leaf for 8 layers, as well as
|
|
||||||
/// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into
|
|
||||||
/// itself.
|
|
||||||
///
|
|
||||||
/// # Panics
|
|
||||||
/// With debug assertions on, this function panics under invalid inputs: if `leaves` contains
|
|
||||||
/// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to
|
|
||||||
/// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified
|
|
||||||
/// maximum depth (`DEPTH`), or if `leaves` is not sorted.
|
|
||||||
fn build_subtree(
|
|
||||||
mut leaves: Vec<SubtreeLeaf>,
|
|
||||||
tree_depth: u8,
|
|
||||||
bottom_depth: u8,
|
|
||||||
) -> (BTreeMap<NodeIndex, InnerNode>, SubtreeLeaf) {
|
|
||||||
debug_assert!(bottom_depth <= tree_depth);
|
|
||||||
debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH));
|
|
||||||
debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32));
|
|
||||||
let subtree_root = bottom_depth - SUBTREE_DEPTH;
|
|
||||||
let mut inner_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
|
|
||||||
let mut next_leaves: Vec<SubtreeLeaf> = Vec::with_capacity(leaves.len() / 2);
|
|
||||||
for next_depth in (subtree_root..bottom_depth).rev() {
|
|
||||||
debug_assert!(next_depth <= bottom_depth);
|
|
||||||
// `next_depth` is the stuff we're making.
|
|
||||||
// `current_depth` is the stuff we have.
|
|
||||||
let current_depth = next_depth + 1;
|
|
||||||
let mut iter = leaves.drain(..).peekable();
|
|
||||||
while let Some(first) = iter.next() {
|
|
||||||
// On non-continuous iterations, including the first iteration, `first_column` may
|
|
||||||
// be a left or right node. On subsequent continuous iterations, we will always call
|
|
||||||
// `iter.next()` twice.
|
|
||||||
// On non-continuous iterations (including the very first iteration), this column
|
|
||||||
// could be either on the left or the right. If the next iteration is not
|
|
||||||
// discontinuous with our right node, then the next iteration's
|
|
||||||
let is_right = first.col.is_odd();
|
|
||||||
let (left, right) = if is_right {
|
|
||||||
// Discontinuous iteration: we have no left node, so it must be empty.
|
|
||||||
let left = SubtreeLeaf {
|
|
||||||
col: first.col - 1,
|
|
||||||
hash: *EmptySubtreeRoots::entry(tree_depth, current_depth),
|
|
||||||
};
|
|
||||||
let right = first;
|
|
||||||
(left, right)
|
|
||||||
} else {
|
|
||||||
let left = first;
|
|
||||||
let right_col = first.col + 1;
|
|
||||||
let right = match iter.peek().copied() {
|
|
||||||
Some(SubtreeLeaf { col, .. }) if col == right_col => {
|
|
||||||
// Our inputs must be sorted.
|
|
||||||
debug_assert!(left.col <= col);
|
|
||||||
// The next leaf in the iterator is our sibling. Use it and consume it!
|
|
||||||
iter.next().unwrap()
|
|
||||||
},
|
|
||||||
// Otherwise, the leaves don't contain our sibling, so our sibling must be
|
|
||||||
// empty.
|
|
||||||
_ => SubtreeLeaf {
|
|
||||||
col: right_col,
|
|
||||||
hash: *EmptySubtreeRoots::entry(tree_depth, current_depth),
|
|
||||||
},
|
|
||||||
};
|
|
||||||
(left, right)
|
|
||||||
};
|
|
||||||
let index = NodeIndex::new_unchecked(current_depth, left.col).parent();
|
|
||||||
let node = InnerNode { left: left.hash, right: right.hash };
|
|
||||||
let hash = node.hash();
|
|
||||||
let &equivalent_empty_hash = EmptySubtreeRoots::entry(tree_depth, next_depth);
|
|
||||||
// If this hash is empty, then it doesn't become a new inner node, nor does it count
|
|
||||||
// as a leaf for the next depth.
|
|
||||||
if hash != equivalent_empty_hash {
|
|
||||||
inner_nodes.insert(index, node);
|
|
||||||
next_leaves.push(SubtreeLeaf { col: index.value(), hash });
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Stop borrowing `leaves`, so we can swap it.
|
|
||||||
// The iterator is empty at this point anyway.
|
|
||||||
drop(iter);
|
|
||||||
// After each depth, consider the stuff we just made the new "leaves", and empty the
|
|
||||||
// other collection.
|
|
||||||
mem::swap(&mut leaves, &mut next_leaves);
|
|
||||||
}
|
|
||||||
debug_assert_eq!(leaves.len(), 1);
|
|
||||||
let root = leaves.pop().unwrap();
|
|
||||||
(inner_nodes, root)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(feature = "internal")]
|
|
||||||
pub fn build_subtree_for_bench(
|
|
||||||
leaves: Vec<SubtreeLeaf>,
|
|
||||||
tree_depth: u8,
|
|
||||||
bottom_depth: u8,
|
|
||||||
) -> (BTreeMap<NodeIndex, InnerNode>, SubtreeLeaf) {
|
|
||||||
build_subtree(leaves, tree_depth, bottom_depth)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TESTS
|
|
||||||
// ================================================================================================
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests;
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue