Compare commits

...

10 commits

8 changed files with 621 additions and 85 deletions

1
Cargo.lock generated
View file

@ -534,6 +534,7 @@ dependencies = [
"rand", "rand",
"rand_chacha", "rand_chacha",
"rand_core", "rand_core",
"rayon",
"seq-macro", "seq-macro",
"serde", "serde",
"sha3", "sha3",

View file

@ -35,6 +35,10 @@ harness = false
name = "merkle" name = "merkle"
harness = false harness = false
[[bench]]
name = "parallel-subtree"
harness = false
[[bench]] [[bench]]
name = "store" name = "store"
harness = false harness = false
@ -66,6 +70,7 @@ sha3 = { version = "0.10", default-features = false }
winter-crypto = { version = "0.10", default-features = false } winter-crypto = { version = "0.10", default-features = false }
winter-math = { version = "0.10", default-features = false } winter-math = { version = "0.10", default-features = false }
winter-utils = { version = "0.10", default-features = false } winter-utils = { version = "0.10", default-features = false }
rayon = "1.10.0"
[dev-dependencies] [dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] } criterion = { version = "0.5", features = ["html_reports"] }

View file

@ -0,0 +1,72 @@
use std::{fmt::Debug, hint, mem, time::Duration};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use miden_crypto::{hash::rpo::RpoDigest, merkle::Smt, Felt, Word, ONE};
use rand_utils::prng_array;
use winter_utils::Randomizable;
// 2^0, 2^4, 2^8, 2^12, 2^16
const PAIR_COUNTS: [u64; 6] = [1, 16, 256, 4096, 65536, 1_048_576];
fn smt_parallel_subtree(c: &mut Criterion) {
let mut seed = [0u8; 32];
let mut group = c.benchmark_group("parallel-subtrees");
for pair_count in PAIR_COUNTS {
let bench_id = BenchmarkId::from_parameter(pair_count);
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
b.iter_batched(
|| {
// Setup.
(0..pair_count)
.map(|i| {
let count = pair_count as f64;
let idx = ((i as f64 / count) * (count)) as u64;
let key = RpoDigest::new([
generate_value(&mut seed),
ONE,
Felt::new(i),
Felt::new(idx),
]);
let value = generate_word(&mut seed);
(key, value)
})
.collect()
},
|entries| {
// Benchmarked function.
let (leaves, inner_nodes) = Smt::build_subtrees(hint::black_box(entries));
assert!(!leaves.is_empty());
assert!(!inner_nodes.is_empty());
},
BatchSize::SmallInput,
);
});
}
}
criterion_group! {
name = smt_subtree_group;
config = Criterion::default()
.measurement_time(Duration::from_secs(960))
.sample_size(60)
.configure_from_args();
targets = smt_parallel_subtree
}
criterion_main!(smt_subtree_group);
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
fn generate_value<T: Copy + Debug + Randomizable>(seed: &mut [u8; 32]) -> T {
mem::swap(seed, &mut prng_array(*seed));
let value: [T; 1] = rand_utils::prng_array(*seed);
value[0]
}
fn generate_word(seed: &mut [u8; 32]) -> Word {
mem::swap(seed, &mut prng_array(*seed));
let nums: [u64; 4] = prng_array(*seed);
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
}

View file

