Compare commits

..

No commits in common. "6addcd02267135d2b254dc4973c6bf06e88eff50" and "f4a9d5b027a671dfef1c81590675e14bcc8ab326" have entirely different histories.

13 changed files with 50 additions and 1169 deletions

View file

@ -3,7 +3,6 @@
- [BREAKING]: renamed `Mmr::open()` into `Mmr::open_at()` and `Mmr::peaks()` into `Mmr::peaks_at()` (#234). - [BREAKING]: renamed `Mmr::open()` into `Mmr::open_at()` and `Mmr::peaks()` into `Mmr::peaks_at()` (#234).
- Added `Mmr::open()` and `Mmr::peaks()` which rely on `Mmr::open_at()` and `Mmr::peaks()` respectively (#234). - Added `Mmr::open()` and `Mmr::peaks()` which rely on `Mmr::open_at()` and `Mmr::peaks()` respectively (#234).
- Standardised CI and Makefile across Miden repos (#323). - Standardised CI and Makefile across Miden repos (#323).
- Added `Smt::compute_mutations()` and `Smt::apply_mutations()` for validation-checked insertions (#327).
## 0.10.0 (2024-08-06) ## 0.10.0 (2024-08-06)

View file

@ -31,12 +31,8 @@ harness = false
name = "store" name = "store"
harness = false harness = false
[[bench]]
name = "subtree"
harness = false
[features] [features]
default = ["std", "async"] default = ["std"]
executable = ["dep:clap", "dep:rand-utils", "std"] executable = ["dep:clap", "dep:rand-utils", "std"]
serde = ["dep:serde", "serde?/alloc", "winter-math/serde"] serde = ["dep:serde", "serde?/alloc", "winter-math/serde"]
std = [ std = [
@ -48,7 +44,6 @@ std = [
"winter-math/std", "winter-math/std",
"winter-utils/std", "winter-utils/std",
] ]
async = ["serde?/rc"]
[dependencies] [dependencies]
blake3 = { version = "1.5", default-features = false } blake3 = { version = "1.5", default-features = false }

View file

@ -1,66 +0,0 @@
use std::{collections::BTreeMap, sync::Arc, time::Duration};
use criterion::{criterion_group, criterion_main, BatchSize, BenchmarkId, Criterion};
use miden_crypto::{
hash::rpo::RpoDigest,
merkle::{NodeIndex, NodeSubtreeComputer, Smt, SparseMerkleTree},
Felt, Word, ONE,
};
const SUBTREE_INTERVAL: u8 = 8;
fn setup_subtree8(tree_size: u64) -> (Smt, NodeIndex, Arc<BTreeMap<RpoDigest, Word>>, RpoDigest) {
let entries: BTreeMap<RpoDigest, Word> = (0..tree_size)
.into_iter()
.map(|i| {
let leaf_index = u64::MAX / (i + 1);
let key = RpoDigest::new([ONE, ONE, Felt::new(i), Felt::new(leaf_index)]);
let value = [ONE, ONE, ONE, Felt::new(i)];
(key, value)
})
.collect();
let control = Smt::with_entries(entries.clone()).unwrap();
let subtree = entries
.keys()
.map(|key| {
let index_for_key = NodeIndex::from(Smt::key_to_leaf_index(key));
index_for_key.parent_n(SUBTREE_INTERVAL)
})
.next()
.unwrap();
let control_hash = control.get_inner_node(subtree).hash();
(Smt::new(), subtree, Arc::new(entries), control_hash)
}
fn bench_subtree8(
(smt, subtree, entries, control_hash): (
Smt,
NodeIndex,
Arc<BTreeMap<RpoDigest, Word>>,
RpoDigest,
),
) {
let mut state = NodeSubtreeComputer::with_smt(&smt, Default::default(), entries);
let hash = state.get_or_make_hash(subtree);
assert_eq!(control_hash, hash);
}
fn smt_subtree8(c: &mut Criterion) {
let mut group = c.benchmark_group("subtree8");
group.measurement_time(Duration::from_secs(120));
group.sample_size(30);
for &tree_size in [32, 128, 512, 1024].iter() {
let bench_id = BenchmarkId::from_parameter(tree_size);
//group.throughput(Throughput::Elements(tree_size));
group.bench_with_input(bench_id, &tree_size, |bench, &tree_size| {
bench.iter_batched(|| setup_subtree8(tree_size), bench_subtree8, BatchSize::SmallInput);
});
}
group.finish();
}
criterion_group!(subtree_group, smt_subtree8);
criterion_main!(subtree_group);

View file

@ -74,7 +74,7 @@ where
rev rev
} }
/// Computes the first n powers of the 2nd root of unity, and put them in bit-reversed order. /// Computes the first n powers of the 2nth root of unity, and put them in bit-reversed order.
#[allow(dead_code)] #[allow(dead_code)]
fn bitreversed_powers(n: usize) -> Vec<Self> { fn bitreversed_powers(n: usize) -> Vec<Self> {
let psi = Self::primitive_root_of_unity(2 * n); let psi = Self::primitive_root_of_unity(2 * n);
@ -88,7 +88,7 @@ where
array array
} }
/// Computes the first n powers of the 2nd root of unity, invert them, and put them in /// Computes the first n powers of the 2nth root of unity, invert them, and put them in
/// bit-reversed order. /// bit-reversed order.
#[allow(dead_code)] #[allow(dead_code)]
fn bitreversed_powers_inverse(n: usize) -> Vec<Self> { fn bitreversed_powers_inverse(n: usize) -> Vec<Self> {

View file

@ -35,7 +35,6 @@ pub fn benchmark_smt() {
let mut tree = construction(entries, tree_size).unwrap(); let mut tree = construction(entries, tree_size).unwrap();
insertion(&mut tree, tree_size).unwrap(); insertion(&mut tree, tree_size).unwrap();
batched_insertion(&mut tree, tree_size).unwrap();
proof_generation(&mut tree, tree_size).unwrap(); proof_generation(&mut tree, tree_size).unwrap();
} }
@ -83,54 +82,6 @@ pub fn insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
Ok(()) Ok(())
} }
pub fn batched_insertion(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
println!("Running a batched insertion benchmark:");
let new_pairs: Vec<(RpoDigest, Word)> = (0..1000)
.map(|i| {
let key = Rpo256::hash(&rand_value::<u64>().to_be_bytes());
let value = [ONE, ONE, ONE, Felt::new(size + i)];
(key, value)
})
.collect();
let now = Instant::now();
let mutations = tree.compute_mutations(new_pairs);
let compute_elapsed = now.elapsed();
let now = Instant::now();
tree.apply_mutations(mutations).unwrap();
let apply_elapsed = now.elapsed();
println!(
"An average batch computation time measured by a 1k-batch into an SMT with {} key-value pairs over {:.3} milliseconds is {:.3} milliseconds",
size,
compute_elapsed.as_secs_f32() * 1000f32,
// Dividing by the number of iterations, 1000, and then multiplying by 1000 to get
// milliseconds, cancels out.
compute_elapsed.as_secs_f32(),
);
println!(
"An average batch application time measured by a 1k-batch into an SMT with {} key-value pairs over {:.3} milliseconds is {:.3} milliseconds",
size,
apply_elapsed.as_secs_f32() * 1000f32,
// Dividing by the number of iterations, 1000, and then multiplying by 1000 to get
// milliseconds, cancels out.
apply_elapsed.as_secs_f32(),
);
println!(
"An average batch insertion time measured by a 1k-batch into an SMT with {} key-value pairs totals to {:.3} milliseconds",
size,
(compute_elapsed + apply_elapsed).as_secs_f32() * 1000f32,
);
println!();
Ok(())
}
/// Runs the proof generation benchmark for the [`Smt`]. /// Runs the proof generation benchmark for the [`Smt`].
pub fn proof_generation(tree: &mut Smt, size: u64) -> Result<(), MerkleError> { pub fn proof_generation(tree: &mut Smt, size: u64) -> Result<(), MerkleError> {
println!("Running a proof generation benchmark:"); println!("Running a proof generation benchmark:");

View file

@ -1,6 +1,6 @@
use core::slice; use core::slice;
use super::{smt::InnerNode, Felt, RpoDigest, EMPTY_WORD}; use super::{Felt, RpoDigest, EMPTY_WORD};
// EMPTY NODES SUBTREES // EMPTY NODES SUBTREES
// ================================================================================================ // ================================================================================================
@ -25,17 +25,6 @@ impl EmptySubtreeRoots {
let pos = 255 - tree_depth + node_depth; let pos = 255 - tree_depth + node_depth;
&EMPTY_SUBTREES[pos as usize] &EMPTY_SUBTREES[pos as usize]
} }
/// Returns a sparse Merkle tree [`InnerNode`] with two empty children.
///
/// # Note
/// `node_depth` is the depth of the **parent** to have empty children. That is, `node_depth`
/// and the depth of the returned [`InnerNode`] are the same, and thus the empty hashes are for
/// subtrees of depth `node_depth + 1`.
pub(crate) const fn get_inner_node(tree_depth: u8, node_depth: u8) -> InnerNode {
let &child = Self::entry(tree_depth, node_depth + 1);
InnerNode { left: child, right: child }
}
} }
const EMPTY_SUBTREES: [RpoDigest; 256] = [ const EMPTY_SUBTREES: [RpoDigest; 256] = [

View file

@ -1,4 +1,4 @@
use core::{fmt::Display, num::NonZero}; use core::fmt::Display;
use super::{Felt, MerkleError, RpoDigest}; use super::{Felt, MerkleError, RpoDigest};
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
@ -72,53 +72,6 @@ impl NodeIndex {
Self::new(depth, value) Self::new(depth, value)
} }
/// Converts a scalar representation of a depth/value pair to a [`NodeIndex`].
///
/// This is the inverse operation of [`NodeIndex::to_scalar_index()`]. As `1` represents the
/// root node, `index` cannot be zero.
///
/// # Errors
/// Returns the same errors under the same conditions as [`NodeIndex::new()`].
///
/// # Panics
/// Panics if the depth indicated by `index` does not fit in a [`u8`], or if the row-value
/// indicated by `index` does not fit in a [`u64`].
pub fn from_scalar_index(index: NonZero<u128>) -> Result<Self, MerkleError> {
let index = index.get() - 1;
if index == 0 {
return Ok(Self::root());
}
// The log of 1 is always 0.
if index == 1 {
return Ok(Self::root().left_child());
}
let depth = {
let depth = u128::ilog2(index + 1);
assert!(depth <= u8::MAX as u32);
//let depth = f64::log2(index as f64).round();
//std::eprintln!("depth for scalar index {index} is {depth}");
//assert!(depth <= u8::MAX as f64);
depth as u8
};
let max_value_for_depth = (1 << depth) - 1;
assert!(
max_value_for_depth <= u64::MAX as u128,
"max_value ({max_value_for_depth}) does not fit in u64",
);
let value = {
let value = index - max_value_for_depth;
assert!(value <= u64::MAX as u128);
value as u64
};
Self::new(depth, value)
}
/// Creates a new node index pointing to the root of the tree. /// Creates a new node index pointing to the root of the tree.
pub const fn root() -> Self { pub const fn root() -> Self {
Self { depth: 0, value: 0 } Self { depth: 0, value: 0 }
@ -144,55 +97,6 @@ impl NodeIndex {
self self
} }
/// Returns the parent of the current node.
pub const fn parent(mut self) -> Self {
self.depth = self.depth.saturating_sub(1);
self.value >>= 1;
self
}
/// Returns the `n`th parent of the current node.
pub const fn parent_n(mut self, n: u8) -> Self {
debug_assert!(n < self.depth);
let delta = self.depth.saturating_sub(n);
self.depth = self.depth.saturating_sub(delta);
self.value >>= delta as u32;
self
}
/// Returns `true` if and only if `other` is a child of the current node.
pub const fn contains(&self, mut other: Self) -> bool {
loop {
if self.depth == other.depth && self.value == other.value {
return true;
}
if other.is_root() {
return false;
}
other = other.parent();
}
}
/// Returns the right-most descendent of the current node for a tree of `DEPTH` depth.
pub const fn rightmost_descendent<const DEPTH: u8>(mut self) -> Self {
while self.depth() < DEPTH {
self = self.right_child();
}
self
}
/// Returns the left-most descendent of the current node for a tree of `DEPTH` depth.
pub const fn leftmost_descendent<const DEPTH: u8>(mut self) -> Self {
while self.depth() < DEPTH {
self = self.left_child();
}
self
}
// PROVIDERS // PROVIDERS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@ -210,8 +114,8 @@ impl NodeIndex {
/// Returns the scalar representation of the depth/value pair. /// Returns the scalar representation of the depth/value pair.
/// ///
/// It is computed as `2^depth + value`. /// It is computed as `2^depth + value`.
pub const fn to_scalar_index(&self) -> u128 { pub const fn to_scalar_index(&self) -> u64 {
(1 << self.depth as u64) + (self.value as u128) (1 << self.depth as u64) + self.value
} }
/// Returns the depth of the current instance. /// Returns the depth of the current instance.
@ -306,52 +210,6 @@ mod tests {
assert!(NodeIndex::new(64, u64::MAX).is_ok()); assert!(NodeIndex::new(64, u64::MAX).is_ok());
} }
//#[test]
//fn test_traversal_roundtrip() {
// // Arbitrary value that's at the bottom and not in a corner.
// let start = NodeIndex::make(64, u64::MAX - 8);
//
// let mut index = start;
// while !index.is_root() {
// std::dbg!(&index);
// let as_traversal = index.to_traversal_index();
// let as_scalar = index.to_scalar_index() - 1;
// assert_eq!(as_traversal, as_scalar as u128);
// let round_trip = NodeIndex::from_traversal_index(as_traversal).unwrap();
// assert_eq!(index, round_trip, "{:?} did not round-trip as a traversal index", index);
// index.move_up();
// }
// assert!(index.is_root());
// let root_control = NodeIndex::root();
// assert_eq!(index, root_control);
//
// // Traversal index 0 should be root.
// assert_eq!(index, NodeIndex::from_traversal_index(0).unwrap());
//}
#[test]
fn test_scalar_roundtrip() {
// Arbitrary value that's at the bottom and not in a corner.
let start = NodeIndex::make(64, u64::MAX - 8);
let mut index = start;
while !index.is_root() {
let as_scalar = index.to_scalar_index();
let round_trip =
NodeIndex::from_scalar_index(NonZero::new(as_scalar).unwrap()).unwrap();
assert_eq!(index, round_trip, "{index:?} did not round-trip as a scalar index");
index.move_up();
}
//let start = NodeIndex::root().left_child().to_scalar_index();
//let max = u64::MAX as u128;
//for scalar in start..max {
// let index = NodeIndex::from_scalar_index(NonZero::new(scalar).unwrap()).unwrap();
// let round_trip = index.to_scalar_index();
// assert_eq!(scalar, round_trip, "scalar index {scalar} ({index:?}) did not round-trip");
//}
}
prop_compose! { prop_compose! {
fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex { fn node_index()(value in 0..2u64.pow(u64::BITS - 1)) -> NodeIndex {
// unwrap never panics because the range of depth is 0..u64::BITS // unwrap never panics because the range of depth is 0..u64::BITS

View file

@ -22,8 +22,8 @@ pub use path::{MerklePath, RootPath, ValuePath};
mod smt; mod smt;
pub use smt::{ pub use smt::{
InnerNode, LeafIndex, MutationSet, NodeSubtreeComputer, SimpleSmt, Smt, SmtLeaf, SmtLeafError, LeafIndex, SimpleSmt, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH,
SmtProof, SmtProofError, SparseMerkleTree, SMT_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
}; };
mod mmr; mod mmr;

View file

@ -350,7 +350,7 @@ impl Deserializable for SmtLeaf {
// ================================================================================================ // ================================================================================================
/// Converts a key-value tuple to an iterator of `Felt`s /// Converts a key-value tuple to an iterator of `Felt`s
pub(crate) fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt> { fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<Item = Felt> {
let key_elements = key.into_iter(); let key_elements = key.into_iter();
let value_elements = value.into_iter(); let value_elements = value.into_iter();
@ -359,7 +359,7 @@ pub(crate) fn kv_to_elements((key, value): (RpoDigest, Word)) -> impl Iterator<I
/// Compares two keys, compared element-by-element using their integer representations starting with /// Compares two keys, compared element-by-element using their integer representations starting with
/// the most significant element. /// the most significant element.
pub(crate) fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering { fn cmp_keys(key_1: RpoDigest, key_2: RpoDigest) -> Ordering {
for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() { for (v1, v2) in key_1.iter().zip(key_2.iter()).rev() {
let v1 = v1.as_int(); let v1 = v1.as_int();
let v2 = v2.as_int(); let v2 = v2.as_int();

View file

@ -1,6 +1,3 @@
#[cfg(feature = "async")]
use std::{collections::HashMap, sync::Arc};
use alloc::{ use alloc::{
collections::{BTreeMap, BTreeSet}, collections::{BTreeMap, BTreeSet},
string::ToString, string::ToString,
@ -9,12 +6,9 @@ use alloc::{
use super::{ use super::{
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath, EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath,
MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
}; };
#[cfg(feature = "async")]
use super::NodeMutation;
mod error; mod error;
pub use error::{SmtLeafError, SmtProofError}; pub use error::{SmtLeafError, SmtProofError};
@ -49,16 +43,8 @@ pub const SMT_DEPTH: u8 = 64;
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct Smt { pub struct Smt {
root: RpoDigest, root: RpoDigest,
#[cfg(not(feature = "async"))]
leaves: BTreeMap<u64, SmtLeaf>, leaves: BTreeMap<u64, SmtLeaf>,
#[cfg(feature = "async")]
leaves: Arc<BTreeMap<u64, SmtLeaf>>,
#[cfg(not(feature = "async"))]
inner_nodes: BTreeMap<NodeIndex, InnerNode>, inner_nodes: BTreeMap<NodeIndex, InnerNode>,
#[cfg(feature = "async")]
inner_nodes: Arc<BTreeMap<NodeIndex, InnerNode>>,
} }
impl Smt { impl Smt {
@ -78,8 +64,8 @@ impl Smt {
Self { Self {
root, root,
leaves: Default::default(), leaves: BTreeMap::new(),
inner_nodes: Default::default(), inner_nodes: BTreeMap::new(),
} }
} }
@ -115,11 +101,6 @@ impl Smt {
Ok(tree) Ok(tree)
} }
#[cfg(feature = "async")]
pub fn get_leaves(&self) -> Arc<BTreeMap<u64, SmtLeaf>> {
Arc::clone(&self.leaves)
}
// PUBLIC ACCESSORS // PUBLIC ACCESSORS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@ -140,7 +121,12 @@ impl Smt {
/// Returns the value associated with `key` /// Returns the value associated with `key`
pub fn get_value(&self, key: &RpoDigest) -> Word { pub fn get_value(&self, key: &RpoDigest) -> Word {
<Self as SparseMerkleTree<SMT_DEPTH>>::get_value(self, key) let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
match self.leaves.get(&leaf_pos) {
Some(leaf) => leaf.get_value(key).unwrap_or_default(),
None => EMPTY_WORD,
}
} }
/// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle /// Returns an opening of the leaf associated with `key`. Conceptually, an opening is a Merkle
@ -173,40 +159,6 @@ impl Smt {
}) })
} }
/// Gets a mutable reference to this structure's inner node mapping.
///
/// # Panics
/// This will panic if we have violated our own invariants and try to mutate these nodes while
/// Self::compute_mutations_parallel() is still running.
fn inner_nodes_mut(&mut self) -> &mut BTreeMap<NodeIndex, InnerNode> {
#[cfg(feature = "async")]
{
Arc::get_mut(&mut self.inner_nodes).unwrap()
}
#[cfg(not(feature = "async"))]
{
&mut self.inner_nodes
}
}
/// Gets a mutable reference to this structure's inner leaf mapping.
///
/// # Panics
/// This will panic if we have violated our own invariants and try to mutate these nodes while
/// Self::compute_mutations_parallel() is still running.
fn leaves_mut(&mut self) -> &mut BTreeMap<u64, SmtLeaf> {
#[cfg(feature = "async")]
{
Arc::get_mut(&mut self.leaves).unwrap()
}
#[cfg(not(feature = "async"))]
{
&mut self.leaves
}
}
// STATE MUTATORS // STATE MUTATORS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@ -220,47 +172,6 @@ impl Smt {
<Self as SparseMerkleTree<SMT_DEPTH>>::insert(self, key, value) <Self as SparseMerkleTree<SMT_DEPTH>>::insert(self, key, value)
} }
/// Computes what changes are necessary to insert the specified key-value pairs into this Merkle
/// tree, allowing for validation before applying those changes.
///
/// This method returns a [`MutationSet`], which contains all the information for inserting
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
/// [`Smt::apply_mutations()`] can be called in order to commit these changes to the Merkle
/// tree, or [`drop()`] to discard them.
///
/// # Example
/// ```
/// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word};
/// # use miden_crypto::merkle::{Smt, EmptySubtreeRoots, SMT_DEPTH};
/// let mut smt = Smt::new();
/// let pair = (RpoDigest::default(), Word::default());
/// let mutations = smt.compute_mutations(vec![pair]);
/// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
/// smt.apply_mutations(mutations);
/// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(SMT_DEPTH, 0));
/// ```
pub fn compute_mutations(
&self,
kv_pairs: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> MutationSet<SMT_DEPTH, RpoDigest, Word> {
<Self as SparseMerkleTree<SMT_DEPTH>>::compute_mutations(self, kv_pairs)
}
/// Apply the prospective mutations computed with [`Smt::compute_mutations()`] to this tree.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
pub fn apply_mutations(
&mut self,
mutations: MutationSet<SMT_DEPTH, RpoDigest, Word>,
) -> Result<(), MerkleError> {
<Self as SparseMerkleTree<SMT_DEPTH>>::apply_mutations(self, mutations)
}
// HELPERS // HELPERS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@ -271,12 +182,10 @@ impl Smt {
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key); let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
let leaves = self.leaves_mut(); match self.leaves.get_mut(&leaf_index.value()) {
match leaves.get_mut(&leaf_index.value()) {
Some(leaf) => leaf.insert(key, value), Some(leaf) => leaf.insert(key, value),
None => { None => {
leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value))); self.leaves.insert(leaf_index.value(), SmtLeaf::Single((key, value)));
None None
}, },
@ -287,12 +196,10 @@ impl Smt {
fn perform_remove(&mut self, key: RpoDigest) -> Option<Word> { fn perform_remove(&mut self, key: RpoDigest) -> Option<Word> {
let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key); let leaf_index: LeafIndex<SMT_DEPTH> = Self::key_to_leaf_index(&key);
let leaves = self.leaves_mut(); if let Some(leaf) = self.leaves.get_mut(&leaf_index.value()) {
if let Some(leaf) = leaves.get_mut(&leaf_index.value()) {
let (old_value, is_empty) = leaf.remove(key); let (old_value, is_empty) = leaf.remove(key);
if is_empty { if is_empty {
leaves.remove(&leaf_index.value()); self.leaves.remove(&leaf_index.value());
} }
old_value old_value
} else { } else {
@ -300,27 +207,6 @@ impl Smt {
None None
} }
} }
fn construct_prospective_leaf(
mut existing_leaf: SmtLeaf,
key: &RpoDigest,
value: &Word,
) -> SmtLeaf {
debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key));
match existing_leaf {
SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value),
_ => {
if *value != EMPTY_WORD {
existing_leaf.insert(*key, *value);
} else {
existing_leaf.remove(*key);
}
existing_leaf
},
}
}
} }
impl SparseMerkleTree<SMT_DEPTH> for Smt { impl SparseMerkleTree<SMT_DEPTH> for Smt {
@ -340,18 +226,19 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
} }
fn get_inner_node(&self, index: NodeIndex) -> InnerNode { fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
self.inner_nodes self.inner_nodes.get(&index).cloned().unwrap_or_else(|| {
.get(&index) let node = EmptySubtreeRoots::entry(SMT_DEPTH, index.depth() + 1);
.cloned()
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(SMT_DEPTH, index.depth())) InnerNode { left: *node, right: *node }
})
} }
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) { fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
self.inner_nodes_mut().insert(index, inner_node); self.inner_nodes.insert(index, inner_node);
} }
fn remove_inner_node(&mut self, index: NodeIndex) { fn remove_inner_node(&mut self, index: NodeIndex) {
let _ = self.inner_nodes_mut().remove(&index); let _ = self.inner_nodes.remove(&index);
} }
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> { fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value> {
@ -363,15 +250,6 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
} }
} }
fn get_value(&self, key: &Self::Key) -> Self::Value {
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
match self.leaves.get(&leaf_pos) {
Some(leaf) => leaf.get_value(key).unwrap_or_default(),
None => EMPTY_WORD,
}
}
fn get_leaf(&self, key: &RpoDigest) -> Self::Leaf { fn get_leaf(&self, key: &RpoDigest) -> Self::Leaf {
let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value(); let leaf_pos = LeafIndex::<SMT_DEPTH>::from(*key).value();
@ -385,28 +263,6 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
leaf.hash() leaf.hash()
} }
fn construct_prospective_leaf(
&self,
mut existing_leaf: SmtLeaf,
key: &RpoDigest,
value: &Word,
) -> SmtLeaf {
debug_assert_eq!(existing_leaf.index(), Self::key_to_leaf_index(key));
match existing_leaf {
SmtLeaf::Empty(_) => SmtLeaf::new_single(*key, *value),
_ => {
if *value != EMPTY_WORD {
existing_leaf.insert(*key, *value);
} else {
existing_leaf.remove(*key);
}
existing_leaf
},
}
}
fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex<SMT_DEPTH> { fn key_to_leaf_index(key: &RpoDigest) -> LeafIndex<SMT_DEPTH> {
let most_significant_felt = key[3]; let most_significant_felt = key[3];
LeafIndex::new_max_depth(most_significant_felt.as_int()) LeafIndex::new_max_depth(most_significant_felt.as_int())
@ -423,141 +279,6 @@ impl Default for Smt {
} }
} }
/// Just a [`NodeMutation`] with its hash already computed and stored.
#[cfg(feature = "async")]
pub struct ComputedNodeMutation {
pub mutation: NodeMutation,
pub hash: RpoDigest,
}
#[cfg(feature = "async")]
pub struct NodeSubtreeComputer {
inner_nodes: Arc<BTreeMap<NodeIndex, InnerNode>>,
leaves: Arc<BTreeMap<u64, SmtLeaf>>,
existing_mutations: Arc<HashMap<NodeIndex, ComputedNodeMutation>>,
new_mutations: HashMap<NodeIndex, ComputedNodeMutation>,
new_pairs: Arc<BTreeMap<RpoDigest, Word>>,
/// Cache indices we know to be dirty.
dirtied_indices: HashMap<NodeIndex, bool>,
cached_leaf_hashes: HashMap<LeafIndex<SMT_DEPTH>, RpoDigest>,
}
#[cfg(feature = "async")]
impl NodeSubtreeComputer {
pub fn with_smt(
smt: &Smt,
existing_mutations: Arc<HashMap<NodeIndex, ComputedNodeMutation>>,
new_pairs: Arc<BTreeMap<RpoDigest, Word>>,
) -> Self {
Self {
inner_nodes: Arc::clone(&smt.inner_nodes),
leaves: Arc::clone(&smt.leaves),
existing_mutations,
new_mutations: Default::default(),
new_pairs,
dirtied_indices: Default::default(),
cached_leaf_hashes: Default::default(),
}
}
pub(crate) fn is_index_dirty(&mut self, index_to_check: NodeIndex) -> bool {
if let Some(cached) = self.dirtied_indices.get(&index_to_check) {
return *cached;
}
// An index is dirty if there is a new pair at it, an known existing mutation at it, or an
// ancestor of one of those.
let is_dirty = self
.existing_mutations
.iter()
.map(|(index, _)| *index)
.chain(self.new_pairs.iter().map(|(key, _v)| Smt::key_to_leaf_index(key).index))
.filter(|&dirtied_index| index_to_check.contains(dirtied_index))
.next()
.is_some();
// This is somewhat expensive to compute, so cache this.
self.dirtied_indices.insert(index_to_check, is_dirty);
is_dirty
}
pub(crate) fn get_effective_leaf(&self, index: LeafIndex<SMT_DEPTH>) -> SmtLeaf {
let pairs_at_index = self
.new_pairs
.iter()
.filter(|&(new_key, _)| Smt::key_to_leaf_index(new_key) == index);
let existing_leaf = self
.leaves
.get(&index.index.value())
.cloned()
.unwrap_or_else(|| SmtLeaf::new_empty(index));
pairs_at_index.fold(existing_leaf, |acc, (k, v)| {
let existing_leaf = acc.clone();
Smt::construct_prospective_leaf(existing_leaf, k, v)
})
}
/// Does NOT check `new_mutations`.
pub(crate) fn get_clean_hash(&self, index: NodeIndex) -> Option<RpoDigest> {
self.existing_mutations
.get(&index)
.map(|ComputedNodeMutation { hash, .. }| *hash)
.or_else(|| self.inner_nodes.get(&index).map(|inner_node| InnerNode::hash(&inner_node)))
}
/// Retrieve a cached hash, or recursively compute it.
pub fn get_or_make_hash(&mut self, index: NodeIndex) -> RpoDigest {
use NodeMutation::*;
// If this is a leaf, then only do leaf stuff.
if index.depth() == SMT_DEPTH {
let index = LeafIndex::new(index.value()).unwrap();
return match self.cached_leaf_hashes.get(&index) {
Some(cached_hash) => cached_hash.clone(),
None => {
let leaf = self.get_effective_leaf(index);
let hash = Smt::hash_leaf(&leaf);
self.cached_leaf_hashes.insert(index, hash);
hash
},
};
}
// If we already computed this one earlier as a mutation, just return it.
if let Some(ComputedNodeMutation { hash, .. }) = self.new_mutations.get(&index) {
return *hash;
}
// Otherwise, we need to know if this node is one of the nodes we're in the process of
// recomputing, or if we can safely use the node already in the Merkle tree.
if !self.is_index_dirty(index) {
return self
.get_clean_hash(index)
.unwrap_or_else(|| *EmptySubtreeRoots::entry(SMT_DEPTH, index.depth()));
}
// If we got here, then we have to make, rather than get, this hash.
// Make sure we mark this index as now dirty.
self.dirtied_indices.insert(index, true);
// Recurse for the left and right sides.
let left = self.get_or_make_hash(index.left_child());
let right = self.get_or_make_hash(index.right_child());
let node = InnerNode { left, right };
let hash = node.hash();
let &equivalent_empty_hash = EmptySubtreeRoots::entry(SMT_DEPTH, index.depth());
let is_removal = hash == equivalent_empty_hash;
let new_entry = if is_removal { Removal } else { Addition(node) };
self.new_mutations
.insert(index, ComputedNodeMutation { hash, mutation: new_entry });
hash
}
}
// CONVERSIONS // CONVERSIONS
// ================================================================================================ // ================================================================================================

