feat: optimized duplicate key detection in concurrent SMT construction (#395)

This commit is contained in:
Krushimir 2025-03-13 09:57:27 +01:00 committed by GitHub
parent 0df69679e9
commit 222197d08f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 161 additions and 57 deletions

View file

@ -11,6 +11,7 @@
- Sort keys in a leaf in the concurrent implementation of `Smt::with_entries`, ensuring consistency with the sequential version (#385).
- Skip unchanged leaves in the concurrent implementation of `Smt::compute_mutations` (#385).
- Add range checks to `ntru_gen` for Falcon DSA (#391).
- Optimized duplicate key detection in `Smt::with_entries_concurrent` (#395).
## 0.13.3 (2025-02-18)

View file

@ -1,11 +1,12 @@
use alloc::{collections::BTreeSet, vec::Vec};
use alloc::vec::Vec;
use core::mem;
use num::Integer;
use rayon::prelude::*;
use super::{
leaf, EmptySubtreeRoots, InnerNode, InnerNodes, LeafIndex, Leaves, MerkleError, MutationSet,
NodeIndex, RpoDigest, Smt, SmtLeaf, SparseMerkleTree, Word, SMT_DEPTH,
leaf, EmptySubtreeRoots, InnerNode, InnerNodes, Leaves, MerkleError, MutationSet, NodeIndex,
RpoDigest, Smt, SmtLeaf, SparseMerkleTree, Word, SMT_DEPTH,
};
use crate::merkle::smt::{NodeMutation, NodeMutations, UnorderedMap};
@ -33,28 +34,25 @@ impl Smt {
/// 3. These subtree roots are recursively merged to become the "leaves" for the next iteration,
/// which processes the next 8 levels up. This continues until the final root of the tree is
/// computed at depth 0.
///
/// # Errors
/// Returns an error if the provided entries contain multiple values for the same key.
pub(crate) fn with_entries_concurrent(
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
) -> Result<Self, MerkleError> {
let mut seen_keys = BTreeSet::new();
let entries: Vec<_> = entries
.into_iter()
// Filter out key-value pairs whose value is empty.
.filter(|(_key, value)| *value != Self::EMPTY_VALUE)
.map(|(key, value)| {
if seen_keys.insert(key) {
Ok((key, value))
} else {
Err(MerkleError::DuplicateValuesForIndex(
LeafIndex::<SMT_DEPTH>::from(key).value(),
))
}
})
.collect::<Result<_, _>>()?;
let entries: Vec<(RpoDigest, Word)> = entries.into_iter().collect();
if entries.is_empty() {
return Ok(Self::default());
}
let (inner_nodes, leaves) = Self::build_subtrees(entries);
let (inner_nodes, leaves) = Self::build_subtrees(entries)?;
// All the leaves are empty
if inner_nodes.is_empty() {
return Ok(Self::default());
}
let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash();
<Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root)
}
@ -80,8 +78,6 @@ impl Smt {
where
Self: Sized + Sync,
{
use rayon::prelude::*;
// Collect and sort key-value pairs by their corresponding leaf index
let mut sorted_kv_pairs: Vec<_> = kv_pairs.into_iter().collect();
sorted_kv_pairs.par_sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value());
@ -206,9 +202,14 @@ impl Smt {
/// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs.
///
/// `entries` need not be sorted. This function will sort them.
fn build_subtrees(mut entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) {
entries.sort_by_key(|item| {
/// `entries` need not be sorted. This function will sort them using parallel sorting.
///
/// # Errors
/// Returns an error if the provided entries contain multiple values for the same key.
fn build_subtrees(
mut entries: Vec<(RpoDigest, Word)>,
) -> Result<(InnerNodes, Leaves), MerkleError> {
entries.par_sort_unstable_by_key(|item| {
let index = Self::key_to_leaf_index(&item.0);
index.value()
});
@ -219,15 +220,23 @@ impl Smt {
///
/// This function is mostly an implementation detail of
/// [`Smt::with_entries_concurrent()`].
fn build_subtrees_from_sorted_entries(entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) {
use rayon::prelude::*;
///
/// # Errors
/// Returns an error if the provided entries contain multiple values for the same key.
fn build_subtrees_from_sorted_entries(
entries: Vec<(RpoDigest, Word)>,
) -> Result<(InnerNodes, Leaves), MerkleError> {
let mut accumulated_nodes: InnerNodes = Default::default();
let PairComputations {
leaves: mut leaf_subtrees,
nodes: initial_leaves,
} = Self::sorted_pairs_to_leaves(entries);
} = Self::sorted_pairs_to_leaves(entries)?;
// If there are no leaves, we can return early
if initial_leaves.is_empty() {
return Ok((accumulated_nodes, initial_leaves));
}
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) =
@ -247,7 +256,7 @@ impl Smt {
debug_assert!(!leaf_subtrees.is_empty());
}
(accumulated_nodes, initial_leaves)
Ok((accumulated_nodes, initial_leaves))
}
// LEAF NODE CONSTRUCTION
@ -260,31 +269,46 @@ impl Smt {
/// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If
/// `pairs` is not correctly sorted, the returned computations will be incorrect.
///
/// # Errors
/// Returns an error if the provided pairs contain multiple values for the same key.
///
/// # Panics
/// With debug assertions on, this function panics if it detects that `pairs` is not correctly
/// sorted. Without debug assertions, the returned computations will be incorrect.
fn sorted_pairs_to_leaves(pairs: Vec<(RpoDigest, Word)>) -> PairComputations<u64, SmtLeaf> {
Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| {
Some(Self::pairs_to_leaf(leaf_pairs))
})
fn sorted_pairs_to_leaves(
pairs: Vec<(RpoDigest, Word)>,
) -> Result<PairComputations<u64, SmtLeaf>, MerkleError> {
Self::process_sorted_pairs_to_leaves(pairs, Self::pairs_to_leaf)
}
/// Constructs a single leaf from an arbitrary amount of key-value pairs.
/// Those pairs must all have the same leaf index.
fn pairs_to_leaf(mut pairs: Vec<(RpoDigest, Word)>) -> SmtLeaf {
///
/// # Errors
/// Returns a `MerkleError::DuplicateValuesForIndex` if the provided pairs contain multiple
/// values for the same key.
///
/// # Returns
/// - `Ok(Some(SmtLeaf))` if a valid leaf is constructed.
/// - `Ok(None)` if the only provided value is `Self::EMPTY_VALUE`.
fn pairs_to_leaf(mut pairs: Vec<(RpoDigest, Word)>) -> Result<Option<SmtLeaf>, MerkleError> {
assert!(!pairs.is_empty());
if pairs.len() > 1 {
pairs.sort_by(|(key_1, _), (key_2, _)| leaf::cmp_keys(*key_1, *key_2));
SmtLeaf::new_multiple(pairs).unwrap()
// Check for duplicates in a sorted list by comparing adjacent pairs
if let Some(window) = pairs.windows(2).find(|window| window[0].0 == window[1].0) {
// If we find a duplicate, return an error
let col = Self::key_to_leaf_index(&window[0].0).index.value();
return Err(MerkleError::DuplicateValuesForIndex(col));
}
Ok(Some(SmtLeaf::new_multiple(pairs).unwrap()))
} else {
let (key, value) = pairs.pop().unwrap();
// TODO: should we ever be constructing empty leaves from pairs?
if value == Self::EMPTY_VALUE {
let index = Self::key_to_leaf_index(&key);
SmtLeaf::new_empty(index)
Ok(None)
} else {
SmtLeaf::new_single(key, value)
Ok(Some(SmtLeaf::new_single(key, value)))
}
}
}
@ -322,13 +346,19 @@ impl Smt {
if leaf_changed {
// Only return the leaf if it actually changed
Some(leaf)
Ok(Some(leaf))
} else {
// Return None if leaf hasn't changed
None
Ok(None)
}
});
(accumulator.leaves, new_pairs)
// The closure is the only possible source of errors.
// Since it never returns an error - only `Ok(Some(_))` or `Ok(None)` - we can safely assume
// `accumulator` is always `Ok(_)`.
(
accumulator.expect("process_sorted_pairs_to_leaves never fails").leaves,
new_pairs,
)
}
/// Processes sorted key-value pairs to compute leaves for a subtree.
@ -352,16 +382,18 @@ impl Smt {
/// - `leaves`: A collection of `SubtreeLeaf` structures representing the processed leaves. Each
/// `SubtreeLeaf` includes the column index and the hash of the corresponding leaf.
///
/// # Errors
/// Returns an error if the `process_leaf` callback fails.
///
/// # Panics
/// This function will panic in debug mode if the input `pairs` are not sorted by column index.
fn process_sorted_pairs_to_leaves<F>(
pairs: Vec<(RpoDigest, Word)>,
mut process_leaf: F,
) -> PairComputations<u64, SmtLeaf>
) -> Result<PairComputations<u64, SmtLeaf>, MerkleError>
where
F: FnMut(Vec<(RpoDigest, Word)>) -> Option<SmtLeaf>,
F: FnMut(Vec<(RpoDigest, Word)>) -> Result<Option<SmtLeaf>, MerkleError>,
{
use rayon::prelude::*;
debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value()));
let mut accumulator: PairComputations<u64, SmtLeaf> = Default::default();
@ -392,8 +424,16 @@ impl Smt {
// Otherwise, the next pair is a different column, or there is no next pair. Either way
// it's time to swap out our buffer.
let leaf_pairs = mem::take(&mut current_leaf_buffer);
if let Some(leaf) = process_leaf(leaf_pairs) {
accumulator.nodes.insert(col, leaf);
// Process leaf and propagate any errors
match process_leaf(leaf_pairs) {
Ok(Some(leaf)) => {
accumulator.nodes.insert(col, leaf);
},
Ok(None) => {
// No leaf was constructed for this column. The column will be skipped.
},
Err(e) => return Err(e),
}
debug_assert!(current_leaf_buffer.is_empty());
@ -415,7 +455,7 @@ impl Smt {
// subtree boundaries as we go. Either way this function is only used at the beginning of a
// parallel construction, so it should not be a critical path.
accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect();
accumulator
Ok(accumulator)
}
}
@ -539,6 +579,7 @@ fn build_subtree(
// construction enforces uniqueness. However, when testing or benchmarking
// `build_subtree()` in isolation, duplicate columns can appear if input
// constraints are not enforced.
use alloc::collections::BTreeSet;
let mut seen_cols = BTreeSet::new();
for leaf in &leaves {
assert!(seen_cols.insert(leaf.col), "Duplicate column found in subtree: {}", leaf.col);

View file

@ -3,15 +3,19 @@ use alloc::{
vec::Vec,
};
use assert_matches::assert_matches;
use proptest::prelude::*;
use rand::{prelude::IteratorRandom, thread_rng, Rng};
use super::{
build_subtree, InnerNode, LeafIndex, NodeIndex, NodeMutations, PairComputations, RpoDigest,
Smt, SmtLeaf, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, UnorderedMap, COLS_PER_SUBTREE,
SMT_DEPTH, SUBTREE_DEPTH,
build_subtree, InnerNode, NodeIndex, NodeMutations, PairComputations, RpoDigest, Smt, SmtLeaf,
SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, UnorderedMap, COLS_PER_SUBTREE, SMT_DEPTH,
SUBTREE_DEPTH,
};
use crate::{
merkle::{smt::Felt, LeafIndex, MerkleError},
Word, EMPTY_WORD, ONE, ZERO,
};
use crate::{merkle::smt::Felt, Word, EMPTY_WORD, ONE, ZERO};
fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf {
SubtreeLeaf {
@ -72,7 +76,7 @@ fn test_sorted_pairs_to_leaves() {
control_subtree_leaves
};
let subtrees: PairComputations<u64, SmtLeaf> = Smt::sorted_pairs_to_leaves(entries);
let subtrees: PairComputations<u64, SmtLeaf> = Smt::sorted_pairs_to_leaves(entries).unwrap();
// This will check that the hashes, columns, and subtree assignments all match.
assert_eq!(subtrees.leaves, control_subtree_leaves);
// Flattening and re-separating out the leaves into subtrees should have the same result.
@ -140,7 +144,7 @@ fn test_single_subtree() {
let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
// `entries` should already be sorted by nature of how we constructed it.
let leaves = Smt::sorted_pairs_to_leaves(entries).leaves;
let leaves = Smt::sorted_pairs_to_leaves(entries).unwrap().leaves;
let leaves = leaves.into_iter().next().unwrap();
let (first_subtree, subtree_root) = build_subtree(leaves, SMT_DEPTH, SMT_DEPTH);
assert!(!first_subtree.is_empty());
@ -172,7 +176,7 @@ fn test_two_subtrees() {
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2;
let entries = generate_entries(PAIR_COUNT);
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries);
let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries).unwrap();
// With two subtrees' worth of leaves, we should have exactly two subtrees.
let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap();
assert_eq!(first.len() as u64, PAIR_COUNT / 2);
@ -220,7 +224,7 @@ fn test_singlethreaded_subtrees() {
let PairComputations {
leaves: mut leaf_subtrees,
nodes: test_leaves,
} = Smt::sorted_pairs_to_leaves(entries);
} = Smt::sorted_pairs_to_leaves(entries).unwrap();
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
// There's no flat_map_unzip(), so this is the best we can do.
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
@ -304,7 +308,7 @@ fn test_multithreaded_subtrees() {
let PairComputations {
leaves: mut leaf_subtrees,
nodes: test_leaves,
} = Smt::sorted_pairs_to_leaves(entries);
} = Smt::sorted_pairs_to_leaves(entries).unwrap();
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
.into_par_iter()
@ -479,6 +483,64 @@ fn test_smt_construction_with_entries_unsorted() {
assert_eq!(smt, control);
}
#[test]
fn test_smt_construction_with_entries_duplicate_keys() {
let entries = [
(RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]),
(RpoDigest::new([ONE; 4]), [ONE; 4]),
(RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]),
];
let expected_col = Smt::key_to_leaf_index(&entries[0].0).index.value();
let err = Smt::with_entries(entries).unwrap_err();
assert_matches!(err, MerkleError::DuplicateValuesForIndex(col) if col == expected_col);
}
#[test]
fn test_smt_construction_with_some_empty_values() {
let entries = [
(RpoDigest::new([ONE, ONE, ONE, ONE]), Smt::EMPTY_VALUE),
(RpoDigest::new([ONE, ONE, ONE, Felt::new(2)]), [ONE; 4]),
];
let result = Smt::with_entries(entries);
assert!(result.is_ok(), "SMT construction failed with mixed empty values");
let smt = result.unwrap();
let control = Smt::with_entries_sequential(entries).unwrap();
assert_eq!(smt.num_leaves(), 1);
assert_eq!(smt.root(), control.root(), "Root hashes do not match");
assert_eq!(smt, control, "SMTs are not equal");
}
#[test]
fn test_smt_construction_with_all_empty_values() {
let entries = [(RpoDigest::new([ONE, ONE, ONE, ONE]), Smt::EMPTY_VALUE)];
let result = Smt::with_entries(entries);
assert!(result.is_ok(), "SMT construction failed with all empty values");
let smt = result.unwrap();
assert_eq!(
smt.root(),
Smt::default().root(),
"SMT with all empty values should have the same root as the default SMT"
);
assert_eq!(smt, Smt::default(), "SMT with all empty values should be empty");
}
#[test]
fn test_smt_construction_with_no_entries() {
let entries: [(RpoDigest, Word); 0] = [];
let result = Smt::with_entries(entries);
assert!(result.is_ok(), "SMT construction failed with no entries");
let smt = result.unwrap();
assert_eq!(smt, Smt::default(), "SMT with no entries should be empty");
}
fn arb_felt() -> impl Strategy<Value = Felt> {
prop_oneof![any::<u64>().prop_map(Felt::new), Just(ZERO), Just(ONE),]
}