diff --git a/src/merkle/index.rs b/src/merkle/index.rs index e9d5e0e..577f712 100644 --- a/src/merkle/index.rs +++ b/src/merkle/index.rs @@ -83,20 +83,32 @@ impl NodeIndex { /// # Panics /// Panics if the depth indicated by `index` does not fit in a [`u8`], or if the row-value /// indicated by `index` does not fit in a [`u64`]. - pub const fn from_scalar_index(index: NonZero) -> Result { + pub fn from_scalar_index(index: NonZero) -> Result { let index = index.get() - 1; if index == 0 { - return Ok(Self { depth: 0, value: 0 }); + return Ok(Self::root()); + } + + // The log of 1 is always 0. + if index == 1 { + return Ok(Self::root().left_child()); } let depth = { - let depth = u128::ilog2(index); + let depth = u128::ilog2(index + 1); assert!(depth <= u8::MAX as u32); + //let depth = f64::log2(index as f64).round(); + //std::eprintln!("depth for scalar index {index} is {depth}"); + //assert!(depth <= u8::MAX as f64); depth as u8 }; let max_value_for_depth = (1 << depth) - 1; + assert!( + max_value_for_depth <= u64::MAX as u128, + "max_value ({max_value_for_depth}) does not fit in u64", + ); let value = { let value = index - max_value_for_depth; @@ -132,6 +144,55 @@ impl NodeIndex { self } + /// Returns the parent of the current node. + pub const fn parent(mut self) -> Self { + self.depth = self.depth.saturating_sub(1); + self.value >>= 1; + self + } + + /// Returns the `n`th parent of the current node. + pub const fn parent_n(mut self, n: u8) -> Self { + debug_assert!(n < self.depth); + let delta = self.depth.saturating_sub(n); + self.depth = self.depth.saturating_sub(delta); + self.value >>= delta as u32; + self + } + + /// Returns `true` if and only if `other` is a child of the current node. + pub const fn contains(&self, mut other: Self) -> bool { + loop { + if self.depth == other.depth && self.value == other.value { + return true; + } + + if other.is_root() { + return false; + } + + other = other.parent(); + } + } + + /// Returns the right-most descendent of the current node for a tree of `DEPTH` depth. + pub const fn rightmost_descendent(mut self) -> Self { + while self.depth() < DEPTH { + self = self.right_child(); + } + + self + } + + /// Returns the left-most descendent of the current node for a tree of `DEPTH` depth. + pub const fn leftmost_descendent(mut self) -> Self { + while self.depth() < DEPTH { + self = self.left_child(); + } + + self + } + // PROVIDERS // -------------------------------------------------------------------------------------------- @@ -245,6 +306,52 @@ mod tests { assert!(NodeIndex::new(64, u64::MAX).is_ok()); } + //#[test] + //fn test_traversal_roundtrip() { + // // Arbitrary value that's at the bottom and not in a corner. + // let start = NodeIndex::make(64, u64::MAX - 8); + // + // let mut index = start; + // while !index.is_root() { + // std::dbg!(&index); + // let as_traversal = index.to_traversal_index(); + // let as_scalar = index.to_scalar_index() - 1; + // assert_eq!(as_traversal, as_scalar as u128); + // let round_trip = NodeIndex::from_traversal_index(as_traversal).unwrap(); + // assert_eq!(index, round_trip, "{:?} did not round-trip as a traversal index", index); + // index.move_up(); + // } + // assert!(index.is_root()); + // let root_control = NodeIndex::root(); + // assert_eq!(index, root_control); + // + // // Traversal index 0 should be root. + // assert_eq!(index, NodeIndex::from_traversal_index(0).unwrap()); + //} + + #[test] + fn test_scalar_roundtrip() { + // Arbitrary value that's at the bottom and not in a corner. + let start = NodeIndex::make(64, u64::MAX - 8); + + let mut index = start; + while !index.is_root() { + let as_scalar = index.to_scalar_index(); + let round_trip = + NodeIndex::from_scalar_index(NonZero::new(as_scalar).unwrap()).unwrap(); + assert_eq!(index, round_trip, "{index:?} did not round-trip as a scalar index"); + index.move_up(); + } + + //let start = NodeIndex::root().left_child().to_scalar_index(); + //let max = u64::MAX as u128; + //for scalar in start..max { + // let index = NodeIndex::from_scalar_index(NonZero::new(scalar).unwrap()).unwrap(); + // let round_trip = index.to_scalar_index(); + // assert_eq!(scalar, round_trip, "scalar index {scalar} ({index:?}) did not round-trip"); + //} + } + prop_compose! { fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex { // unwrap never panics because the range of depth is 0..u64::BITS