WIP(smt): allow inner_nodes: to be wrapped in an Arc for async

This commit is contained in:
Qyriad 2024-09-16 17:33:43 -06:00
parent 913384600d
commit e5dd7c6d6a
3 changed files with 64 additions and 9 deletions

View file

@ -32,7 +32,7 @@ name = "store"
harness = false harness = false
[features] [features]
default = ["std"] default = ["std", "async"]
executable = ["dep:clap", "dep:rand-utils", "std"] executable = ["dep:clap", "dep:rand-utils", "std"]
serde = ["dep:serde", "serde?/alloc", "winter-math/serde"] serde = ["dep:serde", "serde?/alloc", "winter-math/serde"]
std = [ std = [
@ -44,6 +44,7 @@ std = [
"winter-math/std", "winter-math/std",
"winter-utils/std", "winter-utils/std",
] ]
async = ["serde?/rc"]
[dependencies] [dependencies]
blake3 = { version = "1.5", default-features = false } blake3 = { version = "1.5", default-features = false }

View file

@ -1,3 +1,6 @@
#[cfg(feature = "async")]
use std::sync::Arc;
use alloc::{ use alloc::{
collections::{BTreeMap, BTreeSet}, collections::{BTreeMap, BTreeSet},
string::ToString, string::ToString,
@ -44,7 +47,10 @@ pub const SMT_DEPTH: u8 = 64;
pub struct Smt { pub struct Smt {
root: RpoDigest, root: RpoDigest,
leaves: BTreeMap<u64, SmtLeaf>, leaves: BTreeMap<u64, SmtLeaf>,
#[cfg(not(feature = "async"))]
inner_nodes: BTreeMap<NodeIndex, InnerNode>, inner_nodes: BTreeMap<NodeIndex, InnerNode>,
#[cfg(feature = "async")]
inner_nodes: Arc<BTreeMap<NodeIndex, InnerNode>>,
} }
impl Smt { impl Smt {
@ -65,7 +71,7 @@ impl Smt {
Self { Self {
root, root,
leaves: BTreeMap::new(), leaves: BTreeMap::new(),
inner_nodes: BTreeMap::new(), inner_nodes: Default::default(),
} }
} }
@ -154,6 +160,23 @@ impl Smt {
}) })
} }
/// Gets a mutable reference to this structure's inner node mapping.
///
/// # Panics
/// This will panic if we have violated our own invariants and try to mutate these nodes while
/// Self::compute_mutations_parallel() is still running.
fn inner_nodes_mut(&mut self) -> &mut BTreeMap<NodeIndex, InnerNode> {
#[cfg(feature = "async")]
{
Arc::get_mut(&mut self.inner_nodes).unwrap()
}
#[cfg(not(feature = "async"))]
{
&mut self.inner_nodes
}
}
// STATE MUTATORS // STATE MUTATORS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@ -269,11 +292,11 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
} }
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) { fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
self.inner_nodes.insert(index, inner_node); self.inner_nodes_mut().insert(index, inner_node);
} }
fn remove_inner_node(&mut self, index: NodeIndex) { fn remove_inner_node(&mut self, index: NodeIndex) {
let _ = self.inner_nodes.remove(&index); let _ = self.inner_nodes_mut().remove(&index);
} }
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> { fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> {

View file

@ -1,4 +1,6 @@
use alloc::collections::{BTreeMap, BTreeSet}; use alloc::collections::{BTreeMap, BTreeSet};
#[cfg(feature = "async")]
use std::sync::Arc;
use super::{ use super::{
super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError,
@ -20,7 +22,10 @@ mod tests;
pub struct SimpleSmt<const DEPTH: u8> { pub struct SimpleSmt<const DEPTH: u8> {
root: RpoDigest, root: RpoDigest,
leaves: BTreeMap<u64, Word>, leaves: BTreeMap<u64, Word>,
#[cfg(not(feature = "async"))]
inner_nodes: BTreeMap<NodeIndex, InnerNode>, inner_nodes: BTreeMap<NodeIndex, InnerNode>,
#[cfg(feature = "async")]
inner_nodes: Arc<BTreeMap<NodeIndex, InnerNode>>,
} }
impl<const DEPTH: u8> SimpleSmt<DEPTH> { impl<const DEPTH: u8> SimpleSmt<DEPTH> {
@ -52,7 +57,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
Ok(Self { Ok(Self {
root, root,
leaves: BTreeMap::new(), leaves: BTreeMap::new(),
inner_nodes: BTreeMap::new(), inner_nodes: Default::default(),
}) })
} }
@ -175,6 +180,23 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
}) })
} }
/// Gets a mutable reference to this structure's inner node mapping.
///
/// # Panics
/// This will panic if we have violated our own invariants and try to mutate these nodes while
/// Self::compute_mutations_parallel() is still running.
fn inner_nodes_mut(&mut self) -> &mut BTreeMap<NodeIndex, InnerNode> {
#[cfg(feature = "async")]
{
Arc::get_mut(&mut self.inner_nodes).unwrap()
}
#[cfg(not(feature = "async"))]
{
&mut self.inner_nodes
}
}
// STATE MUTATORS // STATE MUTATORS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@ -271,7 +293,16 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
// add subtree's branch nodes (which includes the root) // add subtree's branch nodes (which includes the root)
// -------------- // --------------
for (branch_idx, branch_node) in subtree.inner_nodes { let subtree_nodes;
#[cfg(feature = "async")]
{
subtree_nodes = Arc::into_inner(subtree.inner_nodes).unwrap();
}
#[cfg(not(feature = "async"))]
{
subtree_nodes = subtree.inner_nodes
}
for (branch_idx, branch_node) in subtree_nodes {
let new_branch_idx = { let new_branch_idx = {
let new_depth = subtree_root_insertion_depth + branch_idx.depth(); let new_depth = subtree_root_insertion_depth + branch_idx.depth();
let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into()) let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
@ -280,7 +311,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid") NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
}; };
self.inner_nodes.insert(new_branch_idx, branch_node); self.inner_nodes_mut().insert(new_branch_idx, branch_node);
} }
// recompute nodes starting from subtree root // recompute nodes starting from subtree root
@ -315,11 +346,11 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
} }
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) { fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
self.inner_nodes.insert(index, inner_node); self.inner_nodes_mut().insert(index, inner_node);
} }
fn remove_inner_node(&mut self, index: NodeIndex) { fn remove_inner_node(&mut self, index: NodeIndex) {
let _ = self.inner_nodes.remove(&index); let _ = self.inner_nodes_mut().remove(&index);
} }
fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> { fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> {