add test_two_subtrees() test

This commit is contained in:
Qyriad 2024-10-29 14:17:02 -06:00
parent c35e18453a
commit 6db08f4714

View file

@ -623,6 +623,34 @@ pub struct PairComputations<L> {
pub subtrees: Vec<SubtreeLeaf>,
}
impl<L> PairComputations<L> {
#[cfg_attr(not(test), allow(dead_code))]
pub fn split_at_column(mut self, col: u64) -> (Self, Self) {
let split_point = match self.subtrees.binary_search_by_key(&col, |key| key.col) {
// Surprisingly, Result<T, E> has no method like `unwrap_or_unwrap_err() where T == E`.
// Probably because there's no way to write that where bound.
Ok(split_point) | Err(split_point) => split_point,
};
let subtrees_right = self.subtrees.split_off(split_point);
let subtrees_left = self.subtrees;
let nodes_right = self.nodes.split_off(&col);
let nodes_left = self.nodes;
let left = Self {
nodes: nodes_left,
subtrees: subtrees_left,
};
let right = Self {
nodes: nodes_right,
subtrees: subtrees_right,
};
(left, right)
}
}
// Derive requires `L` to impl Default, even though we don't actually need that.
impl<L> Default for PairComputations<L> {
fn default() -> Self {
@ -696,7 +724,7 @@ mod test {
}
#[test]
fn test_build_subtree_from_leaves() {
fn test_single_subtree() {
// A single subtree's worth of leaves.
const PAIR_COUNT: u64 = 256;
@ -718,4 +746,55 @@ mod test {
);
}
}
#[test]
fn test_two_subtrees() {
// Two subtrees' worth of leaves.
const PAIR_COUNT: u64 = 512;
let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries(entries.clone()).unwrap();
let leaves = Smt::sorted_pairs_to_leaves(entries);
let (first, second) = leaves.split_at_column(PAIR_COUNT / 2);
assert_eq!(first.subtrees.len(), second.subtrees.len());
let mut current_depth = SMT_DEPTH;
let mut next_leaves: Vec<SubtreeLeaf> = Default::default();
let (first_nodes, leaves) = Smt::build_subtree(first.subtrees, current_depth);
next_leaves.extend(leaves);
let (second_nodes, leaves) = Smt::build_subtree(second.subtrees, current_depth);
next_leaves.extend(leaves);
// All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle.
let total_computed = first_nodes.len() + second_nodes.len() + next_leaves.len();
assert_eq!(total_computed as u64, PAIR_COUNT);
// Verify the computed nodes of both subtrees.
let computed_nodes = first_nodes.clone().into_iter().chain(second_nodes);
for (index, test_node) in computed_nodes {
let control_node = control.get_inner_node(index);
assert_eq!(
control_node, test_node,
"subtree-computed node at index {index:?} does not match control",
);
}
current_depth -= 8;
let (nodes, next_leaves) = Smt::build_subtree(next_leaves, current_depth);
assert_eq!(nodes.len(), 8);
assert_eq!(next_leaves.len(), 1);
for (index, test_node) in nodes {
let control_node = control.get_inner_node(index);
assert_eq!(
control_node, test_node,
"subtree-computed node at index {index:?} does not match control",
);
}
}
}