From a12e62ff2263f00d26ae94e1db07a758d60c75b8 Mon Sep 17 00:00:00 2001 From: Bobbin Threadbare <43513081+bobbinth@users.noreply.github.com> Date: Sun, 18 Aug 2024 09:35:12 -0700 Subject: [PATCH 1/5] feat: improve MMR api (#324) --- CHANGELOG.md | 5 ++ src/merkle/mmr/full.rs | 41 ++++++++++--- src/merkle/mmr/partial.rs | 44 +++++++------- src/merkle/mmr/tests.rs | 119 +++++++++++++++++++------------------- 4 files changed, 117 insertions(+), 92 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4612b0c..7ddf280 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +## 0.11.0 (TBD) + +- [BREAKING]: renamed `Mmr::open()` into `Mmr::open_at()` and `Mmr::peaks()` into `Mmr::peaks_at()` (#234). +- Added `Mmr::open()` and `Mmr::peaks()` which rely on `Mmr::open_at()` and `Mmr::peaks()` respectively (#234). + ## 0.10.0 (2024-08-06) * Added more `RpoDigest` and `RpxDigest` conversions (#311). diff --git a/src/merkle/mmr/full.rs b/src/merkle/mmr/full.rs index b2ea2df..7c1d3f0 100644 --- a/src/merkle/mmr/full.rs +++ b/src/merkle/mmr/full.rs @@ -72,19 +72,36 @@ impl Mmr { // FUNCTIONALITY // ============================================================================================ - /// Given a leaf position, returns the Merkle path to its corresponding peak. If the position - /// is greater-or-equal than the tree size an error is returned. + /// Returns an [MmrProof] for the leaf at the specified position. /// /// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were /// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element /// has position 0, the second position 1, and so on. - pub fn open(&self, pos: usize, target_forest: usize) -> Result { + /// + /// # Errors + /// Returns an error if the specified leaf position is out of bounds for this MMR. + pub fn open(&self, pos: usize) -> Result { + self.open_at(pos, self.forest) + } + + /// Returns an [MmrProof] for the leaf at the specified position using the state of the MMR + /// at the specified `forest`. + /// + /// Note: The leaf position is the 0-indexed number corresponding to the order the leaves were + /// added, this corresponds to the MMR size _prior_ to adding the element. So the 1st element + /// has position 0, the second position 1, and so on. + /// + /// # Errors + /// Returns an error if: + /// - The specified leaf position is out of bounds for this MMR. + /// - The specified `forest` value is not valid for this MMR. + pub fn open_at(&self, pos: usize, forest: usize) -> Result { // find the target tree responsible for the MMR position let tree_bit = - leaf_to_corresponding_tree(pos, target_forest).ok_or(MmrError::InvalidPosition(pos))?; + leaf_to_corresponding_tree(pos, forest).ok_or(MmrError::InvalidPosition(pos))?; // isolate the trees before the target - let forest_before = target_forest & high_bitmask(tree_bit + 1); + let forest_before = forest & high_bitmask(tree_bit + 1); let index_offset = nodes_in_forest(forest_before); // update the value position from global to the target tree @@ -94,7 +111,7 @@ impl Mmr { let (_, path) = self.collect_merkle_path_and_value(tree_bit, relative_pos, index_offset); Ok(MmrProof { - forest: target_forest, + forest, position: pos, merkle_path: MerklePath::new(path), }) @@ -145,8 +162,16 @@ impl Mmr { self.forest += 1; } - /// Returns an peaks of the MMR for the version specified by `forest`. - pub fn peaks(&self, forest: usize) -> Result { + /// Returns the current peaks of the MMR. + pub fn peaks(&self) -> MmrPeaks { + self.peaks_at(self.forest).expect("failed to get peaks at current forest") + } + + /// Returns the peaks of the MMR at the state specified by `forest`. + /// + /// # Errors + /// Returns an error if the specified `forest` value is not valid for this MMR. + pub fn peaks_at(&self, forest: usize) -> Result { if forest > self.forest { return Err(MmrError::InvalidPeaks); } diff --git a/src/merkle/mmr/partial.rs b/src/merkle/mmr/partial.rs index c2c4464..a3543c6 100644 --- a/src/merkle/mmr/partial.rs +++ b/src/merkle/mmr/partial.rs @@ -688,18 +688,18 @@ mod tests { // build an MMR with 10 nodes (2 peaks) and a partial MMR based on it let mut mmr = Mmr::default(); (0..10).for_each(|i| mmr.add(int_to_node(i))); - let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into(); + let mut partial_mmr: PartialMmr = mmr.peaks().into(); // add authentication path for position 1 and 8 { let node = mmr.get(1).unwrap(); - let proof = mmr.open(1, mmr.forest()).unwrap(); + let proof = mmr.open(1).unwrap(); partial_mmr.track(1, node, &proof.merkle_path).unwrap(); } { let node = mmr.get(8).unwrap(); - let proof = mmr.open(8, mmr.forest()).unwrap(); + let proof = mmr.open(8).unwrap(); partial_mmr.track(8, node, &proof.merkle_path).unwrap(); } @@ -712,7 +712,7 @@ mod tests { validate_apply_delta(&mmr, &mut partial_mmr); { let node = mmr.get(12).unwrap(); - let proof = mmr.open(12, mmr.forest()).unwrap(); + let proof = mmr.open(12).unwrap(); partial_mmr.track(12, node, &proof.merkle_path).unwrap(); assert!(partial_mmr.track_latest); } @@ -737,7 +737,7 @@ mod tests { let nodes_delta = partial.apply(delta).unwrap(); // new peaks were computed correctly - assert_eq!(mmr.peaks(mmr.forest()).unwrap(), partial.peaks()); + assert_eq!(mmr.peaks(), partial.peaks()); let mut expected_nodes = nodes_before; for (key, value) in nodes_delta { @@ -753,7 +753,7 @@ mod tests { let index_value: u64 = index.into(); let pos = index_value / 2; let proof1 = partial.open(pos as usize).unwrap().unwrap(); - let proof2 = mmr.open(pos as usize, mmr.forest()).unwrap(); + let proof2 = mmr.open(pos as usize).unwrap(); assert_eq!(proof1, proof2); } } @@ -762,16 +762,16 @@ mod tests { fn test_partial_mmr_inner_nodes_iterator() { // build the MMR let mmr: Mmr = LEAVES.into(); - let first_peak = mmr.peaks(mmr.forest).unwrap().peaks()[0]; + let first_peak = mmr.peaks().peaks()[0]; // -- test single tree ---------------------------- // get path and node for position 1 let node1 = mmr.get(1).unwrap(); - let proof1 = mmr.open(1, mmr.forest()).unwrap(); + let proof1 = mmr.open(1).unwrap(); // create partial MMR and add authentication path to node at position 1 - let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into(); + let mut partial_mmr: PartialMmr = mmr.peaks().into(); partial_mmr.track(1, node1, &proof1.merkle_path).unwrap(); // empty iterator should have no nodes @@ -789,13 +789,13 @@ mod tests { // -- test no duplicates -------------------------- // build the partial MMR - let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into(); + let mut partial_mmr: PartialMmr = mmr.peaks().into(); let node0 = mmr.get(0).unwrap(); - let proof0 = mmr.open(0, mmr.forest()).unwrap(); + let proof0 = mmr.open(0).unwrap(); let node2 = mmr.get(2).unwrap(); - let proof2 = mmr.open(2, mmr.forest()).unwrap(); + let proof2 = mmr.open(2).unwrap(); partial_mmr.track(0, node0, &proof0.merkle_path).unwrap(); partial_mmr.track(1, node1, &proof1.merkle_path).unwrap(); @@ -826,10 +826,10 @@ mod tests { // -- test multiple trees ------------------------- // build the partial MMR - let mut partial_mmr: PartialMmr = mmr.peaks(mmr.forest()).unwrap().into(); + let mut partial_mmr: PartialMmr = mmr.peaks().into(); let node5 = mmr.get(5).unwrap(); - let proof5 = mmr.open(5, mmr.forest()).unwrap(); + let proof5 = mmr.open(5).unwrap(); partial_mmr.track(1, node1, &proof1.merkle_path).unwrap(); partial_mmr.track(5, node5, &proof5.merkle_path).unwrap(); @@ -841,7 +841,7 @@ mod tests { let index1 = NodeIndex::new(2, 1).unwrap(); let index5 = NodeIndex::new(1, 1).unwrap(); - let second_peak = mmr.peaks(mmr.forest).unwrap().peaks()[1]; + let second_peak = mmr.peaks().peaks()[1]; let path1 = store.get_path(first_peak, index1).unwrap().path; let path5 = store.get_path(second_peak, index5).unwrap().path; @@ -860,8 +860,7 @@ mod tests { mmr.add(el); partial_mmr.add(el, false); - let mmr_peaks = mmr.peaks(mmr.forest()).unwrap(); - assert_eq!(mmr_peaks, partial_mmr.peaks()); + assert_eq!(mmr.peaks(), partial_mmr.peaks()); assert_eq!(mmr.forest(), partial_mmr.forest()); } } @@ -877,12 +876,11 @@ mod tests { mmr.add(el); partial_mmr.add(el, true); - let mmr_peaks = mmr.peaks(mmr.forest()).unwrap(); - assert_eq!(mmr_peaks, partial_mmr.peaks()); + assert_eq!(mmr.peaks(), partial_mmr.peaks()); assert_eq!(mmr.forest(), partial_mmr.forest()); for pos in 0..i { - let mmr_proof = mmr.open(pos as usize, mmr.forest()).unwrap(); + let mmr_proof = mmr.open(pos as usize).unwrap(); let partialmmr_proof = partial_mmr.open(pos as usize).unwrap().unwrap(); assert_eq!(mmr_proof, partialmmr_proof); } @@ -894,8 +892,8 @@ mod tests { let mut mmr = Mmr::from((0..7).map(int_to_node)); // derive a partial Mmr from it which tracks authentication path to leaf 5 - let mut partial_mmr = PartialMmr::from_peaks(mmr.peaks(mmr.forest()).unwrap()); - let path_to_5 = mmr.open(5, mmr.forest()).unwrap().merkle_path; + let mut partial_mmr = PartialMmr::from_peaks(mmr.peaks()); + let path_to_5 = mmr.open(5).unwrap().merkle_path; let leaf_at_5 = mmr.get(5).unwrap(); partial_mmr.track(5, leaf_at_5, &path_to_5).unwrap(); @@ -905,6 +903,6 @@ mod tests { partial_mmr.add(leaf_at_7, false); // the openings should be the same - assert_eq!(mmr.open(5, mmr.forest()).unwrap(), partial_mmr.open(5).unwrap().unwrap()); + assert_eq!(mmr.open(5).unwrap(), partial_mmr.open(5).unwrap().unwrap()); } } diff --git a/src/merkle/mmr/tests.rs b/src/merkle/mmr/tests.rs index 509f83c..02785f9 100644 --- a/src/merkle/mmr/tests.rs +++ b/src/merkle/mmr/tests.rs @@ -138,7 +138,7 @@ fn test_mmr_simple() { assert_eq!(mmr.nodes.len(), 1); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); - let acc = mmr.peaks(mmr.forest()).unwrap(); + let acc = mmr.peaks(); assert_eq!(acc.num_leaves(), 1); assert_eq!(acc.peaks(), &[postorder[0]]); @@ -147,7 +147,7 @@ fn test_mmr_simple() { assert_eq!(mmr.nodes.len(), 3); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); - let acc = mmr.peaks(mmr.forest()).unwrap(); + let acc = mmr.peaks(); assert_eq!(acc.num_leaves(), 2); assert_eq!(acc.peaks(), &[postorder[2]]); @@ -156,7 +156,7 @@ fn test_mmr_simple() { assert_eq!(mmr.nodes.len(), 4); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); - let acc = mmr.peaks(mmr.forest()).unwrap(); + let acc = mmr.peaks(); assert_eq!(acc.num_leaves(), 3); assert_eq!(acc.peaks(), &[postorder[2], postorder[3]]); @@ -165,7 +165,7 @@ fn test_mmr_simple() { assert_eq!(mmr.nodes.len(), 7); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); - let acc = mmr.peaks(mmr.forest()).unwrap(); + let acc = mmr.peaks(); assert_eq!(acc.num_leaves(), 4); assert_eq!(acc.peaks(), &[postorder[6]]); @@ -174,7 +174,7 @@ fn test_mmr_simple() { assert_eq!(mmr.nodes.len(), 8); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); - let acc = mmr.peaks(mmr.forest()).unwrap(); + let acc = mmr.peaks(); assert_eq!(acc.num_leaves(), 5); assert_eq!(acc.peaks(), &[postorder[6], postorder[7]]); @@ -183,7 +183,7 @@ fn test_mmr_simple() { assert_eq!(mmr.nodes.len(), 10); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); - let acc = mmr.peaks(mmr.forest()).unwrap(); + let acc = mmr.peaks(); assert_eq!(acc.num_leaves(), 6); assert_eq!(acc.peaks(), &[postorder[6], postorder[9]]); @@ -192,7 +192,7 @@ fn test_mmr_simple() { assert_eq!(mmr.nodes.len(), 11); assert_eq!(mmr.nodes.as_slice(), &postorder[0..mmr.nodes.len()]); - let acc = mmr.peaks(mmr.forest()).unwrap(); + let acc = mmr.peaks(); assert_eq!(acc.num_leaves(), 7); assert_eq!(acc.peaks(), &[postorder[6], postorder[9], postorder[10]]); } @@ -204,95 +204,92 @@ fn test_mmr_open() { let h23 = merge(LEAVES[2], LEAVES[3]); // node at pos 7 is the root - assert!( - mmr.open(7, mmr.forest()).is_err(), - "Element 7 is not in the tree, result should be None" - ); + assert!(mmr.open(7).is_err(), "Element 7 is not in the tree, result should be None"); // node at pos 6 is the root let empty: MerklePath = MerklePath::new(vec![]); let opening = mmr - .open(6, mmr.forest()) + .open(6) .expect("Element 6 is contained in the tree, expected an opening result."); assert_eq!(opening.merkle_path, empty); assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.position, 6); assert!( - mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[6], opening), + mmr.peaks().verify(LEAVES[6], opening), "MmrProof should be valid for the current accumulator." ); // nodes 4,5 are depth 1 let root_to_path = MerklePath::new(vec![LEAVES[4]]); let opening = mmr - .open(5, mmr.forest()) + .open(5) .expect("Element 5 is contained in the tree, expected an opening result."); assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.position, 5); assert!( - mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[5], opening), + mmr.peaks().verify(LEAVES[5], opening), "MmrProof should be valid for the current accumulator." ); let root_to_path = MerklePath::new(vec![LEAVES[5]]); let opening = mmr - .open(4, mmr.forest()) + .open(4) .expect("Element 4 is contained in the tree, expected an opening result."); assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.position, 4); assert!( - mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[4], opening), + mmr.peaks().verify(LEAVES[4], opening), "MmrProof should be valid for the current accumulator." ); // nodes 0,1,2,3 are detph 2 let root_to_path = MerklePath::new(vec![LEAVES[2], h01]); let opening = mmr - .open(3, mmr.forest()) + .open(3) .expect("Element 3 is contained in the tree, expected an opening result."); assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.position, 3); assert!( - mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[3], opening), + mmr.peaks().verify(LEAVES[3], opening), "MmrProof should be valid for the current accumulator." ); let root_to_path = MerklePath::new(vec![LEAVES[3], h01]); let opening = mmr - .open(2, mmr.forest()) + .open(2) .expect("Element 2 is contained in the tree, expected an opening result."); assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.position, 2); assert!( - mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[2], opening), + mmr.peaks().verify(LEAVES[2], opening), "MmrProof should be valid for the current accumulator." ); let root_to_path = MerklePath::new(vec![LEAVES[0], h23]); let opening = mmr - .open(1, mmr.forest()) + .open(1) .expect("Element 1 is contained in the tree, expected an opening result."); assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.position, 1); assert!( - mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[1], opening), + mmr.peaks().verify(LEAVES[1], opening), "MmrProof should be valid for the current accumulator." ); let root_to_path = MerklePath::new(vec![LEAVES[1], h23]); let opening = mmr - .open(0, mmr.forest()) + .open(0) .expect("Element 0 is contained in the tree, expected an opening result."); assert_eq!(opening.merkle_path, root_to_path); assert_eq!(opening.forest, mmr.forest); assert_eq!(opening.position, 0); assert!( - mmr.peaks(mmr.forest()).unwrap().verify(LEAVES[0], opening), + mmr.peaks().verify(LEAVES[0], opening), "MmrProof should be valid for the current accumulator." ); } @@ -308,7 +305,7 @@ fn test_mmr_open_older_version() { // merkle path of a node is empty if there are no elements to pair with it for pos in (0..mmr.forest()).filter(is_even) { let forest = pos + 1; - let proof = mmr.open(pos, forest).unwrap(); + let proof = mmr.open_at(pos, forest).unwrap(); assert_eq!(proof.forest, forest); assert_eq!(proof.merkle_path.nodes(), []); assert_eq!(proof.position, pos); @@ -320,7 +317,7 @@ fn test_mmr_open_older_version() { for pos in 0..4 { let idx = NodeIndex::new(2, pos).unwrap(); let path = mtree.get_path(idx).unwrap(); - let proof = mmr.open(pos as usize, forest).unwrap(); + let proof = mmr.open_at(pos as usize, forest).unwrap(); assert_eq!(path, proof.merkle_path); } } @@ -331,7 +328,7 @@ fn test_mmr_open_older_version() { let path = mtree.get_path(idx).unwrap(); // account for the bigger tree with 4 elements let mmr_pos = (pos + 4) as usize; - let proof = mmr.open(mmr_pos, forest).unwrap(); + let proof = mmr.open_at(mmr_pos, forest).unwrap(); assert_eq!(path, proof.merkle_path); } } @@ -357,49 +354,49 @@ fn test_mmr_open_eight() { let root = mtree.root(); let position = 0; - let proof = mmr.open(position, mmr.forest()).unwrap(); + let proof = mmr.open(position).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); let position = 1; - let proof = mmr.open(position, mmr.forest()).unwrap(); + let proof = mmr.open(position).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); let position = 2; - let proof = mmr.open(position, mmr.forest()).unwrap(); + let proof = mmr.open(position).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); let position = 3; - let proof = mmr.open(position, mmr.forest()).unwrap(); + let proof = mmr.open(position).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); let position = 4; - let proof = mmr.open(position, mmr.forest()).unwrap(); + let proof = mmr.open(position).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); let position = 5; - let proof = mmr.open(position, mmr.forest()).unwrap(); + let proof = mmr.open(position).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); let position = 6; - let proof = mmr.open(position, mmr.forest()).unwrap(); + let proof = mmr.open(position).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); let position = 7; - let proof = mmr.open(position, mmr.forest()).unwrap(); + let proof = mmr.open(position).unwrap(); let merkle_path = mtree.get_path(NodeIndex::new(3, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(position as u64, leaves[position]).unwrap(), root); @@ -415,47 +412,47 @@ fn test_mmr_open_seven() { let mmr: Mmr = LEAVES.into(); let position = 0; - let proof = mmr.open(position, mmr.forest()).unwrap(); + let proof = mmr.open(position).unwrap(); let merkle_path: MerklePath = mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(0, LEAVES[0]).unwrap(), mtree1.root()); let position = 1; - let proof = mmr.open(position, mmr.forest()).unwrap(); + let proof = mmr.open(position).unwrap(); let merkle_path: MerklePath = mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(1, LEAVES[1]).unwrap(), mtree1.root()); let position = 2; - let proof = mmr.open(position, mmr.forest()).unwrap(); + let proof = mmr.open(position).unwrap(); let merkle_path: MerklePath = mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(2, LEAVES[2]).unwrap(), mtree1.root()); let position = 3; - let proof = mmr.open(position, mmr.forest()).unwrap(); + let proof = mmr.open(position).unwrap(); let merkle_path: MerklePath = mtree1.get_path(NodeIndex::new(2, position as u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(3, LEAVES[3]).unwrap(), mtree1.root()); let position = 4; - let proof = mmr.open(position, mmr.forest()).unwrap(); + let proof = mmr.open(position).unwrap(); let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 0u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(0, LEAVES[4]).unwrap(), mtree2.root()); let position = 5; - let proof = mmr.open(position, mmr.forest()).unwrap(); + let proof = mmr.open(position).unwrap(); let merkle_path: MerklePath = mtree2.get_path(NodeIndex::new(1, 1u64).unwrap()).unwrap(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(1, LEAVES[5]).unwrap(), mtree2.root()); let position = 6; - let proof = mmr.open(position, mmr.forest()).unwrap(); + let proof = mmr.open(position).unwrap(); let merkle_path: MerklePath = [].as_ref().into(); assert_eq!(proof, MmrProof { forest, position, merkle_path }); assert_eq!(proof.merkle_path.compute_root(0, LEAVES[6]).unwrap(), LEAVES[6]); @@ -479,7 +476,7 @@ fn test_mmr_invariants() { let mut mmr = Mmr::new(); for v in 1..=1028 { mmr.add(int_to_node(v)); - let accumulator = mmr.peaks(mmr.forest()).unwrap(); + let accumulator = mmr.peaks(); assert_eq!(v as usize, mmr.forest(), "MMR leaf count must increase by one on every add"); assert_eq!( v as usize, @@ -565,37 +562,37 @@ fn test_mmr_peaks() { let mmr: Mmr = LEAVES.into(); let forest = 0b0001; - let acc = mmr.peaks(forest).unwrap(); + let acc = mmr.peaks_at(forest).unwrap(); assert_eq!(acc.num_leaves(), forest); assert_eq!(acc.peaks(), &[mmr.nodes[0]]); let forest = 0b0010; - let acc = mmr.peaks(forest).unwrap(); + let acc = mmr.peaks_at(forest).unwrap(); assert_eq!(acc.num_leaves(), forest); assert_eq!(acc.peaks(), &[mmr.nodes[2]]); let forest = 0b0011; - let acc = mmr.peaks(forest).unwrap(); + let acc = mmr.peaks_at(forest).unwrap(); assert_eq!(acc.num_leaves(), forest); assert_eq!(acc.peaks(), &[mmr.nodes[2], mmr.nodes[3]]); let forest = 0b0100; - let acc = mmr.peaks(forest).unwrap(); + let acc = mmr.peaks_at(forest).unwrap(); assert_eq!(acc.num_leaves(), forest); assert_eq!(acc.peaks(), &[mmr.nodes[6]]); let forest = 0b0101; - let acc = mmr.peaks(forest).unwrap(); + let acc = mmr.peaks_at(forest).unwrap(); assert_eq!(acc.num_leaves(), forest); assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[7]]); let forest = 0b0110; - let acc = mmr.peaks(forest).unwrap(); + let acc = mmr.peaks_at(forest).unwrap(); assert_eq!(acc.num_leaves(), forest); assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9]]); let forest = 0b0111; - let acc = mmr.peaks(forest).unwrap(); + let acc = mmr.peaks_at(forest).unwrap(); assert_eq!(acc.num_leaves(), forest); assert_eq!(acc.peaks(), &[mmr.nodes[6], mmr.nodes[9], mmr.nodes[10]]); } @@ -603,7 +600,7 @@ fn test_mmr_peaks() { #[test] fn test_mmr_hash_peaks() { let mmr: Mmr = LEAVES.into(); - let peaks = mmr.peaks(mmr.forest()).unwrap(); + let peaks = mmr.peaks(); let first_peak = Rpo256::merge(&[ Rpo256::merge(&[LEAVES[0], LEAVES[1]]), @@ -657,7 +654,7 @@ fn test_mmr_peaks_hash_odd() { #[test] fn test_mmr_delta() { let mmr: Mmr = LEAVES.into(); - let acc = mmr.peaks(mmr.forest()).unwrap(); + let acc = mmr.peaks(); // original_forest can't have more elements assert!( @@ -757,7 +754,7 @@ fn test_mmr_delta_old_forest() { #[test] fn test_partial_mmr_simple() { let mmr: Mmr = LEAVES.into(); - let peaks = mmr.peaks(mmr.forest()).unwrap(); + let peaks = mmr.peaks(); let mut partial: PartialMmr = peaks.clone().into(); // check initial state of the partial mmr @@ -768,7 +765,7 @@ fn test_partial_mmr_simple() { assert_eq!(partial.nodes.len(), 0); // check state after adding tracking one element - let proof1 = mmr.open(0, mmr.forest()).unwrap(); + let proof1 = mmr.open(0).unwrap(); let el1 = mmr.get(proof1.position).unwrap(); partial.track(proof1.position, el1, &proof1.merkle_path).unwrap(); @@ -780,7 +777,7 @@ fn test_partial_mmr_simple() { let idx = idx.parent(); assert_eq!(partial.nodes[&idx.sibling()], proof1.merkle_path[1]); - let proof2 = mmr.open(1, mmr.forest()).unwrap(); + let proof2 = mmr.open(1).unwrap(); let el2 = mmr.get(proof2.position).unwrap(); partial.track(proof2.position, el2, &proof2.merkle_path).unwrap(); @@ -798,9 +795,9 @@ fn test_partial_mmr_update_single() { let mut full = Mmr::new(); let zero = int_to_node(0); full.add(zero); - let mut partial: PartialMmr = full.peaks(full.forest()).unwrap().into(); + let mut partial: PartialMmr = full.peaks().into(); - let proof = full.open(0, full.forest()).unwrap(); + let proof = full.open(0).unwrap(); partial.track(proof.position, zero, &proof.merkle_path).unwrap(); for i in 1..100 { @@ -810,9 +807,9 @@ fn test_partial_mmr_update_single() { partial.apply(delta).unwrap(); assert_eq!(partial.forest(), full.forest()); - assert_eq!(partial.peaks(), full.peaks(full.forest()).unwrap()); + assert_eq!(partial.peaks(), full.peaks()); - let proof1 = full.open(i as usize, full.forest()).unwrap(); + let proof1 = full.open(i as usize).unwrap(); partial.track(proof1.position, node, &proof1.merkle_path).unwrap(); let proof2 = partial.open(proof1.position).unwrap().unwrap(); assert_eq!(proof1.merkle_path, proof2.merkle_path); @@ -822,7 +819,7 @@ fn test_partial_mmr_update_single() { #[test] fn test_mmr_add_invalid_odd_leaf() { let mmr: Mmr = LEAVES.into(); - let acc = mmr.peaks(mmr.forest()).unwrap(); + let acc = mmr.peaks(); let mut partial: PartialMmr = acc.clone().into(); let empty = MerklePath::new(Vec::new()); From 762c821217c97a8fd21c60963d658a58dd3a55df Mon Sep 17 00:00:00 2001 From: Qyriad Date: Fri, 9 Aug 2024 17:26:29 -0600 Subject: [PATCH 2/5] refactor: make Smt's node recomputation pure And do mutations in its callers instead. --- src/merkle/smt/mod.rs | 44 +++++++++++++++++++++++++++++------- src/merkle/smt/simple/mod.rs | 12 +++++++++- 2 files changed, 47 insertions(+), 9 deletions(-) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 52ed1d2..bd6a6e9 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -104,23 +104,41 @@ pub(crate) trait SparseMerkleTree { leaf_index.into() }; - self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf)); + let mut mutations = + self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf)); + for index in mutations.removals.drain(..) { + self.remove_inner_node(index); + } + + for (index, new_node) in mutations.additions.drain(..) { + self.insert_inner_node(index, new_node); + } + + self.set_root(mutations.new_root); old_value } /// Recomputes the branch nodes (including the root) from `index` all the way to the root. /// `node_hash_at_index` is the hash of the node stored at index. + /// + /// This method is pure, and only computes the mutations to apply. fn recompute_nodes_from_index_to_root( - &mut self, + &self, mut index: NodeIndex, node_hash_at_index: RpoDigest, - ) { + ) -> Mutations { let mut node_hash = node_hash_at_index; + + let mut removals: Vec = Vec::new(); + let mut additions: Vec<(NodeIndex, InnerNode)> = Vec::new(); + for node_depth in (0..index.depth()).rev() { let is_right = index.is_value_odd(); index.move_up(); + let InnerNode { left, right } = self.get_inner_node(index); + let (left, right) = if is_right { (left, node_hash) } else { @@ -129,14 +147,15 @@ pub(crate) trait SparseMerkleTree { node_hash = Rpo256::merge(&[left, right]); if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) { - // If a subtree is empty, when can remove the inner node, since it's equal to the - // default value - self.remove_inner_node(index) + // If a subtree is empty, we can remove the inner node, since it's equal to the + // default value. + removals.push(index); } else { - self.insert_inner_node(index, InnerNode { left, right }); + additions.push((index, InnerNode { left, right })); } } - self.set_root(node_hash); + + Mutations { removals, additions, new_root: node_hash } } // REQUIRED METHODS @@ -243,3 +262,12 @@ impl TryFrom for LeafIndex { Self::new(node_index.value()) } } + +// MUTATIONS +// ================================================================================================ + +pub(crate) struct Mutations { + removals: Vec, + additions: Vec<(NodeIndex, InnerNode)>, + new_root: RpoDigest, +} diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index 2fa5ae4..ef26322 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -242,7 +242,17 @@ impl SimpleSmt { // recompute nodes starting from subtree root // -------------- - self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root); + let mut mutations = + self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root); + for index in mutations.removals.drain(..) { + self.remove_inner_node(index); + } + + for (index, new_node) in mutations.additions.drain(..) { + self.insert_inner_node(index, new_node); + } + + self.set_root(mutations.new_root); Ok(self.root) } From 65e8f536d7a293a897f939891c0103b2793d768c Mon Sep 17 00:00:00 2001 From: Qyriad Date: Mon, 12 Aug 2024 15:56:07 -0600 Subject: [PATCH 3/5] WIP: implement hash_prospective_leaf() --- src/merkle/smt/full/leaf.rs | 4 +- src/merkle/smt/full/mod.rs | 45 ++++++++++++++++++ src/merkle/smt/full/tests.rs | 92 +++++++++++++++++++++++++++++++++++- src/merkle/smt/mod.rs | 8 ++++ src/merkle/smt/simple/mod.rs | 4 ++ 5 files changed, 150 insertions(+), 3 deletions(-) diff --git a/src/merkle/smt/full/leaf.rs b/src/merkle/smt/full/leaf.rs index 23e1ee4..e82a3f3 100644 --- a/src/merkle/smt/full/leaf.rs +++ b/src/merkle/smt/full/leaf.rs @@ -349,7 +349,7 @@ impl Deserializable for SmtLeaf { // ================================================================================================ /// Converts a key-value tuple to an iterator of `Felt`s -fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator { +pub(crate) fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator { let key_elements = key.into_iter(); let value_elements = value.into_iter(); @@ -358,7 +358,7 @@ fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator /// Compares two keys, compared element-by-element using their integer representations starting with /// the most significant element. -fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering { +pub(crate) fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering { for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() { let v1 = v1.as_int(); let v2 = v2.as_int(); diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 52b416f..1f43634 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -262,6 +262,29 @@ impl SparseMerkleTree for Smt { leaf.hash() } + fn hash_prospective_leaf(&self, key: &RpoDigest, value: &Word) -> RpoDigest { + // If this key already has a value, then the hash will be based off a + // prospective mutation on the leaf. + let leaf_index: LeafIndex = Self::key_to_leaf_index(&key); + match self.leaves.get(&leaf_index.value()) { + Some(existing_leaf) => { + if value == &Self::EMPTY_VALUE { + // A leaf with an empty value is conceptually a removal the + // value in that leaf with this key. + // TODO: avoid cloning the leaf. + let mut cloned = existing_leaf.clone(); + cloned.remove(*key); + return cloned.hash(); + } + // TODO: avoid cloning the leaf. + let mut cloned = existing_leaf.clone(); + cloned.insert(*key, *value); + cloned.hash() + }, + None => SmtLeaf::new_single(*key, *value).hash(), + } + } + fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex { let most_significant_felt = key[3]; LeafIndex::new_max_depth(most_significant_felt.as_int()) @@ -356,3 +379,25 @@ fn test_smt_serialization_deserialization() { let bytes = smt.to_bytes(); assert_eq!(smt, Smt::read_from_bytes(&bytes).unwrap()); } + +#[test] +fn test_prospective_hash() { + // Smt with values + let smt_leaves_2: [(RpoDigest, Word); 2] = [ + ( + RpoDigest::new([Felt::new(101), Felt::new(102), Felt::new(103), Felt::new(104)]), + [Felt::new(1_u64), Felt::new(2_u64), Felt::new(3_u64), Felt::new(4_u64)], + ), + ( + RpoDigest::new([Felt::new(105), Felt::new(106), Felt::new(107), Felt::new(108)]), + [Felt::new(5_u64), Felt::new(6_u64), Felt::new(7_u64), Felt::new(8_u64)], + ), + ]; + let smt = Smt::with_entries(smt_leaves_2).unwrap(); + + for (key, value) in &smt_leaves_2 { + let expected = smt.get_leaf(key).hash(); + let actual = smt.hash_prospective_leaf(key, value); + assert_eq!(expected, actual); + } +} diff --git a/src/merkle/smt/full/tests.rs b/src/merkle/smt/full/tests.rs index e852811..27d24bd 100644 --- a/src/merkle/smt/full/tests.rs +++ b/src/merkle/smt/full/tests.rs @@ -1,6 +1,6 @@ use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH}; use crate::{ - merkle::{EmptySubtreeRoots, MerkleStore}, + merkle::{smt::SparseMerkleTree, EmptySubtreeRoots, MerkleStore}, utils::{Deserializable, Serializable}, Word, ONE, WORD_SIZE, }; @@ -257,6 +257,96 @@ fn test_smt_removal() { } } +#[test] +fn test_prospective_hash() { + let mut smt = Smt::default(); + + let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; + + let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]); + let key_2: RpoDigest = + RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]); + let key_3: RpoDigest = + RpoDigest::from([3_u32.into(), 3_u32.into(), 3_u32.into(), Felt::new(raw)]); + + let value_1 = [ONE; WORD_SIZE]; + let value_2 = [2_u32.into(); WORD_SIZE]; + let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE]; + + // insert key-value 1 + { + let prospective = smt.hash_prospective_leaf(&key_1, &value_1); + let old_value_1 = smt.insert(key_1, value_1); + assert_eq!(old_value_1, EMPTY_WORD); + + assert_eq!(smt.get_leaf(&key_1).hash(), prospective); + + assert_eq!(smt.get_leaf(&key_1), SmtLeaf::Single((key_1, value_1))); + + } + + // insert key-value 2 + { + let prospective = smt.hash_prospective_leaf(&key_2, &value_2); + let old_value_2 = smt.insert(key_2, value_2); + assert_eq!(old_value_2, EMPTY_WORD); + + assert_eq!(smt.get_leaf(&key_2).hash(), prospective); + + assert_eq!( + smt.get_leaf(&key_2), + SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2)]) + ); + } + + // insert key-value 3 + { + let prospective_hash = smt.hash_prospective_leaf(&key_3, &value_3); + let old_value_3 = smt.insert(key_3, value_3); + assert_eq!(old_value_3, EMPTY_WORD); + + assert_eq!(smt.get_leaf(&key_3).hash(), prospective_hash); + + assert_eq!( + smt.get_leaf(&key_3), + SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2), (key_3, value_3)]) + ); + } + + // remove key 3 + { + let old_hash = smt.get_leaf(&key_3).hash(); + let old_value_3 = smt.insert(key_3, EMPTY_WORD); + assert_eq!(old_value_3, value_3); + assert_eq!(old_hash, smt.hash_prospective_leaf(&key_3, &old_value_3)); + + assert_eq!( + smt.get_leaf(&key_3), + SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2)]) + ); + } + + // remove key 2 + { + let old_hash = smt.get_leaf(&key_2).hash(); + let old_value_2 = smt.insert(key_2, EMPTY_WORD); + assert_eq!(old_value_2, value_2); + assert_eq!(old_hash, smt.hash_prospective_leaf(&key_2, &old_value_2)); + + assert_eq!(smt.get_leaf(&key_2), SmtLeaf::Single((key_1, value_1))); + } + + // remove key 1 + { + let old_hash = smt.get_leaf(&key_1).hash(); + let old_value_1 = smt.insert(key_1, EMPTY_WORD); + assert_eq!(old_value_1, value_1); + assert_eq!(old_hash, smt.hash_prospective_leaf(&key_1, &old_value_1)); + + assert_eq!(smt.get_leaf(&key_1), SmtLeaf::new_empty(key_1.into())); + } +} + /// Tests that 2 key-value pairs stored in the same leaf have the same path #[test] fn test_smt_path_to_keys_in_same_leaf_are_equal() { diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index bd6a6e9..bcfdb7b 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -185,6 +185,14 @@ pub(crate) trait SparseMerkleTree { /// Returns the hash of a leaf fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest; + /// Returns the hash of a leaf if the leaf WERE inserted into the tree, + /// without performing any insertion or other mutation. + /// + /// Note: calling this function after actually performing an insert with + /// the same arguments will *not* return the same result, as inserting + /// multiple times with the same key mutates the leaf each time. + fn hash_prospective_leaf(&self, key: &Self::Key, value: &Self::Value) -> RpoDigest; + /// Maps a key to a leaf index fn key_to_leaf_index(key: &Self::Key) -> LeafIndex; diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index ef26322..a6fe794 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -311,6 +311,10 @@ impl SparseMerkleTree for SimpleSmt { leaf.into() } + fn hash_prospective_leaf(&self, _key: &LeafIndex, value: &Word) -> RpoDigest { + Self::hash_leaf(value) + } + fn key_to_leaf_index(key: &LeafIndex) -> LeafIndex { *key } From bd1a6fcd825a1a6da2c0d56baca79a56179b2c1e Mon Sep 17 00:00:00 2001 From: Qyriad Date: Wed, 21 Aug 2024 13:22:11 -0600 Subject: [PATCH 4/5] WIP: add and implement get_value() to Smt trait --- src/merkle/smt/full/mod.rs | 16 ++++++++++------ src/merkle/smt/mod.rs | 2 ++ src/merkle/smt/simple/mod.rs | 8 ++++++++ 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 1f43634..74ed99a 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -120,12 +120,7 @@ impl Smt { /// Returns the value associated with `key` pub fn get_value(&self, key: &RpoDigest) -> Word { - let leaf_pos = LeafIndex::::from(*key).value(); - - match self.leaves.get(&leaf_pos) { - Some(leaf) => leaf.get_value(key).unwrap_or_default(), - None => EMPTY_WORD, - } + >::get_value(self, key) } /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle @@ -249,6 +244,15 @@ impl SparseMerkleTree for Smt { } } + fn get_value(&self, key: &Self::Key) -> Self::Value { + let leaf_pos = LeafIndex::::from(*key).value(); + + match self.leaves.get(&leaf_pos) { + Some(leaf) => leaf.get_value(key).unwrap_or_default(), + None => EMPTY_WORD, + } + } + fn get_leaf(&self, key: &RpoDigest) -> Self::Leaf { let leaf_pos = LeafIndex::::from(*key).value(); diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index bcfdb7b..9330605 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -179,6 +179,8 @@ pub(crate) trait SparseMerkleTree { /// Inserts a leaf node, and returns the value at the key if already exists fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option; + fn get_value(&self, key: &Self::Key) -> Self::Value; + /// Returns the leaf at the specified index. fn get_leaf(&self, key: &Self::Key) -> Self::Leaf; diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index a6fe794..a267045 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -298,6 +298,14 @@ impl SparseMerkleTree for SimpleSmt { } } + fn get_value(&self, key: &Self::Key) -> Self::Value { + let leaf_pos = key.value(); + match self.leaves.get(&leaf_pos) { + Some(word) => *word, + None => Self::EMPTY_VALUE, + } + } + fn get_leaf(&self, key: &LeafIndex) -> Word { let leaf_pos = key.value(); match self.leaves.get(&leaf_pos) { From bf2ca7ab4df6756127ca7af80f87a9241302b6ca Mon Sep 17 00:00:00 2001 From: Qyriad Date: Wed, 21 Aug 2024 14:49:47 -0600 Subject: [PATCH 5/5] WIP: smt: implement root-checked insertion --- src/merkle/smt/full/mod.rs | 19 +++++++ src/merkle/smt/full/tests.rs | 76 +++++++++++++++++++++----- src/merkle/smt/mod.rs | 100 +++++++++++++++++++++++++++++++++++ src/merkle/smt/simple/mod.rs | 19 +++++++ 4 files changed, 201 insertions(+), 13 deletions(-) diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 74ed99a..7b85a11 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -166,6 +166,25 @@ impl Smt { >::insert(self, key, value) } + /// Like [`Self::insert()`], but only performs the insert if the the new tree's root + /// hash would be equal to the hash given in `expected_root`. + /// + /// # Errors + /// Returns [`MerkleError::ConflictingRoots`] with a two-item [Vec] if the new root of the tree is + /// different from the expected root. The first item of the vector is the expected root, and the + /// second is actual root. + /// + /// No mutations are performed if the roots do no match. + pub fn insert_ensure_root( + &mut self, + key: RpoDigest, + value: Word, + expected_root: RpoDigest, + ) -> Result + { + >::insert_ensure_root(self, key, value, expected_root) + } + // HELPERS // -------------------------------------------------------------------------------------------- diff --git a/src/merkle/smt/full/tests.rs b/src/merkle/smt/full/tests.rs index 27d24bd..c4e9caa 100644 --- a/src/merkle/smt/full/tests.rs +++ b/src/merkle/smt/full/tests.rs @@ -1,6 +1,6 @@ use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH}; use crate::{ - merkle::{smt::SparseMerkleTree, EmptySubtreeRoots, MerkleStore}, + merkle::{smt::SparseMerkleTree, EmptySubtreeRoots, MerkleError, MerkleStore}, utils::{Deserializable, Serializable}, Word, ONE, WORD_SIZE, }; @@ -258,8 +258,11 @@ fn test_smt_removal() { } #[test] -fn test_prospective_hash() { +fn test_checked_insertion() { + use MerkleError::ConflictingRoots; + let mut smt = Smt::default(); + let smt_empty = smt.clone(); let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; @@ -273,50 +276,91 @@ fn test_prospective_hash() { let value_2 = [2_u32.into(); WORD_SIZE]; let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE]; + let root_empty = smt.root(); + // insert key-value 1 - { + let root_1 = { let prospective = smt.hash_prospective_leaf(&key_1, &value_1); let old_value_1 = smt.insert(key_1, value_1); assert_eq!(old_value_1, EMPTY_WORD); - assert_eq!(smt.get_leaf(&key_1).hash(), prospective); - + assert_eq!(prospective, smt.get_leaf(&key_1).hash()); assert_eq!(smt.get_leaf(&key_1), SmtLeaf::Single((key_1, value_1))); + smt.root() + }; + + { + // Trying to insert something else into key_1 with the existing root should fail, and + // should not modify the tree at all. + let smt_before = smt.clone(); + assert!(matches!(smt.insert_ensure_root(key_1, value_2, root_1), Err(ConflictingRoots(_)))); + assert_eq!(smt, smt_before); + + // And inserting an empty word should bring us back to where we were. + assert_eq!(smt.insert_ensure_root(key_1, EMPTY_WORD, root_empty), Ok(value_1)); + assert_eq!(smt, smt_empty); + + smt.insert_ensure_root(key_1, value_1, root_1).unwrap(); + assert_eq!(smt, smt_before); } // insert key-value 2 - { + let root_2 = { let prospective = smt.hash_prospective_leaf(&key_2, &value_2); let old_value_2 = smt.insert(key_2, value_2); assert_eq!(old_value_2, EMPTY_WORD); - assert_eq!(smt.get_leaf(&key_2).hash(), prospective); + assert_eq!(prospective, smt.get_leaf(&key_2).hash()); assert_eq!( smt.get_leaf(&key_2), SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2)]) ); + + smt.root() + }; + + { + let smt_before = smt.clone(); + assert!(matches!(smt.insert_ensure_root(key_2, value_1, root_2), Err(ConflictingRoots(_)))); + assert_eq!(smt, smt_before); + + assert_eq!(smt.insert_ensure_root(key_2, EMPTY_WORD, root_1), Ok(value_2)); + smt.insert_ensure_root(key_2, value_2, root_2).unwrap(); + assert_eq!(smt, smt_before); } // insert key-value 3 - { - let prospective_hash = smt.hash_prospective_leaf(&key_3, &value_3); + let root_3 = { + let prospective = smt.hash_prospective_leaf(&key_3, &value_3); let old_value_3 = smt.insert(key_3, value_3); assert_eq!(old_value_3, EMPTY_WORD); - assert_eq!(smt.get_leaf(&key_3).hash(), prospective_hash); + assert_eq!(prospective, smt.get_leaf(&key_3).hash()); assert_eq!( smt.get_leaf(&key_3), SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2), (key_3, value_3)]) ); + + smt.root() + }; + + { + let smt_before = smt.clone(); + assert!(matches!(smt.insert_ensure_root(key_3, value_1, root_3), Err(ConflictingRoots(_)))); + assert_eq!(smt, smt_before); + + assert_eq!(smt.insert_ensure_root(key_3, EMPTY_WORD, root_2), Ok(value_3)); + smt.insert_ensure_root(key_3, value_3, root_3).unwrap(); + assert_eq!(smt, smt_before); } // remove key 3 { let old_hash = smt.get_leaf(&key_3).hash(); - let old_value_3 = smt.insert(key_3, EMPTY_WORD); + let old_value_3 = smt.insert_ensure_root(key_3, EMPTY_WORD, root_2).unwrap(); assert_eq!(old_value_3, value_3); assert_eq!(old_hash, smt.hash_prospective_leaf(&key_3, &old_value_3)); @@ -324,26 +368,32 @@ fn test_prospective_hash() { smt.get_leaf(&key_3), SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2)]) ); + + assert_eq!(smt.root(), root_2); } // remove key 2 { let old_hash = smt.get_leaf(&key_2).hash(); - let old_value_2 = smt.insert(key_2, EMPTY_WORD); + let old_value_2 = smt.insert_ensure_root(key_2, EMPTY_WORD, root_1).unwrap(); assert_eq!(old_value_2, value_2); assert_eq!(old_hash, smt.hash_prospective_leaf(&key_2, &old_value_2)); assert_eq!(smt.get_leaf(&key_2), SmtLeaf::Single((key_1, value_1))); + + assert_eq!(smt.root(), root_1); } // remove key 1 { let old_hash = smt.get_leaf(&key_1).hash(); - let old_value_1 = smt.insert(key_1, EMPTY_WORD); + let old_value_1 = smt.insert_ensure_root(key_1, EMPTY_WORD, root_empty).unwrap(); assert_eq!(old_value_1, value_1); assert_eq!(old_hash, smt.hash_prospective_leaf(&key_1, &old_value_1)); assert_eq!(smt.get_leaf(&key_1), SmtLeaf::new_empty(key_1.into())); + + assert_eq!(smt.root(), root_empty); } } diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 9330605..ebc2f5d 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -119,6 +119,106 @@ pub(crate) trait SparseMerkleTree { old_value } + /// Like [`Self::insert()`], but only performs the insert if the the new tree's root + /// hash would be equal to the hash given in `expected_root`. + /// + /// # Errors + /// Returns [`MerkleError::ConflictingRoots`] with a two-item [Vec] if the new root of the tree is + /// different from the expected root. The first item of the vector is the expected root, and the + /// second is actual root. + /// + /// No mutations are performed if the roots do no match. + fn insert_ensure_root( + &mut self, + key: Self::Key, + value: Self::Value, + expected_root: RpoDigest, + ) -> Result + { + + let old_value = self.get_value(&key); + // if the old value and new value are the same, there is nothing to update + if value == old_value { + return Ok(value); + } + + // Compute the nodes we'll need to make and remove. + let mut removals: Vec = Vec::with_capacity(DEPTH as usize); + let mut additions: Vec<(NodeIndex, InnerNode)> = Vec::with_capacity(DEPTH as usize); + + let (mut node_index, mut parent_node) = { + let leaf_index: LeafIndex = Self::key_to_leaf_index(&key); + let node_index = NodeIndex::from(leaf_index); + + let mut parent_index = node_index.clone(); + parent_index.move_up(); + + (node_index, Some(self.get_inner_node(parent_index))) + }; + + let mut new_child_hash = self.hash_prospective_leaf(&key, &value); + for node_depth in (0..node_index.depth()).rev() { + let is_right = node_index.is_value_odd(); + node_index.move_up(); + + let old_node = match parent_node.take() { + // On the first iteration, the 'old node' is the parent of the + // perspective leaf. + Some(parent_node) => parent_node, + // Otherwise it's a regular existing node. + None => self.get_inner_node(node_index), + }; + + //let new_node = new_node_from(is_right, old_node, new_child_hash); + let new_node = if is_right { + InnerNode { + left: old_node.left, + right: new_child_hash, + } + } else { + InnerNode { + left: new_child_hash, + right: old_node.right, + } + }; + + // The next iteration will operate on this node's new hash. + new_child_hash = new_node.hash(); + + let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth); + if new_child_hash == equivalent_empty_hash { + // If a subtree is empty, we can remove the inner node, since it's equal to the + // default value. + removals.push(node_index); + } else { + additions.push((node_index, new_node)); + } + } + + // Once we're at depth 0, the last node we made is the new root. + let new_root = new_child_hash; + + if expected_root != new_root { + return Err(MerkleError::ConflictingRoots(vec![expected_root, new_root])); + } + + // Actual mutations start here. + + self.insert_value(key, value); + + for index in removals.drain(..) { + self.remove_inner_node(index); + } + + for (index, new_node) in additions.drain(..) { + self.insert_inner_node(index, new_node); + } + + self.set_root(new_root); + + Ok(old_value) + } + /// Recomputes the branch nodes (including the root) from `index` all the way to the root. /// `node_hash_at_index` is the hash of the node stored at index. /// diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index a267045..14307a9 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -187,6 +187,25 @@ impl SimpleSmt { >::insert(self, key, value) } + /// Like [`Self::insert()`], but only performs the insert if the the new tree's root + /// hash would be equal to the hash given in `expected_root`. + /// + /// # Errors + /// Returns [`MerkleError::ConflictingRoots`] with a two-item [Vec] if the new root of the tree is + /// different from the expected root. The first item of the vector is the expected root, and the + /// second is actual root. + /// + /// No mutations are performed if the roots do no match. + pub fn insert_ensure_root( + &mut self, + key: LeafIndex, + value: Word, + expected_root: RpoDigest, + ) -> Result + { + >::insert_ensure_root(self, key, value, expected_root) + } + /// Inserts a subtree at the specified index. The depth at which the subtree is inserted is /// computed as `DEPTH - SUBTREE_DEPTH`. ///