feat: optimized duplicate key detection in concurrent SMT construction (#395)

This commit is contained in:
Krushimir 2025-03-13 09:57:27 +01:00 committed by GitHub
parent 0df69679e9
commit 222197d08f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 161 additions and 57 deletions

View file

@ -11,6 +11,7 @@
- Sort keys in a leaf in the concurrent implementation of `Smt::with_entries`, ensuring consistency with the sequential version (#385). - Sort keys in a leaf in the concurrent implementation of `Smt::with_entries`, ensuring consistency with the sequential version (#385).
- Skip unchanged leaves in the concurrent implementation of `Smt::compute_mutations` (#385). - Skip unchanged leaves in the concurrent implementation of `Smt::compute_mutations` (#385).
- Add range checks to `ntru_gen` for Falcon DSA (#391). - Add range checks to `ntru_gen` for Falcon DSA (#391).
- Optimized duplicate key detection in `Smt::with_entries_concurrent` (#395).
## 0.13.3 (2025-02-18) ## 0.13.3 (2025-02-18)

View file

@ -1,11 +1,12 @@
use alloc::{collections::BTreeSet, vec::Vec}; use alloc::vec::Vec;
use core::mem; use core::mem;
use num::Integer; use num::Integer;
use rayon::prelude::*;
use super::{ use super::{
leaf, EmptySubtreeRoots, InnerNode, InnerNodes, LeafIndex, Leaves, MerkleError, MutationSet, leaf, EmptySubtreeRoots, InnerNode, InnerNodes, Leaves, MerkleError, MutationSet, NodeIndex,
NodeIndex, RpoDigest, Smt, SmtLeaf, SparseMerkleTree, Word, SMT_DEPTH, RpoDigest, Smt, SmtLeaf, SparseMerkleTree, Word, SMT_DEPTH,
}; };
use crate::merkle::smt::{NodeMutation, NodeMutations, UnorderedMap}; use crate::merkle::smt::{NodeMutation, NodeMutations, UnorderedMap};
@ -33,28 +34,25 @@ impl Smt {
/// 3. These subtree roots are recursively merged to become the "leaves" for the next iteration, /// 3. These subtree roots are recursively merged to become the "leaves" for the next iteration,
/// which processes the next 8 levels up. This continues until the final root of the tree is /// which processes the next 8 levels up. This continues until the final root of the tree is
/// computed at depth 0. /// computed at depth 0.
///
/// # Errors
/// Returns an error if the provided entries contain multiple values for the same key.
pub(crate) fn with_entries_concurrent( pub(crate) fn with_entries_concurrent(
entries: impl IntoIterator<Item = (RpoDigest, Word)>, entries: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> Result<Self, MerkleError> { ) -> Result<Self, MerkleError> {
let mut seen_keys = BTreeSet::new(); let entries: Vec<(RpoDigest, Word)> = entries.into_iter().collect();
let entries: Vec<_> = entries
.into_iter()
// Filter out key-value pairs whose value is empty.
.filter(|(_key, value)| *value != Self::EMPTY_VALUE)
.map(|(key, value)| {
if seen_keys.insert(key) {
Ok((key, value))
} else {
Err(MerkleError::DuplicateValuesForIndex(
LeafIndex::<SMT_DEPTH>::from(key).value(),
))
}
})
.collect::<Result<_, _>>()?;
if entries.is_empty() { if entries.is_empty() {
return Ok(Self::default()); return Ok(Self::default());
} }
let (inner_nodes, leaves) = Self::build_subtrees(entries);
let (inner_nodes, leaves) = Self::build_subtrees(entries)?;
// All the leaves are empty
if inner_nodes.is_empty() {
return Ok(Self::default());
}
let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash(); let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash();
<Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root) <Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root)
} }
@ -80,8 +78,6 @@ impl Smt {
where where
Self: Sized + Sync, Self: Sized + Sync,
{ {
use rayon::prelude::*;
// Collect and sort key-value pairs by their corresponding leaf index // Collect and sort key-value pairs by their corresponding leaf index
let mut sorted_kv_pairs: Vec<_> = kv_pairs.into_iter().collect(); let mut sorted_kv_pairs: Vec<_> = kv_pairs.into_iter().collect();
sorted_kv_pairs.par_sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value()); sorted_kv_pairs.par_sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value());
@ -206,9 +202,14 @@ impl Smt {
/// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs. /// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs.
/// ///
/// `entries` need not be sorted. This function will sort them. /// `entries` need not be sorted. This function will sort them using parallel sorting.
fn build_subtrees(mut entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) { ///
entries.sort_by_key(|item| { /// # Errors
/// Returns an error if the provided entries contain multiple values for the same key.
fn build_subtrees(
mut entries: Vec<(RpoDigest, Word)>,
) -> Result<(InnerNodes, Leaves), MerkleError> {
entries.par_sort_unstable_by_key(|item| {
let index = Self::key_to_leaf_index(&item.0); let index = Self::key_to_leaf_index(&item.0);
index.value() index.value()
}); });
@ -219,15 +220,23 @@ impl Smt {
/// ///
/// This function is mostly an implementation detail of /// This function is mostly an implementation detail of
/// [`Smt::with_entries_concurrent()`]. /// [`Smt::with_entries_concurrent()`].
fn build_subtrees_from_sorted_entries(entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) { ///
use rayon::prelude::*; /// # Errors
/// Returns an error if the provided entries contain multiple values for the same key.
fn build_subtrees_from_sorted_entries(
entries: Vec<(RpoDigest, Word)>,
) -> Result<(InnerNodes, Leaves), MerkleError> {
let mut accumulated_nodes: InnerNodes = Default::default(); let mut accumulated_nodes: InnerNodes = Default::default();
let PairComputations { let PairComputations {
leaves: mut leaf_subtrees, leaves: mut leaf_subtrees,
nodes: initial_leaves, nodes: initial_leaves,
} = Self::sorted_pairs_to_leaves(entries); } = Self::sorted_pairs_to_leaves(entries)?;
// If there are no leaves, we can return early
if initial_leaves.is_empty() {
return Ok((accumulated_nodes, initial_leaves));
}
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) = let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) =
@ -247,7 +256,7 @@ impl Smt {
debug_assert!(!leaf_subtrees.is_empty()); debug_assert!(!leaf_subtrees.is_empty());
} }
(accumulated_nodes, initial_leaves) Ok((accumulated_nodes, initial_leaves))
} }
// LEAF NODE CONSTRUCTION // LEAF NODE CONSTRUCTION
@ -260,31 +269,46 @@ impl Smt {
/// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If /// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If
/// `pairs` is not correctly sorted, the returned computations will be incorrect. /// `pairs` is not correctly sorted, the returned computations will be incorrect.
/// ///
/// # Errors
/// Returns an error if the provided pairs contain multiple values for the same key.
///
/// # Panics /// # Panics
/// With debug assertions on, this function panics if it detects that `pairs` is not correctly /// With debug assertions on, this function panics if it detects that `pairs` is not correctly
/// sorted. Without debug assertions, the returned computations will be incorrect. /// sorted. Without debug assertions, the returned computations will be incorrect.
fn sorted_pairs_to_leaves(pairs: Vec<(RpoDigest, Word)>) -> PairComputations<u64, SmtLeaf> { fn sorted_pairs_to_leaves(
Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| { pairs: Vec<(RpoDigest, Word)>,
Some(Self::pairs_to_leaf(leaf_pairs)) ) -> Result<PairComputations<u64, SmtLeaf>, MerkleError> {
}) Self::process_sorted_pairs_to_leaves(pairs, Self::pairs_to_leaf)
} }
/// Constructs a single leaf from an arbitrary amount of key-value pairs. /// Constructs a single leaf from an arbitrary amount of key-value pairs.
/// Those pairs must all have the same leaf index. /// Those pairs must all have the same leaf index.
fn pairs_to_leaf(mut pairs: Vec<(RpoDigest, Word)>) -> SmtLeaf { ///
/// # Errors
/// Returns a `MerkleError::DuplicateValuesForIndex` if the provided pairs contain multiple
/// values for the same key.
///
/// # Returns
/// - `Ok(Some(SmtLeaf))` if a valid leaf is constructed.
/// - `Ok(None)` if the only provided value is `Self::EMPTY_VALUE`.
fn pairs_to_leaf(mut pairs: Vec<(RpoDigest, Word)>) -> Result<Option<SmtLeaf>, MerkleError> {
assert!(!pairs.is_empty()); assert!(!pairs.is_empty());
if pairs.len() > 1 { if pairs.len() > 1 {
pairs.sort_by(|(key_1, _), (key_2, _)| leaf::cmp_keys(*key_1, *key_2)); pairs.sort_by(|(key_1, _), (key_2, _)| leaf::cmp_keys(*key_1, *key_2));
SmtLeaf::new_multiple(pairs).unwrap() // Check for duplicates in a sorted list by comparing adjacent pairs
if let Some(window) = pairs.windows(2).find(|window| window[0].0 == window[1].0) {
// If we find a duplicate, return an error
let col = Self::key_to_leaf_index(&window[0].0).index.value();
return Err(MerkleError::DuplicateValuesForIndex(col));
}
Ok(Some(SmtLeaf::new_multiple(pairs).unwrap()))
} else { } else {
let (key, value) = pairs.pop().unwrap(); let (key, value) = pairs.pop().unwrap();
// TODO: should we ever be constructing empty leaves from pairs?
if value == Self::EMPTY_VALUE { if value == Self::EMPTY_VALUE {
let index = Self::key_to_leaf_index(&key); Ok(None)
SmtLeaf::new_empty(index)
} else { } else {
SmtLeaf::new_single(key, value) Ok(Some(SmtLeaf::new_single(key, value)))
} }
} }
} }
@ -322,13 +346,19 @@ impl Smt {
if leaf_changed { if leaf_changed {
// Only return the leaf if it actually changed // Only return the leaf if it actually changed
Some(leaf) Ok(Some(leaf))
} else { } else {
// Return None if leaf hasn't changed // Return None if leaf hasn't changed
None Ok(None)
} }
}); });
(accumulator.leaves, new_pairs) // The closure is the only possible source of errors.
// Since it never returns an error - only `Ok(Some(_))` or `Ok(None)` - we can safely assume
// `accumulator` is always `Ok(_)`.
(
accumulator.expect("process_sorted_pairs_to_leaves never fails").leaves,
new_pairs,
)
} }
/// Processes sorted key-value pairs to compute leaves for a subtree. /// Processes sorted key-value pairs to compute leaves for a subtree.
@ -352,16 +382,18 @@ impl Smt {
/// - `leaves`: A collection of `SubtreeLeaf` structures representing the processed leaves. Each /// - `leaves`: A collection of `SubtreeLeaf` structures representing the processed leaves. Each
/// `SubtreeLeaf` includes the column index and the hash of the corresponding leaf. /// `SubtreeLeaf` includes the column index and the hash of the corresponding leaf.
/// ///
/// # Errors
/// Returns an error if the `process_leaf` callback fails.
///
/// # Panics /// # Panics
/// This function will panic in debug mode if the input `pairs` are not sorted by column index. /// This function will panic in debug mode if the input `pairs` are not sorted by column index.
fn process_sorted_pairs_to_leaves<F>( fn process_sorted_pairs_to_leaves<F>(
pairs: Vec<(RpoDigest, Word)>, pairs: Vec<(RpoDigest, Word)>,
mut process_leaf: F, mut process_leaf: F,
) -> PairComputations<u64, SmtLeaf> ) -> Result<PairComputations<u64, SmtLeaf>, MerkleError>
where where
F: FnMut(Vec<(RpoDigest, Word)>) -> Option<SmtLeaf>, F: FnMut(Vec<(RpoDigest, Word)>) -> Result<Option<SmtLeaf>, MerkleError>,
{ {
use rayon::prelude::*;
debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value())); debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value()));
let mut accumulator: PairComputations<u64, SmtLeaf> = Default::default(); let mut accumulator: PairComputations<u64, SmtLeaf> = Default::default();
@ -392,8 +424,16 @@ impl Smt {
// Otherwise, the next pair is a different column, or there is no next pair. Either way // Otherwise, the next pair is a different column, or there is no next pair. Either way
// it's time to swap out our buffer. // it's time to swap out our buffer.
let leaf_pairs = mem::take(&mut current_leaf_buffer); let leaf_pairs = mem::take(&mut current_leaf_buffer);
if let Some(leaf) = process_leaf(leaf_pairs) {
accumulator.nodes.insert(col, leaf); // Process leaf and propagate any errors
match process_leaf(leaf_pairs) {
Ok(Some(leaf)) => {
accumulator.nodes.insert(col, leaf);
},
Ok(None) => {
// No leaf was constructed for this column. The column will be skipped.
},
Err(e) => return Err(e),
} }
debug_assert!(current_leaf_buffer.is_empty()); debug_assert!(current_leaf_buffer.is_empty());
@ -415,7 +455,7 @@ impl Smt {
// subtree boundaries as we go. Either way this function is only used at the beginning of a // subtree boundaries as we go. Either way this function is only used at the beginning of a
// parallel construction, so it should not be a critical path. // parallel construction, so it should not be a critical path.
accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect(); accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect();
accumulator Ok(accumulator)
} }
} }
@ -539,6 +579,7 @@ fn build_subtree(
// construction enforces uniqueness. However, when testing or benchmarking // construction enforces uniqueness. However, when testing or benchmarking
// `build_subtree()` in isolation, duplicate columns can appear if input // `build_subtree()` in isolation, duplicate columns can appear if input
// constraints are not enforced. // constraints are not enforced.
use alloc::collections::BTreeSet;
let mut seen_cols = BTreeSet::new(); let mut seen_cols = BTreeSet::new();
for leaf in &leaves { for leaf in &leaves {
assert!(seen_cols.insert(leaf.col), "Duplicate column found in subtree: {}", leaf.col); assert!(seen_cols.insert(leaf.col), "Duplicate column found in subtree: {}", leaf.col);

View file

@ -3,15 +3,19 @@ use alloc::{
vec::Vec, vec::Vec,
}; };
use assert_matches::assert_matches;
use proptest::prelude::*; use proptest::prelude::*;
use rand::{prelude::IteratorRandom, thread_rng, Rng}; use rand::{prelude::IteratorRandom, thread_rng, Rng};
use super::{ use super::{
build_subtree, InnerNode, LeafIndex, NodeIndex, NodeMutations, PairComputations, RpoDigest, build_subtree, InnerNode, NodeIndex, NodeMutations, PairComputations, RpoDigest, Smt, SmtLeaf,
Smt, SmtLeaf, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, UnorderedMap, COLS_PER_SUBTREE, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, UnorderedMap, COLS_PER_SUBTREE, SMT_DEPTH,
SMT_DEPTH, SUBTREE_DEPTH, SUBTREE_DEPTH,
};
use crate::{
merkle::{smt::Felt, LeafIndex, MerkleError},
Word, EMPTY_WORD, ONE, ZERO,
}; };
use crate::{merkle::smt::Felt, Word, EMPTY_WORD, ONE, ZERO};
fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf { fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf {
SubtreeLeaf { SubtreeLeaf {
@ -72,7 +76,7 @@ fn test_sorted_pairs_to_leaves() {
control_subtree_leaves control_subtree_leaves
}; };
let subtrees: PairComputations<u64, SmtLeaf> = Smt::sorted_pairs_to_leaves(entries); let subtrees: PairComputations<u64, SmtLeaf> = Smt::sorted_pairs_to_leaves(entries).unwrap();
// This will check that the hashes, columns, and subtree assignments all match. // This will check that the hashes, columns, and subtree assignments all match.
assert_eq!(subtrees.leaves, control_subtree_leaves); assert_eq!(subtrees.leaves, control_subtree_leaves);
// Flattening and re-separating out the leaves into subtrees should have the same result. // Flattening and re-separating out the leaves into subtrees should have the same result.
@ -140,7 +144,7 @@ fn test_single_subtree() {
let entries = generate_entries(PAIR_COUNT); let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries_sequential(entries.clone()).unwrap(); let control = Smt::with_entries_sequential(entries.clone()).unwrap();
// `entries` should already be sorted by nature of how we constructed it. // `entries` should already be sorted by nature of how we constructed it.
let leaves = Smt::sorted_pairs_to_leaves(entries).leaves; let leaves = Smt::sorted_pairs_to_leaves(entries).unwrap().leaves;
let leaves = leaves.into_iter().next().unwrap(); let leaves = leaves.into_iter().next().unwrap();
let (first_subtree, subtree_root) = build_subtree(leaves, SMT_DEPTH, SMT_DEPTH); let (first_subtree, subtree_root) = build_subtree(leaves, SMT_DEPTH, SMT_DEPTH);
assert!(!first_subtree.is_empty()); assert!(!first_subtree.is_empty());
@ -172,7 +176,7 @@ fn test_two_subtrees() {
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2; const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2;
let entries = generate_entries(PAIR_COUNT); let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries_sequential(entries.clone()).unwrap(); let control = Smt::with_entries_sequential(entries.clone()).unwrap();
let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries); let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries).unwrap();
// With two subtrees' worth of leaves, we should have exactly two subtrees. // With two subtrees' worth of leaves, we should have exactly two subtrees.
let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap(); let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap();
assert_eq!(first.len() as u64, PAIR_COUNT / 2); assert_eq!(first.len() as u64, PAIR_COUNT / 2);
@ -220,7 +224,7 @@ fn test_singlethreaded_subtrees() {
let PairComputations { let PairComputations {
leaves: mut leaf_subtrees, leaves: mut leaf_subtrees,
nodes: test_leaves, nodes: test_leaves,
} = Smt::sorted_pairs_to_leaves(entries); } = Smt::sorted_pairs_to_leaves(entries).unwrap();
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
// There's no flat_map_unzip(), so this is the best we can do. // There's no flat_map_unzip(), so this is the best we can do.
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
@ -304,7 +308,7 @@ fn test_multithreaded_subtrees() {
let PairComputations { let PairComputations {
leaves: mut leaf_subtrees, leaves: mut leaf_subtrees,
nodes: test_leaves, nodes: test_leaves,
} = Smt::sorted_pairs_to_leaves(entries); } = Smt::sorted_pairs_to_leaves(entries).unwrap();
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() { for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
.into_par_iter() .into_par_iter()
@ -479,6 +483,64 @@ fn test_smt_construction_with_entries_unsorted() {
assert_eq!(smt, control); assert_eq!(smt, control);
} }
#[test]
fn test_smt_construction_with_entries_duplicate_keys() {
let entries = [
(RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]),
(RpoDigest::new([ONE; 4]), [ONE; 4]),
(RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]),
];
let expected_col = Smt::key_to_leaf_index(&entries[0].0).index.value();
let err = Smt::with_entries(entries).unwrap_err();
assert_matches!(err, MerkleError::DuplicateValuesForIndex(col) if col == expected_col);
}
#[test]
fn test_smt_construction_with_some_empty_values() {
let entries = [
(RpoDigest::new([ONE, ONE, ONE, ONE]), Smt::EMPTY_VALUE),
(RpoDigest::new([ONE, ONE, ONE, Felt::new(2)]), [ONE; 4]),
];
let result = Smt::with_entries(entries);
assert!(result.is_ok(), "SMT construction failed with mixed empty values");
let smt = result.unwrap();
let control = Smt::with_entries_sequential(entries).unwrap();
assert_eq!(smt.num_leaves(), 1);
assert_eq!(smt.root(), control.root(), "Root hashes do not match");
assert_eq!(smt, control, "SMTs are not equal");
}
#[test]
fn test_smt_construction_with_all_empty_values() {
let entries = [(RpoDigest::new([ONE, ONE, ONE, ONE]), Smt::EMPTY_VALUE)];
let result = Smt::with_entries(entries);
assert!(result.is_ok(), "SMT construction failed with all empty values");
let smt = result.unwrap();
assert_eq!(
smt.root(),
Smt::default().root(),
"SMT with all empty values should have the same root as the default SMT"
);
assert_eq!(smt, Smt::default(), "SMT with all empty values should be empty");
}
#[test]
fn test_smt_construction_with_no_entries() {
let entries: [(RpoDigest, Word); 0] = [];
let result = Smt::with_entries(entries);
assert!(result.is_ok(), "SMT construction failed with no entries");
let smt = result.unwrap();
assert_eq!(smt, Smt::default(), "SMT with no entries should be empty");
}
fn arb_felt() -> impl Strategy<Value = Felt> { fn arb_felt() -> impl Strategy<Value = Felt> {
prop_oneof![any::<u64>().prop_map(Felt::new), Just(ZERO), Just(ONE),] prop_oneof![any::<u64>().prop_map(Felt::new), Just(ZERO), Just(ONE),]
} }