WIP(smt): add simple benchmark for single subtree computation

This commit is contained in:
Qyriad 2024-10-16 09:24:17 -06:00
parent a35c11abfe
commit 6addcd0226
5 changed files with 238 additions and 8 deletions

View file

@ -31,6 +31,10 @@ harness = false
name = "store"
harness = false
[[bench]]
name = "subtree"
harness = false
[features]
default = ["std", "async"]
executable = ["dep:clap", "dep:rand-utils", "std"]

66
benches/subtree.rs Normal file
View file

@ -0,0 +1,66 @@
use std::{collections::BTreeMap, sync::Arc, time::Duration};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use miden_crypto::{
hash::rpo::RpoDigest,
merkle::{NodeIndex, NodeSubtreeComputer, Smt, SparseMerkleTree},
Felt, Word, ONE,
};
const SUBTREE_INTERVAL: u8 = 8;
fn setup_subtree8(tree_size: u64) -> (Smt, NodeIndex, Arc<BTreeMap<RpoDigest, Word>>, RpoDigest) {
let entries: BTreeMap<RpoDigest, Word> = (0..tree_size)
.into_iter()
.map(|i| {
let leaf_index = u64::MAX / (i + 1);
let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]);
let value = [ONE, ONE, ONE, Felt::new(i)];
(key, value)
})
.collect();
let control = Smt::with_entries(entries.clone()).unwrap();
let subtree = entries
.keys()
.map(|key| {
let index_for_key = NodeIndex::from(Smt::key_to_leaf_index(key));
index_for_key.parent_n(SUBTREE_INTERVAL)
})
.next()
.unwrap();
let control_hash = control.get_inner_node(subtree).hash();
(Smt::new(), subtree, Arc::new(entries), control_hash)
}
fn bench_subtree8(
(smt, subtree, entries, control_hash): (
Smt,
NodeIndex,
Arc<BTreeMap<RpoDigest, Word>>,
RpoDigest,
),
) {
let mut state = NodeSubtreeComputer::with_smt(&smt, Default::default(), entries);
let hash = state.get_or_make_hash(subtree);
assert_eq!(control_hash, hash);
}
fn smt_subtree8(c: &mut Criterion) {
let mut group = c.benchmark_group("subtree8");
group.measurement_time(Duration::from_secs(120));
group.sample_size(30);
for &tree_size in [32, 128, 512, 1024].iter() {
let bench_id = BenchmarkId::from_parameter(tree_size);
//group.throughput(Throughput::Elements(tree_size));
group.bench_with_input(bench_id, &tree_size, |bench, &tree_size| {
bench.iter_batched(|| setup_subtree8(tree_size), bench_subtree8, BatchSize::SmallInput);
});
}
group.finish();
}
criterion_group!(subtree_group, smt_subtree8);
criterion_main!(subtree_group);

View file

@ -22,8 +22,8 @@ pub use path::{MerklePath, RootPath, ValuePath};
mod smt;
pub use smt::{
LeafIndex, MutationSet, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError,
SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
InnerNode, LeafIndex, MutationSet, NodeSubtreeComputer, SimpleSmt, Smt, SmtLeaf, SmtLeafError,
SmtProof, SmtProofError, SparseMerkleTree, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
};
mod mmr;

View file

