convert test_singlethreaded_subtree to use an iterator adapter instead of state mutation

This commit is contained in:
Qyriad 2024-11-05 13:04:24 -07:00
parent 8997c46bb4
commit ee95b29390

View file

@ -683,19 +683,65 @@ fn add_subtree_leaf(subtrees: &mut Vec<Vec<SubtreeLeaf>>, leaf: SubtreeLeaf) {
} }
} }
#[derive(Debug)]
struct SubtreeLeavesIter<'s> {
leaves: core::iter::Peekable<alloc::vec::Drain<'s, SubtreeLeaf>>,
}
impl<'s> SubtreeLeavesIter<'s> {
fn from_leaves(leaves: &'s mut Vec<SubtreeLeaf>) -> Self {
Self { leaves: leaves.drain(..).peekable() }
}
}
impl<'s> core::iter::Iterator for SubtreeLeavesIter<'s> {
type Item = Vec<SubtreeLeaf>;
/// Each `next()` collects an entire subtree.
fn next(&mut self) -> Option<Vec<SubtreeLeaf>> {
let mut subtree: Vec<SubtreeLeaf> = Default::default();
let mut last_subtree_col = 0;
while let Some(leaf) = self.leaves.peek() {
last_subtree_col = u64::max(1, last_subtree_col);
let is_exact_multiple = Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE);
let next_subtree_col = if is_exact_multiple {
u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE)
} else {
last_subtree_col.next_multiple_of(COLS_PER_SUBTREE)
};
last_subtree_col = leaf.col;
if leaf.col < next_subtree_col {
subtree.push(self.leaves.next().unwrap());
} else if subtree.is_empty() {
continue;
} else {
break;
}
}
if subtree.is_empty() {
debug_assert!(self.leaves.peek().is_none());
return None;
}
Some(subtree)
}
}
// TESTS // TESTS
// ================================================================================================ // ================================================================================================
#[cfg(test)] #[cfg(test)]
mod test { mod test {
use core::mem;
use alloc::{collections::BTreeMap, vec::Vec}; use alloc::{collections::BTreeMap, vec::Vec};
use super::{SparseMerkleTree, SubtreeLeaf}; use super::{SparseMerkleTree, SubtreeLeaf};
use crate::{ use crate::{
hash::rpo::RpoDigest, hash::rpo::RpoDigest,
merkle::{ merkle::{
smt::{InnerNode, PairComputations}, smt::{InnerNode, PairComputations, SubtreeLeavesIter},
LeafIndex, NodeIndex, Smt, SmtLeaf, SMT_DEPTH, LeafIndex, NodeIndex, Smt, SmtLeaf, SMT_DEPTH,
}, },
Felt, Word, ONE, Felt, Word, ONE,
@ -889,7 +935,11 @@ mod test {
} = Smt::sorted_pairs_to_leaves(entries); } = Smt::sorted_pairs_to_leaves(entries);
for current_depth in (8..=SMT_DEPTH).step_by(8).rev() { for current_depth in (8..=SMT_DEPTH).step_by(8).rev() {
for (i, subtree) in mem::take(&mut leaf_subtrees).into_iter().enumerate() { // There's no flat_map_unzip(), so this is the best we can do.
let (nodes, subtrees): (Vec<BTreeMap<_, _>>, Vec<Vec<SubtreeLeaf>>) = leaf_subtrees
.into_iter()
.enumerate()
.map(|(i, subtree)| {
// Pre-assertions. // Pre-assertions.
assert!( assert!(
subtree.is_sorted(), subtree.is_sorted(),
@ -902,9 +952,9 @@ mod test {
// Do actual things. // Do actual things.
let (nodes, next_leaves) = Smt::build_subtree(subtree, current_depth); let (nodes, next_leaves) = Smt::build_subtree(subtree, current_depth);
// Post-assertions. // Post-assertions.
assert!(next_leaves.is_sorted()); assert!(next_leaves.is_sorted());
for (&index, test_node) in nodes.iter() { for (&index, test_node) in nodes.iter() {
let control_node = control.get_inner_node(index); let control_node = control.get_inner_node(index);
assert_eq!( assert_eq!(
@ -914,13 +964,17 @@ mod test {
); );
} }
// Update state. (nodes, next_leaves)
accumulated_nodes.extend(nodes); })
.unzip();
for subtree_leaf in next_leaves { // Update state between each depth iteration.
super::add_subtree_leaf(&mut leaf_subtrees, subtree_leaf);
} // FIXME: is this flatten or Box<dyn Iterator> better?
} let mut all_leaves: Vec<SubtreeLeaf> = subtrees.into_iter().flatten().collect();
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect();
accumulated_nodes.extend(nodes.into_iter().flatten());
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
} }