smt: add parallel constructors to Smt and SimpleSmt

What the previous few commits have been leading up to: SparseMerkleTree
now has a function to construct the tree from existing data in parallel.
This is significantly faster than the singlethreaded equivalent.
Benchmarks incoming!
This commit is contained in:
Qyriad 2024-11-14 19:01:58 -07:00
parent ecd7e18623
commit 13363307b4
4 changed files with 116 additions and 0 deletions

View file

@ -101,6 +101,19 @@ impl Smt {
Ok(tree) Ok(tree)
} }
/// The parallel version of [`Smt::with_entries()`].
///
/// Returns a new [`Smt`] instantiated with leaves set as specified by the provided entries,
/// constructed in parallel.
///
/// All leaves omitted from the entries list are set to [Self::EMPTY_VALUE].
#[cfg(feature = "concurrent")]
pub fn with_entries_par(
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> Result<Self, MerkleError> {
<Self as SparseMerkleTree<SMT_DEPTH>>::with_entries_par(Vec::from_iter(entries))
}
/// Returns a new [`Smt`] instantiated from already computed leaves and nodes. /// Returns a new [`Smt`] instantiated from already computed leaves and nodes.
/// ///
/// This function performs minimal consistency checking. It is the caller's responsibility to /// This function performs minimal consistency checking. It is the caller's responsibility to

View file

@ -65,6 +65,19 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
// PROVIDED METHODS // PROVIDED METHODS
// --------------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------------
/// Creates a new sparse Merkle tree from an existing set of key-value pairs, in parallel.
#[cfg(feature = "concurrent")]
fn with_entries_par(entries: Vec<(Self::Key, Self::Value)>) -> Result<Self, MerkleError>
where
Self: Sized,
{
let (inner_nodes, leaves) = Self::build_subtrees(entries);
let leaves: BTreeMap<u64, Self::Leaf> =
leaves.into_iter().map(|(index, leaf)| (index.value(), leaf)).collect();
let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash();
Self::from_raw_parts(inner_nodes, leaves, root)
}
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
/// path to the leaf, as well as the leaf itself. /// path to the leaf, as well as the leaf itself.
fn open(&self, key: &Self::Key) -> Self::Opening { fn open(&self, key: &Self::Key) -> Self::Opening {
@ -429,6 +442,8 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
/// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into /// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into
/// itself. /// itself.
/// ///
/// This function is mostly an implementation detail of [`SparseMerkleTree::build_subtrees()`].
///
/// # Panics /// # Panics
/// With debug assertions on, this function panics under invalid inputs: if `leaves` contains /// With debug assertions on, this function panics under invalid inputs: if `leaves` contains
/// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to /// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to
@ -522,6 +537,64 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
(inner_nodes, leaves) (inner_nodes, leaves)
} }
/// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs.
///
/// `entries` need not be sorted. This function will sort them.
///
/// This function is mostly an implementation detail of
/// [`SparseMerkleTree::with_entries_par()`].
#[cfg(feature = "concurrent")]
fn build_subtrees(
mut entries: Vec<(Self::Key, Self::Value)>,
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<LeafIndex<DEPTH>, Self::Leaf>) {
use rayon::prelude::*;
entries.sort_by_key(|item| {
let index = Self::key_to_leaf_index(&item.0);
index.value()
});
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
let PairComputations {
leaves: mut leaf_subtrees,
nodes: initial_leaves,
} = Self::sorted_pairs_to_leaves(entries);
for current_depth in (SUBTREE_DEPTH..=DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
let (nodes, subtrees): (Vec<BTreeMap<_, _>>, Vec<Vec<SubtreeLeaf>>) = leaf_subtrees
.into_par_iter()
.map(|subtree| {
debug_assert!(subtree.is_sorted());
debug_assert!(!subtree.is_empty());
let (nodes, next_leaves) = Self::build_subtree(subtree, current_depth);
debug_assert!(next_leaves.is_sorted());
(nodes, next_leaves)
})
.unzip();
let mut all_leaves: Vec<SubtreeLeaf> = subtrees.into_iter().flatten().collect();
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect();
accumulated_nodes.extend(nodes.into_iter().flatten());
debug_assert!(!leaf_subtrees.is_empty());
}
let leaves: BTreeMap<LeafIndex<DEPTH>, Self::Leaf> = initial_leaves
.into_iter()
.map(|(key, value)| {
// This unwrap *should* be unreachable.
let key = LeafIndex::<DEPTH>::new(key).unwrap();
(key, value)
})
.collect();
(accumulated_nodes, leaves)
}
} }
// INNER NODE // INNER NODE

View file

@ -100,6 +100,22 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
Ok(tree) Ok(tree)
} }
/// The parallel version of [`SimpleSmt::with_leaves()`].
///
/// Returns a new [`SimpleSmt`] instantiated with leaves set as specified by the provided entries.
///
/// All leaves omitted from the entries list are set to [ZERO; 4].
#[cfg(feature = "concurrent")]
pub fn with_leaves_par(
entries: impl IntoIterator<Item = (u64, Word)>,
) -> Result<Self, MerkleError> {
let entries: Vec<_> = entries
.into_iter()
.map(|(col, value)| (LeafIndex::<DEPTH>::new(col).unwrap(), value))
.collect();
<Self as SparseMerkleTree<DEPTH>>::with_entries_par(entries)
}
/// Returns a new [`SimpleSmt`] instantiated from already computed leaves and nodes. /// Returns a new [`SimpleSmt`] instantiated from already computed leaves and nodes.
/// ///
/// This function performs minimal consistency checking. It is the caller's responsibility to /// This function performs minimal consistency checking. It is the caller's responsibility to

View file

@ -412,3 +412,17 @@ fn test_multithreaded_subtrees() {
// And of course the root we got from each place should match. // And of course the root we got from each place should match.
assert_eq!(control.root(), root_leaf.hash); assert_eq!(control.root(), root_leaf.hash);
} }
#[test]
#[cfg(feature = "concurrent")]
fn test_with_entries_par() {
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 64;
let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries(entries.clone()).unwrap();
let smt = Smt::with_entries_par(entries.clone()).unwrap();
assert_eq!(smt.root(), control.root());
assert_eq!(smt, control);
}