@ -1,5 +1,5 @@
#[cfg(feature = "async")]
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};
use alloc::{
collections::{BTreeMap, BTreeSet},
@ -12,6 +12,9 @@ use super::{
MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
};
#[cfg(feature = "async")]
use super::NodeMutation;
mod error;
pub use error::{SmtLeafError, SmtProofError};
@ -297,6 +300,27 @@ impl Smt {
None
}
}
fn construct_prospective_leaf(
mut existing_leaf: SmtLeaf,
key: &RpoDigest,
value: &Word,
) -> SmtLeaf {
debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key));
match existing_leaf {
SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value),
_ => {
if *value != EMPTY_WORD {
existing_leaf.insert(*key, *value);
} else {
existing_leaf.remove(*key);
}
existing_leaf
},
}
}
}
impl SparseMerkleTree<SMT_DEPTH> for Smt {
@ -399,6 +423,141 @@ impl Default for Smt {
}
}
/// Just a [`NodeMutation`] with its hash already computed and stored.
#[cfg(feature = "async")]
pub struct ComputedNodeMutation {
pub mutation: NodeMutation,
pub hash: RpoDigest,
}
#[cfg(feature = "async")]
pub struct NodeSubtreeComputer {
inner_nodes: Arc<BTreeMap<NodeIndex, InnerNode>>,
leaves: Arc<BTreeMap<u64, SmtLeaf>>,
existing_mutations: Arc<HashMap<NodeIndex, ComputedNodeMutation>>,
new_mutations: HashMap<NodeIndex, ComputedNodeMutation>,
new_pairs: Arc<BTreeMap<RpoDigest, Word>>,
/// Cache indices we know to be dirty.
dirtied_indices: HashMap<NodeIndex, bool>,
cached_leaf_hashes: HashMap<LeafIndex<SMT_DEPTH>, RpoDigest>,
}
#[cfg(feature = "async")]
impl NodeSubtreeComputer {
pub fn with_smt(
smt: &Smt,
existing_mutations: Arc<HashMap<NodeIndex, ComputedNodeMutation>>,
new_pairs: Arc<BTreeMap<RpoDigest, Word>>,
) -> Self {
Self {
inner_nodes: Arc::clone(&smt.inner_nodes),
leaves: Arc::clone(&smt.leaves),
existing_mutations,
new_mutations: Default::default(),
new_pairs,
dirtied_indices: Default::default(),
cached_leaf_hashes: Default::default(),
}
}
pub(crate) fn is_index_dirty(&mut self, index_to_check: NodeIndex) -> bool {
if let Some(cached) = self.dirtied_indices.get(&index_to_check) {
return *cached;
}
// An index is dirty if there is a new pair at it, an known existing mutation at it, or an
// ancestor of one of those.
let is_dirty = self
.existing_mutations
.iter()
.map(|(index, _)| *index)
.chain(self.new_pairs.iter().map(|(key, _v)| Smt::key_to_leaf_index(key).index))
.filter(|&dirtied_index| index_to_check.contains(dirtied_index))
.next()
.is_some();
// This is somewhat expensive to compute, so cache this.
self.dirtied_indices.insert(index_to_check, is_dirty);
is_dirty
}
pub(crate) fn get_effective_leaf(&self, index: LeafIndex<SMT_DEPTH>) -> SmtLeaf {
let pairs_at_index = self
.new_pairs
.iter()
.filter(|&(new_key, _)| Smt::key_to_leaf_index(new_key) == index);
let existing_leaf = self
.leaves
.get(&index.index.value())
.cloned()
.unwrap_or_else(|| SmtLeaf::new_empty(index));
pairs_at_index.fold(existing_leaf, |acc, (k, v)| {
let existing_leaf = acc.clone();
Smt::construct_prospective_leaf(existing_leaf, k, v)
})
}
/// Does NOT check `new_mutations`.
pub(crate) fn get_clean_hash(&self, index: NodeIndex) -> Option<RpoDigest> {
self.existing_mutations
.get(&index)
.map(|ComputedNodeMutation { hash, .. }| *hash)
.or_else(|| self.inner_nodes.get(&index).map(|inner_node| InnerNode::hash(&inner_node)))
}
/// Retrieve a cached hash, or recursively compute it.
pub fn get_or_make_hash(&mut self, index: NodeIndex) -> RpoDigest {
use NodeMutation::*;
// If this is a leaf, then only do leaf stuff.
if index.depth() == SMT_DEPTH {
let index = LeafIndex::new(index.value()).unwrap();
return match self.cached_leaf_hashes.get(&index) {
Some(cached_hash) => cached_hash.clone(),
None => {
let leaf = self.get_effective_leaf(index);
let hash = Smt::hash_leaf(&leaf);
self.cached_leaf_hashes.insert(index, hash);
hash
},
};
}
// If we already computed this one earlier as a mutation, just return it.
if let Some(ComputedNodeMutation { hash, .. }) = self.new_mutations.get(&index) {
return *hash;
}
// Otherwise, we need to know if this node is one of the nodes we're in the process of
// recomputing, or if we can safely use the node already in the Merkle tree.
if !self.is_index_dirty(index) {
return self
.get_clean_hash(index)
.unwrap_or_else(|| *EmptySubtreeRoots::entry(SMT_DEPTH, index.depth()));
}
// If we got here, then we have to make, rather than get, this hash.
// Make sure we mark this index as now dirty.
self.dirtied_indices.insert(index, true);
// Recurse for the left and right sides.
let left = self.get_or_make_hash(index.left_child());
let right = self.get_or_make_hash(index.right_child());
let node = InnerNode { left, right };
let hash = node.hash();
let &equivalent_empty_hash = EmptySubtreeRoots::entry(SMT_DEPTH, index.depth());
let is_removal = hash == equivalent_empty_hash;
let new_entry = if is_removal { Removal } else { Addition(node) };
self.new_mutations
.insert(index, ComputedNodeMutation { hash, mutation: new_entry });
hash
}
}
// CONVERSIONS
// ================================================================================================

View file

@ -7,7 +7,9 @@ use crate::{
};
mod full;
pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH};
pub use full::{
NodeSubtreeComputer, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH,
};
mod simple;
pub use simple::SimpleSmt;
@ -43,7 +45,7 @@ pub const SMT_MAX_DEPTH: u8 = 64;
/// must accomodate all keys that map to the same leaf.
///
/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
pub trait SparseMerkleTree<const DEPTH: u8> {
/// The type for a key
type Key: Clone + Ord;
/// The type for a value
@ -346,7 +348,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub(crate) struct InnerNode {
pub struct InnerNode {
pub left: RpoDigest,
pub right: RpoDigest,
}
@ -459,7 +461,7 @@ impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
/// [`MutationSet`] stores this type in relation to a [`NodeIndex`] to keep track of what changes
/// need to occur at which node indices.
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum NodeMutation {
pub enum NodeMutation {
/// Corresponds to [`SparseMerkleTree::remove_inner_node()`].
Removal,
/// Corresponds to [`SparseMerkleTree::insert_inner_node()`].
@ -499,7 +501,6 @@ impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;