diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bb9395b..2194ef6 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -17,7 +17,7 @@ jobs: matrix: toolchain: [stable, nightly] os: [ubuntu] - args: [default, no-std] + args: [default, smt-hashmaps, no-std] timeout-minutes: 30 steps: - uses: actions/checkout@main diff --git a/CHANGELOG.md b/CHANGELOG.md index 9bd0eb0..eb77cf6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ - Fixed a bug in the implementation of `draw_integers` for `RpoRandomCoin` (#343). - [BREAKING] Refactor error messages and use `thiserror` to derive errors (#344). - [BREAKING] Updated Winterfell dependency to v0.11 (#346). +- Added support for hashmaps in `Smt` and `SimpleSmt` which gives up to 10x boost in some operations (#363). ## 0.12.0 (2024-10-30) diff --git a/Cargo.lock b/Cargo.lock index 62e9fb9..b227f27 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "anes" version = "0.1.6" @@ -349,6 +355,12 @@ version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + [[package]] name = "errno" version = "0.3.10" @@ -371,6 +383,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f81ec6369c545a7d40e4589b5597581fa1c441fe1cce96dd1de43159910a36a2" + [[package]] name = "generic-array" version = "0.14.7" @@ -410,6 +428,18 @@ dependencies = [ "crunchy", ] +[[package]] +name = "hashbrown" +version = "0.15.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", + "serde", +] + [[package]] name = "heck" version = "0.5.0" @@ -535,6 +565,7 @@ dependencies = [ "criterion", "getrandom", "glob", + "hashbrown", "hex", "num", "num-complex", diff --git a/Cargo.toml b/Cargo.toml index fd750e8..3469c08 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ harness = false concurrent = ["dep:rayon"] default = ["std", "concurrent"] executable = ["dep:clap", "dep:rand-utils", "std"] +smt_hashmaps = ["dep:hashbrown"] internal = [] serde = ["dep:serde", "serde?/alloc", "winter-math/serde"] std = [ @@ -63,6 +64,7 @@ std = [ [dependencies] blake3 = { version = "1.5", default-features = false } clap = { version = "4.5", optional = true, features = ["derive"] } +hashbrown = { version = "0.15", optional = true, features = ["serde"] } num = { version = "0.4", default-features = false, features = ["alloc", "libm"] } num-complex = { version = "0.4", default-features = false } rand = { version = "0.8", default-features = false } diff --git a/Makefile b/Makefile index 233419d..6ab285d 100644 --- a/Makefile +++ b/Makefile @@ -46,6 +46,9 @@ doc: ## Generate and check documentation test-default: ## Run tests with default features $(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --all-features +.PHONY: test-smt-hashmaps +test-smt-hashmaps: ## Run tests with `smt_hashmaps` feature enabled + $(DEBUG_OVERFLOW_INFO) cargo nextest run --profile default --release --features smt_hashmaps .PHONY: test-no-std test-no-std: ## Run tests with `no-default-features` (std) @@ -53,7 +56,7 @@ test-no-std: ## Run tests with `no-default-features` (std) .PHONY: test -test: test-default test-no-std ## Run all tests +test: test-default test-smt-hashmaps test-no-std ## Run all tests # --- checking ------------------------------------------------------------------------------------ diff --git a/README.md b/README.md index c9a758a..7933e0b 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,7 @@ This crate can be compiled with the following features: - `concurrent`- enabled by default; enables multi-threaded implementation of `Smt::with_entries()` which significantly improves performance on multi-core CPUs. - `std` - enabled by default and relies on the Rust standard library. - `no_std` does not rely on the Rust standard library and enables compilation to WebAssembly. +- `smt_hashmaps` - uses hashbrown hashmaps in SMT implementation which significantly improves performance of SMT updating. Keys ordering in SMT iterators is not guarantied when this feature is enabled. All of these features imply the use of [alloc](https://doc.rust-lang.org/alloc/) to support heap-allocated collections. diff --git a/src/hash/rescue/rpo/digest.rs b/src/hash/rescue/rpo/digest.rs index 4466369..a525892 100644 --- a/src/hash/rescue/rpo/digest.rs +++ b/src/hash/rescue/rpo/digest.rs @@ -1,5 +1,11 @@ use alloc::string::String; -use core::{cmp::Ordering, fmt::Display, ops::Deref, slice}; +use core::{ + cmp::Ordering, + fmt::Display, + hash::{Hash, Hasher}, + ops::Deref, + slice, +}; use thiserror::Error; @@ -55,6 +61,12 @@ impl RpoDigest { } } +impl Hash for RpoDigest { + fn hash(&self, state: &mut H) { + state.write(&self.as_bytes()); + } +} + impl Digest for RpoDigest { fn as_bytes(&self) -> [u8; DIGEST_BYTES] { let mut result = [0; DIGEST_BYTES]; diff --git a/src/merkle/node.rs b/src/merkle/node.rs index bf18d38..b821903 100644 --- a/src/merkle/node.rs +++ b/src/merkle/node.rs @@ -3,6 +3,7 @@ use super::RpoDigest; /// Representation of a node with two children used for iterating over containers. #[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] +#[cfg_attr(test, derive(PartialOrd, Ord))] pub struct InnerNodeInfo { pub value: RpoDigest, pub left: RpoDigest, diff --git a/src/merkle/smt/full/mod.rs b/src/merkle/smt/full/mod.rs index f90dd7e..5cd641e 100644 --- a/src/merkle/smt/full/mod.rs +++ b/src/merkle/smt/full/mod.rs @@ -1,12 +1,8 @@ -use alloc::{ - collections::{BTreeMap, BTreeSet}, - string::ToString, - vec::Vec, -}; +use alloc::{collections::BTreeSet, string::ToString, vec::Vec}; use super::{ - EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, MerklePath, - MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, + EmptySubtreeRoots, Felt, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex, MerkleError, + MerklePath, MutationSet, NodeIndex, Rpo256, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, }; mod error; @@ -30,6 +26,8 @@ pub const SMT_DEPTH: u8 = 64; // SMT // ================================================================================================ +type Leaves = super::Leaves; + /// Sparse Merkle tree mapping 256-bit keys to 256-bit values. Both keys and values are represented /// by 4 field elements. /// @@ -43,8 +41,8 @@ pub const SMT_DEPTH: u8 = 64; #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct Smt { root: RpoDigest, - leaves: BTreeMap, - inner_nodes: BTreeMap, + inner_nodes: InnerNodes, + leaves: Leaves, } impl Smt { @@ -64,8 +62,8 @@ impl Smt { Self { root, - leaves: BTreeMap::new(), - inner_nodes: BTreeMap::new(), + inner_nodes: Default::default(), + leaves: Default::default(), } } @@ -148,11 +146,7 @@ impl Smt { /// # Panics /// With debug assertions on, this function panics if `root` does not match the root node in /// `inner_nodes`. - pub fn from_raw_parts( - inner_nodes: BTreeMap, - leaves: BTreeMap, - root: RpoDigest, - ) -> Self { + pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: RpoDigest) -> Self { // Our particular implementation of `from_raw_parts()` never returns `Err`. >::from_raw_parts(inner_nodes, leaves, root).unwrap() } @@ -339,8 +333,8 @@ impl SparseMerkleTree for Smt { const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(SMT_DEPTH, 0); fn from_raw_parts( - inner_nodes: BTreeMap, - leaves: BTreeMap, + inner_nodes: InnerNodes, + leaves: Leaves, root: RpoDigest, ) -> Result { if cfg!(debug_assertions) { diff --git a/src/merkle/smt/full/tests.rs b/src/merkle/smt/full/tests.rs index 6404f29..787c01a 100644 --- a/src/merkle/smt/full/tests.rs +++ b/src/merkle/smt/full/tests.rs @@ -1,9 +1,9 @@ -use alloc::{collections::BTreeMap, vec::Vec}; +use alloc::vec::Vec; use super::{Felt, LeafIndex, NodeIndex, Rpo256, RpoDigest, Smt, SmtLeaf, EMPTY_WORD, SMT_DEPTH}; use crate::{ merkle::{ - smt::{NodeMutation, SparseMerkleTree}, + smt::{NodeMutation, SparseMerkleTree, UnorderedMap}, EmptySubtreeRoots, MerkleStore, MutationSet, }, utils::{Deserializable, Serializable}, @@ -420,7 +420,7 @@ fn test_prospective_insertion() { assert_eq!(revert.root(), root_empty, "reverse mutations new root did not match"); assert_eq!( revert.new_pairs, - BTreeMap::from_iter([(key_1, EMPTY_WORD)]), + UnorderedMap::from_iter([(key_1, EMPTY_WORD)]), "reverse mutations pairs did not match" ); assert_eq!( @@ -440,7 +440,7 @@ fn test_prospective_insertion() { assert_eq!(revert.root(), old_root, "reverse mutations new root did not match"); assert_eq!( revert.new_pairs, - BTreeMap::from_iter([(key_2, EMPTY_WORD), (key_3, EMPTY_WORD)]), + UnorderedMap::from_iter([(key_2, EMPTY_WORD), (key_3, EMPTY_WORD)]), "reverse mutations pairs did not match" ); @@ -454,7 +454,7 @@ fn test_prospective_insertion() { assert_eq!(revert.root(), old_root, "reverse mutations new root did not match"); assert_eq!( revert.new_pairs, - BTreeMap::from_iter([(key_3, value_3)]), + UnorderedMap::from_iter([(key_3, value_3)]), "reverse mutations pairs did not match" ); @@ -474,7 +474,7 @@ fn test_prospective_insertion() { assert_eq!(revert.root(), old_root, "reverse mutations new root did not match"); assert_eq!( revert.new_pairs, - BTreeMap::from_iter([(key_1, value_1), (key_2, value_2), (key_3, value_3)]), + UnorderedMap::from_iter([(key_1, value_1), (key_2, value_2), (key_3, value_3)]), "reverse mutations pairs did not match" ); @@ -603,21 +603,21 @@ fn test_smt_get_value() { /// Tests that `entries()` works as expected #[test] fn test_smt_entries() { - let key_1: RpoDigest = RpoDigest::from([ONE, ONE, ONE, ONE]); - let key_2: RpoDigest = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]); + let key_1 = RpoDigest::from([ONE, ONE, ONE, ONE]); + let key_2 = RpoDigest::from([2_u32, 2_u32, 2_u32, 2_u32]); let value_1 = [ONE; WORD_SIZE]; let value_2 = [2_u32.into(); WORD_SIZE]; + let entries = [(key_1, value_1), (key_2, value_2)]; - let smt = Smt::with_entries([(key_1, value_1), (key_2, value_2)]).unwrap(); + let smt = Smt::with_entries(entries).unwrap(); - let mut entries = smt.entries(); + let mut expected = Vec::from_iter(entries); + expected.sort_by_key(|(k, _)| *k); + let mut actual: Vec<_> = smt.entries().cloned().collect(); + actual.sort_by_key(|(k, _)| *k); - // Note: for simplicity, we assume the order `(k1,v1), (k2,v2)`. If a new implementation - // switches the order, it is OK to modify the order here as well. - assert_eq!(&(key_1, value_1), entries.next().unwrap()); - assert_eq!(&(key_2, value_2), entries.next().unwrap()); - assert!(entries.next().is_none()); + assert_eq!(actual, expected); } /// Tests that `EMPTY_ROOT` constant generated in the `Smt` equals to the root of the empty tree of diff --git a/src/merkle/smt/mod.rs b/src/merkle/smt/mod.rs index de501b8..ec43957 100644 --- a/src/merkle/smt/mod.rs +++ b/src/merkle/smt/mod.rs @@ -1,5 +1,5 @@ use alloc::{collections::BTreeMap, vec::Vec}; -use core::mem; +use core::{hash::Hash, mem}; use num::Integer; use winter_utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable}; @@ -28,6 +28,15 @@ pub const SMT_MAX_DEPTH: u8 = 64; // SPARSE MERKLE TREE // ================================================================================================ +/// A map whose keys are not guarantied to be ordered. +#[cfg(feature = "smt_hashmaps")] +type UnorderedMap = hashbrown::HashMap; +#[cfg(not(feature = "smt_hashmaps"))] +type UnorderedMap = alloc::collections::BTreeMap; +type InnerNodes = UnorderedMap; +type Leaves = UnorderedMap; +type NodeMutations = UnorderedMap; + /// An abstract description of a sparse Merkle tree. /// /// A sparse Merkle tree is a key-value map which also supports proving that a given value is indeed @@ -49,7 +58,7 @@ pub const SMT_MAX_DEPTH: u8 = 64; /// [SparseMerkleTree] currently doesn't support optimizations that compress Merkle proofs. pub(crate) trait SparseMerkleTree { /// The type for a key - type Key: Clone + Ord; + type Key: Clone + Ord + Eq + Hash; /// The type for a value type Value: Clone + PartialEq; /// The type for a leaf @@ -173,8 +182,8 @@ pub(crate) trait SparseMerkleTree { use NodeMutation::*; let mut new_root = self.root(); - let mut new_pairs: BTreeMap = Default::default(); - let mut node_mutations: BTreeMap = Default::default(); + let mut new_pairs: UnorderedMap = Default::default(); + let mut node_mutations: NodeMutations = Default::default(); for (key, value) in kv_pairs { // If the old value and the new value are the same, there is nothing to update. @@ -341,7 +350,7 @@ pub(crate) trait SparseMerkleTree { }); } - let mut reverse_mutations = BTreeMap::new(); + let mut reverse_mutations = NodeMutations::new(); for (index, mutation) in node_mutations { match mutation { Removal => { @@ -359,7 +368,7 @@ pub(crate) trait SparseMerkleTree { } } - let mut reverse_pairs = BTreeMap::new(); + let mut reverse_pairs = UnorderedMap::new(); for (key, value) in new_pairs { if let Some(old_value) = self.insert_value(key.clone(), value) { reverse_pairs.insert(key, old_value); @@ -384,8 +393,8 @@ pub(crate) trait SparseMerkleTree { /// Construct this type from already computed leaves and nodes. The caller ensures passed /// arguments are correct and consistent with each other. fn from_raw_parts( - inner_nodes: BTreeMap, - leaves: BTreeMap, + inner_nodes: InnerNodes, + leaves: Leaves, root: RpoDigest, ) -> Result where @@ -516,7 +525,7 @@ pub(crate) trait SparseMerkleTree { #[cfg(feature = "concurrent")] fn build_subtrees( mut entries: Vec<(Self::Key, Self::Value)>, - ) -> (BTreeMap, BTreeMap) { + ) -> (InnerNodes, Leaves) { entries.sort_by_key(|item| { let index = Self::key_to_leaf_index(&item.0); index.value() @@ -531,10 +540,10 @@ pub(crate) trait SparseMerkleTree { #[cfg(feature = "concurrent")] fn build_subtrees_from_sorted_entries( entries: Vec<(Self::Key, Self::Value)>, - ) -> (BTreeMap, BTreeMap) { + ) -> (InnerNodes, Leaves) { use rayon::prelude::*; - let mut accumulated_nodes: BTreeMap = Default::default(); + let mut accumulated_nodes: InnerNodes = Default::default(); let PairComputations { leaves: mut leaf_subtrees, @@ -651,8 +660,8 @@ pub enum NodeMutation { /// Represents a group of prospective mutations to a `SparseMerkleTree`, created by /// `SparseMerkleTree::compute_mutations()`, and that can be applied with /// `SparseMerkleTree::apply_mutations()`. -#[derive(Debug, Clone, PartialEq, Eq, Default)] -pub struct MutationSet { +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct MutationSet { /// The root of the Merkle tree this MutationSet is for, recorded at the time /// [`SparseMerkleTree::compute_mutations()`] was called. Exists to guard against applying /// mutations to the wrong tree or applying stale mutations to a tree that has since changed. @@ -662,18 +671,18 @@ pub struct MutationSet { /// index overlayed, if any. Each [`NodeMutation::Addition`] corresponds to a /// [`SparseMerkleTree::insert_inner_node()`] call, and each [`NodeMutation::Removal`] /// corresponds to a [`SparseMerkleTree::remove_inner_node()`] call. - node_mutations: BTreeMap, + node_mutations: NodeMutations, /// The set of top-level key-value pairs we're prospectively adding to the tree, including /// adding empty values. The "effective" value for a key is the value in this BTreeMap, falling /// back to the existing value in the Merkle tree. Each entry corresponds to a /// [`SparseMerkleTree::insert_value()`] call. - new_pairs: BTreeMap, + new_pairs: UnorderedMap, /// The calculated root for the Merkle tree, given these mutations. Publicly retrievable with /// [`MutationSet::root()`]. Corresponds to a [`SparseMerkleTree::set_root()`]. call. new_root: RpoDigest, } -impl MutationSet { +impl MutationSet { /// Returns the SMT root that was calculated during `SparseMerkleTree::compute_mutations()`. See /// that method for more information. pub fn root(&self) -> RpoDigest { @@ -686,13 +695,13 @@ impl MutationSet { } /// Returns the set of inner nodes that need to be removed or added. - pub fn node_mutations(&self) -> &BTreeMap { + pub fn node_mutations(&self) -> &NodeMutations { &self.node_mutations } /// Returns the set of top-level key-value pairs that need to be added, updated or deleted /// (i.e. set to `EMPTY_WORD`). - pub fn new_pairs(&self) -> &BTreeMap { + pub fn new_pairs(&self) -> &UnorderedMap { &self.new_pairs } } @@ -702,8 +711,8 @@ impl MutationSet { impl Serializable for InnerNode { fn write_into(&self, target: &mut W) { - self.left.write_into(target); - self.right.write_into(target); + target.write(self.left); + target.write(self.right); } } @@ -739,23 +748,57 @@ impl Deserializable for NodeMutation { } } -impl Serializable for MutationSet { +impl Serializable + for MutationSet +{ fn write_into(&self, target: &mut W) { target.write(self.old_root); target.write(self.new_root); - self.node_mutations.write_into(target); - self.new_pairs.write_into(target); + + let inner_removals: Vec<_> = self + .node_mutations + .iter() + .filter(|(_, value)| matches!(value, NodeMutation::Removal)) + .map(|(key, _)| key) + .collect(); + let inner_additions: Vec<_> = self + .node_mutations + .iter() + .filter_map(|(key, value)| match value { + NodeMutation::Addition(node) => Some((key, node)), + _ => None, + }) + .collect(); + + target.write(inner_removals); + target.write(inner_additions); + + target.write_usize(self.new_pairs.len()); + target.write_many(&self.new_pairs); } } -impl Deserializable +impl Deserializable for MutationSet { fn read_from(source: &mut R) -> Result { let old_root = source.read()?; let new_root = source.read()?; - let node_mutations = source.read()?; - let new_pairs = source.read()?; + + let inner_removals: Vec = source.read()?; + let inner_additions: Vec<(NodeIndex, InnerNode)> = source.read()?; + + let node_mutations = NodeMutations::from_iter( + inner_removals.into_iter().map(|index| (index, NodeMutation::Removal)).chain( + inner_additions + .into_iter() + .map(|(index, node)| (index, NodeMutation::Addition(node))), + ), + ); + + let num_new_pairs = source.read_usize()?; + let new_pairs = source.read_many(num_new_pairs)?; + let new_pairs = UnorderedMap::from_iter(new_pairs); Ok(Self { old_root, @@ -768,6 +811,7 @@ impl Deserializable // SUBTREES // ================================================================================================ + /// A subtree is of depth 8. const SUBTREE_DEPTH: u8 = 8; @@ -787,10 +831,10 @@ pub struct SubtreeLeaf { } /// Helper struct to organize the return value of [`SparseMerkleTree::sorted_pairs_to_leaves()`]. -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone)] pub(crate) struct PairComputations { /// Literal leaves to be added to the sparse Merkle tree's internal mapping. - pub nodes: BTreeMap, + pub nodes: UnorderedMap, /// "Conceptual" leaves that will be used for computations. pub leaves: Vec>, } @@ -818,7 +862,7 @@ impl<'s> SubtreeLeavesIter<'s> { Self { leaves: leaves.drain(..).peekable() } } } -impl core::iter::Iterator for SubtreeLeavesIter<'_> { +impl Iterator for SubtreeLeavesIter<'_> { type Item = Vec; /// Each `next()` collects an entire subtree. diff --git a/src/merkle/smt/simple/mod.rs b/src/merkle/smt/simple/mod.rs index e1e3bd8..166cc98 100644 --- a/src/merkle/smt/simple/mod.rs +++ b/src/merkle/smt/simple/mod.rs @@ -1,11 +1,8 @@ -use alloc::{ - collections::{BTreeMap, BTreeSet}, - vec::Vec, -}; +use alloc::{collections::BTreeSet, vec::Vec}; use super::{ - super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, LeafIndex, MerkleError, - MerklePath, MutationSet, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, + super::ValuePath, EmptySubtreeRoots, InnerNode, InnerNodeInfo, InnerNodes, LeafIndex, + MerkleError, MerklePath, MutationSet, NodeIndex, RpoDigest, SparseMerkleTree, Word, EMPTY_WORD, SMT_MAX_DEPTH, SMT_MIN_DEPTH, }; @@ -15,6 +12,8 @@ mod tests; // SPARSE MERKLE TREE // ================================================================================================ +type Leaves = super::Leaves; + /// A sparse Merkle tree with 64-bit keys and 4-element leaf values, without compaction. /// /// The root of the tree is recomputed on each new leaf update. @@ -22,8 +21,8 @@ mod tests; #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct SimpleSmt { root: RpoDigest, - leaves: BTreeMap, - inner_nodes: BTreeMap, + inner_nodes: InnerNodes, + leaves: Leaves, } impl SimpleSmt { @@ -54,8 +53,8 @@ impl SimpleSmt { Ok(Self { root, - leaves: BTreeMap::new(), - inner_nodes: BTreeMap::new(), + inner_nodes: Default::default(), + leaves: Default::default(), }) } @@ -108,11 +107,7 @@ impl SimpleSmt { /// # Panics /// With debug assertions on, this function panics if `root` does not match the root node in /// `inner_nodes`. - pub fn from_raw_parts( - inner_nodes: BTreeMap, - leaves: BTreeMap, - root: RpoDigest, - ) -> Self { + pub fn from_raw_parts(inner_nodes: InnerNodes, leaves: Leaves, root: RpoDigest) -> Self { // Our particular implementation of `from_raw_parts()` never returns `Err`. >::from_raw_parts(inner_nodes, leaves, root).unwrap() } @@ -344,8 +339,8 @@ impl SparseMerkleTree for SimpleSmt { const EMPTY_ROOT: RpoDigest = *EmptySubtreeRoots::entry(DEPTH, 0); fn from_raw_parts( - inner_nodes: BTreeMap, - leaves: BTreeMap, + inner_nodes: InnerNodes, + leaves: Leaves, root: RpoDigest, ) -> Result { if cfg!(debug_assertions) { diff --git a/src/merkle/smt/simple/tests.rs b/src/merkle/smt/simple/tests.rs index 84bad47..9078c52 100644 --- a/src/merkle/smt/simple/tests.rs +++ b/src/merkle/smt/simple/tests.rs @@ -141,12 +141,15 @@ fn test_inner_node_iterator() -> Result<(), MerkleError> { let l2n2 = tree.get_node(NodeIndex::make(2, 2))?; let l2n3 = tree.get_node(NodeIndex::make(2, 3))?; - let nodes: Vec = tree.inner_nodes().collect(); - let expected = vec![ + let mut nodes: Vec = tree.inner_nodes().collect(); + let mut expected = [ InnerNodeInfo { value: root, left: l1n0, right: l1n1 }, InnerNodeInfo { value: l1n0, left: l2n0, right: l2n1 }, InnerNodeInfo { value: l1n1, left: l2n2, right: l2n3 }, ]; + nodes.sort(); + expected.sort(); + assert_eq!(nodes, expected); Ok(())