@ -3,7 +3,7 @@ use std::time::Instant;
use clap::Parser; use clap::Parser;
use miden_crypto::{ use miden_crypto::{
hash::rpo::{Rpo256, RpoDigest}, hash::rpo::{Rpo256, RpoDigest},
merkle::{MerkleError, Smt}, merkle::{MerkleError, NodeIndex, Smt},
Felt, Word, ONE, Felt, Word, ONE,
}; };
use rand_utils::rand_value; use rand_utils::rand_value;
@ -33,7 +33,9 @@ pub fn benchmark_smt() {
entries.push((key, value)); entries.push((key, value));
} }
let mut tree = construction(entries, tree_size).unwrap(); let mut tree = construction(entries.clone(), tree_size).unwrap();
let parallel = parallel_construction(entries, tree_size).unwrap();
assert_eq!(tree, parallel);
insertion(&mut tree, tree_size).unwrap(); insertion(&mut tree, tree_size).unwrap();
batched_insertion(&mut tree, tree_size).unwrap(); batched_insertion(&mut tree, tree_size).unwrap();
proof_generation(&mut tree, tree_size).unwrap(); proof_generation(&mut tree, tree_size).unwrap();
@ -56,6 +58,31 @@ pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result<Smt, M
Ok(tree) Ok(tree)
} }
pub fn parallel_construction(
entries: Vec<(RpoDigest, Word)>,
size: u64,
) -> Result<Smt, MerkleError> {
println!("Running a parallel construction benchmark:");
let now = Instant::now();
let (inner_nodes, leaves) = Smt::build_subtrees(entries);
let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash();
let leaves = leaves.into_iter().map(|(key, value)| (key.value(), value)).collect();
let tree = Smt::from_raw_parts(inner_nodes, leaves, root)?;
let elapsed = now.elapsed();
println!(
"Parallel-constructed an SMT with {} key-value pairs in {:.3} seconds",
size,
elapsed.as_secs_f32(),
);
println!("Number of leaf nodes: {}\n", tree.leaves().count());
Ok(tree)
}
/// Runs the insertion benchmark for the [`Smt`]. /// Runs the insertion benchmark for the [`Smt`].
pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> { pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
println!("Running an insertion benchmark:"); println!("Running an insertion benchmark:");

View file

@ -78,27 +78,15 @@ impl Smt {
pub fn with_entries( pub fn with_entries(
entries: impl IntoIterator<Item = (RpoDigest, Word)>, entries: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> Result<Self, MerkleError> { ) -> Result<Self, MerkleError> {
// create an empty tree <Self as SparseMerkleTree<SMT_DEPTH>>::with_entries(entries)
let mut tree = Self::new(); }
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so pub fn from_raw_parts(
// entries with the empty value need additional tracking. inner_nodes: BTreeMap<NodeIndex, InnerNode>,
let mut key_set_to_zero = BTreeSet::new(); leaves: BTreeMap<u64, SmtLeaf>,
root: RpoDigest,
for (key, value) in entries { ) -> Result<Self, MerkleError> {
let old_value = tree.insert(key, value); <Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root)
if old_value != EMPTY_WORD || key_set_to_zero.contains(&key) {
return Err(MerkleError::DuplicateValuesForIndex(
LeafIndex::<SMT_DEPTH>::from(key).value(),
));
}
if value == EMPTY_WORD {
key_set_to_zero.insert(key);
};
}
Ok(tree)
} }
// PUBLIC ACCESSORS // PUBLIC ACCESSORS
@ -250,12 +238,31 @@ impl Smt {
} }
} }
/// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and
/// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and
/// `leaves` must not contain more than one depth-8 subtree's worth of leaves.
///
/// This function will then calculate the inner nodes above each leaf for 8 layers, as well as
/// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into
/// itself.
///
/// # Panics
/// With debug assertions on, this function panics under invalid inputs: if `leaves` contains
/// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to
/// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified
/// maximum depth (`DEPTH`), or if `leaves` is not sorted.
pub fn build_subtree( pub fn build_subtree(
leaves: Vec<SubtreeLeaf>, leaves: Vec<SubtreeLeaf>,
bottom_depth: u8, bottom_depth: u8,
) -> (BTreeMap<NodeIndex, InnerNode>, Vec<SubtreeLeaf>) { ) -> (BTreeMap<NodeIndex, InnerNode>, Vec<SubtreeLeaf>) {
<Self as SparseMerkleTree<SMT_DEPTH>>::build_subtree(leaves, bottom_depth) <Self as SparseMerkleTree<SMT_DEPTH>>::build_subtree(leaves, bottom_depth)
} }
pub fn build_subtrees(
entries: Vec<(RpoDigest, Word)>,
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<LeafIndex<SMT_DEPTH>, SmtLeaf>) {
<Self as SparseMerkleTree<SMT_DEPTH>>::build_subtrees(entries)
}
} }
impl SparseMerkleTree<SMT_DEPTH> for Smt { impl SparseMerkleTree<SMT_DEPTH> for Smt {
@ -267,6 +274,45 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
const EMPTY_VALUE: Self::Value = EMPTY_WORD; const EMPTY_VALUE: Self::Value = EMPTY_WORD;
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0); const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, SmtLeaf>,
root: RpoDigest,
) -> Result<Self, MerkleError> {
if cfg!(debug_assertions) {
let root_node = inner_nodes.get(&NodeIndex::root()).unwrap();
assert_eq!(root_node.hash(), root);
}
Ok(Self { root, inner_nodes, leaves })
}
fn with_entries(
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> Result<Self, MerkleError> {
// create an empty tree
let mut tree = Self::new();
// This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so
// entries with the empty value need additional tracking.
let mut key_set_to_zero = BTreeSet::new();
for (key, value) in entries {
let old_value = tree.insert(key, value);
if old_value != EMPTY_WORD || key_set_to_zero.contains(&key) {
return Err(MerkleError::DuplicateValuesForIndex(
LeafIndex::<SMT_DEPTH>::from(key).value(),
));
}
if value == EMPTY_WORD {
key_set_to_zero.insert(key);
};
}
Ok(tree)
}
fn root(&self) -> RpoDigest { fn root(&self) -> RpoDigest {
self.root self.root
} }

View file

@ -292,6 +292,21 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
// REQUIRED METHODS // REQUIRED METHODS
// --------------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------------
/// Construct a tree from an iterator of its keys and values.
fn with_entries(
entries: impl IntoIterator<Item = (Self::Key, Self::Value)>,
) -> Result<Self, MerkleError>
where
Self: Sized;
fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, Self::Leaf>,
root: RpoDigest,
) -> Result<Self, MerkleError>
where
Self: Sized;
/// The root of the tree /// The root of the tree
fn root(&self) -> RpoDigest; fn root(&self) -> RpoDigest;
@ -392,17 +407,19 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
accumulator accumulator
} }
/// Builds Merkle nodes from a bottom layer of tuples of horizontal indices and their hashes, /// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and
/// sorted by their position. /// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and
/// `leaves` must not contain more than one depth-8 subtree's worth of leaves.
/// ///
/// The leaves are 'conceptual' leaves, simply being entities at the bottom of some subtree, not /// This function will then calculate the inner nodes above each leaf for 8 layers, as well as
/// [`Self::Leaf`]. /// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into
/// itself.
/// ///
/// # Panics /// # Panics
/// With debug assertions on, this function panics under invalid inputs: if `leaves` contains /// With debug assertions on, this function panics under invalid inputs: if `leaves` contains
/// more entries than can fit in a depth-8 subtree (more than 256), if `bottom_depth` is /// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to
/// lower in the tree than the specified maximum depth (`DEPTH`), or if `leaves` is not sorted. /// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified
// FIXME: more complete docstring. /// maximum depth (`DEPTH`), or if `leaves` is not sorted.
fn build_subtree( fn build_subtree(
mut leaves: Vec<SubtreeLeaf>, mut leaves: Vec<SubtreeLeaf>,
bottom_depth: u8, bottom_depth: u8,
@ -492,6 +509,57 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
(inner_nodes, leaves) (inner_nodes, leaves)
} }
fn build_subtrees(
mut entries: Vec<(Self::Key, Self::Value)>,
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<LeafIndex<DEPTH>, Self::Leaf>) {
use rayon::prelude::*;
entries.sort_by_key(|item| {
let index = Self::key_to_leaf_index(&item.0);
index.value()
});
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
let PairComputations {
leaves: mut leaf_subtrees,
nodes: initial_leaves,
} = Self::sorted_pairs_to_leaves(entries);
for current_depth in (8..=DEPTH).step_by(8).rev() {
let (nodes, subtrees): (Vec<BTreeMap<_, _>>, Vec<Vec<SubtreeLeaf>>) = leaf_subtrees
.into_par_iter()
.map(|subtree| {
debug_assert!(subtree.is_sorted());
debug_assert!(!subtree.is_empty());
let (nodes, next_leaves) = Self::build_subtree(subtree, current_depth);
debug_assert!(next_leaves.is_sorted());
(nodes, next_leaves)
})
.unzip();
let mut all_leaves: Vec<SubtreeLeaf> = subtrees.into_iter().flatten().collect();
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect();
accumulated_nodes.extend(nodes.into_iter().flatten());
debug_assert!(!leaf_subtrees.is_empty());
}
let leaves: BTreeMap<LeafIndex<DEPTH>, Self::Leaf> = initial_leaves
.into_iter()
.map(|(key, value)| {
// FIXME: unwrap is unreachable?
let key = LeafIndex::<DEPTH>::new(key).unwrap();
(key, value)
})
.collect();
(accumulated_nodes, leaves)
}
} }
// INNER NODE // INNER NODE
@ -616,22 +684,18 @@ impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
/// A depth-8 subtree contains 256 "columns" that can possibly be occupied. /// A depth-8 subtree contains 256 "columns" that can possibly be occupied.
const COLS_PER_SUBTREE: u64 = u64::pow(2, 8); const COLS_PER_SUBTREE: u64 = u64::pow(2, 8);
/// Helper struct for organizing the data we care about when computing Merkle subtrees.
///
/// Note that these represet "conceptual" leaves of some subtree, not necessarily
/// [`SparseMerkleTree::Leaf`].
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
pub struct SubtreeLeaf { pub struct SubtreeLeaf {
/// The 'value' field of [`NodeIndex`]. When computing a subtree, the depth is already known.
pub col: u64, pub col: u64,
/// The hash of the node this `SubtreeLeaf` represents.
pub hash: RpoDigest, pub hash: RpoDigest,
} }
impl SubtreeLeaf {
#[cfg_attr(not(test), allow(dead_code))]
fn from_smt_leaf(leaf: &crate::merkle::SmtLeaf) -> Self {
Self {
col: leaf.index().index.value(),
hash: leaf.hash(),
}
}
}
/// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`]. /// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`].
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct PairComputations<K, L> { pub(crate) struct PairComputations<K, L> {
@ -681,21 +745,77 @@ fn add_subtree_leaf(subtrees: &mut Vec<Vec<SubtreeLeaf>>, leaf: SubtreeLeaf) {
} }
} }
#[derive(Debug)]
struct SubtreeLeavesIter<'s> {
leaves: core::iter::Peekable<alloc::vec::Drain<'s, SubtreeLeaf>>,
}
impl<'s> SubtreeLeavesIter<'s> {
fn from_leaves(leaves: &'s mut Vec<SubtreeLeaf>) -> Self {
Self { leaves: leaves.drain(..).peekable() }
}
}
impl<'s> core::iter::Iterator for SubtreeLeavesIter<'s> {
type Item = Vec<SubtreeLeaf>;
/// Each `next()` collects an entire subtree.
fn next(&mut self) -> Option<Vec<SubtreeLeaf>> {
let mut subtree: Vec<SubtreeLeaf> = Default::default();
let mut last_subtree_col = 0;
while let Some(leaf) = self.leaves.peek() {
last_subtree_col = u64::max(1, last_subtree_col);
let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE);
let next_subtree_col = if is_exact_multiple {
u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE)
} else {
last_subtree_col.next_multiple_of(COLS_PER_SUBTREE)
};
last_subtree_col = leaf.col;
if leaf.col < next_subtree_col {
subtree.push(self.leaves.next().unwrap());
} else if subtree.is_empty() {
continue;
} else {
break;
}
}
if subtree.is_empty() {
debug_assert!(self.leaves.peek().is_none());
return None;
}
Some(subtree)
}
}
// TESTS // TESTS
// ================================================================================================ // ================================================================================================
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use core::mem;
use alloc::{collections::BTreeMap, vec::Vec}; use alloc::{collections::BTreeMap, vec::Vec};
use super::{SparseMerkleTree, SubtreeLeaf}; use super::{SparseMerkleTree, SubtreeLeaf};
use crate::{ use crate::{
hash::rpo::RpoDigest, hash::rpo::RpoDigest,
merkle::{smt::InnerNode, NodeIndex, Smt, SmtLeaf, SMT_DEPTH}, merkle::{
smt::{InnerNode, PairComputations, SubtreeLeavesIter},
LeafIndex, NodeIndex, Smt, SmtLeaf, SMT_DEPTH,
},
Felt, Word, ONE, Felt, Word, ONE,
}; };
fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf {
SubtreeLeaf {
col: leaf.index().index.value(),
hash: leaf.hash(),
}
}
#[test] #[test]
fn test_sorted_pairs_to_leaves() { fn test_sorted_pairs_to_leaves() {
let entries: Vec<(RpoDigest, Word)> = vec![ let entries: Vec<(RpoDigest, Word)> = vec![
@ -744,7 +864,7 @@ mod test {
// Subtree 2. // Subtree 2.
vec![next_leaf()], vec![next_leaf()],
] ]
.map(|subtree| subtree.into_iter().map(SubtreeLeaf::from_smt_leaf).collect()) .map(|subtree| subtree.into_iter().map(smtleaf_to_subtree_leaf).collect())
.to_vec(); .to_vec();
assert_eq!(control_leaves_iter.next(), None); assert_eq!(control_leaves_iter.next(), None);
control_subtree_leaves control_subtree_leaves
@ -871,63 +991,195 @@ mod test {
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default(); let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
let starting_leaves = Smt::sorted_pairs_to_leaves(entries); let PairComputations {
leaves: mut leaf_subtrees,
nodes: test_leaves,
} = Smt::sorted_pairs_to_leaves(entries);
let mut leaf_subtrees = starting_leaves.leaves;
for current_depth in (8..=SMT_DEPTH).step_by(8).rev() { for current_depth in (8..=SMT_DEPTH).step_by(8).rev() {
for (i, subtree) in mem::take(&mut leaf_subtrees).into_iter().enumerate() { // There's no flat_map_unzip(), so this is the best we can do.
// Pre-assertions. let (nodes, subtrees): (Vec<BTreeMap<_, _>>, Vec<Vec<SubtreeLeaf>>) = leaf_subtrees
assert!( .into_iter()
subtree.is_sorted(), .enumerate()
"subtree {i} at bottom-depth {current_depth} is not sorted", .map(|(i, subtree)| {
); // Pre-assertions.
assert!( assert!(
!subtree.is_empty(), subtree.is_sorted(),
"subtree {i} at bottom-depth {current_depth} is empty!", "subtree {i} at bottom-depth {current_depth} is not sorted",
); );
assert!(
// Do actual things. !subtree.is_empty(),
let (nodes, next_leaves) = Smt::build_subtree(subtree, current_depth); "subtree {i} at bottom-depth {current_depth} is empty!",
// Post-assertions.
assert!(next_leaves.is_sorted());
for (&index, test_node) in nodes.iter() {
let control_node = control.get_inner_node(index);
assert_eq!(
test_node, &control_node,
"depth {} subtree {}: test node does not match control at index {:?}",
current_depth, i, index,
); );
}
// Update state. // Do actual things.
accumulated_nodes.extend(nodes); let (nodes, next_leaves) = Smt::build_subtree(subtree, current_depth);
// Post-assertions.
assert!(next_leaves.is_sorted());
for subtree_leaf in next_leaves { for (&index, test_node) in nodes.iter() {
super::add_subtree_leaf(&mut leaf_subtrees, subtree_leaf); let control_node = control.get_inner_node(index);
} assert_eq!(
} test_node, &control_node,
"depth {} subtree {}: test node does not match control at index {:?}",
current_depth, i, index,
);
}
(nodes, next_leaves)
})
.unzip();
// Update state between each depth iteration.
// FIXME: is this flatten or Box<dyn Iterator> better?
let mut all_leaves: Vec<SubtreeLeaf> = subtrees.into_iter().flatten().collect();
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect();
accumulated_nodes.extend(nodes.into_iter().flatten());
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
} }
// Make sure the true leaves match, checking length first and then each individual leaf.
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
let control_leaves_len = control_leaves.len();
let test_leaves_len = test_leaves.len();
assert_eq!(test_leaves_len, control_leaves_len);
for (col, ref test_leaf) in test_leaves {
let index = LeafIndex::new_max_depth(col);
let &control_leaf = control_leaves.get(&index).unwrap();
assert_eq!(test_leaf, control_leaf);
}
// Make sure the inner nodes match, checking length first and then each individual leaf.
let control_nodes_len = control.inner_nodes().count();
let test_nodes_len = accumulated_nodes.len();
assert_eq!(test_nodes_len, control_nodes_len);
for (index, test_node) in accumulated_nodes.clone() { for (index, test_node) in accumulated_nodes.clone() {
let control_node = control.get_inner_node(index); let control_node = control.get_inner_node(index);
assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
} }
assert_eq!(leaf_subtrees.len(), 1); // After the last iteration of the above for loop, we should have the new root node actually
let mut leaf_subtree = leaf_subtrees.pop().unwrap(); // in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from
assert_eq!(leaf_subtree.len(), 1); // `build_subtree()`. So let's check both!
let root_leaf = leaf_subtree.pop().unwrap();
let control_root = control.get_inner_node(NodeIndex::root());
// That for loop should have left us with only one leaf subtree...
let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap();
// which itself contains only one 'leaf'...
let [root_leaf]: [_; 1] = leaf_subtree.try_into().unwrap();
// which matches the expected root.
assert_eq!(control.root(), root_leaf.hash); assert_eq!(control.root(), root_leaf.hash);
// Do we have a root? // Likewise `accumulated_nodes` should contain a node at the root index...
assert!(accumulated_nodes.contains_key(&NodeIndex::root())); assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
// and it should match our actual root.
// And does it match?
let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap(); let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap();
assert_eq!(control.root(), test_root.hash()); assert_eq!(control_root, *test_root);
// And of course the root we got from each place should match.
assert_eq!(control.root(), root_leaf.hash);
}
#[test]
fn test_multithreaded_subtrees() {
use rayon::prelude::*;
const PAIR_COUNT: u64 = 4096 * 4;
let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries(entries.clone()).unwrap();
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
let PairComputations {
leaves: mut leaf_subtrees,
nodes: test_leaves,
} = Smt::sorted_pairs_to_leaves(entries);
for current_depth in (8..=SMT_DEPTH).step_by(8).rev() {
let (nodes, subtrees): (Vec<BTreeMap<_, _>>, Vec<Vec<SubtreeLeaf>>) = leaf_subtrees
.into_par_iter()
.enumerate()
.map(|(i, subtree)| {
// Pre-assertions.
assert!(
subtree.is_sorted(),
"subtree {i} at bottom-depth {current_depth} is not sorted",
);
assert!(
!subtree.is_empty(),
"subtree {i} at bottom-depth {current_depth} is empty!",
);
let (nodes, next_leaves) = Smt::build_subtree(subtree, current_depth);
// Post-assertions.
assert!(next_leaves.is_sorted());
for (&index, test_node) in nodes.iter() {
let control_node = control.get_inner_node(index);
assert_eq!(
test_node, &control_node,
"depth {} subtree {}: test node does not match control at index {:?}",
current_depth, i, index,
);
}
(nodes, next_leaves)
})
// FIXME: unzip_into_vecs() instead?
.unzip();
let mut all_leaves: Vec<SubtreeLeaf> = subtrees.into_iter().flatten().collect();
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect();
accumulated_nodes.extend(nodes.into_iter().flatten());
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
}
// Make sure the true leaves match, checking length first and then each individual leaf.
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
let control_leaves_len = control_leaves.len();
let test_leaves_len = test_leaves.len();
assert_eq!(test_leaves_len, control_leaves_len);
for (col, ref test_leaf) in test_leaves {
let index = LeafIndex::new_max_depth(col);
let &control_leaf = control_leaves.get(&index).unwrap();
assert_eq!(test_leaf, control_leaf);
}
// Make sure the inner nodes match, checking length first and then each individual leaf.
let control_nodes_len = control.inner_nodes().count();
let test_nodes_len = accumulated_nodes.len();
assert_eq!(test_nodes_len, control_nodes_len);
for (index, test_node) in accumulated_nodes.clone() {
let control_node = control.get_inner_node(index);
assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
}
// After the last iteration of the above for loop, we should have the new root node actually
// in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from
// `build_subtree()`. So let's check both!
let control_root = control.get_inner_node(NodeIndex::root());
// That for loop should have left us with only one leaf subtree...
let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap();
// which itself contains only one 'leaf'...
let [root_leaf]: [_; 1] = leaf_subtree.try_into().unwrap();
// which matches the expected root.
assert_eq!(control.root(), root_leaf.hash);
// Likewise `accumulated_nodes` should contain a node at the root index...
assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
// and it should match our actual root.
let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap();
assert_eq!(control_root, *test_root);
// And of course the root we got from each place should match.
assert_eq!(control.root(), root_leaf.hash); assert_eq!(control.root(), root_leaf.hash);
} }
} }