View file

@ -2,7 +2,7 @@ use alloc::vec::Vec;
use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH}; use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH};
use crate::{ use crate::{
merkle::{smt::SparseMerkleTree, EmptySubtreeRoots, MerkleStore}, merkle::{EmptySubtreeRoots, MerkleStore},
utils::{Deserializable, Serializable}, utils::{Deserializable, Serializable},
Word, ONE, WORD_SIZE, Word, ONE, WORD_SIZE,
}; };
@ -258,195 +258,6 @@ fn test_smt_removal() {
} }
} }
/// This tests that we can correctly calculate prospective leaves -- that is, we can construct
/// correct [`SmtLeaf`] values for a theoretical insertion on a Merkle tree without mutating or
/// cloning the tree.
#[test]
fn test_prospective_hash() {
let mut smt = Smt::default();
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
// Sort key_3 before key_1, to test non-append insertion.
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
// insert key-value 1
{
let prospective =
smt.construct_prospective_leaf(smt.get_leaf(&key_1), &key_1, &value_1).hash();
smt.insert(key_1, value_1);
let leaf = smt.get_leaf(&key_1);
assert_eq!(
prospective,
leaf.hash(),
"prospective hash for leaf {leaf:?} did not match actual hash",
);
}
// insert key-value 2
{
let prospective =
smt.construct_prospective_leaf(smt.get_leaf(&key_2), &key_2, &value_2).hash();
smt.insert(key_2, value_2);
let leaf = smt.get_leaf(&key_2);
assert_eq!(
prospective,
leaf.hash(),
"prospective hash for leaf {leaf:?} did not match actual hash",
);
}
// insert key-value 3
{
let prospective =
smt.construct_prospective_leaf(smt.get_leaf(&key_3), &key_3, &value_3).hash();
smt.insert(key_3, value_3);
let leaf = smt.get_leaf(&key_3);
assert_eq!(
prospective,
leaf.hash(),
"prospective hash for leaf {leaf:?} did not match actual hash",
);
}
// remove key 3
{
let old_leaf = smt.get_leaf(&key_3);
let old_value_3 = smt.insert(key_3, EMPTY_WORD);
assert_eq!(old_value_3, value_3);
let prospective_leaf =
smt.construct_prospective_leaf(smt.get_leaf(&key_3), &key_3, &old_value_3);
assert_eq!(
old_leaf.hash(),
prospective_leaf.hash(),
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
\n original leaf: {old_leaf:?}\
\n prospective leaf: {prospective_leaf:?}",
);
}
// remove key 2
{
let old_leaf = smt.get_leaf(&key_2);
let old_value_2 = smt.insert(key_2, EMPTY_WORD);
assert_eq!(old_value_2, value_2);
let prospective_leaf =
smt.construct_prospective_leaf(smt.get_leaf(&key_2), &key_2, &old_value_2);
assert_eq!(
old_leaf.hash(),
prospective_leaf.hash(),
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
\n original leaf: {old_leaf:?}\
\n prospective leaf: {prospective_leaf:?}",
);
}
// remove key 1
{
let old_leaf = smt.get_leaf(&key_1);
let old_value_1 = smt.insert(key_1, EMPTY_WORD);
assert_eq!(old_value_1, value_1);
let prospective_leaf =
smt.construct_prospective_leaf(smt.get_leaf(&key_1), &key_1, &old_value_1);
assert_eq!(
old_leaf.hash(),
prospective_leaf.hash(),
"removing and prospectively re-adding a leaf didn't yield the original leaf:\
\n original leaf: {old_leaf:?}\
\n prospective leaf: {prospective_leaf:?}",
);
}
}
/// This tests that we can perform prospective changes correctly.
#[test]
fn test_prospective_insertion() {
let mut smt = Smt::default();
let raw = 0b_01101001_01101100_00011111_11111111_10010110_10010011_11100000_00000000_u64;
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, Felt::new(raw)]);
let key_2: RpoDigest =
RpoDigest::from([2_u32.into(), 2_u32.into(), 2_u32.into(), Felt::new(raw)]);
// Sort key_3 before key_1, to test non-append insertion.
let key_3: RpoDigest =
RpoDigest::from([0_u32.into(), 0_u32.into(), 0_u32.into(), Felt::new(raw)]);
let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE];
let value_3: [Felt; 4] = [3_u32.into(); WORD_SIZE];
let root_empty = smt.root();
let root_1 = {
smt.insert(key_1, value_1);
smt.root()
};
let root_2 = {
smt.insert(key_2, value_2);
smt.root()
};
let root_3 = {
smt.insert(key_3, value_3);
smt.root()
};
// Test incremental updates.
let mut smt = Smt::default();
let mutations = smt.compute_mutations(vec![(key_1, value_1)]);
assert_eq!(mutations.root(), root_1, "prospective root 1 did not match actual root 1");
smt.apply_mutations(mutations).unwrap();
assert_eq!(smt.root(), root_1, "mutations before and after apply did not match");
let mutations = smt.compute_mutations(vec![(key_2, value_2)]);
assert_eq!(mutations.root(), root_2, "prospective root 2 did not match actual root 2");
let mutations =
smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_2, value_2), (key_3, value_3)]);
assert_eq!(mutations.root(), root_3, "mutations before and after apply did not match");
smt.apply_mutations(mutations).unwrap();
// Edge case: multiple values at the same key, where a later pair restores the original value.
let mutations = smt.compute_mutations(vec![(key_3, EMPTY_WORD), (key_3, value_3)]);
assert_eq!(mutations.root(), root_3);
smt.apply_mutations(mutations).unwrap();
assert_eq!(smt.root(), root_3);
// Test batch updates, and that the order doesn't matter.
let pairs =
vec![(key_3, value_2), (key_2, EMPTY_WORD), (key_1, EMPTY_WORD), (key_3, EMPTY_WORD)];
let mutations = smt.compute_mutations(pairs);
assert_eq!(
mutations.root(),
root_empty,
"prospective root for batch removal did not match actual root",
);
smt.apply_mutations(mutations).unwrap();
assert_eq!(smt.root(), root_empty, "mutations before and after apply did not match");
let pairs = vec![(key_3, value_3), (key_1, value_1), (key_2, value_2)];
let mutations = smt.compute_mutations(pairs);
assert_eq!(mutations.root(), root_3);
smt.apply_mutations(mutations).unwrap();
assert_eq!(smt.root(), root_3);
}
/// Tests that 2 key-value pairs stored in the same leaf have the same path /// Tests that 2 key-value pairs stored in the same leaf have the same path
#[test] #[test]
fn test_smt_path_to_keys_in_same_leaf_are_equal() { fn test_smt_path_to_keys_in_same_leaf_are_equal() {

View file

@ -1,4 +1,4 @@
use alloc::{collections::BTreeMap, vec::Vec}; use alloc::vec::Vec;
use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex}; use super::{EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex};
use crate::{ use crate::{
@ -7,9 +7,7 @@ use crate::{
}; };
mod full; mod full;
pub use full::{ pub use full::{Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH};
NodeSubtreeComputer, Smt, SmtLeaf, SmtLeafError, SmtProof, SmtProofError, SMT_DEPTH,
};
mod simple; mod simple;
pub use simple::SimpleSmt; pub use simple::SimpleSmt;
@ -45,13 +43,13 @@ pub const SMT_MAX_DEPTH: u8 = 64;
/// must accomodate all keys that map to the same leaf. /// must accomodate all keys that map to the same leaf.
/// ///
/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs. /// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
pub trait SparseMerkleTree<const DEPTH: u8> { pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
/// The type for a key /// The type for a key
type Key: Clone + Ord; type Key: Clone;
/// The type for a value /// The type for a value
type Value: Clone + PartialEq; type Value: Clone + PartialEq;
/// The type for a leaf /// The type for a leaf
type Leaf: Clone; type Leaf;
/// The type for an opening (i.e. a "proof") of a leaf /// The type for an opening (i.e. a "proof") of a leaf
type Opening; type Opening;
@ -142,149 +140,6 @@ pub trait SparseMerkleTree<const DEPTH: u8> {
self.set_root(node_hash); self.set_root(node_hash);
} }
/// Computes what changes are necessary to insert the specified key-value pairs into this Merkle
/// tree, allowing for validation before applying those changes.
///
/// This method returns a [`MutationSet`], which contains all the information for inserting
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
/// [`SparseMerkleTree::apply_mutations()`] can be called in order to commit these changes to
/// the Merkle tree, or [`drop()`] to discard them.
fn compute_mutations(
&self,
kv_pairs: impl IntoIterator<Item = (Self::Key, Self::Value)>,
) -> MutationSet<DEPTH, Self::Key, Self::Value> {
use NodeMutation::*;
let mut new_root = self.root();
let mut new_pairs: BTreeMap<Self::Key, Self::Value> = Default::default();
let mut node_mutations: BTreeMap<NodeIndex, NodeMutation> = Default::default();
for (key, value) in kv_pairs {
// If the old value and the new value are the same, there is nothing to update.
// For the unusual case that kv_pairs has multiple values at the same key, we'll have
// to check the key-value pairs we've already seen to get the "effective" old value.
let old_value = new_pairs.get(&key).cloned().unwrap_or_else(|| self.get_value(&key));
if value == old_value {
continue;
}
let leaf_index = Self::key_to_leaf_index(&key);
let mut node_index = NodeIndex::from(leaf_index);
// We need the current leaf's hash to calculate the new leaf, but in the rare case that
// `kv_pairs` has multiple pairs that go into the same leaf, then those pairs are also
// part of the "current leaf".
let old_leaf = {
let pairs_at_index = new_pairs
.iter()
.filter(|&(new_key, _)| Self::key_to_leaf_index(new_key) == leaf_index);
pairs_at_index.fold(self.get_leaf(&key), |acc, (k, v)| {
// Most of the time `pairs_at_index` should only contain a single entry (or
// none at all), as multi-leaves should be really rare.
let existing_leaf = acc.clone();
self.construct_prospective_leaf(existing_leaf, k, v)
})
};
let new_leaf = self.construct_prospective_leaf(old_leaf, &key, &value);
let mut new_child_hash = Self::hash_leaf(&new_leaf);
for node_depth in (0..node_index.depth()).rev() {
// Whether the node we're replacing is the right child or the left child.
let is_right = node_index.is_value_odd();
node_index.move_up();
let old_node = node_mutations
.get(&node_index)
.map(|mutation| match mutation {
Addition(node) => node.clone(),
Removal => EmptySubtreeRoots::get_inner_node(DEPTH, node_depth),
})
.unwrap_or_else(|| self.get_inner_node(node_index));
let new_node = if is_right {
InnerNode {
left: old_node.left,
right: new_child_hash,
}
} else {
InnerNode {
left: new_child_hash,
right: old_node.right,
}
};
// The next iteration will operate on this new node's hash.
new_child_hash = new_node.hash();
let &equivalent_empty_hash = EmptySubtreeRoots::entry(DEPTH, node_depth);
let is_removal = new_child_hash == equivalent_empty_hash;
let new_entry = if is_removal { Removal } else { Addition(new_node) };
node_mutations.insert(node_index, new_entry);
}
// Once we're at depth 0, the last node we made is the new root.
new_root = new_child_hash;
// And then we're done with this pair; on to the next one.
new_pairs.insert(key, value);
}
MutationSet {
old_root: self.root(),
new_root,
node_mutations,
new_pairs,
}
}
/// Apply the prospective mutations computed with [`SparseMerkleTree::compute_mutations()`] to
/// this tree.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`Vec`]. The first item is the root hash
/// the `mutations` were computed against, and the second item is the actual current root of
/// this tree.
fn apply_mutations(
&mut self,
mutations: MutationSet<DEPTH, Self::Key, Self::Value>,
) -> Result<(), MerkleError>
where
Self: Sized,
{
use NodeMutation::*;
let MutationSet {
old_root,
node_mutations,
new_pairs,
new_root,
} = mutations;
// Guard against accidentally trying to apply mutations that were computed against a
// different tree, including a stale version of this tree.
if old_root != self.root() {
return Err(MerkleError::ConflictingRoots(vec![old_root, self.root()]));
}
for (index, mutation) in node_mutations {
match mutation {
Removal => self.remove_inner_node(index),
Addition(node) => self.insert_inner_node(index, node),
}
}
for (key, value) in new_pairs {
self.insert_value(key, value);
}
self.set_root(new_root);
Ok(())
}
// REQUIRED METHODS // REQUIRED METHODS
// --------------------------------------------------------------------------------------------- // ---------------------------------------------------------------------------------------------
@ -306,34 +161,12 @@ pub trait SparseMerkleTree<const DEPTH: u8> {
/// Inserts a leaf node, and returns the value at the key if already exists /// Inserts a leaf node, and returns the value at the key if already exists
fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>; fn insert_value(&mut self, key: Self::Key, value: Self::Value) -> Option<Self::Value>;
/// Returns the value at the specified key. Recall that by definition, any key that hasn't been
/// updated is associated with [`Self::EMPTY_VALUE`].
fn get_value(&self, key: &Self::Key) -> Self::Value;
/// Returns the leaf at the specified index. /// Returns the leaf at the specified index.
fn get_leaf(&self, key: &Self::Key) -> Self::Leaf; fn get_leaf(&self, key: &Self::Key) -> Self::Leaf;
/// Returns the hash of a leaf /// Returns the hash of a leaf
fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest; fn hash_leaf(leaf: &Self::Leaf) -> RpoDigest;
/// Returns what a leaf would look like if a key-value pair were inserted into the tree, without
/// mutating the tree itself. The existing leaf can be empty.
///
/// To get a prospective leaf based on the current state of the tree, use `self.get_leaf(key)`
/// as the argument for `existing_leaf`. The return value from this function can be chained back
/// into this function as the first argument to continue making prospective changes.
///
/// # Invariants
/// Because this method is for a prospective key-value insertion into a specific leaf,
/// `existing_leaf` must have the same leaf index as `key` (as determined by
/// [`SparseMerkleTree::key_to_leaf_index()`]), or the result will be meaningless.
fn construct_prospective_leaf(
&self,
existing_leaf: Self::Leaf,
key: &Self::Key,
value: &Self::Value,
) -> Self::Leaf;
/// Maps a key to a leaf index /// Maps a key to a leaf index
fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>; fn key_to_leaf_index(key: &Self::Key) -> LeafIndex<DEPTH>;
@ -348,7 +181,7 @@ pub trait SparseMerkleTree<const DEPTH: u8> {
#[derive(Debug, Default, Clone, PartialEq, Eq)] #[derive(Debug, Default, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct InnerNode { pub(crate) struct InnerNode {
pub left: RpoDigest, pub left: RpoDigest,
pub right: RpoDigest, pub right: RpoDigest,
} }
@ -381,48 +214,6 @@ impl<const DEPTH: u8> LeafIndex<DEPTH> {
pub fn value(&self) -> u64 { pub fn value(&self) -> u64 {
self.index.value() self.index.value()
} }
/// Lowest common ancestor — finds the lowest (highest depth) [`NodeIndex`] that is an ancestor
/// of both `self` and `rhs`.
///
/// The general case algorithm is `O(n)`, however leaf indexes are always at the same depth,
/// and we only need find the depth of the lowest-common ancestor (since we can trivially get
/// its horizontal position based on either child's position), so we can reduce this to
/// `O(log n)`.
pub fn lca(&self, other: &Self) -> NodeIndex {
let mut self_scalar = self.index.to_scalar_index();
let mut other_scalar = other.index.to_scalar_index();
while self_scalar != other_scalar {
self_scalar >>= 1;
other_scalar >>= 1;
}
// Once we've shifted them enough to be equal, we've found a scalar index with the depth of
// the lowest common ancestor. Time to convert that scalar index to a depth, and apply that
// depth to either of our `NodeIndex`s to get the full position of that ancestor.
// In general, we can get the depth of a binary tree's scalar index by taking the binary
// logarithm of that index. However, for the root node, the scalar index is 0, and the
// logarithm is undefined for 0, so we trivally special case the root index.
if self_scalar == 0 {
return NodeIndex::root();
}
let depth = {
let depth = u128::ilog2(self_scalar);
// The scalar index should not be able to exceed `u8::MAX + u64::MAX` (as those are the
// maximum values `NodeIndex` can hold), and the binary logarithm of `u8::MAX +
// u64::MAX` is 64, which fits in a u8. In other words, this assert should only be
// possible to fail if `to_scalar_index()` is wildly incorrect.
debug_assert!(depth <= u8::MAX as u32);
depth as u8
};
let mut lca = self.index;
lca.move_up_to(depth);
lca
}
} }
impl LeafIndex<SMT_MAX_DEPTH> { impl LeafIndex<SMT_MAX_DEPTH> {
@ -453,86 +244,3 @@ impl<const DEPTH: u8> TryFrom<NodeIndex> for LeafIndex<DEPTH> {
Self::new(node_index.value()) Self::new(node_index.value())
} }
} }
// MUTATIONS
// ================================================================================================
/// A change to an inner node of a [`SparseMerkleTree`] that hasn't yet been applied.
/// [`MutationSet`] stores this type in relation to a [`NodeIndex`] to keep track of what changes
/// need to occur at which node indices.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NodeMutation {
/// Corresponds to [`SparseMerkleTree::remove_inner_node()`].
Removal,
/// Corresponds to [`SparseMerkleTree::insert_inner_node()`].
Addition(InnerNode),
}
/// Represents a group of prospective mutations to a `SparseMerkleTree`, created by
/// `SparseMerkleTree::compute_mutations()`, and that can be applied with
/// `SparseMerkleTree::apply_mutations()`.
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct MutationSet<const DEPTH: u8, K, V> {
/// The root of the Merkle tree this MutationSet is for, recorded at the time
/// [`SparseMerkleTree::compute_mutations()`] was called. Exists to guard against applying
/// mutations to the wrong tree or applying stale mutations to a tree that has since changed.
old_root: RpoDigest,
/// The set of nodes that need to be removed or added. The "effective" node at an index is the
/// Merkle tree's existing node at that index, with the [`NodeMutation`] in this map at that
/// index overlayed, if any. Each [`NodeMutation::Addition`] corresponds to a
/// [`SparseMerkleTree::insert_inner_node()`] call, and each [`NodeMutation::Removal`]
/// corresponds to a [`SparseMerkleTree::remove_inner_node()`] call.
node_mutations: BTreeMap<NodeIndex, NodeMutation>,
/// The set of top-level key-value pairs we're prospectively adding to the tree, including
/// adding empty values. The "effective" value for a key is the value in this BTreeMap, falling
/// back to the existing value in the Merkle tree. Each entry corresponds to a
/// [`SparseMerkleTree::insert_value()`] call.
new_pairs: BTreeMap<K, V>,
/// The calculated root for the Merkle tree, given these mutations. Publicly retrievable with
/// [`MutationSet::root()`]. Corresponds to a [`SparseMerkleTree::set_root()`]. call.
new_root: RpoDigest,
}
impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
/// Queries the root that was calculated during `SparseMerkleTree::compute_mutations()`. See
/// that method for more information.
pub fn root(&self) -> RpoDigest {
self.new_root
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use crate::merkle::{LeafIndex, NodeIndex, SMT_DEPTH};
prop_compose! {
fn leaf_index()(value in 0..2u64.pow(u64::BITS - 1)) -> LeafIndex<SMT_DEPTH> {
LeafIndex::new(value).unwrap()
}
}
proptest! {
/// Tests that the O(log n) algorithm has the same results as the naïve version.
#[test]
fn test_leaf_lca(left in leaf_index(), right in leaf_index()) {
let control: NodeIndex = {
let mut left = left.index;
let mut right = right.index;
loop {
if left == right {
break left;
}
left.move_up();
right.move_up();
}
};
let actual: NodeIndex = left.lca(&right);
assert_eq!(actual, control);
}
}
}

