From 71e6cce812299b8a0b4a56bab291e9d6a745a127 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Fri, 21 Mar 2025 13:59:36 +0100 Subject: [PATCH] SparseMerklePath: implement iterators --- src/merkle/sparse_path.rs | 106 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/src/merkle/sparse_path.rs b/src/merkle/sparse_path.rs index cb97903..dad8cb7 100644 --- a/src/merkle/sparse_path.rs +++ b/src/merkle/sparse_path.rs @@ -117,6 +117,87 @@ impl SparseMerklePath { } } +// ITERATORS +// ================================================================================================ + +impl IntoIterator for SparseMerklePath { + type Item = ::Item; + type IntoIter = SparseMerkleIter; + + fn into_iter(self) -> SparseMerkleIter { + let tree_depth = self.depth(); + SparseMerkleIter { + path: self, + next_depth: Some(0), + tree_depth, + } + } +} + +/// Owning iterator for [`SparseMerklePath`]. +// TODO: add a non-owning iterator too. +pub struct SparseMerkleIter { + /// The "inner" value we're iterating over. + path: SparseMerklePath, + + /// The depth a `next()` call will get. It will only be None if someone calls `next_back()` at + /// depth 0, to indicate that all further `next_back()` calls must also return `None`. + next_depth: Option, + + /// "Cached" value of `path.depth()`. + tree_depth: u8, +} + +impl Iterator for SparseMerkleIter { + type Item = RpoDigest; + + fn next(&mut self) -> Option { + // If `next_depth` is None, then someone called `next_back()` at depth 0. + let next_depth = self.next_depth.unwrap_or(0); + if next_depth > self.tree_depth { + return None; + } + + match self.path.get(self.tree_depth, next_depth) { + Some(node) => { + self.next_depth = Some(next_depth + 1); + Some(node) + }, + None => None, + } + } + + // SparseMerkleIter always knows its exact size. + fn size_hint(&self) -> (usize, Option) { + let next_depth = self.next_depth.unwrap_or(0); + let len: usize = self.path.depth().into(); + let remaining = len - next_depth as usize; + (remaining, Some(remaining)) + } +} + +impl ExactSizeIterator for SparseMerkleIter { + fn len(&self) -> usize { + let next_depth = self.next_depth.unwrap_or(0); + (self.path.depth() - next_depth) as usize + } +} + +impl DoubleEndedIterator for SparseMerkleIter { + fn next_back(&mut self) -> Option { + // While `next_depth` is None, all calls to `next_back()` also return `None`. + let next_depth = self.next_depth?; + + match self.path.get(self.tree_depth, next_depth) { + Some(node) => { + self.next_depth = if next_depth == 0 { None } else { Some(next_depth - 1) }; + Some(node) + }, + None => None, + } + } +} + #[cfg(test)] mod tests { use alloc::vec::Vec; @@ -179,4 +260,29 @@ mod tests { } } } + + #[test] + fn iterator() { + let tree = make_smt(8192); + + for (i, (key, _value)) in tree.entries().enumerate() { + let path = tree.path(key); + let sparse_path = SparseMerklePath::from_path(SMT_DEPTH, path.clone()).unwrap(); + assert_eq!(path.depth(), sparse_path.depth()); + assert_eq!(sparse_path.depth(), SMT_DEPTH); + for (depth, iter_node) in sparse_path.clone().into_iter().enumerate() { + let control_node = sparse_path.get(SMT_DEPTH, depth as u8).unwrap(); + assert_eq!(control_node, iter_node, "at depth {depth} for entry {i}"); + } + + let iter = sparse_path.clone().into_iter().enumerate().rev().skip(1); + for (depth, iter_node) in iter { + let control_node = sparse_path.get(SMT_DEPTH, depth as u8).unwrap(); + assert_eq!( + control_node, iter_node, + "at depth {depth} for entry {i} during reverse-iteration", + ); + } + } + } }