From 3abb959048d7377bea16ea5bbe1b39d97569d101 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Tue, 29 Oct 2024 14:17:02 -0600 Subject: [PATCH] add test_two_subtrees() test --- src/merkle/smt/mod.rs | 81 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index b7620d7..0a5de36 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -623,6 +623,34 @@ pub struct PrecomputedLeaves { pub subtrees: Vec, } +impl PrecomputedLeaves { + #[cfg_attr(not(test), allow(dead_code))] + pub fn split_at_column(mut self, col: u64) -> (Self, Self) { + let split_point = match self.subtrees.binary_search_by_key(&col, |key| key.col) { + // Surprisingly, Result has no method like `unwrap_or_unwrap_err() where T == E`. + // Probably because there's no way to write that where bound. + Ok(split_point) | Err(split_point) => split_point, + }; + + let subtrees_right = self.subtrees.split_off(split_point); + let subtrees_left = self.subtrees; + + let nodes_right = self.nodes.split_off(&col); + let nodes_left = self.nodes; + + let left = Self { + nodes: nodes_left, + subtrees: subtrees_left, + }; + let right = Self { + nodes: nodes_right, + subtrees: subtrees_right, + }; + + (left, right) + } +} + // Derive requires `L` to impl Default, even though we don't actually need that. impl Default for PrecomputedLeaves { fn default() -> Self { @@ -696,7 +724,7 @@ mod test { } #[test] - fn test_build_subtree_from_leaves() { + fn test_single_subtree() { // A single subtree's worth of leaves. const PAIR_COUNT: u64 = 256; @@ -718,4 +746,55 @@ mod test { ); } } + + #[test] + fn test_two_subtrees() { + // Two subtrees' worth of leaves. + const PAIR_COUNT: u64 = 512; + + let entries = generate_entries(PAIR_COUNT); + + let control = Smt::with_entries(entries.clone()).unwrap(); + + let leaves = Smt::sorted_pairs_to_leaves(entries); + let (first, second) = leaves.split_at_column(PAIR_COUNT / 2); + assert_eq!(first.subtrees.len(), second.subtrees.len()); + + let mut current_depth = SMT_DEPTH; + let mut next_leaves: Vec = Default::default(); + + let (first_nodes, leaves) = Smt::build_subtree(first.subtrees, current_depth); + next_leaves.extend(leaves); + + let (second_nodes, leaves) = Smt::build_subtree(second.subtrees, current_depth); + next_leaves.extend(leaves); + + // All new inner nodes + the new subtree-leaves should be 512, for one depth-cycle. + let total_computed = first_nodes.len() + second_nodes.len() + next_leaves.len(); + assert_eq!(total_computed as u64, PAIR_COUNT); + + // Verify the computed nodes of both subtrees. + let computed_nodes = first_nodes.clone().into_iter().chain(second_nodes); + for (index, test_node) in computed_nodes { + let control_node = control.get_inner_node(index); + assert_eq!( + control_node, test_node, + "subtree-computed node at index {index:?} does not match control", + ); + } + + current_depth -= 8; + + let (nodes, next_leaves) = Smt::build_subtree(next_leaves, current_depth); + assert_eq!(nodes.len(), 8); + assert_eq!(next_leaves.len(), 1); + + for (index, test_node) in nodes { + let control_node = control.get_inner_node(index); + assert_eq!( + control_node, test_node, + "subtree-computed node at index {index:?} does not match control", + ); + } + } }