View file

@ -1,11 +1,9 @@
use alloc::collections::{BTreeMap, BTreeSet}; use alloc::collections::{BTreeMap, BTreeSet};
#[cfg(feature = "async")]
use std::sync::Arc;
use super::{ use super::{
super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError,
MerklePath, MutationSet, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, MerklePath, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, SMT_MAX_DEPTH,
SMT_MAX_DEPTH, SMT_MIN_DEPTH, SMT_MIN_DEPTH,
}; };
#[cfg(test)] #[cfg(test)]
@ -22,10 +20,7 @@ mod tests;
pub struct SimpleSmt<const DEPTH: u8> { pub struct SimpleSmt<const DEPTH: u8> {
root: RpoDigest, root: RpoDigest,
leaves: BTreeMap<u64, Word>, leaves: BTreeMap<u64, Word>,
#[cfg(not(feature = "async"))]
inner_nodes: BTreeMap<NodeIndex, InnerNode>, inner_nodes: BTreeMap<NodeIndex, InnerNode>,
#[cfg(feature = "async")]
inner_nodes: Arc<BTreeMap<NodeIndex, InnerNode>>,
} }
impl<const DEPTH: u8> SimpleSmt<DEPTH> { impl<const DEPTH: u8> SimpleSmt<DEPTH> {
@ -57,7 +52,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
Ok(Self { Ok(Self {
root, root,
leaves: BTreeMap::new(), leaves: BTreeMap::new(),
inner_nodes: Default::default(), inner_nodes: BTreeMap::new(),
}) })
} }
@ -180,23 +175,6 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
}) })
} }
/// Gets a mutable reference to this structure's inner node mapping.
///
/// # Panics
/// This will panic if we have violated our own invariants and try to mutate these nodes while
/// Self::compute_mutations_parallel() is still running.
fn inner_nodes_mut(&mut self) -> &mut BTreeMap<NodeIndex, InnerNode> {
#[cfg(feature = "async")]
{
Arc::get_mut(&mut self.inner_nodes).unwrap()
}
#[cfg(not(feature = "async"))]
{
&mut self.inner_nodes
}
}
// STATE MUTATORS // STATE MUTATORS
// -------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------
@ -210,48 +188,6 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
<Self as SparseMerkleTree<DEPTH>>::insert(self, key, value) <Self as SparseMerkleTree<DEPTH>>::insert(self, key, value)
} }
/// Computes what changes are necessary to insert the specified key-value pairs into this
/// Merkle tree, allowing for validation before applying those changes.
///
/// This method returns a [`MutationSet`], which contains all the information for inserting
/// `kv_pairs` into this Merkle tree already calculated, including the new root hash, which can
/// be queried with [`MutationSet::root()`]. Once a mutation set is returned,
/// [`SimpleSmt::apply_mutations()`] can be called in order to commit these changes to the
/// Merkle tree, or [`drop()`] to discard them.
/// # Example
/// ```
/// # use miden_crypto::{hash::rpo::RpoDigest, Felt, Word};
/// # use miden_crypto::merkle::{LeafIndex, SimpleSmt, EmptySubtreeRoots, SMT_DEPTH};
/// let mut smt: SimpleSmt<3> = SimpleSmt::new().unwrap();
/// let pair = (LeafIndex::default(), Word::default());
/// let mutations = smt.compute_mutations(vec![pair]);
/// assert_eq!(mutations.root(), *EmptySubtreeRoots::entry(3, 0));
/// smt.apply_mutations(mutations);
/// assert_eq!(smt.root(), *EmptySubtreeRoots::entry(3, 0));
/// ```
pub fn compute_mutations(
&self,
kv_pairs: impl IntoIterator<Item = (LeafIndex<DEPTH>, Word)>,
) -> MutationSet<DEPTH, LeafIndex<DEPTH>, Word> {
<Self as SparseMerkleTree<DEPTH>>::compute_mutations(self, kv_pairs)
}
/// Apply the prospective mutations computed with [`SimpleSmt::compute_mutations()`] to this
/// tree.
///
/// # Errors
/// If `mutations` was computed on a tree with a different root than this one, returns
/// [`MerkleError::ConflictingRoots`] with a two-item [`alloc::vec::Vec`]. The first item is the
/// root hash the `mutations` were computed against, and the second item is the actual
/// current root of this tree.
pub fn apply_mutations(
&mut self,
mutations: MutationSet<DEPTH, LeafIndex<DEPTH>, Word>,
) -> Result<(), MerkleError> {
<Self as SparseMerkleTree<DEPTH>>::apply_mutations(self, mutations)
}
/// Inserts a subtree at the specified index. The depth at which the subtree is inserted is /// Inserts a subtree at the specified index. The depth at which the subtree is inserted is
/// computed as `DEPTH - SUBTREE_DEPTH`. /// computed as `DEPTH - SUBTREE_DEPTH`.
/// ///
@ -293,16 +229,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
// add subtree's branch nodes (which includes the root) // add subtree's branch nodes (which includes the root)
// -------------- // --------------
let subtree_nodes; for (branch_idx, branch_node) in subtree.inner_nodes {
#[cfg(feature = "async")]
{
subtree_nodes = Arc::into_inner(subtree.inner_nodes).unwrap();
}
#[cfg(not(feature = "async"))]
{
subtree_nodes = subtree.inner_nodes
}
for (branch_idx, branch_node) in subtree_nodes {
let new_branch_idx = { let new_branch_idx = {
let new_depth = subtree_root_insertion_depth + branch_idx.depth(); let new_depth = subtree_root_insertion_depth + branch_idx.depth();
let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into()) let new_value = subtree_insertion_index * 2_u64.pow(branch_idx.depth().into())
@ -311,7 +238,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid") NodeIndex::new(new_depth, new_value).expect("index guaranteed to be valid")
}; };
self.inner_nodes_mut().insert(new_branch_idx, branch_node); self.inner_nodes.insert(new_branch_idx, branch_node);
} }
// recompute nodes starting from subtree root // recompute nodes starting from subtree root
@ -339,18 +266,19 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
} }
fn get_inner_node(&self, index: NodeIndex) -> InnerNode { fn get_inner_node(&self, index: NodeIndex) -> InnerNode {
self.inner_nodes self.inner_nodes.get(&index).cloned().unwrap_or_else(|| {
.get(&index) let node = EmptySubtreeRoots::entry(DEPTH, index.depth() + 1);
.cloned()
.unwrap_or_else(|| EmptySubtreeRoots::get_inner_node(DEPTH, index.depth())) InnerNode { left: *node, right: *node }
})
} }
fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) { fn insert_inner_node(&mut self, index: NodeIndex, inner_node: InnerNode) {
self.inner_nodes_mut().insert(index, inner_node); self.inner_nodes.insert(index, inner_node);
} }
fn remove_inner_node(&mut self, index: NodeIndex) { fn remove_inner_node(&mut self, index: NodeIndex) {
let _ = self.inner_nodes_mut().remove(&index); let _ = self.inner_nodes.remove(&index);
} }
fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> { fn insert_value(&mut self, key: LeafIndex<DEPTH>, value: Word) -> Option<Word> {
@ -361,10 +289,6 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
} }
} }
fn get_value(&self, key: &LeafIndex<DEPTH>) -> Word {
self.get_leaf(key)
}
fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word { fn get_leaf(&self, key: &LeafIndex<DEPTH>) -> Word {
let leaf_pos = key.value(); let leaf_pos = key.value();
match self.leaves.get(&leaf_pos) { match self.leaves.get(&leaf_pos) {
@ -378,15 +302,6 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
leaf.into() leaf.into()
} }
fn construct_prospective_leaf(
&self,
_existing_leaf: Word,
_key: &LeafIndex<DEPTH>,
value: &Word,
) -> Word {
*value
}
fn key_to_leaf_index(key: &LeafIndex<DEPTH>) -> LeafIndex<DEPTH> { fn key_to_leaf_index(key: &LeafIndex<DEPTH>) -> LeafIndex<DEPTH> {
*key *key
} }