Compare commits

...

9 commits

Author SHA1 Message Date
6addcd0226 WIP(smt): add simple benchmark for single subtree computation 2024-10-16 09:24:17 -06:00
a35c11abfe WIP(smt): allow leaves to be wrapped in an Arc for async 2024-10-11 13:44:19 -06:00
e7ee6b53ba WIP: add many helper methods on NodeIndex 2024-10-04 13:26:03 -06:00
b289e7ed73 feat(smt): impl lowest common ancestor for leaf indices 2024-09-24 10:39:43 -06:00
c414a875f3 feat(merkle): impl constructing NodeIndex from scalar index 2024-09-24 10:39:43 -06:00
0e7e6705d8 fix(merkle): fix overflow in to_scalar_index for nodes at depth 64 2024-09-24 10:39:43 -06:00
e5dd7c6d6a WIP(smt): allow inner_nodes: to be wrapped in an Arc for async 2024-09-24 10:39:43 -06:00
Bobbin Threadbare
913384600d
chore: fix typos 2024-09-11 16:52:21 -07:00
Qyriad
ae807a47ae
feat: implement transactional Smt insertion (#327)
* feat(smt): impl constructing leaves that don't yet exist

This commit implements 'prospective leaf construction' -- computing
sparse Merkle tree leaves for a key-value insertion without actually
performing that insertion.

For SimpleSmt, this is trivial, since the leaf type is simply the value
being inserted.

For the full Smt, the new leaf payload depends on the existing payload
in that leaf. Since almost all leaves are very small, we can just clone
the leaf and modify a copy.

This will allow us to perform more general prospective changes on Merkle
trees.

* feat(smt): export get_value() in the trait

* feat(smt): implement generic prospective insertions

This commit adds two methods to SparseMerkleTree: compute_mutations()
and apply_mutations(), which respectively create and consume a new
MutationSet type. This type represents as set of changes to a
SparseMerkleTree that haven't happened yet, and can be queried on to
ensure a set of insertions result in the correct tree root before
finalizing and committing the mutation.

This is a direct step towards issue 222, and will directly enable
removing Merkle tree clones in miden-node InnerState::apply_block().

As part of this change, SparseMerkleTree now requires its Key to be Ord
and its Leaf to be Clone (both bounds which were already met by existing
implementations). The Ord bound could instead be changed to Eq + Hash,
if MutationSet were changed to use a HashMap instead of a BTreeMap.

* chore(smt): refactor empty node construction to helper function
2024-09-11 16:49:57 -07:00
13 changed files with 1169 additions and 50 deletions

View file

@ -3,6 +3,7 @@
- [BREAKING]: renamed `Mmr::open()` into `Mmr::open_at()` and `Mmr::peaks()` into `Mmr::peaks_at()` (#234).
- Added `Mmr::open()` and `Mmr::peaks()` which rely on `Mmr::open_at()` and `Mmr::peaks()` respectively (#234).
- Standardised CI and Makefile across Miden repos (#323).
- Added `Smt::compute_mutations()` and `Smt::apply_mutations()` for validation-checked insertions (#327).
## 0.10.0 (2024-08-06)

View file

@ -31,8 +31,12 @@ harness = false
name = "store"
harness = false
[[bench]]
name = "subtree"
harness = false
[features]
default = ["std"]
default = ["std", "async"]
executable = ["dep:clap", "dep:rand-utils", "std"]
serde = ["dep:serde", "serde?/alloc", "winter-math/serde"]
std = [
@ -44,6 +48,7 @@ std = [
"winter-math/std",
"winter-utils/std",
]
async = ["serde?/rc"]
[dependencies]
blake3 = { version = "1.5", default-features = false }

66
benches/subtree.rs Normal file
View file

@ -0,0 +1,66 @@
use std::{collections::BTreeMap, sync::Arc, time::Duration};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use miden_crypto::{
hash::rpo::RpoDigest,
merkle::{NodeIndex, NodeSubtreeComputer, Smt, SparseMerkleTree},
Felt, Word, ONE,
};
const SUBTREE_INTERVAL: u8 = 8;
fn setup_subtree8(tree_size: u64) -> (Smt, NodeIndex, Arc<BTreeMap<RpoDigest, Word>>, RpoDigest) {
let entries: BTreeMap<RpoDigest, Word> = (0..tree_size)
.into_iter()
.map(|i| {
let leaf_index = u64::MAX / (i + 1);
let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]);
let value = [ONE, ONE, ONE, Felt::new(i)];
(key, value)
})
.collect();
let control = Smt::with_entries(entries.clone()).unwrap();
let subtree = entries
.keys()
.map(|key| {
let index_for_key = NodeIndex::from(Smt::key_to_leaf_index(key));
index_for_key.parent_n(SUBTREE_INTERVAL)
})
.next()
.unwrap();
let control_hash = control.get_inner_node(subtree).hash();
(Smt::new(), subtree, Arc::new(entries), control_hash)
}
fn bench_subtree8(
(smt, subtree, entries, control_hash): (
Smt,
NodeIndex,
Arc<BTreeMap<RpoDigest, Word>>,
RpoDigest,
),
) {
let mut state = NodeSubtreeComputer::with_smt(&smt, Default::default(), entries);
let hash = state.get_or_make_hash(subtree);
assert_eq!(control_hash, hash);
}
fn smt_subtree8(c: &mut Criterion) {
let mut group = c.benchmark_group("subtree8");
group.measurement_time(Duration::from_secs(120));
group.sample_size(30);
for &tree_size in [32, 128, 512, 1024].iter() {
let bench_id = BenchmarkId::from_parameter(tree_size);
//group.throughput(Throughput::Elements(tree_size));
group.bench_with_input(bench_id, &tree_size, |bench, &tree_size| {
bench.iter_batched(|| setup_subtree8(tree_size), bench_subtree8, BatchSize::SmallInput);
});
}
group.finish();
}
criterion_group!(subtree_group, smt_subtree8);
criterion_main!(subtree_group);

View file

@ -74,7 +74,7 @@ where
rev
}
/// Computes the first n powers of the 2nth root of unity, and put them in bit-reversed order.
/// Computes the first n powers of the 2nd root of unity, and put them in bit-reversed order.
#[allow(dead_code)]
fn bitreversed_powers(n: usize) -> Vec<Self> {
let psi = Self::primitive_root_of_unity(2 * n);
@ -88,7 +88,7 @@ where
array
}
/// Computes the first n powers of the 2nth root of unity, invert them, and put them in
/// Computes the first n powers of the 2nd root of unity, invert them, and put them in
/// bit-reversed order.
#[allow(dead_code)]
fn bitreversed_powers_inverse(n: usize) -> Vec<Self> {

View file

@ -35,6 +35,7 @@ pub fn benchmark_smt() {
let mut tree = construction(entries, tree_size).unwrap();
insertion(&mut tree, tree_size).unwrap();
batched_insertion(&mut tree, tree_size).unwrap();
proof_generation(&mut tree, tree_size).unwrap();
}
@ -82,6 +83,54 @@ pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
Ok(())
}
pub fn batched_insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
println!("Running a batched insertion benchmark:");
let new_pairs: Vec<(RpoDigest, Word)> = (0..1000)
.map(|i| {
let key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let value = [ONE, ONE, ONE, Felt::new(size + i)];
(key, value)
})
.collect();
let now = Instant::now();
let mutations = tree.compute_mutations(new_pairs);
let compute_elapsed = now.elapsed();
let now = Instant::now();
tree.apply_mutations(mutations).unwrap();
let apply_elapsed = now.elapsed();
println!(
"An average batch computation time measured by a 1k-batch into an SMT with {} key-value pairs over {:.3} milliseconds is {:.3} milliseconds",
size,
compute_elapsed.as_secs_f32() * 1000f32,
// Dividing by the number of iterations, 1000, and then multiplying by 1000 to get
// milliseconds, cancels out.
compute_elapsed.as_secs_f32(),
);
println!(
"An average batch application time measured by a 1k-batch into an SMT with {} key-value pairs over {:.3} milliseconds is {:.3} milliseconds",
size,
apply_elapsed.as_secs_f32() * 1000f32,
// Dividing by the number of iterations, 1000, and then multiplying by 1000 to get
// milliseconds, cancels out.
apply_elapsed.as_secs_f32(),
);
println!(
"An average batch insertion time measured by a 1k-batch into an SMT with {} key-value pairs totals to {:.3} milliseconds",
size,
(compute_elapsed + apply_elapsed).as_secs_f32() * 1000f32,
);
println!();
Ok(())
}
/// Runs the proof generation benchmark for the [`Smt`].
pub fn proof_generation(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
println!("Running a proof generation benchmark:");

View file

@ -1,6 +1,6 @@
use core::slice;
use super::{Felt, RpoDigest, EMPTY_WORD};
use super::{smt::InnerNode, Felt, RpoDigest, EMPTY_WORD};
// EMPTY NODES SUBTREES
// ================================================================================================
@ -25,6 +25,17 @@ impl EmptySubtreeRoots {
let pos = 255 - tree_depth + node_depth;
&EMPTY_SUBTREES[pos as usize]
}
/// Returns a sparse Merkle tree [`InnerNode`] with two empty children.
///
/// # Note
/// `node_depth` is the depth of the **parent** to have empty children. That is, `node_depth`
/// and the depth of the returned [`InnerNode`] are the same, and thus the empty hashes are for
/// subtrees of depth `node_depth + 1`.
pub(crate) const fn get_inner_node(tree_depth: u8, node_depth: u8) -> InnerNode {
let &child = Self::entry(tree_depth, node_depth + 1);
InnerNode { left: child, right: child }
}
}
const EMPTY_SUBTREES: [RpoDigest; 256] = [

View file

@ -1,4 +1,4 @@
use core::fmt::Display;
use core::{fmt::Display, num::NonZero};
use super::{Felt, MerkleError, RpoDigest};
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
@ -72,6 +72,53 @@ impl NodeIndex {
Self::new(depth, value)
}
/// Converts a scalar representation of a depth/value pair to a [`NodeIndex`].
///
/// This is the inverse operation of [`NodeIndex::to_scalar_index()`]. As `1` represents the
/// root node, `index` cannot be zero.
///
/// # Errors
/// Returns the same errors under the same conditions as [`NodeIndex::new()`].
///
/// # Panics
/// Panics if the depth indicated by `index` does not fit in a [`u8`], or if the row-value
/// indicated by `index` does not fit in a [`u64`].
pub fn from_scalar_index(index: NonZero<u128>) -> Result<Self, MerkleError> {
let index = index.get() - 1;
if index == 0 {
return Ok(Self::root());
}
// The log of 1 is always 0.
if index == 1 {
return Ok(Self::root().left_child());
}
let depth = {
let depth = u128::ilog2(index + 1);
assert!(depth <= u8::MAX as u32);
//let depth = f64::log2(index as f64).round();
//std::eprintln!("depth for scalar index {index} is {depth}");
//assert!(depth <= u8::MAX as f64);
depth as u8
};
let max_value_for_depth = (1 << depth) - 1;
assert!(
max_value_for_depth <= u64::MAX as u128,
"max_value ({max_value_for_depth}) does not fit in u64",
);
let value = {
let value = index - max_value_for_depth;
assert!(value <= u64::MAX as u128);
value as u64
};
Self::new(depth, value)
}
/// Creates a new node index pointing to the root of the tree.
pub const fn root() -> Self {
Self { depth: 0, value: 0 }
@ -97,6 +144,55 @@ impl NodeIndex {
self
}
/// Returns the parent of the current node.
pub const fn parent(mut self) -> Self {
self.depth = self.depth.saturating_sub(1);
self.value >>= 1;
self
}
/// Returns the `n`th parent of the current node.
pub const fn parent_n(mut self, n: u8) -> Self {
debug_assert!(n < self.depth);
let delta = self.depth.saturating_sub(n);
self.depth = self.depth.saturating_sub(delta);
self.value >>= delta as u32;
self
}
/// Returns `true` if and only if `other` is a child of the current node.
pub const fn contains(&self, mut other: Self) -> bool {
loop {
if self.depth == other.depth && self.value == other.value {
return true;
}
if other.is_root() {
return false;
}
other = other.parent();
}
}
/// Returns the right-most descendent of the current node for a tree of `DEPTH` depth.
pub const fn rightmost_descendent<const DEPTH: u8>(mut self) -> Self {
while self.depth() < DEPTH {
self = self.right_child();
}
self
}
/// Returns the left-most descendent of the current node for a tree of `DEPTH` depth.
pub const fn leftmost_descendent<const DEPTH: u8>(mut self) -> Self {
while self.depth() < DEPTH {
self = self.left_child();
}
self
}
// PROVIDERS
// --------------------------------------------------------------------------------------------
@ -114,8 +210,8 @@ impl NodeIndex {
/// Returns the scalar representation of the depth/value pair.
///
/// It is computed as `2^depth + value`.
pub const fn to_scalar_index(&self) -> u64 {
(1 << self.depth as u64) + self.value
pub const fn to_scalar_index(&self) -> u128 {
(1 << self.depth as u64) + (self.value as u128)
}
/// Returns the depth of the current instance.
@ -210,6 +306,52 @@ mod tests {
assert!(NodeIndex::new(64, u64::MAX).is_ok());
}
//#[test]
//fn test_traversal_roundtrip() {
// // Arbitrary value that's at the bottom and not in a corner.
// let start = NodeIndex::make(64, u64::MAX - 8);
//
// let mut index = start;
// while !index.is_root() {
// std::dbg!(&index);
// let as_traversal = index.to_traversal_index();
// let as_scalar = index.to_scalar_index() - 1;
// assert_eq!(as_traversal, as_scalar as u128);
// let round_trip = NodeIndex::from_traversal_index(as_traversal).unwrap();
// assert_eq!(index, round_trip, "{:?} did not round-trip as a traversal index", index);
// index.move_up();
// }
// assert!(index.is_root());
// let root_control = NodeIndex::root();
// assert_eq!(index, root_control);
//
// // Traversal index 0 should be root.
// assert_eq!(index, NodeIndex::from_traversal_index(0).unwrap());
//}
#[test]
fn test_scalar_roundtrip() {
// Arbitrary value that's at the bottom and not in a corner.
let start = NodeIndex::make(64, u64::MAX - 8);
let mut index = start;
while !index.is_root() {
let as_scalar = index.to_scalar_index();
let round_trip =
NodeIndex::from_scalar_index(NonZero::new(as_scalar).unwrap()).unwrap();
assert_eq!(index, round_trip, "{index:?} did not round-trip as a scalar index");
index.move_up();
}
//let start = NodeIndex::root().left_child().to_scalar_index();
//let max = u64::MAX as u128;
//for scalar in start..max {
// let index = NodeIndex::from_scalar_index(NonZero::new(scalar).unwrap()).unwrap();
// let round_trip = index.to_scalar_index();
// assert_eq!(scalar, round_trip, "scalar index {scalar} ({index:?}) did not round-trip");
//}
}
prop_compose! {
fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
// unwrap never panics because the range of depth is 0..u64::BITS

View file

@ -22,8 +22,8 @@ pub use path::{MerklePath, RootPath, ValuePath};
mod smt;
pub use smt::{
LeafIndex, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH,
SMT_MAX_DEPTH, SMT_MIN_DEPTH,
InnerNode, LeafIndex, MutationSet, NodeSubtreeComputer, SimpleSmt, Smt, SmtLeaf, SmtLeafError,
SmtProof, SmtProofError, SparseMerkleTree, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
};
mod mmr;

View file

@ -350,7 +350,7 @@ impl Deserializable for SmtLeaf {
// ================================================================================================
/// Converts a key-value tuple to an iterator of `Felt`s
fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt> {
pub(crate) fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt> {
let key_elements = key.into_iter();
let value_elements = value.into_iter();
@ -359,7 +359,7 @@ fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt>
/// Compares two keys, compared element-by-element using their integer representations starting with
/// the most significant element.
fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering {
pub(crate) fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering {
for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() {
let v1 = v1.as_int();
let v2 = v2.as_int();

View file

@ -1,3 +1,6 @@
#[cfg(feature = "async")]
use std::{collections::HashMap, sync::Arc};
use alloc::{
collections::{BTreeMap, BTreeSet},
string::ToString,
@ -6,9 +9,12 @@ use alloc::{
use super::{
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath,
NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
};
#[cfg(feature = "async")]
use super::NodeMutation;
mod error;
pub use error::{SmtLeafError, SmtProofError};
@ -43,8 +49,16 @@ pub const SMT_DEPTH: u8 = 64;
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct Smt {
root: RpoDigest,
#[cfg(not(feature = "async"))]
leaves: BTreeMap<u64, SmtLeaf>,
#[cfg(feature = "async")]
leaves: Arc<BTreeMap<u64, SmtLeaf>>,
#[cfg(not(feature = "async"))]
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
#[cfg(feature = "async")]
inner_nodes: Arc<BTreeMap<NodeIndex, InnerNode>>,
}
impl Smt {
@ -64,8 +78,8 @@ impl Smt {
Self {
root,
leaves: BTreeMap::new(),
inner_nodes: BTreeMap::new(),
leaves: Default::default(),
inner_nodes: Default::default(),
}
}
@ -101,6 +115,11 @@ impl Smt {
Ok(tree)
}
#[cfg(feature = "async")]
pub fn get_leaves(&self) -> Arc<BTreeMap<u64, SmtLeaf>> {
Arc::clone(&self.leaves)
}
// PUBLIC ACCESSORS
// --------------------------------------------------------------------------------------------
@ -121,12 +140,7 @@ impl Smt {
/// Returns the value associated with `key`
pub fn get_value(&self, key: &RpoDigest) -> Word {
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
match self.leaves.get(&leaf_pos) {
Some(leaf) => leaf.get_value(key).unwrap_or_default(),
None => EMPTY_WORD,
}
<Self as SparseMerkleTree<SMT_DEPTH>>::get_value(self, key)
}
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
@ -159,6 +173,40 @@ impl Smt {
})
}
/// Gets a mutable reference to this structure's inner node mapping.
///
/// # Panics
/// This will panic if we have violated our own invariants and try to mutate these nodes while
/// Self::compute_mutations_parallel() is still running.
fn inner_nodes_mut(&mut self) -> &mut BTreeMap<NodeIndex, InnerNode> {
#[cfg(feature = "async")]
{
Arc::get_mut(&mut self.inner_nodes).unwrap()
}
#[cfg(not(feature = "async"))]
{
&mut self.inner_nodes
}
}
/// Gets a mutable reference to this structure's inner leaf mapping.
///
/// # Panics
/// This will panic if we have violated our own invariants and try to mutate these nodes while
/// Self::compute_mutations_parallel() is still running.
fn leaves_mut(&mut self) -> &mut BTreeMap<u64, SmtLeaf> {
#[cfg(feature = "async")]
{
Arc::get_mut(&mut self.leaves).unwrap()
}
#[cfg(not(feature = "async"))]
{
&mut self.leaves
}
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
@ -172,6 +220,47 @@ impl Smt {
<Self as SparseMerkleTree<SMT_DEPTH>>::insert(self, key, value)
}
/// Computes what changes are necessary to insert the specified key-value pairs into this Merkle
/// tree, allowing for validation before applying those changes.
///
/// This method returns a [`MutationSet`], which contains all the information for inserting
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
/// [`Smt::apply_mutations()`] can be called in order to commit these changes to the Merkle
/// tree, or [`drop()`] to discard them.
///
/// # Example
/// ```
/// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word};
/// # use miden_crypto::merkle::{Smt, EmptySubtreeRoots, SMT_DEPTH};
/// let mut smt = Smt::new();
/// let pair = (RpoDigest::default(), Word::default());
/// let mutations = smt.compute_mutations(vec![pair]);
/// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
/// smt.apply_mutations(mutations);
/// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
/// ```
pub fn compute_mutations(
&self,
kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
<Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations(self, kv_pairs)
}
/// Apply the prospective mutations computed with [`Smt::compute_mutations()`] to this tree.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
pub fn apply_mutations(
&mut self,
mutations: MutationSet<SMT_DEPTH, RpoDigest, Word>,
) -> Result<(), MerkleError> {
<Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations(self, mutations)
}
// HELPERS
// --------------------------------------------------------------------------------------------
@ -182,10 +271,12 @@ impl Smt {
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
match self.leaves.get_mut(&leaf_index.value()) {
let leaves = self.leaves_mut();
match leaves.get_mut(&leaf_index.value()) {
Some(leaf) => leaf.insert(key, value),
None => {
self.leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
None
},
@ -196,10 +287,12 @@ impl Smt {
fn perform_remove(&mut self, key: RpoDigest) -> Option<Word> {
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
if let Some(leaf) = self.leaves.get_mut(&leaf_index.value()) {
let leaves = self.leaves_mut();
if let Some(leaf) = leaves.get_mut(&leaf_index.value()) {
let (old_value, is_empty) = leaf.remove(key);
if is_empty {
self.leaves.remove(&leaf_index.value());
leaves.remove(&leaf_index.value());
}
old_value
} else {
@ -207,6 +300,27 @@ impl Smt {
None
}
}
fn construct_prospective_leaf(
mut existing_leaf: SmtLeaf,
key: &RpoDigest,
value: &Word,
) -> SmtLeaf {
debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key));
match existing_leaf {
SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value),
_ => {
if *value != EMPTY_WORD {
existing_leaf.insert(*key, *value);
} else {
existing_leaf.remove(*key);
}
existing_leaf
},
}
}
}
impl SparseMerkleTree<SMT_DEPTH> for Smt {
@ -226,19 +340,18 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
}
fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
self.inner_nodes.get(&index).cloned().unwrap_or_else(|| {
let node = EmptySubtreeRoots::entry(SMT_DEPTH, index.depth() + 1);
InnerNode { left: *node, right: *node }
})
self.inner_nodes
.get(&index)
.cloned()
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth()))
}
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
self.inner_nodes.insert(index, inner_node);
self.inner_nodes_mut().insert(index, inner_node);
}
fn remove_inner_node(&mut self, index: NodeIndex) {
let _ = self.inner_nodes.remove(&index);
let _ = self.inner_nodes_mut().remove(&index);
}
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> {
@ -250,6 +363,15 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
}
}
fn get_value(&self, key: &Self::Key) -> Self::Value {
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
match self.leaves.get(&leaf_pos) {
Some(leaf) => leaf.get_value(key).unwrap_or_default(),
None => EMPTY_WORD,
}
}
fn get_leaf(&self, key: &RpoDigest) -> Self::Leaf {
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
@ -263,6 +385,28 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
leaf.hash()
}
fn construct_prospective_leaf(
&self,
mut existing_leaf: SmtLeaf,
key: &RpoDigest,
value: &Word,
) -> SmtLeaf {
debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key));
match existing_leaf {
SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value),
_ => {
if *value != EMPTY_WORD {
existing_leaf.insert(*key, *value);
} else {
existing_leaf.remove(*key);
}
existing_leaf
},
}
}
fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex<SMT_DEPTH> {
let most_significant_felt = key[3];
LeafIndex::new_max_depth(most_significant_felt.as_int())
@ -279,6 +423,141 @@ impl Default for Smt {
}
}
/// Just a [`NodeMutation`] with its hash already computed and stored.
#[cfg(feature = "async")]
pub struct ComputedNodeMutation {
pub mutation: NodeMutation,
pub hash: RpoDigest,
}
#[cfg(feature = "async")]
pub struct NodeSubtreeComputer {
inner_nodes: Arc<BTreeMap<NodeIndex, InnerNode>>,
leaves: Arc<BTreeMap<u64, SmtLeaf>>,
existing_mutations: Arc<HashMap<NodeIndex, ComputedNodeMutation>>,
new_mutations: HashMap<NodeIndex, ComputedNodeMutation>,
new_pairs: Arc<BTreeMap<RpoDigest, Word>>,
/// Cache indices we know to be dirty.
dirtied_indices: HashMap<NodeIndex, bool>,
cached_leaf_hashes: HashMap<LeafIndex<SMT_DEPTH>, RpoDigest>,
}
#[cfg(feature = "async")]
impl NodeSubtreeComputer {
pub fn with_smt(
smt: &Smt,
existing_mutations: Arc<HashMap<NodeIndex, ComputedNodeMutation>>,
new_pairs: Arc<BTreeMap<RpoDigest, Word>>,
) -> Self {
Self {
inner_nodes: Arc::clone(&smt.inner_nodes),
leaves: Arc::clone(&smt.leaves),
existing_mutations,
new_mutations: Default::default(),
new_pairs,
dirtied_indices: Default::default(),
cached_leaf_hashes: Default::default(),
}
}
pub(crate) fn is_index_dirty(&mut self, index_to_check: NodeIndex) -> bool {
if let Some(cached) = self.dirtied_indices.get(&index_to_check) {
return *cached;
}
// An index is dirty if there is a new pair at it, an known existing mutation at it, or an
// ancestor of one of those.
let is_dirty = self
.existing_mutations
.iter()
.map(|(index, _)| *index)
.chain(self.new_pairs.iter().map(|(key, _v)| Smt::key_to_leaf_index(key).index))
.filter(|&dirtied_index| index_to_check.contains(dirtied_index))
.next()
.is_some();
// This is somewhat expensive to compute, so cache this.
self.dirtied_indices.insert(index_to_check, is_dirty);
is_dirty
}
pub(crate) fn get_effective_leaf(&self, index: LeafIndex<SMT_DEPTH>) -> SmtLeaf {
let pairs_at_index = self
.new_pairs
.iter()
.filter(|&(new_key, _)| Smt::key_to_leaf_index(new_key) == index);
let existing_leaf = self
.leaves
.get(&index.index.value())
.cloned()
.unwrap_or_else(|| SmtLeaf::new_empty(index));
pairs_at_index.fold(existing_leaf, |acc, (k, v)| {
let existing_leaf = acc.clone();
Smt::construct_prospective_leaf(existing_leaf, k, v)
})
}
/// Does NOT check `new_mutations`.
pub(crate) fn get_clean_hash(&self, index: NodeIndex) -> Option<RpoDigest> {
self.existing_mutations
.get(&index)
.map(|ComputedNodeMutation { hash, .. }| *hash)
.or_else(|| self.inner_nodes.get(&index).map(|inner_node| InnerNode::hash(&inner_node)))
}
/// Retrieve a cached hash, or recursively compute it.
pub fn get_or_make_hash(&mut self, index: NodeIndex) -> RpoDigest {
use NodeMutation::*;
// If this is a leaf, then only do leaf stuff.
if index.depth() == SMT_DEPTH {
let index = LeafIndex::new(index.value()).unwrap();
return match self.cached_leaf_hashes.get(&index) {
Some(cached_hash) => cached_hash.clone(),
None => {
let leaf = self.get_effective_leaf(index);
let hash = Smt::hash_leaf(&leaf);
self.cached_leaf_hashes.insert(index, hash);
hash
},
};
}
// If we already computed this one earlier as a mutation, just return it.
if let Some(ComputedNodeMutation { hash, .. }) = self.new_mutations.get(&index) {
return *hash;
}
// Otherwise, we need to know if this node is one of the nodes we're in the process of
// recomputing, or if we can safely use the node already in the Merkle tree.
if !self.is_index_dirty(index) {
return self
.get_clean_hash(index)
.unwrap_or_else(|| *EmptySubtreeRoots::entry(SMT_DEPTH, index.depth()));
}
// If we got here, then we have to make, rather than get, this hash.
// Make sure we mark this index as now dirty.
self.dirtied_indices.insert(index, true);
// Recurse for the left and right sides.
let left = self.get_or_make_hash(index.left_child());
let right = self.get_or_make_hash(index.right_child());
let node = InnerNode { left, right };
let hash = node.hash();
let &equivalent_empty_hash = EmptySubtreeRoots::entry(SMT_DEPTH, index.depth());
let is_removal = hash == equivalent_empty_hash;
let new_entry = if is_removal { Removal } else { Addition(node) };
self.new_mutations
.insert(index, ComputedNodeMutation { hash, mutation: new_entry });
hash
}
}
// CONVERSIONS
// ================================================================================================

View file

@ -2,7 +2,7 @@ use alloc::vec::Vec;
use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH};
use crate::{
merkle::{EmptySubtreeRoots, MerkleStore},
merkle::{smt::SparseMerkleTree, EmptySubtreeRoots, MerkleStore},
utils::{Deserializable, Serializable},
Word, ONE, WORD_SIZE,
};
@ -258,6 +258,195 @@ fn test_smt_removal() {
}
}
/// This tests that we can correctly calculate prospective leaves -- that is, we can construct
/// correct [`SmtLeaf`] values for a theoretical insertion on a Merkle tree without mutating or
/// cloning the tree.
#[test]
fn test_prospective_hash() {
let mut smt = Smt::default();
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
// Sort key_3 before key_1, to test non-append insertion.
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
// insert key-value 1
{
let prospective =
smt.construct_prospective_leaf(smt.get_leaf(&key_1), &key_1, &value_1).hash();
smt.insert(key_1, value_1);
let leaf = smt.get_leaf(&key_1);
assert_eq!(
prospective,
leaf.hash(),
"prospective hash for leaf {leaf:?} did not match actual hash",
);
}
// insert key-value 2
{
let prospective =
smt.construct_prospective_leaf(smt.get_leaf(&key_2), &key_2, &value_2).hash();
smt.insert(key_2, value_2);
let leaf = smt.get_leaf(&key_2);
assert_eq!(
prospective,
leaf.hash(),
"prospective hash for leaf {leaf:?} did not match actual hash",
);
}
// insert key-value 3
{
let prospective =
smt.construct_prospective_leaf(smt.get_leaf(&key_3), &key_3, &value_3).hash();
smt.insert(key_3, value_3);
let leaf = smt.get_leaf(&key_3);
assert_eq!(
prospective,
leaf.hash(),
"prospective hash for leaf {leaf:?} did not match actual hash",
);
}
// remove key 3
{
let old_leaf = smt.get_leaf(&key_3);
let old_value_3 = smt.insert(key_3, EMPTY_WORD);
assert_eq!(old_value_3, value_3);
let prospective_leaf =
smt.construct_prospective_leaf(smt.get_leaf(&key_3), &key_3, &old_value_3);
assert_eq!(
old_leaf.hash(),
prospective_leaf.hash(),
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
\n original leaf: {old_leaf:?}\
\n prospective leaf: {prospective_leaf:?}",
);
}
// remove key 2
{
let old_leaf = smt.get_leaf(&key_2);
let old_value_2 = smt.insert(key_2, EMPTY_WORD);
assert_eq!(old_value_2, value_2);
let prospective_leaf =
smt.construct_prospective_leaf(smt.get_leaf(&key_2), &key_2, &old_value_2);
assert_eq!(
old_leaf.hash(),
prospective_leaf.hash(),
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
\n original leaf: {old_leaf:?}\
\n prospective leaf: {prospective_leaf:?}",
);
}
// remove key 1
{
let old_leaf = smt.get_leaf(&key_1);
let old_value_1 = smt.insert(key_1, EMPTY_WORD);
assert_eq!(old_value_1, value_1);
let prospective_leaf =
smt.construct_prospective_leaf(smt.get_leaf(&key_1), &key_1, &old_value_1);
assert_eq!(
old_leaf.hash(),
prospective_leaf.hash(),
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
\n original leaf: {old_leaf:?}\
\n prospective leaf: {prospective_leaf:?}",
);
}
}
/// This tests that we can perform prospective changes correctly.
#[test]
fn test_prospective_insertion() {
let mut smt = Smt::default();
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
// Sort key_3 before key_1, to test non-append insertion.
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
let root_empty = smt.root();
let root_1 = {
smt.insert(key_1, value_1);
smt.root()
};
let root_2 = {
smt.insert(key_2, value_2);
smt.root()
};
let root_3 = {
smt.insert(key_3, value_3);
smt.root()
};
// Test incremental updates.
let mut smt = Smt::default();
let mutations = smt.compute_mutations(vec![(key_1, value_1)]);
assert_eq!(mutations.root(), root_1, "prospective root 1 did not match actual root 1");
smt.apply_mutations(mutations).unwrap();
assert_eq!(smt.root(), root_1, "mutations before and after apply did not match");
let mutations = smt.compute_mutations(vec![(key_2, value_2)]);
assert_eq!(mutations.root(), root_2, "prospective root 2 did not match actual root 2");
let mutations =
smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_2, value_2), (key_3, value_3)]);
assert_eq!(mutations.root(), root_3, "mutations before and after apply did not match");
smt.apply_mutations(mutations).unwrap();
// Edge case: multiple values at the same key, where a later pair restores the original value.
let mutations = smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_3, value_3)]);
assert_eq!(mutations.root(), root_3);
smt.apply_mutations(mutations).unwrap();
assert_eq!(smt.root(), root_3);
// Test batch updates, and that the order doesn't matter.
let pairs =
vec![(key_3, value_2), (key_2, EMPTY_WORD), (key_1, EMPTY_WORD), (key_3, EMPTY_WORD)];
let mutations = smt.compute_mutations(pairs);
assert_eq!(
mutations.root(),
root_empty,
"prospective root for batch removal did not match actual root",
);
smt.apply_mutations(mutations).unwrap();
assert_eq!(smt.root(), root_empty, "mutations before and after apply did not match");
let pairs = vec![(key_3, value_3), (key_1, value_1), (key_2, value_2)];
let mutations = smt.compute_mutations(pairs);
assert_eq!(mutations.root(), root_3);
smt.apply_mutations(mutations).unwrap();
assert_eq!(smt.root(), root_3);
}
/// Tests that 2 key-value pairs stored in the same leaf have the same path
#[test]
fn test_smt_path_to_keys_in_same_leaf_are_equal() {

View file

@ -1,4 +1,4 @@
use alloc::vec::Vec;
use alloc::{collections::BTreeMap, vec::Vec};
use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex};
use crate::{
@ -7,7 +7,9 @@ use crate::{
};
mod full;
pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH};
pub use full::{
NodeSubtreeComputer, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH,
};
mod simple;
pub use simple::SimpleSmt;
@ -43,13 +45,13 @@ pub const SMT_MAX_DEPTH: u8 = 64;
/// must accomodate all keys that map to the same leaf.
///
/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
pub trait SparseMerkleTree<const DEPTH: u8> {
/// The type for a key
type Key: Clone;
type Key: Clone + Ord;
/// The type for a value
type Value: Clone + PartialEq;
/// The type for a leaf
type Leaf;
type Leaf: Clone;
/// The type for an opening (i.e. a "proof") of a leaf
type Opening;
@ -140,6 +142,149 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
self.set_root(node_hash);
}
/// Computes what changes are necessary to insert the specified key-value pairs into this Merkle
/// tree, allowing for validation before applying those changes.
///
/// This method returns a [`MutationSet`], which contains all the information for inserting
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
/// [`SparseMerkleTree::apply_mutations()`] can be called in order to commit these changes to
/// the Merkle tree, or [`drop()`] to discard them.
fn compute_mutations(
&self,
kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
) -> MutationSet<DEPTH, Self::Key, Self::Value> {
use NodeMutation::*;
let mut new_root = self.root();
let mut new_pairs: BTreeMap<Self::Key, Self::Value> = Default::default();
let mut node_mutations: BTreeMap<NodeIndex, NodeMutation> = Default::default();
for (key, value) in kv_pairs {
// If the old value and the new value are the same, there is nothing to update.
// For the unusual case that kv_pairs has multiple values at the same key, we'll have
// to check the key-value pairs we've already seen to get the "effective" old value.
let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));
if value == old_value {
continue;
}
let leaf_index = Self::key_to_leaf_index(&key);
let mut node_index = NodeIndex::from(leaf_index);
// We need the current leaf's hash to calculate the new leaf, but in the rare case that
// `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also
// part of the "current leaf".
let old_leaf = {
let pairs_at_index = new_pairs
.iter()
.filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index);
pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| {
// Most of the time `pairs_at_index` should only contain a single entry (or
// none at all), as multi-leaves should be really rare.
let existing_leaf = acc.clone();
self.construct_prospective_leaf(existing_leaf, k, v)
})
};
let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value);
let mut new_child_hash = Self::hash_leaf(&new_leaf);
for node_depth in (0..node_index.depth()).rev() {
// Whether the node we're replacing is the right child or the left child.
let is_right = node_index.is_value_odd();
node_index.move_up();
let old_node = node_mutations
.get(&node_index)
.map(|mutation| match mutation {
Addition(node) => node.clone(),
Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth),
})
.unwrap_or_else(|| self.get_inner_node(node_index));
let new_node = if is_right {
InnerNode {
left: old_node.left,
right: new_child_hash,
}
} else {
InnerNode {
left: new_child_hash,
right: old_node.right,
}
};
// The next iteration will operate on this new node's hash.
new_child_hash = new_node.hash();
let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth);
let is_removal = new_child_hash == equivalent_empty_hash;
let new_entry = if is_removal { Removal } else { Addition(new_node) };
node_mutations.insert(node_index, new_entry);
}
// Once we're at depth 0, the last node we made is the new root.
new_root = new_child_hash;
// And then we're done with this pair; on to the next one.
new_pairs.insert(key, value);
}
MutationSet {
old_root: self.root(),
new_root,
node_mutations,
new_pairs,
}
}
/// Apply the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
/// this tree.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
fn apply_mutations(
&mut self,
mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
) -> Result<(), MerkleError>
where
Self: Sized,
{
use NodeMutation::*;
let MutationSet {
old_root,
node_mutations,
new_pairs,
new_root,
} = mutations;
// Guard against accidentally trying to apply mutations that were computed against a
// different tree, including a stale version of this tree.
if old_root != self.root() {
return Err(MerkleError::ConflictingRoots(vec![old_root, self.root()]));
}
for (index, mutation) in node_mutations {
match mutation {
Removal => self.remove_inner_node(index),
Addition(node) => self.insert_inner_node(index, node),
}
}
for (key, value) in new_pairs {
self.insert_value(key, value);
}
self.set_root(new_root);
Ok(())
}
// REQUIRED METHODS
// ---------------------------------------------------------------------------------------------
@ -161,12 +306,34 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
/// Inserts a leaf node, and returns the value at the key if already exists
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>;
/// Returns the value at the specified key. Recall that by definition, any key that hasn't been
/// updated is associated with [`Self::EMPTY_VALUE`].
fn get_value(&self, key: &Self::Key) -> Self::Value;
/// Returns the leaf at the specified index.
fn get_leaf(&self, key: &Self::Key) -> Self::Leaf;
/// Returns the hash of a leaf
fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest;
/// Returns what a leaf would look like if a key-value pair were inserted into the tree, without
/// mutating the tree itself. The existing leaf can be empty.
///
/// To get a prospective leaf based on the current state of the tree, use `self.get_leaf(key)`
/// as the argument for `existing_leaf`. The return value from this function can be chained back
/// into this function as the first argument to continue making prospective changes.
///
/// # Invariants
/// Because this method is for a prospective key-value insertion into a specific leaf,
/// `existing_leaf` must have the same leaf index as `key` (as determined by
/// [`SparseMerkleTree::key_to_leaf_index()`]), or the result will be meaningless.
fn construct_prospective_leaf(
&self,
existing_leaf: Self::Leaf,
key: &Self::Key,
value: &Self::Value,
) -> Self::Leaf;
/// Maps a key to a leaf index
fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>;
@ -181,7 +348,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
#[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub(crate) struct InnerNode {
pub struct InnerNode {
pub left: RpoDigest,
pub right: RpoDigest,
}
@ -214,6 +381,48 @@ impl<const DEPTH: u8> LeafIndex<DEPTH> {
pub fn value(&self) -> u64 {
self.index.value()
}
/// Lowest common ancestor — finds the lowest (highest depth) [`NodeIndex`] that is an ancestor
/// of both `self` and `rhs`.
///
/// The general case algorithm is `O(n)`, however leaf indexes are always at the same depth,
/// and we only need find the depth of the lowest-common ancestor (since we can trivially get
/// its horizontal position based on either child's position), so we can reduce this to
/// `O(log n)`.
pub fn lca(&self, other: &Self) -> NodeIndex {
let mut self_scalar = self.index.to_scalar_index();
let mut other_scalar = other.index.to_scalar_index();
while self_scalar != other_scalar {
self_scalar >>= 1;
other_scalar >>= 1;
}
// Once we've shifted them enough to be equal, we've found a scalar index with the depth of
// the lowest common ancestor. Time to convert that scalar index to a depth, and apply that
// depth to either of our `NodeIndex`s to get the full position of that ancestor.
// In general, we can get the depth of a binary tree's scalar index by taking the binary
// logarithm of that index. However, for the root node, the scalar index is 0, and the
// logarithm is undefined for 0, so we trivally special case the root index.
if self_scalar == 0 {
return NodeIndex::root();
}
let depth = {
let depth = u128::ilog2(self_scalar);
// The scalar index should not be able to exceed `u8::MAX + u64::MAX` (as those are the
// maximum values `NodeIndex` can hold), and the binary logarithm of `u8::MAX +
// u64::MAX` is 64, which fits in a u8. In other words, this assert should only be
// possible to fail if `to_scalar_index()` is wildly incorrect.
debug_assert!(depth <= u8::MAX as u32);
depth as u8
};
let mut lca = self.index;
lca.move_up_to(depth);
lca
}
}
impl LeafIndex<SMT_MAX_DEPTH> {
@ -244,3 +453,86 @@ impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
Self::new(node_index.value())
}
}
// MUTATIONS
// ================================================================================================
/// A change to an inner node of a [`SparseMerkleTree`] that hasn't yet been applied.
/// [`MutationSet`] stores this type in relation to a [`NodeIndex`] to keep track of what changes
/// need to occur at which node indices.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NodeMutation {
/// Corresponds to [`SparseMerkleTree::remove_inner_node()`].
Removal,
/// Corresponds to [`SparseMerkleTree::insert_inner_node()`].
Addition(InnerNode),
}
/// Represents a group of prospective mutations to a `SparseMerkleTree`, created by
/// `SparseMerkleTree::compute_mutations()`, and that can be applied with
/// `SparseMerkleTree::apply_mutations()`.
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct MutationSet<const DEPTH: u8, K, V> {
/// The root of the Merkle tree this MutationSet is for, recorded at the time
/// [`SparseMerkleTree::compute_mutations()`] was called. Exists to guard against applying
/// mutations to the wrong tree or applying stale mutations to a tree that has since changed.
old_root: RpoDigest,
/// The set of nodes that need to be removed or added. The "effective" node at an index is the
/// Merkle tree's existing node at that index, with the [`NodeMutation`] in this map at that
/// index overlayed, if any. Each [`NodeMutation::Addition`] corresponds to a
/// [`SparseMerkleTree::insert_inner_node()`] call, and each [`NodeMutation::Removal`]
/// corresponds to a [`SparseMerkleTree::remove_inner_node()`] call.
node_mutations: BTreeMap<NodeIndex, NodeMutation>,
/// The set of top-level key-value pairs we're prospectively adding to the tree, including
/// adding empty values. The "effective" value for a key is the value in this BTreeMap, falling
/// back to the existing value in the Merkle tree. Each entry corresponds to a
/// [`SparseMerkleTree::insert_value()`] call.
new_pairs: BTreeMap<K, V>,
/// The calculated root for the Merkle tree, given these mutations. Publicly retrievable with
/// [`MutationSet::root()`]. Corresponds to a [`SparseMerkleTree::set_root()`]. call.
new_root: RpoDigest,
}
impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
/// Queries the root that was calculated during `SparseMerkleTree::compute_mutations()`. See
/// that method for more information.
pub fn root(&self) -> RpoDigest {
self.new_root
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use crate::merkle::{LeafIndex, NodeIndex, SMT_DEPTH};
prop_compose! {
fn leaf_index()(value in 0..2u64.pow(u64::BITS - 1)) -> LeafIndex<SMT_DEPTH> {
LeafIndex::new(value).unwrap()
}
}
proptest! {
/// Tests that the O(log n) algorithm has the same results as the naïve version.
#[test]
fn test_leaf_lca(left in leaf_index(), right in leaf_index()) {
let control: NodeIndex = {
let mut left = left.index;
let mut right = right.index;
loop {
if left == right {
break left;
}
left.move_up();
right.move_up();
}
};
let actual: NodeIndex = left.lca(&right);
assert_eq!(actual, control);
}
}
}

