diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 659ef1f..360ae02 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -46,7 +46,12 @@ pub const SMT_DEPTH: u8 = 64; #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct Smt { root: RpoDigest, + + #[cfg(not(feature = "async"))] leaves: BTreeMap, + #[cfg(feature = "async")] + leaves: Arc>, + #[cfg(not(feature = "async"))] inner_nodes: BTreeMap, #[cfg(feature = "async")] @@ -70,7 +75,7 @@ impl Smt { Self { root, - leaves: BTreeMap::new(), + leaves: Default::default(), inner_nodes: Default::default(), } } @@ -107,6 +112,11 @@ impl Smt { Ok(tree) } + #[cfg(feature = "async")] + pub fn get_leaves(&self) -> Arc> { + Arc::clone(&self.leaves) + } + // PUBLIC ACCESSORS // -------------------------------------------------------------------------------------------- @@ -177,6 +187,23 @@ impl Smt { } } + /// Gets a mutable reference to this structure's inner leaf mapping. + /// + /// # Panics + /// This will panic if we have violated our own invariants and try to mutate these nodes while + /// Self::compute_mutations_parallel() is still running. + fn leaves_mut(&mut self) -> &mut BTreeMap { + #[cfg(feature = "async")] + { + Arc::get_mut(&mut self.leaves).unwrap() + } + + #[cfg(not(feature = "async"))] + { + &mut self.leaves + } + } + // STATE MUTATORS // -------------------------------------------------------------------------------------------- @@ -241,10 +268,12 @@ impl Smt { let leaf_index: LeafIndex = Self::key_to_leaf_index(&key); - match self.leaves.get_mut(&leaf_index.value()) { + let leaves = self.leaves_mut(); + + match leaves.get_mut(&leaf_index.value()) { Some(leaf) => leaf.insert(key, value), None => { - self.leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value))); + leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value))); None }, @@ -255,10 +284,12 @@ impl Smt { fn perform_remove(&mut self, key: RpoDigest) -> Option { let leaf_index: LeafIndex = Self::key_to_leaf_index(&key); - if let Some(leaf) = self.leaves.get_mut(&leaf_index.value()) { + let leaves = self.leaves_mut(); + + if let Some(leaf) = leaves.get_mut(&leaf_index.value()) { let (old_value, is_empty) = leaf.remove(key); if is_empty { - self.leaves.remove(&leaf_index.value()); + leaves.remove(&leaf_index.value()); } old_value } else {