diff --git a/miden-crypto/src/merkle/sparse_path.rs b/miden-crypto/src/merkle/sparse_path.rs index 88bd3d5..d73e12a 100644 --- a/miden-crypto/src/merkle/sparse_path.rs +++ b/miden-crypto/src/merkle/sparse_path.rs @@ -1,7 +1,7 @@ use alloc::vec::Vec; use core::iter; -use super::{EmptySubtreeRoots, MerklePath, RpoDigest, SMT_MAX_DEPTH}; +use super::{EmptySubtreeRoots, MerkleError, MerklePath, RpoDigest, SMT_MAX_DEPTH}; /// A different representation of [`MerklePath`] designed for memory efficiency for Merkle paths /// with empty nodes. @@ -89,6 +89,50 @@ impl SparseMerklePath { pub fn depth(&self) -> u8 { (self.nodes.len() + self.empty_nodes.count_ones() as usize) as u8 } + + /// Get a specific node in this path at a given depth. + /// + /// # Errors + /// Returns [MerkleError::DepthTooBig] if `node_depth` is greater than the total depth of this + /// path. + pub fn get(&self, node_depth: u8) -> Result { + let node = self + .get_nonempty(node_depth)? + .unwrap_or_else(|| *EmptySubtreeRoots::entry(self.depth(), node_depth)); + + Ok(node) + } + + /// Get a specific non-emptynode in this path at a given depth, or `None` if the specified node + /// is an empty node. + /// + /// # Errors + /// Returns [MerkleError::DepthTooBig] if `node_depth` is greater than the total depth of this + /// path. + pub fn get_nonempty(&self, node_depth: u8) -> Result, MerkleError> { + if node_depth > self.depth() { + return Err(MerkleError::DepthTooBig(node_depth.into())); + } + + let empty_bit = 1u64 << node_depth; + let is_empty = (self.empty_nodes & empty_bit) != 0; + + if is_empty { + return Ok(None); + } + + // Our index needs to account for all the empty nodes that aren't in `self.nodes`. + let nonempty_index: usize = { + // TODO: this could also be u64::unbounded_shl(1, node_depth + 1).wrapping_sub(1). + // We should check if that has any performance benefits over using 128-bit integers. + let mask: u64 = ((1u128 << (node_depth + 1)) - 1u128).try_into().unwrap(); + + let empty_before = u64::count_ones(self.empty_nodes & mask); + node_depth as usize - empty_before as usize + }; + + Ok(Some(self.nodes[nonempty_index])) + } } #[cfg(test)] @@ -102,9 +146,7 @@ mod tests { merkle::{SMT_DEPTH, Smt, smt::SparseMerkleTree}, }; - #[test] - fn roundtrip() { - let pair_count: u64 = 8192; + fn make_smt(pair_count: u64) -> Smt { let entries: Vec<(RpoDigest, Word)> = (0..pair_count) .map(|n| { let leaf_index = ((n as f64 / pair_count as f64) * 255.0) as u64; @@ -113,7 +155,13 @@ mod tests { (key, value) }) .collect(); - let tree = Smt::with_entries(entries).unwrap(); + + Smt::with_entries(entries).unwrap() + } + + #[test] + fn roundtrip() { + let tree = make_smt(8192); for (key, _value) in tree.entries() { let control_path = tree.get_path(key); @@ -125,4 +173,21 @@ mod tests { assert_eq!(control_path, test_path); } } + + #[test] + fn random_access() { + let tree = make_smt(8192); + + for (i, (key, _value)) in tree.entries().enumerate() { + let control_path = tree.get_path(key); + let sparse_path = SparseMerklePath::from_path(control_path.clone()).unwrap(); + assert_eq!(control_path.depth(), sparse_path.depth()); + assert_eq!(sparse_path.depth(), SMT_DEPTH); + + for (depth, control_node) in control_path.iter().enumerate() { + let test_node = sparse_path.get(depth as u8).unwrap(); + assert_eq!(*control_node, test_node, "at depth {depth} for entry {i}"); + } + } + } }