Compare commits
5 commits
b289e7ed73
...
8b0d7bd7b4
Author | SHA1 | Date | |
---|---|---|---|
8b0d7bd7b4 | |||
3817ddbec3 | |||
aa88e29f2c | |||
1ca498b346 | |||
1632ec54aa |
9 changed files with 1163 additions and 29 deletions
209
Cargo.lock
generated
209
Cargo.lock
generated
|
@ -2,6 +2,21 @@
|
|||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "addr2line"
|
||||
version = "0.22.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6e4503c46a5c0c7844e948c9a4d6acd9f50cccb4de1c48eb9e291ea17470c678"
|
||||
dependencies = [
|
||||
"gimli",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "adler"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "1.1.3"
|
||||
|
@ -84,6 +99,21 @@ version = "1.3.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0"
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.73"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5cc23269a4f8976d0a4d2e7109211a419fe30e8d88d677cd60b6bc79c5732e0a"
|
||||
dependencies = [
|
||||
"addr2line",
|
||||
"cc",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"miniz_oxide",
|
||||
"object",
|
||||
"rustc-demangle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bit-set"
|
||||
version = "0.5.3"
|
||||
|
@ -261,7 +291,7 @@ dependencies = [
|
|||
"clap",
|
||||
"criterion-plot",
|
||||
"is-terminal",
|
||||
"itertools",
|
||||
"itertools 0.10.5",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"oorandom",
|
||||
|
@ -282,7 +312,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
|
||||
dependencies = [
|
||||
"cast",
|
||||
"itertools",
|
||||
"itertools 0.10.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -364,6 +394,95 @@ version = "1.0.7"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-executor",
|
||||
"futures-io",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-core"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d"
|
||||
|
||||
[[package]]
|
||||
name = "futures-executor"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5"
|
||||
|
||||
[[package]]
|
||||
name = "futures-task"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-macro",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"memchr",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.14.7"
|
||||
|
@ -387,6 +506,12 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gimli"
|
||||
version = "0.29.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "40ecd4077b5ae9fd2e9e169b102c6c330d0605168eb0e8bf79952b256dbefffd"
|
||||
|
||||
[[package]]
|
||||
name = "glob"
|
||||
version = "0.3.1"
|
||||
|
@ -447,6 +572,15 @@ dependencies = [
|
|||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.11"
|
||||
|
@ -524,24 +658,37 @@ dependencies = [
|
|||
"cc",
|
||||
"clap",
|
||||
"criterion",
|
||||
"futures",
|
||||
"getrandom",
|
||||
"glob",
|
||||
"hex",
|
||||
"itertools 0.13.0",
|
||||
"num",
|
||||
"num-complex",
|
||||
"proptest",
|
||||
"rand",
|
||||
"rand_chacha",
|
||||
"rand_core",
|
||||
"rayon",
|
||||
"seq-macro",
|
||||
"serde",
|
||||
"sha3",
|
||||
"tokio",
|
||||
"winter-crypto",
|
||||
"winter-math",
|
||||
"winter-rand-utils",
|
||||
"winter-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08"
|
||||
dependencies = [
|
||||
"adler",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num"
|
||||
version = "0.4.3"
|
||||
|
@ -616,6 +763,15 @@ dependencies = [
|
|||
"libm",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "object"
|
||||
version = "0.36.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "084f1a5821ac4c651660a94a7153d27ac9d8a53736203f58b31945ded098070a"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.19.0"
|
||||
|
@ -628,6 +784,18 @@ version = "11.1.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9"
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02"
|
||||
|
||||
[[package]]
|
||||
name = "pin-utils"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
|
||||
|
||||
[[package]]
|
||||
name = "plotters"
|
||||
version = "0.3.6"
|
||||
|
@ -797,6 +965,12 @@ version = "0.8.4"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-demangle"
|
||||
version = "0.1.24"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f"
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.34"
|
||||
|
@ -885,6 +1059,15 @@ dependencies = [
|
|||
"keccak",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "strsim"
|
||||
version = "0.11.1"
|
||||
|
@ -925,6 +1108,28 @@ dependencies = [
|
|||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.40.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998"
|
||||
dependencies = [
|
||||
"backtrace",
|
||||
"pin-project-lite",
|
||||
"tokio-macros",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-macros"
|
||||
version = "2.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.17.0"
|
||||
|
|
|
@ -44,21 +44,25 @@ std = [
|
|||
"winter-math/std",
|
||||
"winter-utils/std",
|
||||
]
|
||||
async = ["serde?/rc"]
|
||||
async = ["std", "dep:tokio", "dep:rayon", "dep:futures", "serde?/rc"]
|
||||
|
||||
[dependencies]
|
||||
blake3 = { version = "1.5", default-features = false }
|
||||
clap = { version = "4.5", optional = true, features = ["derive"] }
|
||||
futures = { version = "0.3.30", optional = true }
|
||||
num = { version = "0.4", default-features = false, features = ["alloc", "libm"] }
|
||||
num-complex = { version = "0.4", default-features = false }
|
||||
rand = { version = "0.8", default-features = false }
|
||||
rand_core = { version = "0.6", default-features = false }
|
||||
rand-utils = { version = "0.9", package = "winter-rand-utils", optional = true }
|
||||
rayon = { version = "1.10.0", optional = true }
|
||||
serde = { version = "1.0", default-features = false, optional = true, features = ["derive"] }
|
||||
sha3 = { version = "0.10", default-features = false }
|
||||
tokio = { version = "1.40", features = ["rt-multi-thread", "macros", "sync"], optional = true }
|
||||
winter-crypto = { version = "0.9", default-features = false }
|
||||
winter-math = { version = "0.9", default-features = false }
|
||||
winter-utils = { version = "0.9", default-features = false }
|
||||
itertools = { version = "0.13.0", default-features = false, features = ["use_alloc"] }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.5", features = ["html_reports"] }
|
||||
|
|
41
src/main.rs
41
src/main.rs
|
@ -16,10 +16,51 @@ pub struct BenchmarkCmd {
|
|||
size: u64,
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "async"))]
|
||||
fn main() {
|
||||
benchmark_smt();
|
||||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
#[tokio::main(flavor = "multi_thread")]
|
||||
async fn main() {
|
||||
// FIXME: very incomplete
|
||||
|
||||
let args = BenchmarkCmd::parse();
|
||||
let tree_size = args.size;
|
||||
|
||||
let mut entries = Vec::new();
|
||||
for i in 0..tree_size {
|
||||
//let key = rand_value::<RpoDigest>();
|
||||
let leaf_index = u64::MAX / (i + 1);
|
||||
let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]);
|
||||
let value = [ONE, ONE, ONE, Felt::new(i)];
|
||||
entries.push((key, value));
|
||||
}
|
||||
|
||||
let control = Smt::with_entries(entries.clone()).unwrap();
|
||||
|
||||
let mut tree = Smt::new();
|
||||
println!("Running a parallel construction benchmark:");
|
||||
let now = Instant::now();
|
||||
let mutations = tree.compute_mutations_parallel(entries).await;
|
||||
assert_eq!(mutations.root(), control.root());
|
||||
tree.apply_mutations(mutations.clone()).unwrap();
|
||||
let elapsed = now.elapsed();
|
||||
assert_eq!(tree.root(), mutations.root(), "mutation did not apply the right root?");
|
||||
assert_eq!(control.root(), mutations.root(), "mutation root hash did not match control");
|
||||
assert_eq!(tree.root(), control.root(), "applied root hash did not match control");
|
||||
std::eprintln!("\nassertion checks complete");
|
||||
|
||||
println!(
|
||||
"Constructed an SMT in parallel with {} key-value pairs in {:.3} seconds",
|
||||
tree_size,
|
||||
elapsed.as_secs_f32(),
|
||||
);
|
||||
|
||||
//benchmark_smt();
|
||||
}
|
||||
|
||||
/// Run a benchmark for [`Smt`].
|
||||
pub fn benchmark_smt() {
|
||||
let args = BenchmarkCmd::parse();
|
||||
|
|
|
@ -11,7 +11,7 @@ use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError,
|
|||
/// The position is represented by the pair `(depth, pos)`, where for a given depth `d` elements
|
||||
/// are numbered from $0..(2^d)-1$. Example:
|
||||
///
|
||||
/// ```ignore
|
||||
/// ```text
|
||||
/// depth
|
||||
/// 0 0
|
||||
/// 1 0 1
|
||||
|
@ -72,6 +72,37 @@ impl NodeIndex {
|
|||
Self::new(depth, value)
|
||||
}
|
||||
|
||||
/// Converts a row traversal index to depth-value form and returns it as a [`NodeIndex`].
|
||||
/// See [`NodeIndex::to_traversal_index()`] for more details.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns the same errors under the same conditions as [`NodeIndex::new()`].
|
||||
///
|
||||
/// # Panics
|
||||
/// Panics if the depth indicated by `index` does not fit in a [`u8`], or if the row-value
|
||||
/// indicated by `index` does not fit in a [`u64`].
|
||||
pub const fn from_traversal_index(index: u128) -> Result<Self, MerkleError> {
|
||||
if index == 0 {
|
||||
return Ok(Self { depth: 0, value: 0 });
|
||||
}
|
||||
|
||||
let depth = {
|
||||
let depth = u128::ilog2(index);
|
||||
assert!(depth <= u8::MAX as u32);
|
||||
depth as u8
|
||||
};
|
||||
|
||||
let max_index_for_depth = u128::pow(2, depth as u32) - 1;
|
||||
|
||||
let value = {
|
||||
let value = index - max_index_for_depth;
|
||||
assert!(value <= u64::MAX as u128);
|
||||
value as u64
|
||||
};
|
||||
|
||||
Self::new(depth, value)
|
||||
}
|
||||
|
||||
/// Converts a scalar representation of a depth/value pair to a [`NodeIndex`].
|
||||
///
|
||||
/// This is the inverse operation of [`NodeIndex::to_scalar_index()`]. As `1` represents the
|
||||
|
@ -83,20 +114,29 @@ impl NodeIndex {
|
|||
/// # Panics
|
||||
/// Panics if the depth indicated by `index` does not fit in a [`u8`], or if the row-value
|
||||
/// indicated by `index` does not fit in a [`u64`].
|
||||
pub const fn from_scalar_index(index: NonZero<u128>) -> Result<Self, MerkleError> {
|
||||
pub fn from_scalar_index(index: NonZero<u128>) -> Result<Self, MerkleError> {
|
||||
let index = index.get() - 1;
|
||||
|
||||
if index == 0 {
|
||||
return Ok(Self { depth: 0, value: 0 });
|
||||
return Ok(Self::root());
|
||||
}
|
||||
|
||||
// The log of 1 is always 0.
|
||||
if index == 1 {
|
||||
return Ok(Self::root().left_child());
|
||||
}
|
||||
|
||||
let depth = {
|
||||
let depth = u128::ilog2(index);
|
||||
let depth = u128::ilog2(index + 1);
|
||||
assert!(depth <= u8::MAX as u32);
|
||||
depth as u8
|
||||
};
|
||||
|
||||
let max_value_for_depth = (1 << depth) - 1;
|
||||
assert!(
|
||||
max_value_for_depth <= u64::MAX as u128,
|
||||
"max_value ({max_value_for_depth}) does not fit in u64",
|
||||
);
|
||||
|
||||
let value = {
|
||||
let value = index - max_value_for_depth;
|
||||
|
@ -125,6 +165,18 @@ impl NodeIndex {
|
|||
self
|
||||
}
|
||||
|
||||
pub const fn left_ancestor_n(mut self, n: u8) -> Self {
|
||||
self.depth += n;
|
||||
self.value <<= n;
|
||||
self
|
||||
}
|
||||
|
||||
pub const fn right_ancestor_n(mut self, n: u8) -> Self {
|
||||
self.depth += n;
|
||||
self.value = (self.value << n) + 1;
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns right child index of the current node.
|
||||
pub const fn right_child(mut self) -> Self {
|
||||
self.depth += 1;
|
||||
|
@ -132,6 +184,64 @@ impl NodeIndex {
|
|||
self
|
||||
}
|
||||
|
||||
/// Returns the parent of the current node.
|
||||
pub const fn parent(mut self) -> Self {
|
||||
self.depth = self.depth.saturating_sub(1);
|
||||
self.value >>= 1;
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the `n`th parent of the current node.
|
||||
pub fn parent_n(mut self, n: u8) -> Self {
|
||||
debug_assert!(n <= self.depth);
|
||||
self.depth = self.depth.saturating_sub(n);
|
||||
self.value >>= n;
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns `true` if and only if `other` is an ancestor of the current node, or the current
|
||||
/// node itself.
|
||||
pub fn contains(&self, mut other: Self) -> bool {
|
||||
if other == *self {
|
||||
return true;
|
||||
}
|
||||
if other.is_root() {
|
||||
return false;
|
||||
}
|
||||
if other.depth < self.depth {
|
||||
return false;
|
||||
}
|
||||
|
||||
other = other.parent_n(other.depth() - self.depth());
|
||||
|
||||
loop {
|
||||
if other == *self {
|
||||
return true;
|
||||
}
|
||||
|
||||
if other.is_root() {
|
||||
return false;
|
||||
}
|
||||
|
||||
if other.depth < self.depth {
|
||||
return false;
|
||||
}
|
||||
|
||||
other = other.parent();
|
||||
}
|
||||
}
|
||||
|
||||
/// The inverse of [`NodeIndex::is_ancestor_of`], except that it does not include itself.
|
||||
pub fn is_descendent_of(self, other: Self) -> bool {
|
||||
self.depth != other.depth && self.value != other.value && other.contains(self)
|
||||
}
|
||||
|
||||
/// Returns `true` if and only if `other` is an ancestor of the current node.
|
||||
pub fn is_ancestor_of(self, other: Self) -> bool {
|
||||
self.depth != other.depth && self.value != other.value && self.contains(other)
|
||||
}
|
||||
|
||||
// PROVIDERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -154,11 +264,13 @@ impl NodeIndex {
|
|||
}
|
||||
|
||||
/// Returns the depth of the current instance.
|
||||
#[inline(always)]
|
||||
pub const fn depth(&self) -> u8 {
|
||||
self.depth
|
||||
}
|
||||
|
||||
/// Returns the value of this index.
|
||||
#[inline(always)]
|
||||
pub const fn value(&self) -> u64 {
|
||||
self.value
|
||||
}
|
||||
|
@ -169,10 +281,26 @@ impl NodeIndex {
|
|||
}
|
||||
|
||||
/// Returns `true` if the depth is `0`.
|
||||
#[inline(always)]
|
||||
pub const fn is_root(&self) -> bool {
|
||||
self.depth == 0
|
||||
}
|
||||
|
||||
/// Converts this [`NodeIndex`] to the equivalent row traversal index (also called level order
|
||||
/// traversal index or breadth-first order index).
|
||||
///
|
||||
/// [`NodeIndex`] denotes a node position by its depth and the index in the row at that depth.
|
||||
/// The row traversal index denotes a node position by counting all the nodes "before" it --
|
||||
/// counting all nodes before it in the row, as well as all nodes in all other rows above it.
|
||||
/// In other words, a node at row traversal index `n` is the `n`th node in the tree.
|
||||
/// For example, a node at depth `2`, value `3` has the root node `0, 0`, its two children
|
||||
/// `1, 0` and `1, 1`, and their children `2, 0`, `2, 1`, `2, 2`, and `2, 3`, 0-indexed, `2, 3`
|
||||
/// is node **`6`** in that order, so its row traversal index is `6`.
|
||||
pub const fn to_traversal_index(&self) -> u128 {
|
||||
(1 << self.depth) - 1 + (self.value as u128)
|
||||
//u128::pow(2, self.depth as u32) - 1 + (self.value as u128)
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -215,6 +343,27 @@ impl Deserializable for NodeIndex {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, PartialOrd, Ord, Hash)]
|
||||
pub struct SubtreeIndex {
|
||||
pub root: NodeIndex,
|
||||
pub depth: u8,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl SubtreeIndex {
|
||||
pub const fn new(root: NodeIndex, depth: u8) -> Self {
|
||||
Self { root, depth }
|
||||
}
|
||||
|
||||
pub const fn left_bound(&self) -> NodeIndex {
|
||||
self.root.left_ancestor_n(self.depth)
|
||||
}
|
||||
|
||||
pub const fn right_bound(&self) -> NodeIndex {
|
||||
self.root.right_ancestor_n(self.depth)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use proptest::prelude::*;
|
||||
|
@ -245,6 +394,21 @@ mod tests {
|
|||
assert!(NodeIndex::new(64, u64::MAX).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scalar_roundtrip() {
|
||||
// Arbitrary value that's at the bottom and not in a corner.
|
||||
let start = NodeIndex::make(64, u64::MAX - 8);
|
||||
|
||||
let mut index = start;
|
||||
while !index.is_root() {
|
||||
let as_scalar = index.to_scalar_index();
|
||||
let round_trip =
|
||||
NodeIndex::from_scalar_index(NonZero::new(as_scalar).unwrap()).unwrap();
|
||||
assert_eq!(index, round_trip, "{index:?} did not round-trip as a scalar index");
|
||||
index.move_up();
|
||||
}
|
||||
}
|
||||
|
||||
prop_compose! {
|
||||
fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
|
||||
// unwrap never panics because the range of depth is 0..u64::BITS
|
||||
|
|
|
@ -1,16 +1,25 @@
|
|||
#[cfg(feature = "async")]
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use alloc::{
|
||||
collections::{BTreeMap, BTreeSet},
|
||||
string::ToString,
|
||||
vec::Vec,
|
||||
};
|
||||
#[cfg(feature = "async")]
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
use super::NodeMutation;
|
||||
use super::{
|
||||
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath,
|
||||
MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
|
||||
};
|
||||
#[cfg(feature = "async")]
|
||||
use crate::merkle::index::SubtreeIndex;
|
||||
|
||||
mod error;
|
||||
pub use error::{SmtLeafError, SmtProofError};
|
||||
|
@ -46,7 +55,12 @@ pub const SMT_DEPTH: u8 = 64;
|
|||
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
|
||||
pub struct Smt {
|
||||
root: RpoDigest,
|
||||
|
||||
#[cfg(not(feature = "async"))]
|
||||
leaves: BTreeMap<u64, SmtLeaf>,
|
||||
#[cfg(feature = "async")]
|
||||
leaves: Arc<BTreeMap<u64, SmtLeaf>>,
|
||||
|
||||
#[cfg(not(feature = "async"))]
|
||||
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
|
||||
#[cfg(feature = "async")]
|
||||
|
@ -70,7 +84,7 @@ impl Smt {
|
|||
|
||||
Self {
|
||||
root,
|
||||
leaves: BTreeMap::new(),
|
||||
leaves: Default::default(),
|
||||
inner_nodes: Default::default(),
|
||||
}
|
||||
}
|
||||
|
@ -107,6 +121,22 @@ impl Smt {
|
|||
Ok(tree)
|
||||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
pub fn get_leaves(&self) -> Arc<BTreeMap<u64, SmtLeaf>> {
|
||||
Arc::clone(&self.leaves)
|
||||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
pub async fn compute_mutations_parallel(
|
||||
&self,
|
||||
kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
|
||||
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
|
||||
<Self as super::ParallelSparseMerkleTree<SMT_DEPTH>>::compute_mutations_parallel(
|
||||
self, kv_pairs,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
// PUBLIC ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -177,6 +207,23 @@ impl Smt {
|
|||
}
|
||||
}
|
||||
|
||||
/// Gets a mutable reference to this structure's inner leaf mapping.
|
||||
///
|
||||
/// # Panics
|
||||
/// This will panic if we have violated our own invariants and try to mutate these nodes while
|
||||
/// Self::compute_mutations_parallel() is still running.
|
||||
fn leaves_mut(&mut self) -> &mut BTreeMap<u64, SmtLeaf> {
|
||||
#[cfg(feature = "async")]
|
||||
{
|
||||
Arc::get_mut(&mut self.leaves).unwrap()
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "async"))]
|
||||
{
|
||||
&mut self.leaves
|
||||
}
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -241,10 +288,12 @@ impl Smt {
|
|||
|
||||
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
|
||||
|
||||
match self.leaves.get_mut(&leaf_index.value()) {
|
||||
let leaves = self.leaves_mut();
|
||||
|
||||
match leaves.get_mut(&leaf_index.value()) {
|
||||
Some(leaf) => leaf.insert(key, value),
|
||||
None => {
|
||||
self.leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
|
||||
leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
|
||||
|
||||
None
|
||||
},
|
||||
|
@ -255,10 +304,12 @@ impl Smt {
|
|||
fn perform_remove(&mut self, key: RpoDigest) -> Option<Word> {
|
||||
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
|
||||
|
||||
if let Some(leaf) = self.leaves.get_mut(&leaf_index.value()) {
|
||||
let leaves = self.leaves_mut();
|
||||
|
||||
if let Some(leaf) = leaves.get_mut(&leaf_index.value()) {
|
||||
let (old_value, is_empty) = leaf.remove(key);
|
||||
if is_empty {
|
||||
self.leaves.remove(&leaf_index.value());
|
||||
leaves.remove(&leaf_index.value());
|
||||
}
|
||||
old_value
|
||||
} else {
|
||||
|
@ -266,6 +317,27 @@ impl Smt {
|
|||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn construct_prospective_leaf(
|
||||
mut existing_leaf: SmtLeaf,
|
||||
key: &RpoDigest,
|
||||
value: &Word,
|
||||
) -> SmtLeaf {
|
||||
debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key));
|
||||
|
||||
match existing_leaf {
|
||||
SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value),
|
||||
_ => {
|
||||
if *value != EMPTY_WORD {
|
||||
existing_leaf.insert(*key, *value);
|
||||
} else {
|
||||
existing_leaf.remove(*key);
|
||||
}
|
||||
|
||||
existing_leaf
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SparseMerkleTree<SMT_DEPTH> for Smt {
|
||||
|
@ -332,24 +404,11 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
|
|||
|
||||
fn construct_prospective_leaf(
|
||||
&self,
|
||||
mut existing_leaf: SmtLeaf,
|
||||
existing_leaf: SmtLeaf,
|
||||
key: &RpoDigest,
|
||||
value: &Word,
|
||||
) -> SmtLeaf {
|
||||
debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key));
|
||||
|
||||
match existing_leaf {
|
||||
SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value),
|
||||
_ => {
|
||||
if *value != EMPTY_WORD {
|
||||
existing_leaf.insert(*key, *value);
|
||||
} else {
|
||||
existing_leaf.remove(*key);
|
||||
}
|
||||
|
||||
existing_leaf
|
||||
},
|
||||
}
|
||||
Smt::construct_prospective_leaf(existing_leaf, key, value)
|
||||
}
|
||||
|
||||
fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex<SMT_DEPTH> {
|
||||
|
@ -362,12 +421,301 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
|
|||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
impl super::ParallelSparseMerkleTree<SMT_DEPTH> for Smt {
|
||||
// Helpers required only for the parallel version of the SMT trait.
|
||||
fn get_inner_nodes(&self) -> Arc<BTreeMap<NodeIndex, InnerNode>> {
|
||||
Arc::clone(&self.inner_nodes)
|
||||
}
|
||||
|
||||
fn get_leaves(&self) -> Arc<BTreeMap<u64, SmtLeaf>> {
|
||||
Arc::clone(&self.leaves)
|
||||
}
|
||||
|
||||
async fn compute_mutations_parallel<I>(
|
||||
&self,
|
||||
kv_pairs: I,
|
||||
) -> MutationSet<SMT_DEPTH, RpoDigest, Word>
|
||||
where
|
||||
I: IntoIterator<Item = (RpoDigest, Word)>,
|
||||
{
|
||||
use std::time::Instant;
|
||||
|
||||
const SUBTREE_INTERVAL: u8 = 8;
|
||||
|
||||
// FIXME: check for duplicates and return MerkleError.
|
||||
let kv_pairs = Arc::new(BTreeMap::from_iter(kv_pairs));
|
||||
|
||||
// The first subtrees we calculate, which include our new leaves.
|
||||
let mut subtrees: HashSet<NodeIndex> = kv_pairs
|
||||
.keys()
|
||||
.map(|key| {
|
||||
let index_for_key = NodeIndex::from(Smt::key_to_leaf_index(key));
|
||||
index_for_key.parent_n(SUBTREE_INTERVAL)
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Node mutations across all tasks will be collected here.
|
||||
// Every time we collect tasks we store all the new known node mutations and their hashes
|
||||
// (so we don't have to recompute them every time we need them).
|
||||
let mut node_mutations: Arc<HashMap<NodeIndex, (RpoDigest, NodeMutation)>> =
|
||||
Default::default();
|
||||
// Any leaf hashes done by tasks will be collected here, so hopefully we only hash each leaf
|
||||
// once.
|
||||
let mut cached_leaf_hashes: Arc<HashMap<LeafIndex<SMT_DEPTH>, RpoDigest>> =
|
||||
Default::default();
|
||||
|
||||
for subtree_depth in (0..SMT_DEPTH).step_by(SUBTREE_INTERVAL.into()).rev() {
|
||||
let now = Instant::now();
|
||||
let mut tasks = JoinSet::new();
|
||||
|
||||
for subtree in subtrees.iter().copied() {
|
||||
debug_assert_eq!(subtree.depth(), subtree_depth);
|
||||
let mut state = NodeSubtreeState::<SMT_DEPTH>::with_smt(
|
||||
&self,
|
||||
Arc::clone(&node_mutations),
|
||||
Arc::clone(&kv_pairs),
|
||||
SubtreeIndex::new(subtree, SUBTREE_INTERVAL as u8),
|
||||
);
|
||||
// The "double spawn" here is necessary to allow tokio to run these tasks in
|
||||
// parallel.
|
||||
tasks.spawn(tokio::spawn(async move {
|
||||
let hash = state.get_or_make_hash(subtree);
|
||||
(subtree, hash, state.into_results())
|
||||
}));
|
||||
}
|
||||
|
||||
let task_results = tasks.join_all().await;
|
||||
let elapsed = now.elapsed();
|
||||
std::eprintln!(
|
||||
"joined {} tasks for depth {} in {:.3} milliseconds",
|
||||
task_results.len(),
|
||||
subtree_depth,
|
||||
elapsed.as_secs_f64() * 1000.0,
|
||||
);
|
||||
|
||||
for result in task_results {
|
||||
// FIXME: .expect() error message?
|
||||
let result = result.unwrap();
|
||||
let (subtree, hash, state) = result;
|
||||
let NodeSubtreeResults {
|
||||
new_mutations,
|
||||
cached_leaf_hashes: new_leaf_hashes,
|
||||
} = state;
|
||||
|
||||
Arc::get_mut(&mut node_mutations).unwrap().extend(new_mutations);
|
||||
Arc::get_mut(&mut cached_leaf_hashes).unwrap().extend(new_leaf_hashes);
|
||||
// Make sure the final hash we calculated is in the new mutations.
|
||||
assert_eq!(
|
||||
node_mutations.get(&subtree).unwrap().0,
|
||||
hash,
|
||||
"Stored and returned hashes for subtree '{subtree:?}' differ",
|
||||
);
|
||||
}
|
||||
|
||||
// And advance our subtrees, unless we just did the root depth.
|
||||
if subtree_depth == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let subtree_count_before_advance = subtrees.len();
|
||||
subtrees =
|
||||
subtrees.into_iter().map(|subtree| subtree.parent_n(SUBTREE_INTERVAL)).collect();
|
||||
// FIXME: remove.
|
||||
assert!(subtrees.len() <= subtree_count_before_advance);
|
||||
}
|
||||
|
||||
let root = NodeIndex::root();
|
||||
let new_root = node_mutations.get(&root).unwrap().0;
|
||||
|
||||
MutationSet {
|
||||
old_root: self.root(),
|
||||
//node_mutations: Arc::into_inner(node_mutations).unwrap().into_iter().collect(),
|
||||
node_mutations: Arc::into_inner(node_mutations)
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.map(|(key, (_hash, node))| (key, node))
|
||||
.collect(),
|
||||
new_pairs: Arc::into_inner(kv_pairs).unwrap(),
|
||||
new_root,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Smt {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) struct NodeSubtreeState<const DEPTH: u8> {
|
||||
inner_nodes: Arc<BTreeMap<NodeIndex, InnerNode>>,
|
||||
leaves: Arc<BTreeMap<u64, SmtLeaf>>,
|
||||
// This field has invariants!
|
||||
dirtied_indices: HashMap<NodeIndex, bool>,
|
||||
existing_mutations: Arc<HashMap<NodeIndex, (RpoDigest, NodeMutation)>>,
|
||||
new_mutations: HashMap<NodeIndex, (RpoDigest, NodeMutation)>,
|
||||
new_pairs: Arc<BTreeMap<RpoDigest, Word>>,
|
||||
cached_leaf_hashes: HashMap<LeafIndex<SMT_DEPTH>, RpoDigest>,
|
||||
indentation: u8,
|
||||
subtree: SubtreeIndex,
|
||||
dirty_high: u128,
|
||||
dirty_low: u128,
|
||||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
impl<const DEPTH: u8> NodeSubtreeState<DEPTH> {
|
||||
pub(crate) fn new(
|
||||
inner_nodes: Arc<BTreeMap<NodeIndex, InnerNode>>,
|
||||
existing_mutations: Arc<HashMap<NodeIndex, (RpoDigest, NodeMutation)>>,
|
||||
leaves: Arc<BTreeMap<u64, SmtLeaf>>,
|
||||
new_pairs: Arc<BTreeMap<RpoDigest, Word>>,
|
||||
subtree: SubtreeIndex,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner_nodes,
|
||||
leaves,
|
||||
dirtied_indices: Default::default(),
|
||||
new_mutations: Default::default(),
|
||||
existing_mutations,
|
||||
new_pairs,
|
||||
cached_leaf_hashes: Default::default(),
|
||||
indentation: 0,
|
||||
subtree,
|
||||
dirty_high: 0,
|
||||
dirty_low: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn with_smt(
|
||||
smt: &Smt,
|
||||
existing_mutations: Arc<HashMap<NodeIndex, (RpoDigest, NodeMutation)>>,
|
||||
new_pairs: Arc<BTreeMap<RpoDigest, Word>>,
|
||||
subtree: SubtreeIndex,
|
||||
) -> Self {
|
||||
Self::new(
|
||||
Arc::clone(&smt.inner_nodes),
|
||||
existing_mutations,
|
||||
Arc::clone(&smt.leaves),
|
||||
new_pairs,
|
||||
subtree,
|
||||
)
|
||||
}
|
||||
|
||||
#[inline(never)] // XXX: for profiling.
|
||||
pub(crate) fn is_index_dirty(&mut self, index_to_check: NodeIndex) -> bool {
|
||||
if index_to_check == self.subtree.root {
|
||||
return true;
|
||||
}
|
||||
if let Some(cached) = self.dirtied_indices.get(&index_to_check) {
|
||||
return *cached;
|
||||
}
|
||||
let is_dirty = self
|
||||
.existing_mutations
|
||||
.iter()
|
||||
.map(|(index, _)| *index)
|
||||
.chain(self.new_pairs.iter().map(|(key, _v)| Smt::key_to_leaf_index(key).index))
|
||||
.filter(|&dirtied_index| index_to_check.contains(dirtied_index))
|
||||
.next()
|
||||
.is_some();
|
||||
self.dirtied_indices.insert(index_to_check, is_dirty);
|
||||
is_dirty
|
||||
}
|
||||
|
||||
/// Does NOT check `new_mutations`.
|
||||
#[inline(never)] // XXX: for profiling.
|
||||
pub(crate) fn get_clean_hash(&self, index: NodeIndex) -> Option<RpoDigest> {
|
||||
self.existing_mutations
|
||||
.get(&index)
|
||||
.map(|(hash, _)| *hash)
|
||||
.or_else(|| self.inner_nodes.get(&index).map(|inner_node| InnerNode::hash(&inner_node)))
|
||||
}
|
||||
|
||||
#[inline(never)] // XXX: for profiling.
|
||||
pub(crate) fn get_effective_leaf(&self, index: LeafIndex<SMT_DEPTH>) -> SmtLeaf {
|
||||
let pairs_at_index = self
|
||||
.new_pairs
|
||||
.iter()
|
||||
.filter(|&(new_key, _)| Smt::key_to_leaf_index(new_key) == index);
|
||||
|
||||
let existing_leaf = self
|
||||
.leaves
|
||||
.get(&index.index.value())
|
||||
.cloned()
|
||||
.unwrap_or_else(|| SmtLeaf::new_empty(index));
|
||||
|
||||
pairs_at_index.fold(existing_leaf, |acc, (k, v)| {
|
||||
let existing_leaf = acc.clone();
|
||||
Smt::construct_prospective_leaf(existing_leaf, k, v)
|
||||
})
|
||||
}
|
||||
|
||||
/// Retrieve a cached hash, or recursively compute it.
|
||||
#[inline(never)] // XXX: for profiling.
|
||||
pub fn get_or_make_hash(&mut self, index: NodeIndex) -> RpoDigest {
|
||||
use NodeMutation::*;
|
||||
|
||||
// If this is a leaf, then only do leaf stuff.
|
||||
if index.depth() == SMT_DEPTH {
|
||||
let index = LeafIndex::new(index.value()).unwrap();
|
||||
return match self.cached_leaf_hashes.get(&index) {
|
||||
Some(cached_hash) => cached_hash.clone(),
|
||||
None => {
|
||||
let leaf = self.get_effective_leaf(index);
|
||||
let hash = Smt::hash_leaf(&leaf);
|
||||
self.cached_leaf_hashes.insert(index, hash);
|
||||
hash
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
// If we already computed this one earlier as a mutation, just return it.
|
||||
if let Some((hash, _)) = self.new_mutations.get(&index) {
|
||||
return *hash;
|
||||
}
|
||||
|
||||
// Otherwise, we need to know if this node is one of the nodes we're in the process of
|
||||
// recomputing, or if we can safely use the node already in the Merkle tree.
|
||||
if !self.is_index_dirty(index) {
|
||||
return self
|
||||
.get_clean_hash(index)
|
||||
.unwrap_or_else(|| *EmptySubtreeRoots::entry(SMT_DEPTH, index.depth()));
|
||||
}
|
||||
|
||||
// If we got here, then we have to make, rather than get, this hash.
|
||||
// Make sure we mark this index as now dirty.
|
||||
self.dirtied_indices.insert(index, true);
|
||||
|
||||
// Recurse for the left and right sides.
|
||||
let left = self.get_or_make_hash(index.left_child());
|
||||
let right = self.get_or_make_hash(index.right_child());
|
||||
let node = InnerNode { left, right };
|
||||
let hash = node.hash();
|
||||
let &equivalent_empty_hash = EmptySubtreeRoots::entry(SMT_DEPTH, index.depth());
|
||||
let is_removal = hash == equivalent_empty_hash;
|
||||
let new_entry = if is_removal { Removal } else { Addition(node) };
|
||||
|
||||
self.new_mutations.insert(index, (hash, new_entry));
|
||||
|
||||
hash
|
||||
}
|
||||
|
||||
fn into_results(self) -> NodeSubtreeResults {
|
||||
NodeSubtreeResults {
|
||||
new_mutations: self.new_mutations,
|
||||
cached_leaf_hashes: self.cached_leaf_hashes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) struct NodeSubtreeResults {
|
||||
pub(crate) new_mutations: HashMap<NodeIndex, (RpoDigest, NodeMutation)>,
|
||||
pub(crate) cached_leaf_hashes: HashMap<LeafIndex<SMT_DEPTH>, RpoDigest>,
|
||||
}
|
||||
|
||||
// CONVERSIONS
|
||||
// ================================================================================================
|
||||
|
||||
|
|
|
@ -1,6 +1,19 @@
|
|||
use alloc::vec::Vec;
|
||||
#[cfg(feature = "async")]
|
||||
use std::{
|
||||
collections::{BTreeMap, BTreeSet, HashMap},
|
||||
sync::Arc,
|
||||
time::Instant,
|
||||
};
|
||||
#[cfg(feature = "async")]
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH};
|
||||
#[cfg(feature = "async")]
|
||||
use crate::merkle::{
|
||||
index::SubtreeIndex,
|
||||
smt::{full::NodeSubtreeState, NodeMutation},
|
||||
};
|
||||
use crate::{
|
||||
merkle::{smt::SparseMerkleTree, EmptySubtreeRoots, MerkleStore},
|
||||
utils::{Deserializable, Serializable},
|
||||
|
@ -568,6 +581,263 @@ fn test_multiple_smt_leaf_serialization_success() {
|
|||
assert_eq!(multiple_leaf, deserialized);
|
||||
}
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
fn setup_subtree_test(kv_count: u64) -> (Vec<(RpoDigest, Word)>, Smt) {
|
||||
// FIXME: override seed.
|
||||
let rand_felt = || rand_utils::rand_value::<Felt>();
|
||||
|
||||
let kv_pairs: Vec<(RpoDigest, Word)> = (0..kv_count)
|
||||
.into_iter()
|
||||
.map(|i| {
|
||||
let leaf_index = u64::MAX / (i + 1);
|
||||
let key =
|
||||
RpoDigest::new([rand_felt(), rand_felt(), rand_felt(), Felt::new(leaf_index)]);
|
||||
let value: Word = [Felt::new(i), rand_felt(), rand_felt(), rand_felt()];
|
||||
(key, value)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let control_smt = Smt::with_entries(kv_pairs.clone()).unwrap();
|
||||
|
||||
(kv_pairs, control_smt)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "async")]
|
||||
fn test_single_node_subtree() {
|
||||
use alloc::collections::BTreeMap;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use crate::merkle::smt::{full::NodeSubtreeState, NodeMutation};
|
||||
|
||||
const KV_COUNT: u64 = 2_000;
|
||||
|
||||
let (kv_pairs, control_smt) = setup_subtree_test(KV_COUNT);
|
||||
let new_pairs = Arc::new(BTreeMap::from_iter(kv_pairs));
|
||||
|
||||
let test_smt = Smt::new();
|
||||
|
||||
let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();
|
||||
let _: () = rt.block_on(async move {
|
||||
// Construct some fake node mutations based on the leaves in the control Smt.
|
||||
let node_mutations: HashMap<NodeIndex, (RpoDigest, NodeMutation)> = control_smt
|
||||
.leaves()
|
||||
.flat_map(|(index, _leaf)| {
|
||||
let subtree = index.index.parent();
|
||||
let mutation = control_smt
|
||||
.inner_nodes
|
||||
.get(&subtree)
|
||||
.cloned()
|
||||
.map(|node| (node.hash(), NodeMutation::Addition(node)))
|
||||
.unwrap_or_else(|| {
|
||||
(
|
||||
*EmptySubtreeRoots::entry(SMT_DEPTH, subtree.depth()),
|
||||
NodeMutation::Removal,
|
||||
)
|
||||
});
|
||||
|
||||
vec![(subtree, mutation)]
|
||||
})
|
||||
.collect();
|
||||
let node_mutations = Arc::new(node_mutations);
|
||||
|
||||
let mut state = NodeSubtreeState::<SMT_DEPTH>::new(
|
||||
Arc::clone(&test_smt.inner_nodes),
|
||||
Arc::clone(&node_mutations),
|
||||
Arc::clone(&control_smt.leaves),
|
||||
Arc::clone(&new_pairs),
|
||||
SubtreeIndex::new(NodeIndex::root(), 8),
|
||||
);
|
||||
|
||||
for (i, (&index, mutation)) in node_mutations.iter().enumerate() {
|
||||
assert!(index.depth() <= SMT_DEPTH, "index {index:?} is invalid");
|
||||
|
||||
let control_hash = if index.depth() < SMT_DEPTH {
|
||||
control_smt.get_inner_node(index).hash()
|
||||
} else {
|
||||
control_smt
|
||||
.leaves
|
||||
.get(&index.value())
|
||||
.map(Smt::hash_leaf)
|
||||
.unwrap_or_else(|| *EmptySubtreeRoots::entry(SMT_DEPTH, index.depth()))
|
||||
};
|
||||
let mutation_hash = mutation.0;
|
||||
let test_hash = state.get_or_make_hash(index);
|
||||
assert_eq!(mutation_hash, control_hash);
|
||||
assert_eq!(
|
||||
test_hash, control_hash,
|
||||
"test_hash != control_hash for mutation {i} at {index:?}",
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Test doing a node subtree from a LeafSubtreeMutationSet.
|
||||
#[test]
|
||||
#[cfg(feature = "async")]
|
||||
fn test_node_subtree_with_leaves() {
|
||||
const KV_COUNT: u64 = 2_000;
|
||||
|
||||
let (kv_pairs, control_smt) = setup_subtree_test(KV_COUNT);
|
||||
let new_pairs = Arc::new(BTreeMap::from_iter(kv_pairs));
|
||||
|
||||
let test_smt = Smt::new();
|
||||
|
||||
let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();
|
||||
let _: () = rt.block_on(async move {
|
||||
let mut task_mutations: HashMap<NodeIndex, (RpoDigest, NodeMutation)> = Default::default();
|
||||
|
||||
for leaf_index in control_smt.leaves().map(|(index, _leaf)| index) {
|
||||
let subtree = SubtreeIndex::new(leaf_index.index.parent_n(8), 8);
|
||||
let subtree_pairs: Vec<(RpoDigest, Word)> = control_smt
|
||||
.leaves()
|
||||
.flat_map(|(leaf_index, leaf)| {
|
||||
if subtree.root.contains(leaf_index.index) {
|
||||
leaf.entries()
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
})
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let mut state = NodeSubtreeState::<SMT_DEPTH>::with_smt(
|
||||
&test_smt,
|
||||
Default::default(),
|
||||
Arc::new(BTreeMap::from_iter(subtree_pairs)),
|
||||
//Arc::clone(&new_pairs),
|
||||
subtree,
|
||||
);
|
||||
let test_subtree_hash = state.get_or_make_hash(subtree.root);
|
||||
let control_subtree_hash = control_smt.get_inner_node(subtree.root).hash();
|
||||
assert_eq!(test_subtree_hash, control_subtree_hash);
|
||||
|
||||
task_mutations.extend(state.new_mutations);
|
||||
}
|
||||
|
||||
let node_mutations = Arc::new(task_mutations);
|
||||
let subtrees: BTreeSet<SubtreeIndex> = control_smt
|
||||
.leaves()
|
||||
.map(|(index, _leaf)| SubtreeIndex::new(index.index.parent_n(8).parent_n(8), 8))
|
||||
.collect();
|
||||
|
||||
for (i, subtree) in subtrees.into_iter().enumerate() {
|
||||
let mut state = NodeSubtreeState::<SMT_DEPTH>::with_smt(
|
||||
&test_smt,
|
||||
Arc::clone(&node_mutations),
|
||||
Arc::clone(&new_pairs),
|
||||
subtree,
|
||||
);
|
||||
|
||||
let control_subtree_hash = control_smt.get_inner_node(subtree.root).hash();
|
||||
let test_subtree_hash = state.get_or_make_hash(subtree.root);
|
||||
assert_eq!(
|
||||
test_subtree_hash, control_subtree_hash,
|
||||
"test subtree hash does not match control hash for subtree {i} '{subtree:?}'",
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "async")]
|
||||
fn test_node_subtrees_parallel() {
|
||||
const KV_COUNT: u64 = 2_000;
|
||||
|
||||
let (kv_pairs, control_smt) = setup_subtree_test(KV_COUNT);
|
||||
let new_pairs = Arc::new(BTreeMap::from_iter(kv_pairs));
|
||||
|
||||
let test_smt = Smt::new();
|
||||
|
||||
let rt = tokio::runtime::Builder::new_multi_thread().build().unwrap();
|
||||
//let rt = tokio::runtime::Builder::new_current_thread().build().unwrap();
|
||||
let _: () = rt.block_on(async move {
|
||||
let mut current_subtree_depth = SMT_DEPTH;
|
||||
|
||||
let subtrees: BTreeSet<SubtreeIndex> = new_pairs
|
||||
.keys()
|
||||
.map(|key| SubtreeIndex::new(Smt::key_to_leaf_index(key).index.parent_n(8), 8))
|
||||
.collect();
|
||||
current_subtree_depth -= 8;
|
||||
|
||||
let mut node_mutations: Arc<HashMap<NodeIndex, (RpoDigest, NodeMutation)>> =
|
||||
Default::default();
|
||||
let mut tasks = JoinSet::new();
|
||||
|
||||
// FIXME
|
||||
let mut now = Instant::now();
|
||||
for subtree in subtrees.iter().copied() {
|
||||
let mut state = NodeSubtreeState::<SMT_DEPTH>::with_smt(
|
||||
&test_smt,
|
||||
Arc::clone(&node_mutations),
|
||||
Arc::clone(&new_pairs),
|
||||
subtree,
|
||||
);
|
||||
tasks.spawn(tokio::spawn(async move {
|
||||
let hash = state.get_or_make_hash(subtree.root);
|
||||
(subtree, hash, state)
|
||||
}));
|
||||
}
|
||||
|
||||
let mut cached_leaf_hashes: HashMap<LeafIndex<SMT_DEPTH>, RpoDigest> = Default::default();
|
||||
let mut subtrees = subtrees;
|
||||
let mut tasks = Some(tasks);
|
||||
while current_subtree_depth > 0 {
|
||||
std::eprintln!(
|
||||
"joining {} tasks for depth {current_subtree_depth}",
|
||||
tasks.as_ref().unwrap().len(),
|
||||
);
|
||||
let mut tasks_mutations: HashMap<NodeIndex, (RpoDigest, NodeMutation)> =
|
||||
Default::default();
|
||||
let results = tasks.take().unwrap().join_all().await;
|
||||
let elapsed = now.elapsed();
|
||||
std::eprintln!(" joined in {:.3} milliseconds", elapsed.as_secs_f64() * 1000.0);
|
||||
for result in results {
|
||||
let (subtree, test_hash, state) = result.unwrap();
|
||||
let control_hash = control_smt.get_inner_node(subtree.root).hash();
|
||||
assert_eq!(test_hash, control_hash);
|
||||
|
||||
tasks_mutations.extend(state.new_mutations);
|
||||
cached_leaf_hashes.extend(state.cached_leaf_hashes);
|
||||
}
|
||||
Arc::get_mut(&mut node_mutations).unwrap().extend(tasks_mutations);
|
||||
|
||||
// Move all our subtrees up.
|
||||
current_subtree_depth -= 8;
|
||||
subtrees = subtrees
|
||||
.into_iter()
|
||||
.map(|subtree| {
|
||||
let subtree = SubtreeIndex::new(subtree.root.parent_n(8), 8);
|
||||
assert_eq!(subtree.root.depth(), current_subtree_depth);
|
||||
subtree
|
||||
})
|
||||
.collect();
|
||||
|
||||
// And spawn our new tasks.
|
||||
//std::eprintln!("spawning tasks for depth {current_subtree_depth}");
|
||||
let tasks = tasks.insert(JoinSet::new());
|
||||
// FIXME
|
||||
now = Instant::now();
|
||||
for subtree in subtrees.iter().copied() {
|
||||
let mut state = NodeSubtreeState::<SMT_DEPTH>::with_smt(
|
||||
&test_smt,
|
||||
Arc::clone(&node_mutations),
|
||||
Arc::clone(&new_pairs),
|
||||
subtree,
|
||||
);
|
||||
state.cached_leaf_hashes = cached_leaf_hashes.clone();
|
||||
tasks.spawn(tokio::spawn(async move {
|
||||
let hash = state.get_or_make_hash(subtree.root);
|
||||
(subtree, hash, state)
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
assert!(tasks.is_some());
|
||||
assert_eq!(tasks.as_ref().unwrap().len(), 1);
|
||||
});
|
||||
}
|
||||
|
||||
// HELPERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
|
|
|
@ -12,6 +12,11 @@ pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH};
|
|||
mod simple;
|
||||
pub use simple::SimpleSmt;
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
mod parallel;
|
||||
#[cfg(feature = "async")]
|
||||
pub(crate) use parallel::ParallelSparseMerkleTree;
|
||||
|
||||
// CONSTANTS
|
||||
// ================================================================================================
|
||||
|
||||
|
@ -466,6 +471,26 @@ pub(crate) enum NodeMutation {
|
|||
Addition(InnerNode),
|
||||
}
|
||||
|
||||
impl NodeMutation {
|
||||
#[allow(dead_code)]
|
||||
pub fn into_inner_node(self, tree_depth: u8, node_depth: u8) -> InnerNode {
|
||||
use NodeMutation::*;
|
||||
match self {
|
||||
Addition(node) => node,
|
||||
Removal => EmptySubtreeRoots::get_inner_node(tree_depth, node_depth),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn as_hash(&self, tree_depth: u8, node_depth: u8) -> RpoDigest {
|
||||
use NodeMutation::*;
|
||||
match self {
|
||||
Addition(node) => node.hash(),
|
||||
Removal => *EmptySubtreeRoots::entry(tree_depth, node_depth),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a group of prospective mutations to a `SparseMerkleTree`, created by
|
||||
/// `SparseMerkleTree::compute_mutations()`, and that can be applied with
|
||||
/// `SparseMerkleTree::apply_mutations()`.
|
||||
|
@ -499,7 +524,6 @@ impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use proptest::prelude::*;
|
||||
|
|
62
src/merkle/smt/parallel.rs
Normal file
62
src/merkle/smt/parallel.rs
Normal file
|
@ -0,0 +1,62 @@
|
|||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::BTreeMap,
|
||||
sync::{Arc, LazyLock},
|
||||
thread,
|
||||
};
|
||||
|
||||
use crate::merkle::smt::{InnerNode, MutationSet, NodeIndex, SparseMerkleTree};
|
||||
|
||||
static TASK_COUNT: LazyLock<usize> = LazyLock::new(|| {
|
||||
// FIXME: error handling?
|
||||
thread::available_parallelism().unwrap().get()
|
||||
});
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) trait ParallelSparseMerkleTree<const DEPTH: u8>
|
||||
where
|
||||
// Note: these type bounds need to be specified this way or we'll have to duplicate them
|
||||
// everywhere.
|
||||
// https://github.com/rust-lang/rust/issues/130805.
|
||||
Self: SparseMerkleTree<
|
||||
DEPTH,
|
||||
Key: Send + Sync + 'static,
|
||||
Value: Send + Sync + 'static,
|
||||
Leaf: Send + Sync + 'static,
|
||||
>,
|
||||
{
|
||||
/// Shortcut for [`ParallelSparseMerkleTree::compute_mutations_parallel_n()`] with an
|
||||
/// automatically determined number of tasks.
|
||||
///
|
||||
/// Currently, the default number of tasks is the return value of
|
||||
/// [`std::thread::available_parallelism()`], but this may be subject to change in the future.
|
||||
async fn compute_mutations_parallel<I>(
|
||||
&self,
|
||||
kv_pairs: I,
|
||||
) -> MutationSet<DEPTH, Self::Key, Self::Value>
|
||||
where
|
||||
I: IntoIterator<Item = (Self::Key, Self::Value)>,
|
||||
{
|
||||
self.compute_mutations_parallel_n(kv_pairs, *TASK_COUNT).await
|
||||
}
|
||||
|
||||
async fn compute_mutations_parallel_n<I>(
|
||||
&self,
|
||||
_kv_pairs: I,
|
||||
_tasks: usize,
|
||||
) -> MutationSet<DEPTH, Self::Key, Self::Value>
|
||||
where
|
||||
I: IntoIterator<Item = (Self::Key, Self::Value)>,
|
||||
{
|
||||
todo!();
|
||||
}
|
||||
|
||||
fn get_inner_nodes(&self) -> Arc<BTreeMap<NodeIndex, InnerNode>>;
|
||||
fn get_leaves(&self) -> Arc<BTreeMap<u64, Self::Leaf>>;
|
||||
fn get_leaf_value(_leaf: &Self::Leaf, _key: &Self::Key) -> Option<Self::Value> {
|
||||
todo!();
|
||||
}
|
||||
fn cmp_keys(_lhs: &Self::Key, _rhs: &Self::Key) -> Ordering {
|
||||
todo!();
|
||||
}
|
||||
}
|
|
@ -395,3 +395,19 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
|
|||
(path, leaf).into()
|
||||
}
|
||||
}
|
||||
|
||||
//#[cfg(feature = "async")]
|
||||
////impl<const DEPTH: u8> super::ParallelSparseMerkleTree<DEPTH, LeafIndex<DEPTH>, Word, Word>
|
||||
//impl<const DEPTH: u8> super::ParallelSparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
|
||||
// fn get_inner_nodes(&self) -> Arc<BTreeMap<NodeIndex, InnerNode>> {
|
||||
// Arc::clone(&self.inner_nodes)
|
||||
// }
|
||||
//
|
||||
// fn get_leaf_value(leaf: &Word, _key: &LeafIndex<DEPTH>) -> Option<Word> {
|
||||
// Some(*leaf)
|
||||
// }
|
||||
//
|
||||
// fn cmp_keys(lhs: &LeafIndex<DEPTH>, rhs: &LeafIndex<DEPTH>) -> Ordering {
|
||||
// LeafIndex::cmp(lhs, rhs)
|
||||
// }
|
||||
//}
|
||||
|
|
Loading…
Add table
Reference in a new issue