View file

@ -100,6 +100,14 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
Ok(tree) Ok(tree)
} }
pub fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, Word>,
root: RpoDigest,
) -> Result<Self, MerkleError> {
<Self as SparseMerkleTree<DEPTH>>::from_raw_parts(inner_nodes, leaves, root)
}
/// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices /// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices
/// starting at index 0. /// starting at index 0.
pub fn with_contiguous_leaves( pub fn with_contiguous_leaves(
@ -309,6 +317,27 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
const EMPTY_VALUE: Self::Value = EMPTY_WORD; const EMPTY_VALUE: Self::Value = EMPTY_WORD;
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0); const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0);
fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, Word>,
root: RpoDigest,
) -> Result<Self, MerkleError> {
if cfg!(debug_assertions) {
let root_node = inner_nodes.get(&NodeIndex::root()).unwrap();
assert_eq!(root_node.hash(), root);
}
Ok(Self { root, inner_nodes, leaves })
}
fn with_entries(
entries: impl IntoIterator<Item = (LeafIndex<DEPTH>, Word)>,
) -> Result<Self, MerkleError> {
<SimpleSmt<DEPTH>>::with_leaves(
entries.into_iter().map(|(key, value)| (key.value(), value)),
)
}
fn root(&self) -> RpoDigest { fn root(&self) -> RpoDigest {
self.root self.root
} }

