Compare commits

...

5 commits

9 changed files with 1163 additions and 29 deletions

209
Cargo.lock generated
View file

@ -2,6 +2,21 @@
# It is not intended for manual editing.
version = 3
[[package]]
name = "addr2line"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678"
dependencies = [
"gimli",
]
[[package]]
name = "adler"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "aho-corasick"
version = "1.1.3"
@ -84,6 +99,21 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
[[package]]
name = "backtrace"
version = "0.3.73"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a"
dependencies = [
"addr2line",
"cc",
"cfg-if",
"libc",
"miniz_oxide",
"object",
"rustc-demangle",
]
[[package]]
name = "bit-set"
version = "0.5.3"
@ -261,7 +291,7 @@ dependencies = [
"clap",
"criterion-plot",
"is-terminal",
"itertools",
"itertools 0.10.5",
"num-traits",
"once_cell",
"oorandom",
@ -282,7 +312,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
dependencies = [
"cast",
"itertools",
"itertools 0.10.5",
]
[[package]]
@ -364,6 +394,95 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "futures"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
name = "futures-core"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d"
[[package]]
name = "futures-executor"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
[[package]]
name = "futures-macro"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5"
[[package]]
name = "futures-task"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
[[package]]
name = "futures-util"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
]
[[package]]
name = "generic-array"
version = "0.14.7"
@ -387,6 +506,12 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "gimli"
version = "0.29.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd"
[[package]]
name = "glob"
version = "0.3.1"
@ -447,6 +572,15 @@ dependencies = [
"either",
]
[[package]]
name = "itertools"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.11"
@ -524,24 +658,37 @@ dependencies = [
"cc",
"clap",
"criterion",
"futures",
"getrandom",
"glob",
"hex",
"itertools 0.13.0",
"num",
"num-complex",
"proptest",
"rand",
"rand_chacha",
"rand_core",
"rayon",
"seq-macro",
"serde",
"sha3",
"tokio",
"winter-crypto",
"winter-math",
"winter-rand-utils",
"winter-utils",
]
[[package]]
name = "miniz_oxide"
version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08"
dependencies = [
"adler",
]
[[package]]
name = "num"
version = "0.4.3"
@ -616,6 +763,15 @@ dependencies = [
"libm",
]
[[package]]
name = "object"
version = "0.36.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "084f1a5821ac4c651660a94a7153d27ac9d8a53736203f58b31945ded098070a"
dependencies = [
"memchr",
]
[[package]]
name = "once_cell"
version = "1.19.0"
@ -628,6 +784,18 @@ version = "11.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9"
[[package]]
name = "pin-project-lite"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02"
[[package]]
name = "pin-utils"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "plotters"
version = "0.3.6"
@ -797,6 +965,12 @@ version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
[[package]]
name = "rustc-demangle"
version = "0.1.24"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
[[package]]
name = "rustix"
version = "0.38.34"
@ -885,6 +1059,15 @@ dependencies = [
"keccak",
]
[[package]]
name = "slab"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67"
dependencies = [
"autocfg",
]
[[package]]
name = "strsim"
version = "0.11.1"
@ -925,6 +1108,28 @@ dependencies = [
"serde_json",
]
[[package]]
name = "tokio"
version = "1.40.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998"
dependencies = [
"backtrace",
"pin-project-lite",
"tokio-macros",
]
[[package]]
name = "tokio-macros"
version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "typenum"
version = "1.17.0"

View file

@ -44,21 +44,25 @@ std = [
"winter-math/std",
"winter-utils/std",
]
async = ["serde?/rc"]
async = ["std", "dep:tokio", "dep:rayon", "dep:futures", "serde?/rc"]
[dependencies]
blake3 = { version = "1.5", default-features = false }
clap = { version = "4.5", optional = true, features = ["derive"] }
futures = { version = "0.3.30", optional = true }
num = { version = "0.4", default-features = false, features = ["alloc", "libm"] }
num-complex = { version = "0.4", default-features = false }
rand = { version = "0.8", default-features = false }
rand_core = { version = "0.6", default-features = false }
rand-utils = { version = "0.9", package = "winter-rand-utils", optional = true }
rayon = { version = "1.10.0", optional = true }
serde = { version = "1.0", default-features = false, optional = true, features = ["derive"] }
sha3 = { version = "0.10", default-features = false }
tokio = { version = "1.40", features = ["rt-multi-thread", "macros", "sync"], optional = true }
winter-crypto = { version = "0.9", default-features = false }
winter-math = { version = "0.9", default-features = false }
winter-utils = { version = "0.9", default-features = false }
itertools = { version = "0.13.0", default-features = false, features = ["use_alloc"] }
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }

View file

