feat: optimized duplicate key detection in concurrent SMT construction (#395)
This commit is contained in:
parent
0df69679e9
commit
222197d08f
3 changed files with 161 additions and 57 deletions
|
@ -11,6 +11,7 @@
|
|||
- Sort keys in a leaf in the concurrent implementation of `Smt::with_entries`, ensuring consistency with the sequential version (#385).
|
||||
- Skip unchanged leaves in the concurrent implementation of `Smt::compute_mutations` (#385).
|
||||
- Add range checks to `ntru_gen` for Falcon DSA (#391).
|
||||
- Optimized duplicate key detection in `Smt::with_entries_concurrent` (#395).
|
||||
|
||||
## 0.13.3 (2025-02-18)
|
||||
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
use alloc::{collections::BTreeSet, vec::Vec};
|
||||
use alloc::vec::Vec;
|
||||
use core::mem;
|
||||
|
||||
use num::Integer;
|
||||
use rayon::prelude::*;
|
||||
|
||||
use super::{
|
||||
leaf, EmptySubtreeRoots, InnerNode, InnerNodes, LeafIndex, Leaves, MerkleError, MutationSet,
|
||||
NodeIndex, RpoDigest, Smt, SmtLeaf, SparseMerkleTree, Word, SMT_DEPTH,
|
||||
leaf, EmptySubtreeRoots, InnerNode, InnerNodes, Leaves, MerkleError, MutationSet, NodeIndex,
|
||||
RpoDigest, Smt, SmtLeaf, SparseMerkleTree, Word, SMT_DEPTH,
|
||||
};
|
||||
use crate::merkle::smt::{NodeMutation, NodeMutations, UnorderedMap};
|
||||
|
||||
|
@ -33,28 +34,25 @@ impl Smt {
|
|||
/// 3. These subtree roots are recursively merged to become the "leaves" for the next iteration,
|
||||
/// which processes the next 8 levels up. This continues until the final root of the tree is
|
||||
/// computed at depth 0.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the provided entries contain multiple values for the same key.
|
||||
pub(crate) fn with_entries_concurrent(
|
||||
entries: impl IntoIterator<Item = (RpoDigest, Word)>,
|
||||
) -> Result<Self, MerkleError> {
|
||||
let mut seen_keys = BTreeSet::new();
|
||||
let entries: Vec<_> = entries
|
||||
.into_iter()
|
||||
// Filter out key-value pairs whose value is empty.
|
||||
.filter(|(_key, value)| *value != Self::EMPTY_VALUE)
|
||||
.map(|(key, value)| {
|
||||
if seen_keys.insert(key) {
|
||||
Ok((key, value))
|
||||
} else {
|
||||
Err(MerkleError::DuplicateValuesForIndex(
|
||||
LeafIndex::<SMT_DEPTH>::from(key).value(),
|
||||
))
|
||||
}
|
||||
})
|
||||
.collect::<Result<_, _>>()?;
|
||||
let entries: Vec<(RpoDigest, Word)> = entries.into_iter().collect();
|
||||
|
||||
if entries.is_empty() {
|
||||
return Ok(Self::default());
|
||||
}
|
||||
let (inner_nodes, leaves) = Self::build_subtrees(entries);
|
||||
|
||||
let (inner_nodes, leaves) = Self::build_subtrees(entries)?;
|
||||
|
||||
// All the leaves are empty
|
||||
if inner_nodes.is_empty() {
|
||||
return Ok(Self::default());
|
||||
}
|
||||
|
||||
let root = inner_nodes.get(&NodeIndex::root()).unwrap().hash();
|
||||
<Self as SparseMerkleTree<SMT_DEPTH>>::from_raw_parts(inner_nodes, leaves, root)
|
||||
}
|
||||
|
@ -80,8 +78,6 @@ impl Smt {
|
|||
where
|
||||
Self: Sized + Sync,
|
||||
{
|
||||
use rayon::prelude::*;
|
||||
|
||||
// Collect and sort key-value pairs by their corresponding leaf index
|
||||
let mut sorted_kv_pairs: Vec<_> = kv_pairs.into_iter().collect();
|
||||
sorted_kv_pairs.par_sort_unstable_by_key(|(key, _)| Self::key_to_leaf_index(key).value());
|
||||
|
@ -206,9 +202,14 @@ impl Smt {
|
|||
|
||||
/// Computes the raw parts for a new sparse Merkle tree from a set of key-value pairs.
|
||||
///
|
||||
/// `entries` need not be sorted. This function will sort them.
|
||||
fn build_subtrees(mut entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) {
|
||||
entries.sort_by_key(|item| {
|
||||
/// `entries` need not be sorted. This function will sort them using parallel sorting.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the provided entries contain multiple values for the same key.
|
||||
fn build_subtrees(
|
||||
mut entries: Vec<(RpoDigest, Word)>,
|
||||
) -> Result<(InnerNodes, Leaves), MerkleError> {
|
||||
entries.par_sort_unstable_by_key(|item| {
|
||||
let index = Self::key_to_leaf_index(&item.0);
|
||||
index.value()
|
||||
});
|
||||
|
@ -219,15 +220,23 @@ impl Smt {
|
|||
///
|
||||
/// This function is mostly an implementation detail of
|
||||
/// [`Smt::with_entries_concurrent()`].
|
||||
fn build_subtrees_from_sorted_entries(entries: Vec<(RpoDigest, Word)>) -> (InnerNodes, Leaves) {
|
||||
use rayon::prelude::*;
|
||||
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the provided entries contain multiple values for the same key.
|
||||
fn build_subtrees_from_sorted_entries(
|
||||
entries: Vec<(RpoDigest, Word)>,
|
||||
) -> Result<(InnerNodes, Leaves), MerkleError> {
|
||||
let mut accumulated_nodes: InnerNodes = Default::default();
|
||||
|
||||
let PairComputations {
|
||||
leaves: mut leaf_subtrees,
|
||||
nodes: initial_leaves,
|
||||
} = Self::sorted_pairs_to_leaves(entries);
|
||||
} = Self::sorted_pairs_to_leaves(entries)?;
|
||||
|
||||
// If there are no leaves, we can return early
|
||||
if initial_leaves.is_empty() {
|
||||
return Ok((accumulated_nodes, initial_leaves));
|
||||
}
|
||||
|
||||
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
||||
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) =
|
||||
|
@ -247,7 +256,7 @@ impl Smt {
|
|||
|
||||
debug_assert!(!leaf_subtrees.is_empty());
|
||||
}
|
||||
(accumulated_nodes, initial_leaves)
|
||||
Ok((accumulated_nodes, initial_leaves))
|
||||
}
|
||||
|
||||
// LEAF NODE CONSTRUCTION
|
||||
|
@ -260,31 +269,46 @@ impl Smt {
|
|||
/// `pairs` *must* already be sorted **by leaf index column**, not simply sorted by key. If
|
||||
/// `pairs` is not correctly sorted, the returned computations will be incorrect.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the provided pairs contain multiple values for the same key.
|
||||
///
|
||||
/// # Panics
|
||||
/// With debug assertions on, this function panics if it detects that `pairs` is not correctly
|
||||
/// sorted. Without debug assertions, the returned computations will be incorrect.
|
||||
fn sorted_pairs_to_leaves(pairs: Vec<(RpoDigest, Word)>) -> PairComputations<u64, SmtLeaf> {
|
||||
Self::process_sorted_pairs_to_leaves(pairs, |leaf_pairs| {
|
||||
Some(Self::pairs_to_leaf(leaf_pairs))
|
||||
})
|
||||
fn sorted_pairs_to_leaves(
|
||||
pairs: Vec<(RpoDigest, Word)>,
|
||||
) -> Result<PairComputations<u64, SmtLeaf>, MerkleError> {
|
||||
Self::process_sorted_pairs_to_leaves(pairs, Self::pairs_to_leaf)
|
||||
}
|
||||
|
||||
/// Constructs a single leaf from an arbitrary amount of key-value pairs.
|
||||
/// Those pairs must all have the same leaf index.
|
||||
fn pairs_to_leaf(mut pairs: Vec<(RpoDigest, Word)>) -> SmtLeaf {
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns a `MerkleError::DuplicateValuesForIndex` if the provided pairs contain multiple
|
||||
/// values for the same key.
|
||||
///
|
||||
/// # Returns
|
||||
/// - `Ok(Some(SmtLeaf))` if a valid leaf is constructed.
|
||||
/// - `Ok(None)` if the only provided value is `Self::EMPTY_VALUE`.
|
||||
fn pairs_to_leaf(mut pairs: Vec<(RpoDigest, Word)>) -> Result<Option<SmtLeaf>, MerkleError> {
|
||||
assert!(!pairs.is_empty());
|
||||
|
||||
if pairs.len() > 1 {
|
||||
pairs.sort_by(|(key_1, _), (key_2, _)| leaf::cmp_keys(*key_1, *key_2));
|
||||
SmtLeaf::new_multiple(pairs).unwrap()
|
||||
// Check for duplicates in a sorted list by comparing adjacent pairs
|
||||
if let Some(window) = pairs.windows(2).find(|window| window[0].0 == window[1].0) {
|
||||
// If we find a duplicate, return an error
|
||||
let col = Self::key_to_leaf_index(&window[0].0).index.value();
|
||||
return Err(MerkleError::DuplicateValuesForIndex(col));
|
||||
}
|
||||
Ok(Some(SmtLeaf::new_multiple(pairs).unwrap()))
|
||||
} else {
|
||||
let (key, value) = pairs.pop().unwrap();
|
||||
// TODO: should we ever be constructing empty leaves from pairs?
|
||||
if value == Self::EMPTY_VALUE {
|
||||
let index = Self::key_to_leaf_index(&key);
|
||||
SmtLeaf::new_empty(index)
|
||||
Ok(None)
|
||||
} else {
|
||||
SmtLeaf::new_single(key, value)
|
||||
Ok(Some(SmtLeaf::new_single(key, value)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -322,13 +346,19 @@ impl Smt {
|
|||
|
||||
if leaf_changed {
|
||||
// Only return the leaf if it actually changed
|
||||
Some(leaf)
|
||||
Ok(Some(leaf))
|
||||
} else {
|
||||
// Return None if leaf hasn't changed
|
||||
None
|
||||
Ok(None)
|
||||
}
|
||||
});
|
||||
(accumulator.leaves, new_pairs)
|
||||
// The closure is the only possible source of errors.
|
||||
// Since it never returns an error - only `Ok(Some(_))` or `Ok(None)` - we can safely assume
|
||||
// `accumulator` is always `Ok(_)`.
|
||||
(
|
||||
accumulator.expect("process_sorted_pairs_to_leaves never fails").leaves,
|
||||
new_pairs,
|
||||
)
|
||||
}
|
||||
|
||||
/// Processes sorted key-value pairs to compute leaves for a subtree.
|
||||
|
@ -352,16 +382,18 @@ impl Smt {
|
|||
/// - `leaves`: A collection of `SubtreeLeaf` structures representing the processed leaves. Each
|
||||
/// `SubtreeLeaf` includes the column index and the hash of the corresponding leaf.
|
||||
///
|
||||
/// # Errors
|
||||
/// Returns an error if the `process_leaf` callback fails.
|
||||
///
|
||||
/// # Panics
|
||||
/// This function will panic in debug mode if the input `pairs` are not sorted by column index.
|
||||
fn process_sorted_pairs_to_leaves<F>(
|
||||
pairs: Vec<(RpoDigest, Word)>,
|
||||
mut process_leaf: F,
|
||||
) -> PairComputations<u64, SmtLeaf>
|
||||
) -> Result<PairComputations<u64, SmtLeaf>, MerkleError>
|
||||
where
|
||||
F: FnMut(Vec<(RpoDigest, Word)>) -> Option<SmtLeaf>,
|
||||
F: FnMut(Vec<(RpoDigest, Word)>) -> Result<Option<SmtLeaf>, MerkleError>,
|
||||
{
|
||||
use rayon::prelude::*;
|
||||
debug_assert!(pairs.is_sorted_by_key(|(key, _)| Self::key_to_leaf_index(key).value()));
|
||||
|
||||
let mut accumulator: PairComputations<u64, SmtLeaf> = Default::default();
|
||||
|
@ -392,8 +424,16 @@ impl Smt {
|
|||
// Otherwise, the next pair is a different column, or there is no next pair. Either way
|
||||
// it's time to swap out our buffer.
|
||||
let leaf_pairs = mem::take(&mut current_leaf_buffer);
|
||||
if let Some(leaf) = process_leaf(leaf_pairs) {
|
||||
|
||||
// Process leaf and propagate any errors
|
||||
match process_leaf(leaf_pairs) {
|
||||
Ok(Some(leaf)) => {
|
||||
accumulator.nodes.insert(col, leaf);
|
||||
},
|
||||
Ok(None) => {
|
||||
// No leaf was constructed for this column. The column will be skipped.
|
||||
},
|
||||
Err(e) => return Err(e),
|
||||
}
|
||||
|
||||
debug_assert!(current_leaf_buffer.is_empty());
|
||||
|
@ -415,7 +455,7 @@ impl Smt {
|
|||
// subtree boundaries as we go. Either way this function is only used at the beginning of a
|
||||
// parallel construction, so it should not be a critical path.
|
||||
accumulator.leaves = SubtreeLeavesIter::from_leaves(&mut accumulated_leaves).collect();
|
||||
accumulator
|
||||
Ok(accumulator)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -539,6 +579,7 @@ fn build_subtree(
|
|||
// construction enforces uniqueness. However, when testing or benchmarking
|
||||
// `build_subtree()` in isolation, duplicate columns can appear if input
|
||||
// constraints are not enforced.
|
||||
use alloc::collections::BTreeSet;
|
||||
let mut seen_cols = BTreeSet::new();
|
||||
for leaf in &leaves {
|
||||
assert!(seen_cols.insert(leaf.col), "Duplicate column found in subtree: {}", leaf.col);
|
||||
|
|
|
@ -3,15 +3,19 @@ use alloc::{
|
|||
vec::Vec,
|
||||
};
|
||||
|
||||
use assert_matches::assert_matches;
|
||||
use proptest::prelude::*;
|
||||
use rand::{prelude::IteratorRandom, thread_rng, Rng};
|
||||
|
||||
use super::{
|
||||
build_subtree, InnerNode, LeafIndex, NodeIndex, NodeMutations, PairComputations, RpoDigest,
|
||||
Smt, SmtLeaf, SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, UnorderedMap, COLS_PER_SUBTREE,
|
||||
SMT_DEPTH, SUBTREE_DEPTH,
|
||||
build_subtree, InnerNode, NodeIndex, NodeMutations, PairComputations, RpoDigest, Smt, SmtLeaf,
|
||||
SparseMerkleTree, SubtreeLeaf, SubtreeLeavesIter, UnorderedMap, COLS_PER_SUBTREE, SMT_DEPTH,
|
||||
SUBTREE_DEPTH,
|
||||
};
|
||||
use crate::{
|
||||
merkle::{smt::Felt, LeafIndex, MerkleError},
|
||||
Word, EMPTY_WORD, ONE, ZERO,
|
||||
};
|
||||
use crate::{merkle::smt::Felt, Word, EMPTY_WORD, ONE, ZERO};
|
||||
|
||||
fn smtleaf_to_subtree_leaf(leaf: &SmtLeaf) -> SubtreeLeaf {
|
||||
SubtreeLeaf {
|
||||
|
@ -72,7 +76,7 @@ fn test_sorted_pairs_to_leaves() {
|
|||
control_subtree_leaves
|
||||
};
|
||||
|
||||
let subtrees: PairComputations<u64, SmtLeaf> = Smt::sorted_pairs_to_leaves(entries);
|
||||
let subtrees: PairComputations<u64, SmtLeaf> = Smt::sorted_pairs_to_leaves(entries).unwrap();
|
||||
// This will check that the hashes, columns, and subtree assignments all match.
|
||||
assert_eq!(subtrees.leaves, control_subtree_leaves);
|
||||
// Flattening and re-separating out the leaves into subtrees should have the same result.
|
||||
|
@ -140,7 +144,7 @@ fn test_single_subtree() {
|
|||
let entries = generate_entries(PAIR_COUNT);
|
||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||
// `entries` should already be sorted by nature of how we constructed it.
|
||||
let leaves = Smt::sorted_pairs_to_leaves(entries).leaves;
|
||||
let leaves = Smt::sorted_pairs_to_leaves(entries).unwrap().leaves;
|
||||
let leaves = leaves.into_iter().next().unwrap();
|
||||
let (first_subtree, subtree_root) = build_subtree(leaves, SMT_DEPTH, SMT_DEPTH);
|
||||
assert!(!first_subtree.is_empty());
|
||||
|
@ -172,7 +176,7 @@ fn test_two_subtrees() {
|
|||
const PAIR_COUNT: u64 = COLS_PER_SUBTREE * 2;
|
||||
let entries = generate_entries(PAIR_COUNT);
|
||||
let control = Smt::with_entries_sequential(entries.clone()).unwrap();
|
||||
let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries);
|
||||
let PairComputations { leaves, .. } = Smt::sorted_pairs_to_leaves(entries).unwrap();
|
||||
// With two subtrees' worth of leaves, we should have exactly two subtrees.
|
||||
let [first, second]: [Vec<_>; 2] = leaves.try_into().unwrap();
|
||||
assert_eq!(first.len() as u64, PAIR_COUNT / 2);
|
||||
|
@ -220,7 +224,7 @@ fn test_singlethreaded_subtrees() {
|
|||
let PairComputations {
|
||||
leaves: mut leaf_subtrees,
|
||||
nodes: test_leaves,
|
||||
} = Smt::sorted_pairs_to_leaves(entries);
|
||||
} = Smt::sorted_pairs_to_leaves(entries).unwrap();
|
||||
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
||||
// There's no flat_map_unzip(), so this is the best we can do.
|
||||
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
|
||||
|
@ -304,7 +308,7 @@ fn test_multithreaded_subtrees() {
|
|||
let PairComputations {
|
||||
leaves: mut leaf_subtrees,
|
||||
nodes: test_leaves,
|
||||
} = Smt::sorted_pairs_to_leaves(entries);
|
||||
} = Smt::sorted_pairs_to_leaves(entries).unwrap();
|
||||
for current_depth in (SUBTREE_DEPTH..=SMT_DEPTH).step_by(SUBTREE_DEPTH as usize).rev() {
|
||||
let (nodes, mut subtree_roots): (Vec<UnorderedMap<_, _>>, Vec<SubtreeLeaf>) = leaf_subtrees
|
||||
.into_par_iter()
|
||||
|
@ -479,6 +483,64 @@ fn test_smt_construction_with_entries_unsorted() {
|
|||
assert_eq!(smt, control);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_smt_construction_with_entries_duplicate_keys() {
|
||||
let entries = [
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]),
|
||||
(RpoDigest::new([ONE; 4]), [ONE; 4]),
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(16)]), [ONE; 4]),
|
||||
];
|
||||
let expected_col = Smt::key_to_leaf_index(&entries[0].0).index.value();
|
||||
let err = Smt::with_entries(entries).unwrap_err();
|
||||
assert_matches!(err, MerkleError::DuplicateValuesForIndex(col) if col == expected_col);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_smt_construction_with_some_empty_values() {
|
||||
let entries = [
|
||||
(RpoDigest::new([ONE, ONE, ONE, ONE]), Smt::EMPTY_VALUE),
|
||||
(RpoDigest::new([ONE, ONE, ONE, Felt::new(2)]), [ONE; 4]),
|
||||
];
|
||||
|
||||
let result = Smt::with_entries(entries);
|
||||
assert!(result.is_ok(), "SMT construction failed with mixed empty values");
|
||||
|
||||
let smt = result.unwrap();
|
||||
let control = Smt::with_entries_sequential(entries).unwrap();
|
||||
|
||||
assert_eq!(smt.num_leaves(), 1);
|
||||
assert_eq!(smt.root(), control.root(), "Root hashes do not match");
|
||||
assert_eq!(smt, control, "SMTs are not equal");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_smt_construction_with_all_empty_values() {
|
||||
let entries = [(RpoDigest::new([ONE, ONE, ONE, ONE]), Smt::EMPTY_VALUE)];
|
||||
|
||||
let result = Smt::with_entries(entries);
|
||||
assert!(result.is_ok(), "SMT construction failed with all empty values");
|
||||
|
||||
let smt = result.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
smt.root(),
|
||||
Smt::default().root(),
|
||||
"SMT with all empty values should have the same root as the default SMT"
|
||||
);
|
||||
assert_eq!(smt, Smt::default(), "SMT with all empty values should be empty");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_smt_construction_with_no_entries() {
|
||||
let entries: [(RpoDigest, Word); 0] = [];
|
||||
|
||||
let result = Smt::with_entries(entries);
|
||||
assert!(result.is_ok(), "SMT construction failed with no entries");
|
||||
|
||||
let smt = result.unwrap();
|
||||
assert_eq!(smt, Smt::default(), "SMT with no entries should be empty");
|
||||
}
|
||||
|
||||
fn arb_felt() -> impl Strategy<Value = Felt> {
|
||||
prop_oneof![any::<u64>().prop_map(Felt::new), Just(ZERO), Just(ONE),]
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue