smt: add build_subtrees() to coordinate subtree building

This commit is contained in:
Qyriad 2024-11-12 14:11:37 -07:00
parent 96d42a4a06
commit e6a6ad3712
3 changed files with 164 additions and 3 deletions

View file

@ -257,6 +257,12 @@ impl Smt {
) -> (BTreeMap<NodeIndex, InnerNode>, Vec<SubtreeLeaf>) {
<Self as SparseMerkleTree<SMT_DEPTH>>::build_subtree(leaves, bottom_depth)
}
pub fn build_subtrees(
entries: Vec<(RpoDigest, Word)>,
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<LeafIndex<SMT_DEPTH>, SmtLeaf>) {
<Self as SparseMerkleTree<SMT_DEPTH>>::build_subtrees(entries)
}
}
impl SparseMerkleTree<SMT_DEPTH> for Smt {

View file

@ -508,6 +508,57 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
(inner_nodes, leaves)
}
fn build_subtrees(
mut entries: Vec<(Self::Key, Self::Value)>,
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<LeafIndex<DEPTH>, Self::Leaf>) {
use rayon::prelude::*;
entries.sort_by_key(|item| {
let index = Self::key_to_leaf_index(&item.0);
index.value()
});
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
let PairComputations {
leaves: mut leaf_subtrees,
nodes: initial_leaves,
} = Self::sorted_pairs_to_leaves(entries);
for current_depth in (8..=DEPTH).step_by(8).rev() {
let (nodes, subtrees): (Vec<BTreeMap<_, _>>, Vec<Vec<SubtreeLeaf>>) = leaf_subtrees
.into_par_iter()
.map(|subtree| {
debug_assert!(subtree.is_sorted());
debug_assert!(!subtree.is_empty());
let (nodes, next_leaves) = Self::build_subtree(subtree, current_depth);
debug_assert!(next_leaves.is_sorted());
(nodes, next_leaves)
})
.unzip();
let mut all_leaves: Vec<SubtreeLeaf> = subtrees.into_iter().flatten().collect();
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect();
accumulated_nodes.extend(nodes.into_iter().flatten());
debug_assert!(!leaf_subtrees.is_empty());
}
let leaves: BTreeMap<LeafIndex<DEPTH>, Self::Leaf> = initial_leaves
.into_iter()
.map(|(key, value)| {
// FIXME: unwrap is unreachable?
let key = LeafIndex::<DEPTH>::new(key).unwrap();
(key, value)
})
.collect();
(accumulated_nodes, leaves)
}
}
// INNER NODE

View file

@ -1,3 +1,6 @@
use core::mem;
use std::collections::BTreeMap;
use alloc::vec::Vec;
use super::{
@ -7,10 +10,11 @@ use super::{
use crate::{
hash::rpo::Rpo256,
merkle::{
digests_to_words, int_to_leaf, int_to_node, smt::SparseMerkleTree, EmptySubtreeRoots,
InnerNodeInfo, LeafIndex, MerkleTree,
digests_to_words, int_to_leaf, int_to_node,
smt::{self, InnerNode, PairComputations, SparseMerkleTree},
EmptySubtreeRoots, InnerNodeInfo, LeafIndex, MerkleTree, SubtreeLeaf,
},
Word, EMPTY_WORD,
Felt, Word, EMPTY_WORD, ONE,
};
// TEST DATA
@ -461,6 +465,106 @@ fn test_simplesmt_check_empty_root_constant() {
assert_eq!(empty_root_1_depth, SimpleSmt::<1>::EMPTY_ROOT);
}
#[test]
fn test_simplesmt_subtrees() {
const PAIR_COUNT: u64 = 4096;
const DEPTH: u8 = 64;
type SimpleSmt = super::SimpleSmt<DEPTH>;
let entries: Vec<(LeafIndex<DEPTH>, Word)> = (0..PAIR_COUNT)
.map(|i| {
let leaf_index = ((i as f64 / PAIR_COUNT as f64) * (PAIR_COUNT as f64)) as u64;
let key = LeafIndex::new_max_depth(leaf_index);
let value: Word = [ONE, ONE, ONE, Felt::new(i)];
(key, value)
})
.collect();
let leaves = entries.iter().map(|(key, value)| (key.value(), *value));
let control = SimpleSmt::with_leaves(leaves).unwrap();
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
let PairComputations {
leaves: mut leaf_subtrees,
nodes: test_leaves,
} = SimpleSmt::sorted_pairs_to_leaves(entries);
for current_depth in (8..=DEPTH).step_by(8).rev() {
for (i, subtree) in mem::take(&mut leaf_subtrees).into_iter().enumerate() {
// Pre-assertions.
assert!(
subtree.is_sorted(),
"subtree {i} at bottom-depth {current_depth} is not sorted",
);
assert!(!subtree.is_empty(), "subtree {i} at bottom-depth {current_depth} is empty!");
// Do actual things.
let (nodes, next_leaves) = SimpleSmt::build_subtree(subtree, current_depth);
// Post-assertions.
assert!(next_leaves.is_sorted());
for (&index, test_node) in nodes.iter() {
let control_node = control.get_inner_node(index);
assert_eq!(
test_node, &control_node,
"depth {} subtree {}: test node does not match control at index {:?}",
current_depth, i, index,
);
}
// Update state.
accumulated_nodes.extend(nodes);
for subtree_leaf in next_leaves {
smt::add_subtree_leaf(&mut leaf_subtrees, subtree_leaf);
}
}
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
}
// Make sure the true leaves match, checking length first and then each individual leaf.
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
let control_leaves_len = control_leaves.len();
let test_leaves_len = test_leaves.len();
assert_eq!(test_leaves_len, control_leaves_len);
for (col, ref test_leaf) in test_leaves {
let &control_leaf = control_leaves.get(&col).unwrap();
assert_eq!(test_leaf, control_leaf);
}
// Make sure inner nodes match, checking length first and then each individual leaf.
let control_nodes_len = control.inner_nodes().count();
let test_nodes_len = accumulated_nodes.len();
assert_eq!(test_nodes_len, control_nodes_len);
for (index, test_node) in accumulated_nodes.clone() {
let control_node = control.get_inner_node(index);
assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
}
// After the last iteration of the above for loop, we should have the new root actually in two
// places: one in `accumulated_nodes`, and the other as the "next leaves" return from
// `build_subtree()`. So let's check both!
let control_root = control.get_inner_node(NodeIndex::root());
// That for loop should have left us with only one leaf subtree...
let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap();
// which itself contains only one 'leaf'...
let [SubtreeLeaf { hash: test_root_hash, .. }]: [_; 1] = leaf_subtree.try_into().unwrap();
// which matches the expected root.
assert_eq!(control.root(), test_root_hash);
// Likewise `accumulated_nodes` should contain a node at the root index...
assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
// and it should match our actual root.
let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap();
assert_eq!(control_root, *test_root);
// And of course the root we got from each place should match.
assert_eq!(control.root(), test_root_hash);
}
// HELPER FUNCTIONS
// --------------------------------------------------------------------------------------------