View file

@ -1,3 +1,6 @@
use core::mem;
use std::collections::BTreeMap;
use alloc::vec::Vec; use alloc::vec::Vec;
use super::{ use super::{
@ -7,10 +10,11 @@ use super::{
use crate::{ use crate::{
hash::rpo::Rpo256, hash::rpo::Rpo256,
merkle::{ merkle::{
digests_to_words, int_to_leaf, int_to_node, smt::SparseMerkleTree, EmptySubtreeRoots, digests_to_words, int_to_leaf, int_to_node,
InnerNodeInfo, LeafIndex, MerkleTree, smt::{self, InnerNode, PairComputations, SparseMerkleTree},
EmptySubtreeRoots, InnerNodeInfo, LeafIndex, MerkleTree, SubtreeLeaf,
}, },
Word, EMPTY_WORD, Felt, Word, EMPTY_WORD, ONE,
}; };
// TEST DATA // TEST DATA
@ -461,6 +465,106 @@ fn test_simplesmt_check_empty_root_constant() {
assert_eq!(empty_root_1_depth, SimpleSmt::<1>::EMPTY_ROOT); assert_eq!(empty_root_1_depth, SimpleSmt::<1>::EMPTY_ROOT);
} }
#[test]
fn test_simplesmt_subtrees() {
const PAIR_COUNT: u64 = 4096;
const DEPTH: u8 = 64;
type SimpleSmt = super::SimpleSmt<DEPTH>;
let entries: Vec<(LeafIndex<DEPTH>, Word)> = (0..PAIR_COUNT)
.map(|i| {
let leaf_index = ((i as f64 / PAIR_COUNT as f64) * (PAIR_COUNT as f64)) as u64;
let key = LeafIndex::new_max_depth(leaf_index);
let value: Word = [ONE, ONE, ONE, Felt::new(i)];
(key, value)
})
.collect();
let leaves = entries.iter().map(|(key, value)| (key.value(), *value));
let control = SimpleSmt::with_leaves(leaves).unwrap();
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
let PairComputations {
leaves: mut leaf_subtrees,
nodes: test_leaves,
} = SimpleSmt::sorted_pairs_to_leaves(entries);
for current_depth in (8..=DEPTH).step_by(8).rev() {
for (i, subtree) in mem::take(&mut leaf_subtrees).into_iter().enumerate() {
// Pre-assertions.
assert!(
subtree.is_sorted(),
"subtree {i} at bottom-depth {current_depth} is not sorted",
);
assert!(!subtree.is_empty(), "subtree {i} at bottom-depth {current_depth} is empty!");
// Do actual things.
let (nodes, next_leaves) = SimpleSmt::build_subtree(subtree, current_depth);
// Post-assertions.
assert!(next_leaves.is_sorted());
for (&index, test_node) in nodes.iter() {
let control_node = control.get_inner_node(index);
assert_eq!(
test_node, &control_node,
"depth {} subtree {}: test node does not match control at index {:?}",
current_depth, i, index,
);
}
// Update state.
accumulated_nodes.extend(nodes);
for subtree_leaf in next_leaves {
smt::add_subtree_leaf(&mut leaf_subtrees, subtree_leaf);
}
}
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
}
// Make sure the true leaves match, checking length first and then each individual leaf.
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
let control_leaves_len = control_leaves.len();
let test_leaves_len = test_leaves.len();
assert_eq!(test_leaves_len, control_leaves_len);
for (col, ref test_leaf) in test_leaves {
let &control_leaf = control_leaves.get(&col).unwrap();
assert_eq!(test_leaf, control_leaf);
}
// Make sure inner nodes match, checking length first and then each individual leaf.
let control_nodes_len = control.inner_nodes().count();
let test_nodes_len = accumulated_nodes.len();
assert_eq!(test_nodes_len, control_nodes_len);
for (index, test_node) in accumulated_nodes.clone() {
let control_node = control.get_inner_node(index);
assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
}
// After the last iteration of the above for loop, we should have the new root actually in two
// places: one in `accumulated_nodes`, and the other as the "next leaves" return from
// `build_subtree()`. So let's check both!
let control_root = control.get_inner_node(NodeIndex::root());
// That for loop should have left us with only one leaf subtree...
let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap();
// which itself contains only one 'leaf'...
let [SubtreeLeaf { hash: test_root_hash, .. }]: [_; 1] = leaf_subtree.try_into().unwrap();
// which matches the expected root.
assert_eq!(control.root(), test_root_hash);
// Likewise `accumulated_nodes` should contain a node at the root index...
assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
// and it should match our actual root.
let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap();
assert_eq!(control_root, *test_root);
// And of course the root we got from each place should match.
assert_eq!(control.root(), test_root_hash);
}
// HELPER FUNCTIONS // HELPER FUNCTIONS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------