From 034a2a66030aa70af3a9c1ff21eac094fc74904b Mon Sep 17 00:00:00 2001 From: Qyriad Date: Thu, 14 Nov 2024 19:01:58 -0700 Subject: [PATCH 1/3] smt: add parallel constructors to Smt and SimpleSmt What the previous few commits have been leading up to: SparseMerkleTree now has a function to construct the tree from existing data in parallel. This is significantly faster than the singlethreaded equivalent. Benchmarks incoming! --- src/merkle/smt/full/mod.rs | 13 +++++++ src/merkle/smt/mod.rs | 73 ++++++++++++++++++++++++++++++++++++ src/merkle/smt/simple/mod.rs | 16 ++++++++ src/merkle/smt/tests.rs | 14 +++++++ 4 files changed, 116 insertions(+) diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 4f6ec62..fa680bf 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -101,6 +101,19 @@ impl Smt { Ok(tree) } + /// The parallel version of [`Smt::with_entries()`]. + /// + /// Returns a new [`Smt`] instantiated with leaves set as specified by the provided entries, + /// constructed in parallel. + /// + /// All leaves omitted from the entries list are set to [Self::EMPTY_VALUE]. + #[cfg(feature = "concurrent")] + pub fn with_entries_par( + entries: impl IntoIterator, + ) -> Result { + >::with_entries_par(Vec::from_iter(entries)) + } + /// Returns a new [`Smt`] instantiated from already computed leaves and nodes. /// /// This function performs minimal consistency checking. It is the caller's responsibility to diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index f5ae2d4..43caeb3 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -65,6 +65,19 @@ pub(crate) trait SparseMerkleTree { // PROVIDED METHODS // --------------------------------------------------------------------------------------------- + /// Creates a new sparse Merkle tree from an existing set of key-value pairs, in parallel. + #[cfg(feature = "concurrent")] + fn with_entries_par(entries: Vec<(Self::Key, Self::Value)>) -> Result + where + Self: Sized, + { + let (inner_nodes, leaves) = Self::build_subtrees(entries); + let leaves: BTreeMap = + leaves.into_iter().map(|(index, leaf)| (index.value(), leaf)).collect(); + let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash(); + Self::from_raw_parts(inner_nodes, leaves, root) + } + /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle /// path to the leaf, as well as the leaf itself. fn open(&self, key: &Self::Key) -> Self::Opening { @@ -429,6 +442,8 @@ pub(crate) trait SparseMerkleTree { /// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into /// itself. /// + /// This function is mostly an implementation detail of [`SparseMerkleTree::build_subtrees()`]. + /// /// # Panics /// With debug assertions on, this function panics under invalid inputs: if `leaves` contains /// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to @@ -522,6 +537,64 @@ pub(crate) trait SparseMerkleTree { (inner_nodes, leaves) } + + /// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs. + /// + /// `entries` need not be sorted. This function will sort them. + /// + /// This function is mostly an implementation detail of + /// [`SparseMerkleTree::with_entries_par()`]. + #[cfg(feature = "concurrent")] + fn build_subtrees( + mut entries: Vec<(Self::Key, Self::Value)>, + ) -> (BTreeMap, BTreeMap, Self::Leaf>) { + use rayon::prelude::*; + + entries.sort_by_key(|item| { + let index = Self::key_to_leaf_index(&item.0); + index.value() + }); + + let mut accumulated_nodes: BTreeMap = Default::default(); + + let PairComputations { + leaves: mut leaf_subtrees, + nodes: initial_leaves, + } = Self::sorted_pairs_to_leaves(entries); + + for current_depth in (SUBTREE_DEPTH..=DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + let (nodes, subtrees): (Vec>, Vec>) = leaf_subtrees + .into_par_iter() + .map(|subtree| { + debug_assert!(subtree.is_sorted()); + debug_assert!(!subtree.is_empty()); + + let (nodes, next_leaves) = Self::build_subtree(subtree, current_depth); + + debug_assert!(next_leaves.is_sorted()); + + (nodes, next_leaves) + }) + .unzip(); + + let mut all_leaves: Vec = subtrees.into_iter().flatten().collect(); + leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect(); + accumulated_nodes.extend(nodes.into_iter().flatten()); + + debug_assert!(!leaf_subtrees.is_empty()); + } + + let leaves: BTreeMap, Self::Leaf> = initial_leaves + .into_iter() + .map(|(key, value)| { + // This unwrap *should* be unreachable. + let key = LeafIndex::::new(key).unwrap(); + (key, value) + }) + .collect(); + + (accumulated_nodes, leaves) + } } // INNER NODE diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index 1ded87f..a95aed1 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -100,6 +100,22 @@ impl SimpleSmt { Ok(tree) } + /// The parallel version of [`SimpleSmt::with_leaves()`]. + /// + /// Returns a new [`SimpleSmt`] instantiated with leaves set as specified by the provided entries. + /// + /// All leaves omitted from the entries list are set to [ZERO; 4]. + #[cfg(feature = "concurrent")] + pub fn with_leaves_par( + entries: impl IntoIterator, + ) -> Result { + let entries: Vec<_> = entries + .into_iter() + .map(|(col, value)| (LeafIndex::::new(col).unwrap(), value)) + .collect(); + >::with_entries_par(entries) + } + /// Returns a new [`SimpleSmt`] instantiated from already computed leaves and nodes. /// /// This function performs minimal consistency checking. It is the caller's responsibility to diff --git a/src/merkle/smt/tests.rs b/src/merkle/smt/tests.rs index b16dc2e..7701417 100644 --- a/src/merkle/smt/tests.rs +++ b/src/merkle/smt/tests.rs @@ -412,3 +412,17 @@ fn test_multithreaded_subtrees() { // And of course the root we got from each place should match. assert_eq!(control.root(), root_leaf.hash); } + +#[test] +#[cfg(feature = "concurrent")] +fn test_with_entries_par() { + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + + let entries = generate_entries(PAIR_COUNT); + + let control = Smt::with_entries(entries.clone()).unwrap(); + + let smt = Smt::with_entries_par(entries.clone()).unwrap(); + assert_eq!(smt.root(), control.root()); + assert_eq!(smt, control); +} From 3b7ce3d2537071a1a1f450ad7c86764f8abb8d18 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Thu, 14 Nov 2024 19:44:26 -0700 Subject: [PATCH 2/3] smt: add benchmarks for parallel construction --- Cargo.toml | 5 +++ benches/parallel-subtree.rs | 75 +++++++++++++++++++++++++++++++++++++ src/main.rs | 28 +++++++++++++- 3 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 benches/parallel-subtree.rs diff --git a/Cargo.toml b/Cargo.toml index 347cca8..750748a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,11 @@ harness = false name = "merkle" harness = false +[[bench]] +name = "parallel-subtree" +harness = false +required-features = ["concurrent"] + [[bench]] name = "store" harness = false diff --git a/benches/parallel-subtree.rs b/benches/parallel-subtree.rs new file mode 100644 index 0000000..65c3918 --- /dev/null +++ b/benches/parallel-subtree.rs @@ -0,0 +1,75 @@ +use std::{fmt::Debug, hint, mem, time::Duration}; + +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use miden_crypto::{hash::rpo::RpoDigest, merkle::Smt, Felt, Word, ONE}; +use rand_utils::prng_array; +use winter_utils::Randomizable; + +// 2^0, 2^4, 2^8, 2^12, 2^16 +const PAIR_COUNTS: [u64; 6] = [1, 16, 256, 4096, 65536, 1_048_576]; + +fn smt_parallel_subtree(c: &mut Criterion) { + let mut seed = [0u8; 32]; + + let mut group = c.benchmark_group("parallel-subtrees"); + + for pair_count in PAIR_COUNTS { + let bench_id = BenchmarkId::from_parameter(pair_count); + group.bench_with_input(bench_id, &pair_count, |b, &pair_count| { + b.iter_batched( + || { + // Setup. + let entries: Vec<(RpoDigest, Word)> = (0..pair_count) + .map(|i| { + let count = pair_count as f64; + let idx = ((i as f64 / count) * (count)) as u64; + let key = RpoDigest::new([ + generate_value(&mut seed), + ONE, + Felt::new(i), + Felt::new(idx), + ]); + let value = generate_word(&mut seed); + (key, value) + }) + .collect(); + + let control = Smt::with_entries(entries.clone()).unwrap(); + (entries, control) + }, + |(entries, control)| { + // Benchmarked function. + let tree = Smt::with_entries_par(hint::black_box(entries)).unwrap(); + assert_eq!(tree.root(), control.root()); + }, + BatchSize::SmallInput, + ); + }); + } +} + +criterion_group! { + name = smt_subtree_group; + config = Criterion::default() + //.measurement_time(Duration::from_secs(960)) + .measurement_time(Duration::from_secs(60)) + .sample_size(10) + .configure_from_args(); + targets = smt_parallel_subtree +} +criterion_main!(smt_subtree_group); + +// HELPER FUNCTIONS +// -------------------------------------------------------------------------------------------- + +fn generate_value(seed: &mut [u8; 32]) -> T { + mem::swap(seed, &mut prng_array(*seed)); + let value: [T; 1] = rand_utils::prng_array(*seed); + value[0] +} + +fn generate_word(seed: &mut [u8; 32]) -> Word { + mem::swap(seed, &mut prng_array(*seed)); + let nums: [u64; 4] = prng_array(*seed); + [Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])] +} diff --git a/src/main.rs b/src/main.rs index 776ccc2..018dd50 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,7 +33,12 @@ pub fn benchmark_smt() { entries.push((key, value)); } - let mut tree = construction(entries, tree_size).unwrap(); + let mut tree = construction(entries.clone(), tree_size).unwrap(); + #[cfg(feature = "concurrent")] + { + let parallel = parallel_construction(entries, tree_size).unwrap(); + assert_eq!(tree, parallel); + } insertion(&mut tree, tree_size).unwrap(); batched_insertion(&mut tree, tree_size).unwrap(); proof_generation(&mut tree, tree_size).unwrap(); @@ -56,6 +61,27 @@ pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result, + size: u64, +) -> Result { + println!("Running a parallel construction benchmark:"); + let now = Instant::now(); + + let tree = Smt::with_entries_par(entries).unwrap(); + + let elapsed = now.elapsed(); + println!( + "Parallel-constructed an SMT with {} key-value pairs in {:.3} seconds", + size, + elapsed.as_secs_f32(), + ); + println!("Number of leaf nodes: {}\n", tree.leaves().count()); + + Ok(tree) +} + /// Runs the insertion benchmark for the [`Smt`]. pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> { println!("Running an insertion benchmark:"); From 006954c1bd6a9a59c69ec2673933896d4004d03e Mon Sep 17 00:00:00 2001 From: Qyriad Date: Thu, 14 Nov 2024 19:48:50 -0700 Subject: [PATCH 3/3] add news item for smt parallel subtree construction --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc22853..d225165 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## 0.11.0 (2024-10-30) - [BREAKING] Updated Winterfell dependency to v0.10 (#338). +- Added `Smt::with_entries_par()`, a parallel version of `with_entries()` with significantly better performance (#341). ## 0.11.0 (2024-10-17)