WIP(smt): allow leaves to be wrapped in an Arc for async

This commit is contained in:
Qyriad 2024-10-11 13:44:19 -06:00
parent 1632ec54aa
commit 1ca498b346

View file

@ -46,7 +46,12 @@ pub const SMT_DEPTH: u8 = 64;
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct Smt {
root: RpoDigest,
#[cfg(not(feature = "async"))]
leaves: BTreeMap<u64, SmtLeaf>,
#[cfg(feature = "async")]
leaves: Arc<BTreeMap<u64, SmtLeaf>>,
#[cfg(not(feature = "async"))]
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
#[cfg(feature = "async")]
@ -70,7 +75,7 @@ impl Smt {
Self {
root,
leaves: BTreeMap::new(),
leaves: Default::default(),
inner_nodes: Default::default(),
}
}
@ -107,6 +112,11 @@ impl Smt {
Ok(tree)
}
#[cfg(feature = "async")]
pub fn get_leaves(&self) -> Arc<BTreeMap<u64, SmtLeaf>> {
Arc::clone(&self.leaves)
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
@ -177,6 +187,23 @@ impl Smt {
}
}
/// Gets a mutable reference to this structure's inner leaf mapping.
///
/// # Panics
/// This will panic if we have violated our own invariants and try to mutate these nodes while
/// Self::compute_mutations_parallel() is still running.
fn leaves_mut(&mut self) -> &mut BTreeMap<u64, SmtLeaf> {
#[cfg(feature = "async")]
{
Arc::get_mut(&mut self.leaves).unwrap()
}
#[cfg(not(feature = "async"))]
{
&mut self.leaves
}
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
@ -241,10 +268,12 @@ impl Smt {
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
match self.leaves.get_mut(&leaf_index.value()) {
let leaves = self.leaves_mut();
match leaves.get_mut(&leaf_index.value()) {
Some(leaf) => leaf.insert(key, value),
None => {
self.leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
None
},
@ -255,10 +284,12 @@ impl Smt {
fn perform_remove(&mut self, key: RpoDigest) -> Option<Word> {
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
if let Some(leaf) = self.leaves.get_mut(&leaf_index.value()) {
let leaves = self.leaves_mut();
if let Some(leaf) = leaves.get_mut(&leaf_index.value()) {
let (old_value, is_empty) = leaf.remove(key);
if is_empty {
self.leaves.remove(&leaf_index.value());
leaves.remove(&leaf_index.value());
}
old_value
} else {