From e5dd7c6d6a9518127887569851c23c6aa4ed52f3 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Mon, 16 Sep 2024 17:33:43 -0600 Subject: [PATCH] WIP(smt): allow inner_nodes: to be wrapped in an Arc for async --- Cargo.toml | 3 ++- src/merkle/smt/full/mod.rs | 29 ++++++++++++++++++++++--- src/merkle/smt/simple/mod.rs | 41 +++++++++++++++++++++++++++++++----- 3 files changed, 64 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2616341..50a797b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ name = "store" harness = false [features] -default = ["std"] +default = ["std", "async"] executable = ["dep:clap", "dep:rand-utils", "std"] serde = ["dep:serde", "serde?/alloc", "winter-math/serde"] std = [ @@ -44,6 +44,7 @@ std = [ "winter-math/std", "winter-utils/std", ] +async = ["serde?/rc"] [dependencies] blake3 = { version = "1.5", default-features = false } diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 9c64002..659ef1f 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "async")] +use std::sync::Arc; + use alloc::{ collections::{BTreeMap, BTreeSet}, string::ToString, @@ -44,7 +47,10 @@ pub const SMT_DEPTH: u8 = 64; pub struct Smt { root: RpoDigest, leaves: BTreeMap, + #[cfg(not(feature = "async"))] inner_nodes: BTreeMap, + #[cfg(feature = "async")] + inner_nodes: Arc>, } impl Smt { @@ -65,7 +71,7 @@ impl Smt { Self { root, leaves: BTreeMap::new(), - inner_nodes: BTreeMap::new(), + inner_nodes: Default::default(), } } @@ -154,6 +160,23 @@ impl Smt { }) } + /// Gets a mutable reference to this structure's inner node mapping. + /// + /// # Panics + /// This will panic if we have violated our own invariants and try to mutate these nodes while + /// Self::compute_mutations_parallel() is still running. + fn inner_nodes_mut(&mut self) -> &mut BTreeMap { + #[cfg(feature = "async")] + { + Arc::get_mut(&mut self.inner_nodes).unwrap() + } + + #[cfg(not(feature = "async"))] + { + &mut self.inner_nodes + } + } + // STATE MUTATORS // -------------------------------------------------------------------------------------------- @@ -269,11 +292,11 @@ impl SparseMerkleTree for Smt { } fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) { - self.inner_nodes.insert(index, inner_node); + self.inner_nodes_mut().insert(index, inner_node); } fn remove_inner_node(&mut self, index: NodeIndex) { - let _ = self.inner_nodes.remove(&index); + let _ = self.inner_nodes_mut().remove(&index); } fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option { diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index 1744430..3ce3f7c 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -1,4 +1,6 @@ use alloc::collections::{BTreeMap, BTreeSet}; +#[cfg(feature = "async")] +use std::sync::Arc; use super::{ super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, @@ -20,7 +22,10 @@ mod tests; pub struct SimpleSmt { root: RpoDigest, leaves: BTreeMap, + #[cfg(not(feature = "async"))] inner_nodes: BTreeMap, + #[cfg(feature = "async")] + inner_nodes: Arc>, } impl SimpleSmt { @@ -52,7 +57,7 @@ impl SimpleSmt { Ok(Self { root, leaves: BTreeMap::new(), - inner_nodes: BTreeMap::new(), + inner_nodes: Default::default(), }) } @@ -175,6 +180,23 @@ impl SimpleSmt { }) } + /// Gets a mutable reference to this structure's inner node mapping. + /// + /// # Panics + /// This will panic if we have violated our own invariants and try to mutate these nodes while + /// Self::compute_mutations_parallel() is still running. + fn inner_nodes_mut(&mut self) -> &mut BTreeMap { + #[cfg(feature = "async")] + { + Arc::get_mut(&mut self.inner_nodes).unwrap() + } + + #[cfg(not(feature = "async"))] + { + &mut self.inner_nodes + } + } + // STATE MUTATORS // -------------------------------------------------------------------------------------------- @@ -271,7 +293,16 @@ impl SimpleSmt { // add subtree's branch nodes (which includes the root) // -------------- - for (branch_idx, branch_node) in subtree.inner_nodes { + let subtree_nodes; + #[cfg(feature = "async")] + { + subtree_nodes = Arc::into_inner(subtree.inner_nodes).unwrap(); + } + #[cfg(not(feature = "async"))] + { + subtree_nodes = subtree.inner_nodes + } + for (branch_idx, branch_node) in subtree_nodes { let new_branch_idx = { let new_depth = subtree_root_insertion_depth + branch_idx.depth(); let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into()) @@ -280,7 +311,7 @@ impl SimpleSmt { NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid") }; - self.inner_nodes.insert(new_branch_idx, branch_node); + self.inner_nodes_mut().insert(new_branch_idx, branch_node); } // recompute nodes starting from subtree root @@ -315,11 +346,11 @@ impl SparseMerkleTree for SimpleSmt { } fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) { - self.inner_nodes.insert(index, inner_node); + self.inner_nodes_mut().insert(index, inner_node); } fn remove_inner_node(&mut self, index: NodeIndex) { - let _ = self.inner_nodes.remove(&index); + let _ = self.inner_nodes_mut().remove(&index); } fn insert_value(&mut self, key: LeafIndex, value: Word) -> Option {