factor out subtree-append logic

This commit is contained in:
Qyriad 2024-11-04 12:53:27 -07:00
parent 2b04a93a15
commit 60f4dd2161

View file

@ -68,28 +68,28 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
/// path to the leaf, as well as the leaf itself.
fn open(&self, key: &Self::Key) -> Self::Opening {
let leaf = self.get_leaf(key);
let leaf = self.get_leaf(key);
let mut index: NodeIndex = {
let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(key);
leaf_index.into()
};
let mut index: NodeIndex = {
let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(key);
leaf_index.into()
};
let merkle_path = {
let mut path = Vec::with_capacity(index.depth() as usize);
for _ in 0..index.depth() {
let is_right = index.is_value_odd();
index.move_up();
let InnerNode { left, right } = self.get_inner_node(index);
let value = if is_right { left } else { right };
path.push(value);
}
let merkle_path = {
let mut path = Vec::with_capacity(index.depth() as usize);
for _ in 0..index.depth() {
let is_right = index.is_value_odd();
index.move_up();
let InnerNode { left, right } = self.get_inner_node(index);
let value = if is_right { left } else { right };
path.push(value);
}
MerklePath::new(path)
};
MerklePath::new(path)
};
Self::path_and_leaf_to_opening(merkle_path, leaf)
}
Self::path_and_leaf_to_opening(merkle_path, leaf)
}
/// Inserts a value at the specified key, returning the previous value associated with that key.
/// Recall that by definition, any key that hasn't been updated is associated with
@ -98,53 +98,53 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
/// This also recomputes all hashes between the leaf (associated with the key) and the root,
/// updating the root itself.
fn insert(&mut self, key: Self::Key, value: Self::Value) -> Self::Value {
let old_value = self.insert_value(key.clone(), value.clone()).unwrap_or(Self::EMPTY_VALUE);
let old_value = self.insert_value(key.clone(), value.clone()).unwrap_or(Self::EMPTY_VALUE);
// if the old value and new value are the same, there is nothing to update
if value == old_value {
return value;
// if the old value and new value are the same, there is nothing to update
if value == old_value {
return value;
}
let leaf = self.get_leaf(&key);
let node_index = {
let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(&key);
leaf_index.into()
};
self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf));
old_value
}
let leaf = self.get_leaf(&key);
let node_index = {
let leaf_index: LeafIndex<DEPTH> = Self::key_to_leaf_index(&key);
leaf_index.into()
};
self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf));
old_value
}
/// Recomputes the branch nodes (including the root) from `index` all the way to the root.
/// `node_hash_at_index` is the hash of the node stored at index.
fn recompute_nodes_from_index_to_root(
&mut self,
mut index: NodeIndex,
node_hash_at_index: RpoDigest,
) {
let mut node_hash = node_hash_at_index;
for node_depth in (0..index.depth()).rev() {
let is_right = index.is_value_odd();
index.move_up();
let InnerNode { left, right } = self.get_inner_node(index);
let (left, right) = if is_right {
(left, node_hash)
} else {
(node_hash, right)
};
node_hash = Rpo256::merge(&[left, right]);
&mut self,
mut index: NodeIndex,
node_hash_at_index: RpoDigest,
) {
let mut node_hash = node_hash_at_index;
for node_depth in (0..index.depth()).rev() {
let is_right = index.is_value_odd();
index.move_up();
let InnerNode { left, right } = self.get_inner_node(index);
let (left, right) = if is_right {
(left, node_hash)
} else {
(node_hash, right)
};
node_hash = Rpo256::merge(&[left, right]);
if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
// If a subtree is empty, when can remove the inner node, since it's equal to the
// default value
self.remove_inner_node(index)
} else {
self.insert_inner_node(index, InnerNode { left, right });
if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
// If a subtree is empty, when can remove the inner node, since it's equal to the
// default value
self.remove_inner_node(index)
} else {
self.insert_inner_node(index, InnerNode { left, right });
}
}
self.set_root(node_hash);
}
self.set_root(node_hash);
}
/// Computes what changes are necessary to insert the specified key-value pairs into this Merkle
/// tree, allowing for validation before applying those changes.
@ -155,95 +155,95 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
/// [`SparseMerkleTree::apply_mutations()`] can be called in order to commit these changes to
/// the Merkle tree, or [`drop()`] to discard them.
fn compute_mutations(
&self,
kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
) -> MutationSet<DEPTH, Self::Key, Self::Value> {
use NodeMutation::*;
&self,
kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
) -> MutationSet<DEPTH, Self::Key, Self::Value> {
use NodeMutation::*;
let mut new_root = self.root();
let mut new_pairs: BTreeMap<Self::Key, Self::Value> = Default::default();
let mut node_mutations: BTreeMap<NodeIndex, NodeMutation> = Default::default();
let mut new_root = self.root();
let mut new_pairs: BTreeMap<Self::Key, Self::Value> = Default::default();
let mut node_mutations: BTreeMap<NodeIndex, NodeMutation> = Default::default();
for (key, value) in kv_pairs {
// If the old value and the new value are the same, there is nothing to update.
// For the unusual case that kv_pairs has multiple values at the same key, we'll have
// to check the key-value pairs we've already seen to get the "effective" old value.
let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));
if value == old_value {
continue;
}
for (key, value) in kv_pairs {
// If the old value and the new value are the same, there is nothing to update.
// For the unusual case that kv_pairs has multiple values at the same key, we'll have
// to check the key-value pairs we've already seen to get the "effective" old value.
let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));
if value == old_value {
continue;
}
let leaf_index = Self::key_to_leaf_index(&key);
let mut node_index = NodeIndex::from(leaf_index);
let leaf_index = Self::key_to_leaf_index(&key);
let mut node_index = NodeIndex::from(leaf_index);
// We need the current leaf's hash to calculate the new leaf, but in the rare case that
// `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also
// part of the "current leaf".
let old_leaf = {
let pairs_at_index = new_pairs
.iter()
.filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index);
// We need the current leaf's hash to calculate the new leaf, but in the rare case that
// `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also
// part of the "current leaf".
let old_leaf = {
let pairs_at_index = new_pairs
.iter()
.filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index);
pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| {
// Most of the time `pairs_at_index` should only contain a single entry (or
// none at all), as multi-leaves should be really rare.
let existing_leaf = acc.clone();
self.construct_prospective_leaf(existing_leaf, k, v)
})
};
let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value);
let mut new_child_hash = Self::hash_leaf(&new_leaf);
for node_depth in (0..node_index.depth()).rev() {
// Whether the node we're replacing is the right child or the left child.
let is_right = node_index.is_value_odd();
node_index.move_up();
let old_node = node_mutations
.get(&node_index)
.map(|mutation| match mutation {
Addition(node) => node.clone(),
Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth),
pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| {
// Most of the time `pairs_at_index` should only contain a single entry (or
// none at all), as multi-leaves should be really rare.
let existing_leaf = acc.clone();
self.construct_prospective_leaf(existing_leaf, k, v)
})
.unwrap_or_else(|| self.get_inner_node(node_index));
let new_node = if is_right {
InnerNode {
left: old_node.left,
right: new_child_hash,
}
} else {
InnerNode {
left: new_child_hash,
right: old_node.right,
}
};
// The next iteration will operate on this new node's hash.
new_child_hash = new_node.hash();
let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value);
let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth);
let is_removal = new_child_hash == equivalent_empty_hash;
let new_entry = if is_removal { Removal } else { Addition(new_node) };
node_mutations.insert(node_index, new_entry);
let mut new_child_hash = Self::hash_leaf(&new_leaf);
for node_depth in (0..node_index.depth()).rev() {
// Whether the node we're replacing is the right child or the left child.
let is_right = node_index.is_value_odd();
node_index.move_up();
let old_node = node_mutations
.get(&node_index)
.map(|mutation| match mutation {
Addition(node) => node.clone(),
Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth),
})
.unwrap_or_else(|| self.get_inner_node(node_index));
let new_node = if is_right {
InnerNode {
left: old_node.left,
right: new_child_hash,
}
} else {
InnerNode {
left: new_child_hash,
right: old_node.right,
}
};
// The next iteration will operate on this new node's hash.
new_child_hash = new_node.hash();
let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth);
let is_removal = new_child_hash == equivalent_empty_hash;
let new_entry = if is_removal { Removal } else { Addition(new_node) };
node_mutations.insert(node_index, new_entry);
}
// Once we're at depth 0, the last node we made is the new root.
new_root = new_child_hash;
// And then we're done with this pair; on to the next one.
new_pairs.insert(key, value);
}
// Once we're at depth 0, the last node we made is the new root.
new_root = new_child_hash;
// And then we're done with this pair; on to the next one.
new_pairs.insert(key, value);
MutationSet {
old_root: self.root(),
new_root,
node_mutations,
new_pairs,
}
}
MutationSet {
old_root: self.root(),
new_root,
node_mutations,
new_pairs,
}
}
/// Apply the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
/// this tree.
///
@ -253,42 +253,42 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
fn apply_mutations(
&mut self,
mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
) -> Result<(), MerkleError>
where
Self: Sized,
{
use NodeMutation::*;
let MutationSet {
old_root,
node_mutations,
new_pairs,
new_root,
} = mutations;
&mut self,
mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
) -> Result<(), MerkleError>
where
Self: Sized,
{
use NodeMutation::*;
let MutationSet {
old_root,
node_mutations,
new_pairs,
new_root,
} = mutations;
// Guard against accidentally trying to apply mutations that were computed against a
// different tree, including a stale version of this tree.
if old_root != self.root() {
return Err(MerkleError::ConflictingRoots(vec![old_root, self.root()]));
}
for (index, mutation) in node_mutations {
match mutation {
Removal => self.remove_inner_node(index),
Addition(node) => self.insert_inner_node(index, node),
// Guard against accidentally trying to apply mutations that were computed against a
// different tree, including a stale version of this tree.
if old_root != self.root() {
return Err(MerkleError::ConflictingRoots(vec![old_root, self.root()]));
}
for (index, mutation) in node_mutations {
match mutation {
Removal => self.remove_inner_node(index),
Addition(node) => self.insert_inner_node(index, node),
}
}
for (key, value) in new_pairs {
self.insert_value(key, value);
}
self.set_root(new_root);
Ok(())
}
for (key, value) in new_pairs {
self.insert_value(key, value);
}
self.set_root(new_root);
Ok(())
}
// REQUIRED METHODS
// ---------------------------------------------------------------------------------------------
@ -332,11 +332,11 @@ where
/// `existing_leaf` must have the same leaf index as `key` (as determined by
/// [`SparseMerkleTree::key_to_leaf_index()`]), or the result will be meaningless.
fn construct_prospective_leaf(
&self,
existing_leaf: Self::Leaf,
key: &Self::Key,
value: &Self::Value,
) -> Self::Leaf;
&self,
existing_leaf: Self::Leaf,
key: &Self::Key,
value: &Self::Value,
) -> Self::Leaf;
/// Maps a key to a leaf index
fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>;
@ -383,7 +383,7 @@ where
let hash = Self::hash_leaf(&leaf);
accumulator.nodes.insert(col, leaf);
accumulator.add_leaf(SubtreeLeaf { col, hash });
add_subtree_leaf(&mut accumulator.leaves, SubtreeLeaf { col, hash });
debug_assert!(current_leaf_buffer.is_empty());
}
@ -631,6 +631,7 @@ impl SubtreeLeaf {
}
}
/// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`].
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct PairComputations<K, L> {
/// Literal leaves to be added to the sparse Merkle tree's internal mapping.
@ -679,6 +680,36 @@ impl<K, L> Default for PairComputations<K, L> {
}
}
/// Handles the logic for figuring out whether the new leaf starts a new subtree or not.
fn add_subtree_leaf(subtrees: &mut Vec<Vec<SubtreeLeaf>>, leaf: SubtreeLeaf) {
let last_subtree = match subtrees.last_mut() {
// Base case.
None => {
subtrees.push(vec![leaf]);
return;
},
Some(last_subtree) => last_subtree,
};
debug_assert!(!last_subtree.is_empty());
debug_assert!(last_subtree.len() <= COLS_PER_SUBTREE as usize);
// The multiple of 256 after 0 is 1, but 0 and 1 do not belong to different subtrees.
let last_subtree_col = u64::max(1, last_subtree.last().unwrap().col);
let next_subtree_col = if Integer::is_multiple_of(&last_subtree_col, &COLS_PER_SUBTREE) {
u64::next_multiple_of(last_subtree_col + 1, COLS_PER_SUBTREE)
} else {
last_subtree_col.next_multiple_of(COLS_PER_SUBTREE)
};
if leaf.col < next_subtree_col {
last_subtree.push(leaf);
} else {
let next_subtree = vec![leaf];
subtrees.push(next_subtree);
}
}
// TESTS
// ================================================================================================
#[cfg(test)]
@ -687,9 +718,7 @@ mod test {
use alloc::{collections::BTreeMap, vec::Vec};
use num::Integer;
use super::{InnerNode, PairComputations, SparseMerkleTree, SubtreeLeaf, COLS_PER_SUBTREE};
use super::{InnerNode, PairComputations, SparseMerkleTree, SubtreeLeaf};
use crate::{
hash::rpo::RpoDigest,
merkle::{NodeIndex, Smt, SmtLeaf, SMT_DEPTH},
@ -904,25 +933,7 @@ mod test {
accumulated_nodes.extend(nodes);
for subtree_leaf in next_leaves {
if leaf_subtrees.is_empty() {
leaf_subtrees.push(vec![subtree_leaf]);
continue;
}
let buffer_max_col =
u64::max(1, leaf_subtrees.last().unwrap().last().unwrap().col);
let next_subtree_col =
if Integer::is_multiple_of(&buffer_max_col, &COLS_PER_SUBTREE) {
u64::next_multiple_of(buffer_max_col + 1, COLS_PER_SUBTREE)
} else {
buffer_max_col.next_multiple_of(COLS_PER_SUBTREE)
};
if subtree_leaf.col < next_subtree_col {
leaf_subtrees.last_mut().unwrap().push(subtree_leaf);
} else {
leaf_subtrees.push(vec![subtree_leaf]);
}
super::add_subtree_leaf(&mut leaf_subtrees, subtree_leaf);
}
}