refactor sorted_pairs_to_leaves() to also group subtrees
This commit is contained in:
parent
6db08f4714
commit
47e1650a40
1 changed files with 135 additions and 73 deletions
|
@ -353,40 +353,42 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
|
|||
fn sorted_pairs_to_leaves(
|
||||
pairs: Vec<(Self::Key, Self::Value)>,
|
||||
) -> PairComputations<Self::Leaf> {
|
||||
let mut all_leaves = PairComputations::default();
|
||||
let mut accumulator: PairComputations<Self::Leaf> = Default::default();
|
||||
|
||||
let mut buffer: Vec<(Self::Key, Self::Value)> = Default::default();
|
||||
// The kv-pairs we've seen so far that correspond to a single leaf.
|
||||
let mut current_leaf_buffer: Vec<(Self::Key, Self::Value)> = Default::default();
|
||||
|
||||
let mut iter = pairs.into_iter().peekable();
|
||||
while let Some((key, value)) = iter.next() {
|
||||
let col = Self::key_to_leaf_index(&key).index.value();
|
||||
let next_col = iter.peek().map(|(key, _)| {
|
||||
let peeked_col = iter.peek().map(|(key, _v)| {
|
||||
let index = Self::key_to_leaf_index(key);
|
||||
index.index.value()
|
||||
let next_col = index.index.value();
|
||||
// We panic if `pairs` is not sorted by column.
|
||||
debug_assert!(next_col >= col);
|
||||
next_col
|
||||
});
|
||||
current_leaf_buffer.push((key, value));
|
||||
|
||||
buffer.push((key, value));
|
||||
|
||||
if let Some(next_col) = next_col {
|
||||
assert!(next_col >= col);
|
||||
}
|
||||
|
||||
if next_col == Some(col) {
|
||||
// Keep going in our buffer.
|
||||
// If the next pair is the same column as this one, then we're done after adding this
|
||||
// pair to the buffer.
|
||||
if peeked_col == Some(col) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Whether the next pair is a different column, or non-existent, we break off.
|
||||
let leaf_pairs = mem::take(&mut buffer);
|
||||
// Otherwise, the next pair is a different column, or there is no next pair. Either way
|
||||
// it's time to swap out our buffer.
|
||||
let leaf_pairs = mem::take(&mut current_leaf_buffer);
|
||||
let leaf = Self::pairs_to_leaf(leaf_pairs);
|
||||
let hash = Self::hash_leaf(&leaf);
|
||||
|
||||
all_leaves.nodes.insert(col, leaf);
|
||||
all_leaves.subtrees.push(SubtreeLeaf { col, hash });
|
||||
}
|
||||
assert_eq!(buffer.len(), 0);
|
||||
accumulator.nodes.insert(col, leaf);
|
||||
accumulator.add_leaf(SubtreeLeaf { col, hash });
|
||||
|
||||
all_leaves
|
||||
debug_assert!(current_leaf_buffer.is_empty());
|
||||
}
|
||||
|
||||
accumulator
|
||||
}
|
||||
|
||||
/// Builds Merkle nodes from a bottom layer of tuples of horizontal indices and their hashes,
|
||||
|
@ -615,39 +617,55 @@ pub struct SubtreeLeaf {
|
|||
pub hash: RpoDigest,
|
||||
}
|
||||
|
||||
impl SubtreeLeaf {
|
||||
#[cfg_attr(not(test), allow(dead_code))]
|
||||
fn from_smt_leaf(leaf: &crate::merkle::SmtLeaf) -> Self {
|
||||
Self {
|
||||
col: leaf.index().index.value(),
|
||||
hash: leaf.hash(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct PairComputations<L> {
|
||||
/// Literal leaves to be added to the sparse Merkle tree's internal mapping.
|
||||
pub nodes: BTreeMap<u64, L>,
|
||||
/// "Conceptual" leaves that will be used for computations.
|
||||
pub subtrees: Vec<SubtreeLeaf>,
|
||||
pub leaves: Vec<Vec<SubtreeLeaf>>,
|
||||
}
|
||||
|
||||
impl<L> PairComputations<L> {
|
||||
#[cfg_attr(not(test), allow(dead_code))]
|
||||
pub fn split_at_column(mut self, col: u64) -> (Self, Self) {
|
||||
let split_point = match self.subtrees.binary_search_by_key(&col, |key| key.col) {
|
||||
// Surprisingly, Result<T, E> has no method like `unwrap_or_unwrap_err() where T == E`.
|
||||
// Probably because there's no way to write that where bound.
|
||||
Ok(split_point) | Err(split_point) => split_point,
|
||||
pub fn add_leaf(&mut self, leaf: SubtreeLeaf) {
|
||||
// A depth-8 subtree contains 256 "columns" that can possibly be occupied.
|
||||
const COLS_PER_SUBTREE: u64 = u64::pow(2, 8);
|
||||
|
||||
let last_subtree = match self.leaves.last_mut() {
|
||||
// Base case.
|
||||
None => {
|
||||
self.leaves.push(vec![leaf]);
|
||||
return;
|
||||
},
|
||||
Some(last_subtree) => last_subtree,
|
||||
};
|
||||
|
||||
let subtrees_right = self.subtrees.split_off(split_point);
|
||||
let subtrees_left = self.subtrees;
|
||||
debug_assert!(!last_subtree.is_empty());
|
||||
debug_assert!(last_subtree.len() <= COLS_PER_SUBTREE as usize);
|
||||
|
||||
let nodes_right = self.nodes.split_off(&col);
|
||||
let nodes_left = self.nodes;
|
||||
|
||||
let left = Self {
|
||||
nodes: nodes_left,
|
||||
subtrees: subtrees_left,
|
||||
// The multiple of 256 after 0 is 1, but 0 and 1 do not belong to different subtrees.
|
||||
let last_subtree_col = u64::max(1, last_subtree.last().unwrap().col);
|
||||
let next_subtree_col = if last_subtree_col.is_multiple_of(&COLS_PER_SUBTREE) {
|
||||
u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE)
|
||||
} else {
|
||||
last_subtree_col.next_multiple_of(COLS_PER_SUBTREE)
|
||||
};
|
||||
let right = Self {
|
||||
nodes: nodes_right,
|
||||
subtrees: subtrees_right,
|
||||
};
|
||||
|
||||
(left, right)
|
||||
if leaf.col < next_subtree_col {
|
||||
last_subtree.push(leaf);
|
||||
} else {
|
||||
//std::eprintln!("\tcreating new subtree for column {}", leaf.col);
|
||||
let next_subtree = vec![leaf];
|
||||
self.leaves.push(next_subtree);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -656,7 +674,7 @@ impl<L> Default for PairComputations<L> {
|
|||
fn default() -> Self {
|
||||
Self {
|
||||
nodes: Default::default(),
|
||||
subtrees: Default::default(),
|
||||
leaves: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -665,50 +683,91 @@ impl<L> Default for PairComputations<L> {
|
|||
// ================================================================================================
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use alloc::vec::Vec;
|
||||
use alloc::{collections::BTreeMap, vec::Vec};
|
||||
|
||||
use super::SparseMerkleTree;
|
||||
use crate::{
|
||||
hash::rpo::RpoDigest,
|
||||
merkle::{smt::SubtreeLeaf, Smt, SmtLeaf, SMT_DEPTH},
|
||||
Felt, Word, EMPTY_WORD, ONE,
|
||||
merkle::{
|
||||
smt::{PairComputations, SubtreeLeaf},
|
||||
Smt, SmtLeaf, SMT_DEPTH,
|
||||
},
|
||||
Felt, Word, ONE,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_sorted_pairs_to_leaves() {
|
||||
let entries: Vec<(RpoDigest, Word)> = vec![
|
||||
// Subtree 0.
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]),
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(17)]), [ONE; 4]),
|
||||
// Leaf index collision.
|
||||
(RpoDigest::new([ONE, ONE, Felt::new(10), Felt::new(20)]), [ONE; 4]),
|
||||
(RpoDigest::new([ONE, ONE, Felt::new(20), Felt::new(20)]), [ONE; 4]),
|
||||
// Normal single leaf again.
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]),
|
||||
// Empty leaf.
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(500)]), EMPTY_WORD),
|
||||
];
|
||||
let mut entries_iter = entries.iter().cloned();
|
||||
let mut next_entry = || entries_iter.next().unwrap();
|
||||
|
||||
let control_leaves: Vec<SmtLeaf> = vec![
|
||||
SmtLeaf::Single(next_entry()),
|
||||
SmtLeaf::Single(next_entry()),
|
||||
SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(),
|
||||
SmtLeaf::Single(next_entry()),
|
||||
SmtLeaf::new_empty(Smt::key_to_leaf_index(&next_entry().0)),
|
||||
// Subtree 1. Normal single leaf again.
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(400)]), [ONE; 4]), // Subtree boundary.
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(401)]), [ONE; 4]),
|
||||
// Subtree 2. Another normal leaf.
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(1024)]), [ONE; 4]),
|
||||
];
|
||||
|
||||
let control_subtree_leaves: Vec<SubtreeLeaf> = control_leaves
|
||||
.iter()
|
||||
.map(|leaf| {
|
||||
let col = leaf.index().index.value();
|
||||
let hash = leaf.hash();
|
||||
SubtreeLeaf { col, hash }
|
||||
})
|
||||
let control = Smt::with_entries(entries.clone()).unwrap();
|
||||
|
||||
let control_leaves: Vec<SmtLeaf> = {
|
||||
let mut entries_iter = entries.iter().cloned();
|
||||
let mut next_entry = || entries_iter.next().unwrap();
|
||||
let control_leaves = vec![
|
||||
// Subtree 0.
|
||||
SmtLeaf::Single(next_entry()),
|
||||
SmtLeaf::Single(next_entry()),
|
||||
SmtLeaf::new_multiple(vec![next_entry(), next_entry()]).unwrap(),
|
||||
// Subtree 1.
|
||||
SmtLeaf::Single(next_entry()),
|
||||
SmtLeaf::Single(next_entry()),
|
||||
// Subtree 2.
|
||||
SmtLeaf::Single(next_entry()),
|
||||
];
|
||||
assert_eq!(entries_iter.next(), None);
|
||||
control_leaves
|
||||
};
|
||||
|
||||
let control_subtree_leaves: Vec<Vec<SubtreeLeaf>> = {
|
||||
let mut control_leaves_iter = control_leaves.iter();
|
||||
let mut next_leaf = || control_leaves_iter.next().unwrap();
|
||||
|
||||
let control_subtree_leaves: Vec<Vec<SubtreeLeaf>> = [
|
||||
// Subtree 0.
|
||||
vec![next_leaf(), next_leaf(), next_leaf()],
|
||||
// Subtree 1.
|
||||
vec![next_leaf(), next_leaf()],
|
||||
// Subtree 2.
|
||||
vec![next_leaf()],
|
||||
]
|
||||
.map(|subtree| subtree.into_iter().map(SubtreeLeaf::from_smt_leaf).collect())
|
||||
.to_vec();
|
||||
assert_eq!(control_leaves_iter.next(), None);
|
||||
control_subtree_leaves
|
||||
};
|
||||
|
||||
let subtrees = Smt::sorted_pairs_to_leaves(entries);
|
||||
// This will check that the hashes, columns, and subtree assignments all match.
|
||||
assert_eq!(subtrees.leaves, control_subtree_leaves);
|
||||
|
||||
// Then finally we might as well check the computed leaf nodes too.
|
||||
let control_leaves: BTreeMap<u64, SmtLeaf> = control
|
||||
.leaves()
|
||||
.map(|(index, value)| (index.index.value(), value.clone()))
|
||||
.collect();
|
||||
|
||||
let test_subtree_leaves = Smt::sorted_pairs_to_leaves(entries).subtrees;
|
||||
assert_eq!(control_subtree_leaves, test_subtree_leaves);
|
||||
for (column, test_leaf) in subtrees.nodes {
|
||||
if test_leaf.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let control_leaf = control_leaves
|
||||
.get(&column)
|
||||
.expect(&format!("no leaf node found for column {column}"));
|
||||
assert_eq!(control_leaf, &test_leaf);
|
||||
}
|
||||
}
|
||||
|
||||
// Helper for the below tests.
|
||||
|
@ -733,7 +792,8 @@ mod test {
|
|||
let control = Smt::with_entries(entries.clone()).unwrap();
|
||||
|
||||
// `entries` should already be sorted by nature of how we constructed it.
|
||||
let leaves = Smt::sorted_pairs_to_leaves(entries).subtrees;
|
||||
let leaves = Smt::sorted_pairs_to_leaves(entries).leaves;
|
||||
let leaves = leaves.into_iter().next().unwrap();
|
||||
|
||||
let (first_subtree, _) = Smt::build_subtree(leaves, SMT_DEPTH);
|
||||
assert!(!first_subtree.is_empty());
|
||||
|
@ -756,17 +816,19 @@ mod test {
|
|||
|
||||
let control = Smt::with_entries(entries.clone()).unwrap();
|
||||
|
||||
let leaves = Smt::sorted_pairs_to_leaves(entries);
|
||||
let (first, second) = leaves.split_at_column(PAIR_COUNT / 2);
|
||||
assert_eq!(first.subtrees.len(), second.subtrees.len());
|
||||
let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries);
|
||||
// With two subtrees' worth of leaves, we should have exactly two subtrees.
|
||||
let [first, second]: [_; 2] = leaves.try_into().unwrap();
|
||||
assert_eq!(first.len() as u64, PAIR_COUNT / 2);
|
||||
assert_eq!(first.len(), second.len());
|
||||
|
||||
let mut current_depth = SMT_DEPTH;
|
||||
let mut next_leaves: Vec<SubtreeLeaf> = Default::default();
|
||||
|
||||
let (first_nodes, leaves) = Smt::build_subtree(first.subtrees, current_depth);
|
||||
let (first_nodes, leaves) = Smt::build_subtree(first, current_depth);
|
||||
next_leaves.extend(leaves);
|
||||
|
||||
let (second_nodes, leaves) = Smt::build_subtree(second.subtrees, current_depth);
|
||||
let (second_nodes, leaves) = Smt::build_subtree(second, current_depth);
|
||||
next_leaves.extend(leaves);
|
||||
|
||||
// All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle.
|
||||
|
|
Loading…
Add table
Reference in a new issue