diff --git a/CHANGELOG.md b/CHANGELOG.md index fe493fd..1b39f3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ - Sort keys in a leaf in the concurrent implementation of `Smt::with_entries`, ensuring consistency with the sequential version (#385). - Skip unchanged leaves in the concurrent implementation of `Smt::compute_mutations` (#385). - Add range checks to `ntru_gen` for Falcon DSA (#391). +- Optimized duplicate key detection in `Smt::with_entries_concurrent` (#395). ## 0.13.3 (2025-02-18) diff --git a/src/merkle/smt/full/concurrent/mod.rs b/src/merkle/smt/full/concurrent/mod.rs index 14730e6..2862b2a 100644 --- a/src/merkle/smt/full/concurrent/mod.rs +++ b/src/merkle/smt/full/concurrent/mod.rs @@ -1,11 +1,12 @@ -use alloc::{collections::BTreeSet, vec::Vec}; +use alloc::vec::Vec; use core::mem; use num::Integer; +use rayon::prelude::*; use super::{ - leaf, EmptySubtreeRoots, InnerNode, InnerNodes, LeafIndex, Leaves, MerkleError, MutationSet, - NodeIndex, RpoDigest, Smt, SmtLeaf, SparseMerkleTree, Word, SMT_DEPTH, + leaf, EmptySubtreeRoots, InnerNode, InnerNodes, Leaves, MerkleError, MutationSet, NodeIndex, + RpoDigest, Smt, SmtLeaf, SparseMerkleTree, Word, SMT_DEPTH, }; use crate::merkle::smt::{NodeMutation, NodeMutations, UnorderedMap}; @@ -33,28 +34,25 @@ impl Smt { /// 3. These subtree roots are recursively merged to become the "leaves" for the next iteration, /// which processes the next 8 levels up. This continues until the final root of the tree is /// computed at depth 0. + /// + /// # Errors + /// Returns an error if the provided entries contain multiple values for the same key. pub(crate) fn with_entries_concurrent( entries: impl IntoIterator, ) -> Result { - let mut seen_keys = BTreeSet::new(); - let entries: Vec<_> = entries - .into_iter() - // Filter out key-value pairs whose value is empty. - .filter(|(_key, value)| *value != Self::EMPTY_VALUE) - .map(|(key, value)| { - if seen_keys.insert(key) { - Ok((key, value)) - } else { - Err(MerkleError::DuplicateValuesForIndex( - LeafIndex::::from(key).value(), - )) - } - }) - .collect::>()?; + let entries: Vec<(RpoDigest, Word)> = entries.into_iter().collect(); + if entries.is_empty() { return Ok(Self::default()); } - let (inner_nodes, leaves) = Self::build_subtrees(entries); + + let (inner_nodes, leaves) = Self::build_subtrees(entries)?; + + // All the leaves are empty + if inner_nodes.is_empty() { + return Ok(Self::default()); + } + let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash(); >::from_raw_parts(inner_nodes, leaves, root) } @@ -80,8 +78,6 @@ impl Smt { where Self: Sized + Sync, { - use rayon::prelude::*; - // Collect and sort key-value pairs by their corresponding leaf index let mut sorted_kv_pairs: Vec<_> = kv_pairs.into_iter().collect(); sorted_kv_pairs.par_sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value()); @@ -206,9 +202,14 @@ impl Smt { /// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs. /// - /// `entries` need not be sorted. This function will sort them. - fn build_subtrees(mut entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) { - entries.sort_by_key(|item| { + /// `entries` need not be sorted. This function will sort them using parallel sorting. + /// + /// # Errors + /// Returns an error if the provided entries contain multiple values for the same key. + fn build_subtrees( + mut entries: Vec<(RpoDigest, Word)>, + ) -> Result<(InnerNodes, Leaves), MerkleError> { + entries.par_sort_unstable_by_key(|item| { let index = Self::key_to_leaf_index(&item.0); index.value() }); @@ -219,15 +220,23 @@ impl Smt { /// /// This function is mostly an implementation detail of /// [`Smt::with_entries_concurrent()`]. - fn build_subtrees_from_sorted_entries(entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) { - use rayon::prelude::*; - + /// + /// # Errors + /// Returns an error if the provided entries contain multiple values for the same key. + fn build_subtrees_from_sorted_entries( + entries: Vec<(RpoDigest, Word)>, + ) -> Result<(InnerNodes, Leaves), MerkleError> { let mut accumulated_nodes: InnerNodes = Default::default(); let PairComputations { leaves: mut leaf_subtrees, nodes: initial_leaves, - } = Self::sorted_pairs_to_leaves(entries); + } = Self::sorted_pairs_to_leaves(entries)?; + + // If there are no leaves, we can return early + if initial_leaves.is_empty() { + return Ok((accumulated_nodes, initial_leaves)); + } for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { let (nodes, mut subtree_roots): (Vec>, Vec) = @@ -247,7 +256,7 @@ impl Smt { debug_assert!(!leaf_subtrees.is_empty()); } - (accumulated_nodes, initial_leaves) + Ok((accumulated_nodes, initial_leaves)) } // LEAF NODE CONSTRUCTION @@ -260,31 +269,46 @@ impl Smt { /// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If /// `pairs` is not correctly sorted, the returned computations will be incorrect. /// + /// # Errors + /// Returns an error if the provided pairs contain multiple values for the same key. + /// /// # Panics /// With debug assertions on, this function panics if it detects that `pairs` is not correctly /// sorted. Without debug assertions, the returned computations will be incorrect. - fn sorted_pairs_to_leaves(pairs: Vec<(RpoDigest, Word)>) -> PairComputations { - Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| { - Some(Self::pairs_to_leaf(leaf_pairs)) - }) + fn sorted_pairs_to_leaves( + pairs: Vec<(RpoDigest, Word)>, + ) -> Result, MerkleError> { + Self::process_sorted_pairs_to_leaves(pairs, Self::pairs_to_leaf) } /// Constructs a single leaf from an arbitrary amount of key-value pairs. /// Those pairs must all have the same leaf index. - fn pairs_to_leaf(mut pairs: Vec<(RpoDigest, Word)>) -> SmtLeaf { + /// + /// # Errors + /// Returns a `MerkleError::DuplicateValuesForIndex` if the provided pairs contain multiple + /// values for the same key. + /// + /// # Returns + /// - `Ok(Some(SmtLeaf))` if a valid leaf is constructed. + /// - `Ok(None)` if the only provided value is `Self::EMPTY_VALUE`. + fn pairs_to_leaf(mut pairs: Vec<(RpoDigest, Word)>) -> Result, MerkleError> { assert!(!pairs.is_empty()); if pairs.len() > 1 { pairs.sort_by(|(key_1, _), (key_2, _)| leaf::cmp_keys(*key_1, *key_2)); - SmtLeaf::new_multiple(pairs).unwrap() + // Check for duplicates in a sorted list by comparing adjacent pairs + if let Some(window) = pairs.windows(2).find(|window| window[0].0 == window[1].0) { + // If we find a duplicate, return an error + let col = Self::key_to_leaf_index(&window[0].0).index.value(); + return Err(MerkleError::DuplicateValuesForIndex(col)); + } + Ok(Some(SmtLeaf::new_multiple(pairs).unwrap())) } else { let (key, value) = pairs.pop().unwrap(); - // TODO: should we ever be constructing empty leaves from pairs? if value == Self::EMPTY_VALUE { - let index = Self::key_to_leaf_index(&key); - SmtLeaf::new_empty(index) + Ok(None) } else { - SmtLeaf::new_single(key, value) + Ok(Some(SmtLeaf::new_single(key, value))) } } } @@ -322,13 +346,19 @@ impl Smt { if leaf_changed { // Only return the leaf if it actually changed - Some(leaf) + Ok(Some(leaf)) } else { // Return None if leaf hasn't changed - None + Ok(None) } }); - (accumulator.leaves, new_pairs) + // The closure is the only possible source of errors. + // Since it never returns an error - only `Ok(Some(_))` or `Ok(None)` - we can safely assume + // `accumulator` is always `Ok(_)`. + ( + accumulator.expect("process_sorted_pairs_to_leaves never fails").leaves, + new_pairs, + ) } /// Processes sorted key-value pairs to compute leaves for a subtree. @@ -352,16 +382,18 @@ impl Smt { /// - `leaves`: A collection of `SubtreeLeaf` structures representing the processed leaves. Each /// `SubtreeLeaf` includes the column index and the hash of the corresponding leaf. /// + /// # Errors + /// Returns an error if the `process_leaf` callback fails. + /// /// # Panics /// This function will panic in debug mode if the input `pairs` are not sorted by column index. fn process_sorted_pairs_to_leaves( pairs: Vec<(RpoDigest, Word)>, mut process_leaf: F, - ) -> PairComputations + ) -> Result, MerkleError> where - F: FnMut(Vec<(RpoDigest, Word)>) -> Option, + F: FnMut(Vec<(RpoDigest, Word)>) -> Result, MerkleError>, { - use rayon::prelude::*; debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value())); let mut accumulator: PairComputations = Default::default(); @@ -392,8 +424,16 @@ impl Smt { // Otherwise, the next pair is a different column, or there is no next pair. Either way // it's time to swap out our buffer. let leaf_pairs = mem::take(&mut current_leaf_buffer); - if let Some(leaf) = process_leaf(leaf_pairs) { - accumulator.nodes.insert(col, leaf); + + // Process leaf and propagate any errors + match process_leaf(leaf_pairs) { + Ok(Some(leaf)) => { + accumulator.nodes.insert(col, leaf); + }, + Ok(None) => { + // No leaf was constructed for this column. The column will be skipped. + }, + Err(e) => return Err(e), } debug_assert!(current_leaf_buffer.is_empty()); @@ -415,7 +455,7 @@ impl Smt { // subtree boundaries as we go. Either way this function is only used at the beginning of a // parallel construction, so it should not be a critical path. accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect(); - accumulator + Ok(accumulator) } } @@ -539,6 +579,7 @@ fn build_subtree( // construction enforces uniqueness. However, when testing or benchmarking // `build_subtree()` in isolation, duplicate columns can appear if input // constraints are not enforced. + use alloc::collections::BTreeSet; let mut seen_cols = BTreeSet::new(); for leaf in &leaves { assert!(seen_cols.insert(leaf.col), "Duplicate column found in subtree: {}", leaf.col); diff --git a/src/merkle/smt/full/concurrent/tests.rs b/src/merkle/smt/full/concurrent/tests.rs index 8e42052..5a05c44 100644 --- a/src/merkle/smt/full/concurrent/tests.rs +++ b/src/merkle/smt/full/concurrent/tests.rs @@ -3,15 +3,19 @@ use alloc::{ vec::Vec, }; +use assert_matches::assert_matches; use proptest::prelude::*; use rand::{prelude::IteratorRandom, thread_rng, Rng}; use super::{ - build_subtree, InnerNode, LeafIndex, NodeIndex, NodeMutations, PairComputations, RpoDigest, - Smt, SmtLeaf, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, UnorderedMap, COLS_PER_SUBTREE, - SMT_DEPTH, SUBTREE_DEPTH, + build_subtree, InnerNode, NodeIndex, NodeMutations, PairComputations, RpoDigest, Smt, SmtLeaf, + SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, UnorderedMap, COLS_PER_SUBTREE, SMT_DEPTH, + SUBTREE_DEPTH, +}; +use crate::{ + merkle::{smt::Felt, LeafIndex, MerkleError}, + Word, EMPTY_WORD, ONE, ZERO, }; -use crate::{merkle::smt::Felt, Word, EMPTY_WORD, ONE, ZERO}; fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf { SubtreeLeaf { @@ -72,7 +76,7 @@ fn test_sorted_pairs_to_leaves() { control_subtree_leaves }; - let subtrees: PairComputations = Smt::sorted_pairs_to_leaves(entries); + let subtrees: PairComputations = Smt::sorted_pairs_to_leaves(entries).unwrap(); // This will check that the hashes, columns, and subtree assignments all match. assert_eq!(subtrees.leaves, control_subtree_leaves); // Flattening and re-separating out the leaves into subtrees should have the same result. @@ -140,7 +144,7 @@ fn test_single_subtree() { let entries = generate_entries(PAIR_COUNT); let control = Smt::with_entries_sequential(entries.clone()).unwrap(); // `entries` should already be sorted by nature of how we constructed it. - let leaves = Smt::sorted_pairs_to_leaves(entries).leaves; + let leaves = Smt::sorted_pairs_to_leaves(entries).unwrap().leaves; let leaves = leaves.into_iter().next().unwrap(); let (first_subtree, subtree_root) = build_subtree(leaves, SMT_DEPTH, SMT_DEPTH); assert!(!first_subtree.is_empty()); @@ -172,7 +176,7 @@ fn test_two_subtrees() { const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2; let entries = generate_entries(PAIR_COUNT); let control = Smt::with_entries_sequential(entries.clone()).unwrap(); - let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries); + let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries).unwrap(); // With two subtrees' worth of leaves, we should have exactly two subtrees. let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap(); assert_eq!(first.len() as u64, PAIR_COUNT / 2); @@ -220,7 +224,7 @@ fn test_singlethreaded_subtrees() { let PairComputations { leaves: mut leaf_subtrees, nodes: test_leaves, - } = Smt::sorted_pairs_to_leaves(entries); + } = Smt::sorted_pairs_to_leaves(entries).unwrap(); for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { // There's no flat_map_unzip(), so this is the best we can do. let (nodes, mut subtree_roots): (Vec>, Vec) = leaf_subtrees @@ -304,7 +308,7 @@ fn test_multithreaded_subtrees() { let PairComputations { leaves: mut leaf_subtrees, nodes: test_leaves, - } = Smt::sorted_pairs_to_leaves(entries); + } = Smt::sorted_pairs_to_leaves(entries).unwrap(); for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { let (nodes, mut subtree_roots): (Vec>, Vec) = leaf_subtrees .into_par_iter() @@ -479,6 +483,64 @@ fn test_smt_construction_with_entries_unsorted() { assert_eq!(smt, control); } +#[test] +fn test_smt_construction_with_entries_duplicate_keys() { + let entries = [ + (RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]), + (RpoDigest::new([ONE; 4]), [ONE; 4]), + (RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]), + ]; + let expected_col = Smt::key_to_leaf_index(&entries[0].0).index.value(); + let err = Smt::with_entries(entries).unwrap_err(); + assert_matches!(err, MerkleError::DuplicateValuesForIndex(col) if col == expected_col); +} + +#[test] +fn test_smt_construction_with_some_empty_values() { + let entries = [ + (RpoDigest::new([ONE, ONE, ONE, ONE]), Smt::EMPTY_VALUE), + (RpoDigest::new([ONE, ONE, ONE, Felt::new(2)]), [ONE; 4]), + ]; + + let result = Smt::with_entries(entries); + assert!(result.is_ok(), "SMT construction failed with mixed empty values"); + + let smt = result.unwrap(); + let control = Smt::with_entries_sequential(entries).unwrap(); + + assert_eq!(smt.num_leaves(), 1); + assert_eq!(smt.root(), control.root(), "Root hashes do not match"); + assert_eq!(smt, control, "SMTs are not equal"); +} + +#[test] +fn test_smt_construction_with_all_empty_values() { + let entries = [(RpoDigest::new([ONE, ONE, ONE, ONE]), Smt::EMPTY_VALUE)]; + + let result = Smt::with_entries(entries); + assert!(result.is_ok(), "SMT construction failed with all empty values"); + + let smt = result.unwrap(); + + assert_eq!( + smt.root(), + Smt::default().root(), + "SMT with all empty values should have the same root as the default SMT" + ); + assert_eq!(smt, Smt::default(), "SMT with all empty values should be empty"); +} + +#[test] +fn test_smt_construction_with_no_entries() { + let entries: [(RpoDigest, Word); 0] = []; + + let result = Smt::with_entries(entries); + assert!(result.is_ok(), "SMT construction failed with no entries"); + + let smt = result.unwrap(); + assert_eq!(smt, Smt::default(), "SMT with no entries should be empty"); +} + fn arb_felt() -> impl Strategy { prop_oneof![any::().prop_map(Felt::new), Just(ZERO), Just(ONE),] }