View file

@ -1,9 +1,11 @@
use alloc::collections::{BTreeMap, BTreeSet};
#[cfg(feature = "async")]
use std::sync::Arc;
use super::{
super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError,
MerklePath, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, SMT_MAX_DEPTH,
SMT_MIN_DEPTH,
MerklePath, MutationSet, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
SMT_MAX_DEPTH, SMT_MIN_DEPTH,
};
#[cfg(test)]
@ -20,7 +22,10 @@ mod tests;
pub struct SimpleSmt<const DEPTH: u8> {
root: RpoDigest,
leaves: BTreeMap<u64, Word>,
#[cfg(not(feature = "async"))]
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
#[cfg(feature = "async")]
inner_nodes: Arc<BTreeMap<NodeIndex, InnerNode>>,
}
impl<const DEPTH: u8> SimpleSmt<DEPTH> {
@ -52,7 +57,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
Ok(Self {
root,
leaves: BTreeMap::new(),
inner_nodes: BTreeMap::new(),
inner_nodes: Default::default(),
})
}
@ -175,6 +180,23 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
})
}
/// Gets a mutable reference to this structure's inner node mapping.
///
/// # Panics
/// This will panic if we have violated our own invariants and try to mutate these nodes while
/// Self::compute_mutations_parallel() is still running.
fn inner_nodes_mut(&mut self) -> &mut BTreeMap<NodeIndex, InnerNode> {
#[cfg(feature = "async")]
{
Arc::get_mut(&mut self.inner_nodes).unwrap()
}
#[cfg(not(feature = "async"))]
{
&mut self.inner_nodes
}
}
// STATE MUTATORS
// --------------------------------------------------------------------------------------------
@ -188,6 +210,48 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
<Self as SparseMerkleTree<DEPTH>>::insert(self, key, value)
}
/// Computes what changes are necessary to insert the specified key-value pairs into this
/// Merkle tree, allowing for validation before applying those changes.
///
/// This method returns a [`MutationSet`], which contains all the information for inserting
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
/// [`SimpleSmt::apply_mutations()`] can be called in order to commit these changes to the
/// Merkle tree, or [`drop()`] to discard them.
/// # Example
/// ```
/// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word};
/// # use miden_crypto::merkle::{LeafIndex, SimpleSmt, EmptySubtreeRoots, SMT_DEPTH};
/// let mut smt: SimpleSmt<3> = SimpleSmt::new().unwrap();
/// let pair = (LeafIndex::default(), Word::default());
/// let mutations = smt.compute_mutations(vec![pair]);
/// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(3, 0));
/// smt.apply_mutations(mutations);
/// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(3, 0));
/// ```
pub fn compute_mutations(
&self,
kv_pairs: impl IntoIterator<Item = (LeafIndex<DEPTH>, Word)>,
) -> MutationSet<DEPTH, LeafIndex<DEPTH>, Word> {
<Self as SparseMerkleTree<DEPTH>>::compute_mutations(self, kv_pairs)
}
/// Apply the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this
/// tree.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
/// root hash the `mutations` were computed against, and the second item is the actual
/// current root of this tree.
pub fn apply_mutations(
&mut self,
mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
) -> Result<(), MerkleError> {
<Self as SparseMerkleTree<DEPTH>>::apply_mutations(self, mutations)
}
/// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
/// computed as `DEPTH - SUBTREE_DEPTH`.
///
@ -229,7 +293,16 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
// add subtree's branch nodes (which includes the root)
// --------------
for (branch_idx, branch_node) in subtree.inner_nodes {
let subtree_nodes;
#[cfg(feature = "async")]
{
subtree_nodes = Arc::into_inner(subtree.inner_nodes).unwrap();
}
#[cfg(not(feature = "async"))]
{
subtree_nodes = subtree.inner_nodes
}
for (branch_idx, branch_node) in subtree_nodes {
let new_branch_idx = {
let new_depth = subtree_root_insertion_depth + branch_idx.depth();
let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
@ -238,7 +311,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
};
self.inner_nodes.insert(new_branch_idx, branch_node);
self.inner_nodes_mut().insert(new_branch_idx, branch_node);
}
// recompute nodes starting from subtree root
@ -266,19 +339,18 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
}
fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
self.inner_nodes.get(&index).cloned().unwrap_or_else(|| {
let node = EmptySubtreeRoots::entry(DEPTH, index.depth() + 1);
InnerNode { left: *node, right: *node }
})
self.inner_nodes
.get(&index)
.cloned()
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(DEPTH, index.depth()))
}
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
self.inner_nodes.insert(index, inner_node);
self.inner_nodes_mut().insert(index, inner_node);
}
fn remove_inner_node(&mut self, index: NodeIndex) {
let _ = self.inner_nodes.remove(&index);
let _ = self.inner_nodes_mut().remove(&index);
}
fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> {
@ -289,6 +361,10 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
}
}
fn get_value(&self, key: &LeafIndex<DEPTH>) -> Word {
self.get_leaf(key)
}
fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
let leaf_pos = key.value();
match self.leaves.get(&leaf_pos) {
@ -302,6 +378,15 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
leaf.into()
}
fn construct_prospective_leaf(
&self,
_existing_leaf: Word,
_key: &LeafIndex<DEPTH>,
value: &Word,
) -> Word {
*value
}
fn key_to_leaf_index(key: &LeafIndex<DEPTH>) -> LeafIndex<DEPTH> {
*key
}