From 33371774327833f0b0be7f4699d739b56b4c3a63 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Tue, 12 Nov 2024 14:11:37 -0700 Subject: [PATCH] smt: add `build_subtrees()` to coordinate subtree building --- src/merkle/smt/full/mod.rs | 6 ++ src/merkle/smt/mod.rs | 51 +++++++++++++++ src/merkle/smt/simple/tests.rs | 110 ++++++++++++++++++++++++++++++++- 3 files changed, 164 insertions(+), 3 deletions(-) diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 9c08f26..c6a98f4 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -257,6 +257,12 @@ impl Smt { ) -> (BTreeMap, Vec) { >::build_subtree(leaves, bottom_depth) } + + pub fn build_subtrees( + entries: Vec<(RpoDigest, Word)>, + ) -> (BTreeMap, BTreeMap, SmtLeaf>) { + >::build_subtrees(entries) + } } impl SparseMerkleTree for Smt { diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index c7cf8f9..36814a4 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -509,6 +509,57 @@ pub(crate) trait SparseMerkleTree { (inner_nodes, leaves) } + + fn build_subtrees( + mut entries: Vec<(Self::Key, Self::Value)>, + ) -> (BTreeMap, BTreeMap, Self::Leaf>) { + use rayon::prelude::*; + + entries.sort_by_key(|item| { + let index = Self::key_to_leaf_index(&item.0); + index.value() + }); + + let mut accumulated_nodes: BTreeMap = Default::default(); + + let PairComputations { + leaves: mut leaf_subtrees, + nodes: initial_leaves, + } = Self::sorted_pairs_to_leaves(entries); + + for current_depth in (8..=DEPTH).step_by(8).rev() { + let (nodes, subtrees): (Vec>, Vec>) = leaf_subtrees + .into_par_iter() + .map(|subtree| { + debug_assert!(subtree.is_sorted()); + debug_assert!(!subtree.is_empty()); + + let (nodes, next_leaves) = Self::build_subtree(subtree, current_depth); + + debug_assert!(next_leaves.is_sorted()); + + (nodes, next_leaves) + }) + .unzip(); + + let mut all_leaves: Vec = subtrees.into_iter().flatten().collect(); + leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect(); + accumulated_nodes.extend(nodes.into_iter().flatten()); + + debug_assert!(!leaf_subtrees.is_empty()); + } + + let leaves: BTreeMap, Self::Leaf> = initial_leaves + .into_iter() + .map(|(key, value)| { + // FIXME: unwrap is unreachable? + let key = LeafIndex::::new(key).unwrap(); + (key, value) + }) + .collect(); + + (accumulated_nodes, leaves) + } } // INNER NODE diff --git a/src/merkle/smt/simple/tests.rs b/src/merkle/smt/simple/tests.rs index b1dd28d..29c63be 100644 --- a/src/merkle/smt/simple/tests.rs +++ b/src/merkle/smt/simple/tests.rs @@ -1,3 +1,6 @@ +use core::mem; +use std::collections::BTreeMap; + use alloc::vec::Vec; use super::{ @@ -7,10 +10,11 @@ use super::{ use crate::{ hash::rpo::Rpo256, merkle::{ - digests_to_words, int_to_leaf, int_to_node, smt::SparseMerkleTree, EmptySubtreeRoots, - InnerNodeInfo, LeafIndex, MerkleTree, + digests_to_words, int_to_leaf, int_to_node, + smt::{self, InnerNode, PairComputations, SparseMerkleTree}, + EmptySubtreeRoots, InnerNodeInfo, LeafIndex, MerkleTree, SubtreeLeaf, }, - Word, EMPTY_WORD, + Felt, Word, EMPTY_WORD, ONE, }; // TEST DATA @@ -461,6 +465,106 @@ fn test_simplesmt_check_empty_root_constant() { assert_eq!(empty_root_1_depth, SimpleSmt::<1>::EMPTY_ROOT); } +#[test] +fn test_simplesmt_subtrees() { + const PAIR_COUNT: u64 = 4096; + const DEPTH: u8 = 64; + type SimpleSmt = super::SimpleSmt; + + let entries: Vec<(LeafIndex, Word)> = (0..PAIR_COUNT) + .map(|i| { + let leaf_index = ((i as f64 / PAIR_COUNT as f64) * (PAIR_COUNT as f64)) as u64; + let key = LeafIndex::new_max_depth(leaf_index); + let value: Word = [ONE, ONE, ONE, Felt::new(i)]; + (key, value) + }) + .collect(); + let leaves = entries.iter().map(|(key, value)| (key.value(), *value)); + + let control = SimpleSmt::with_leaves(leaves).unwrap(); + + let mut accumulated_nodes: BTreeMap = Default::default(); + + let PairComputations { + leaves: mut leaf_subtrees, + nodes: test_leaves, + } = SimpleSmt::sorted_pairs_to_leaves(entries); + + for current_depth in (8..=DEPTH).step_by(8).rev() { + for (i, subtree) in mem::take(&mut leaf_subtrees).into_iter().enumerate() { + // Pre-assertions. + assert!( + subtree.is_sorted(), + "subtree {i} at bottom-depth {current_depth} is not sorted", + ); + assert!(!subtree.is_empty(), "subtree {i} at bottom-depth {current_depth} is empty!"); + + // Do actual things. + let (nodes, next_leaves) = SimpleSmt::build_subtree(subtree, current_depth); + + // Post-assertions. + assert!(next_leaves.is_sorted()); + for (&index, test_node) in nodes.iter() { + let control_node = control.get_inner_node(index); + assert_eq!( + test_node, &control_node, + "depth {} subtree {}: test node does not match control at index {:?}", + current_depth, i, index, + ); + } + + // Update state. + accumulated_nodes.extend(nodes); + + for subtree_leaf in next_leaves { + smt::add_subtree_leaf(&mut leaf_subtrees, subtree_leaf); + } + } + + assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}"); + } + + // Make sure the true leaves match, checking length first and then each individual leaf. + let control_leaves: BTreeMap<_, _> = control.leaves().collect(); + let control_leaves_len = control_leaves.len(); + let test_leaves_len = test_leaves.len(); + assert_eq!(test_leaves_len, control_leaves_len); + for (col, ref test_leaf) in test_leaves { + let &control_leaf = control_leaves.get(&col).unwrap(); + assert_eq!(test_leaf, control_leaf); + } + + // Make sure inner nodes match, checking length first and then each individual leaf. + let control_nodes_len = control.inner_nodes().count(); + let test_nodes_len = accumulated_nodes.len(); + assert_eq!(test_nodes_len, control_nodes_len); + for (index, test_node) in accumulated_nodes.clone() { + let control_node = control.get_inner_node(index); + assert_eq!(test_node, control_node, "test node does not match control at {index:?}"); + } + + // After the last iteration of the above for loop, we should have the new root actually in two + // places: one in `accumulated_nodes`, and the other as the "next leaves" return from + // `build_subtree()`. So let's check both! + + let control_root = control.get_inner_node(NodeIndex::root()); + + // That for loop should have left us with only one leaf subtree... + let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap(); + // which itself contains only one 'leaf'... + let [SubtreeLeaf { hash: test_root_hash, .. }]: [_; 1] = leaf_subtree.try_into().unwrap(); + // which matches the expected root. + assert_eq!(control.root(), test_root_hash); + + // Likewise `accumulated_nodes` should contain a node at the root index... + assert!(accumulated_nodes.contains_key(&NodeIndex::root())); + // and it should match our actual root. + let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap(); + assert_eq!(control_root, *test_root); + // And of course the root we got from each place should match. + assert_eq!(control.root(), test_root_hash); +} + // HELPER FUNCTIONS // --------------------------------------------------------------------------------------------