From b289e7ed73a229dc95e94f90ecf636139292ab14 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Mon, 23 Sep 2024 15:43:45 -0600 Subject: [PATCH] feat(smt): impl lowest common ancestor for leaf indices --- src/merkle/smt/mod.rs | 79 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 0b7ceb9..3c91cc4 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -379,6 +379,48 @@ impl LeafIndex { pub fn value(&self) -> u64 { self.index.value() } + + /// Lowest common ancestor — finds the lowest (highest depth) [`NodeIndex`] that is an ancestor + /// of both `self` and `rhs`. + /// + /// The general case algorithm is `O(n)`, however leaf indexes are always at the same depth, + /// and we only need find the depth of the lowest-common ancestor (since we can trivially get + /// its horizontal position based on either child's position), so we can reduce this to + /// `O(log n)`. + pub fn lca(&self, other: &Self) -> NodeIndex { + let mut self_scalar = self.index.to_scalar_index(); + let mut other_scalar = other.index.to_scalar_index(); + + while self_scalar != other_scalar { + self_scalar >>= 1; + other_scalar >>= 1; + } + + // Once we've shifted them enough to be equal, we've found a scalar index with the depth of + // the lowest common ancestor. Time to convert that scalar index to a depth, and apply that + // depth to either of our `NodeIndex`s to get the full position of that ancestor. + + // In general, we can get the depth of a binary tree's scalar index by taking the binary + // logarithm of that index. However, for the root node, the scalar index is 0, and the + // logarithm is undefined for 0, so we trivally special case the root index. + if self_scalar == 0 { + return NodeIndex::root(); + } + + let depth = { + let depth = u128::ilog2(self_scalar); + // The scalar index should not be able to exceed `u8::MAX + u64::MAX` (as those are the + // maximum values `NodeIndex` can hold), and the binary logarithm of `u8::MAX + + // u64::MAX` is 64, which fits in a u8. In other words, this assert should only be + // possible to fail if `to_scalar_index()` is wildly incorrect. + debug_assert!(depth <= u8::MAX as u32); + depth as u8 + }; + + let mut lca = self.index; + lca.move_up_to(depth); + lca + } } impl LeafIndex { @@ -456,3 +498,40 @@ impl MutationSet { self.new_root } } + + +#[cfg(test)] +mod tests { + use proptest::prelude::*; + + use crate::merkle::{LeafIndex, NodeIndex, SMT_DEPTH}; + + prop_compose! { + fn leaf_index()(value in 0..2u64.pow(u64::BITS - 1)) -> LeafIndex { + LeafIndex::new(value).unwrap() + } + } + + proptest! { + /// Tests that the O(log n) algorithm has the same results as the naïve version. + #[test] + fn test_leaf_lca(left in leaf_index(), right in leaf_index()) { + let control: NodeIndex = { + let mut left = left.index; + let mut right = right.index; + + loop { + if left == right { + break left; + } + left.move_up(); + right.move_up(); + } + }; + + let actual: NodeIndex = left.lca(&right); + + assert_eq!(actual, control); + } + } +}