From 8b104655477b6fd42de771e1f860f0e0144e1f6e Mon Sep 17 00:00:00 2001 From: Qyriad Date: Wed, 13 Nov 2024 15:32:48 -0700 Subject: [PATCH 01/10] smt: add sorted_pairs_to_leaves() and test for it --- src/merkle/smt/mod.rs | 155 ++++++++++++++++++++++++++++++++++++++++ src/merkle/smt/tests.rs | 91 +++++++++++++++++++++++ 2 files changed, 246 insertions(+) create mode 100644 src/merkle/smt/tests.rs diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 03d9d45..1972304 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -1,4 +1,7 @@ use alloc::{collections::BTreeMap, vec::Vec}; +use core::mem; + +use num::Integer; use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex}; use crate::{ @@ -346,6 +349,67 @@ pub(crate) trait SparseMerkleTree { /// /// The length `path` is guaranteed to be equal to `DEPTH` fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening; + + /// Performs the initial transforms for constructing a [`SparseMerkleTree`] by composing + /// subtrees. In other words, this function takes the key-value inputs to the tree, and produces + /// the inputs to feed into [`SparseMerkleTree::build_subtree()`]. + /// + /// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If + /// `pairs` is not correctly sorted, the returned computations will be incorrect. + /// + /// # Panics + /// With debug assertions on, this function panics if it detects that `pairs` is not correctly + /// sorted. Without debug assertions, the returned computations will be incorrect. + fn sorted_pairs_to_leaves( + pairs: Vec<(Self::Key, Self::Value)>, + ) -> PairComputations { + debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value())); + + let mut accumulator: PairComputations = Default::default(); + let mut accumulated_leaves: Vec = Default::default(); + + // As we iterate, we'll keep track of the kv-pairs we've seen so far that correspond to a + // single leaf. When we see a pair that's in a different leaf, we'll swap these pairs + // out and store them in our accumulated leaves. + let mut current_leaf_buffer: Vec<(Self::Key, Self::Value)> = Default::default(); + + let mut iter = pairs.into_iter().peekable(); + while let Some((key, value)) = iter.next() { + let col = Self::key_to_leaf_index(&key).index.value(); + let peeked_col = iter.peek().map(|(key, _v)| { + let index = Self::key_to_leaf_index(key); + let next_col = index.index.value(); + // We panic if `pairs` is not sorted by column. + debug_assert!(next_col >= col); + next_col + }); + current_leaf_buffer.push((key, value)); + + // If the next pair is the same column as this one, then we're done after adding this + // pair to the buffer. + if peeked_col == Some(col) { + continue; + } + + // Otherwise, the next pair is a different column, or there is no next pair. Either way + // it's time to swap out our buffer. + let leaf_pairs = mem::take(&mut current_leaf_buffer); + let leaf = Self::pairs_to_leaf(leaf_pairs); + let hash = Self::hash_leaf(&leaf); + + accumulator.nodes.insert(col, leaf); + accumulated_leaves.push(SubtreeLeaf { col, hash }); + + debug_assert!(current_leaf_buffer.is_empty()); + } + + // TODO: determine is there is any notable performance difference between computing + // subtree boundaries after the fact as an iterator adapter (like this), versus computing + // subtree boundaries as we go. Either way this function is only used at the beginning of a + // parallel construction, so it should not be a critical path. + accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect(); + accumulator + } } // INNER NODE @@ -463,3 +527,94 @@ impl MutationSet { self.new_root } } + +// SUBTREES +// ================================================================================================ +/// A depth-8 subtree contains 256 "columns" that can possibly be occupied. +const COLS_PER_SUBTREE: u64 = u64::pow(2, 8); + +/// Helper struct for organizing the data we care about when computing Merkle subtrees. +/// +/// Note that these represet "conceptual" leaves of some subtree, not necessarily +/// the leaf type for the sparse Merkle tree. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] +pub struct SubtreeLeaf { + /// The 'value' field of [`NodeIndex`]. When computing a subtree, the depth is already known. + pub col: u64, + /// The hash of the node this `SubtreeLeaf` represents. + pub hash: RpoDigest, +} + +/// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct PairComputations { + /// Literal leaves to be added to the sparse Merkle tree's internal mapping. + pub nodes: BTreeMap, + /// "Conceptual" leaves that will be used for computations. + pub leaves: Vec>, +} + +// Derive requires `L` to impl Default, even though we don't actually need that. +impl Default for PairComputations { + fn default() -> Self { + Self { + nodes: Default::default(), + leaves: Default::default(), + } + } +} + +#[derive(Debug)] +struct SubtreeLeavesIter<'s> { + leaves: core::iter::Peekable>, +} +impl<'s> SubtreeLeavesIter<'s> { + fn from_leaves(leaves: &'s mut Vec) -> Self { + // TODO: determine if there is any notable performance difference between taking a Vec, + // which many need flattening first, vs storing a `Box>`. + // The latter may have self-referential properties that are impossible to express in purely + // safe Rust Rust. + Self { leaves: leaves.drain(..).peekable() } + } +} +impl<'s> core::iter::Iterator for SubtreeLeavesIter<'s> { + type Item = Vec; + + /// Each `next()` collects an entire subtree. + fn next(&mut self) -> Option> { + let mut subtree: Vec = Default::default(); + + let mut last_subtree_col = 0; + + while let Some(leaf) = self.leaves.peek() { + last_subtree_col = u64::max(1, last_subtree_col); + let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE); + let next_subtree_col = if is_exact_multiple { + u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE) + } else { + last_subtree_col.next_multiple_of(COLS_PER_SUBTREE) + }; + + last_subtree_col = leaf.col; + if leaf.col < next_subtree_col { + subtree.push(self.leaves.next().unwrap()); + } else if subtree.is_empty() { + continue; + } else { + break; + } + } + + if subtree.is_empty() { + debug_assert!(self.leaves.peek().is_none()); + return None; + } + + Some(subtree) + } +} + +// TESTS +// ================================================================================================ +#[cfg(test)] +mod tests; diff --git a/src/merkle/smt/tests.rs b/src/merkle/smt/tests.rs new file mode 100644 index 0000000..fbaa3fa --- /dev/null +++ b/src/merkle/smt/tests.rs @@ -0,0 +1,91 @@ +use alloc::{collections::BTreeMap, vec::Vec}; + +use super::{PairComputations, SmtLeaf, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter}; +use crate::{hash::rpo::RpoDigest, merkle::Smt, Felt, Word, ONE}; + +fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf { + SubtreeLeaf { + col: leaf.index().index.value(), + hash: leaf.hash(), + } +} + +#[test] +fn test_sorted_pairs_to_leaves() { + let entries: Vec<(RpoDigest, Word)> = vec![ + // Subtree 0. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]), + (RpoDigest::new([ONE, ONE, ONE, Felt::new(17)]), [ONE; 4]), + // Leaf index collision. + (RpoDigest::new([ONE, ONE, Felt::new(10), Felt::new(20)]), [ONE; 4]), + (RpoDigest::new([ONE, ONE, Felt::new(20), Felt::new(20)]), [ONE; 4]), + // Subtree 1. Normal single leaf again. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), // Subtree boundary. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(401)]), [ONE; 4]), + // Subtree 2. Another normal leaf. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(1024)]), [ONE; 4]), + ]; + + let control = Smt::with_entries(entries.clone()).unwrap(); + + let control_leaves: Vec = { + let mut entries_iter = entries.iter().cloned(); + let mut next_entry = || entries_iter.next().unwrap(); + let control_leaves = vec![ + // Subtree 0. + SmtLeaf::Single(next_entry()), + SmtLeaf::Single(next_entry()), + SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(), + // Subtree 1. + SmtLeaf::Single(next_entry()), + SmtLeaf::Single(next_entry()), + // Subtree 2. + SmtLeaf::Single(next_entry()), + ]; + assert_eq!(entries_iter.next(), None); + control_leaves + }; + + let control_subtree_leaves: Vec> = { + let mut control_leaves_iter = control_leaves.iter(); + let mut next_leaf = || control_leaves_iter.next().unwrap(); + + let control_subtree_leaves: Vec> = [ + // Subtree 0. + vec![next_leaf(), next_leaf(), next_leaf()], + // Subtree 1. + vec![next_leaf(), next_leaf()], + // Subtree 2. + vec![next_leaf()], + ] + .map(|subtree| subtree.into_iter().map(smtleaf_to_subtree_leaf).collect()) + .to_vec(); + assert_eq!(control_leaves_iter.next(), None); + control_subtree_leaves + }; + + let subtrees: PairComputations = Smt::sorted_pairs_to_leaves(entries); + // This will check that the hashes, columns, and subtree assignments all match. + assert_eq!(subtrees.leaves, control_subtree_leaves); + + // Flattening and re-separating out the leaves into subtrees should have the same result. + let mut all_leaves: Vec = subtrees.leaves.clone().into_iter().flatten().collect(); + let re_grouped: Vec> = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect(); + assert_eq!(subtrees.leaves, re_grouped); + + // Then finally we might as well check the computed leaf nodes too. + let control_leaves: BTreeMap = control + .leaves() + .map(|(index, value)| (index.index.value(), value.clone())) + .collect(); + + for (column, test_leaf) in subtrees.nodes { + if test_leaf.is_empty() { + continue; + } + let control_leaf = control_leaves + .get(&column) + .unwrap_or_else(|| panic!("no leaf node found for column {column}")); + assert_eq!(control_leaf, &test_leaf); + } +} From 16456aa724ae6c86f77e9ad98892196c7fcd2f1e Mon Sep 17 00:00:00 2001 From: Qyriad Date: Thu, 14 Nov 2024 14:04:15 -0700 Subject: [PATCH 02/10] smt: implement single subtree-8 hashing, w/ benchmarks & tests This will be composed into depth-8-subtree-based computation of entire sparse Merkle trees. --- Cargo.toml | 4 ++ benches/smt-subtree.rs | 136 +++++++++++++++++++++++++++++++++++++ src/merkle/mod.rs | 2 +- src/merkle/smt/full/mod.rs | 26 ++++++- src/merkle/smt/mod.rs | 112 +++++++++++++++++++++++++++++- src/merkle/smt/tests.rs | 61 ++++++++++++++++- 6 files changed, 335 insertions(+), 6 deletions(-) create mode 100644 benches/smt-subtree.rs diff --git a/Cargo.toml b/Cargo.toml index 5d124c6..ec59d42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,10 @@ harness = false name = "smt" harness = false +[[bench]] +name = "smt-subtree" +harness = false + [[bench]] name = "store" harness = false diff --git a/benches/smt-subtree.rs b/benches/smt-subtree.rs new file mode 100644 index 0000000..cd7454a --- /dev/null +++ b/benches/smt-subtree.rs @@ -0,0 +1,136 @@ +use std::{fmt::Debug, hint, mem, time::Duration}; + +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use miden_crypto::{ + hash::rpo::RpoDigest, + merkle::{NodeIndex, Smt, SmtLeaf, SubtreeLeaf, SMT_DEPTH}, + Felt, Word, ONE, +}; +use rand_utils::prng_array; +use winter_utils::Randomizable; + +const PAIR_COUNTS: [u64; 5] = [1, 64, 128, 192, 256]; + +fn smt_subtree_even(c: &mut Criterion) { + let mut seed = [0u8; 32]; + + let mut group = c.benchmark_group("subtree8-even"); + + for pair_count in PAIR_COUNTS { + let bench_id = BenchmarkId::from_parameter(pair_count); + group.bench_with_input(bench_id, &pair_count, |b, &pair_count| { + b.iter_batched( + || { + // Setup. + let entries: Vec<(RpoDigest, Word)> = (0..pair_count) + .map(|n| { + // A single depth-8 subtree can have a maximum of 255 leaves. + let leaf_index = ((n as f64 / pair_count as f64) * 255.0) as u64; + let key = RpoDigest::new([ + generate_value(&mut seed), + ONE, + Felt::new(n), + Felt::new(leaf_index), + ]); + let value = generate_word(&mut seed); + (key, value) + }) + .collect(); + + let mut leaves: Vec<_> = entries + .iter() + .map(|(key, value)| { + let leaf = SmtLeaf::new_single(*key, *value); + let col = NodeIndex::from(leaf.index()).value(); + let hash = leaf.hash(); + SubtreeLeaf { col, hash } + }) + .collect(); + leaves.sort(); + leaves.dedup_by_key(|leaf| leaf.col); + leaves + }, + |leaves| { + // Benchmarked function. + let (subtree, _) = + Smt::build_subtree(hint::black_box(leaves), hint::black_box(SMT_DEPTH)); + assert!(!subtree.is_empty()); + }, + BatchSize::SmallInput, + ); + }); + } +} + +fn smt_subtree_random(c: &mut Criterion) { + let mut seed = [0u8; 32]; + + let mut group = c.benchmark_group("subtree8-rand"); + + for pair_count in PAIR_COUNTS { + let bench_id = BenchmarkId::from_parameter(pair_count); + group.bench_with_input(bench_id, &pair_count, |b, &pair_count| { + b.iter_batched( + || { + // Setup. + let entries: Vec<(RpoDigest, Word)> = (0..pair_count) + .map(|i| { + let leaf_index: u8 = generate_value(&mut seed); + let key = RpoDigest::new([ + ONE, + ONE, + Felt::new(i), + Felt::new(leaf_index as u64), + ]); + let value = generate_word(&mut seed); + (key, value) + }) + .collect(); + + let mut leaves: Vec<_> = entries + .iter() + .map(|(key, value)| { + let leaf = SmtLeaf::new_single(*key, *value); + let col = NodeIndex::from(leaf.index()).value(); + let hash = leaf.hash(); + SubtreeLeaf { col, hash } + }) + .collect(); + leaves.sort(); + leaves + }, + |leaves| { + let (subtree, _) = + Smt::build_subtree(hint::black_box(leaves), hint::black_box(SMT_DEPTH)); + assert!(!subtree.is_empty()); + }, + BatchSize::SmallInput, + ); + }); + } +} + +criterion_group! { + name = smt_subtree_group; + config = Criterion::default() + .measurement_time(Duration::from_secs(40)) + .sample_size(60) + .configure_from_args(); + targets = smt_subtree_even, smt_subtree_random +} +criterion_main!(smt_subtree_group); + +// HELPER FUNCTIONS +// -------------------------------------------------------------------------------------------- + +fn generate_value(seed: &mut [u8; 32]) -> T { + mem::swap(seed, &mut prng_array(*seed)); + let value: [T; 1] = rand_utils::prng_array(*seed); + value[0] +} + +fn generate_word(seed: &mut [u8; 32]) -> Word { + mem::swap(seed, &mut prng_array(*seed)); + let nums: [u64; 4] = prng_array(*seed); + [Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])] +} diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index a562aa5..dc897dc 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -23,7 +23,7 @@ pub use path::{MerklePath, RootPath, ValuePath}; mod smt; pub use smt::{ LeafIndex, MutationSet, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, - SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH, + SubtreeLeaf, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH, }; mod mmr; diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index c8133e2..a89b4b3 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -6,7 +6,7 @@ use alloc::{ use super::{ EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath, - MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, + MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, SubtreeLeaf, Word, EMPTY_WORD, }; mod error; @@ -249,6 +249,30 @@ impl Smt { None } } + + /// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and + /// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and + /// `leaves` must not contain more than one depth-8 subtree's worth of leaves. + /// + /// This function will then calculate the inner nodes above each leaf for 8 layers, as well as + /// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into + /// itself. + /// + /// # Panics + /// With debug assertions on, this function panics under invalid inputs: if `leaves` contains + /// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to + /// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified + /// maximum depth (`DEPTH`), or if `leaves` is not sorted. + /// + /// This function is public so functions returning it can be used in tests and benchmarks, but + /// is otherwise not part of the public API. + #[doc(hidden)] + pub fn build_subtree( + leaves: Vec, + bottom_depth: u8, + ) -> (BTreeMap, Vec) { + >::build_subtree(leaves, bottom_depth) + } } impl SparseMerkleTree for Smt { diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 1972304..b769c4b 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -410,14 +410,119 @@ pub(crate) trait SparseMerkleTree { accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect(); accumulator } + + /// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and + /// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and + /// `leaves` must not contain more than one depth-8 subtree's worth of leaves. + /// + /// This function will then calculate the inner nodes above each leaf for 8 layers, as well as + /// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into + /// itself. + /// + /// # Panics + /// With debug assertions on, this function panics under invalid inputs: if `leaves` contains + /// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to + /// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified + /// maximum depth (`DEPTH`), or if `leaves` is not sorted. + fn build_subtree( + mut leaves: Vec, + bottom_depth: u8, + ) -> (BTreeMap, Vec) { + debug_assert!(bottom_depth <= DEPTH); + debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH)); + debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32)); + + let subtree_root = bottom_depth - SUBTREE_DEPTH; + + let mut inner_nodes: BTreeMap = Default::default(); + + let mut next_leaves: Vec = Vec::with_capacity(leaves.len() / 2); + + for next_depth in (subtree_root..bottom_depth).rev() { + debug_assert!(next_depth <= bottom_depth); + + // `next_depth` is the stuff we're making. + // `current_depth` is the stuff we have. + let current_depth = next_depth + 1; + + let mut iter = leaves.drain(..).peekable(); + while let Some(first) = iter.next() { + // On non-continuous iterations, including the first iteration, `first_column` may + // be a left or right node. On subsequent continuous iterations, we will always call + // `iter.next()` twice. + + // On non-continuous iterations (including the very first iteration), this column + // could be either on the left or the right. If the next iteration is not + // discontinuous with our right node, then the next iteration's + + let is_right = first.col.is_odd(); + let (left, right) = if is_right { + // Discontinuous iteration: we have no left node, so it must be empty. + + let left = SubtreeLeaf { + col: first.col - 1, + hash: *EmptySubtreeRoots::entry(DEPTH, current_depth), + }; + let right = first; + + (left, right) + } else { + let left = first; + + let right_col = first.col + 1; + let right = match iter.peek().copied() { + Some(SubtreeLeaf { col, .. }) if col == right_col => { + // Our inputs must be sorted. + debug_assert!(left.col <= col); + // The next leaf in the iterator is our sibling. Use it and consume it! + iter.next().unwrap() + }, + // Otherwise, the leaves don't contain our sibling, so our sibling must be + // empty. + _ => SubtreeLeaf { + col: right_col, + hash: *EmptySubtreeRoots::entry(DEPTH, current_depth), + }, + }; + + (left, right) + }; + + let index = NodeIndex::new_unchecked(current_depth, left.col).parent(); + let node = InnerNode { left: left.hash, right: right.hash }; + let hash = node.hash(); + + let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, next_depth); + // If this hash is empty, then it doesn't become a new inner node, nor does it count + // as a leaf for the next depth. + if hash != equivalent_empty_hash { + inner_nodes.insert(index, node); + next_leaves.push(SubtreeLeaf { col: index.value(), hash }); + } + } + + // Stop borrowing `leaves`, so we can swap it. + // The iterator is empty at this point anyway. + drop(iter); + + // After each depth, consider the stuff we just made the new "leaves", and empty the + // other collection. + mem::swap(&mut leaves, &mut next_leaves); + } + + (inner_nodes, leaves) + } } // INNER NODE // ================================================================================================ +/// This struct is public so functions returning it can be used in `benches/`, but is otherwise not +/// part of the public API. +#[doc(hidden)] #[derive(Debug, Default, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] -pub(crate) struct InnerNode { +pub struct InnerNode { pub left: RpoDigest, pub right: RpoDigest, } @@ -530,8 +635,11 @@ impl MutationSet { // SUBTREES // ================================================================================================ +/// A subtree is of depth 8. +const SUBTREE_DEPTH: u8 = 8; + /// A depth-8 subtree contains 256 "columns" that can possibly be occupied. -const COLS_PER_SUBTREE: u64 = u64::pow(2, 8); +const COLS_PER_SUBTREE: u64 = u64::pow(2, SUBTREE_DEPTH as u32); /// Helper struct for organizing the data we care about when computing Merkle subtrees. /// diff --git a/src/merkle/smt/tests.rs b/src/merkle/smt/tests.rs index fbaa3fa..d889de8 100644 --- a/src/merkle/smt/tests.rs +++ b/src/merkle/smt/tests.rs @@ -1,7 +1,14 @@ use alloc::{collections::BTreeMap, vec::Vec}; -use super::{PairComputations, SmtLeaf, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter}; -use crate::{hash::rpo::RpoDigest, merkle::Smt, Felt, Word, ONE}; +use super::{ + NodeIndex, PairComputations, SmtLeaf, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, + COLS_PER_SUBTREE, SUBTREE_DEPTH, +}; +use crate::{ + hash::rpo::RpoDigest, + merkle::{Smt, SMT_DEPTH}, + Felt, Word, ONE, +}; fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf { SubtreeLeaf { @@ -89,3 +96,53 @@ fn test_sorted_pairs_to_leaves() { assert_eq!(control_leaf, &test_leaf); } } + +// Helper for the below tests. +fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> { + (0..pair_count) + .map(|i| { + let leaf_index = ((i as f64 / pair_count as f64) * (pair_count as f64)) as u64; + let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]); + let value = [ONE, ONE, ONE, Felt::new(i)]; + (key, value) + }) + .collect() +} + +#[test] +fn test_single_subtree() { + // A single subtree's worth of leaves. + const PAIR_COUNT: u64 = COLS_PER_SUBTREE; + + let entries = generate_entries(PAIR_COUNT); + + let control = Smt::with_entries(entries.clone()).unwrap(); + + // `entries` should already be sorted by nature of how we constructed it. + let leaves = Smt::sorted_pairs_to_leaves(entries).leaves; + let leaves = leaves.into_iter().next().unwrap(); + + let (first_subtree, next_leaves) = Smt::build_subtree(leaves, SMT_DEPTH); + assert!(!first_subtree.is_empty()); + + // The inner nodes computed from that subtree should match the nodes in our control tree. + for (index, node) in first_subtree.into_iter() { + let control = control.get_inner_node(index); + assert_eq!( + control, node, + "subtree-computed node at index {index:?} does not match control", + ); + } + + // The "next leaves" returned should also have matching hashes from the equivalent nodes in + // our control tree. + for SubtreeLeaf { col, hash } in next_leaves { + let index = NodeIndex::new(SMT_DEPTH - SUBTREE_DEPTH, col).unwrap(); + let control_node = control.get_inner_node(index); + let control = control_node.hash(); + assert_eq!( + control, hash, + "subtree-computed next leaf at index {index:?} does not match control", + ); + } +} From 1863dab6d3b8422f1722c753e67a4837a95305b3 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Thu, 14 Nov 2024 14:45:26 -0700 Subject: [PATCH 03/10] merkle: add a benchmark for constructing 256-balanced trees This is intended for comparison with the benchmarks from the previous commit. This benchmark represents the theoretical perfect-efficiency performance we could possibly (but impractically) get for computing depth-8 sparse Merkle subtrees. --- Cargo.toml | 4 +++ benches/merkle.rs | 66 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 benches/merkle.rs diff --git a/Cargo.toml b/Cargo.toml index ec59d42..74df3ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,10 @@ harness = false name = "smt-subtree" harness = false +[[bench]] +name = "merkle" +harness = false + [[bench]] name = "store" harness = false diff --git a/benches/merkle.rs b/benches/merkle.rs new file mode 100644 index 0000000..7d6bb2c --- /dev/null +++ b/benches/merkle.rs @@ -0,0 +1,66 @@ +//! Benchmark for building a [`miden_crypto::merkle::MerkleTree`]. This is intended to be compared +//! with the results from `benches/smt-subtree.rs`, as building a fully balanced Merkle tree with +//! 256 leaves should indicate the *absolute best* performance we could *possibly* get for building +//! a depth-8 sparse Merkle subtree, though practically speaking building a fully balanced Merkle +//! tree will perform better than the sparse version. At the time of this writing (2024/11/24), this +//! benchmark is about four times more efficient than the equivalent benchmark in +//! `benches/smt-subtree.rs`. +use std::{hint, mem, time::Duration}; + +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use miden_crypto::{merkle::MerkleTree, Felt, Word, ONE}; +use rand_utils::prng_array; + +fn balanced_merkle_even(c: &mut Criterion) { + c.bench_function("balanced-merkle-even", |b| { + b.iter_batched( + || { + let entries: Vec = + (0..256).map(|i| [Felt::new(i), ONE, ONE, Felt::new(i)]).collect(); + assert_eq!(entries.len(), 256); + entries + }, + |leaves| { + let tree = MerkleTree::new(hint::black_box(leaves)).unwrap(); + assert_eq!(tree.depth(), 8); + }, + BatchSize::SmallInput, + ); + }); +} + +fn balanced_merkle_rand(c: &mut Criterion) { + let mut seed = [0u8; 32]; + c.bench_function("balanced-merkle-rand", |b| { + b.iter_batched( + || { + let entries: Vec = (0..256).map(|_| generate_word(&mut seed)).collect(); + assert_eq!(entries.len(), 256); + entries + }, + |leaves| { + let tree = MerkleTree::new(hint::black_box(leaves)).unwrap(); + assert_eq!(tree.depth(), 8); + }, + BatchSize::SmallInput, + ); + }); +} + +criterion_group! { + name = smt_subtree_group; + config = Criterion::default() + .measurement_time(Duration::from_secs(20)) + .configure_from_args(); + targets = balanced_merkle_even, balanced_merkle_rand +} +criterion_main!(smt_subtree_group); + +// HELPER FUNCTIONS +// -------------------------------------------------------------------------------------------- + +fn generate_word(seed: &mut [u8; 32]) -> Word { + mem::swap(seed, &mut prng_array(*seed)); + let nums: [u64; 4] = prng_array(*seed); + [Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])] +} From cd1dc7c7c8074bf5c80376b63fd40e41d8bc3338 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Thu, 14 Nov 2024 15:59:29 -0700 Subject: [PATCH 04/10] smt: test that SparseMerkleTree::build_subtree() is composable --- src/merkle/smt/tests.rs | 63 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/src/merkle/smt/tests.rs b/src/merkle/smt/tests.rs index d889de8..cfe69c7 100644 --- a/src/merkle/smt/tests.rs +++ b/src/merkle/smt/tests.rs @@ -146,3 +146,66 @@ fn test_single_subtree() { ); } } + +// Test that not just can we compute a subtree correctly, but we can feed the results of one +// subtree into computing another. In other words, test that `build_subtree()` is correctly +// composable. +#[test] +fn test_two_subtrees() { + // Two subtrees' worth of leaves. + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2; + + let entries = generate_entries(PAIR_COUNT); + + let control = Smt::with_entries(entries.clone()).unwrap(); + + let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries); + // With two subtrees' worth of leaves, we should have exactly two subtrees. + let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap(); + assert_eq!(first.len() as u64, PAIR_COUNT / 2); + assert_eq!(first.len(), second.len()); + + let mut current_depth = SMT_DEPTH; + let mut next_leaves: Vec = Default::default(); + + let (first_nodes, leaves) = Smt::build_subtree(first, current_depth); + next_leaves.extend(leaves); + + let (second_nodes, leaves) = Smt::build_subtree(second, current_depth); + next_leaves.extend(leaves); + + // All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle. + let total_computed = first_nodes.len() + second_nodes.len() + next_leaves.len(); + assert_eq!(total_computed as u64, PAIR_COUNT); + + // Verify the computed nodes of both subtrees. + let computed_nodes = first_nodes.clone().into_iter().chain(second_nodes); + for (index, test_node) in computed_nodes { + let control_node = control.get_inner_node(index); + assert_eq!( + control_node, test_node, + "subtree-computed node at index {index:?} does not match control", + ); + } + + current_depth -= SUBTREE_DEPTH; + + let (nodes, next_leaves) = Smt::build_subtree(next_leaves, current_depth); + assert_eq!(nodes.len(), SUBTREE_DEPTH as usize); + assert_eq!(next_leaves.len(), 1); + + for (index, test_node) in nodes { + let control_node = control.get_inner_node(index); + assert_eq!( + control_node, test_node, + "subtree-computed node at index {index:?} does not match control", + ); + } + + for SubtreeLeaf { col, hash } in next_leaves { + let index = NodeIndex::new(current_depth - SUBTREE_DEPTH, col).unwrap(); + let control_node = control.get_inner_node(index); + let control = control_node.hash(); + assert_eq!(control, hash); + } +} From 475c8264d7bc02f087f304458de8decebe4bf0eb Mon Sep 17 00:00:00 2001 From: Qyriad Date: Thu, 14 Nov 2024 16:00:58 -0700 Subject: [PATCH 05/10] smt: test that subtree logic can correctly construct an entire tree This commit ensures that `SparseMerkleTree::build_subtree()` can correctly compose into building an entire sparse Merkle tree, without yet getting into potential complications concurrency introduces. --- src/merkle/smt/tests.rs | 105 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 103 insertions(+), 2 deletions(-) diff --git a/src/merkle/smt/tests.rs b/src/merkle/smt/tests.rs index cfe69c7..5336a65 100644 --- a/src/merkle/smt/tests.rs +++ b/src/merkle/smt/tests.rs @@ -1,8 +1,8 @@ use alloc::{collections::BTreeMap, vec::Vec}; use super::{ - NodeIndex, PairComputations, SmtLeaf, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, - COLS_PER_SUBTREE, SUBTREE_DEPTH, + InnerNode, LeafIndex, NodeIndex, PairComputations, SmtLeaf, SparseMerkleTree, SubtreeLeaf, + SubtreeLeavesIter, COLS_PER_SUBTREE, SUBTREE_DEPTH, }; use crate::{ hash::rpo::RpoDigest, @@ -209,3 +209,104 @@ fn test_two_subtrees() { assert_eq!(control, hash); } } + +#[test] +fn test_singlethreaded_subtrees() { + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + + let entries = generate_entries(PAIR_COUNT); + + let control = Smt::with_entries(entries.clone()).unwrap(); + + let mut accumulated_nodes: BTreeMap = Default::default(); + + let PairComputations { + leaves: mut leaf_subtrees, + nodes: test_leaves, + } = Smt::sorted_pairs_to_leaves(entries); + + for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + // There's no flat_map_unzip(), so this is the best we can do. + let (nodes, subtrees): (Vec>, Vec>) = leaf_subtrees + .into_iter() + .enumerate() + .map(|(i, subtree)| { + // Pre-assertions. + assert!( + subtree.is_sorted(), + "subtree {i} at bottom-depth {current_depth} is not sorted", + ); + assert!( + !subtree.is_empty(), + "subtree {i} at bottom-depth {current_depth} is empty!", + ); + + // Do actual things. + let (nodes, next_leaves) = Smt::build_subtree(subtree, current_depth); + // Post-assertions. + assert!(next_leaves.is_sorted()); + + for (&index, test_node) in nodes.iter() { + let control_node = control.get_inner_node(index); + assert_eq!( + test_node, &control_node, + "depth {} subtree {}: test node does not match control at index {:?}", + current_depth, i, index, + ); + } + + (nodes, next_leaves) + }) + .unzip(); + + // Update state between each depth iteration. + + let mut all_leaves: Vec = subtrees.into_iter().flatten().collect(); + leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect(); + accumulated_nodes.extend(nodes.into_iter().flatten()); + + assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); + } + + // Make sure the true leaves match, first checking length and then checking each individual + // leaf. + let control_leaves: BTreeMap<_, _> = control.leaves().collect(); + let control_leaves_len = control_leaves.len(); + let test_leaves_len = test_leaves.len(); + assert_eq!(test_leaves_len, control_leaves_len); + for (col, ref test_leaf) in test_leaves { + let index = LeafIndex::new_max_depth(col); + let &control_leaf = control_leaves.get(&index).unwrap(); + assert_eq!(test_leaf, control_leaf, "test leaf at column {col} does not match control"); + } + + // Make sure the inner nodes match, checking length first and then each individual leaf. + let control_nodes_len = control.inner_nodes().count(); + let test_nodes_len = accumulated_nodes.len(); + assert_eq!(test_nodes_len, control_nodes_len); + for (index, test_node) in accumulated_nodes.clone() { + let control_node = control.get_inner_node(index); + assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); + } + + // After the last iteration of the above for loop, we should have the new root node actually + // in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from + // `build_subtree()`. So let's check both! + + let control_root = control.get_inner_node(NodeIndex::root()); + + // That for loop should have left us with only one leaf subtree... + let [leaf_subtree]: [Vec<_>; 1] = leaf_subtrees.try_into().unwrap(); + // which itself contains only one 'leaf'... + let [root_leaf]: [SubtreeLeaf; 1] = leaf_subtree.try_into().unwrap(); + // which matches the expected root. + assert_eq!(control.root(), root_leaf.hash); + + // Likewise `accumulated_nodes` should contain a node at the root index... + assert!(accumulated_nodes.contains_key(&NodeIndex::root())); + // and it should match our actual root. + let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap(); + assert_eq!(control_root, *test_root); + // And of course the root we got from each place should match. + assert_eq!(control.root(), root_leaf.hash); +} From 5b9480a9f5de326a86919c027d813cebb6fb9fa9 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Thu, 14 Nov 2024 16:46:28 -0700 Subject: [PATCH 06/10] smt: implement test for basic parallelized subtree computation w/ rayon Building on the previous commit, this commit implements a test proving that `SparseMerkleTree::build_subtree()` can be composed into itself not just concurrently, but in parallel, without issue. --- Cargo.lock | 1 + Cargo.toml | 4 +- src/merkle/smt/tests.rs | 101 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 3e822fb..fca1e2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -534,6 +534,7 @@ dependencies = [ "rand", "rand_chacha", "rand_core", + "rayon", "seq-macro", "serde", "sha3", diff --git a/Cargo.toml b/Cargo.toml index 74df3ab..347cca8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,9 +40,10 @@ name = "store" harness = false [features] -default = ["std"] +default = ["std", "concurrent"] executable = ["dep:clap", "dep:rand-utils", "std"] serde = ["dep:serde", "serde?/alloc", "winter-math/serde"] +concurrent = ["dep:rayon"] std = [ "blake3/std", "dep:cc", @@ -66,6 +67,7 @@ sha3 = { version = "0.10", default-features = false } winter-crypto = { version = "0.10", default-features = false } winter-math = { version = "0.10", default-features = false } winter-utils = { version = "0.10", default-features = false } +rayon = { version = "1.10.0", optional = true } [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/src/merkle/smt/tests.rs b/src/merkle/smt/tests.rs index 5336a65..aa0a459 100644 --- a/src/merkle/smt/tests.rs +++ b/src/merkle/smt/tests.rs @@ -310,3 +310,104 @@ fn test_singlethreaded_subtrees() { // And of course the root we got from each place should match. assert_eq!(control.root(), root_leaf.hash); } + +/// The parallel version of `test_singlethreaded_subtree()`. +#[test] +#[cfg(feature = "concurrent")] +fn test_multithreaded_subtrees() { + use rayon::prelude::*; + + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + + let entries = generate_entries(PAIR_COUNT); + + let control = Smt::with_entries(entries.clone()).unwrap(); + + let mut accumulated_nodes: BTreeMap = Default::default(); + + let PairComputations { + leaves: mut leaf_subtrees, + nodes: test_leaves, + } = Smt::sorted_pairs_to_leaves(entries); + + for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + let (nodes, subtrees): (Vec>, Vec>) = leaf_subtrees + .into_par_iter() + .enumerate() + .map(|(i, subtree)| { + // Pre-assertions. + assert!( + subtree.is_sorted(), + "subtree {i} at bottom-depth {current_depth} is not sorted", + ); + assert!( + !subtree.is_empty(), + "subtree {i} at bottom-depth {current_depth} is empty!", + ); + + let (nodes, next_leaves) = Smt::build_subtree(subtree, current_depth); + + // Post-assertions. + assert!(next_leaves.is_sorted()); + for (&index, test_node) in nodes.iter() { + let control_node = control.get_inner_node(index); + assert_eq!( + test_node, &control_node, + "depth {} subtree {}: test node does not match control at index {:?}", + current_depth, i, index, + ); + } + + (nodes, next_leaves) + }) + .unzip(); + + let mut all_leaves: Vec = subtrees.into_iter().flatten().collect(); + leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect(); + + accumulated_nodes.extend(nodes.into_iter().flatten()); + + assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); + } + + // Make sure the true leaves match, checking length first and then each individual leaf. + let control_leaves: BTreeMap<_, _> = control.leaves().collect(); + let control_leaves_len = control_leaves.len(); + let test_leaves_len = test_leaves.len(); + assert_eq!(test_leaves_len, control_leaves_len); + for (col, ref test_leaf) in test_leaves { + let index = LeafIndex::new_max_depth(col); + let &control_leaf = control_leaves.get(&index).unwrap(); + assert_eq!(test_leaf, control_leaf); + } + + // Make sure the inner nodes match, checking length first and then each individual leaf. + let control_nodes_len = control.inner_nodes().count(); + let test_nodes_len = accumulated_nodes.len(); + assert_eq!(test_nodes_len, control_nodes_len); + for (index, test_node) in accumulated_nodes.clone() { + let control_node = control.get_inner_node(index); + assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); + } + + // After the last iteration of the above for loop, we should have the new root node actually + // in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from + // `build_subtree()`. So let's check both! + + let control_root = control.get_inner_node(NodeIndex::root()); + + // That for loop should have left us with only one leaf subtree... + let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap(); + // which itself contains only one 'leaf'... + let [root_leaf]: [_; 1] = leaf_subtree.try_into().unwrap(); + // which matches the expected root. + assert_eq!(control.root(), root_leaf.hash); + + // Likewise `accumulated_nodes` should contain a node at the root index... + assert!(accumulated_nodes.contains_key(&NodeIndex::root())); + // and it should match our actual root. + let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap(); + assert_eq!(control_root, *test_root); + // And of course the root we got from each place should match. + assert_eq!(control.root(), root_leaf.hash); +} From 38422f592bd16a94f02242d6e60cda712650c75f Mon Sep 17 00:00:00 2001 From: Qyriad Date: Thu, 14 Nov 2024 17:49:52 -0700 Subject: [PATCH 07/10] smt: add from_raw_parts() to trait interface This commit adds a new required method to the SparseMerkleTree trait, to allow generic construction from pre-computed parts. This will be used to add a generic version of `with_entries()` in a later commit. --- src/merkle/smt/full/mod.rs | 30 ++++++++++++++++++++++++++++++ src/merkle/smt/mod.rs | 10 ++++++++++ src/merkle/smt/simple/mod.rs | 30 ++++++++++++++++++++++++++++++ 3 files changed, 70 insertions(+) diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index a89b4b3..d57a5a5 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -101,6 +101,23 @@ impl Smt { Ok(tree) } + /// Returns a new [`Smt`] instantiated from already computed leaves and nodes. + /// + /// This function performs minimal consistency checking. It is the caller's responsibility to + /// ensure the passed arguments are correct and consistent with each other. + /// + /// # Panics + /// With debug assertions on, this function panics if `root` does not match the root node in + /// `inner_nodes`. + pub fn from_raw_parts( + inner_nodes: BTreeMap, + leaves: BTreeMap, + root: RpoDigest, + ) -> Self { + // Our particular implementation of `from_raw_parts()` never returns `Err`. + >::from_raw_parts(inner_nodes, leaves, root).unwrap() + } + // PUBLIC ACCESSORS // -------------------------------------------------------------------------------------------- @@ -284,6 +301,19 @@ impl SparseMerkleTree for Smt { const EMPTY_VALUE: Self::Value = EMPTY_WORD; const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0); + fn from_raw_parts( + inner_nodes: BTreeMap, + leaves: BTreeMap, + root: RpoDigest, + ) -> Result { + if cfg!(debug_assertions) { + let root_node = inner_nodes.get(&NodeIndex::root()).unwrap(); + assert_eq!(root_node.hash(), root); + } + + Ok(Self { root, inner_nodes, leaves }) + } + fn root(&self) -> RpoDigest { self.root } diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index b769c4b..8f71ee8 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -292,6 +292,16 @@ pub(crate) trait SparseMerkleTree { // REQUIRED METHODS // --------------------------------------------------------------------------------------------- + /// Construct this type from already computed leaves and nodes. The caller ensures passed + /// arguments are correct and consistent with each other. + fn from_raw_parts( + inner_nodes: BTreeMap, + leaves: BTreeMap, + root: RpoDigest, + ) -> Result + where + Self: Sized; + /// The root of the tree fn root(&self) -> RpoDigest; diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index 04476a0..1ded87f 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -100,6 +100,23 @@ impl SimpleSmt { Ok(tree) } + /// Returns a new [`SimpleSmt`] instantiated from already computed leaves and nodes. + /// + /// This function performs minimal consistency checking. It is the caller's responsibility to + /// ensure the passed arguments are correct and consistent with each other. + /// + /// # Panics + /// With debug assertions on, this function panics if `root` does not match the root node in + /// `inner_nodes`. + pub fn from_raw_parts( + inner_nodes: BTreeMap, + leaves: BTreeMap, + root: RpoDigest, + ) -> Self { + // Our particular implementation of `from_raw_parts()` never returns `Err`. + >::from_raw_parts(inner_nodes, leaves, root).unwrap() + } + /// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices /// starting at index 0. pub fn with_contiguous_leaves( @@ -309,6 +326,19 @@ impl SparseMerkleTree for SimpleSmt { const EMPTY_VALUE: Self::Value = EMPTY_WORD; const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0); + fn from_raw_parts( + inner_nodes: BTreeMap, + leaves: BTreeMap, + root: RpoDigest, + ) -> Result { + if cfg!(debug_assertions) { + let root_node = inner_nodes.get(&NodeIndex::root()).unwrap(); + assert_eq!(root_node.hash(), root); + } + + Ok(Self { root, inner_nodes, leaves }) + } + fn root(&self) -> RpoDigest { self.root } From cc144a6ef5b4024b06a487cd273b2dce4ff548f3 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Thu, 14 Nov 2024 19:01:58 -0700 Subject: [PATCH 08/10] smt: add parallel constructors to Smt and SimpleSmt What the previous few commits have been leading up to: SparseMerkleTree now has a function to construct the tree from existing data in parallel. This is significantly faster than the singlethreaded equivalent. Benchmarks incoming! --- src/merkle/smt/full/mod.rs | 13 +++++++ src/merkle/smt/mod.rs | 73 ++++++++++++++++++++++++++++++++++++ src/merkle/smt/simple/mod.rs | 17 +++++++++ src/merkle/smt/tests.rs | 14 +++++++ 4 files changed, 117 insertions(+) diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index d57a5a5..572766c 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -101,6 +101,19 @@ impl Smt { Ok(tree) } + /// The parallel version of [`Smt::with_entries()`]. + /// + /// Returns a new [`Smt`] instantiated with leaves set as specified by the provided entries, + /// constructed in parallel. + /// + /// All leaves omitted from the entries list are set to [Self::EMPTY_VALUE]. + #[cfg(feature = "concurrent")] + pub fn with_entries_par( + entries: impl IntoIterator, + ) -> Result { + >::with_entries_par(Vec::from_iter(entries)) + } + /// Returns a new [`Smt`] instantiated from already computed leaves and nodes. /// /// This function performs minimal consistency checking. It is the caller's responsibility to diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 8f71ee8..97f7c8b 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -65,6 +65,19 @@ pub(crate) trait SparseMerkleTree { // PROVIDED METHODS // --------------------------------------------------------------------------------------------- + /// Creates a new sparse Merkle tree from an existing set of key-value pairs, in parallel. + #[cfg(feature = "concurrent")] + fn with_entries_par(entries: Vec<(Self::Key, Self::Value)>) -> Result + where + Self: Sized, + { + let (inner_nodes, leaves) = Self::build_subtrees(entries); + let leaves: BTreeMap = + leaves.into_iter().map(|(index, leaf)| (index.value(), leaf)).collect(); + let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash(); + Self::from_raw_parts(inner_nodes, leaves, root) + } + /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle /// path to the leaf, as well as the leaf itself. fn open(&self, key: &Self::Key) -> Self::Opening { @@ -429,6 +442,8 @@ pub(crate) trait SparseMerkleTree { /// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into /// itself. /// + /// This function is mostly an implementation detail of [`SparseMerkleTree::build_subtrees()`]. + /// /// # Panics /// With debug assertions on, this function panics under invalid inputs: if `leaves` contains /// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to @@ -522,6 +537,64 @@ pub(crate) trait SparseMerkleTree { (inner_nodes, leaves) } + + /// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs. + /// + /// `entries` need not be sorted. This function will sort them. + /// + /// This function is mostly an implementation detail of + /// [`SparseMerkleTree::with_entries_par()`]. + #[cfg(feature = "concurrent")] + fn build_subtrees( + mut entries: Vec<(Self::Key, Self::Value)>, + ) -> (BTreeMap, BTreeMap, Self::Leaf>) { + use rayon::prelude::*; + + entries.sort_by_key(|item| { + let index = Self::key_to_leaf_index(&item.0); + index.value() + }); + + let mut accumulated_nodes: BTreeMap = Default::default(); + + let PairComputations { + leaves: mut leaf_subtrees, + nodes: initial_leaves, + } = Self::sorted_pairs_to_leaves(entries); + + for current_depth in (SUBTREE_DEPTH..=DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { + let (nodes, subtrees): (Vec>, Vec>) = leaf_subtrees + .into_par_iter() + .map(|subtree| { + debug_assert!(subtree.is_sorted()); + debug_assert!(!subtree.is_empty()); + + let (nodes, next_leaves) = Self::build_subtree(subtree, current_depth); + + debug_assert!(next_leaves.is_sorted()); + + (nodes, next_leaves) + }) + .unzip(); + + let mut all_leaves: Vec = subtrees.into_iter().flatten().collect(); + leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect(); + accumulated_nodes.extend(nodes.into_iter().flatten()); + + debug_assert!(!leaf_subtrees.is_empty()); + } + + let leaves: BTreeMap, Self::Leaf> = initial_leaves + .into_iter() + .map(|(key, value)| { + // This unwrap *should* be unreachable. + let key = LeafIndex::::new(key).unwrap(); + (key, value) + }) + .collect(); + + (accumulated_nodes, leaves) + } } // INNER NODE diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index 1ded87f..4c5e404 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -100,6 +100,23 @@ impl SimpleSmt { Ok(tree) } + /// The parallel version of [`SimpleSmt::with_leaves()`]. + /// + /// Returns a new [`SimpleSmt`] instantiated with leaves set as specified by the provided + /// entries. + /// + /// All leaves omitted from the entries list are set to [ZERO; 4]. + #[cfg(feature = "concurrent")] + pub fn with_leaves_par( + entries: impl IntoIterator, + ) -> Result { + let entries: Vec<_> = entries + .into_iter() + .map(|(col, value)| (LeafIndex::::new(col).unwrap(), value)) + .collect(); + >::with_entries_par(entries) + } + /// Returns a new [`SimpleSmt`] instantiated from already computed leaves and nodes. /// /// This function performs minimal consistency checking. It is the caller's responsibility to diff --git a/src/merkle/smt/tests.rs b/src/merkle/smt/tests.rs index aa0a459..1235fca 100644 --- a/src/merkle/smt/tests.rs +++ b/src/merkle/smt/tests.rs @@ -411,3 +411,17 @@ fn test_multithreaded_subtrees() { // And of course the root we got from each place should match. assert_eq!(control.root(), root_leaf.hash); } + +#[test] +#[cfg(feature = "concurrent")] +fn test_with_entries_par() { + const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64; + + let entries = generate_entries(PAIR_COUNT); + + let control = Smt::with_entries(entries.clone()).unwrap(); + + let smt = Smt::with_entries_par(entries.clone()).unwrap(); + assert_eq!(smt.root(), control.root()); + assert_eq!(smt, control); +} From ec2dfdf4b8efd20d34d44683078bf9e210ab2b52 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Thu, 14 Nov 2024 19:44:26 -0700 Subject: [PATCH 09/10] smt: add benchmarks for parallel construction --- Cargo.toml | 5 +++ benches/parallel-subtree.rs | 75 +++++++++++++++++++++++++++++++++++++ src/main.rs | 28 +++++++++++++- 3 files changed, 107 insertions(+), 1 deletion(-) create mode 100644 benches/parallel-subtree.rs diff --git a/Cargo.toml b/Cargo.toml index 347cca8..750748a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,11 @@ harness = false name = "merkle" harness = false +[[bench]] +name = "parallel-subtree" +harness = false +required-features = ["concurrent"] + [[bench]] name = "store" harness = false diff --git a/benches/parallel-subtree.rs b/benches/parallel-subtree.rs new file mode 100644 index 0000000..65c3918 --- /dev/null +++ b/benches/parallel-subtree.rs @@ -0,0 +1,75 @@ +use std::{fmt::Debug, hint, mem, time::Duration}; + +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use miden_crypto::{hash::rpo::RpoDigest, merkle::Smt, Felt, Word, ONE}; +use rand_utils::prng_array; +use winter_utils::Randomizable; + +// 2^0, 2^4, 2^8, 2^12, 2^16 +const PAIR_COUNTS: [u64; 6] = [1, 16, 256, 4096, 65536, 1_048_576]; + +fn smt_parallel_subtree(c: &mut Criterion) { + let mut seed = [0u8; 32]; + + let mut group = c.benchmark_group("parallel-subtrees"); + + for pair_count in PAIR_COUNTS { + let bench_id = BenchmarkId::from_parameter(pair_count); + group.bench_with_input(bench_id, &pair_count, |b, &pair_count| { + b.iter_batched( + || { + // Setup. + let entries: Vec<(RpoDigest, Word)> = (0..pair_count) + .map(|i| { + let count = pair_count as f64; + let idx = ((i as f64 / count) * (count)) as u64; + let key = RpoDigest::new([ + generate_value(&mut seed), + ONE, + Felt::new(i), + Felt::new(idx), + ]); + let value = generate_word(&mut seed); + (key, value) + }) + .collect(); + + let control = Smt::with_entries(entries.clone()).unwrap(); + (entries, control) + }, + |(entries, control)| { + // Benchmarked function. + let tree = Smt::with_entries_par(hint::black_box(entries)).unwrap(); + assert_eq!(tree.root(), control.root()); + }, + BatchSize::SmallInput, + ); + }); + } +} + +criterion_group! { + name = smt_subtree_group; + config = Criterion::default() + //.measurement_time(Duration::from_secs(960)) + .measurement_time(Duration::from_secs(60)) + .sample_size(10) + .configure_from_args(); + targets = smt_parallel_subtree +} +criterion_main!(smt_subtree_group); + +// HELPER FUNCTIONS +// -------------------------------------------------------------------------------------------- + +fn generate_value(seed: &mut [u8; 32]) -> T { + mem::swap(seed, &mut prng_array(*seed)); + let value: [T; 1] = rand_utils::prng_array(*seed); + value[0] +} + +fn generate_word(seed: &mut [u8; 32]) -> Word { + mem::swap(seed, &mut prng_array(*seed)); + let nums: [u64; 4] = prng_array(*seed); + [Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])] +} diff --git a/src/main.rs b/src/main.rs index 776ccc2..018dd50 100644 --- a/src/main.rs +++ b/src/main.rs @@ -33,7 +33,12 @@ pub fn benchmark_smt() { entries.push((key, value)); } - let mut tree = construction(entries, tree_size).unwrap(); + let mut tree = construction(entries.clone(), tree_size).unwrap(); + #[cfg(feature = "concurrent")] + { + let parallel = parallel_construction(entries, tree_size).unwrap(); + assert_eq!(tree, parallel); + } insertion(&mut tree, tree_size).unwrap(); batched_insertion(&mut tree, tree_size).unwrap(); proof_generation(&mut tree, tree_size).unwrap(); @@ -56,6 +61,27 @@ pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result, + size: u64, +) -> Result { + println!("Running a parallel construction benchmark:"); + let now = Instant::now(); + + let tree = Smt::with_entries_par(entries).unwrap(); + + let elapsed = now.elapsed(); + println!( + "Parallel-constructed an SMT with {} key-value pairs in {:.3} seconds", + size, + elapsed.as_secs_f32(), + ); + println!("Number of leaf nodes: {}\n", tree.leaves().count()); + + Ok(tree) +} + /// Runs the insertion benchmark for the [`Smt`]. pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> { println!("Running an insertion benchmark:"); From 3f52ef32a38b37756a4bfb3b7b528504ee75e155 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Thu, 14 Nov 2024 19:48:50 -0700 Subject: [PATCH 10/10] add news item for smt parallel subtree construction --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc22853..d225165 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## 0.11.0 (2024-10-30) - [BREAKING] Updated Winterfell dependency to v0.10 (#338). +- Added `Smt::with_entries_par()`, a parallel version of `with_entries()` with significantly better performance (#341). ## 0.11.0 (2024-10-17)