diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index a658d28..a95e7cb 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -353,40 +353,42 @@ pub(crate) trait SparseMerkleTree { fn sorted_pairs_to_leaves( pairs: Vec<(Self::Key, Self::Value)>, ) -> PairComputations { - let mut all_leaves = PairComputations::default(); + let mut accumulator: PairComputations = Default::default(); - let mut buffer: Vec<(Self::Key, Self::Value)> = Default::default(); + // The kv-pairs we've seen so far that correspond to a single leaf. + let mut current_leaf_buffer: Vec<(Self::Key, Self::Value)> = Default::default(); let mut iter = pairs.into_iter().peekable(); while let Some((key, value)) = iter.next() { let col = Self::key_to_leaf_index(&key).index.value(); - let next_col = iter.peek().map(|(key, _)| { + let peeked_col = iter.peek().map(|(key, _v)| { let index = Self::key_to_leaf_index(key); - index.index.value() + let next_col = index.index.value(); + // We panic if `pairs` is not sorted by column. + debug_assert!(next_col >= col); + next_col }); + current_leaf_buffer.push((key, value)); - buffer.push((key, value)); - - if let Some(next_col) = next_col { - assert!(next_col >= col); - } - - if next_col == Some(col) { - // Keep going in our buffer. + // If the next pair is the same column as this one, then we're done after adding this + // pair to the buffer. + if peeked_col == Some(col) { continue; } - // Whether the next pair is a different column, or non-existent, we break off. - let leaf_pairs = mem::take(&mut buffer); + // Otherwise, the next pair is a different column, or there is no next pair. Either way + // it's time to swap out our buffer. + let leaf_pairs = mem::take(&mut current_leaf_buffer); let leaf = Self::pairs_to_leaf(leaf_pairs); let hash = Self::hash_leaf(&leaf); - all_leaves.nodes.insert(col, leaf); - all_leaves.subtrees.push(SubtreeLeaf { col, hash }); - } - assert_eq!(buffer.len(), 0); + accumulator.nodes.insert(col, leaf); + accumulator.add_leaf(SubtreeLeaf { col, hash }); - all_leaves + debug_assert!(current_leaf_buffer.is_empty()); + } + + accumulator } /// Builds Merkle nodes from a bottom layer of tuples of horizontal indices and their hashes, @@ -615,39 +617,55 @@ pub struct SubtreeLeaf { pub hash: RpoDigest, } +impl SubtreeLeaf { + #[cfg_attr(not(test), allow(dead_code))] + fn from_smt_leaf(leaf: &crate::merkle::SmtLeaf) -> Self { + Self { + col: leaf.index().index.value(), + hash: leaf.hash(), + } + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub struct PairComputations { /// Literal leaves to be added to the sparse Merkle tree's internal mapping. pub nodes: BTreeMap, /// "Conceptual" leaves that will be used for computations. - pub subtrees: Vec, + pub leaves: Vec>, } impl PairComputations { - #[cfg_attr(not(test), allow(dead_code))] - pub fn split_at_column(mut self, col: u64) -> (Self, Self) { - let split_point = match self.subtrees.binary_search_by_key(&col, |key| key.col) { - // Surprisingly, Result has no method like `unwrap_or_unwrap_err() where T == E`. - // Probably because there's no way to write that where bound. - Ok(split_point) | Err(split_point) => split_point, + pub fn add_leaf(&mut self, leaf: SubtreeLeaf) { + // A depth-8 subtree contains 256 "columns" that can possibly be occupied. + const COLS_PER_SUBTREE: u64 = u64::pow(2, 8); + + let last_subtree = match self.leaves.last_mut() { + // Base case. + None => { + self.leaves.push(vec![leaf]); + return; + }, + Some(last_subtree) => last_subtree, }; - let subtrees_right = self.subtrees.split_off(split_point); - let subtrees_left = self.subtrees; + debug_assert!(!last_subtree.is_empty()); + debug_assert!(last_subtree.len() <= COLS_PER_SUBTREE as usize); - let nodes_right = self.nodes.split_off(&col); - let nodes_left = self.nodes; - - let left = Self { - nodes: nodes_left, - subtrees: subtrees_left, + // The multiple of 256 after 0 is 1, but 0 and 1 do not belong to different subtrees. + let last_subtree_col = u64::max(1, last_subtree.last().unwrap().col); + let next_subtree_col = if last_subtree_col.is_multiple_of(&COLS_PER_SUBTREE) { + u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE) + } else { + last_subtree_col.next_multiple_of(COLS_PER_SUBTREE) }; - let right = Self { - nodes: nodes_right, - subtrees: subtrees_right, - }; - - (left, right) + if leaf.col < next_subtree_col { + last_subtree.push(leaf); + } else { + //std::eprintln!("\tcreating new subtree for column {}", leaf.col); + let next_subtree = vec![leaf]; + self.leaves.push(next_subtree); + } } } @@ -656,7 +674,7 @@ impl Default for PairComputations { fn default() -> Self { Self { nodes: Default::default(), - subtrees: Default::default(), + leaves: Default::default(), } } } @@ -665,50 +683,91 @@ impl Default for PairComputations { // ================================================================================================ #[cfg(test)] mod test { - use alloc::vec::Vec; + use alloc::{collections::BTreeMap, vec::Vec}; use super::SparseMerkleTree; use crate::{ hash::rpo::RpoDigest, - merkle::{smt::SubtreeLeaf, Smt, SmtLeaf, SMT_DEPTH}, - Felt, Word, EMPTY_WORD, ONE, + merkle::{ + smt::{PairComputations, SubtreeLeaf}, + Smt, SmtLeaf, SMT_DEPTH, + }, + Felt, Word, ONE, }; #[test] fn test_sorted_pairs_to_leaves() { let entries: Vec<(RpoDigest, Word)> = vec![ + // Subtree 0. (RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]), (RpoDigest::new([ONE, ONE, ONE, Felt::new(17)]), [ONE; 4]), // Leaf index collision. (RpoDigest::new([ONE, ONE, Felt::new(10), Felt::new(20)]), [ONE; 4]), (RpoDigest::new([ONE, ONE, Felt::new(20), Felt::new(20)]), [ONE; 4]), - // Normal single leaf again. - (RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), - // Empty leaf. - (RpoDigest::new([ONE, ONE, ONE, Felt::new(500)]), EMPTY_WORD), - ]; - let mut entries_iter = entries.iter().cloned(); - let mut next_entry = || entries_iter.next().unwrap(); - - let control_leaves: Vec = vec![ - SmtLeaf::Single(next_entry()), - SmtLeaf::Single(next_entry()), - SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(), - SmtLeaf::Single(next_entry()), - SmtLeaf::new_empty(Smt::key_to_leaf_index(&next_entry().0)), + // Subtree 1. Normal single leaf again. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), // Subtree boundary. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(401)]), [ONE; 4]), + // Subtree 2. Another normal leaf. + (RpoDigest::new([ONE, ONE, ONE, Felt::new(1024)]), [ONE; 4]), ]; - let control_subtree_leaves: Vec = control_leaves - .iter() - .map(|leaf| { - let col = leaf.index().index.value(); - let hash = leaf.hash(); - SubtreeLeaf { col, hash } - }) + let control = Smt::with_entries(entries.clone()).unwrap(); + + let control_leaves: Vec = { + let mut entries_iter = entries.iter().cloned(); + let mut next_entry = || entries_iter.next().unwrap(); + let control_leaves = vec![ + // Subtree 0. + SmtLeaf::Single(next_entry()), + SmtLeaf::Single(next_entry()), + SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(), + // Subtree 1. + SmtLeaf::Single(next_entry()), + SmtLeaf::Single(next_entry()), + // Subtree 2. + SmtLeaf::Single(next_entry()), + ]; + assert_eq!(entries_iter.next(), None); + control_leaves + }; + + let control_subtree_leaves: Vec> = { + let mut control_leaves_iter = control_leaves.iter(); + let mut next_leaf = || control_leaves_iter.next().unwrap(); + + let control_subtree_leaves: Vec> = [ + // Subtree 0. + vec![next_leaf(), next_leaf(), next_leaf()], + // Subtree 1. + vec![next_leaf(), next_leaf()], + // Subtree 2. + vec![next_leaf()], + ] + .map(|subtree| subtree.into_iter().map(SubtreeLeaf::from_smt_leaf).collect()) + .to_vec(); + assert_eq!(control_leaves_iter.next(), None); + control_subtree_leaves + }; + + let subtrees = Smt::sorted_pairs_to_leaves(entries); + // This will check that the hashes, columns, and subtree assignments all match. + assert_eq!(subtrees.leaves, control_subtree_leaves); + + // Then finally we might as well check the computed leaf nodes too. + let control_leaves: BTreeMap = control + .leaves() + .map(|(index, value)| (index.index.value(), value.clone())) .collect(); - let test_subtree_leaves = Smt::sorted_pairs_to_leaves(entries).subtrees; - assert_eq!(control_subtree_leaves, test_subtree_leaves); + for (column, test_leaf) in subtrees.nodes { + if test_leaf.is_empty() { + continue; + } + let control_leaf = control_leaves + .get(&column) + .expect(&format!("no leaf node found for column {column}")); + assert_eq!(control_leaf, &test_leaf); + } } // Helper for the below tests. @@ -733,7 +792,8 @@ mod test { let control = Smt::with_entries(entries.clone()).unwrap(); // `entries` should already be sorted by nature of how we constructed it. - let leaves = Smt::sorted_pairs_to_leaves(entries).subtrees; + let leaves = Smt::sorted_pairs_to_leaves(entries).leaves; + let leaves = leaves.into_iter().next().unwrap(); let (first_subtree, _) = Smt::build_subtree(leaves, SMT_DEPTH); assert!(!first_subtree.is_empty()); @@ -756,17 +816,19 @@ mod test { let control = Smt::with_entries(entries.clone()).unwrap(); - let leaves = Smt::sorted_pairs_to_leaves(entries); - let (first, second) = leaves.split_at_column(PAIR_COUNT / 2); - assert_eq!(first.subtrees.len(), second.subtrees.len()); + let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries); + // With two subtrees' worth of leaves, we should have exactly two subtrees. + let [first, second]: [_; 2] = leaves.try_into().unwrap(); + assert_eq!(first.len() as u64, PAIR_COUNT / 2); + assert_eq!(first.len(), second.len()); let mut current_depth = SMT_DEPTH; let mut next_leaves: Vec = Default::default(); - let (first_nodes, leaves) = Smt::build_subtree(first.subtrees, current_depth); + let (first_nodes, leaves) = Smt::build_subtree(first, current_depth); next_leaves.extend(leaves); - let (second_nodes, leaves) = Smt::build_subtree(second.subtrees, current_depth); + let (second_nodes, leaves) = Smt::build_subtree(second, current_depth); next_leaves.extend(leaves); // All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle.