From 6bf895f027ed0246269786b8a55638d424667de9 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Mon, 4 Nov 2024 12:53:27 -0700 Subject: [PATCH] factor out subtree-append logic --- src/merkle/smt/mod.rs | 395 ++++++++++++++++++++++-------------------- 1 file changed, 203 insertions(+), 192 deletions(-) diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 69fe901..99fdfa8 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -68,28 +68,28 @@ pub(crate) trait SparseMerkleTree { /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle /// path to the leaf, as well as the leaf itself. fn open(&self, key: &Self::Key) -> Self::Opening { - let leaf = self.get_leaf(key); + let leaf = self.get_leaf(key); - let mut index: NodeIndex = { - let leaf_index: LeafIndex = Self::key_to_leaf_index(key); - leaf_index.into() - }; + let mut index: NodeIndex = { + let leaf_index: LeafIndex = Self::key_to_leaf_index(key); + leaf_index.into() + }; - let merkle_path = { - let mut path = Vec::with_capacity(index.depth() as usize); - for _ in 0..index.depth() { - let is_right = index.is_value_odd(); - index.move_up(); - let InnerNode { left, right } = self.get_inner_node(index); - let value = if is_right { left } else { right }; - path.push(value); - } + let merkle_path = { + let mut path = Vec::with_capacity(index.depth() as usize); + for _ in 0..index.depth() { + let is_right = index.is_value_odd(); + index.move_up(); + let InnerNode { left, right } = self.get_inner_node(index); + let value = if is_right { left } else { right }; + path.push(value); + } - MerklePath::new(path) - }; + MerklePath::new(path) + }; - Self::path_and_leaf_to_opening(merkle_path, leaf) -} + Self::path_and_leaf_to_opening(merkle_path, leaf) + } /// Inserts a value at the specified key, returning the previous value associated with that key. /// Recall that by definition, any key that hasn't been updated is associated with @@ -98,53 +98,53 @@ pub(crate) trait SparseMerkleTree { /// This also recomputes all hashes between the leaf (associated with the key) and the root, /// updating the root itself. fn insert(&mut self, key: Self::Key, value: Self::Value) -> Self::Value { - let old_value = self.insert_value(key.clone(), value.clone()).unwrap_or(Self::EMPTY_VALUE); + let old_value = self.insert_value(key.clone(), value.clone()).unwrap_or(Self::EMPTY_VALUE); - // if the old value and new value are the same, there is nothing to update - if value == old_value { - return value; + // if the old value and new value are the same, there is nothing to update + if value == old_value { + return value; + } + + let leaf = self.get_leaf(&key); + let node_index = { + let leaf_index: LeafIndex = Self::key_to_leaf_index(&key); + leaf_index.into() + }; + + self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf)); + + old_value } - let leaf = self.get_leaf(&key); - let node_index = { - let leaf_index: LeafIndex = Self::key_to_leaf_index(&key); - leaf_index.into() - }; - - self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf)); - - old_value -} - /// Recomputes the branch nodes (including the root) from `index` all the way to the root. /// `node_hash_at_index` is the hash of the node stored at index. fn recompute_nodes_from_index_to_root( - &mut self, - mut index: NodeIndex, - node_hash_at_index: RpoDigest, -) { - let mut node_hash = node_hash_at_index; - for node_depth in (0..index.depth()).rev() { - let is_right = index.is_value_odd(); - index.move_up(); - let InnerNode { left, right } = self.get_inner_node(index); - let (left, right) = if is_right { - (left, node_hash) - } else { - (node_hash, right) - }; - node_hash = Rpo256::merge(&[left, right]); + &mut self, + mut index: NodeIndex, + node_hash_at_index: RpoDigest, + ) { + let mut node_hash = node_hash_at_index; + for node_depth in (0..index.depth()).rev() { + let is_right = index.is_value_odd(); + index.move_up(); + let InnerNode { left, right } = self.get_inner_node(index); + let (left, right) = if is_right { + (left, node_hash) + } else { + (node_hash, right) + }; + node_hash = Rpo256::merge(&[left, right]); - if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) { - // If a subtree is empty, when can remove the inner node, since it's equal to the - // default value - self.remove_inner_node(index) - } else { - self.insert_inner_node(index, InnerNode { left, right }); + if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) { + // If a subtree is empty, when can remove the inner node, since it's equal to the + // default value + self.remove_inner_node(index) + } else { + self.insert_inner_node(index, InnerNode { left, right }); + } } + self.set_root(node_hash); } - self.set_root(node_hash); -} /// Computes what changes are necessary to insert the specified key-value pairs into this Merkle /// tree, allowing for validation before applying those changes. @@ -155,95 +155,95 @@ pub(crate) trait SparseMerkleTree { /// [`SparseMerkleTree::apply_mutations()`] can be called in order to commit these changes to /// the Merkle tree, or [`drop()`] to discard them. fn compute_mutations( - &self, - kv_pairs: impl IntoIterator, -) -> MutationSet { - use NodeMutation::*; + &self, + kv_pairs: impl IntoIterator, + ) -> MutationSet { + use NodeMutation::*; - let mut new_root = self.root(); - let mut new_pairs: BTreeMap = Default::default(); - let mut node_mutations: BTreeMap = Default::default(); + let mut new_root = self.root(); + let mut new_pairs: BTreeMap = Default::default(); + let mut node_mutations: BTreeMap = Default::default(); - for (key, value) in kv_pairs { - // If the old value and the new value are the same, there is nothing to update. - // For the unusual case that kv_pairs has multiple values at the same key, we'll have - // to check the key-value pairs we've already seen to get the "effective" old value. - let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key)); - if value == old_value { - continue; - } + for (key, value) in kv_pairs { + // If the old value and the new value are the same, there is nothing to update. + // For the unusual case that kv_pairs has multiple values at the same key, we'll have + // to check the key-value pairs we've already seen to get the "effective" old value. + let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key)); + if value == old_value { + continue; + } - let leaf_index = Self::key_to_leaf_index(&key); - let mut node_index = NodeIndex::from(leaf_index); + let leaf_index = Self::key_to_leaf_index(&key); + let mut node_index = NodeIndex::from(leaf_index); - // We need the current leaf's hash to calculate the new leaf, but in the rare case that - // `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also - // part of the "current leaf". - let old_leaf = { - let pairs_at_index = new_pairs - .iter() - .filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index); + // We need the current leaf's hash to calculate the new leaf, but in the rare case that + // `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also + // part of the "current leaf". + let old_leaf = { + let pairs_at_index = new_pairs + .iter() + .filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index); - pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| { - // Most of the time `pairs_at_index` should only contain a single entry (or - // none at all), as multi-leaves should be really rare. - let existing_leaf = acc.clone(); - self.construct_prospective_leaf(existing_leaf, k, v) - }) - }; - - let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value); - - let mut new_child_hash = Self::hash_leaf(&new_leaf); - - for node_depth in (0..node_index.depth()).rev() { - // Whether the node we're replacing is the right child or the left child. - let is_right = node_index.is_value_odd(); - node_index.move_up(); - - let old_node = node_mutations - .get(&node_index) - .map(|mutation| match mutation { - Addition(node) => node.clone(), - Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth), + pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| { + // Most of the time `pairs_at_index` should only contain a single entry (or + // none at all), as multi-leaves should be really rare. + let existing_leaf = acc.clone(); + self.construct_prospective_leaf(existing_leaf, k, v) }) - .unwrap_or_else(|| self.get_inner_node(node_index)); - - let new_node = if is_right { - InnerNode { - left: old_node.left, - right: new_child_hash, - } - } else { - InnerNode { - left: new_child_hash, - right: old_node.right, - } }; - // The next iteration will operate on this new node's hash. - new_child_hash = new_node.hash(); + let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value); - let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth); - let is_removal = new_child_hash == equivalent_empty_hash; - let new_entry = if is_removal { Removal } else { Addition(new_node) }; - node_mutations.insert(node_index, new_entry); + let mut new_child_hash = Self::hash_leaf(&new_leaf); + + for node_depth in (0..node_index.depth()).rev() { + // Whether the node we're replacing is the right child or the left child. + let is_right = node_index.is_value_odd(); + node_index.move_up(); + + let old_node = node_mutations + .get(&node_index) + .map(|mutation| match mutation { + Addition(node) => node.clone(), + Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth), + }) + .unwrap_or_else(|| self.get_inner_node(node_index)); + + let new_node = if is_right { + InnerNode { + left: old_node.left, + right: new_child_hash, + } + } else { + InnerNode { + left: new_child_hash, + right: old_node.right, + } + }; + + // The next iteration will operate on this new node's hash. + new_child_hash = new_node.hash(); + + let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth); + let is_removal = new_child_hash == equivalent_empty_hash; + let new_entry = if is_removal { Removal } else { Addition(new_node) }; + node_mutations.insert(node_index, new_entry); + } + + // Once we're at depth 0, the last node we made is the new root. + new_root = new_child_hash; + // And then we're done with this pair; on to the next one. + new_pairs.insert(key, value); } - // Once we're at depth 0, the last node we made is the new root. - new_root = new_child_hash; - // And then we're done with this pair; on to the next one. - new_pairs.insert(key, value); + MutationSet { + old_root: self.root(), + new_root, + node_mutations, + new_pairs, + } } - MutationSet { - old_root: self.root(), - new_root, - node_mutations, - new_pairs, - } -} - /// Apply the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to /// this tree. /// @@ -253,42 +253,42 @@ pub(crate) trait SparseMerkleTree { /// the `mutations` were computed against, and the second item is the actual current root of /// this tree. fn apply_mutations( - &mut self, - mutations: MutationSet, -) -> Result<(), MerkleError> -where - Self: Sized, -{ - use NodeMutation::*; - let MutationSet { - old_root, - node_mutations, - new_pairs, - new_root, - } = mutations; + &mut self, + mutations: MutationSet, + ) -> Result<(), MerkleError> + where + Self: Sized, + { + use NodeMutation::*; + let MutationSet { + old_root, + node_mutations, + new_pairs, + new_root, + } = mutations; - // Guard against accidentally trying to apply mutations that were computed against a - // different tree, including a stale version of this tree. - if old_root != self.root() { - return Err(MerkleError::ConflictingRoots(vec![old_root, self.root()])); - } - - for (index, mutation) in node_mutations { - match mutation { - Removal => self.remove_inner_node(index), - Addition(node) => self.insert_inner_node(index, node), + // Guard against accidentally trying to apply mutations that were computed against a + // different tree, including a stale version of this tree. + if old_root != self.root() { + return Err(MerkleError::ConflictingRoots(vec![old_root, self.root()])); } + + for (index, mutation) in node_mutations { + match mutation { + Removal => self.remove_inner_node(index), + Addition(node) => self.insert_inner_node(index, node), + } + } + + for (key, value) in new_pairs { + self.insert_value(key, value); + } + + self.set_root(new_root); + + Ok(()) } - for (key, value) in new_pairs { - self.insert_value(key, value); - } - - self.set_root(new_root); - - Ok(()) -} - // REQUIRED METHODS // --------------------------------------------------------------------------------------------- @@ -332,11 +332,11 @@ where /// `existing_leaf` must have the same leaf index as `key` (as determined by /// [`SparseMerkleTree::key_to_leaf_index()`]), or the result will be meaningless. fn construct_prospective_leaf( - &self, - existing_leaf: Self::Leaf, - key: &Self::Key, - value: &Self::Value, -) -> Self::Leaf; + &self, + existing_leaf: Self::Leaf, + key: &Self::Key, + value: &Self::Value, + ) -> Self::Leaf; /// Maps a key to a leaf index fn key_to_leaf_index(key: &Self::Key) -> LeafIndex; @@ -383,7 +383,7 @@ where let hash = Self::hash_leaf(&leaf); accumulator.nodes.insert(col, leaf); - accumulator.add_leaf(SubtreeLeaf { col, hash }); + add_subtree_leaf(&mut accumulator.leaves, SubtreeLeaf { col, hash }); debug_assert!(current_leaf_buffer.is_empty()); } @@ -631,6 +631,7 @@ impl SubtreeLeaf { } } +/// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`]. #[derive(Debug, Clone, PartialEq, Eq)] pub(crate) struct PairComputations { /// Literal leaves to be added to the sparse Merkle tree's internal mapping. @@ -724,6 +725,36 @@ impl<'s> core::iter::Iterator for SubtreeLeavesIter<'s> { Some(subtree) } } +/// Handles the logic for figuring out whether the new leaf starts a new subtree or not. + +fn add_subtree_leaf(subtrees: &mut Vec>, leaf: SubtreeLeaf) { + let last_subtree = match subtrees.last_mut() { + // Base case. + None => { + subtrees.push(vec![leaf]); + return; + }, + Some(last_subtree) => last_subtree, + }; + + debug_assert!(!last_subtree.is_empty()); + debug_assert!(last_subtree.len() <= COLS_PER_SUBTREE as usize); + + // The multiple of 256 after 0 is 1, but 0 and 1 do not belong to different subtrees. + let last_subtree_col = u64::max(1, last_subtree.last().unwrap().col); + let next_subtree_col = if Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE) { + u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE) + } else { + last_subtree_col.next_multiple_of(COLS_PER_SUBTREE) + }; + + if leaf.col < next_subtree_col { + last_subtree.push(leaf); + } else { + let next_subtree = vec![leaf]; + subtrees.push(next_subtree); + } +} // TESTS // ================================================================================================ @@ -733,9 +764,7 @@ mod test { use alloc::{collections::BTreeMap, vec::Vec}; - use num::Integer; - - use super::{InnerNode, PairComputations, SparseMerkleTree, SubtreeLeaf, COLS_PER_SUBTREE}; + use super::{InnerNode, PairComputations, SparseMerkleTree, SubtreeLeaf}; use crate::{ hash::rpo::RpoDigest, merkle::{NodeIndex, Smt, SmtLeaf, SMT_DEPTH}, @@ -950,25 +979,7 @@ mod test { accumulated_nodes.extend(nodes); for subtree_leaf in next_leaves { - if leaf_subtrees.is_empty() { - leaf_subtrees.push(vec![subtree_leaf]); - continue; - } - - let buffer_max_col = - u64::max(1, leaf_subtrees.last().unwrap().last().unwrap().col); - let next_subtree_col = - if Integer::is_multiple_of(&buffer_max_col, &COLS_PER_SUBTREE) { - u64::next_multiple_of(buffer_max_col + 1, COLS_PER_SUBTREE) - } else { - buffer_max_col.next_multiple_of(COLS_PER_SUBTREE) - }; - - if subtree_leaf.col < next_subtree_col { - leaf_subtrees.last_mut().unwrap().push(subtree_leaf); - } else { - leaf_subtrees.push(vec![subtree_leaf]); - } + super::add_subtree_leaf(&mut leaf_subtrees, subtree_leaf); } }