smt: add sorted_pairs_to_leaves() and test for it

This commit is contained in:
Qyriad 2024-11-13 15:32:48 -07:00
parent ae772d2af6
commit 6de9c95f4c
2 changed files with 247 additions and 0 deletions

View file

@ -1,4 +1,7 @@
use alloc::{collections::BTreeMap, vec::Vec};
use core::mem;
use num::Integer;
use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex};
use crate::{
@ -346,6 +349,67 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
///
/// The length `path` is guaranteed to be equal to `DEPTH`
fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening;
/// Performs the initial transforms for constructing a [`SparseMerkleTree`] by composing
/// subtrees. In other words, this function takes the key-value inputs to the tree, and produces
/// the inputs to feed into [`SparseMerkleTree::build_subtree()`].
///
/// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If
/// `pairs` is not correctly sorted, the returned computations will be incorrect.
///
/// # Panics
/// With debug assertions on, this function panics if it detects that `pairs` is not correctly
/// sorted. Without debug assertions, the returned computations will be incorrect.
fn sorted_pairs_to_leaves(
pairs: Vec<(Self::Key, Self::Value)>,
) -> PairComputations<u64, Self::Leaf> {
debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value()));
let mut accumulator: PairComputations<u64, Self::Leaf> = Default::default();
let mut accumulated_leaves: Vec<SubtreeLeaf> = Default::default();
// As we iterate, we'll keep track of the kv-pairs we've seen so far that correspond to a
// single leaf. When we see a pair that's in a different leaf, we'll swap these pairs
// out and store them in our accumulated leaves.
let mut current_leaf_buffer: Vec<(Self::Key, Self::Value)> = Default::default();
let mut iter = pairs.into_iter().peekable();
while let Some((key, value)) = iter.next() {
let col = Self::key_to_leaf_index(&key).index.value();
let peeked_col = iter.peek().map(|(key, _v)| {
let index = Self::key_to_leaf_index(key);
let next_col = index.index.value();
// We panic if `pairs` is not sorted by column.
debug_assert!(next_col >= col);
next_col
});
current_leaf_buffer.push((key, value));
// If the next pair is the same column as this one, then we're done after adding this
// pair to the buffer.
if peeked_col == Some(col) {
continue;
}
// Otherwise, the next pair is a different column, or there is no next pair. Either way
// it's time to swap out our buffer.
let leaf_pairs = mem::take(&mut current_leaf_buffer);
let leaf = Self::pairs_to_leaf(leaf_pairs);
let hash = Self::hash_leaf(&leaf);
accumulator.nodes.insert(col, leaf);
accumulated_leaves.push(SubtreeLeaf { col, hash });
debug_assert!(current_leaf_buffer.is_empty());
}
// TODO: determine is there is any notable performance difference between computing
// subtree boundaries after the fact as an iterator adapter (like this), versus computing
// subtree boundaries as we go. Either way this function is only used at the beginning of a
// parallel construction, so it should not be a critical path.
accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect();
accumulator
}
}
// INNER NODE
@ -463,3 +527,94 @@ impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
self.new_root
}
}
// SUBTREES
// ================================================================================================
/// A depth-8 subtree contains 256 "columns" that can possibly be occupied.
const COLS_PER_SUBTREE: u64 = u64::pow(2, 8);
/// Helper struct for organizing the data we care about when computing Merkle subtrees.
///
/// Note that these represet "conceptual" leaves of some subtree, not necessarily
/// [`SparseMerkleTree::Leaf`].
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)]
pub struct SubtreeLeaf {
/// The 'value' field of [`NodeIndex`]. When computing a subtree, the depth is already known.
pub col: u64,
/// The hash of the node this `SubtreeLeaf` represents.
pub hash: RpoDigest,
}
/// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`].
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct PairComputations<K, L> {
/// Literal leaves to be added to the sparse Merkle tree's internal mapping.
pub nodes: BTreeMap<K, L>,
/// "Conceptual" leaves that will be used for computations.
pub leaves: Vec<Vec<SubtreeLeaf>>,
}
// Derive requires `L` to impl Default, even though we don't actually need that.
impl<K, L> Default for PairComputations<K, L> {
fn default() -> Self {
Self {
nodes: Default::default(),
leaves: Default::default(),
}
}
}
#[derive(Debug)]
struct SubtreeLeavesIter<'s> {
leaves: core::iter::Peekable<alloc::vec::Drain<'s, SubtreeLeaf>>,
}
impl<'s> SubtreeLeavesIter<'s> {
fn from_leaves(leaves: &'s mut Vec<SubtreeLeaf>) -> Self {
// TODO: determine if there is any notable performance difference between taking a Vec,
// which many need flattening first, vs storing a `Box<dyn Iterator<Item = SubtreeLeaf>>`.
// The latter may have self-referential properties that are impossible to express in purely
// safe Rust Rust.
Self { leaves: leaves.drain(..).peekable() }
}
}
impl<'s> core::iter::Iterator for SubtreeLeavesIter<'s> {
type Item = Vec<SubtreeLeaf>;
/// Each `next()` collects an entire subtree.
fn next(&mut self) -> Option<Vec<SubtreeLeaf>> {
let mut subtree: Vec<SubtreeLeaf> = Default::default();
let mut last_subtree_col = 0;
while let Some(leaf) = self.leaves.peek() {
last_subtree_col = u64::max(1, last_subtree_col);
let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE);
let next_subtree_col = if is_exact_multiple {
u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE)
} else {
last_subtree_col.next_multiple_of(COLS_PER_SUBTREE)
};
last_subtree_col = leaf.col;
if leaf.col < next_subtree_col {
subtree.push(self.leaves.next().unwrap());
} else if subtree.is_empty() {
continue;
} else {
break;
}
}
if subtree.is_empty() {
debug_assert!(self.leaves.peek().is_none());
return None;
}
Some(subtree)
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
mod tests;

