WIP(smt): allow inner_nodes: to be wrapped in an Arc for async

This commit is contained in:
Qyriad 2024-09-16 17:33:43 -06:00
parent 913384600d
commit e5dd7c6d6a
3 changed files with 64 additions and 9 deletions

View file

@ -32,7 +32,7 @@ name = "store"
harness = false
[features]
default = ["std"]
default = ["std", "async"]
executable = ["dep:clap", "dep:rand-utils", "std"]
serde = ["dep:serde", "serde?/alloc", "winter-math/serde"]
std = [
@ -44,6 +44,7 @@ std = [
"winter-math/std",
"winter-utils/std",
]
async = ["serde?/rc"]
[dependencies]
blake3 = { version = "1.5", default-features = false }

View file

@ -1,3 +1,6 @@
#[cfg(feature = "async")]
use std::sync::Arc;
use alloc::{
collections::{BTreeMap, BTreeSet},
string::ToString,
@ -44,7 +47,10 @@ pub const SMT_DEPTH: u8 = 64;
pub struct Smt {
root: RpoDigest,
leaves: BTreeMap<u64, SmtLeaf>,
#[cfg(not(feature = "async"))]
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
#[cfg(feature = "async")]
inner_nodes: Arc<BTreeMap<NodeIndex, InnerNode>>,
}
impl Smt {
@ -65,7 +71,7 @@ impl Smt {
Self {
root,
leaves: BTreeMap::new(),
inner_nodes: BTreeMap::new(),
inner_nodes: Default::default(),
}
}
@ -154,6 +160,23 @@ impl Smt {
})
}
/// Gets a mutable reference to this structure's inner node mapping.
///
/// # Panics
/// This will panic if we have violated our own invariants and try to mutate these nodes while
/// Self::compute_mutations_parallel() is still running.
fn inner_nodes_mut(&mut self) -> &mut BTreeMap<NodeIndex, InnerNode> {
#[cfg(feature = "async")]
{
Arc::get_mut(&mut self.inner_nodes).unwrap()
}
#[cfg(not(feature = "async"))]
{
&mut self.inner_nodes
}
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
@ -269,11 +292,11 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
}
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
self.inner_nodes.insert(index, inner_node);
self.inner_nodes_mut().insert(index, inner_node);
}
fn remove_inner_node(&mut self, index: NodeIndex) {
let _ = self.inner_nodes.remove(&index);
let _ = self.inner_nodes_mut().remove(&index);
}
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> {

View file

@ -1,4 +1,6 @@
use alloc::collections::{BTreeMap, BTreeSet};
#[cfg(feature = "async")]
use std::sync::Arc;
use super::{
super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError,
@ -20,7 +22,10 @@ mod tests;
pub struct SimpleSmt<const DEPTH: u8> {
root: RpoDigest,
leaves: BTreeMap<u64, Word>,
#[cfg(not(feature = "async"))]
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
#[cfg(feature = "async")]
inner_nodes: Arc<BTreeMap<NodeIndex, InnerNode>>,
}
impl<const DEPTH: u8> SimpleSmt<DEPTH> {
@ -52,7 +57,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
Ok(Self {
root,
leaves: BTreeMap::new(),
inner_nodes: BTreeMap::new(),
inner_nodes: Default::default(),
})
}
@ -175,6 +180,23 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
})
}
/// Gets a mutable reference to this structure's inner node mapping.
///
/// # Panics
/// This will panic if we have violated our own invariants and try to mutate these nodes while
/// Self::compute_mutations_parallel() is still running.
fn inner_nodes_mut(&mut self) -> &mut BTreeMap<NodeIndex, InnerNode> {
#[cfg(feature = "async")]
{
Arc::get_mut(&mut self.inner_nodes).unwrap()
}
#[cfg(not(feature = "async"))]
{
&mut self.inner_nodes
}
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
@ -271,7 +293,16 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
// add subtree's branch nodes (which includes the root)
// --------------
for (branch_idx, branch_node) in subtree.inner_nodes {
let subtree_nodes;
#[cfg(feature = "async")]
{
subtree_nodes = Arc::into_inner(subtree.inner_nodes).unwrap();
}
#[cfg(not(feature = "async"))]
{
subtree_nodes = subtree.inner_nodes
}
for (branch_idx, branch_node) in subtree_nodes {
let new_branch_idx = {
let new_depth = subtree_root_insertion_depth + branch_idx.depth();
let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
@ -280,7 +311,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
};
self.inner_nodes.insert(new_branch_idx, branch_node);
self.inner_nodes_mut().insert(new_branch_idx, branch_node);
}
// recompute nodes starting from subtree root
@ -315,11 +346,11 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
}
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
self.inner_nodes.insert(index, inner_node);
self.inner_nodes_mut().insert(index, inner_node);
}
fn remove_inner_node(&mut self, index: NodeIndex) {
let _ = self.inner_nodes.remove(&index);
let _ = self.inner_nodes_mut().remove(&index);
}
fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> {