smt: implement single subtree-8 hashing, w/ benchmarks & tests

This will be composed into depth-8-subtree-based computation of entire
sparse Merkle trees.
This commit is contained in:
Qyriad 2024-11-14 14:04:15 -07:00
parent 6de9c95f4c
commit 8c8167fcaf
6 changed files with 335 additions and 6 deletions

View file

@ -27,6 +27,10 @@ harness = false
name = "smt"
harness = false
[[bench]]
name = "smt-subtree"
harness = false
[[bench]]
name = "store"
harness = false

136
benches/smt-subtree.rs Normal file
View file

@ -0,0 +1,136 @@
use std::{fmt::Debug, hint, mem, time::Duration};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use miden_crypto::{
hash::rpo::RpoDigest,
merkle::{NodeIndex, Smt, SmtLeaf, SubtreeLeaf, SMT_DEPTH},
Felt, Word, ONE,
};
use rand_utils::prng_array;
use winter_utils::Randomizable;
const PAIR_COUNTS: [u64; 5] = [1, 64, 128, 192, 256];
fn smt_subtree_even(c: &mut Criterion) {
let mut seed = [0u8; 32];
let mut group = c.benchmark_group("subtree8-even");
for pair_count in PAIR_COUNTS {
let bench_id = BenchmarkId::from_parameter(pair_count);
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
b.iter_batched(
|| {
// Setup.
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
.map(|n| {
// A single depth-8 subtree can have a maximum of 255 leaves.
let leaf_index = ((n as f64 / pair_count as f64) * 255.0) as u64;
let key = RpoDigest::new([
generate_value(&mut seed),
ONE,
Felt::new(n),
Felt::new(leaf_index),
]);
let value = generate_word(&mut seed);
(key, value)
})
.collect();
let mut leaves: Vec<_> = entries
.iter()
.map(|(key, value)| {
let leaf = SmtLeaf::new_single(*key, *value);
let col = NodeIndex::from(leaf.index()).value();
let hash = leaf.hash();
SubtreeLeaf { col, hash }
})
.collect();
leaves.sort();
leaves.dedup_by_key(|leaf| leaf.col);
leaves
},
|leaves| {
// Benchmarked function.
let (subtree, _) =
Smt::build_subtree(hint::black_box(leaves), hint::black_box(SMT_DEPTH));
assert!(!subtree.is_empty());
},
BatchSize::SmallInput,
);
});
}
}
fn smt_subtree_random(c: &mut Criterion) {
let mut seed = [0u8; 32];
let mut group = c.benchmark_group("subtree8-rand");
for pair_count in PAIR_COUNTS {
let bench_id = BenchmarkId::from_parameter(pair_count);
group.bench_with_input(bench_id, &pair_count, |b, &pair_count| {
b.iter_batched(
|| {
// Setup.
let entries: Vec<(RpoDigest, Word)> = (0..pair_count)
.map(|i| {
let leaf_index: u8 = generate_value(&mut seed);
let key = RpoDigest::new([
ONE,
ONE,
Felt::new(i),
Felt::new(leaf_index as u64),
]);
let value = generate_word(&mut seed);
(key, value)
})
.collect();
let mut leaves: Vec<_> = entries
.iter()
.map(|(key, value)| {
let leaf = SmtLeaf::new_single(*key, *value);
let col = NodeIndex::from(leaf.index()).value();
let hash = leaf.hash();
SubtreeLeaf { col, hash }
})
.collect();
leaves.sort();
leaves
},
|leaves| {
let (subtree, _) =
Smt::build_subtree(hint::black_box(leaves), hint::black_box(SMT_DEPTH));
assert!(!subtree.is_empty());
},
BatchSize::SmallInput,
);
});
}
}
criterion_group! {
name = smt_subtree_group;
config = Criterion::default()
.measurement_time(Duration::from_secs(40))
.sample_size(60)
.configure_from_args();
targets = smt_subtree_even, smt_subtree_random
}
criterion_main!(smt_subtree_group);
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------
fn generate_value<T: Copy + Debug + Randomizable>(seed: &mut [u8; 32]) -> T {
mem::swap(seed, &mut prng_array(*seed));
let value: [T; 1] = rand_utils::prng_array(*seed);
value[0]
}
fn generate_word(seed: &mut [u8; 32]) -> Word {
mem::swap(seed, &mut prng_array(*seed));
let nums: [u64; 4] = prng_array(*seed);
[Felt::new(nums[0]), Felt::new(nums[1]), Felt::new(nums[2]), Felt::new(nums[3])]
}

