From 520fecaf65e594a970af6dc1652bb888c8614f34 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Mon, 28 Oct 2024 15:38:42 -0600 Subject: [PATCH 01/21] add sorted_pairs_to_leaves() and test for it --- src/merkle/smt/mod.rs | 259 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 259 insertions(+) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 03d9d45..fc6ace7 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -1,4 +1,7 @@ use alloc::{collections::BTreeMap, vec::Vec}; +use core::mem; + +use num::Integer; use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex}; use crate::{ @@ -346,6 +349,146 @@ pub(crate) trait SparseMerkleTree { /// /// The length `path` is guaranteed to be equal to `DEPTH` fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening; + + fn sorted_pairs_to_leaves( + pairs: Vec<(Self::Key, Self::Value)>, + ) -> PairComputations { + let mut all_leaves = PairComputations::default(); + + let mut buffer: Vec<(Self::Key, Self::Value)> = Default::default(); + + let mut iter = pairs.into_iter().peekable(); + while let Some((key, value)) = iter.next() { + let col = Self::key_to_leaf_index(&key).index.value(); + let next_col = iter.peek().map(|(key, _)| { + let index = Self::key_to_leaf_index(key); + index.index.value() + }); + + buffer.push((key, value)); + + if let Some(next_col) = next_col { + assert!(next_col >= col); + } + + if next_col == Some(col) { + // Keep going in our buffer. + continue; + } + + // Whether the next pair is a different column, or non-existent, we break off. + let leaf_pairs = mem::take(&mut buffer); + let leaf = Self::pairs_to_leaf(leaf_pairs); + let hash = Self::hash_leaf(&leaf); + + all_leaves.nodes.insert(col, leaf); + all_leaves.subtrees.push(SubtreeLeaf { col, hash }); + } + assert_eq!(buffer.len(), 0); + + all_leaves + } + + /// Builds Merkle nodes from a bottom layer of tuples of horizontal indices and their hashes, + /// sorted by their position. + /// + /// The leaves are 'conceptual' leaves, simply being entities at the bottom of some subtree, not + /// [`Self::Leaf`]. + /// + /// # Panics + /// With debug assertions on, this function panics under invalid inputs: if `leaves` contains + /// more entries than can fit in a depth-8 subtree (more than 256), if `bottom_depth` is + /// lower in the tree than the specified maximum depth (`DEPTH`), or if `leaves` is not sorted. + // FIXME: more complete docstring. + fn build_subtree( + mut leaves: Vec, + bottom_depth: u8, + ) -> (BTreeMap, Vec) { + debug_assert!(bottom_depth <= DEPTH); + debug_assert!(Integer::is_multiple_of(&bottom_depth, &8)); + debug_assert!(leaves.len() <= usize::pow(2, 8)); + + let subtree_root = bottom_depth - 8; + + let mut inner_nodes: BTreeMap = Default::default(); + + let mut next_leaves: Vec = Vec::with_capacity(leaves.len() / 2); + + for next_depth in (subtree_root..bottom_depth).rev() { + debug_assert!(next_depth <= bottom_depth); + + // `next_depth` is the stuff we're making. + // `current_depth` is the stuff we have. + let current_depth = next_depth + 1; + + let mut iter = leaves.drain(..).peekable(); + while let Some(first) = iter.next() { + // On non-continuous iterations, including the first iteration, `first_column` may + // be a left or right node. On subsequent continuous iterations, we will always call + // `iter.next()` twice. + + // On non-continuous iterations (including the very first iteration), this column + // could be either on the left or the right. If the next iteration is not + // discontinuous with our right node, then the next iteration's + + let is_right = first.col.is_odd(); + let (left, right) = if is_right { + // Discontinuous iteration: we have no left node, so it must be empty. + + let left = SubtreeLeaf { + col: first.col - 1, + hash: *EmptySubtreeRoots::entry(DEPTH, current_depth), + }; + let right = first; + + (left, right) + } else { + let left = first; + + let right_col = first.col + 1; + let right = match iter.peek().copied() { + Some(SubtreeLeaf { col, .. }) if col == right_col => { + // Our inputs must be sorted. + debug_assert!(left.col <= col); + // The next leaf in the iterator is our sibling. Use it and consume it! + iter.next().unwrap() + }, + // Otherwise, the leaves don't contain our sibling, so our sibling must be + // empty. + _ => SubtreeLeaf { + col: right_col, + hash: *EmptySubtreeRoots::entry(DEPTH, current_depth), + }, + }; + + (left, right) + }; + + let index = NodeIndex::new_unchecked(current_depth, left.col).parent(); + let node = InnerNode { left: left.hash, right: right.hash }; + let hash = node.hash(); + + let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, next_depth); + // If this hash is empty, then it doesn't become a new inner node, nor does it count + // as a leaf for the next depth. + if hash != equivalent_empty_hash { + inner_nodes.insert(index, node); + // FIXME: is it possible for this to end up not being sorted? I don't think so. + next_leaves.push(SubtreeLeaf { col: index.value(), hash }); + } + } + + // Stop borrowing `leaves`, so we can swap it. + // The iterator is empty at this point anyway. + drop(iter); + + // After each depth, consider the stuff we just made the new "leaves", and empty the + // other collection. + mem::swap(&mut leaves, &mut next_leaves); + } + + (inner_nodes, leaves) + } } // INNER NODE @@ -463,3 +606,119 @@ impl MutationSet { self.new_root } } + +// HELPERS +// ================================================================================================ +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] +pub struct SubtreeLeaf { + pub col: u64, + pub hash: RpoDigest, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PairComputations { + /// Literal leaves to be added to the sparse Merkle tree's internal mapping. + pub nodes: BTreeMap, + /// "Conceptual" leaves that will be used for computations. + pub subtrees: Vec, +} + +// Derive requires `L` to impl Default, even though we don't actually need that. +impl Default for PairComputations { + fn default() -> Self { + Self { + nodes: Default::default(), + subtrees: Default::default(), + } + } +} + +// TESTS +// ================================================================================================ +#[cfg(test)] +mod test { + use alloc::vec::Vec; + + use super::SparseMerkleTree; + use crate::{ + hash::rpo::RpoDigest, + merkle::{smt::SubtreeLeaf, Smt, SmtLeaf, SMT_DEPTH}, + Felt, Word, EMPTY_WORD, ONE, + }; + + #[test] + fn test_sorted_pairs_to_leaves() { + let entries: Vec<(RpoDigest, Word)> = vec![ + (RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]), + (RpoDigest::new([ONE, ONE, ONE, Felt::new(17)]), [ONE; 4]), + // Leaf index collision. + (RpoDigest::new([ONE, ONE, Felt::new(10), Felt::new(20)]), [ONE; 4]), + (RpoDigest::new([ONE, ONE, Felt::new(20), Felt::new(20)]), [ONE; 4]), + // Normal single leaf again. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), + // Empty leaf. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(500)]), EMPTY_WORD), + ]; + let mut entries_iter = entries.iter().cloned(); + let mut next_entry = || entries_iter.next().unwrap(); + + let control_leaves: Vec = vec![ + SmtLeaf::Single(next_entry()), + SmtLeaf::Single(next_entry()), + SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(), + SmtLeaf::Single(next_entry()), + SmtLeaf::new_empty(Smt::key_to_leaf_index(&next_entry().0)), + ]; + + let control_subtree_leaves: Vec = control_leaves + .iter() + .map(|leaf| { + let col = leaf.index().index.value(); + let hash = leaf.hash(); + SubtreeLeaf { col, hash } + }) + .collect(); + + let test_subtree_leaves = Smt::sorted_pairs_to_leaves(entries).subtrees; + assert_eq!(control_subtree_leaves, test_subtree_leaves); + } + + #[test] + fn test_build_subtree_from_leaves() { + const PAIR_COUNT: u64 = u64::pow(2, 8); + + let entries: Vec<(RpoDigest, Word)> = (0..PAIR_COUNT) + .map(|i| { + let leaf_index = ((i as f64 / PAIR_COUNT as f64) * (PAIR_COUNT as f64)) as u64; + let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]); + let value = [ONE, ONE, ONE, Felt::new(i)]; + (key, value) + }) + .collect(); + + let control = Smt::with_entries(entries.clone()).unwrap(); + + let mut leaves: Vec = entries + .iter() + .map(|(key, value)| { + let leaf = SmtLeaf::new_single(*key, *value); + let col = leaf.index().index.value(); + let hash = leaf.hash(); + SubtreeLeaf { col, hash } + }) + .collect(); + leaves.sort(); + leaves.dedup_by_key(|leaf| leaf.col); + + let (first_subtree, _) = Smt::build_subtree(leaves, SMT_DEPTH); + assert!(!first_subtree.is_empty()); + + for (index, node) in first_subtree.into_iter() { + let control = control.get_inner_node(index); + assert_eq!( + control, node, + "subtree-computed node at index {index:?} does not match control", + ); + } + } +} From 98e5e0a5b20e995c0fcc4234a123d3ea47557d5e Mon Sep 17 00:00:00 2001 From: Qyriad Date: Wed, 23 Oct 2024 19:26:42 -0600 Subject: [PATCH 02/21] WIP(smt): impl simple subtree8 hashing and benchmarks for it bench(smt-subtree): add a benchmark for single-leaf subtrees make build_subtree also return the next leaf row convert (col, hash) tuples to a dedicated struct --- Cargo.toml | 4 ++ benches/smt-subtree.rs | 136 +++++++++++++++++++++++++++++++++++++ src/merkle/mod.rs | 2 +- src/merkle/smt/full/mod.rs | 9 ++- src/merkle/smt/mod.rs | 2 +- 5 files changed, 150 insertions(+), 3 deletions(-) create mode 100644 benches/smt-subtree.rs diff --git a/Cargo.toml b/Cargo.toml index 5d124c6..ec59d42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,10 @@ harness = false name = "smt" harness = false +[[bench]] +name = "smt-subtree" +harness = false + [[bench]] name = "store" harness = false diff --git a/benches/smt-subtree.rs b/benches/smt-subtree.rs new file mode 100644 index 0000000..cd7454a --- /dev/null +++ b/benches/smt-subtree.rs @@ -0,0 +1,136 @@ +use std::{fmt::Debug, hint, mem, time::Duration}; + +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use miden_crypto::{ + hash::rpo::RpoDigest, + merkle::{NodeIndex, Smt, SmtLeaf, SubtreeLeaf, SMT_DEPTH}, + Felt, Word, ONE, +}; +use rand_utils::prng_array; +use winter_utils::Randomizable; + +const PAIR_COUNTS: [u64; 5] = [1, 64, 128, 192, 256]; + +fn smt_subtree_even(c: &mut Criterion) { + let mut seed = [0u8; 32]; + + let mut group = c.benchmark_group("subtree8-even"); + + for pair_count in PAIR_COUNTS { + let bench_id = BenchmarkId::from_parameter(pair_count); + group.bench_with_input(bench_id, &pair_count, |b, &pair_count| { + b.iter_batched( + || { + // Setup. + let entries: Vec<(RpoDigest, Word)> = (0..pair_count) + .map(|n| { + // A single depth-8 subtree can have a maximum of 255 leaves. + let leaf_index = ((n as f64 / pair_count as f64) * 255.0) as u64; + let key = RpoDigest::new([ + generate_value(&mut seed), + ONE, + Felt::new(n), + Felt::new(leaf_index), + ]); + let value = generate_word(&mut seed); + (key, value) + }) + .collect(); + + let mut leaves: Vec<_> = entries + .iter() + .map(|(key, value)| { + let leaf = SmtLeaf::new_single(*key, *value); + let col = NodeIndex::from(leaf.index()).value(); + let hash = leaf.hash(); + SubtreeLeaf { col, hash } + }) + .collect(); + leaves.sort(); + leaves.dedup_by_key(|leaf| leaf.col); + leaves + }, + |leaves| { + // Benchmarked function. + let (subtree, _) = + Smt::build_subtree(hint::black_box(leaves), hint::black_box(SMT_DEPTH)); + assert!(!subtree.is_empty()); + }, + BatchSize::SmallInput, + ); + }); + } +} + +fn smt_subtree_random(c: &mut Criterion) { + let mut seed = [0u8; 32]; + + let mut group = c.benchmark_group("subtree8-rand"); + + for pair_count in PAIR_COUNTS { + let bench_id = BenchmarkId::from_parameter(pair_count); + group.bench_with_input(bench_id, &pair_count, |b, &pair_count| { + b.iter_batched( + || { + // Setup. + let entries: Vec<(RpoDigest, Word)> = (0..pair_count) + .map(|i| { + let leaf_index: u8 = generate_value(&mut seed); + let key = RpoDigest::new([ + ONE, + ONE, + Felt::new(i), + Felt::new(leaf_index as u64), + ]); + let value = generate_word(&mut seed); + (key, value) + }) + .collect(); + + let mut leaves: Vec<_> = entries + .iter() + .map(|(key, value)| { + let leaf = SmtLeaf::new_single(*key, *value); + let col = NodeIndex::from(leaf.index()).value(); + let hash = leaf.hash(); + SubtreeLeaf { col, hash } + }) + .collect(); + leaves.sort(); + leaves + }, + |leaves| { + let (subtree, _) = + Smt::build_subtree(hint::black_box(leaves), hint::black_box(SMT_DEPTH)); + assert!(!subtree.is_empty()); + }, + BatchSize::SmallInput, + ); + }); + } +} + +criterion_group! { + name = smt_subtree_group; + config = Criterion::default() + .measurement_time(Duration::from_secs(40)) + .sample_size(60) + .configure_from_args(); + targets = smt_subtree_even, smt_subtree_random +} +criterion_main!(smt_subtree_group); + +// HELPER FUNCTIONS +// -------------------------------------------------------------------------------------------- + +fn generate_value(seed: &mut [u8; 32]) -> T { + mem::swap(seed, &mut prng_array(*seed)); + let value: [T; 1] = rand_utils::prng_array(*seed); + value[0] +} + +fn generate_word(seed: &mut [u8; 32]) -> Word { + mem::swap(seed, &mut prng_array(*seed)); + let nums: [u64; 4] = prng_array(*seed); + [Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])] +} diff --git a/src/merkle/mod.rs b/src/merkle/mod.rs index a562aa5..dc897dc 100644 --- a/src/merkle/mod.rs +++ b/src/merkle/mod.rs @@ -23,7 +23,7 @@ pub use path::{MerklePath, RootPath, ValuePath}; mod smt; pub use smt::{ LeafIndex, MutationSet, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, - SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH, + SubtreeLeaf, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH, }; mod mmr; diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 1e2c574..89d0702 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -6,7 +6,7 @@ use alloc::{ use super::{ EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath, - MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, + MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, SubtreeLeaf, Word, EMPTY_WORD, }; mod error; @@ -249,6 +249,13 @@ impl Smt { None } } + + pub fn build_subtree( + leaves: Vec, + bottom_depth: u8, + ) -> (BTreeMap, Vec) { + >::build_subtree(leaves, bottom_depth) + } } impl SparseMerkleTree for Smt { diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index fc6ace7..0c87724 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -496,7 +496,7 @@ pub(crate) trait SparseMerkleTree { #[derive(Debug, Default, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] -pub(crate) struct InnerNode { +pub struct InnerNode { pub left: RpoDigest, pub right: RpoDigest, } From 1bc790586aeb93cd076336e0f22ed3eebd66ebe6 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Tue, 29 Oct 2024 12:29:20 -0600 Subject: [PATCH 03/21] refactor test_build_subtree_from_leaves() --- src/merkle/smt/mod.rs | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 0c87724..0c55bd0 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -683,32 +683,29 @@ mod test { assert_eq!(control_subtree_leaves, test_subtree_leaves); } - #[test] - fn test_build_subtree_from_leaves() { - const PAIR_COUNT: u64 = u64::pow(2, 8); - - let entries: Vec<(RpoDigest, Word)> = (0..PAIR_COUNT) + // Helper for the below tests. + fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> { + (0..pair_count) .map(|i| { - let leaf_index = ((i as f64 / PAIR_COUNT as f64) * (PAIR_COUNT as f64)) as u64; + let leaf_index = ((i as f64 / pair_count as f64) * (pair_count as f64)) as u64; let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]); let value = [ONE, ONE, ONE, Felt::new(i)]; (key, value) }) - .collect(); + .collect() + } + + #[test] + fn test_build_subtree_from_leaves() { + // A single subtree's worth of leaves. + const PAIR_COUNT: u64 = 256; + + let entries = generate_entries(PAIR_COUNT); let control = Smt::with_entries(entries.clone()).unwrap(); - let mut leaves: Vec = entries - .iter() - .map(|(key, value)| { - let leaf = SmtLeaf::new_single(*key, *value); - let col = leaf.index().index.value(); - let hash = leaf.hash(); - SubtreeLeaf { col, hash } - }) - .collect(); - leaves.sort(); - leaves.dedup_by_key(|leaf| leaf.col); + // `entries` should already be sorted by nature of how we constructed it. + let leaves = Smt::sorted_pairs_to_leaves(entries).subtrees; let (first_subtree, _) = Smt::build_subtree(leaves, SMT_DEPTH); assert!(!first_subtree.is_empty()); From c35e18453a0884ababa2eb343fa06983caf28c80 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Fri, 25 Oct 2024 13:31:48 -0600 Subject: [PATCH 04/21] merkle: add a benchmark for constructing 256-leaf balanced trees --- Cargo.toml | 4 ++++ benches/merkle.rs | 59 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 benches/merkle.rs diff --git a/Cargo.toml b/Cargo.toml index ec59d42..74df3ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,10 @@ harness = false name = "smt-subtree" harness = false +[[bench]] +name = "merkle" +harness = false + [[bench]] name = "store" harness = false diff --git a/benches/merkle.rs b/benches/merkle.rs new file mode 100644 index 0000000..5bd434b --- /dev/null +++ b/benches/merkle.rs @@ -0,0 +1,59 @@ +use std::{hint, mem, time::Duration}; + +use criterion::{criterion_group, criterion_main, BatchSize, Criterion}; +use miden_crypto::{merkle::MerkleTree, Felt, Word, ONE}; +use rand_utils::prng_array; + +fn balanced_merkle_even(c: &mut Criterion) { + c.bench_function("balanced-merkle-even", |b| { + b.iter_batched( + || { + let entries: Vec = + (0..256).map(|i| [Felt::new(i), ONE, ONE, Felt::new(i)]).collect(); + assert_eq!(entries.len(), 256); + entries + }, + |leaves| { + let tree = MerkleTree::new(hint::black_box(leaves)).unwrap(); + assert_eq!(tree.depth(), 8); + }, + BatchSize::SmallInput, + ); + }); +} + +fn balanced_merkle_rand(c: &mut Criterion) { + let mut seed = [0u8; 32]; + c.bench_function("balanced-merkle-rand", |b| { + b.iter_batched( + || { + let entries: Vec = (0..256).map(|_| generate_word(&mut seed)).collect(); + assert_eq!(entries.len(), 256); + entries + }, + |leaves| { + let tree = MerkleTree::new(hint::black_box(leaves)).unwrap(); + assert_eq!(tree.depth(), 8); + }, + BatchSize::SmallInput, + ); + }); +} + +criterion_group! { + name = smt_subtree_group; + config = Criterion::default() + .measurement_time(Duration::from_secs(20)) + .configure_from_args(); + targets = balanced_merkle_even, balanced_merkle_rand +} +criterion_main!(smt_subtree_group); + +// HELPER FUNCTIONS +// -------------------------------------------------------------------------------------------- + +fn generate_word(seed: &mut [u8; 32]) -> Word { + mem::swap(seed, &mut prng_array(*seed)); + let nums: [u64; 4] = prng_array(*seed); + [Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])] +} From 6db08f47149d4b37f056d5a5351b9e9284013ef6 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Tue, 29 Oct 2024 14:17:02 -0600 Subject: [PATCH 05/21] add test_two_subtrees() test --- src/merkle/smt/mod.rs | 81 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 0c55bd0..a658d28 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -623,6 +623,34 @@ pub struct PairComputations { pub subtrees: Vec, } +impl PairComputations { + #[cfg_attr(not(test), allow(dead_code))] + pub fn split_at_column(mut self, col: u64) -> (Self, Self) { + let split_point = match self.subtrees.binary_search_by_key(&col, |key| key.col) { + // Surprisingly, Result has no method like `unwrap_or_unwrap_err() where T == E`. + // Probably because there's no way to write that where bound. + Ok(split_point) | Err(split_point) => split_point, + }; + + let subtrees_right = self.subtrees.split_off(split_point); + let subtrees_left = self.subtrees; + + let nodes_right = self.nodes.split_off(&col); + let nodes_left = self.nodes; + + let left = Self { + nodes: nodes_left, + subtrees: subtrees_left, + }; + let right = Self { + nodes: nodes_right, + subtrees: subtrees_right, + }; + + (left, right) + } +} + // Derive requires `L` to impl Default, even though we don't actually need that. impl Default for PairComputations { fn default() -> Self { @@ -696,7 +724,7 @@ mod test { } #[test] - fn test_build_subtree_from_leaves() { + fn test_single_subtree() { // A single subtree's worth of leaves. const PAIR_COUNT: u64 = 256; @@ -718,4 +746,55 @@ mod test { ); } } + + #[test] + fn test_two_subtrees() { + // Two subtrees' worth of leaves. + const PAIR_COUNT: u64 = 512; + + let entries = generate_entries(PAIR_COUNT); + + let control = Smt::with_entries(entries.clone()).unwrap(); + + let leaves = Smt::sorted_pairs_to_leaves(entries); + let (first, second) = leaves.split_at_column(PAIR_COUNT / 2); + assert_eq!(first.subtrees.len(), second.subtrees.len()); + + let mut current_depth = SMT_DEPTH; + let mut next_leaves: Vec = Default::default(); + + let (first_nodes, leaves) = Smt::build_subtree(first.subtrees, current_depth); + next_leaves.extend(leaves); + + let (second_nodes, leaves) = Smt::build_subtree(second.subtrees, current_depth); + next_leaves.extend(leaves); + + // All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle. + let total_computed = first_nodes.len() + second_nodes.len() + next_leaves.len(); + assert_eq!(total_computed as u64, PAIR_COUNT); + + // Verify the computed nodes of both subtrees. + let computed_nodes = first_nodes.clone().into_iter().chain(second_nodes); + for (index, test_node) in computed_nodes { + let control_node = control.get_inner_node(index); + assert_eq!( + control_node, test_node, + "subtree-computed node at index {index:?} does not match control", + ); + } + + current_depth -= 8; + + let (nodes, next_leaves) = Smt::build_subtree(next_leaves, current_depth); + assert_eq!(nodes.len(), 8); + assert_eq!(next_leaves.len(), 1); + + for (index, test_node) in nodes { + let control_node = control.get_inner_node(index); + assert_eq!( + control_node, test_node, + "subtree-computed node at index {index:?} does not match control", + ); + } + } } From 47e1650a40d2b77e209d65c19e8490632f3cb7c3 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Thu, 31 Oct 2024 13:20:53 -0600 Subject: [PATCH 06/21] refactor sorted_pairs_to_leaves() to also group subtrees --- src/merkle/smt/mod.rs | 208 +++++++++++++++++++++++++++--------------- 1 file changed, 135 insertions(+), 73 deletions(-) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index a658d28..a95e7cb 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -353,40 +353,42 @@ pub(crate) trait SparseMerkleTree { fn sorted_pairs_to_leaves( pairs: Vec<(Self::Key, Self::Value)>, ) -> PairComputations { - let mut all_leaves = PairComputations::default(); + let mut accumulator: PairComputations = Default::default(); - let mut buffer: Vec<(Self::Key, Self::Value)> = Default::default(); + // The kv-pairs we've seen so far that correspond to a single leaf. + let mut current_leaf_buffer: Vec<(Self::Key, Self::Value)> = Default::default(); let mut iter = pairs.into_iter().peekable(); while let Some((key, value)) = iter.next() { let col = Self::key_to_leaf_index(&key).index.value(); - let next_col = iter.peek().map(|(key, _)| { + let peeked_col = iter.peek().map(|(key, _v)| { let index = Self::key_to_leaf_index(key); - index.index.value() + let next_col = index.index.value(); + // We panic if `pairs` is not sorted by column. + debug_assert!(next_col >= col); + next_col }); + current_leaf_buffer.push((key, value)); - buffer.push((key, value)); - - if let Some(next_col) = next_col { - assert!(next_col >= col); - } - - if next_col == Some(col) { - // Keep going in our buffer. + // If the next pair is the same column as this one, then we're done after adding this + // pair to the buffer. + if peeked_col == Some(col) { continue; } - // Whether the next pair is a different column, or non-existent, we break off. - let leaf_pairs = mem::take(&mut buffer); + // Otherwise, the next pair is a different column, or there is no next pair. Either way + // it's time to swap out our buffer. + let leaf_pairs = mem::take(&mut current_leaf_buffer); let leaf = Self::pairs_to_leaf(leaf_pairs); let hash = Self::hash_leaf(&leaf); - all_leaves.nodes.insert(col, leaf); - all_leaves.subtrees.push(SubtreeLeaf { col, hash }); - } - assert_eq!(buffer.len(), 0); + accumulator.nodes.insert(col, leaf); + accumulator.add_leaf(SubtreeLeaf { col, hash }); - all_leaves + debug_assert!(current_leaf_buffer.is_empty()); + } + + accumulator } /// Builds Merkle nodes from a bottom layer of tuples of horizontal indices and their hashes, @@ -615,39 +617,55 @@ pub struct SubtreeLeaf { pub hash: RpoDigest, } +impl SubtreeLeaf { + #[cfg_attr(not(test), allow(dead_code))] + fn from_smt_leaf(leaf: &crate::merkle::SmtLeaf) -> Self { + Self { + col: leaf.index().index.value(), + hash: leaf.hash(), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct PairComputations { /// Literal leaves to be added to the sparse Merkle tree's internal mapping. pub nodes: BTreeMap, /// "Conceptual" leaves that will be used for computations. - pub subtrees: Vec, + pub leaves: Vec>, } impl PairComputations { - #[cfg_attr(not(test), allow(dead_code))] - pub fn split_at_column(mut self, col: u64) -> (Self, Self) { - let split_point = match self.subtrees.binary_search_by_key(&col, |key| key.col) { - // Surprisingly, Result has no method like `unwrap_or_unwrap_err() where T == E`. - // Probably because there's no way to write that where bound. - Ok(split_point) | Err(split_point) => split_point, + pub fn add_leaf(&mut self, leaf: SubtreeLeaf) { + // A depth-8 subtree contains 256 "columns" that can possibly be occupied. + const COLS_PER_SUBTREE: u64 = u64::pow(2, 8); + + let last_subtree = match self.leaves.last_mut() { + // Base case. + None => { + self.leaves.push(vec![leaf]); + return; + }, + Some(last_subtree) => last_subtree, }; - let subtrees_right = self.subtrees.split_off(split_point); - let subtrees_left = self.subtrees; + debug_assert!(!last_subtree.is_empty()); + debug_assert!(last_subtree.len() <= COLS_PER_SUBTREE as usize); - let nodes_right = self.nodes.split_off(&col); - let nodes_left = self.nodes; - - let left = Self { - nodes: nodes_left, - subtrees: subtrees_left, + // The multiple of 256 after 0 is 1, but 0 and 1 do not belong to different subtrees. + let last_subtree_col = u64::max(1, last_subtree.last().unwrap().col); + let next_subtree_col = if last_subtree_col.is_multiple_of(&COLS_PER_SUBTREE) { + u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE) + } else { + last_subtree_col.next_multiple_of(COLS_PER_SUBTREE) }; - let right = Self { - nodes: nodes_right, - subtrees: subtrees_right, - }; - - (left, right) + if leaf.col < next_subtree_col { + last_subtree.push(leaf); + } else { + //std::eprintln!("\tcreating new subtree for column {}", leaf.col); + let next_subtree = vec![leaf]; + self.leaves.push(next_subtree); + } } } @@ -656,7 +674,7 @@ impl Default for PairComputations { fn default() -> Self { Self { nodes: Default::default(), - subtrees: Default::default(), + leaves: Default::default(), } } } @@ -665,50 +683,91 @@ impl Default for PairComputations { // ================================================================================================ #[cfg(test)] mod test { - use alloc::vec::Vec; + use alloc::{collections::BTreeMap, vec::Vec}; use super::SparseMerkleTree; use crate::{ hash::rpo::RpoDigest, - merkle::{smt::SubtreeLeaf, Smt, SmtLeaf, SMT_DEPTH}, - Felt, Word, EMPTY_WORD, ONE, + merkle::{ + smt::{PairComputations, SubtreeLeaf}, + Smt, SmtLeaf, SMT_DEPTH, + }, + Felt, Word, ONE, }; #[test] fn test_sorted_pairs_to_leaves() { let entries: Vec<(RpoDigest, Word)> = vec![ + // Subtree 0. (RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]), (RpoDigest::new([ONE, ONE, ONE, Felt::new(17)]), [ONE; 4]), // Leaf index collision. (RpoDigest::new([ONE, ONE, Felt::new(10), Felt::new(20)]), [ONE; 4]), (RpoDigest::new([ONE, ONE, Felt::new(20), Felt::new(20)]), [ONE; 4]), - // Normal single leaf again. - (RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), - // Empty leaf. - (RpoDigest::new([ONE, ONE, ONE, Felt::new(500)]), EMPTY_WORD), - ]; - let mut entries_iter = entries.iter().cloned(); - let mut next_entry = || entries_iter.next().unwrap(); - - let control_leaves: Vec = vec![ - SmtLeaf::Single(next_entry()), - SmtLeaf::Single(next_entry()), - SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(), - SmtLeaf::Single(next_entry()), - SmtLeaf::new_empty(Smt::key_to_leaf_index(&next_entry().0)), + // Subtree 1. Normal single leaf again. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), // Subtree boundary. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(401)]), [ONE; 4]), + // Subtree 2. Another normal leaf. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(1024)]), [ONE; 4]), ]; - let control_subtree_leaves: Vec = control_leaves - .iter() - .map(|leaf| { - let col = leaf.index().index.value(); - let hash = leaf.hash(); - SubtreeLeaf { col, hash } - }) + let control = Smt::with_entries(entries.clone()).unwrap(); + + let control_leaves: Vec = { + let mut entries_iter = entries.iter().cloned(); + let mut next_entry = || entries_iter.next().unwrap(); + let control_leaves = vec![ + // Subtree 0. + SmtLeaf::Single(next_entry()), + SmtLeaf::Single(next_entry()), + SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(), + // Subtree 1. + SmtLeaf::Single(next_entry()), + SmtLeaf::Single(next_entry()), + // Subtree 2. + SmtLeaf::Single(next_entry()), + ]; + assert_eq!(entries_iter.next(), None); + control_leaves + }; + + let control_subtree_leaves: Vec> = { + let mut control_leaves_iter = control_leaves.iter(); + let mut next_leaf = || control_leaves_iter.next().unwrap(); + + let control_subtree_leaves: Vec> = [ + // Subtree 0. + vec![next_leaf(), next_leaf(), next_leaf()], + // Subtree 1. + vec![next_leaf(), next_leaf()], + // Subtree 2. + vec![next_leaf()], + ] + .map(|subtree| subtree.into_iter().map(SubtreeLeaf::from_smt_leaf).collect()) + .to_vec(); + assert_eq!(control_leaves_iter.next(), None); + control_subtree_leaves + }; + + let subtrees = Smt::sorted_pairs_to_leaves(entries); + // This will check that the hashes, columns, and subtree assignments all match. + assert_eq!(subtrees.leaves, control_subtree_leaves); + + // Then finally we might as well check the computed leaf nodes too. + let control_leaves: BTreeMap = control + .leaves() + .map(|(index, value)| (index.index.value(), value.clone())) .collect(); - let test_subtree_leaves = Smt::sorted_pairs_to_leaves(entries).subtrees; - assert_eq!(control_subtree_leaves, test_subtree_leaves); + for (column, test_leaf) in subtrees.nodes { + if test_leaf.is_empty() { + continue; + } + let control_leaf = control_leaves + .get(&column) + .expect(&format!("no leaf node found for column {column}")); + assert_eq!(control_leaf, &test_leaf); + } } // Helper for the below tests. @@ -733,7 +792,8 @@ mod test { let control = Smt::with_entries(entries.clone()).unwrap(); // `entries` should already be sorted by nature of how we constructed it. - let leaves = Smt::sorted_pairs_to_leaves(entries).subtrees; + let leaves = Smt::sorted_pairs_to_leaves(entries).leaves; + let leaves = leaves.into_iter().next().unwrap(); let (first_subtree, _) = Smt::build_subtree(leaves, SMT_DEPTH); assert!(!first_subtree.is_empty()); @@ -756,17 +816,19 @@ mod test { let control = Smt::with_entries(entries.clone()).unwrap(); - let leaves = Smt::sorted_pairs_to_leaves(entries); - let (first, second) = leaves.split_at_column(PAIR_COUNT / 2); - assert_eq!(first.subtrees.len(), second.subtrees.len()); + let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries); + // With two subtrees' worth of leaves, we should have exactly two subtrees. + let [first, second]: [_; 2] = leaves.try_into().unwrap(); + assert_eq!(first.len() as u64, PAIR_COUNT / 2); + assert_eq!(first.len(), second.len()); let mut current_depth = SMT_DEPTH; let mut next_leaves: Vec = Default::default(); - let (first_nodes, leaves) = Smt::build_subtree(first.subtrees, current_depth); + let (first_nodes, leaves) = Smt::build_subtree(first, current_depth); next_leaves.extend(leaves); - let (second_nodes, leaves) = Smt::build_subtree(second.subtrees, current_depth); + let (second_nodes, leaves) = Smt::build_subtree(second, current_depth); next_leaves.extend(leaves); // All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle. From 74ab46ca69744bafb84fc0f556ca124f02d89f3a Mon Sep 17 00:00:00 2001 From: Qyriad Date: Thu, 31 Oct 2024 18:58:18 -0600 Subject: [PATCH 07/21] working test_singlethreaded_subtrees() --- src/merkle/smt/mod.rs | 111 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 100 insertions(+), 11 deletions(-) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index a95e7cb..e81be04 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -611,6 +611,10 @@ impl MutationSet { // HELPERS // ================================================================================================ + +/// A depth-8 subtree contains 256 "columns" that can possibly be occupied. +const COLS_PER_SUBTREE: u64 = u64::pow(2, 8); + #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] pub struct SubtreeLeaf { pub col: u64, @@ -637,9 +641,6 @@ pub struct PairComputations { impl PairComputations { pub fn add_leaf(&mut self, leaf: SubtreeLeaf) { - // A depth-8 subtree contains 256 "columns" that can possibly be occupied. - const COLS_PER_SUBTREE: u64 = u64::pow(2, 8); - let last_subtree = match self.leaves.last_mut() { // Base case. None => { @@ -654,7 +655,7 @@ impl PairComputations { // The multiple of 256 after 0 is 1, but 0 and 1 do not belong to different subtrees. let last_subtree_col = u64::max(1, last_subtree.last().unwrap().col); - let next_subtree_col = if last_subtree_col.is_multiple_of(&COLS_PER_SUBTREE) { + let next_subtree_col = if Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE) { u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE) } else { last_subtree_col.next_multiple_of(COLS_PER_SUBTREE) @@ -662,7 +663,6 @@ impl PairComputations { if leaf.col < next_subtree_col { last_subtree.push(leaf); } else { - //std::eprintln!("\tcreating new subtree for column {}", leaf.col); let next_subtree = vec![leaf]; self.leaves.push(next_subtree); } @@ -683,15 +683,16 @@ impl Default for PairComputations { // ================================================================================================ #[cfg(test)] mod test { + use core::mem; + use alloc::{collections::BTreeMap, vec::Vec}; - use super::SparseMerkleTree; + use num::Integer; + + use super::{InnerNode, PairComputations, SparseMerkleTree, SubtreeLeaf, COLS_PER_SUBTREE}; use crate::{ hash::rpo::RpoDigest, - merkle::{ - smt::{PairComputations, SubtreeLeaf}, - Smt, SmtLeaf, SMT_DEPTH, - }, + merkle::{NodeIndex, Smt, SmtLeaf, SMT_DEPTH}, Felt, Word, ONE, }; @@ -765,7 +766,7 @@ mod test { } let control_leaf = control_leaves .get(&column) - .expect(&format!("no leaf node found for column {column}")); + .unwrap_or_else(|| panic!("no leaf node found for column {column}")); assert_eq!(control_leaf, &test_leaf); } } @@ -859,4 +860,92 @@ mod test { ); } } + + #[test] + fn test_singlethreaded_subtrees() { + const PAIR_COUNT: u64 = 4096 * 4; + + let entries = generate_entries(PAIR_COUNT); + + let control = Smt::with_entries(entries.clone()).unwrap(); + + let mut accumulated_nodes: BTreeMap = Default::default(); + + let starting_leaves = Smt::sorted_pairs_to_leaves(entries); + + let mut leaf_subtrees = starting_leaves.leaves; + for current_depth in (8..=SMT_DEPTH).step_by(8).rev() { + for (i, subtree) in mem::take(&mut leaf_subtrees).into_iter().enumerate() { + // Pre-assertions. + assert!( + subtree.is_sorted(), + "subtree {i} at bottom-depth {current_depth} is not sorted", + ); + assert!( + !subtree.is_empty(), + "subtree {i} at bottom-depth {current_depth} is empty!", + ); + + // Do actual things. + let (nodes, next_leaves) = Smt::build_subtree(subtree, current_depth); + + // Post-assertions. + assert!(next_leaves.is_sorted()); + for (&index, test_node) in nodes.iter() { + let control_node = control.get_inner_node(index); + assert_eq!( + test_node, &control_node, + "depth {} subtree {}: test node does not match control at index {:?}", + current_depth, i, index, + ); + } + + // Update state. + accumulated_nodes.extend(nodes); + + for subtree_leaf in next_leaves { + if leaf_subtrees.is_empty() { + leaf_subtrees.push(vec![subtree_leaf]); + continue; + } + + let buffer_max_col = + u64::max(1, leaf_subtrees.last().unwrap().last().unwrap().col); + let next_subtree_col = + if Integer::is_multiple_of(&buffer_max_col, &COLS_PER_SUBTREE) { + u64::next_multiple_of(buffer_max_col + 1, COLS_PER_SUBTREE) + } else { + buffer_max_col.next_multiple_of(COLS_PER_SUBTREE) + }; + + if subtree_leaf.col < next_subtree_col { + leaf_subtrees.last_mut().unwrap().push(subtree_leaf); + } else { + leaf_subtrees.push(vec![subtree_leaf]); + } + } + } + + assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); + } + + for (index, test_node) in accumulated_nodes.clone() { + let control_node = control.get_inner_node(index); + assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); + } + + assert_eq!(leaf_subtrees.len(), 1); + let mut leaf_subtree = leaf_subtrees.pop().unwrap(); + assert_eq!(leaf_subtree.len(), 1); + let root_leaf = leaf_subtree.pop().unwrap(); + assert_eq!(control.root(), root_leaf.hash); + + // Do we have a root? + assert!(accumulated_nodes.contains_key(&NodeIndex::root())); + + // And does it match? + let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap(); + assert_eq!(control.root(), test_root.hash()); + assert_eq!(control.root(), root_leaf.hash); + } } From 49d88600c00019d54d52b019744f62cef2826732 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Mon, 4 Nov 2024 12:24:41 -0700 Subject: [PATCH 08/21] make PrecomputedSubtrees more generic --- src/merkle/smt/mod.rs | 352 +++++++++++++++++++++--------------------- 1 file changed, 176 insertions(+), 176 deletions(-) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index e81be04..72e06d5 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -68,28 +68,28 @@ pub(crate) trait SparseMerkleTree { /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle /// path to the leaf, as well as the leaf itself. fn open(&self, key: &Self::Key) -> Self::Opening { - let leaf = self.get_leaf(key); + let leaf = self.get_leaf(key); - let mut index: NodeIndex = { - let leaf_index: LeafIndex = Self::key_to_leaf_index(key); - leaf_index.into() - }; + let mut index: NodeIndex = { + let leaf_index: LeafIndex = Self::key_to_leaf_index(key); + leaf_index.into() + }; - let merkle_path = { - let mut path = Vec::with_capacity(index.depth() as usize); - for _ in 0..index.depth() { - let is_right = index.is_value_odd(); - index.move_up(); - let InnerNode { left, right } = self.get_inner_node(index); - let value = if is_right { left } else { right }; - path.push(value); - } + let merkle_path = { + let mut path = Vec::with_capacity(index.depth() as usize); + for _ in 0..index.depth() { + let is_right = index.is_value_odd(); + index.move_up(); + let InnerNode { left, right } = self.get_inner_node(index); + let value = if is_right { left } else { right }; + path.push(value); + } - MerklePath::new(path) - }; + MerklePath::new(path) + }; - Self::path_and_leaf_to_opening(merkle_path, leaf) - } + Self::path_and_leaf_to_opening(merkle_path, leaf) +} /// Inserts a value at the specified key, returning the previous value associated with that key. /// Recall that by definition, any key that hasn't been updated is associated with @@ -98,53 +98,53 @@ pub(crate) trait SparseMerkleTree { /// This also recomputes all hashes between the leaf (associated with the key) and the root, /// updating the root itself. fn insert(&mut self, key: Self::Key, value: Self::Value) -> Self::Value { - let old_value = self.insert_value(key.clone(), value.clone()).unwrap_or(Self::EMPTY_VALUE); + let old_value = self.insert_value(key.clone(), value.clone()).unwrap_or(Self::EMPTY_VALUE); - // if the old value and new value are the same, there is nothing to update - if value == old_value { - return value; - } - - let leaf = self.get_leaf(&key); - let node_index = { - let leaf_index: LeafIndex = Self::key_to_leaf_index(&key); - leaf_index.into() - }; - - self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf)); - - old_value + // if the old value and new value are the same, there is nothing to update + if value == old_value { + return value; } + let leaf = self.get_leaf(&key); + let node_index = { + let leaf_index: LeafIndex = Self::key_to_leaf_index(&key); + leaf_index.into() + }; + + self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf)); + + old_value +} + /// Recomputes the branch nodes (including the root) from `index` all the way to the root. /// `node_hash_at_index` is the hash of the node stored at index. fn recompute_nodes_from_index_to_root( - &mut self, - mut index: NodeIndex, - node_hash_at_index: RpoDigest, - ) { - let mut node_hash = node_hash_at_index; - for node_depth in (0..index.depth()).rev() { - let is_right = index.is_value_odd(); - index.move_up(); - let InnerNode { left, right } = self.get_inner_node(index); - let (left, right) = if is_right { - (left, node_hash) - } else { - (node_hash, right) - }; - node_hash = Rpo256::merge(&[left, right]); + &mut self, + mut index: NodeIndex, + node_hash_at_index: RpoDigest, +) { + let mut node_hash = node_hash_at_index; + for node_depth in (0..index.depth()).rev() { + let is_right = index.is_value_odd(); + index.move_up(); + let InnerNode { left, right } = self.get_inner_node(index); + let (left, right) = if is_right { + (left, node_hash) + } else { + (node_hash, right) + }; + node_hash = Rpo256::merge(&[left, right]); - if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) { - // If a subtree is empty, when can remove the inner node, since it's equal to the - // default value - self.remove_inner_node(index) - } else { - self.insert_inner_node(index, InnerNode { left, right }); - } + if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) { + // If a subtree is empty, when can remove the inner node, since it's equal to the + // default value + self.remove_inner_node(index) + } else { + self.insert_inner_node(index, InnerNode { left, right }); } - self.set_root(node_hash); } + self.set_root(node_hash); +} /// Computes what changes are necessary to insert the specified key-value pairs into this Merkle /// tree, allowing for validation before applying those changes. @@ -155,95 +155,95 @@ pub(crate) trait SparseMerkleTree { /// [`SparseMerkleTree::apply_mutations()`] can be called in order to commit these changes to /// the Merkle tree, or [`drop()`] to discard them. fn compute_mutations( - &self, - kv_pairs: impl IntoIterator, - ) -> MutationSet { - use NodeMutation::*; + &self, + kv_pairs: impl IntoIterator, +) -> MutationSet { + use NodeMutation::*; - let mut new_root = self.root(); - let mut new_pairs: BTreeMap = Default::default(); - let mut node_mutations: BTreeMap = Default::default(); + let mut new_root = self.root(); + let mut new_pairs: BTreeMap = Default::default(); + let mut node_mutations: BTreeMap = Default::default(); - for (key, value) in kv_pairs { - // If the old value and the new value are the same, there is nothing to update. - // For the unusual case that kv_pairs has multiple values at the same key, we'll have - // to check the key-value pairs we've already seen to get the "effective" old value. - let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key)); - if value == old_value { - continue; - } + for (key, value) in kv_pairs { + // If the old value and the new value are the same, there is nothing to update. + // For the unusual case that kv_pairs has multiple values at the same key, we'll have + // to check the key-value pairs we've already seen to get the "effective" old value. + let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key)); + if value == old_value { + continue; + } - let leaf_index = Self::key_to_leaf_index(&key); - let mut node_index = NodeIndex::from(leaf_index); + let leaf_index = Self::key_to_leaf_index(&key); + let mut node_index = NodeIndex::from(leaf_index); - // We need the current leaf's hash to calculate the new leaf, but in the rare case that - // `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also - // part of the "current leaf". - let old_leaf = { - let pairs_at_index = new_pairs - .iter() - .filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index); + // We need the current leaf's hash to calculate the new leaf, but in the rare case that + // `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also + // part of the "current leaf". + let old_leaf = { + let pairs_at_index = new_pairs + .iter() + .filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index); - pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| { - // Most of the time `pairs_at_index` should only contain a single entry (or - // none at all), as multi-leaves should be really rare. - let existing_leaf = acc.clone(); - self.construct_prospective_leaf(existing_leaf, k, v) + pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| { + // Most of the time `pairs_at_index` should only contain a single entry (or + // none at all), as multi-leaves should be really rare. + let existing_leaf = acc.clone(); + self.construct_prospective_leaf(existing_leaf, k, v) + }) + }; + + let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value); + + let mut new_child_hash = Self::hash_leaf(&new_leaf); + + for node_depth in (0..node_index.depth()).rev() { + // Whether the node we're replacing is the right child or the left child. + let is_right = node_index.is_value_odd(); + node_index.move_up(); + + let old_node = node_mutations + .get(&node_index) + .map(|mutation| match mutation { + Addition(node) => node.clone(), + Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth), }) + .unwrap_or_else(|| self.get_inner_node(node_index)); + + let new_node = if is_right { + InnerNode { + left: old_node.left, + right: new_child_hash, + } + } else { + InnerNode { + left: new_child_hash, + right: old_node.right, + } }; - let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value); + // The next iteration will operate on this new node's hash. + new_child_hash = new_node.hash(); - let mut new_child_hash = Self::hash_leaf(&new_leaf); - - for node_depth in (0..node_index.depth()).rev() { - // Whether the node we're replacing is the right child or the left child. - let is_right = node_index.is_value_odd(); - node_index.move_up(); - - let old_node = node_mutations - .get(&node_index) - .map(|mutation| match mutation { - Addition(node) => node.clone(), - Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth), - }) - .unwrap_or_else(|| self.get_inner_node(node_index)); - - let new_node = if is_right { - InnerNode { - left: old_node.left, - right: new_child_hash, - } - } else { - InnerNode { - left: new_child_hash, - right: old_node.right, - } - }; - - // The next iteration will operate on this new node's hash. - new_child_hash = new_node.hash(); - - let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth); - let is_removal = new_child_hash == equivalent_empty_hash; - let new_entry = if is_removal { Removal } else { Addition(new_node) }; - node_mutations.insert(node_index, new_entry); - } - - // Once we're at depth 0, the last node we made is the new root. - new_root = new_child_hash; - // And then we're done with this pair; on to the next one. - new_pairs.insert(key, value); + let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth); + let is_removal = new_child_hash == equivalent_empty_hash; + let new_entry = if is_removal { Removal } else { Addition(new_node) }; + node_mutations.insert(node_index, new_entry); } - MutationSet { - old_root: self.root(), - new_root, - node_mutations, - new_pairs, - } + // Once we're at depth 0, the last node we made is the new root. + new_root = new_child_hash; + // And then we're done with this pair; on to the next one. + new_pairs.insert(key, value); } + MutationSet { + old_root: self.root(), + new_root, + node_mutations, + new_pairs, + } +} + /// Apply the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to /// this tree. /// @@ -253,42 +253,42 @@ pub(crate) trait SparseMerkleTree { /// the `mutations` were computed against, and the second item is the actual current root of /// this tree. fn apply_mutations( - &mut self, - mutations: MutationSet, - ) -> Result<(), MerkleError> - where - Self: Sized, - { - use NodeMutation::*; - let MutationSet { - old_root, - node_mutations, - new_pairs, - new_root, - } = mutations; + &mut self, + mutations: MutationSet, +) -> Result<(), MerkleError> +where + Self: Sized, +{ + use NodeMutation::*; + let MutationSet { + old_root, + node_mutations, + new_pairs, + new_root, + } = mutations; - // Guard against accidentally trying to apply mutations that were computed against a - // different tree, including a stale version of this tree. - if old_root != self.root() { - return Err(MerkleError::ConflictingRoots(vec![old_root, self.root()])); - } - - for (index, mutation) in node_mutations { - match mutation { - Removal => self.remove_inner_node(index), - Addition(node) => self.insert_inner_node(index, node), - } - } - - for (key, value) in new_pairs { - self.insert_value(key, value); - } - - self.set_root(new_root); - - Ok(()) + // Guard against accidentally trying to apply mutations that were computed against a + // different tree, including a stale version of this tree. + if old_root != self.root() { + return Err(MerkleError::ConflictingRoots(vec![old_root, self.root()])); } + for (index, mutation) in node_mutations { + match mutation { + Removal => self.remove_inner_node(index), + Addition(node) => self.insert_inner_node(index, node), + } + } + + for (key, value) in new_pairs { + self.insert_value(key, value); + } + + self.set_root(new_root); + + Ok(()) +} + // REQUIRED METHODS // --------------------------------------------------------------------------------------------- @@ -332,11 +332,11 @@ pub(crate) trait SparseMerkleTree { /// `existing_leaf` must have the same leaf index as `key` (as determined by /// [`SparseMerkleTree::key_to_leaf_index()`]), or the result will be meaningless. fn construct_prospective_leaf( - &self, - existing_leaf: Self::Leaf, - key: &Self::Key, - value: &Self::Value, - ) -> Self::Leaf; + &self, + existing_leaf: Self::Leaf, + key: &Self::Key, + value: &Self::Value, +) -> Self::Leaf; /// Maps a key to a leaf index fn key_to_leaf_index(key: &Self::Key) -> LeafIndex; @@ -352,8 +352,8 @@ pub(crate) trait SparseMerkleTree { fn sorted_pairs_to_leaves( pairs: Vec<(Self::Key, Self::Value)>, - ) -> PairComputations { - let mut accumulator: PairComputations = Default::default(); + ) -> PairComputations { + let mut accumulator: PairComputations = Default::default(); // The kv-pairs we've seen so far that correspond to a single leaf. let mut current_leaf_buffer: Vec<(Self::Key, Self::Value)> = Default::default(); @@ -632,14 +632,14 @@ impl SubtreeLeaf { } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct PairComputations { +pub struct PairComputations { /// Literal leaves to be added to the sparse Merkle tree's internal mapping. - pub nodes: BTreeMap, + pub nodes: BTreeMap, /// "Conceptual" leaves that will be used for computations. pub leaves: Vec>, } -impl PairComputations { +impl PairComputations { pub fn add_leaf(&mut self, leaf: SubtreeLeaf) { let last_subtree = match self.leaves.last_mut() { // Base case. @@ -670,7 +670,7 @@ impl PairComputations { } // Derive requires `L` to impl Default, even though we don't actually need that. -impl Default for PairComputations { +impl Default for PairComputations { fn default() -> Self { Self { nodes: Default::default(), From 2b04a93a15f5a20b9ac6c271562a8d7b33d385f0 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Mon, 4 Nov 2024 13:01:57 -0700 Subject: [PATCH 09/21] rename PrecomputedSubtrees -> PairComputations --- src/merkle/smt/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 72e06d5..4cc34d3 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -632,7 +632,7 @@ impl SubtreeLeaf { } #[derive(Debug, Clone, PartialEq, Eq)] -pub struct PairComputations { +pub(crate) struct PairComputations { /// Literal leaves to be added to the sparse Merkle tree's internal mapping. pub nodes: BTreeMap, /// "Conceptual" leaves that will be used for computations. From 60f4dd2161a5efc34d785d22b22968cd86dd2538 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Mon, 4 Nov 2024 12:53:27 -0700 Subject: [PATCH 10/21] factor out subtree-append logic --- src/merkle/smt/mod.rs | 395 ++++++++++++++++++++++-------------------- 1 file changed, 203 insertions(+), 192 deletions(-) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 4cc34d3..b8abdba 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -68,28 +68,28 @@ pub(crate) trait SparseMerkleTree { /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle /// path to the leaf, as well as the leaf itself. fn open(&self, key: &Self::Key) -> Self::Opening { - let leaf = self.get_leaf(key); + let leaf = self.get_leaf(key); - let mut index: NodeIndex = { - let leaf_index: LeafIndex = Self::key_to_leaf_index(key); - leaf_index.into() - }; + let mut index: NodeIndex = { + let leaf_index: LeafIndex = Self::key_to_leaf_index(key); + leaf_index.into() + }; - let merkle_path = { - let mut path = Vec::with_capacity(index.depth() as usize); - for _ in 0..index.depth() { - let is_right = index.is_value_odd(); - index.move_up(); - let InnerNode { left, right } = self.get_inner_node(index); - let value = if is_right { left } else { right }; - path.push(value); - } + let merkle_path = { + let mut path = Vec::with_capacity(index.depth() as usize); + for _ in 0..index.depth() { + let is_right = index.is_value_odd(); + index.move_up(); + let InnerNode { left, right } = self.get_inner_node(index); + let value = if is_right { left } else { right }; + path.push(value); + } - MerklePath::new(path) - }; + MerklePath::new(path) + }; - Self::path_and_leaf_to_opening(merkle_path, leaf) -} + Self::path_and_leaf_to_opening(merkle_path, leaf) + } /// Inserts a value at the specified key, returning the previous value associated with that key. /// Recall that by definition, any key that hasn't been updated is associated with @@ -98,53 +98,53 @@ pub(crate) trait SparseMerkleTree { /// This also recomputes all hashes between the leaf (associated with the key) and the root, /// updating the root itself. fn insert(&mut self, key: Self::Key, value: Self::Value) -> Self::Value { - let old_value = self.insert_value(key.clone(), value.clone()).unwrap_or(Self::EMPTY_VALUE); + let old_value = self.insert_value(key.clone(), value.clone()).unwrap_or(Self::EMPTY_VALUE); - // if the old value and new value are the same, there is nothing to update - if value == old_value { - return value; + // if the old value and new value are the same, there is nothing to update + if value == old_value { + return value; + } + + let leaf = self.get_leaf(&key); + let node_index = { + let leaf_index: LeafIndex = Self::key_to_leaf_index(&key); + leaf_index.into() + }; + + self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf)); + + old_value } - let leaf = self.get_leaf(&key); - let node_index = { - let leaf_index: LeafIndex = Self::key_to_leaf_index(&key); - leaf_index.into() - }; - - self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf)); - - old_value -} - /// Recomputes the branch nodes (including the root) from `index` all the way to the root. /// `node_hash_at_index` is the hash of the node stored at index. fn recompute_nodes_from_index_to_root( - &mut self, - mut index: NodeIndex, - node_hash_at_index: RpoDigest, -) { - let mut node_hash = node_hash_at_index; - for node_depth in (0..index.depth()).rev() { - let is_right = index.is_value_odd(); - index.move_up(); - let InnerNode { left, right } = self.get_inner_node(index); - let (left, right) = if is_right { - (left, node_hash) - } else { - (node_hash, right) - }; - node_hash = Rpo256::merge(&[left, right]); + &mut self, + mut index: NodeIndex, + node_hash_at_index: RpoDigest, + ) { + let mut node_hash = node_hash_at_index; + for node_depth in (0..index.depth()).rev() { + let is_right = index.is_value_odd(); + index.move_up(); + let InnerNode { left, right } = self.get_inner_node(index); + let (left, right) = if is_right { + (left, node_hash) + } else { + (node_hash, right) + }; + node_hash = Rpo256::merge(&[left, right]); - if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) { - // If a subtree is empty, when can remove the inner node, since it's equal to the - // default value - self.remove_inner_node(index) - } else { - self.insert_inner_node(index, InnerNode { left, right }); + if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) { + // If a subtree is empty, when can remove the inner node, since it's equal to the + // default value + self.remove_inner_node(index) + } else { + self.insert_inner_node(index, InnerNode { left, right }); + } } + self.set_root(node_hash); } - self.set_root(node_hash); -} /// Computes what changes are necessary to insert the specified key-value pairs into this Merkle /// tree, allowing for validation before applying those changes. @@ -155,95 +155,95 @@ pub(crate) trait SparseMerkleTree { /// [`SparseMerkleTree::apply_mutations()`] can be called in order to commit these changes to /// the Merkle tree, or [`drop()`] to discard them. fn compute_mutations( - &self, - kv_pairs: impl IntoIterator, -) -> MutationSet { - use NodeMutation::*; + &self, + kv_pairs: impl IntoIterator, + ) -> MutationSet { + use NodeMutation::*; - let mut new_root = self.root(); - let mut new_pairs: BTreeMap = Default::default(); - let mut node_mutations: BTreeMap = Default::default(); + let mut new_root = self.root(); + let mut new_pairs: BTreeMap = Default::default(); + let mut node_mutations: BTreeMap = Default::default(); - for (key, value) in kv_pairs { - // If the old value and the new value are the same, there is nothing to update. - // For the unusual case that kv_pairs has multiple values at the same key, we'll have - // to check the key-value pairs we've already seen to get the "effective" old value. - let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key)); - if value == old_value { - continue; - } + for (key, value) in kv_pairs { + // If the old value and the new value are the same, there is nothing to update. + // For the unusual case that kv_pairs has multiple values at the same key, we'll have + // to check the key-value pairs we've already seen to get the "effective" old value. + let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key)); + if value == old_value { + continue; + } - let leaf_index = Self::key_to_leaf_index(&key); - let mut node_index = NodeIndex::from(leaf_index); + let leaf_index = Self::key_to_leaf_index(&key); + let mut node_index = NodeIndex::from(leaf_index); - // We need the current leaf's hash to calculate the new leaf, but in the rare case that - // `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also - // part of the "current leaf". - let old_leaf = { - let pairs_at_index = new_pairs - .iter() - .filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index); + // We need the current leaf's hash to calculate the new leaf, but in the rare case that + // `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also + // part of the "current leaf". + let old_leaf = { + let pairs_at_index = new_pairs + .iter() + .filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index); - pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| { - // Most of the time `pairs_at_index` should only contain a single entry (or - // none at all), as multi-leaves should be really rare. - let existing_leaf = acc.clone(); - self.construct_prospective_leaf(existing_leaf, k, v) - }) - }; - - let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value); - - let mut new_child_hash = Self::hash_leaf(&new_leaf); - - for node_depth in (0..node_index.depth()).rev() { - // Whether the node we're replacing is the right child or the left child. - let is_right = node_index.is_value_odd(); - node_index.move_up(); - - let old_node = node_mutations - .get(&node_index) - .map(|mutation| match mutation { - Addition(node) => node.clone(), - Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth), + pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| { + // Most of the time `pairs_at_index` should only contain a single entry (or + // none at all), as multi-leaves should be really rare. + let existing_leaf = acc.clone(); + self.construct_prospective_leaf(existing_leaf, k, v) }) - .unwrap_or_else(|| self.get_inner_node(node_index)); - - let new_node = if is_right { - InnerNode { - left: old_node.left, - right: new_child_hash, - } - } else { - InnerNode { - left: new_child_hash, - right: old_node.right, - } }; - // The next iteration will operate on this new node's hash. - new_child_hash = new_node.hash(); + let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value); - let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth); - let is_removal = new_child_hash == equivalent_empty_hash; - let new_entry = if is_removal { Removal } else { Addition(new_node) }; - node_mutations.insert(node_index, new_entry); + let mut new_child_hash = Self::hash_leaf(&new_leaf); + + for node_depth in (0..node_index.depth()).rev() { + // Whether the node we're replacing is the right child or the left child. + let is_right = node_index.is_value_odd(); + node_index.move_up(); + + let old_node = node_mutations + .get(&node_index) + .map(|mutation| match mutation { + Addition(node) => node.clone(), + Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth), + }) + .unwrap_or_else(|| self.get_inner_node(node_index)); + + let new_node = if is_right { + InnerNode { + left: old_node.left, + right: new_child_hash, + } + } else { + InnerNode { + left: new_child_hash, + right: old_node.right, + } + }; + + // The next iteration will operate on this new node's hash. + new_child_hash = new_node.hash(); + + let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth); + let is_removal = new_child_hash == equivalent_empty_hash; + let new_entry = if is_removal { Removal } else { Addition(new_node) }; + node_mutations.insert(node_index, new_entry); + } + + // Once we're at depth 0, the last node we made is the new root. + new_root = new_child_hash; + // And then we're done with this pair; on to the next one. + new_pairs.insert(key, value); } - // Once we're at depth 0, the last node we made is the new root. - new_root = new_child_hash; - // And then we're done with this pair; on to the next one. - new_pairs.insert(key, value); + MutationSet { + old_root: self.root(), + new_root, + node_mutations, + new_pairs, + } } - MutationSet { - old_root: self.root(), - new_root, - node_mutations, - new_pairs, - } -} - /// Apply the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to /// this tree. /// @@ -253,42 +253,42 @@ pub(crate) trait SparseMerkleTree { /// the `mutations` were computed against, and the second item is the actual current root of /// this tree. fn apply_mutations( - &mut self, - mutations: MutationSet, -) -> Result<(), MerkleError> -where - Self: Sized, -{ - use NodeMutation::*; - let MutationSet { - old_root, - node_mutations, - new_pairs, - new_root, - } = mutations; + &mut self, + mutations: MutationSet, + ) -> Result<(), MerkleError> + where + Self: Sized, + { + use NodeMutation::*; + let MutationSet { + old_root, + node_mutations, + new_pairs, + new_root, + } = mutations; - // Guard against accidentally trying to apply mutations that were computed against a - // different tree, including a stale version of this tree. - if old_root != self.root() { - return Err(MerkleError::ConflictingRoots(vec![old_root, self.root()])); - } - - for (index, mutation) in node_mutations { - match mutation { - Removal => self.remove_inner_node(index), - Addition(node) => self.insert_inner_node(index, node), + // Guard against accidentally trying to apply mutations that were computed against a + // different tree, including a stale version of this tree. + if old_root != self.root() { + return Err(MerkleError::ConflictingRoots(vec![old_root, self.root()])); } + + for (index, mutation) in node_mutations { + match mutation { + Removal => self.remove_inner_node(index), + Addition(node) => self.insert_inner_node(index, node), + } + } + + for (key, value) in new_pairs { + self.insert_value(key, value); + } + + self.set_root(new_root); + + Ok(()) } - for (key, value) in new_pairs { - self.insert_value(key, value); - } - - self.set_root(new_root); - - Ok(()) -} - // REQUIRED METHODS // --------------------------------------------------------------------------------------------- @@ -332,11 +332,11 @@ where /// `existing_leaf` must have the same leaf index as `key` (as determined by /// [`SparseMerkleTree::key_to_leaf_index()`]), or the result will be meaningless. fn construct_prospective_leaf( - &self, - existing_leaf: Self::Leaf, - key: &Self::Key, - value: &Self::Value, -) -> Self::Leaf; + &self, + existing_leaf: Self::Leaf, + key: &Self::Key, + value: &Self::Value, + ) -> Self::Leaf; /// Maps a key to a leaf index fn key_to_leaf_index(key: &Self::Key) -> LeafIndex; @@ -383,7 +383,7 @@ where let hash = Self::hash_leaf(&leaf); accumulator.nodes.insert(col, leaf); - accumulator.add_leaf(SubtreeLeaf { col, hash }); + add_subtree_leaf(&mut accumulator.leaves, SubtreeLeaf { col, hash }); debug_assert!(current_leaf_buffer.is_empty()); } @@ -631,6 +631,7 @@ impl SubtreeLeaf { } } +/// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`]. #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct PairComputations { /// Literal leaves to be added to the sparse Merkle tree's internal mapping. @@ -679,6 +680,36 @@ impl Default for PairComputations { } } +/// Handles the logic for figuring out whether the new leaf starts a new subtree or not. +fn add_subtree_leaf(subtrees: &mut Vec>, leaf: SubtreeLeaf) { + let last_subtree = match subtrees.last_mut() { + // Base case. + None => { + subtrees.push(vec![leaf]); + return; + }, + Some(last_subtree) => last_subtree, + }; + + debug_assert!(!last_subtree.is_empty()); + debug_assert!(last_subtree.len() <= COLS_PER_SUBTREE as usize); + + // The multiple of 256 after 0 is 1, but 0 and 1 do not belong to different subtrees. + let last_subtree_col = u64::max(1, last_subtree.last().unwrap().col); + let next_subtree_col = if Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE) { + u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE) + } else { + last_subtree_col.next_multiple_of(COLS_PER_SUBTREE) + }; + + if leaf.col < next_subtree_col { + last_subtree.push(leaf); + } else { + let next_subtree = vec![leaf]; + subtrees.push(next_subtree); + } +} + // TESTS // ================================================================================================ #[cfg(test)] @@ -687,9 +718,7 @@ mod test { use alloc::{collections::BTreeMap, vec::Vec}; - use num::Integer; - - use super::{InnerNode, PairComputations, SparseMerkleTree, SubtreeLeaf, COLS_PER_SUBTREE}; + use super::{InnerNode, PairComputations, SparseMerkleTree, SubtreeLeaf}; use crate::{ hash::rpo::RpoDigest, merkle::{NodeIndex, Smt, SmtLeaf, SMT_DEPTH}, @@ -904,25 +933,7 @@ mod test { accumulated_nodes.extend(nodes); for subtree_leaf in next_leaves { - if leaf_subtrees.is_empty() { - leaf_subtrees.push(vec![subtree_leaf]); - continue; - } - - let buffer_max_col = - u64::max(1, leaf_subtrees.last().unwrap().last().unwrap().col); - let next_subtree_col = - if Integer::is_multiple_of(&buffer_max_col, &COLS_PER_SUBTREE) { - u64::next_multiple_of(buffer_max_col + 1, COLS_PER_SUBTREE) - } else { - buffer_max_col.next_multiple_of(COLS_PER_SUBTREE) - }; - - if subtree_leaf.col < next_subtree_col { - leaf_subtrees.last_mut().unwrap().push(subtree_leaf); - } else { - leaf_subtrees.push(vec![subtree_leaf]); - } + super::add_subtree_leaf(&mut leaf_subtrees, subtree_leaf); } } From 1c5fc8a8300e8ae347da54210fa75bf56f3632f8 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Wed, 13 Nov 2024 12:32:21 -0700 Subject: [PATCH 11/21] add SubtreeLeavesIter --- src/merkle/smt/mod.rs | 48 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index b8abdba..e97e6ee 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -710,6 +710,54 @@ fn add_subtree_leaf(subtrees: &mut Vec>, leaf: SubtreeLeaf) { } } +#[derive(Debug)] +struct SubtreeLeavesIter<'s> { + leaves: core::iter::Peekable>, +} + +impl<'s> SubtreeLeavesIter<'s> { + fn from_leaves(leaves: &'s mut Vec) -> Self { + Self { leaves: leaves.drain(..).peekable() } + } +} + +impl<'s> core::iter::Iterator for SubtreeLeavesIter<'s> { + type Item = Vec; + + /// Each `next()` collects an entire subtree. + fn next(&mut self) -> Option> { + let mut subtree: Vec = Default::default(); + + let mut last_subtree_col = 0; + + while let Some(leaf) = self.leaves.peek() { + last_subtree_col = u64::max(1, last_subtree_col); + let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE); + let next_subtree_col = if is_exact_multiple { + u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE) + } else { + last_subtree_col.next_multiple_of(COLS_PER_SUBTREE) + }; + + last_subtree_col = leaf.col; + if leaf.col < next_subtree_col { + subtree.push(self.leaves.next().unwrap()); + } else if subtree.is_empty() { + continue; + } else { + break; + } + } + + if subtree.is_empty() { + debug_assert!(self.leaves.peek().is_none()); + return None; + } + + Some(subtree) + } +} + // TESTS // ================================================================================================ #[cfg(test)] From aa3197fcc1184310d5b631f8f5ba98351646168d Mon Sep 17 00:00:00 2001 From: Qyriad Date: Mon, 4 Nov 2024 13:38:55 -0700 Subject: [PATCH 12/21] cleanup test_singlethreaded_subtrees a bit --- src/merkle/smt/mod.rs | 52 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index e97e6ee..d3db29f 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -615,9 +615,15 @@ impl MutationSet { /// A depth-8 subtree contains 256 "columns" that can possibly be occupied. const COLS_PER_SUBTREE: u64 = u64::pow(2, 8); +/// Helper struct for organizing the data we care about when computing Merkle subtrees. +/// +/// Note that these represet "conceptual" leaves of some subtree, not necessarily +/// [`SparseMerkleTree::Leaf`]. #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] pub struct SubtreeLeaf { + /// The 'value' field of [`NodeIndex`]. When computing a subtree, the depth is already known. pub col: u64, + /// The hash of the node this `SubtreeLeaf` represents. pub hash: RpoDigest, } @@ -769,7 +775,7 @@ mod test { use super::{InnerNode, PairComputations, SparseMerkleTree, SubtreeLeaf}; use crate::{ hash::rpo::RpoDigest, - merkle::{NodeIndex, Smt, SmtLeaf, SMT_DEPTH}, + merkle::{LeafIndex, NodeIndex, Smt, SmtLeaf, SMT_DEPTH}, Felt, Word, ONE, }; @@ -948,9 +954,11 @@ mod test { let mut accumulated_nodes: BTreeMap = Default::default(); - let starting_leaves = Smt::sorted_pairs_to_leaves(entries); + let PairComputations { + leaves: mut leaf_subtrees, + nodes: test_leaves, + } = Smt::sorted_pairs_to_leaves(entries); - let mut leaf_subtrees = starting_leaves.leaves; for current_depth in (8..=SMT_DEPTH).step_by(8).rev() { for (i, subtree) in mem::take(&mut leaf_subtrees).into_iter().enumerate() { // Pre-assertions. @@ -988,23 +996,45 @@ mod test { assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); } + // Make sure the true leaves match, checking length first and then each individual leaf. + let control_leaves: BTreeMap<_, _> = control.leaves().collect(); + let control_leaves_len = control_leaves.len(); + let test_leaves_len = test_leaves.len(); + assert_eq!(test_leaves_len, control_leaves_len); + for (col, ref test_leaf) in test_leaves { + let index = LeafIndex::new_max_depth(col); + let &control_leaf = control_leaves.get(&index).unwrap(); + assert_eq!(test_leaf, control_leaf); + } + + // Make sure the inner nodes match, checking length first and then each individual leaf. + let control_nodes_len = control.inner_nodes().count(); + let test_nodes_len = accumulated_nodes.len(); + assert_eq!(test_nodes_len, control_nodes_len); for (index, test_node) in accumulated_nodes.clone() { let control_node = control.get_inner_node(index); assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); } - assert_eq!(leaf_subtrees.len(), 1); - let mut leaf_subtree = leaf_subtrees.pop().unwrap(); - assert_eq!(leaf_subtree.len(), 1); - let root_leaf = leaf_subtree.pop().unwrap(); + // After the last iteration of the above for loop, we should have the new root node actually + // in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from + // `build_subtree()`. So let's check both! + + let control_root = control.get_inner_node(NodeIndex::root()); + + // That for loop should have left us with only one leaf subtree... + let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap(); + // which itself contains only one 'leaf'... + let [root_leaf]: [_; 1] = leaf_subtree.try_into().unwrap(); + // which matches the expected root. assert_eq!(control.root(), root_leaf.hash); - // Do we have a root? + // Likewise `accumulated_nodes` should contain a node at the root index... assert!(accumulated_nodes.contains_key(&NodeIndex::root())); - - // And does it match? + // and it should match our actual root. let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap(); - assert_eq!(control.root(), test_root.hash()); + assert_eq!(control_root, *test_root); + // And of course the root we got from each place should match. assert_eq!(control.root(), root_leaf.hash); } } From a14e67bf2aec0cbd3865b33600a614bc7e336cf7 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Mon, 4 Nov 2024 13:56:41 -0700 Subject: [PATCH 13/21] SubtreeLeaf::from_smt_leaf() was only used in tests --- src/merkle/smt/mod.rs | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index d3db29f..79e9faa 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -627,16 +627,6 @@ pub struct SubtreeLeaf { pub hash: RpoDigest, } -impl SubtreeLeaf { - #[cfg_attr(not(test), allow(dead_code))] - fn from_smt_leaf(leaf: &crate::merkle::SmtLeaf) -> Self { - Self { - col: leaf.index().index.value(), - hash: leaf.hash(), - } - } -} - /// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`]. #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct PairComputations { @@ -779,6 +769,13 @@ mod test { Felt, Word, ONE, }; + fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf { + SubtreeLeaf { + col: leaf.index().index.value(), + hash: leaf.hash(), + } + } + #[test] fn test_sorted_pairs_to_leaves() { let entries: Vec<(RpoDigest, Word)> = vec![ @@ -827,7 +824,7 @@ mod test { // Subtree 2. vec![next_leaf()], ] - .map(|subtree| subtree.into_iter().map(SubtreeLeaf::from_smt_leaf).collect()) + .map(|subtree| subtree.into_iter().map(smtleaf_to_subtree_leaf).collect()) .to_vec(); assert_eq!(control_leaves_iter.next(), None); control_subtree_leaves From 327499095121a66a320e516659ad6a7f9daf2801 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Mon, 4 Nov 2024 17:09:35 -0700 Subject: [PATCH 14/21] smt: make with_entries() a trait method --- src/merkle/smt/full/mod.rs | 48 ++++++++++++++++++++---------------- src/merkle/smt/mod.rs | 7 ++++++ src/merkle/smt/simple/mod.rs | 8 ++++++ 3 files changed, 42 insertions(+), 21 deletions(-) diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 89d0702..0b170d6 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -78,27 +78,7 @@ impl Smt { pub fn with_entries( entries: impl IntoIterator, ) -> Result { - // create an empty tree - let mut tree = Self::new(); - - // This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so - // entries with the empty value need additional tracking. - let mut key_set_to_zero = BTreeSet::new(); - - for (key, value) in entries { - let old_value = tree.insert(key, value); - - if old_value != EMPTY_WORD || key_set_to_zero.contains(&key) { - return Err(MerkleError::DuplicateValuesForIndex( - LeafIndex::::from(key).value(), - )); - } - - if value == EMPTY_WORD { - key_set_to_zero.insert(key); - }; - } - Ok(tree) + >::with_entries(entries) } // PUBLIC ACCESSORS @@ -267,6 +247,32 @@ impl SparseMerkleTree for Smt { const EMPTY_VALUE: Self::Value = EMPTY_WORD; const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0); + fn with_entries( + entries: impl IntoIterator, + ) -> Result { + // create an empty tree + let mut tree = Self::new(); + + // This being a sparse data structure, the EMPTY_WORD is not assigned to the `BTreeMap`, so + // entries with the empty value need additional tracking. + let mut key_set_to_zero = BTreeSet::new(); + + for (key, value) in entries { + let old_value = tree.insert(key, value); + + if old_value != EMPTY_WORD || key_set_to_zero.contains(&key) { + return Err(MerkleError::DuplicateValuesForIndex( + LeafIndex::::from(key).value(), + )); + } + + if value == EMPTY_WORD { + key_set_to_zero.insert(key); + }; + } + Ok(tree) + } + fn root(&self) -> RpoDigest { self.root } diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 79e9faa..7f9df46 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -292,6 +292,13 @@ pub(crate) trait SparseMerkleTree { // REQUIRED METHODS // --------------------------------------------------------------------------------------------- + /// Construct a tree from an iterator of its keys and values. + fn with_entries( + entries: impl IntoIterator, + ) -> Result + where + Self: Sized; + /// The root of the tree fn root(&self) -> RpoDigest; diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index 04476a0..0a30b04 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -309,6 +309,14 @@ impl SparseMerkleTree for SimpleSmt { const EMPTY_VALUE: Self::Value = EMPTY_WORD; const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0); + fn with_entries( + entries: impl IntoIterator, Word)>, + ) -> Result { + >::with_leaves( + entries.into_iter().map(|(key, value)| (key.value(), value)), + ) + } + fn root(&self) -> RpoDigest { self.root } From 5de20ade48ba03b7ab6b0d00126b05f4da0ea7e2 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Tue, 5 Nov 2024 13:04:24 -0700 Subject: [PATCH 15/21] convert test_singlethreaded_subtree to use SubtreeLeavesIter --- src/merkle/smt/mod.rs | 75 ++++++++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 33 deletions(-) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 7f9df46..dab9501 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -765,14 +765,15 @@ impl<'s> core::iter::Iterator for SubtreeLeavesIter<'s> { // ================================================================================================ #[cfg(test)] mod test { - use core::mem; - use alloc::{collections::BTreeMap, vec::Vec}; - use super::{InnerNode, PairComputations, SparseMerkleTree, SubtreeLeaf}; + use super::{ + InnerNode, LeafIndex, PairComputations, SmtLeaf, SparseMerkleTree, SubtreeLeaf, + SubtreeLeavesIter, + }; use crate::{ hash::rpo::RpoDigest, - merkle::{LeafIndex, NodeIndex, Smt, SmtLeaf, SMT_DEPTH}, + merkle::{NodeIndex, Smt, SMT_DEPTH}, Felt, Word, ONE, }; @@ -964,38 +965,46 @@ mod test { } = Smt::sorted_pairs_to_leaves(entries); for current_depth in (8..=SMT_DEPTH).step_by(8).rev() { - for (i, subtree) in mem::take(&mut leaf_subtrees).into_iter().enumerate() { - // Pre-assertions. - assert!( - subtree.is_sorted(), - "subtree {i} at bottom-depth {current_depth} is not sorted", - ); - assert!( - !subtree.is_empty(), - "subtree {i} at bottom-depth {current_depth} is empty!", - ); - - // Do actual things. - let (nodes, next_leaves) = Smt::build_subtree(subtree, current_depth); - - // Post-assertions. - assert!(next_leaves.is_sorted()); - for (&index, test_node) in nodes.iter() { - let control_node = control.get_inner_node(index); - assert_eq!( - test_node, &control_node, - "depth {} subtree {}: test node does not match control at index {:?}", - current_depth, i, index, + // There's no flat_map_unzip(), so this is the best we can do. + let (nodes, subtrees): (Vec>, Vec>) = leaf_subtrees + .into_iter() + .enumerate() + .map(|(i, subtree)| { + // Pre-assertions. + assert!( + subtree.is_sorted(), + "subtree {i} at bottom-depth {current_depth} is not sorted", + ); + assert!( + !subtree.is_empty(), + "subtree {i} at bottom-depth {current_depth} is empty!", ); - } - // Update state. - accumulated_nodes.extend(nodes); + // Do actual things. + let (nodes, next_leaves) = Smt::build_subtree(subtree, current_depth); + // Post-assertions. + assert!(next_leaves.is_sorted()); - for subtree_leaf in next_leaves { - super::add_subtree_leaf(&mut leaf_subtrees, subtree_leaf); - } - } + for (&index, test_node) in nodes.iter() { + let control_node = control.get_inner_node(index); + assert_eq!( + test_node, &control_node, + "depth {} subtree {}: test node does not match control at index {:?}", + current_depth, i, index, + ); + } + + (nodes, next_leaves) + }) + .unzip(); + + // Update state between each depth iteration. + + // FIXME: is this flatten or Box better? + let mut all_leaves: Vec = subtrees.into_iter().flatten().collect(); + leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect(); + + accumulated_nodes.extend(nodes.into_iter().flatten()); assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); } From 6a0721b57d21cf4b40c7ddbe0f4cf5fc3065c0e1 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Tue, 5 Nov 2024 13:28:51 -0700 Subject: [PATCH 16/21] implement test_multithreaded_subtree --- Cargo.lock | 1 + Cargo.toml | 1 + src/merkle/smt/mod.rs | 100 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 102 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 3e822fb..fca1e2c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -534,6 +534,7 @@ dependencies = [ "rand", "rand_chacha", "rand_core", + "rayon", "seq-macro", "serde", "sha3", diff --git a/Cargo.toml b/Cargo.toml index 74df3ab..09fe0ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -66,6 +66,7 @@ sha3 = { version = "0.10", default-features = false } winter-crypto = { version = "0.10", default-features = false } winter-math = { version = "0.10", default-features = false } winter-utils = { version = "0.10", default-features = false } +rayon = "1.10.0" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index dab9501..c50dc27 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -1050,4 +1050,104 @@ mod test { // And of course the root we got from each place should match. assert_eq!(control.root(), root_leaf.hash); } + + #[test] + fn test_multithreaded_subtrees() { + use rayon::prelude::*; + + const PAIR_COUNT: u64 = 4096 * 4; + + let entries = generate_entries(PAIR_COUNT); + + let control = Smt::with_entries(entries.clone()).unwrap(); + + let mut accumulated_nodes: BTreeMap = Default::default(); + + let PairComputations { + leaves: mut leaf_subtrees, + nodes: test_leaves, + } = Smt::sorted_pairs_to_leaves(entries); + + for current_depth in (8..=SMT_DEPTH).step_by(8).rev() { + let (nodes, subtrees): (Vec>, Vec>) = leaf_subtrees + .into_par_iter() + .enumerate() + .map(|(i, subtree)| { + // Pre-assertions. + assert!( + subtree.is_sorted(), + "subtree {i} at bottom-depth {current_depth} is not sorted", + ); + assert!( + !subtree.is_empty(), + "subtree {i} at bottom-depth {current_depth} is empty!", + ); + + let (nodes, next_leaves) = Smt::build_subtree(subtree, current_depth); + + // Post-assertions. + assert!(next_leaves.is_sorted()); + for (&index, test_node) in nodes.iter() { + let control_node = control.get_inner_node(index); + assert_eq!( + test_node, &control_node, + "depth {} subtree {}: test node does not match control at index {:?}", + current_depth, i, index, + ); + } + + (nodes, next_leaves) + }) + // FIXME: unzip_into_vecs() instead? + .unzip(); + + let mut all_leaves: Vec = subtrees.into_iter().flatten().collect(); + leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect(); + + accumulated_nodes.extend(nodes.into_iter().flatten()); + + assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); + } + + // Make sure the true leaves match, checking length first and then each individual leaf. + let control_leaves: BTreeMap<_, _> = control.leaves().collect(); + let control_leaves_len = control_leaves.len(); + let test_leaves_len = test_leaves.len(); + assert_eq!(test_leaves_len, control_leaves_len); + for (col, ref test_leaf) in test_leaves { + let index = LeafIndex::new_max_depth(col); + let &control_leaf = control_leaves.get(&index).unwrap(); + assert_eq!(test_leaf, control_leaf); + } + + // Make sure the inner nodes match, checking length first and then each individual leaf. + let control_nodes_len = control.inner_nodes().count(); + let test_nodes_len = accumulated_nodes.len(); + assert_eq!(test_nodes_len, control_nodes_len); + for (index, test_node) in accumulated_nodes.clone() { + let control_node = control.get_inner_node(index); + assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); + } + + // After the last iteration of the above for loop, we should have the new root node actually + // in two places: one in `accumulated_nodes`, and the other as the "next leaves" return from + // `build_subtree()`. So let's check both! + + let control_root = control.get_inner_node(NodeIndex::root()); + + // That for loop should have left us with only one leaf subtree... + let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap(); + // which itself contains only one 'leaf'... + let [root_leaf]: [_; 1] = leaf_subtree.try_into().unwrap(); + // which matches the expected root. + assert_eq!(control.root(), root_leaf.hash); + + // Likewise `accumulated_nodes` should contain a node at the root index... + assert!(accumulated_nodes.contains_key(&NodeIndex::root())); + // and it should match our actual root. + let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap(); + assert_eq!(control_root, *test_root); + // And of course the root we got from each place should match. + assert_eq!(control.root(), root_leaf.hash); + } } From bfd64810b4b3b7a4b08f9cb4619d59c3c5104b5c Mon Sep 17 00:00:00 2001 From: Qyriad Date: Tue, 5 Nov 2024 16:17:41 -0700 Subject: [PATCH 17/21] smt: add from_raw_parts() to trait interface --- src/merkle/smt/full/mod.rs | 21 +++++++++++++++++++++ src/merkle/smt/mod.rs | 8 ++++++++ src/merkle/smt/simple/mod.rs | 21 +++++++++++++++++++++ 3 files changed, 50 insertions(+) diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 0b170d6..34a677c 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -81,6 +81,14 @@ impl Smt { >::with_entries(entries) } + pub fn from_raw_parts( + inner_nodes: BTreeMap, + leaves: BTreeMap, + root: RpoDigest, + ) -> Result { + >::from_raw_parts(inner_nodes, leaves, root) + } + // PUBLIC ACCESSORS // -------------------------------------------------------------------------------------------- @@ -247,6 +255,19 @@ impl SparseMerkleTree for Smt { const EMPTY_VALUE: Self::Value = EMPTY_WORD; const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0); + fn from_raw_parts( + inner_nodes: BTreeMap, + leaves: BTreeMap, + root: RpoDigest, + ) -> Result { + if cfg!(debug_assertions) { + let root_node = inner_nodes.get(&NodeIndex::root()).unwrap(); + assert_eq!(root_node.hash(), root); + } + + Ok(Self { root, inner_nodes, leaves }) + } + fn with_entries( entries: impl IntoIterator, ) -> Result { diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index c50dc27..a559fcf 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -299,6 +299,14 @@ pub(crate) trait SparseMerkleTree { where Self: Sized; + fn from_raw_parts( + inner_nodes: BTreeMap, + leaves: BTreeMap, + root: RpoDigest, + ) -> Result + where + Self: Sized; + /// The root of the tree fn root(&self) -> RpoDigest; diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index 0a30b04..62ce753 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -100,6 +100,14 @@ impl SimpleSmt { Ok(tree) } + pub fn from_raw_parts( + inner_nodes: BTreeMap, + leaves: BTreeMap, + root: RpoDigest, + ) -> Result { + >::from_raw_parts(inner_nodes, leaves, root) + } + /// Wrapper around [`SimpleSmt::with_leaves`] which inserts leaves at contiguous indices /// starting at index 0. pub fn with_contiguous_leaves( @@ -309,6 +317,19 @@ impl SparseMerkleTree for SimpleSmt { const EMPTY_VALUE: Self::Value = EMPTY_WORD; const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0); + fn from_raw_parts( + inner_nodes: BTreeMap, + leaves: BTreeMap, + root: RpoDigest, + ) -> Result { + if cfg!(debug_assertions) { + let root_node = inner_nodes.get(&NodeIndex::root()).unwrap(); + assert_eq!(root_node.hash(), root); + } + + Ok(Self { root, inner_nodes, leaves }) + } + fn with_entries( entries: impl IntoIterator, Word)>, ) -> Result { From af96aef74f178f7d58356de4ccaaa16195a6cc17 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Fri, 8 Nov 2024 12:53:28 -0700 Subject: [PATCH 18/21] improve docs for build_subtree() --- src/merkle/smt/full/mod.rs | 13 +++++++++++++ src/merkle/smt/mod.rs | 16 +++++++++------- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 34a677c..9c08f26 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -238,6 +238,19 @@ impl Smt { } } + /// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and + /// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and + /// `leaves` must not contain more than one depth-8 subtree's worth of leaves. + /// + /// This function will then calculate the inner nodes above each leaf for 8 layers, as well as + /// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into + /// itself. + /// + /// # Panics + /// With debug assertions on, this function panics under invalid inputs: if `leaves` contains + /// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to + /// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified + /// maximum depth (`DEPTH`), or if `leaves` is not sorted. pub fn build_subtree( leaves: Vec, bottom_depth: u8, diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index a559fcf..9d124f2 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -406,17 +406,19 @@ pub(crate) trait SparseMerkleTree { accumulator } - /// Builds Merkle nodes from a bottom layer of tuples of horizontal indices and their hashes, - /// sorted by their position. + /// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and + /// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and + /// `leaves` must not contain more than one depth-8 subtree's worth of leaves. /// - /// The leaves are 'conceptual' leaves, simply being entities at the bottom of some subtree, not - /// [`Self::Leaf`]. + /// This function will then calculate the inner nodes above each leaf for 8 layers, as well as + /// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into + /// itself. /// /// # Panics /// With debug assertions on, this function panics under invalid inputs: if `leaves` contains - /// more entries than can fit in a depth-8 subtree (more than 256), if `bottom_depth` is - /// lower in the tree than the specified maximum depth (`DEPTH`), or if `leaves` is not sorted. - // FIXME: more complete docstring. + /// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to + /// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified + /// maximum depth (`DEPTH`), or if `leaves` is not sorted. fn build_subtree( mut leaves: Vec, bottom_depth: u8, From 96d42a4a064d06f13d131ab7974dd762e8c00b4a Mon Sep 17 00:00:00 2001 From: Qyriad Date: Tue, 12 Nov 2024 14:05:07 -0700 Subject: [PATCH 19/21] add a parallel construction benchmark to src/main.rs --- src/main.rs | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/src/main.rs b/src/main.rs index 776ccc2..d362b7f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ use std::time::Instant; use clap::Parser; use miden_crypto::{ hash::rpo::{Rpo256, RpoDigest}, - merkle::{MerkleError, Smt}, + merkle::{MerkleError, NodeIndex, Smt}, Felt, Word, ONE, }; use rand_utils::rand_value; @@ -33,7 +33,9 @@ pub fn benchmark_smt() { entries.push((key, value)); } - let mut tree = construction(entries, tree_size).unwrap(); + let mut tree = construction(entries.clone(), tree_size).unwrap(); + let parallel = parallel_construction(entries, tree_size).unwrap(); + assert_eq!(tree, parallel); insertion(&mut tree, tree_size).unwrap(); batched_insertion(&mut tree, tree_size).unwrap(); proof_generation(&mut tree, tree_size).unwrap(); @@ -56,6 +58,31 @@ pub fn construction(entries: Vec<(RpoDigest, Word)>, size: u64) -> Result, + size: u64, +) -> Result { + println!("Running a parallel construction benchmark:"); + let now = Instant::now(); + + let (inner_nodes, leaves) = Smt::build_subtrees(entries); + let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash(); + + let leaves = leaves.into_iter().map(|(key, value)| (key.value(), value)).collect(); + + let tree = Smt::from_raw_parts(inner_nodes, leaves, root)?; + + let elapsed = now.elapsed(); + println!( + "Parallel-constructed an SMT with {} key-value pairs in {:.3} seconds", + size, + elapsed.as_secs_f32(), + ); + println!("Number of leaf nodes: {}\n", tree.leaves().count()); + + Ok(tree) +} + /// Runs the insertion benchmark for the [`Smt`]. pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> { println!("Running an insertion benchmark:"); From e6a6ad3712f290108bd85e6005cd69336ba0e263 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Tue, 12 Nov 2024 14:11:37 -0700 Subject: [PATCH 20/21] smt: add `build_subtrees()` to coordinate subtree building --- src/merkle/smt/full/mod.rs | 6 ++ src/merkle/smt/mod.rs | 51 +++++++++++++++ src/merkle/smt/simple/tests.rs | 110 ++++++++++++++++++++++++++++++++- 3 files changed, 164 insertions(+), 3 deletions(-) diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 9c08f26..c6a98f4 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -257,6 +257,12 @@ impl Smt { ) -> (BTreeMap, Vec) { >::build_subtree(leaves, bottom_depth) } + + pub fn build_subtrees( + entries: Vec<(RpoDigest, Word)>, + ) -> (BTreeMap, BTreeMap, SmtLeaf>) { + >::build_subtrees(entries) + } } impl SparseMerkleTree for Smt { diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 9d124f2..f5f659c 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -508,6 +508,57 @@ pub(crate) trait SparseMerkleTree { (inner_nodes, leaves) } + + fn build_subtrees( + mut entries: Vec<(Self::Key, Self::Value)>, + ) -> (BTreeMap, BTreeMap, Self::Leaf>) { + use rayon::prelude::*; + + entries.sort_by_key(|item| { + let index = Self::key_to_leaf_index(&item.0); + index.value() + }); + + let mut accumulated_nodes: BTreeMap = Default::default(); + + let PairComputations { + leaves: mut leaf_subtrees, + nodes: initial_leaves, + } = Self::sorted_pairs_to_leaves(entries); + + for current_depth in (8..=DEPTH).step_by(8).rev() { + let (nodes, subtrees): (Vec>, Vec>) = leaf_subtrees + .into_par_iter() + .map(|subtree| { + debug_assert!(subtree.is_sorted()); + debug_assert!(!subtree.is_empty()); + + let (nodes, next_leaves) = Self::build_subtree(subtree, current_depth); + + debug_assert!(next_leaves.is_sorted()); + + (nodes, next_leaves) + }) + .unzip(); + + let mut all_leaves: Vec = subtrees.into_iter().flatten().collect(); + leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect(); + accumulated_nodes.extend(nodes.into_iter().flatten()); + + debug_assert!(!leaf_subtrees.is_empty()); + } + + let leaves: BTreeMap, Self::Leaf> = initial_leaves + .into_iter() + .map(|(key, value)| { + // FIXME: unwrap is unreachable? + let key = LeafIndex::::new(key).unwrap(); + (key, value) + }) + .collect(); + + (accumulated_nodes, leaves) + } } // INNER NODE diff --git a/src/merkle/smt/simple/tests.rs b/src/merkle/smt/simple/tests.rs index b1dd28d..29c63be 100644 --- a/src/merkle/smt/simple/tests.rs +++ b/src/merkle/smt/simple/tests.rs @@ -1,3 +1,6 @@ +use core::mem; +use std::collections::BTreeMap; + use alloc::vec::Vec; use super::{ @@ -7,10 +10,11 @@ use super::{ use crate::{ hash::rpo::Rpo256, merkle::{ - digests_to_words, int_to_leaf, int_to_node, smt::SparseMerkleTree, EmptySubtreeRoots, - InnerNodeInfo, LeafIndex, MerkleTree, + digests_to_words, int_to_leaf, int_to_node, + smt::{self, InnerNode, PairComputations, SparseMerkleTree}, + EmptySubtreeRoots, InnerNodeInfo, LeafIndex, MerkleTree, SubtreeLeaf, }, - Word, EMPTY_WORD, + Felt, Word, EMPTY_WORD, ONE, }; // TEST DATA @@ -461,6 +465,106 @@ fn test_simplesmt_check_empty_root_constant() { assert_eq!(empty_root_1_depth, SimpleSmt::<1>::EMPTY_ROOT); } +#[test] +fn test_simplesmt_subtrees() { + const PAIR_COUNT: u64 = 4096; + const DEPTH: u8 = 64; + type SimpleSmt = super::SimpleSmt; + + let entries: Vec<(LeafIndex, Word)> = (0..PAIR_COUNT) + .map(|i| { + let leaf_index = ((i as f64 / PAIR_COUNT as f64) * (PAIR_COUNT as f64)) as u64; + let key = LeafIndex::new_max_depth(leaf_index); + let value: Word = [ONE, ONE, ONE, Felt::new(i)]; + (key, value) + }) + .collect(); + let leaves = entries.iter().map(|(key, value)| (key.value(), *value)); + + let control = SimpleSmt::with_leaves(leaves).unwrap(); + + let mut accumulated_nodes: BTreeMap = Default::default(); + + let PairComputations { + leaves: mut leaf_subtrees, + nodes: test_leaves, + } = SimpleSmt::sorted_pairs_to_leaves(entries); + + for current_depth in (8..=DEPTH).step_by(8).rev() { + for (i, subtree) in mem::take(&mut leaf_subtrees).into_iter().enumerate() { + // Pre-assertions. + assert!( + subtree.is_sorted(), + "subtree {i} at bottom-depth {current_depth} is not sorted", + ); + assert!(!subtree.is_empty(), "subtree {i} at bottom-depth {current_depth} is empty!"); + + // Do actual things. + let (nodes, next_leaves) = SimpleSmt::build_subtree(subtree, current_depth); + + // Post-assertions. + assert!(next_leaves.is_sorted()); + for (&index, test_node) in nodes.iter() { + let control_node = control.get_inner_node(index); + assert_eq!( + test_node, &control_node, + "depth {} subtree {}: test node does not match control at index {:?}", + current_depth, i, index, + ); + } + + // Update state. + accumulated_nodes.extend(nodes); + + for subtree_leaf in next_leaves { + smt::add_subtree_leaf(&mut leaf_subtrees, subtree_leaf); + } + } + + assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); + } + + // Make sure the true leaves match, checking length first and then each individual leaf. + let control_leaves: BTreeMap<_, _> = control.leaves().collect(); + let control_leaves_len = control_leaves.len(); + let test_leaves_len = test_leaves.len(); + assert_eq!(test_leaves_len, control_leaves_len); + for (col, ref test_leaf) in test_leaves { + let &control_leaf = control_leaves.get(&col).unwrap(); + assert_eq!(test_leaf, control_leaf); + } + + // Make sure inner nodes match, checking length first and then each individual leaf. + let control_nodes_len = control.inner_nodes().count(); + let test_nodes_len = accumulated_nodes.len(); + assert_eq!(test_nodes_len, control_nodes_len); + for (index, test_node) in accumulated_nodes.clone() { + let control_node = control.get_inner_node(index); + assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); + } + + // After the last iteration of the above for loop, we should have the new root actually in two + // places: one in `accumulated_nodes`, and the other as the "next leaves" return from + // `build_subtree()`. So let's check both! + + let control_root = control.get_inner_node(NodeIndex::root()); + + // That for loop should have left us with only one leaf subtree... + let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap(); + // which itself contains only one 'leaf'... + let [SubtreeLeaf { hash: test_root_hash, .. }]: [_; 1] = leaf_subtree.try_into().unwrap(); + // which matches the expected root. + assert_eq!(control.root(), test_root_hash); + + // Likewise `accumulated_nodes` should contain a node at the root index... + assert!(accumulated_nodes.contains_key(&NodeIndex::root())); + // and it should match our actual root. + let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap(); + assert_eq!(control_root, *test_root); + // And of course the root we got from each place should match. + assert_eq!(control.root(), test_root_hash); +} + // HELPER FUNCTIONS // -------------------------------------------------------------------------------------------- From 468bd98c124dd4d651cd3fc2331c0640a5424653 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Tue, 12 Nov 2024 14:05:38 -0700 Subject: [PATCH 21/21] add a parallel subtree criterion benchmark --- Cargo.toml | 4 +++ benches/parallel-subtree.rs | 72 +++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 benches/parallel-subtree.rs diff --git a/Cargo.toml b/Cargo.toml index 09fe0ec..e769597 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,6 +35,10 @@ harness = false name = "merkle" harness = false +[[bench]] +name = "parallel-subtree" +harness = false + [[bench]] name = "store" harness = false diff --git a/benches/parallel-subtree.rs b/benches/parallel-subtree.rs new file mode 100644 index 0000000..9f45b32 --- /dev/null +++ b/benches/parallel-subtree.rs @@ -0,0 +1,72 @@ +use std::{fmt::Debug, hint, mem, time::Duration}; + +use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion}; +use miden_crypto::{hash::rpo::RpoDigest, merkle::Smt, Felt, Word, ONE}; +use rand_utils::prng_array; +use winter_utils::Randomizable; + +// 2^0, 2^4, 2^8, 2^12, 2^16 +const PAIR_COUNTS: [u64; 6] = [1, 16, 256, 4096, 65536, 1_048_576]; + +fn smt_parallel_subtree(c: &mut Criterion) { + let mut seed = [0u8; 32]; + + let mut group = c.benchmark_group("parallel-subtrees"); + + for pair_count in PAIR_COUNTS { + let bench_id = BenchmarkId::from_parameter(pair_count); + group.bench_with_input(bench_id, &pair_count, |b, &pair_count| { + b.iter_batched( + || { + // Setup. + (0..pair_count) + .map(|i| { + let count = pair_count as f64; + let idx = ((i as f64 / count) * (count)) as u64; + let key = RpoDigest::new([ + generate_value(&mut seed), + ONE, + Felt::new(i), + Felt::new(idx), + ]); + let value = generate_word(&mut seed); + (key, value) + }) + .collect() + }, + |entries| { + // Benchmarked function. + let (leaves, inner_nodes) = Smt::build_subtrees(hint::black_box(entries)); + assert!(!leaves.is_empty()); + assert!(!inner_nodes.is_empty()); + }, + BatchSize::SmallInput, + ); + }); + } +} + +criterion_group! { + name = smt_subtree_group; + config = Criterion::default() + .measurement_time(Duration::from_secs(960)) + .sample_size(60) + .configure_from_args(); + targets = smt_parallel_subtree +} +criterion_main!(smt_subtree_group); + +// HELPER FUNCTIONS +// -------------------------------------------------------------------------------------------- + +fn generate_value(seed: &mut [u8; 32]) -> T { + mem::swap(seed, &mut prng_array(*seed)); + let value: [T; 1] = rand_utils::prng_array(*seed); + value[0] +} + +fn generate_word(seed: &mut [u8; 32]) -> Word { + mem::swap(seed, &mut prng_array(*seed)); + let nums: [u64; 4] = prng_array(*seed); + [Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])] +}