smt: add build_subtrees()
to coordinate subtree building
This commit is contained in:
parent
96d42a4a06
commit
e6a6ad3712
3 changed files with 164 additions and 3 deletions
|
@ -257,6 +257,12 @@ impl Smt {
|
|||
) -> (BTreeMap<NodeIndex, InnerNode>, Vec<SubtreeLeaf>) {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::build_subtree(leaves, bottom_depth)
|
||||
}
|
||||
|
||||
pub fn build_subtrees(
|
||||
entries: Vec<(RpoDigest, Word)>,
|
||||
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<LeafIndex<SMT_DEPTH>, SmtLeaf>) {
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::build_subtrees(entries)
|
||||
}
|
||||
}
|
||||
|
||||
impl SparseMerkleTree<SMT_DEPTH> for Smt {
|
||||
|
|
|
@ -508,6 +508,57 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
|
|||
|
||||
(inner_nodes, leaves)
|
||||
}
|
||||
|
||||
fn build_subtrees(
|
||||
mut entries: Vec<(Self::Key, Self::Value)>,
|
||||
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<LeafIndex<DEPTH>, Self::Leaf>) {
|
||||
use rayon::prelude::*;
|
||||
|
||||
entries.sort_by_key(|item| {
|
||||
let index = Self::key_to_leaf_index(&item.0);
|
||||
index.value()
|
||||
});
|
||||
|
||||
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
|
||||
|
||||
let PairComputations {
|
||||
leaves: mut leaf_subtrees,
|
||||
nodes: initial_leaves,
|
||||
} = Self::sorted_pairs_to_leaves(entries);
|
||||
|
||||
for current_depth in (8..=DEPTH).step_by(8).rev() {
|
||||
let (nodes, subtrees): (Vec<BTreeMap<_, _>>, Vec<Vec<SubtreeLeaf>>) = leaf_subtrees
|
||||
.into_par_iter()
|
||||
.map(|subtree| {
|
||||
debug_assert!(subtree.is_sorted());
|
||||
debug_assert!(!subtree.is_empty());
|
||||
|
||||
let (nodes, next_leaves) = Self::build_subtree(subtree, current_depth);
|
||||
|
||||
debug_assert!(next_leaves.is_sorted());
|
||||
|
||||
(nodes, next_leaves)
|
||||
})
|
||||
.unzip();
|
||||
|
||||
let mut all_leaves: Vec<SubtreeLeaf> = subtrees.into_iter().flatten().collect();
|
||||
leaf_subtrees = SubtreeLeavesIter::from_leaves(&mut all_leaves).collect();
|
||||
accumulated_nodes.extend(nodes.into_iter().flatten());
|
||||
|
||||
debug_assert!(!leaf_subtrees.is_empty());
|
||||
}
|
||||
|
||||
let leaves: BTreeMap<LeafIndex<DEPTH>, Self::Leaf> = initial_leaves
|
||||
.into_iter()
|
||||
.map(|(key, value)| {
|
||||
// FIXME: unwrap is unreachable?
|
||||
let key = LeafIndex::<DEPTH>::new(key).unwrap();
|
||||
(key, value)
|
||||
})
|
||||
.collect();
|
||||
|
||||
(accumulated_nodes, leaves)
|
||||
}
|
||||
}
|
||||
|
||||
// INNER NODE
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
use core::mem;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use super::{
|
||||
|
@ -7,10 +10,11 @@ use super::{
|
|||
use crate::{
|
||||
hash::rpo::Rpo256,
|
||||
merkle::{
|
||||
digests_to_words, int_to_leaf, int_to_node, smt::SparseMerkleTree, EmptySubtreeRoots,
|
||||
InnerNodeInfo, LeafIndex, MerkleTree,
|
||||
digests_to_words, int_to_leaf, int_to_node,
|
||||
smt::{self, InnerNode, PairComputations, SparseMerkleTree},
|
||||
EmptySubtreeRoots, InnerNodeInfo, LeafIndex, MerkleTree, SubtreeLeaf,
|
||||
},
|
||||
Word, EMPTY_WORD,
|
||||
Felt, Word, EMPTY_WORD, ONE,
|
||||
};
|
||||
|
||||
// TEST DATA
|
||||
|
@ -461,6 +465,106 @@ fn test_simplesmt_check_empty_root_constant() {
|
|||
assert_eq!(empty_root_1_depth, SimpleSmt::<1>::EMPTY_ROOT);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simplesmt_subtrees() {
|
||||
const PAIR_COUNT: u64 = 4096;
|
||||
const DEPTH: u8 = 64;
|
||||
type SimpleSmt = super::SimpleSmt<DEPTH>;
|
||||
|
||||
let entries: Vec<(LeafIndex<DEPTH>, Word)> = (0..PAIR_COUNT)
|
||||
.map(|i| {
|
||||
let leaf_index = ((i as f64 / PAIR_COUNT as f64) * (PAIR_COUNT as f64)) as u64;
|
||||
let key = LeafIndex::new_max_depth(leaf_index);
|
||||
let value: Word = [ONE, ONE, ONE, Felt::new(i)];
|
||||
(key, value)
|
||||
})
|
||||
.collect();
|
||||
let leaves = entries.iter().map(|(key, value)| (key.value(), *value));
|
||||
|
||||
let control = SimpleSmt::with_leaves(leaves).unwrap();
|
||||
|
||||
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default();
|
||||
|
||||
let PairComputations {
|
||||
leaves: mut leaf_subtrees,
|
||||
nodes: test_leaves,
|
||||
} = SimpleSmt::sorted_pairs_to_leaves(entries);
|
||||
|
||||
for current_depth in (8..=DEPTH).step_by(8).rev() {
|
||||
for (i, subtree) in mem::take(&mut leaf_subtrees).into_iter().enumerate() {
|
||||
// Pre-assertions.
|
||||
assert!(
|
||||
subtree.is_sorted(),
|
||||
"subtree {i} at bottom-depth {current_depth} is not sorted",
|
||||
);
|
||||
assert!(!subtree.is_empty(), "subtree {i} at bottom-depth {current_depth} is empty!");
|
||||
|
||||
// Do actual things.
|
||||
let (nodes, next_leaves) = SimpleSmt::build_subtree(subtree, current_depth);
|
||||
|
||||
// Post-assertions.
|
||||
assert!(next_leaves.is_sorted());
|
||||
for (&index, test_node) in nodes.iter() {
|
||||
let control_node = control.get_inner_node(index);
|
||||
assert_eq!(
|
||||
test_node, &control_node,
|
||||
"depth {} subtree {}: test node does not match control at index {:?}",
|
||||
current_depth, i, index,
|
||||
);
|
||||
}
|
||||
|
||||
// Update state.
|
||||
accumulated_nodes.extend(nodes);
|
||||
|
||||
for subtree_leaf in next_leaves {
|
||||
smt::add_subtree_leaf(&mut leaf_subtrees, subtree_leaf);
|
||||
}
|
||||
}
|
||||
|
||||
assert!(!leaf_subtrees.is_empty(), "on depth {current_depth}");
|
||||
}
|
||||
|
||||
// Make sure the true leaves match, checking length first and then each individual leaf.
|
||||
let control_leaves: BTreeMap<_, _> = control.leaves().collect();
|
||||
let control_leaves_len = control_leaves.len();
|
||||
let test_leaves_len = test_leaves.len();
|
||||
assert_eq!(test_leaves_len, control_leaves_len);
|
||||
for (col, ref test_leaf) in test_leaves {
|
||||
let &control_leaf = control_leaves.get(&col).unwrap();
|
||||
assert_eq!(test_leaf, control_leaf);
|
||||
}
|
||||
|
||||
// Make sure inner nodes match, checking length first and then each individual leaf.
|
||||
let control_nodes_len = control.inner_nodes().count();
|
||||
let test_nodes_len = accumulated_nodes.len();
|
||||
assert_eq!(test_nodes_len, control_nodes_len);
|
||||
for (index, test_node) in accumulated_nodes.clone() {
|
||||
let control_node = control.get_inner_node(index);
|
||||
assert_eq!(test_node, control_node, "test node does not match control at {index:?}");
|
||||
}
|
||||
|
||||
// After the last iteration of the above for loop, we should have the new root actually in two
|
||||
// places: one in `accumulated_nodes`, and the other as the "next leaves" return from
|
||||
// `build_subtree()`. So let's check both!
|
||||
|
||||
let control_root = control.get_inner_node(NodeIndex::root());
|
||||
|
||||
// That for loop should have left us with only one leaf subtree...
|
||||
let [leaf_subtree]: [_; 1] = leaf_subtrees.try_into().unwrap();
|
||||
// which itself contains only one 'leaf'...
|
||||
let [SubtreeLeaf { hash: test_root_hash, .. }]: [_; 1] = leaf_subtree.try_into().unwrap();
|
||||
// which matches the expected root.
|
||||
assert_eq!(control.root(), test_root_hash);
|
||||
|
||||
// Likewise `accumulated_nodes` should contain a node at the root index...
|
||||
assert!(accumulated_nodes.contains_key(&NodeIndex::root()));
|
||||
// and it should match our actual root.
|
||||
let test_root = accumulated_nodes.get(&NodeIndex::root()).unwrap();
|
||||
assert_eq!(control_root, *test_root);
|
||||
// And of course the root we got from each place should match.
|
||||
assert_eq!(control.root(), test_root_hash);
|
||||
}
|
||||
|
||||
// HELPER FUNCTIONS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue