refactor: make Smt's node recomputation pure

And do mutations in its callers instead.
This commit is contained in:
Qyriad 2024-08-09 17:26:29 -06:00
parent d92fae7f82
commit 77ea774e59
2 changed files with 47 additions and 9 deletions

View file

@ -104,23 +104,41 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
leaf_index.into()
};
self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf));
let mut mutations =
self.recompute_nodes_from_index_to_root(node_index, Self::hash_leaf(&leaf));
for index in mutations.removals.drain(..) {
self.remove_inner_node(index);
}
for (index, new_node) in mutations.additions.drain(..) {
self.insert_inner_node(index, new_node);
}
self.set_root(mutations.new_root);
old_value
}
/// Recomputes the branch nodes (including the root) from `index` all the way to the root.
/// `node_hash_at_index` is the hash of the node stored at index.
///
/// This method is pure, and only computes the mutations to apply.
fn recompute_nodes_from_index_to_root(
&mut self,
&self,
mut index: NodeIndex,
node_hash_at_index: RpoDigest,
) {
) -> Mutations {
let mut node_hash = node_hash_at_index;
let mut removals: Vec<NodeIndex> = Vec::new();
let mut additions: Vec<(NodeIndex, InnerNode)> = Vec::new();
for node_depth in (0..index.depth()).rev() {
let is_right = index.is_value_odd();
index.move_up();
let InnerNode { left, right } = self.get_inner_node(index);
let (left, right) = if is_right {
(left, node_hash)
} else {
@ -129,14 +147,15 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
node_hash = Rpo256::merge(&[left, right]);
if node_hash == *EmptySubtreeRoots::entry(DEPTH, node_depth) {
// If a subtree is empty, when can remove the inner node, since it's equal to the
// default value
self.remove_inner_node(index)
// If a subtree is empty, we can remove the inner node, since it's equal to the
// default value.
removals.push(index);
} else {
self.insert_inner_node(index, InnerNode { left, right });
additions.push((index, InnerNode { left, right }));
}
}
self.set_root(node_hash);
Mutations { removals, additions, new_root: node_hash }
}
// REQUIRED METHODS
@ -243,3 +262,12 @@ impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
Self::new(node_index.value())
}
}
// MUTATIONS
// ================================================================================================
pub(crate) struct Mutations {
removals: Vec<NodeIndex>,
additions: Vec<(NodeIndex, InnerNode)>,
new_root: RpoDigest,
}

View file

@ -242,7 +242,17 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
// recompute nodes starting from subtree root
// --------------
self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root);
let mut mutations =
self.recompute_nodes_from_index_to_root(subtree_root_index, subtree.root);
for index in mutations.removals.drain(..) {
self.remove_inner_node(index);
}
for (index, new_node) in mutations.additions.drain(..) {
self.insert_inner_node(index, new_node);
}
self.set_root(mutations.new_root);
Ok(self.root)
}