feat: add support for hashmaps in Smt and SimpleSmt (#363)

This commit is contained in:
polydez 2025-01-02 23:23:12 +05:00 committed by GitHub
parent e4373e54c9
commit 7ee6d7fb93
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 171 additions and 84 deletions

View file

@ -17,7 +17,7 @@ jobs:
matrix: matrix:
toolchain: [stable, nightly] toolchain: [stable, nightly]
os: [ubuntu] os: [ubuntu]
args: [default, no-std] args: [default, smt-hashmaps, no-std]
timeout-minutes: 30 timeout-minutes: 30
steps: steps:
- uses: actions/checkout@main - uses: actions/checkout@main

View file

@ -7,6 +7,7 @@
- Fixed a bug in the implementation of `draw_integers` for `RpoRandomCoin` (#343). - Fixed a bug in the implementation of `draw_integers` for `RpoRandomCoin` (#343).
- [BREAKING] Refactor error messages and use `thiserror` to derive errors (#344). - [BREAKING] Refactor error messages and use `thiserror` to derive errors (#344).
- [BREAKING] Updated Winterfell dependency to v0.11 (#346). - [BREAKING] Updated Winterfell dependency to v0.11 (#346).
- Added support for hashmaps in `Smt` and `SimpleSmt` which gives up to 10x boost in some operations (#363).
## 0.12.0 (2024-10-30) ## 0.12.0 (2024-10-30)

31
Cargo.lock generated
View file

@ -11,6 +11,12 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "allocator-api2"
version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923"
[[package]] [[package]]
name = "anes" name = "anes"
version = "0.1.6" version = "0.1.6"
@ -349,6 +355,12 @@ version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0"
[[package]]
name = "equivalent"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]] [[package]]
name = "errno" name = "errno"
version = "0.3.10" version = "0.3.10"
@ -371,6 +383,12 @@ version = "1.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
[[package]]
name = "foldhash"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2"
[[package]] [[package]]
name = "generic-array" name = "generic-array"
version = "0.14.7" version = "0.14.7"
@ -410,6 +428,18 @@ dependencies = [
"crunchy", "crunchy",
] ]
[[package]]
name = "hashbrown"
version = "0.15.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289"
dependencies = [
"allocator-api2",
"equivalent",
"foldhash",
"serde",
]
[[package]] [[package]]
name = "heck" name = "heck"
version = "0.5.0" version = "0.5.0"
@ -535,6 +565,7 @@ dependencies = [
"criterion", "criterion",
"getrandom", "getrandom",
"glob", "glob",
"hashbrown",
"hex", "hex",
"num", "num",
"num-complex", "num-complex",

View file

@ -48,6 +48,7 @@ harness = false
concurrent = ["dep:rayon"] concurrent = ["dep:rayon"]
default = ["std", "concurrent"] default = ["std", "concurrent"]
executable = ["dep:clap", "dep:rand-utils", "std"] executable = ["dep:clap", "dep:rand-utils", "std"]
smt_hashmaps = ["dep:hashbrown"]
internal = [] internal = []
serde = ["dep:serde", "serde?/alloc", "winter-math/serde"] serde = ["dep:serde", "serde?/alloc", "winter-math/serde"]
std = [ std = [
@ -63,6 +64,7 @@ std = [
[dependencies] [dependencies]
blake3 = { version = "1.5", default-features = false } blake3 = { version = "1.5", default-features = false }
clap = { version = "4.5", optional = true, features = ["derive"] } clap = { version = "4.5", optional = true, features = ["derive"] }
hashbrown = { version = "0.15", optional = true, features = ["serde"] }
num = { version = "0.4", default-features = false, features = ["alloc", "libm"] } num = { version = "0.4", default-features = false, features = ["alloc", "libm"] }
num-complex = { version = "0.4", default-features = false } num-complex = { version = "0.4", default-features = false }
rand = { version = "0.8", default-features = false } rand = { version = "0.8", default-features = false }

View file

@ -46,6 +46,9 @@ doc: ## Generate and check documentation
test-default: ## Run tests with default features test-default: ## Run tests with default features
$(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --all-features $(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --all-features
.PHONY: test-smt-hashmaps
test-smt-hashmaps: ## Run tests with `smt_hashmaps` feature enabled
$(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --features smt_hashmaps
.PHONY: test-no-std .PHONY: test-no-std
test-no-std: ## Run tests with `no-default-features` (std) test-no-std: ## Run tests with `no-default-features` (std)
@ -53,7 +56,7 @@ test-no-std: ## Run tests with `no-default-features` (std)
.PHONY: test .PHONY: test
test: test-default test-no-std ## Run all tests test: test-default test-smt-hashmaps test-no-std ## Run all tests
# --- checking ------------------------------------------------------------------------------------ # --- checking ------------------------------------------------------------------------------------

View file

@ -63,6 +63,7 @@ This crate can be compiled with the following features:
- `concurrent`- enabled by default; enables multi-threaded implementation of `Smt::with_entries()` which significantly improves performance on multi-core CPUs. - `concurrent`- enabled by default; enables multi-threaded implementation of `Smt::with_entries()` which significantly improves performance on multi-core CPUs.
- `std` - enabled by default and relies on the Rust standard library. - `std` - enabled by default and relies on the Rust standard library.
- `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly. - `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly.
- `smt_hashmaps` - uses hashbrown hashmaps in SMT implementation which significantly improves performance of SMT updating. Keys ordering in SMT iterators is not guarantied when this feature is enabled.
All of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/) to support heap-allocated collections. All of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/) to support heap-allocated collections.

View file

@ -1,5 +1,11 @@
use alloc::string::String; use alloc::string::String;
use core::{cmp::Ordering, fmt::Display, ops::Deref, slice}; use core::{
cmp::Ordering,
fmt::Display,
hash::{Hash, Hasher},
ops::Deref,
slice,
};
use thiserror::Error; use thiserror::Error;
@ -55,6 +61,12 @@ impl RpoDigest {
} }
} }
impl Hash for RpoDigest {
fn hash<H: Hasher>(&self, state: &mut H) {
state.write(&self.as_bytes());
}
}
impl Digest for RpoDigest { impl Digest for RpoDigest {
fn as_bytes(&self) -> [u8; DIGEST_BYTES] { fn as_bytes(&self) -> [u8; DIGEST_BYTES] {
let mut result = [0; DIGEST_BYTES]; let mut result = [0; DIGEST_BYTES];

View file

@ -3,6 +3,7 @@ use super::RpoDigest;
/// Representation of a node with two children used for iterating over containers. /// Representation of a node with two children used for iterating over containers.
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(test, derive(PartialOrd, Ord))]
pub struct InnerNodeInfo { pub struct InnerNodeInfo {
pub value: RpoDigest, pub value: RpoDigest,
pub left: RpoDigest, pub left: RpoDigest,

View file

@ -1,12 +1,8 @@
use alloc::{ use alloc::{collections::BTreeSet, string::ToString, vec::Vec};
collections::{BTreeMap, BTreeSet},
string::ToString,
vec::Vec,
};
use super::{ use super::{
EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath, EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex, MerkleError,
MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, MerklePath, MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
}; };
mod error; mod error;
@ -30,6 +26,8 @@ pub const SMT_DEPTH: u8 = 64;
// SMT // SMT
// ================================================================================================ // ================================================================================================
type Leaves = super::Leaves<SmtLeaf>;
/// Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and values are represented /// Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and values are represented
/// by 4 field elements. /// by 4 field elements.
/// ///
@ -43,8 +41,8 @@ pub const SMT_DEPTH: u8 = 64;
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct Smt { pub struct Smt {
root: RpoDigest, root: RpoDigest,
leaves: BTreeMap<u64, SmtLeaf>, inner_nodes: InnerNodes,
inner_nodes: BTreeMap<NodeIndex, InnerNode>, leaves: Leaves,
} }
impl Smt { impl Smt {
@ -64,8 +62,8 @@ impl Smt {
Self { Self {
root, root,
leaves: BTreeMap::new(), inner_nodes: Default::default(),
inner_nodes: BTreeMap::new(), leaves: Default::default(),
} }
} }
@ -148,11 +146,7 @@ impl Smt {
/// # Panics /// # Panics
/// With debug assertions on, this function panics if `root` does not match the root node in /// With debug assertions on, this function panics if `root` does not match the root node in
/// `inner_nodes`. /// `inner_nodes`.
pub fn from_raw_parts( pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: RpoDigest) -> Self {
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, SmtLeaf>,
root: RpoDigest,
) -> Self {
// Our particular implementation of `from_raw_parts()` never returns `Err`. // Our particular implementation of `from_raw_parts()` never returns `Err`.
<Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap() <Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
} }
@ -339,8 +333,8 @@ impl SparseMerkleTree<SMT_DEPTH> for Smt {
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0); const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0);
fn from_raw_parts( fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>, inner_nodes: InnerNodes,
leaves: BTreeMap<u64, SmtLeaf>, leaves: Leaves,
root: RpoDigest, root: RpoDigest,
) -> Result<Self, MerkleError> { ) -> Result<Self, MerkleError> {
if cfg!(debug_assertions) { if cfg!(debug_assertions) {

View file

@ -1,9 +1,9 @@
use alloc::{collections::BTreeMap, vec::Vec}; use alloc::vec::Vec;
use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH}; use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH};
use crate::{ use crate::{
merkle::{ merkle::{
smt::{NodeMutation, SparseMerkleTree}, smt::{NodeMutation, SparseMerkleTree, UnorderedMap},
EmptySubtreeRoots, MerkleStore, MutationSet, EmptySubtreeRoots, MerkleStore, MutationSet,
}, },
utils::{Deserializable, Serializable}, utils::{Deserializable, Serializable},
@ -420,7 +420,7 @@ fn test_prospective_insertion() {
assert_eq!(revert.root(), root_empty, "reverse mutations new root did not match"); assert_eq!(revert.root(), root_empty, "reverse mutations new root did not match");
assert_eq!( assert_eq!(
revert.new_pairs, revert.new_pairs,
BTreeMap::from_iter([(key_1, EMPTY_WORD)]), UnorderedMap::from_iter([(key_1, EMPTY_WORD)]),
"reverse mutations pairs did not match" "reverse mutations pairs did not match"
); );
assert_eq!( assert_eq!(
@ -440,7 +440,7 @@ fn test_prospective_insertion() {
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match"); assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!( assert_eq!(
revert.new_pairs, revert.new_pairs,
BTreeMap::from_iter([(key_2, EMPTY_WORD), (key_3, EMPTY_WORD)]), UnorderedMap::from_iter([(key_2, EMPTY_WORD), (key_3, EMPTY_WORD)]),
"reverse mutations pairs did not match" "reverse mutations pairs did not match"
); );
@ -454,7 +454,7 @@ fn test_prospective_insertion() {
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match"); assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!( assert_eq!(
revert.new_pairs, revert.new_pairs,
BTreeMap::from_iter([(key_3, value_3)]), UnorderedMap::from_iter([(key_3, value_3)]),
"reverse mutations pairs did not match" "reverse mutations pairs did not match"
); );
@ -474,7 +474,7 @@ fn test_prospective_insertion() {
assert_eq!(revert.root(), old_root, "reverse mutations new root did not match"); assert_eq!(revert.root(), old_root, "reverse mutations new root did not match");
assert_eq!( assert_eq!(
revert.new_pairs, revert.new_pairs,
BTreeMap::from_iter([(key_1, value_1), (key_2, value_2), (key_3, value_3)]), UnorderedMap::from_iter([(key_1, value_1), (key_2, value_2), (key_3, value_3)]),
"reverse mutations pairs did not match" "reverse mutations pairs did not match"
); );
@ -603,21 +603,21 @@ fn test_smt_get_value() {
/// Tests that `entries()` works as expected /// Tests that `entries()` works as expected
#[test] #[test]
fn test_smt_entries() { fn test_smt_entries() {
let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]); let key_1 = RpoDigest::from([ONE, ONE, ONE, ONE]);
let key_2: RpoDigest = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]); let key_2 = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]);
let value_1 = [ONE; WORD_SIZE]; let value_1 = [ONE; WORD_SIZE];
let value_2 = [2_u32.into(); WORD_SIZE]; let value_2 = [2_u32.into(); WORD_SIZE];
let entries = [(key_1, value_1), (key_2, value_2)];
let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap(); let smt = Smt::with_entries(entries).unwrap();
let mut entries = smt.entries(); let mut expected = Vec::from_iter(entries);
expected.sort_by_key(|(k, _)| *k);
let mut actual: Vec<_> = smt.entries().cloned().collect();
actual.sort_by_key(|(k, _)| *k);
// Note: for simplicity, we assume the order `(k1,v1), (k2,v2)`. If a new implementation assert_eq!(actual, expected);
// switches the order, it is OK to modify the order here as well.
assert_eq!(&(key_1, value_1), entries.next().unwrap());
assert_eq!(&(key_2, value_2), entries.next().unwrap());
assert!(entries.next().is_none());
} }
/// Tests that `EMPTY_ROOT` constant generated in the `Smt` equals to the root of the empty tree of /// Tests that `EMPTY_ROOT` constant generated in the `Smt` equals to the root of the empty tree of

View file

@ -1,5 +1,5 @@
use alloc::{collections::BTreeMap, vec::Vec}; use alloc::{collections::BTreeMap, vec::Vec};
use core::mem; use core::{hash::Hash, mem};
use num::Integer; use num::Integer;
use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
@ -28,6 +28,15 @@ pub const SMT_MAX_DEPTH: u8 = 64;
// SPARSE MERKLE TREE // SPARSE MERKLE TREE
// ================================================================================================ // ================================================================================================
/// A map whose keys are not guarantied to be ordered.
#[cfg(feature = "smt_hashmaps")]
type UnorderedMap<K, V> = hashbrown::HashMap<K, V>;
#[cfg(not(feature = "smt_hashmaps"))]
type UnorderedMap<K, V> = alloc::collections::BTreeMap<K, V>;
type InnerNodes = UnorderedMap<NodeIndex, InnerNode>;
type Leaves<T> = UnorderedMap<u64, T>;
type NodeMutations = UnorderedMap<NodeIndex, NodeMutation>;
/// An abstract description of a sparse Merkle tree. /// An abstract description of a sparse Merkle tree.
/// ///
/// A sparse Merkle tree is a key-value map which also supports proving that a given value is indeed /// A sparse Merkle tree is a key-value map which also supports proving that a given value is indeed
@ -49,7 +58,7 @@ pub const SMT_MAX_DEPTH: u8 = 64;
/// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs. /// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs.
pub(crate) trait SparseMerkleTree<const DEPTH: u8> { pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
/// The type for a key /// The type for a key
type Key: Clone + Ord; type Key: Clone + Ord + Eq + Hash;
/// The type for a value /// The type for a value
type Value: Clone + PartialEq; type Value: Clone + PartialEq;
/// The type for a leaf /// The type for a leaf
@ -173,8 +182,8 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
use NodeMutation::*; use NodeMutation::*;
let mut new_root = self.root(); let mut new_root = self.root();
let mut new_pairs: BTreeMap<Self::Key, Self::Value> = Default::default(); let mut new_pairs: UnorderedMap<Self::Key, Self::Value> = Default::default();
let mut node_mutations: BTreeMap<NodeIndex, NodeMutation> = Default::default(); let mut node_mutations: NodeMutations = Default::default();
for (key, value) in kv_pairs { for (key, value) in kv_pairs {
// If the old value and the new value are the same, there is nothing to update. // If the old value and the new value are the same, there is nothing to update.
@ -341,7 +350,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
}); });
} }
let mut reverse_mutations = BTreeMap::new(); let mut reverse_mutations = NodeMutations::new();
for (index, mutation) in node_mutations { for (index, mutation) in node_mutations {
match mutation { match mutation {
Removal => { Removal => {
@ -359,7 +368,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
} }
} }
let mut reverse_pairs = BTreeMap::new(); let mut reverse_pairs = UnorderedMap::new();
for (key, value) in new_pairs { for (key, value) in new_pairs {
if let Some(old_value) = self.insert_value(key.clone(), value) { if let Some(old_value) = self.insert_value(key.clone(), value) {
reverse_pairs.insert(key, old_value); reverse_pairs.insert(key, old_value);
@ -384,8 +393,8 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
/// Construct this type from already computed leaves and nodes. The caller ensures passed /// Construct this type from already computed leaves and nodes. The caller ensures passed
/// arguments are correct and consistent with each other. /// arguments are correct and consistent with each other.
fn from_raw_parts( fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>, inner_nodes: InnerNodes,
leaves: BTreeMap<u64, Self::Leaf>, leaves: Leaves<Self::Leaf>,
root: RpoDigest, root: RpoDigest,
) -> Result<Self, MerkleError> ) -> Result<Self, MerkleError>
where where
@ -516,7 +525,7 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
#[cfg(feature = "concurrent")] #[cfg(feature = "concurrent")]
fn build_subtrees( fn build_subtrees(
mut entries: Vec<(Self::Key, Self::Value)>, mut entries: Vec<(Self::Key, Self::Value)>,
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<u64, Self::Leaf>) { ) -> (InnerNodes, Leaves<Self::Leaf>) {
entries.sort_by_key(|item| { entries.sort_by_key(|item| {
let index = Self::key_to_leaf_index(&item.0); let index = Self::key_to_leaf_index(&item.0);
index.value() index.value()
@ -531,10 +540,10 @@ pub(crate) trait SparseMerkleTree<const DEPTH: u8> {
#[cfg(feature = "concurrent")] #[cfg(feature = "concurrent")]
fn build_subtrees_from_sorted_entries( fn build_subtrees_from_sorted_entries(
entries: Vec<(Self::Key, Self::Value)>, entries: Vec<(Self::Key, Self::Value)>,
) -> (BTreeMap<NodeIndex, InnerNode>, BTreeMap<u64, Self::Leaf>) { ) -> (InnerNodes, Leaves<Self::Leaf>) {
use rayon::prelude::*; use rayon::prelude::*;
let mut accumulated_nodes: BTreeMap<NodeIndex, InnerNode> = Default::default(); let mut accumulated_nodes: InnerNodes = Default::default();
let PairComputations { let PairComputations {
leaves: mut leaf_subtrees, leaves: mut leaf_subtrees,
@ -651,8 +660,8 @@ pub enum NodeMutation {
/// Represents a group of prospective mutations to a `SparseMerkleTree`, created by /// Represents a group of prospective mutations to a `SparseMerkleTree`, created by
/// `SparseMerkleTree::compute_mutations()`, and that can be applied with /// `SparseMerkleTree::compute_mutations()`, and that can be applied with
/// `SparseMerkleTree::apply_mutations()`. /// `SparseMerkleTree::apply_mutations()`.
#[derive(Debug, Clone, PartialEq, Eq, Default)] #[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct MutationSet<const DEPTH: u8, K, V> { pub struct MutationSet<const DEPTH: u8, K: Eq + Hash, V> {
/// The root of the Merkle tree this MutationSet is for, recorded at the time /// The root of the Merkle tree this MutationSet is for, recorded at the time
/// [`SparseMerkleTree::compute_mutations()`] was called. Exists to guard against applying /// [`SparseMerkleTree::compute_mutations()`] was called. Exists to guard against applying
/// mutations to the wrong tree or applying stale mutations to a tree that has since changed. /// mutations to the wrong tree or applying stale mutations to a tree that has since changed.
@ -662,18 +671,18 @@ pub struct MutationSet<const DEPTH: u8, K, V> {
/// index overlayed, if any. Each [`NodeMutation::Addition`] corresponds to a /// index overlayed, if any. Each [`NodeMutation::Addition`] corresponds to a
/// [`SparseMerkleTree::insert_inner_node()`] call, and each [`NodeMutation::Removal`] /// [`SparseMerkleTree::insert_inner_node()`] call, and each [`NodeMutation::Removal`]
/// corresponds to a [`SparseMerkleTree::remove_inner_node()`] call. /// corresponds to a [`SparseMerkleTree::remove_inner_node()`] call.
node_mutations: BTreeMap<NodeIndex, NodeMutation>, node_mutations: NodeMutations,
/// The set of top-level key-value pairs we're prospectively adding to the tree, including /// The set of top-level key-value pairs we're prospectively adding to the tree, including
/// adding empty values. The "effective" value for a key is the value in this BTreeMap, falling /// adding empty values. The "effective" value for a key is the value in this BTreeMap, falling
/// back to the existing value in the Merkle tree. Each entry corresponds to a /// back to the existing value in the Merkle tree. Each entry corresponds to a
/// [`SparseMerkleTree::insert_value()`] call. /// [`SparseMerkleTree::insert_value()`] call.
new_pairs: BTreeMap<K, V>, new_pairs: UnorderedMap<K, V>,
/// The calculated root for the Merkle tree, given these mutations. Publicly retrievable with /// The calculated root for the Merkle tree, given these mutations. Publicly retrievable with
/// [`MutationSet::root()`]. Corresponds to a [`SparseMerkleTree::set_root()`]. call. /// [`MutationSet::root()`]. Corresponds to a [`SparseMerkleTree::set_root()`]. call.
new_root: RpoDigest, new_root: RpoDigest,
} }
impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> { impl<const DEPTH: u8, K: Eq + Hash, V> MutationSet<DEPTH, K, V> {
/// Returns the SMT root that was calculated during `SparseMerkleTree::compute_mutations()`. See /// Returns the SMT root that was calculated during `SparseMerkleTree::compute_mutations()`. See
/// that method for more information. /// that method for more information.
pub fn root(&self) -> RpoDigest { pub fn root(&self) -> RpoDigest {
@ -686,13 +695,13 @@ impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
} }
/// Returns the set of inner nodes that need to be removed or added. /// Returns the set of inner nodes that need to be removed or added.
pub fn node_mutations(&self) -> &BTreeMap<NodeIndex, NodeMutation> { pub fn node_mutations(&self) -> &NodeMutations {
&self.node_mutations &self.node_mutations
} }
/// Returns the set of top-level key-value pairs that need to be added, updated or deleted /// Returns the set of top-level key-value pairs that need to be added, updated or deleted
/// (i.e. set to `EMPTY_WORD`). /// (i.e. set to `EMPTY_WORD`).
pub fn new_pairs(&self) -> &BTreeMap<K, V> { pub fn new_pairs(&self) -> &UnorderedMap<K, V> {
&self.new_pairs &self.new_pairs
} }
} }
@ -702,8 +711,8 @@ impl<const DEPTH: u8, K, V> MutationSet<DEPTH, K, V> {
impl Serializable for InnerNode { impl Serializable for InnerNode {
fn write_into<W: ByteWriter>(&self, target: &mut W) { fn write_into<W: ByteWriter>(&self, target: &mut W) {
self.left.write_into(target); target.write(self.left);
self.right.write_into(target); target.write(self.right);
} }
} }
@ -739,23 +748,57 @@ impl Deserializable for NodeMutation {
} }
} }
impl<const DEPTH: u8, K: Serializable, V: Serializable> Serializable for MutationSet<DEPTH, K, V> { impl<const DEPTH: u8, K: Serializable + Eq + Hash, V: Serializable> Serializable
for MutationSet<DEPTH, K, V>
{
fn write_into<W: ByteWriter>(&self, target: &mut W) { fn write_into<W: ByteWriter>(&self, target: &mut W) {
target.write(self.old_root); target.write(self.old_root);
target.write(self.new_root); target.write(self.new_root);
self.node_mutations.write_into(target);
self.new_pairs.write_into(target); let inner_removals: Vec<_> = self
.node_mutations
.iter()
.filter(|(_, value)| matches!(value, NodeMutation::Removal))
.map(|(key, _)| key)
.collect();
let inner_additions: Vec<_> = self
.node_mutations
.iter()
.filter_map(|(key, value)| match value {
NodeMutation::Addition(node) => Some((key, node)),
_ => None,
})
.collect();
target.write(inner_removals);
target.write(inner_additions);
target.write_usize(self.new_pairs.len());
target.write_many(&self.new_pairs);
} }
} }
impl<const DEPTH: u8, K: Deserializable + Ord, V: Deserializable> Deserializable impl<const DEPTH: u8, K: Deserializable + Ord + Eq + Hash, V: Deserializable> Deserializable
for MutationSet<DEPTH, K, V> for MutationSet<DEPTH, K, V>
{ {
fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> { fn read_from<R: ByteReader>(source: &mut R) -> Result<Self, DeserializationError> {
let old_root = source.read()?; let old_root = source.read()?;
let new_root = source.read()?; let new_root = source.read()?;
let node_mutations = source.read()?;
let new_pairs = source.read()?; let inner_removals: Vec<NodeIndex> = source.read()?;
let inner_additions: Vec<(NodeIndex, InnerNode)> = source.read()?;
let node_mutations = NodeMutations::from_iter(
inner_removals.into_iter().map(|index| (index, NodeMutation::Removal)).chain(
inner_additions
.into_iter()
.map(|(index, node)| (index, NodeMutation::Addition(node))),
),
);
let num_new_pairs = source.read_usize()?;
let new_pairs = source.read_many(num_new_pairs)?;
let new_pairs = UnorderedMap::from_iter(new_pairs);
Ok(Self { Ok(Self {
old_root, old_root,
@ -768,6 +811,7 @@ impl<const DEPTH: u8, K: Deserializable + Ord, V: Deserializable> Deserializable
// SUBTREES // SUBTREES
// ================================================================================================ // ================================================================================================
/// A subtree is of depth 8. /// A subtree is of depth 8.
const SUBTREE_DEPTH: u8 = 8; const SUBTREE_DEPTH: u8 = 8;
@ -787,10 +831,10 @@ pub struct SubtreeLeaf {
} }
/// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`]. /// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`].
#[derive(Debug, Clone, PartialEq, Eq)] #[derive(Debug, Clone)]
pub(crate) struct PairComputations<K, L> { pub(crate) struct PairComputations<K, L> {
/// Literal leaves to be added to the sparse Merkle tree's internal mapping. /// Literal leaves to be added to the sparse Merkle tree's internal mapping.
pub nodes: BTreeMap<K, L>, pub nodes: UnorderedMap<K, L>,
/// "Conceptual" leaves that will be used for computations. /// "Conceptual" leaves that will be used for computations.
pub leaves: Vec<Vec<SubtreeLeaf>>, pub leaves: Vec<Vec<SubtreeLeaf>>,
} }
@ -818,7 +862,7 @@ impl<'s> SubtreeLeavesIter<'s> {
Self { leaves: leaves.drain(..).peekable() } Self { leaves: leaves.drain(..).peekable() }
} }
} }
impl core::iter::Iterator for SubtreeLeavesIter<'_> { impl Iterator for SubtreeLeavesIter<'_> {
type Item = Vec<SubtreeLeaf>; type Item = Vec<SubtreeLeaf>;
/// Each `next()` collects an entire subtree. /// Each `next()` collects an entire subtree.

View file

@ -1,11 +1,8 @@
use alloc::{ use alloc::{collections::BTreeSet, vec::Vec};
collections::{BTreeMap, BTreeSet},
vec::Vec,
};
use super::{ use super::{
super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex,
MerklePath, MutationSet, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, MerkleError, MerklePath, MutationSet, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD,
SMT_MAX_DEPTH, SMT_MIN_DEPTH, SMT_MAX_DEPTH, SMT_MIN_DEPTH,
}; };
@ -15,6 +12,8 @@ mod tests;
// SPARSE MERKLE TREE // SPARSE MERKLE TREE
// ================================================================================================ // ================================================================================================
type Leaves = super::Leaves<Word>;
/// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction. /// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction.
/// ///
/// The root of the tree is recomputed on each new leaf update. /// The root of the tree is recomputed on each new leaf update.
@ -22,8 +21,8 @@ mod tests;
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct SimpleSmt<const DEPTH: u8> { pub struct SimpleSmt<const DEPTH: u8> {
root: RpoDigest, root: RpoDigest,
leaves: BTreeMap<u64, Word>, inner_nodes: InnerNodes,
inner_nodes: BTreeMap<NodeIndex, InnerNode>, leaves: Leaves,
} }
impl<const DEPTH: u8> SimpleSmt<DEPTH> { impl<const DEPTH: u8> SimpleSmt<DEPTH> {
@ -54,8 +53,8 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
Ok(Self { Ok(Self {
root, root,
leaves: BTreeMap::new(), inner_nodes: Default::default(),
inner_nodes: BTreeMap::new(), leaves: Default::default(),
}) })
} }
@ -108,11 +107,7 @@ impl<const DEPTH: u8> SimpleSmt<DEPTH> {
/// # Panics /// # Panics
/// With debug assertions on, this function panics if `root` does not match the root node in /// With debug assertions on, this function panics if `root` does not match the root node in
/// `inner_nodes`. /// `inner_nodes`.
pub fn from_raw_parts( pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: RpoDigest) -> Self {
inner_nodes: BTreeMap<NodeIndex, InnerNode>,
leaves: BTreeMap<u64, Word>,
root: RpoDigest,
) -> Self {
// Our particular implementation of `from_raw_parts()` never returns `Err`. // Our particular implementation of `from_raw_parts()` never returns `Err`.
<Self as SparseMerkleTree<DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap() <Self as SparseMerkleTree<DEPTH>>::from_raw_parts(inner_nodes, leaves, root).unwrap()
} }
@ -344,8 +339,8 @@ impl<const DEPTH: u8> SparseMerkleTree<DEPTH> for SimpleSmt<DEPTH> {
const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0); const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0);
fn from_raw_parts( fn from_raw_parts(
inner_nodes: BTreeMap<NodeIndex, InnerNode>, inner_nodes: InnerNodes,
leaves: BTreeMap<u64, Word>, leaves: Leaves,
root: RpoDigest, root: RpoDigest,
) -> Result<Self, MerkleError> { ) -> Result<Self, MerkleError> {
if cfg!(debug_assertions) { if cfg!(debug_assertions) {

View file

@ -141,12 +141,15 @@ fn test_inner_node_iterator() -> Result<(), MerkleError> {
let l2n2 = tree.get_node(NodeIndex::make(2, 2))?; let l2n2 = tree.get_node(NodeIndex::make(2, 2))?;
let l2n3 = tree.get_node(NodeIndex::make(2, 3))?; let l2n3 = tree.get_node(NodeIndex::make(2, 3))?;
let nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect(); let mut nodes: Vec<InnerNodeInfo> = tree.inner_nodes().collect();
let expected = vec![ let mut expected = [
InnerNodeInfo { value: root, left: l1n0, right: l1n1 }, InnerNodeInfo { value: root, left: l1n0, right: l1n1 },
InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 }, InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 },
InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 }, InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 },
]; ];
nodes.sort();
expected.sort();
assert_eq!(nodes, expected); assert_eq!(nodes, expected);
Ok(()) Ok(())