View file

@ -23,7 +23,7 @@ pub use path::{MerklePath, RootPath, ValuePath};
mod smt;
pub use smt::{
LeafIndex, MutationSet, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError,
SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
SubtreeLeaf, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
};
mod mmr;

View file

@ -6,7 +6,7 @@ use alloc::{
use super::{
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath,
MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, SubtreeLeaf, Word, EMPTY_WORD,
};
mod error;
@ -249,6 +249,30 @@ impl Smt {
None
}
}
/// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and
/// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and
/// `leaves` must not contain more than one depth-8 subtree's worth of leaves.
///
/// This function will then calculate the inner nodes above each leaf for 8 layers, as well as
/// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into
/// itself.
///
/// # Panics
/// With debug assertions on, this function panics under invalid inputs: if `leaves` contains
/// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to
/// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified
/// maximum depth (`DEPTH`), or if `leaves` is not sorted.
///
/// This function is public so functions returning it can be used in tests and benchmarks, but
/// is otherwise not part of the public API.
#[doc(hidden)]
pub fn build_subtree(
leaves: Vec<SubtreeLeaf>,
bottom_depth: u8,
) -> (BTreeMap<NodeIndex, InnerNode>, Vec<SubtreeLeaf>) {
<Self as SparseMerkleTree<SMT_DEPTH>>::build_subtree(leaves, bottom_depth)
}
}
impl SparseMerkleTree<SMT_DEPTH> for Smt {

View file

@ -410,14 +410,119 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect();
accumulator
}
/// Builds Merkle nodes from a bottom layer of "leaves" -- represented by a horizontal index and
/// the hash of the leaf at that index. `leaves` *must* be sorted by horizontal index, and
/// `leaves` must not contain more than one depth-8 subtree's worth of leaves.
///
/// This function will then calculate the inner nodes above each leaf for 8 layers, as well as
/// the "leaves" for the next 8-deep subtree, so this function can effectively be chained into
/// itself.
///
/// # Panics
/// With debug assertions on, this function panics under invalid inputs: if `leaves` contains
/// more entries than can fit in a depth-8 subtree, if `leaves` contains leaves belonging to
/// different depth-8 subtrees, if `bottom_depth` is lower in the tree than the specified
/// maximum depth (`DEPTH`), or if `leaves` is not sorted.
fn build_subtree(
mut leaves: Vec<SubtreeLeaf>,
bottom_depth: u8,
) -> (BTreeMap<NodeIndex, InnerNode>, Vec<SubtreeLeaf>) {
debug_assert!(bottom_depth <= DEPTH);
debug_assert!(Integer::is_multiple_of(&bottom_depth, &SUBTREE_DEPTH));
debug_assert!(leaves.len() <= usize::pow(2, SUBTREE_DEPTH as u32));
let subtree_root = bottom_depth - SUBTREE_DEPTH;
let mut inner_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
let mut next_leaves: Vec<SubtreeLeaf> = Vec::with_capacity(leaves.len() / 2);
for next_depth in (subtree_root..bottom_depth).rev() {
debug_assert!(next_depth <= bottom_depth);
// `next_depth` is the stuff we're making.
// `current_depth` is the stuff we have.
let current_depth = next_depth + 1;
let mut iter = leaves.drain(..).peekable();
while let Some(first) = iter.next() {
// On non-continuous iterations, including the first iteration, `first_column` may
// be a left or right node. On subsequent continuous iterations, we will always call
// `iter.next()` twice.
// On non-continuous iterations (including the very first iteration), this column
// could be either on the left or the right. If the next iteration is not
// discontinuous with our right node, then the next iteration's
let is_right = first.col.is_odd();
let (left, right) = if is_right {
// Discontinuous iteration: we have no left node, so it must be empty.
let left = SubtreeLeaf {
col: first.col - 1,
hash: *EmptySubtreeRoots::entry(DEPTH, current_depth),
};
let right = first;
(left, right)
} else {
let left = first;
let right_col = first.col + 1;
let right = match iter.peek().copied() {
Some(SubtreeLeaf { col, .. }) if col == right_col => {
// Our inputs must be sorted.
debug_assert!(left.col <= col);
// The next leaf in the iterator is our sibling. Use it and consume it!
iter.next().unwrap()
},
// Otherwise, the leaves don't contain our sibling, so our sibling must be
// empty.
_ => SubtreeLeaf {
col: right_col,
hash: *EmptySubtreeRoots::entry(DEPTH, current_depth),
},
};
(left, right)
};
let index = NodeIndex::new_unchecked(current_depth, left.col).parent();
let node = InnerNode { left: left.hash, right: right.hash };
let hash = node.hash();
let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, next_depth);
// If this hash is empty, then it doesn't become a new inner node, nor does it count
// as a leaf for the next depth.
if hash != equivalent_empty_hash {
inner_nodes.insert(index, node);
next_leaves.push(SubtreeLeaf { col: index.value(), hash });
}
}
// Stop borrowing `leaves`, so we can swap it.
// The iterator is empty at this point anyway.
drop(iter);
// After each depth, consider the stuff we just made the new "leaves", and empty the
// other collection.
mem::swap(&mut leaves, &mut next_leaves);
}
(inner_nodes, leaves)
}
}
// INNER NODE
// ================================================================================================
/// This struct is public so functions returning it can be used in `benches/`, but is otherwise not
/// part of the public API.
#[doc(hidden)]
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub(crate) struct InnerNode {
pub struct InnerNode {
pub left: RpoDigest,
pub right: RpoDigest,
}
@ -530,8 +635,11 @@ impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
// SUBTREES
// ================================================================================================
/// A subtree is of depth 8.
const SUBTREE_DEPTH: u8 = 8;
/// A depth-8 subtree contains 256 "columns" that can possibly be occupied.
const COLS_PER_SUBTREE: u64 = u64::pow(2, 8);
const COLS_PER_SUBTREE: u64 = u64::pow(2, SUBTREE_DEPTH as u32);
/// Helper struct for organizing the data we care about when computing Merkle subtrees.
///

View file

@ -1,8 +1,15 @@
use alloc::{collections::BTreeMap, vec::Vec};
use super::{PairComputations, SmtLeaf, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter};
use super::{
NodeIndex, PairComputations, SmtLeaf, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter,
COLS_PER_SUBTREE, SUBTREE_DEPTH,
};
use crate::{hash::rpo::RpoDigest, merkle::Smt, Felt, Word, ONE};
use crate::{
hash::rpo::RpoDigest,
merkle::{Smt, SMT_DEPTH},
Felt, Word, ONE,
};
fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf {
SubtreeLeaf {
@ -90,3 +97,53 @@ fn test_sorted_pairs_to_leaves() {
assert_eq!(control_leaf, &test_leaf);
}
}
// Helper for the below tests.
fn generate_entries(pair_count: u64) -> Vec<(RpoDigest, Word)> {
(0..pair_count)
.map(|i| {
let leaf_index = ((i as f64 / pair_count as f64) * (pair_count as f64)) as u64;
let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]);
let value = [ONE, ONE, ONE, Felt::new(i)];
(key, value)
})
.collect()
}
#[test]
fn test_single_subtree() {
// A single subtree's worth of leaves.
const PAIR_COUNT: u64 = COLS_PER_SUBTREE;
let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries(entries.clone()).unwrap();
// `entries` should already be sorted by nature of how we constructed it.
let leaves = Smt::sorted_pairs_to_leaves(entries).leaves;
let leaves = leaves.into_iter().next().unwrap();
let (first_subtree, next_leaves) = Smt::build_subtree(leaves, SMT_DEPTH);
assert!(!first_subtree.is_empty());
// The inner nodes computed from that subtree should match the nodes in our control tree.
for (index, node) in first_subtree.into_iter() {
let control = control.get_inner_node(index);
assert_eq!(
control, node,
"subtree-computed node at index {index:?} does not match control",
);
}
// The "next leaves" returned should also have matching hashes from the equivalent nodes in
// our control tree.
for SubtreeLeaf { col, hash } in next_leaves {
let index = NodeIndex::new(SMT_DEPTH - SUBTREE_DEPTH, col).unwrap();
let control_node = control.get_inner_node(index);
let control = control_node.hash();
assert_eq!(
control, hash,
"subtree-computed next leaf at index {index:?} does not match control",
);
}
}