From 9007a54385b1b2b05b080018a2580b5d134b6958 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Mon, 28 Oct 2024 15:38:42 -0600 Subject: [PATCH] add sorted_pairs_to_leaves() and test for it --- src/merkle/smt/mod.rs | 259 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 259 insertions(+) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 03d9d45..88258b5 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -1,4 +1,7 @@ use alloc::{collections::BTreeMap, vec::Vec}; +use core::mem; + +use num::Integer; use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex}; use crate::{ @@ -346,6 +349,146 @@ pub(crate) trait SparseMerkleTree { /// /// The length `path` is guaranteed to be equal to `DEPTH` fn path_and_leaf_to_opening(path: MerklePath, leaf: Self::Leaf) -> Self::Opening; + + fn sorted_pairs_to_leaves( + pairs: Vec<(Self::Key, Self::Value)>, + ) -> PrecomputedLeaves { + let mut all_leaves = PrecomputedLeaves::default(); + + let mut buffer: Vec<(Self::Key, Self::Value)> = Default::default(); + + let mut iter = pairs.into_iter().peekable(); + while let Some((key, value)) = iter.next() { + let col = Self::key_to_leaf_index(&key).index.value(); + let next_col = iter.peek().map(|(key, _)| { + let index = Self::key_to_leaf_index(key); + index.index.value() + }); + + buffer.push((key, value)); + + if let Some(next_col) = next_col { + assert!(next_col >= col); + } + + if next_col == Some(col) { + // Keep going in our buffer. + continue; + } + + // Whether the next pair is a different column, or non-existent, we break off. + let leaf_pairs = mem::take(&mut buffer); + let leaf = Self::pairs_to_leaf(leaf_pairs); + let hash = Self::hash_leaf(&leaf); + + all_leaves.nodes.insert(col, leaf); + all_leaves.subtrees.push(SubtreeLeaf { col, hash }); + } + assert_eq!(buffer.len(), 0); + + all_leaves + } + + /// Builds Merkle nodes from a bottom layer of tuples of horizontal indices and their hashes, + /// sorted by their position. + /// + /// The leaves are 'conceptual' leaves, simply being entities at the bottom of some subtree, not + /// [`Self::Leaf`]. + /// + /// # Panics + /// With debug assertions on, this function panics under invalid inputs: if `leaves` contains + /// more entries than can fit in a depth-8 subtree (more than 256), if `bottom_depth` is + /// lower in the tree than the specified maximum depth (`DEPTH`), or if `leaves` is not sorted. + // FIXME: more complete docstring. + fn build_subtree( + mut leaves: Vec, + bottom_depth: u8, + ) -> (BTreeMap, Vec) { + debug_assert!(bottom_depth <= DEPTH); + debug_assert!(Integer::is_multiple_of(&bottom_depth, &8)); + debug_assert!(leaves.len() <= usize::pow(2, 8)); + + let subtree_root = bottom_depth - 8; + + let mut inner_nodes: BTreeMap = Default::default(); + + let mut next_leaves: Vec = Vec::with_capacity(leaves.len() / 2); + + for next_depth in (subtree_root..bottom_depth).rev() { + debug_assert!(next_depth <= bottom_depth); + + // `next_depth` is the stuff we're making. + // `current_depth` is the stuff we have. + let current_depth = next_depth + 1; + + let mut iter = leaves.drain(..).peekable(); + while let Some(first) = iter.next() { + // On non-continuous iterations, including the first iteration, `first_column` may + // be a left or right node. On subsequent continuous iterations, we will always call + // `iter.next()` twice. + + // On non-continuous iterations (including the very first iteration), this column + // could be either on the left or the right. If the next iteration is not + // discontinuous with our right node, then the next iteration's + + let is_right = first.col.is_odd(); + let (left, right) = if is_right { + // Discontinuous iteration: we have no left node, so it must be empty. + + let left = SubtreeLeaf { + col: first.col - 1, + hash: *EmptySubtreeRoots::entry(DEPTH, current_depth), + }; + let right = first; + + (left, right) + } else { + let left = first; + + let right_col = first.col + 1; + let right = match iter.peek().copied() { + Some(SubtreeLeaf { col, .. }) if col == right_col => { + // Our inputs must be sorted. + debug_assert!(left.col <= col); + // The next leaf in the iterator is our sibling. Use it and consume it! + iter.next().unwrap() + }, + // Otherwise, the leaves don't contain our sibling, so our sibling must be + // empty. + _ => SubtreeLeaf { + col: right_col, + hash: *EmptySubtreeRoots::entry(DEPTH, current_depth), + }, + }; + + (left, right) + }; + + let index = NodeIndex::new_unchecked(current_depth, left.col).parent(); + let node = InnerNode { left: left.hash, right: right.hash }; + let hash = node.hash(); + + let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, next_depth); + // If this hash is empty, then it doesn't become a new inner node, nor does it count + // as a leaf for the next depth. + if hash != equivalent_empty_hash { + inner_nodes.insert(index, node); + // FIXME: is it possible for this to end up not being sorted? I don't think so. + next_leaves.push(SubtreeLeaf { col: index.value(), hash }); + } + } + + // Stop borrowing `leaves`, so we can swap it. + // The iterator is empty at this point anyway. + drop(iter); + + // After each depth, consider the stuff we just made the new "leaves", and empty the + // other collection. + mem::swap(&mut leaves, &mut next_leaves); + } + + (inner_nodes, leaves) + } } // INNER NODE @@ -463,3 +606,119 @@ impl MutationSet { self.new_root } } + +// HELPERS +// ================================================================================================ +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Default)] +pub struct SubtreeLeaf { + pub col: u64, + pub hash: RpoDigest, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PrecomputedLeaves { + /// Literal leaves to be added to the sparse Merkle tree's internal mapping. + pub nodes: BTreeMap, + /// "Conceptual" leaves that will be used for computations. + pub subtrees: Vec, +} + +// Derive requires `L` to impl Default, even though we don't actually need that. +impl Default for PrecomputedLeaves { + fn default() -> Self { + Self { + nodes: Default::default(), + subtrees: Default::default(), + } + } +} + +// TESTS +// ================================================================================================ +#[cfg(test)] +mod test { + use alloc::vec::Vec; + + use super::SparseMerkleTree; + use crate::{ + hash::rpo::RpoDigest, + merkle::{smt::SubtreeLeaf, Smt, SmtLeaf, SMT_DEPTH}, + Felt, Word, EMPTY_WORD, ONE, + }; + + #[test] + fn test_sorted_pairs_to_leaves() { + let entries: Vec<(RpoDigest, Word)> = vec![ + (RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]), + (RpoDigest::new([ONE, ONE, ONE, Felt::new(17)]), [ONE; 4]), + // Leaf index collision. + (RpoDigest::new([ONE, ONE, Felt::new(10), Felt::new(20)]), [ONE; 4]), + (RpoDigest::new([ONE, ONE, Felt::new(20), Felt::new(20)]), [ONE; 4]), + // Normal single leaf again. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), + // Empty leaf. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(500)]), EMPTY_WORD), + ]; + let mut entries_iter = entries.iter().cloned(); + let mut next_entry = || entries_iter.next().unwrap(); + + let control_leaves: Vec = vec![ + SmtLeaf::Single(next_entry()), + SmtLeaf::Single(next_entry()), + SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(), + SmtLeaf::Single(next_entry()), + SmtLeaf::new_empty(Smt::key_to_leaf_index(&next_entry().0)), + ]; + + let control_subtree_leaves: Vec = control_leaves + .iter() + .map(|leaf| { + let col = leaf.index().index.value(); + let hash = leaf.hash(); + SubtreeLeaf { col, hash } + }) + .collect(); + + let test_subtree_leaves = Smt::sorted_pairs_to_leaves(entries).subtrees; + assert_eq!(control_subtree_leaves, test_subtree_leaves); + } + + #[test] + fn test_build_subtree_from_leaves() { + const PAIR_COUNT: u64 = u64::pow(2, 8); + + let entries: Vec<(RpoDigest, Word)> = (0..PAIR_COUNT) + .map(|i| { + let leaf_index = ((i as f64 / PAIR_COUNT as f64) * (PAIR_COUNT as f64)) as u64; + let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]); + let value = [ONE, ONE, ONE, Felt::new(i)]; + (key, value) + }) + .collect(); + + let control = Smt::with_entries(entries.clone()).unwrap(); + + let mut leaves: Vec = entries + .iter() + .map(|(key, value)| { + let leaf = SmtLeaf::new_single(*key, *value); + let col = leaf.index().index.value(); + let hash = leaf.hash(); + SubtreeLeaf { col, hash } + }) + .collect(); + leaves.sort(); + leaves.dedup_by_key(|leaf| leaf.col); + + let (first_subtree, _) = Smt::build_subtree(leaves, SMT_DEPTH); + assert!(!first_subtree.is_empty()); + + for (index, node) in first_subtree.into_iter() { + let control = control.get_inner_node(index); + assert_eq!( + control, node, + "subtree-computed node at index {index:?} does not match control", + ); + } + } +}