@ -16,10 +16,51 @@ pub struct BenchmarkCmd {
size: u64,
}
#[cfg(not(feature = "async"))]
fn main() {
benchmark_smt();
}
#[cfg(feature = "async")]
#[tokio::main(flavor = "multi_thread")]
async fn main() {
// FIXME: very incomplete
let args = BenchmarkCmd::parse();
let tree_size = args.size;
let mut entries = Vec::new();
for i in 0..tree_size {
//let key = rand_value::<RpoDigest>();
let leaf_index = u64::MAX / (i + 1);
let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]);
let value = [ONE, ONE, ONE, Felt::new(i)];
entries.push((key, value));
}
let control = Smt::with_entries(entries.clone()).unwrap();
let mut tree = Smt::new();
println!("Running a parallel construction benchmark:");
let now = Instant::now();
let mutations = tree.compute_mutations_parallel(entries).await;
assert_eq!(mutations.root(), control.root());
tree.apply_mutations(mutations.clone()).unwrap();
let elapsed = now.elapsed();
assert_eq!(tree.root(), mutations.root(), "mutation did not apply the right root?");
assert_eq!(control.root(), mutations.root(), "mutation root hash did not match control");
assert_eq!(tree.root(), control.root(), "applied root hash did not match control");
std::eprintln!("\nassertion checks complete");
println!(
"Constructed an SMT in parallel with {} key-value pairs in {:.3} seconds",
tree_size,
elapsed.as_secs_f32(),
);
//benchmark_smt();
}
/// Run a benchmark for [`Smt`].
pub fn benchmark_smt() {
let args = BenchmarkCmd::parse();

View file

@ -11,7 +11,7 @@ use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError,
/// The position is represented by the pair `(depth, pos)`, where for a given depth `d` elements
/// are numbered from $0..(2^d)-1$. Example:
///
/// ```ignore
/// ```text
/// depth
/// 0 0
/// 1 0 1
@ -72,6 +72,37 @@ impl NodeIndex {
Self::new(depth, value)
}
/// Converts a row traversal index to depth-value form and returns it as a [`NodeIndex`].
/// See [`NodeIndex::to_traversal_index()`] for more details.
///
/// # Errors
/// Returns the same errors under the same conditions as [`NodeIndex::new()`].
///
/// # Panics
/// Panics if the depth indicated by `index` does not fit in a [`u8`], or if the row-value
/// indicated by `index` does not fit in a [`u64`].
pub const fn from_traversal_index(index: u128) -> Result<Self, MerkleError> {
if index == 0 {
return Ok(Self { depth: 0, value: 0 });
}
let depth = {
let depth = u128::ilog2(index);
assert!(depth <= u8::MAX as u32);
depth as u8
};
let max_index_for_depth = u128::pow(2, depth as u32) - 1;
let value = {
let value = index - max_index_for_depth;
assert!(value <= u64::MAX as u128);
value as u64
};
Self::new(depth, value)
}
/// Converts a scalar representation of a depth/value pair to a [`NodeIndex`].
///
/// This is the inverse operation of [`NodeIndex::to_scalar_index()`]. As `1` represents the
@ -83,20 +114,29 @@ impl NodeIndex {
/// # Panics
/// Panics if the depth indicated by `index` does not fit in a [`u8`], or if the row-value
/// indicated by `index` does not fit in a [`u64`].
pub const fn from_scalar_index(index: NonZero<u128>) -> Result<Self, MerkleError> {
pub fn from_scalar_index(index: NonZero<u128>) -> Result<Self, MerkleError> {
let index = index.get() - 1;
if index == 0 {
return Ok(Self { depth: 0, value: 0 });
return Ok(Self::root());
}
// The log of 1 is always 0.
if index == 1 {
return Ok(Self::root().left_child());
}
let depth = {
let depth = u128::ilog2(index);
let depth = u128::ilog2(index + 1);
assert!(depth <= u8::MAX as u32);
depth as u8
};
let max_value_for_depth = (1 << depth) - 1;
assert!(
max_value_for_depth <= u64::MAX as u128,
"max_value ({max_value_for_depth}) does not fit in u64",
);
let value = {
let value = index - max_value_for_depth;
@ -125,6 +165,18 @@ impl NodeIndex {
self
}
pub const fn left_ancestor_n(mut self, n: u8) -> Self {
self.depth += n;
self.value <<= n;
self
}
pub const fn right_ancestor_n(mut self, n: u8) -> Self {
self.depth += n;
self.value = (self.value << n) + 1;
self
}
/// Returns right child index of the current node.
pub const fn right_child(mut self) -> Self {
self.depth += 1;
@ -132,6 +184,64 @@ impl NodeIndex {
self
}
/// Returns the parent of the current node.
pub const fn parent(mut self) -> Self {
self.depth = self.depth.saturating_sub(1);
self.value >>= 1;
self
}
/// Returns the `n`th parent of the current node.
pub fn parent_n(mut self, n: u8) -> Self {
debug_assert!(n <= self.depth);
self.depth = self.depth.saturating_sub(n);
self.value >>= n;
self
}
/// Returns `true` if and only if `other` is an ancestor of the current node, or the current
/// node itself.
pub fn contains(&self, mut other: Self) -> bool {
if other == *self {
return true;
}
if other.is_root() {
return false;
}
if other.depth < self.depth {
return false;
}
other = other.parent_n(other.depth() - self.depth());
loop {
if other == *self {
return true;
}
if other.is_root() {
return false;
}
if other.depth < self.depth {
return false;
}
other = other.parent();
}
}
/// The inverse of [`NodeIndex::is_ancestor_of`], except that it does not include itself.
pub fn is_descendent_of(self, other: Self) -> bool {
self.depth != other.depth && self.value != other.value && other.contains(self)
}
/// Returns `true` if and only if `other` is an ancestor of the current node.
pub fn is_ancestor_of(self, other: Self) -> bool {
self.depth != other.depth && self.value != other.value && self.contains(other)
}
// PROVIDERS
// --------------------------------------------------------------------------------------------
@ -154,11 +264,13 @@ impl NodeIndex {
}
/// Returns the depth of the current instance.
#[inline(always)]
pub const fn depth(&self) -> u8 {
self.depth
}
/// Returns the value of this index.
#[inline(always)]
pub const fn value(&self) -> u64 {
self.value
}
@ -169,10 +281,26 @@ impl NodeIndex {
}
/// Returns `true` if the depth is `0`.
#[inline(always)]
pub const fn is_root(&self) -> bool {
self.depth == 0
}
/// Converts this [`NodeIndex`] to the equivalent row traversal index (also called level order
/// traversal index or breadth-first order index).
///
/// [`NodeIndex`] denotes a node position by its depth and the index in the row at that depth.
/// The row traversal index denotes a node position by counting all the nodes "before" it --
/// counting all nodes before it in the row, as well as all nodes in all other rows above it.
/// In other words, a node at row traversal index `n` is the `n`th node in the tree.
/// For example, a node at depth `2`, value `3` has the root node `0, 0`, its two children
/// `1, 0` and `1, 1`, and their children `2, 0`, `2, 1`, `2, 2`, and `2, 3`, 0-indexed, `2, 3`
/// is node **`6`** in that order, so its row traversal index is `6`.
pub const fn to_traversal_index(&self) -> u128 {
(1 << self.depth) - 1 + (self.value as u128)
//u128::pow(2, self.depth as u32) - 1 + (self.value as u128)
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
@ -215,6 +343,27 @@ impl Deserializable for NodeIndex {
}
}
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
pub struct SubtreeIndex {
pub root: NodeIndex,
pub depth: u8,
}
#[allow(dead_code)]
impl SubtreeIndex {
pub const fn new(root: NodeIndex, depth: u8) -> Self {
Self { root, depth }
}
pub const fn left_bound(&self) -> NodeIndex {
self.root.left_ancestor_n(self.depth)
}
pub const fn right_bound(&self) -> NodeIndex {
self.root.right_ancestor_n(self.depth)
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
@ -245,6 +394,21 @@ mod tests {
assert!(NodeIndex::new(64, u64::MAX).is_ok());
}
#[test]
fn test_scalar_roundtrip() {
// Arbitrary value that's at the bottom and not in a corner.
let start = NodeIndex::make(64, u64::MAX - 8);
let mut index = start;
while !index.is_root() {
let as_scalar = index.to_scalar_index();
let round_trip =
NodeIndex::from_scalar_index(NonZero::new(as_scalar).unwrap()).unwrap();
assert_eq!(index, round_trip, "{index:?} did not round-trip as a scalar index");
index.move_up();
}
}
prop_compose! {
fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
// unwrap never panics because the range of depth is 0..u64::BITS

View file

@ -1,16 +1,25 @@
#[cfg(feature = "async")]
use std::sync::Arc;
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use alloc::{
collections::{BTreeMap, BTreeSet},
string::ToString,
vec::Vec,
};
#[cfg(feature = "async")]
use tokio::task::JoinSet;
#[cfg(feature = "async")]
use super::NodeMutation;
use super::{
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath,
MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
};
#[cfg(feature = "async")]
use crate::merkle::index::SubtreeIndex;
mod error;
pub use error::{SmtLeafError, SmtProofError};
@ -46,7 +55,12 @@ pub const SMT_DEPTH: u8 = 64;
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct Smt {
root: RpoDigest,
#[cfg(not(feature = "async"))]
leaves: BTreeMap<u64, SmtLeaf>,
#[cfg(feature = "async")]
leaves: Arc<BTreeMap<u64, SmtLeaf>>,
#[cfg(not(feature = "async"))]
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
#[cfg(feature = "async")]
@ -70,7 +84,7 @@ impl Smt {
Self {
root,
leaves: BTreeMap::new(),
leaves: Default::default(),
inner_nodes: Default::default(),
}
}
@ -107,6 +121,22 @@ impl Smt {
Ok(tree)
}
#[cfg(feature = "async")]
pub fn get_leaves(&self) -> Arc<BTreeMap<u64, SmtLeaf>> {
Arc::clone(&self.leaves)
}
#[cfg(feature = "async")]
pub async fn compute_mutations_parallel(
&self,
kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
<Self as super::ParallelSparseMerkleTree<SMT_DEPTH>>::compute_mutations_parallel(
self, kv_pairs,
)
.await
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
@ -177,6 +207,23 @@ impl Smt {
}
}
/// Gets a mutable reference to this structure's inner leaf mapping.
///
/// # Panics
/// This will panic if we have violated our own invariants and try to mutate these nodes while
/// Self::compute_mutations_parallel() is still running.
fn leaves_mut(&mut self) -> &mut BTreeMap<u64, SmtLeaf> {
#[cfg(feature = "async")]
{
Arc::get_mut(&mut self.leaves).unwrap()
}
#[cfg(not(feature = "async"))]
{
&mut self.leaves
}
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
@ -241,10 +288,12 @@ impl Smt {
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
match self.leaves.get_mut(&leaf_index.value()) {
let leaves = self.leaves_mut();
match leaves.get_mut(&leaf_index.value()) {
Some(leaf) => leaf.insert(key, value),
None => {
self.leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
None
},
@ -255,10 +304,12 @@ impl Smt {
fn perform_remove(&mut self, key: RpoDigest) -> Option<Word> {
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
if let Some(leaf) = self.leaves.get_mut(&leaf_index.value()) {
let leaves = self.leaves_mut();
if let Some(leaf) = leaves.get_mut(&leaf_index.value()) {
let (old_value, is_empty) = leaf.remove(key);
if is_empty {
self.leaves.remove(&leaf_index.value());
leaves.remove(&leaf_index.value());
}
old_value
} else {
@ -266,6 +317,27 @@ impl Smt {
None
}
}
fn construct_prospective_leaf(
mut existing_leaf: SmtLeaf,
key: &RpoDigest,
value: &Word,
) -> SmtLeaf {
debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key));
match existing_leaf {
SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value),
_ => {
if *value != EMPTY_WORD {
existing_leaf.insert(*key, *value);
} else {
existing_leaf.remove(*key);
}
existing_leaf
},
}
}
}
impl SparseMerkleTree<SMT_DEPTH> for Smt {
@ -332,24 +404,11 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
fn construct_prospective_leaf(
&self,
mut existing_leaf: SmtLeaf,
existing_leaf: SmtLeaf,
key: &RpoDigest,
value: &Word,
) -> SmtLeaf {
debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key));
match existing_leaf {
SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value),
_ => {
if *value != EMPTY_WORD {
existing_leaf.insert(*key, *value);
} else {
existing_leaf.remove(*key);
}
existing_leaf
},
}
Smt::construct_prospective_leaf(existing_leaf, key, value)
}
fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex<SMT_DEPTH> {
@ -362,12 +421,301 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
}
}
#[cfg(feature = "async")]
impl super::ParallelSparseMerkleTree<SMT_DEPTH> for Smt {
// Helpers required only for the parallel version of the SMT trait.
fn get_inner_nodes(&self) -> Arc<BTreeMap<NodeIndex, InnerNode>> {
Arc::clone(&self.inner_nodes)
}
fn get_leaves(&self) -> Arc<BTreeMap<u64, SmtLeaf>> {
Arc::clone(&self.leaves)
}
async fn compute_mutations_parallel<I>(
&self,
kv_pairs: I,
) -> MutationSet<SMT_DEPTH, RpoDigest, Word>
where
I: IntoIterator<Item = (RpoDigest, Word)>,
{
use std::time::Instant;
const SUBTREE_INTERVAL: u8 = 8;
// FIXME: check for duplicates and return MerkleError.
let kv_pairs = Arc::new(BTreeMap::from_iter(kv_pairs));
// The first subtrees we calculate, which include our new leaves.
let mut subtrees: HashSet<NodeIndex> = kv_pairs
.keys()
.map(|key| {
let index_for_key = NodeIndex::from(Smt::key_to_leaf_index(key));
index_for_key.parent_n(SUBTREE_INTERVAL)
})
.collect();
// Node mutations across all tasks will be collected here.
// Every time we collect tasks we store all the new known node mutations and their hashes
// (so we don't have to recompute them every time we need them).
let mut node_mutations: Arc<HashMap<NodeIndex, (RpoDigest, NodeMutation)>> =
Default::default();
// Any leaf hashes done by tasks will be collected here, so hopefully we only hash each leaf
// once.
let mut cached_leaf_hashes: Arc<HashMap<LeafIndex<SMT_DEPTH>, RpoDigest>> =
Default::default();
for subtree_depth in (0..SMT_DEPTH).step_by(SUBTREE_INTERVAL.into()).rev() {
let now = Instant::now();
let mut tasks = JoinSet::new();
for subtree in subtrees.iter().copied() {
debug_assert_eq!(subtree.depth(), subtree_depth);
let mut state = NodeSubtreeState::<SMT_DEPTH>::with_smt(
&self,
Arc::clone(&node_mutations),
Arc::clone(&kv_pairs),
SubtreeIndex::new(subtree, SUBTREE_INTERVAL as u8),
);
// The "double spawn" here is necessary to allow tokio to run these tasks in
// parallel.
tasks.spawn(tokio::spawn(async move {
let hash = state.get_or_make_hash(subtree);
(subtree, hash, state.into_results())
}));
}
let task_results = tasks.join_all().await;
let elapsed = now.elapsed();
std::eprintln!(
"joined {} tasks for depth {} in {:.3} milliseconds",
task_results.len(),
subtree_depth,
elapsed.as_secs_f64() * 1000.0,
);
for result in task_results {
// FIXME: .expect() error message?
let result = result.unwrap();
let (subtree, hash, state) = result;
let NodeSubtreeResults {
new_mutations,
cached_leaf_hashes: new_leaf_hashes,
} = state;
Arc::get_mut(&mut node_mutations).unwrap().extend(new_mutations);
Arc::get_mut(&mut cached_leaf_hashes).unwrap().extend(new_leaf_hashes);
// Make sure the final hash we calculated is in the new mutations.
assert_eq!(
node_mutations.get(&subtree).unwrap().0,
hash,
"Stored and returned hashes for subtree '{subtree:?}' differ",
);
}
// And advance our subtrees, unless we just did the root depth.
if subtree_depth == 0 {
continue;
}
let subtree_count_before_advance = subtrees.len();
subtrees =
subtrees.into_iter().map(|subtree| subtree.parent_n(SUBTREE_INTERVAL)).collect();
// FIXME: remove.
assert!(subtrees.len() <= subtree_count_before_advance);
}
let root = NodeIndex::root();
let new_root = node_mutations.get(&root).unwrap().0;
MutationSet {
old_root: self.root(),
//node_mutations: Arc::into_inner(node_mutations).unwrap().into_iter().collect(),
node_mutations: Arc::into_inner(node_mutations)
.unwrap()
.into_iter()
.map(|(key, (_hash, node))| (key, node))
.collect(),
new_pairs: Arc::into_inner(kv_pairs).unwrap(),
new_root,
}
}
}
impl Default for Smt {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "async")]
pub(crate) struct NodeSubtreeState<const DEPTH: u8> {
inner_nodes: Arc<BTreeMap<NodeIndex, InnerNode>>,
leaves: Arc<BTreeMap<u64, SmtLeaf>>,
// This field has invariants!
dirtied_indices: HashMap<NodeIndex, bool>,
existing_mutations: Arc<HashMap<NodeIndex, (RpoDigest, NodeMutation)>>,
new_mutations: HashMap<NodeIndex, (RpoDigest, NodeMutation)>,
new_pairs: Arc<BTreeMap<RpoDigest, Word>>,
cached_leaf_hashes: HashMap<LeafIndex<SMT_DEPTH>, RpoDigest>,
indentation: u8,
subtree: SubtreeIndex,
dirty_high: u128,
dirty_low: u128,
}
#[cfg(feature = "async")]
impl<const DEPTH: u8> NodeSubtreeState<DEPTH> {
pub(crate) fn new(
inner_nodes: Arc<BTreeMap<NodeIndex, InnerNode>>,
existing_mutations: Arc<HashMap<NodeIndex, (RpoDigest, NodeMutation)>>,
leaves: Arc<BTreeMap<u64, SmtLeaf>>,
new_pairs: Arc<BTreeMap<RpoDigest, Word>>,
subtree: SubtreeIndex,
) -> Self {
Self {
inner_nodes,
leaves,
dirtied_indices: Default::default(),
new_mutations: Default::default(),
existing_mutations,
new_pairs,
cached_leaf_hashes: Default::default(),
indentation: 0,
subtree,
dirty_high: 0,
dirty_low: 0,
}
}
pub(crate) fn with_smt(
smt: &Smt,
existing_mutations: Arc<HashMap<NodeIndex, (RpoDigest, NodeMutation)>>,
new_pairs: Arc<BTreeMap<RpoDigest, Word>>,
subtree: SubtreeIndex,
) -> Self {
Self::new(
Arc::clone(&smt.inner_nodes),
existing_mutations,
Arc::clone(&smt.leaves),
new_pairs,
subtree,
)
}
#[inline(never)] // XXX: for profiling.
pub(crate) fn is_index_dirty(&mut self, index_to_check: NodeIndex) -> bool {
if index_to_check == self.subtree.root {
return true;
}
if let Some(cached) = self.dirtied_indices.get(&index_to_check) {
return *cached;
}
let is_dirty = self
.existing_mutations
.iter()
.map(|(index, _)| *index)
.chain(self.new_pairs.iter().map(|(key, _v)| Smt::key_to_leaf_index(key).index))
.filter(|&dirtied_index| index_to_check.contains(dirtied_index))
.next()
.is_some();
self.dirtied_indices.insert(index_to_check, is_dirty);
is_dirty
}
/// Does NOT check `new_mutations`.
#[inline(never)] // XXX: for profiling.
pub(crate) fn get_clean_hash(&self, index: NodeIndex) -> Option<RpoDigest> {
self.existing_mutations
.get(&index)
.map(|(hash, _)| *hash)
.or_else(|| self.inner_nodes.get(&index).map(|inner_node| InnerNode::hash(&inner_node)))
}
#[inline(never)] // XXX: for profiling.
pub(crate) fn get_effective_leaf(&self, index: LeafIndex<SMT_DEPTH>) -> SmtLeaf {
let pairs_at_index = self
.new_pairs
.iter()
.filter(|&(new_key, _)| Smt::key_to_leaf_index(new_key) == index);
let existing_leaf = self
.leaves
.get(&index.index.value())
.cloned()
.unwrap_or_else(|| SmtLeaf::new_empty(index));
pairs_at_index.fold(existing_leaf, |acc, (k, v)| {
let existing_leaf = acc.clone();
Smt::construct_prospective_leaf(existing_leaf, k, v)
})
}
/// Retrieve a cached hash, or recursively compute it.
#[inline(never)] // XXX: for profiling.
pub fn get_or_make_hash(&mut self, index: NodeIndex) -> RpoDigest {
use NodeMutation::*;
// If this is a leaf, then only do leaf stuff.
if index.depth() == SMT_DEPTH {
let index = LeafIndex::new(index.value()).unwrap();
return match self.cached_leaf_hashes.get(&index) {
Some(cached_hash) => cached_hash.clone(),
None => {
let leaf = self.get_effective_leaf(index);
let hash = Smt::hash_leaf(&leaf);
self.cached_leaf_hashes.insert(index, hash);
hash
},
};
}
// If we already computed this one earlier as a mutation, just return it.
if let Some((hash, _)) = self.new_mutations.get(&index) {
return *hash;
}
// Otherwise, we need to know if this node is one of the nodes we're in the process of
// recomputing, or if we can safely use the node already in the Merkle tree.
if !self.is_index_dirty(index) {
return self
.get_clean_hash(index)
.unwrap_or_else(|| *EmptySubtreeRoots::entry(SMT_DEPTH, index.depth()));
}
// If we got here, then we have to make, rather than get, this hash.
// Make sure we mark this index as now dirty.
self.dirtied_indices.insert(index, true);
// Recurse for the left and right sides.
let left = self.get_or_make_hash(index.left_child());
let right = self.get_or_make_hash(index.right_child());
let node = InnerNode { left, right };
let hash = node.hash();
let &equivalent_empty_hash = EmptySubtreeRoots::entry(SMT_DEPTH, index.depth());
let is_removal = hash == equivalent_empty_hash;
let new_entry = if is_removal { Removal } else { Addition(node) };
self.new_mutations.insert(index, (hash, new_entry));
hash
}
fn into_results(self) -> NodeSubtreeResults {
NodeSubtreeResults {
new_mutations: self.new_mutations,
cached_leaf_hashes: self.cached_leaf_hashes,
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg(feature = "async")]
pub(crate) struct NodeSubtreeResults {
pub(crate) new_mutations: HashMap<NodeIndex, (RpoDigest, NodeMutation)>,
pub(crate) cached_leaf_hashes: HashMap<LeafIndex<SMT_DEPTH>, RpoDigest>,
}
// CONVERSIONS
// ================================================================================================

View file

@ -1,6 +1,19 @@
use alloc::vec::Vec;
#[cfg(feature = "async")]
use std::{
collections::{BTreeMap, BTreeSet, HashMap},
sync::Arc,
time::Instant,
};
#[cfg(feature = "async")]
use tokio::task::JoinSet;
use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH};
#[cfg(feature = "async")]
use crate::merkle::{
index::SubtreeIndex,
smt::{full::NodeSubtreeState, NodeMutation},
};
use crate::{
merkle::{smt::SparseMerkleTree, EmptySubtreeRoots, MerkleStore},
utils::{Deserializable, Serializable},
@ -568,6 +581,263 @@ fn test_multiple_smt_leaf_serialization_success() {
assert_eq!(multiple_leaf, deserialized);
}
#[cfg(feature = "async")]
fn setup_subtree_test(kv_count: u64) -> (Vec<(RpoDigest, Word)>, Smt) {
// FIXME: override seed.
let rand_felt = || rand_utils::rand_value::<Felt>();
let kv_pairs: Vec<(RpoDigest, Word)> = (0..kv_count)
.into_iter()
.map(|i| {
let leaf_index = u64::MAX / (i + 1);
let key =
RpoDigest::new([rand_felt(), rand_felt(), rand_felt(), Felt::new(leaf_index)]);
let value: Word = [Felt::new(i), rand_felt(), rand_felt(), rand_felt()];
(key, value)
})
.collect();
let control_smt = Smt::with_entries(kv_pairs.clone()).unwrap();
(kv_pairs, control_smt)
}
#[test]
#[cfg(feature = "async")]
fn test_single_node_subtree() {
use alloc::collections::BTreeMap;
use std::{collections::HashMap, sync::Arc};
use crate::merkle::smt::{full::NodeSubtreeState, NodeMutation};
const KV_COUNT: u64 = 2_000;
let (kv_pairs, control_smt) = setup_subtree_test(KV_COUNT);
let new_pairs = Arc::new(BTreeMap::from_iter(kv_pairs));
let test_smt = Smt::new();
let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();
let _: () = rt.block_on(async move {
// Construct some fake node mutations based on the leaves in the control Smt.
let node_mutations: HashMap<NodeIndex, (RpoDigest, NodeMutation)> = control_smt
.leaves()
.flat_map(|(index, _leaf)| {
let subtree = index.index.parent();
let mutation = control_smt
.inner_nodes
.get(&subtree)
.cloned()
.map(|node| (node.hash(), NodeMutation::Addition(node)))
.unwrap_or_else(|| {
(
*EmptySubtreeRoots::entry(SMT_DEPTH, subtree.depth()),
NodeMutation::Removal,
)
});
vec![(subtree, mutation)]
})
.collect();
let node_mutations = Arc::new(node_mutations);
let mut state = NodeSubtreeState::<SMT_DEPTH>::new(
Arc::clone(&test_smt.inner_nodes),
Arc::clone(&node_mutations),
Arc::clone(&control_smt.leaves),
Arc::clone(&new_pairs),
SubtreeIndex::new(NodeIndex::root(), 8),
);
for (i, (&index, mutation)) in node_mutations.iter().enumerate() {
assert!(index.depth() <= SMT_DEPTH, "index {index:?} is invalid");
let control_hash = if index.depth() < SMT_DEPTH {
control_smt.get_inner_node(index).hash()
} else {
control_smt
.leaves
.get(&index.value())
.map(Smt::hash_leaf)
.unwrap_or_else(|| *EmptySubtreeRoots::entry(SMT_DEPTH, index.depth()))
};
let mutation_hash = mutation.0;
let test_hash = state.get_or_make_hash(index);
assert_eq!(mutation_hash, control_hash);
assert_eq!(
test_hash, control_hash,
"test_hash != control_hash for mutation {i} at {index:?}",
);
}
});
}
// Test doing a node subtree from a LeafSubtreeMutationSet.
#[test]
#[cfg(feature = "async")]
fn test_node_subtree_with_leaves() {
const KV_COUNT: u64 = 2_000;
let (kv_pairs, control_smt) = setup_subtree_test(KV_COUNT);
let new_pairs = Arc::new(BTreeMap::from_iter(kv_pairs));
let test_smt = Smt::new();
let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();
let _: () = rt.block_on(async move {
let mut task_mutations: HashMap<NodeIndex, (RpoDigest, NodeMutation)> = Default::default();
for leaf_index in control_smt.leaves().map(|(index, _leaf)| index) {
let subtree = SubtreeIndex::new(leaf_index.index.parent_n(8), 8);
let subtree_pairs: Vec<(RpoDigest, Word)> = control_smt
.leaves()
.flat_map(|(leaf_index, leaf)| {
if subtree.root.contains(leaf_index.index) {
leaf.entries()
} else {
vec![]
}
})
.cloned()
.collect();
let mut state = NodeSubtreeState::<SMT_DEPTH>::with_smt(
&test_smt,
Default::default(),
Arc::new(BTreeMap::from_iter(subtree_pairs)),
//Arc::clone(&new_pairs),
subtree,
);
let test_subtree_hash = state.get_or_make_hash(subtree.root);
let control_subtree_hash = control_smt.get_inner_node(subtree.root).hash();
assert_eq!(test_subtree_hash, control_subtree_hash);
task_mutations.extend(state.new_mutations);
}
let node_mutations = Arc::new(task_mutations);
let subtrees: BTreeSet<SubtreeIndex> = control_smt
.leaves()
.map(|(index, _leaf)| SubtreeIndex::new(index.index.parent_n(8).parent_n(8), 8))
.collect();
for (i, subtree) in subtrees.into_iter().enumerate() {
let mut state = NodeSubtreeState::<SMT_DEPTH>::with_smt(
&test_smt,
Arc::clone(&node_mutations),
Arc::clone(&new_pairs),
subtree,
);
let control_subtree_hash = control_smt.get_inner_node(subtree.root).hash();
let test_subtree_hash = state.get_or_make_hash(subtree.root);
assert_eq!(
test_subtree_hash, control_subtree_hash,
"test subtree hash does not match control hash for subtree {i} '{subtree:?}'",
);
}
});
}
#[test]
#[cfg(feature = "async")]
fn test_node_subtrees_parallel() {
const KV_COUNT: u64 = 2_000;
let (kv_pairs, control_smt) = setup_subtree_test(KV_COUNT);
let new_pairs = Arc::new(BTreeMap::from_iter(kv_pairs));
let test_smt = Smt::new();
let rt = tokio::runtime::Builder::new_multi_thread().build().unwrap();
//let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();
let _: () = rt.block_on(async move {
let mut current_subtree_depth = SMT_DEPTH;
let subtrees: BTreeSet<SubtreeIndex> = new_pairs
.keys()
.map(|key| SubtreeIndex::new(Smt::key_to_leaf_index(key).index.parent_n(8), 8))
.collect();
current_subtree_depth -= 8;
let mut node_mutations: Arc<HashMap<NodeIndex, (RpoDigest, NodeMutation)>> =
Default::default();
let mut tasks = JoinSet::new();
// FIXME
let mut now = Instant::now();
for subtree in subtrees.iter().copied() {
let mut state = NodeSubtreeState::<SMT_DEPTH>::with_smt(
&test_smt,
Arc::clone(&node_mutations),
Arc::clone(&new_pairs),
subtree,
);
tasks.spawn(tokio::spawn(async move {
let hash = state.get_or_make_hash(subtree.root);
(subtree, hash, state)
}));
}
let mut cached_leaf_hashes: HashMap<LeafIndex<SMT_DEPTH>, RpoDigest> = Default::default();
let mut subtrees = subtrees;
let mut tasks = Some(tasks);
while current_subtree_depth > 0 {
std::eprintln!(
"joining {} tasks for depth {current_subtree_depth}",
tasks.as_ref().unwrap().len(),
);
let mut tasks_mutations: HashMap<NodeIndex, (RpoDigest, NodeMutation)> =
Default::default();
let results = tasks.take().unwrap().join_all().await;
let elapsed = now.elapsed();
std::eprintln!(" joined in {:.3} milliseconds", elapsed.as_secs_f64() * 1000.0);
for result in results {
let (subtree, test_hash, state) = result.unwrap();
let control_hash = control_smt.get_inner_node(subtree.root).hash();
assert_eq!(test_hash, control_hash);
tasks_mutations.extend(state.new_mutations);
cached_leaf_hashes.extend(state.cached_leaf_hashes);
}
Arc::get_mut(&mut node_mutations).unwrap().extend(tasks_mutations);
// Move all our subtrees up.
current_subtree_depth -= 8;
subtrees = subtrees
.into_iter()
.map(|subtree| {
let subtree = SubtreeIndex::new(subtree.root.parent_n(8), 8);
assert_eq!(subtree.root.depth(), current_subtree_depth);
subtree
})
.collect();
// And spawn our new tasks.
//std::eprintln!("spawning tasks for depth {current_subtree_depth}");
let tasks = tasks.insert(JoinSet::new());
// FIXME
now = Instant::now();
for subtree in subtrees.iter().copied() {
let mut state = NodeSubtreeState::<SMT_DEPTH>::with_smt(
&test_smt,
Arc::clone(&node_mutations),
Arc::clone(&new_pairs),
subtree,
);
state.cached_leaf_hashes = cached_leaf_hashes.clone();
tasks.spawn(tokio::spawn(async move {
let hash = state.get_or_make_hash(subtree.root);
(subtree, hash, state)
}));
}
}
assert!(tasks.is_some());
assert_eq!(tasks.as_ref().unwrap().len(), 1);
});
}
// HELPERS
// --------------------------------------------------------------------------------------------

View file

@ -12,6 +12,11 @@ pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH};
mod simple;
pub use simple::SimpleSmt;
#[cfg(feature = "async")]
mod parallel;
#[cfg(feature = "async")]
pub(crate) use parallel::ParallelSparseMerkleTree;
// CONSTANTS
// ================================================================================================
@ -466,6 +471,26 @@ pub(crate) enum NodeMutation {
Addition(InnerNode),
}
impl NodeMutation {
#[allow(dead_code)]
pub fn into_inner_node(self, tree_depth: u8, node_depth: u8) -> InnerNode {
use NodeMutation::*;
match self {
Addition(node) => node,
Removal => EmptySubtreeRoots::get_inner_node(tree_depth, node_depth),
}
}
#[allow(dead_code)]
pub fn as_hash(&self, tree_depth: u8, node_depth: u8) -> RpoDigest {
use NodeMutation::*;
match self {
Addition(node) => node.hash(),
Removal => *EmptySubtreeRoots::entry(tree_depth, node_depth),
}
}
}
/// Represents a group of prospective mutations to a `SparseMerkleTree`, created by
/// `SparseMerkleTree::compute_mutations()`, and that can be applied with
/// `SparseMerkleTree::apply_mutations()`.
@ -499,7 +524,6 @@ impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;

View file

@ -0,0 +1,62 @@
use std::{
cmp::Ordering,
collections::BTreeMap,
sync::{Arc, LazyLock},
thread,
};
use crate::merkle::smt::{InnerNode, MutationSet, NodeIndex, SparseMerkleTree};
static TASK_COUNT: LazyLock<usize> = LazyLock::new(|| {
// FIXME: error handling?
thread::available_parallelism().unwrap().get()
});
#[allow(dead_code)]
pub(crate) trait ParallelSparseMerkleTree<const DEPTH: u8>
where
// Note: these type bounds need to be specified this way or we'll have to duplicate them
// everywhere.
// https://github.com/rust-lang/rust/issues/130805.
Self: SparseMerkleTree<
DEPTH,
Key: Send + Sync + 'static,
Value: Send + Sync + 'static,
Leaf: Send + Sync + 'static,
>,
{
/// Shortcut for [`ParallelSparseMerkleTree::compute_mutations_parallel_n()`] with an
/// automatically determined number of tasks.
///
/// Currently, the default number of tasks is the return value of
/// [`std::thread::available_parallelism()`], but this may be subject to change in the future.
async fn compute_mutations_parallel<I>(
&self,
kv_pairs: I,
) -> MutationSet<DEPTH, Self::Key, Self::Value>
where
I: IntoIterator<Item = (Self::Key, Self::Value)>,
{
self.compute_mutations_parallel_n(kv_pairs, *TASK_COUNT).await
}
async fn compute_mutations_parallel_n<I>(
&self,
_kv_pairs: I,
_tasks: usize,
) -> MutationSet<DEPTH, Self::Key, Self::Value>
where
I: IntoIterator<Item = (Self::Key, Self::Value)>,
{
todo!();
}
fn get_inner_nodes(&self) -> Arc<BTreeMap<NodeIndex, InnerNode>>;
fn get_leaves(&self) -> Arc<BTreeMap<u64, Self::Leaf>>;
fn get_leaf_value(_leaf: &Self::Leaf, _key: &Self::Key) -> Option<Self::Value> {
todo!();
}
fn cmp_keys(_lhs: &Self::Key, _rhs: &Self::Key) -> Ordering {
todo!();
}
}

View file

@ -395,3 +395,19 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
(path, leaf).into()
}
}
//#[cfg(feature = "async")]
////impl<const DEPTH: u8> super::ParallelSparseMerkleTree<DEPTH, LeafIndex<DEPTH>, Word, Word>
//impl<const DEPTH: u8> super::ParallelSparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
// fn get_inner_nodes(&self) -> Arc<BTreeMap<NodeIndex, InnerNode>> {
// Arc::clone(&self.inner_nodes)
// }
//
// fn get_leaf_value(leaf: &Word, _key: &LeafIndex<DEPTH>) -> Option<Word> {
// Some(*leaf)
// }
//
// fn cmp_keys(lhs: &LeafIndex<DEPTH>, rhs: &LeafIndex<DEPTH>) -> Ordering {
// LeafIndex::cmp(lhs, rhs)
// }
//}