From fdca917a43acadf0aca6e384ea3e446ef577f8fa Mon Sep 17 00:00:00 2001 From: Qyriad Date: Wed, 21 Aug 2024 14:49:47 -0600 Subject: [PATCH 1/2] WIP: smt: implement root-checked insertion --- src/merkle/smt/full/mod.rs | 19 +++++++ src/merkle/smt/full/tests.rs | 81 ++++++++++++++++++++++------ src/merkle/smt/mod.rs | 100 +++++++++++++++++++++++++++++++++++ src/merkle/smt/simple/mod.rs | 19 +++++++ 4 files changed, 204 insertions(+), 15 deletions(-) diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 74ed99a..7b85a11 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -166,6 +166,25 @@ impl Smt { >::insert(self, key, value) } + /// Like [`Self::insert()`], but only performs the insert if the the new tree's root + /// hash would be equal to the hash given in `expected_root`. + /// + /// # Errors + /// Returns [`MerkleError::ConflictingRoots`] with a two-item [Vec] if the new root of the tree is + /// different from the expected root. The first item of the vector is the expected root, and the + /// second is actual root. + /// + /// No mutations are performed if the roots do no match. + pub fn insert_ensure_root( + &mut self, + key: RpoDigest, + value: Word, + expected_root: RpoDigest, + ) -> Result + { + >::insert_ensure_root(self, key, value, expected_root) + } + // HELPERS // -------------------------------------------------------------------------------------------- diff --git a/src/merkle/smt/full/tests.rs b/src/merkle/smt/full/tests.rs index 27d24bd..cddc75e 100644 --- a/src/merkle/smt/full/tests.rs +++ b/src/merkle/smt/full/tests.rs @@ -1,6 +1,6 @@ use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH}; use crate::{ - merkle::{smt::SparseMerkleTree, EmptySubtreeRoots, MerkleStore}, + merkle::{smt::SparseMerkleTree, EmptySubtreeRoots, MerkleError, MerkleStore}, utils::{Deserializable, Serializable}, Word, ONE, WORD_SIZE, }; @@ -258,65 +258,110 @@ fn test_smt_removal() { } #[test] -fn test_prospective_hash() { +fn test_checked_insertion() { + use MerkleError::ConflictingRoots; + let mut smt = Smt::default(); + let smt_empty = smt.clone(); let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64; let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]); let key_2: RpoDigest = RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]); + // Sort key_3 before key_1, to test non-append insertion. let key_3: RpoDigest = - RpoDigest::from([3_u32.into(), 3_u32.into(), 3_u32.into(), Felt::new(raw)]); + RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]); let value_1 = [ONE; WORD_SIZE]; let value_2 = [2_u32.into(); WORD_SIZE]; let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE]; + let root_empty = smt.root(); + // insert key-value 1 - { + let root_1 = { let prospective = smt.hash_prospective_leaf(&key_1, &value_1); let old_value_1 = smt.insert(key_1, value_1); assert_eq!(old_value_1, EMPTY_WORD); - assert_eq!(smt.get_leaf(&key_1).hash(), prospective); - + assert_eq!(prospective, smt.get_leaf(&key_1).hash()); assert_eq!(smt.get_leaf(&key_1), SmtLeaf::Single((key_1, value_1))); + smt.root() + }; + + { + // Trying to insert something else into key_1 with the existing root should fail, and + // should not modify the tree at all. + let smt_before = smt.clone(); + assert!(matches!(smt.insert_ensure_root(key_1, value_2, root_1), Err(ConflictingRoots(_)))); + assert_eq!(smt, smt_before); + + // And inserting an empty word should bring us back to where we were. + assert_eq!(smt.insert_ensure_root(key_1, EMPTY_WORD, root_empty), Ok(value_1)); + assert_eq!(smt, smt_empty); + + smt.insert_ensure_root(key_1, value_1, root_1).unwrap(); + assert_eq!(smt, smt_before); } // insert key-value 2 - { + let root_2 = { let prospective = smt.hash_prospective_leaf(&key_2, &value_2); let old_value_2 = smt.insert(key_2, value_2); assert_eq!(old_value_2, EMPTY_WORD); - assert_eq!(smt.get_leaf(&key_2).hash(), prospective); + assert_eq!(prospective, smt.get_leaf(&key_2).hash()); assert_eq!( smt.get_leaf(&key_2), SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2)]) ); + + smt.root() + }; + + { + let smt_before = smt.clone(); + assert!(matches!(smt.insert_ensure_root(key_2, value_1, root_2), Err(ConflictingRoots(_)))); + assert_eq!(smt, smt_before); + + assert_eq!(smt.insert_ensure_root(key_2, EMPTY_WORD, root_1), Ok(value_2)); + smt.insert_ensure_root(key_2, value_2, root_2).unwrap(); + assert_eq!(smt, smt_before); } // insert key-value 3 - { - let prospective_hash = smt.hash_prospective_leaf(&key_3, &value_3); + let root_3 = { + let prospective = smt.hash_prospective_leaf(&key_3, &value_3); let old_value_3 = smt.insert(key_3, value_3); assert_eq!(old_value_3, EMPTY_WORD); - assert_eq!(smt.get_leaf(&key_3).hash(), prospective_hash); + assert_eq!(prospective, smt.get_leaf(&key_3).hash()); assert_eq!( smt.get_leaf(&key_3), - SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2), (key_3, value_3)]) + SmtLeaf::Multiple(vec![(key_3, value_3), (key_1, value_1), (key_2, value_2)]) ); + + smt.root() + }; + + { + let smt_before = smt.clone(); + assert!(matches!(smt.insert_ensure_root(key_3, value_1, root_3), Err(ConflictingRoots(_)))); + assert_eq!(smt, smt_before); + + assert_eq!(smt.insert_ensure_root(key_3, EMPTY_WORD, root_2), Ok(value_3)); + smt.insert_ensure_root(key_3, value_3, root_3).unwrap(); + assert_eq!(smt, smt_before); } // remove key 3 { let old_hash = smt.get_leaf(&key_3).hash(); - let old_value_3 = smt.insert(key_3, EMPTY_WORD); + let old_value_3 = smt.insert_ensure_root(key_3, EMPTY_WORD, root_2).unwrap(); assert_eq!(old_value_3, value_3); assert_eq!(old_hash, smt.hash_prospective_leaf(&key_3, &old_value_3)); @@ -324,26 +369,32 @@ fn test_prospective_hash() { smt.get_leaf(&key_3), SmtLeaf::Multiple(vec![(key_1, value_1), (key_2, value_2)]) ); + + assert_eq!(smt.root(), root_2); } // remove key 2 { let old_hash = smt.get_leaf(&key_2).hash(); - let old_value_2 = smt.insert(key_2, EMPTY_WORD); + let old_value_2 = smt.insert_ensure_root(key_2, EMPTY_WORD, root_1).unwrap(); assert_eq!(old_value_2, value_2); assert_eq!(old_hash, smt.hash_prospective_leaf(&key_2, &old_value_2)); assert_eq!(smt.get_leaf(&key_2), SmtLeaf::Single((key_1, value_1))); + + assert_eq!(smt.root(), root_1); } // remove key 1 { let old_hash = smt.get_leaf(&key_1).hash(); - let old_value_1 = smt.insert(key_1, EMPTY_WORD); + let old_value_1 = smt.insert_ensure_root(key_1, EMPTY_WORD, root_empty).unwrap(); assert_eq!(old_value_1, value_1); assert_eq!(old_hash, smt.hash_prospective_leaf(&key_1, &old_value_1)); assert_eq!(smt.get_leaf(&key_1), SmtLeaf::new_empty(key_1.into())); + + assert_eq!(smt.root(), root_empty); } } diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index 53e5792..9bf1cea 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -109,6 +109,106 @@ pub(crate) trait SparseMerkleTree { old_value } + /// Like [`Self::insert()`], but only performs the insert if the the new tree's root + /// hash would be equal to the hash given in `expected_root`. + /// + /// # Errors + /// Returns [`MerkleError::ConflictingRoots`] with a two-item [Vec] if the new root of the tree is + /// different from the expected root. The first item of the vector is the expected root, and the + /// second is actual root. + /// + /// No mutations are performed if the roots do no match. + fn insert_ensure_root( + &mut self, + key: Self::Key, + value: Self::Value, + expected_root: RpoDigest, + ) -> Result + { + + let old_value = self.get_value(&key); + // if the old value and new value are the same, there is nothing to update + if value == old_value { + return Ok(value); + } + + // Compute the nodes we'll need to make and remove. + let mut removals: Vec = Vec::with_capacity(DEPTH as usize); + let mut additions: Vec<(NodeIndex, InnerNode)> = Vec::with_capacity(DEPTH as usize); + + let (mut node_index, mut parent_node) = { + let leaf_index: LeafIndex = Self::key_to_leaf_index(&key); + let node_index = NodeIndex::from(leaf_index); + + let mut parent_index = node_index.clone(); + parent_index.move_up(); + + (node_index, Some(self.get_inner_node(parent_index))) + }; + + let mut new_child_hash = self.hash_prospective_leaf(&key, &value); + for node_depth in (0..node_index.depth()).rev() { + let is_right = node_index.is_value_odd(); + node_index.move_up(); + + let old_node = match parent_node.take() { + // On the first iteration, the 'old node' is the parent of the + // perspective leaf. + Some(parent_node) => parent_node, + // Otherwise it's a regular existing node. + None => self.get_inner_node(node_index), + }; + + //let new_node = new_node_from(is_right, old_node, new_child_hash); + let new_node = if is_right { + InnerNode { + left: old_node.left, + right: new_child_hash, + } + } else { + InnerNode { + left: new_child_hash, + right: old_node.right, + } + }; + + // The next iteration will operate on this node's new hash. + new_child_hash = new_node.hash(); + + let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth); + if new_child_hash == equivalent_empty_hash { + // If a subtree is empty, we can remove the inner node, since it's equal to the + // default value. + removals.push(node_index); + } else { + additions.push((node_index, new_node)); + } + } + + // Once we're at depth 0, the last node we made is the new root. + let new_root = new_child_hash; + + if expected_root != new_root { + return Err(MerkleError::ConflictingRoots(vec![expected_root, new_root])); + } + + // Actual mutations start here. + + self.insert_value(key, value); + + for index in removals.drain(..) { + self.remove_inner_node(index); + } + + for (index, new_node) in additions.drain(..) { + self.insert_inner_node(index, new_node); + } + + self.set_root(new_root); + + Ok(old_value) + } + /// Recomputes the branch nodes (including the root) from `index` all the way to the root. /// `node_hash_at_index` is the hash of the node stored at index. fn recompute_nodes_from_index_to_root( diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index ee1ded6..00b3b76 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -187,6 +187,25 @@ impl SimpleSmt { >::insert(self, key, value) } + /// Like [`Self::insert()`], but only performs the insert if the the new tree's root + /// hash would be equal to the hash given in `expected_root`. + /// + /// # Errors + /// Returns [`MerkleError::ConflictingRoots`] with a two-item [Vec] if the new root of the tree is + /// different from the expected root. The first item of the vector is the expected root, and the + /// second is actual root. + /// + /// No mutations are performed if the roots do no match. + pub fn insert_ensure_root( + &mut self, + key: LeafIndex, + value: Word, + expected_root: RpoDigest, + ) -> Result + { + >::insert_ensure_root(self, key, value, expected_root) + } + /// Inserts a subtree at the specified index. The depth at which the subtree is inserted is /// computed as `DEPTH - SUBTREE_DEPTH`. /// From 7c7a35e8879ebde61b24b21a29aee3229a92d072 Mon Sep 17 00:00:00 2001 From: Qyriad Date: Wed, 21 Aug 2024 20:25:50 -0600 Subject: [PATCH 2/2] WIP: remove a *bunch* of allocations and clones in hash_prospective_leaf --- src/merkle/smt/full/mod.rs | 98 +++++++++++++++++++++++++++++++------- 1 file changed, 82 insertions(+), 16 deletions(-) diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index 7b85a11..8ece2e0 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -7,6 +7,7 @@ use alloc::{ string::ToString, vec::Vec, }; +use core::iter; mod error; pub use error::{SmtLeafError, SmtProofError}; @@ -286,25 +287,90 @@ impl SparseMerkleTree for Smt { } fn hash_prospective_leaf(&self, key: &RpoDigest, value: &Word) -> RpoDigest { - // If this key already has a value, then the hash will be based off a - // prospective mutation on the leaf. - let leaf_index: LeafIndex = Self::key_to_leaf_index(&key); + // This function combines logic from SmtLeaf::insert() and SmtLeaf::hash() to determine what + // the hash of a leaf would be with the `(key, value)` pair inserted into it, without simply + // cloning the leaf which could be expensive for some leaves, and is easily avoidable when + // we can combine the insertion and hashing operations. + let new_pair = (*key, *value); + let is_removal: bool = *value == EMPTY_WORD; + + let leaf_index: LeafIndex = Self::key_to_leaf_index(key); match self.leaves.get(&leaf_index.value()) { - Some(existing_leaf) => { - if value == &Self::EMPTY_VALUE { - // A leaf with an empty value is conceptually a removal the - // value in that leaf with this key. - // TODO: avoid cloning the leaf. - let mut cloned = existing_leaf.clone(); - cloned.remove(*key); - return cloned.hash(); + // If this key doesn't have a value, our job is very simple. + None => SmtLeaf::Single(new_pair).hash(), + + // If this key already has a value, then the hash will be based off a prospective + // mutation on the leaf. + Some(existing_leaf) => match existing_leaf { + // Inserting an empty value into an empty leaf or a single leaf both do the same + // thing. + SmtLeaf::Empty(_) | SmtLeaf::Single(_) if is_removal => { + SmtLeaf::new_empty(key.into()).hash() + }, + + SmtLeaf::Empty(_) => SmtLeaf::Single(new_pair).hash(), + + SmtLeaf::Single(pair) => { + if pair.0 == *key { + SmtLeaf::Single(new_pair).hash() + } else { + // Inserting a non-empty value into a new key would change this to a + // multi-leaf. + // TODO: mini-optimization: use an array with each key's and value's Felts + // flattened inline to avoid the Vec allocation. + let elements: Vec = [*pair, new_pair] + .into_iter() + .flat_map(leaf::kv_to_elements) + .collect(); + + Rpo256::hash_elements(&elements) + } + }, + + SmtLeaf::Multiple(pairs) => { + match pairs.binary_search_by(|&(cur_key, _)| leaf::cmp_keys(cur_key, *key)) { + Ok(pos) => { + if is_removal && pairs.len() == 2 { + // This removal would convert this Multi into a Single, so we can + // just stop here. + return SmtLeaf::Single(pairs[0]).hash(); + } + + let (before_pos, rest) = pairs.split_at(pos); + let with_pos_removed = rest.iter().copied().skip(1); + let middle = iter::once(new_pair).filter(|_| !is_removal); + let elements: Vec = before_pos + .iter() + .copied() + .chain(middle) + .chain(with_pos_removed) + .flat_map(leaf::kv_to_elements) + .collect(); + + Rpo256::hash_elements(&elements) + } + Err(pos_for_insert) => { + if is_removal { + // The only values are at other keys, so we just hash the leaf + // as-is. + return existing_leaf.hash(); + } + + let (before_pos, rest) = pairs.split_at(pos_for_insert); + let middle = iter::once(new_pair); + let elements: Vec = before_pos + .iter() + .copied() + .chain(middle) + .chain(rest.iter().copied()) + .flat_map(leaf::kv_to_elements) + .collect(); + + Rpo256::hash_elements(&elements) + } + } } - // TODO: avoid cloning the leaf. - let mut cloned = existing_leaf.clone(); - cloned.insert(*key, *value); - cloned.hash() }, - None => SmtLeaf::new_single(*key, *value).hash(), } }