diff --git a/src/merkle/index.rs b/src/merkle/index.rs index e9d5e0e..cedc435 100644 --- a/src/merkle/index.rs +++ b/src/merkle/index.rs @@ -83,20 +83,29 @@ impl NodeIndex { /// # Panics /// Panics if the depth indicated by `index` does not fit in a [`u8`], or if the row-value /// indicated by `index` does not fit in a [`u64`]. - pub const fn from_scalar_index(index: NonZero) -> Result { + pub fn from_scalar_index(index: NonZero) -> Result { let index = index.get() - 1; if index == 0 { - return Ok(Self { depth: 0, value: 0 }); + return Ok(Self::root()); + } + + // The log of 1 is always 0. + if index == 1 { + return Ok(Self::root().left_child()); } let depth = { - let depth = u128::ilog2(index); + let depth = u128::ilog2(index + 1); assert!(depth <= u8::MAX as u32); depth as u8 }; let max_value_for_depth = (1 << depth) - 1; + assert!( + max_value_for_depth <= u64::MAX as u128, + "max_value ({max_value_for_depth}) does not fit in u64", + ); let value = { let value = index - max_value_for_depth; @@ -125,6 +134,18 @@ impl NodeIndex { self } + pub const fn left_ancestor_n(mut self, n: u8) -> Self { + self.depth += n; + self.value <<= n; + self + } + + pub const fn right_ancestor_n(mut self, n: u8) -> Self { + self.depth += n; + self.value = (self.value << n) + 1; + self + } + /// Returns right child index of the current node. pub const fn right_child(mut self) -> Self { self.depth += 1; @@ -132,6 +153,68 @@ impl NodeIndex { self } + /// Returns the parent of the current node. + pub const fn parent(mut self) -> Self { + self.depth = self.depth.saturating_sub(1); + self.value >>= 1; + self + } + + /// Returns the `n`th parent of the current node. + pub fn parent_n(mut self, n: u8) -> Self { + debug_assert!(n <= self.depth); + self.depth = self.depth.saturating_sub(n); + self.value >>= n; + + self + } + + /// Returns `true` if and only if `other` is an ancestor of the current node, or the current + /// node itself. + pub fn contains(&self, mut other: Self) -> bool { + if other == *self { + return true; + } + if other.is_root() { + return false; + } + if other.depth < self.depth { + return false; + } + + other = other.parent_n(other.depth() - self.depth()); + + loop { + if other == *self { + return true; + } + + if other.is_root() { + return false; + } + + other = other.parent(); + } + } + + /// Returns the right-most descendent of the current node for a tree of `DEPTH` depth. + pub const fn rightmost_descendent(mut self) -> Self { + while self.depth() < DEPTH { + self = self.right_child(); + } + + self + } + + /// Returns the left-most descendent of the current node for a tree of `DEPTH` depth. + pub const fn leftmost_descendent(mut self) -> Self { + while self.depth() < DEPTH { + self = self.left_child(); + } + + self + } + // PROVIDERS // -------------------------------------------------------------------------------------------- @@ -245,6 +328,21 @@ mod tests { assert!(NodeIndex::new(64, u64::MAX).is_ok()); } + #[test] + fn test_scalar_roundtrip() { + // Arbitrary value that's at the bottom and not in a corner. + let start = NodeIndex::make(64, u64::MAX - 8); + + let mut index = start; + while !index.is_root() { + let as_scalar = index.to_scalar_index(); + let round_trip = + NodeIndex::from_scalar_index(NonZero::new(as_scalar).unwrap()).unwrap(); + assert_eq!(index, round_trip, "{index:?} did not round-trip as a scalar index"); + index.move_up(); + } + } + prop_compose! { fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex { // unwrap never panics because the range of depth is 0..u64::BITS