92
src/merkle/smt/tests.rs Normal file
View file

@ -0,0 +1,92 @@
use alloc::{collections::BTreeMap, vec::Vec};
use super::{PairComputations, SmtLeaf, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter};
use crate::{hash::rpo::RpoDigest, merkle::Smt, Felt, Word, ONE};
fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf {
SubtreeLeaf {
col: leaf.index().index.value(),
hash: leaf.hash(),
}
}
#[test]
fn test_sorted_pairs_to_leaves() {
let entries: Vec<(RpoDigest, Word)> = vec![
// Subtree 0.
(RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]),
(RpoDigest::new([ONE, ONE, ONE, Felt::new(17)]), [ONE; 4]),
// Leaf index collision.
(RpoDigest::new([ONE, ONE, Felt::new(10), Felt::new(20)]), [ONE; 4]),
(RpoDigest::new([ONE, ONE, Felt::new(20), Felt::new(20)]), [ONE; 4]),
// Subtree 1. Normal single leaf again.
(RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), // Subtree boundary.
(RpoDigest::new([ONE, ONE, ONE, Felt::new(401)]), [ONE; 4]),
// Subtree 2. Another normal leaf.
(RpoDigest::new([ONE, ONE, ONE, Felt::new(1024)]), [ONE; 4]),
];
let control = Smt::with_entries(entries.clone()).unwrap();
let control_leaves: Vec<SmtLeaf> = {
let mut entries_iter = entries.iter().cloned();
let mut next_entry = || entries_iter.next().unwrap();
let control_leaves = vec![
// Subtree 0.
SmtLeaf::Single(next_entry()),
SmtLeaf::Single(next_entry()),
SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(),
// Subtree 1.
SmtLeaf::Single(next_entry()),
SmtLeaf::Single(next_entry()),
// Subtree 2.
SmtLeaf::Single(next_entry()),
];
assert_eq!(entries_iter.next(), None);
control_leaves
};
let control_subtree_leaves: Vec<Vec<SubtreeLeaf>> = {
let mut control_leaves_iter = control_leaves.iter();
let mut next_leaf = || control_leaves_iter.next().unwrap();
let control_subtree_leaves: Vec<Vec<SubtreeLeaf>> = [
// Subtree 0.
vec![next_leaf(), next_leaf(), next_leaf()],
// Subtree 1.
vec![next_leaf(), next_leaf()],
// Subtree 2.
vec![next_leaf()],
]
.map(|subtree| subtree.into_iter().map(smtleaf_to_subtree_leaf).collect())
.to_vec();
assert_eq!(control_leaves_iter.next(), None);
control_subtree_leaves
};
let subtrees: PairComputations<u64, SmtLeaf> = Smt::sorted_pairs_to_leaves(entries);
// This will check that the hashes, columns, and subtree assignments all match.
assert_eq!(subtrees.leaves, control_subtree_leaves);
// Flattening and re-separating out the leaves into subtrees should have the same result.
let mut all_leaves: Vec<SubtreeLeaf> = subtrees.leaves.clone().into_iter().flatten().collect();
let re_grouped: Vec<Vec<_>> = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect();
assert_eq!(subtrees.leaves, re_grouped);
// Then finally we might as well check the computed leaf nodes too.
let control_leaves: BTreeMap<u64, SmtLeaf> = control
.leaves()
.map(|(index, value)| (index.index.value(), value.clone()))
.collect();
for (column, test_leaf) in subtrees.nodes {
if test_leaf.is_empty() {
continue;
}
let control_leaf = control_leaves
.get(&column)
.unwrap_or_else(|| panic!("no leaf node found for column {column}"));
assert_eq!(control_leaf, &test_leaf);
}
}