Merge pull request #171 from 0xPolygonMiden/frisitano-recording-map-finalizer
Introduce TryApplyDiff and refactor RecordingMap finalizer
This commit is contained in:
commit
809b572a40
7 changed files with 276 additions and 35 deletions
153
src/merkle/delta.rs
Normal file
153
src/merkle/delta.rs
Normal file
|
@ -0,0 +1,153 @@
|
|||
use super::{
|
||||
BTreeMap, KvMap, MerkleError, MerkleStore, NodeIndex, RpoDigest, StoreNode, Vec, Word,
|
||||
};
|
||||
use crate::utils::collections::Diff;
|
||||
|
||||
#[cfg(test)]
|
||||
use super::{empty_roots::EMPTY_WORD, Felt, SimpleSmt};
|
||||
|
||||
// MERKLE STORE DELTA
|
||||
// ================================================================================================
|
||||
|
||||
/// [MerkleStoreDelta] stores a vector of ([RpoDigest], [MerkleTreeDelta]) tuples where the
|
||||
/// [RpoDigest] represents the root of the Merkle tree and [MerkleTreeDelta] represents the
|
||||
/// differences between the initial and final Merkle tree states.
|
||||
#[derive(Default, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct MerkleStoreDelta(pub Vec<(RpoDigest, MerkleTreeDelta)>);
|
||||
|
||||
// MERKLE TREE DELTA
|
||||
// ================================================================================================
|
||||
|
||||
/// [MerkleDelta] stores the differences between the initial and final Merkle tree states.
|
||||
///
|
||||
/// The differences are represented as follows:
|
||||
/// - depth: the depth of the merkle tree.
|
||||
/// - cleared_slots: indexes of slots where values were set to [ZERO; 4].
|
||||
/// - updated_slots: index-value pairs of slots where values were set to non [ZERO; 4] values.
|
||||
#[cfg(not(test))]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct MerkleTreeDelta {
|
||||
depth: u8,
|
||||
cleared_slots: Vec<u64>,
|
||||
updated_slots: Vec<(u64, Word)>,
|
||||
}
|
||||
|
||||
impl MerkleTreeDelta {
|
||||
// CONSTRUCTOR
|
||||
// --------------------------------------------------------------------------------------------
|
||||
pub fn new(depth: u8) -> Self {
|
||||
Self {
|
||||
depth,
|
||||
cleared_slots: Vec::new(),
|
||||
updated_slots: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
// ACCESSORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Returns the depth of the Merkle tree the [MerkleDelta] is associated with.
|
||||
pub fn depth(&self) -> u8 {
|
||||
self.depth
|
||||
}
|
||||
|
||||
/// Returns the indexes of slots where values were set to [ZERO; 4].
|
||||
pub fn cleared_slots(&self) -> &[u64] {
|
||||
&self.cleared_slots
|
||||
}
|
||||
|
||||
/// Returns the index-value pairs of slots where values were set to non [ZERO; 4] values.
|
||||
pub fn updated_slots(&self) -> &[(u64, Word)] {
|
||||
&self.updated_slots
|
||||
}
|
||||
|
||||
// MODIFIERS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
/// Adds a slot index to the list of cleared slots.
|
||||
pub fn add_cleared_slot(&mut self, index: u64) {
|
||||
self.cleared_slots.push(index);
|
||||
}
|
||||
|
||||
/// Adds a slot index and a value to the list of updated slots.
|
||||
pub fn add_updated_slot(&mut self, index: u64, value: Word) {
|
||||
self.updated_slots.push((index, value));
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracts a [MerkleDelta] object by comparing the leaves of two Merkle trees specifies by
|
||||
/// their roots and depth.
|
||||
pub fn merkle_tree_delta<T: KvMap<RpoDigest, StoreNode>>(
|
||||
tree_root_1: RpoDigest,
|
||||
tree_root_2: RpoDigest,
|
||||
depth: u8,
|
||||
merkle_store: &MerkleStore<T>,
|
||||
) -> Result<MerkleTreeDelta, MerkleError> {
|
||||
if tree_root_1 == tree_root_2 {
|
||||
return Ok(MerkleTreeDelta::new(depth));
|
||||
}
|
||||
|
||||
let tree_1_leaves: BTreeMap<NodeIndex, RpoDigest> =
|
||||
merkle_store.non_empty_leaves(tree_root_1, depth).collect();
|
||||
let tree_2_leaves: BTreeMap<NodeIndex, RpoDigest> =
|
||||
merkle_store.non_empty_leaves(tree_root_2, depth).collect();
|
||||
let diff = tree_1_leaves.diff(&tree_2_leaves);
|
||||
|
||||
// TODO: Refactor this diff implementation to prevent allocation of both BTree and Vec.
|
||||
Ok(MerkleTreeDelta {
|
||||
depth,
|
||||
cleared_slots: diff.removed.into_iter().map(|index| index.value()).collect(),
|
||||
updated_slots: diff
|
||||
.updated
|
||||
.into_iter()
|
||||
.map(|(index, leaf)| (index.value(), *leaf))
|
||||
.collect(),
|
||||
})
|
||||
}
|
||||
|
||||
// INTERNALS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
#[cfg(test)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct MerkleTreeDelta {
|
||||
pub depth: u8,
|
||||
pub cleared_slots: Vec<u64>,
|
||||
pub updated_slots: Vec<(u64, Word)>,
|
||||
}
|
||||
|
||||
// MERKLE DELTA
|
||||
// ================================================================================================
|
||||
#[test]
|
||||
fn test_compute_merkle_delta() {
|
||||
let entries = vec![
|
||||
(10, [Felt::new(0), Felt::new(1), Felt::new(2), Felt::new(3)]),
|
||||
(15, [Felt::new(4), Felt::new(5), Felt::new(6), Felt::new(7)]),
|
||||
(20, [Felt::new(8), Felt::new(9), Felt::new(10), Felt::new(11)]),
|
||||
(31, [Felt::new(12), Felt::new(13), Felt::new(14), Felt::new(15)]),
|
||||
];
|
||||
let simple_smt = SimpleSmt::with_leaves(30, entries.clone()).unwrap();
|
||||
let mut store: MerkleStore = (&simple_smt).into();
|
||||
let root = simple_smt.root();
|
||||
|
||||
// add a new node
|
||||
let new_value = [Felt::new(16), Felt::new(17), Felt::new(18), Felt::new(19)];
|
||||
let new_index = NodeIndex::new(simple_smt.depth(), 32).unwrap();
|
||||
let root = store.set_node(root, new_index, new_value.into()).unwrap().root;
|
||||
|
||||
// update an existing node
|
||||
let update_value = [Felt::new(20), Felt::new(21), Felt::new(22), Felt::new(23)];
|
||||
let update_idx = NodeIndex::new(simple_smt.depth(), entries[0].0).unwrap();
|
||||
let root = store.set_node(root, update_idx, update_value.into()).unwrap().root;
|
||||
|
||||
// remove a node
|
||||
let remove_idx = NodeIndex::new(simple_smt.depth(), entries[1].0).unwrap();
|
||||
let root = store.set_node(root, remove_idx, EMPTY_WORD.into()).unwrap().root;
|
||||
|
||||
let merkle_delta =
|
||||
merkle_tree_delta(simple_smt.root(), root, simple_smt.depth(), &store).unwrap();
|
||||
let expected_merkle_delta = MerkleTreeDelta {
|
||||
depth: simple_smt.depth(),
|
||||
cleared_slots: vec![remove_idx.value()],
|
||||
updated_slots: vec![(update_idx.value(), update_value), (new_index.value(), new_value)],
|
||||
};
|
||||
|
||||
assert_eq!(merkle_delta, expected_merkle_delta);
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
use super::{
|
||||
hash::rpo::{Rpo256, RpoDigest},
|
||||
utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, Vec},
|
||||
utils::collections::{vec, BTreeMap, BTreeSet, KvMap, RecordingMap, TryApplyDiff, Vec},
|
||||
Felt, StarkField, Word, WORD_SIZE, ZERO,
|
||||
};
|
||||
use core::fmt;
|
||||
|
@ -11,6 +11,9 @@ use core::fmt;
|
|||
mod empty_roots;
|
||||
pub use empty_roots::EmptySubtreeRoots;
|
||||
|
||||
mod delta;
|
||||
pub use delta::{merkle_tree_delta, MerkleStoreDelta, MerkleTreeDelta};
|
||||
|
||||
mod index;
|
||||
pub use index::NodeIndex;
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use super::{
|
||||
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, NodeIndex,
|
||||
Rpo256, RpoDigest, Vec, Word,
|
||||
BTreeMap, BTreeSet, EmptySubtreeRoots, InnerNodeInfo, MerkleError, MerklePath, MerkleTreeDelta,
|
||||
NodeIndex, Rpo256, RpoDigest, StoreNode, TryApplyDiff, Vec, Word,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -275,3 +275,29 @@ impl BranchNode {
|
|||
Rpo256::merge(&[self.left, self.right])
|
||||
}
|
||||
}
|
||||
|
||||
// TRY APPLY DIFF
|
||||
// ================================================================================================
|
||||
impl TryApplyDiff<RpoDigest, StoreNode> for SimpleSmt {
|
||||
type Error = MerkleError;
|
||||
type DiffType = MerkleTreeDelta;
|
||||
|
||||
fn try_apply(&mut self, diff: MerkleTreeDelta) -> Result<(), MerkleError> {
|
||||
if diff.depth() != self.depth() {
|
||||
return Err(MerkleError::InvalidDepth {
|
||||
expected: self.depth(),
|
||||
provided: diff.depth(),
|
||||
});
|
||||
}
|
||||
|
||||
for slot in diff.cleared_slots() {
|
||||
self.update_leaf(*slot, Self::EMPTY_VALUE)?;
|
||||
}
|
||||
|
||||
for (slot, value) in diff.updated_slots() {
|
||||
self.update_leaf(*slot, *value)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,12 +1,9 @@
|
|||
use super::{
|
||||
mmr::Mmr, BTreeMap, EmptySubtreeRoots, InnerNodeInfo, KvMap, MerkleError, MerklePath,
|
||||
MerklePathSet, MerkleTree, NodeIndex, RecordingMap, RootPath, Rpo256, RpoDigest, SimpleSmt,
|
||||
TieredSmt, ValuePath, Vec,
|
||||
};
|
||||
use crate::utils::{
|
||||
collections::{ApplyDiff, Diff, KvMapDiff},
|
||||
ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable,
|
||||
empty_roots::EMPTY_WORD, mmr::Mmr, BTreeMap, EmptySubtreeRoots, InnerNodeInfo, KvMap,
|
||||
MerkleError, MerklePath, MerklePathSet, MerkleStoreDelta, MerkleTree, NodeIndex, RecordingMap,
|
||||
RootPath, Rpo256, RpoDigest, SimpleSmt, TieredSmt, TryApplyDiff, ValuePath, Vec,
|
||||
};
|
||||
use crate::utils::{ByteReader, ByteWriter, Deserializable, DeserializationError, Serializable};
|
||||
use core::borrow::Borrow;
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -280,6 +277,37 @@ impl<T: KvMap<RpoDigest, StoreNode>> MerkleStore<T> {
|
|||
})
|
||||
}
|
||||
|
||||
/// Iterator over the non-empty leaves of the Merkle tree associated with the specified `root`
|
||||
/// and `max_depth`.
|
||||
pub fn non_empty_leaves(
|
||||
&self,
|
||||
root: RpoDigest,
|
||||
max_depth: u8,
|
||||
) -> impl Iterator<Item = (NodeIndex, RpoDigest)> + '_ {
|
||||
let empty_roots = EmptySubtreeRoots::empty_hashes(max_depth);
|
||||
let mut stack = Vec::new();
|
||||
stack.push((NodeIndex::new_unchecked(0, 0), root));
|
||||
|
||||
core::iter::from_fn(move || {
|
||||
while let Some((index, node_hash)) = stack.pop() {
|
||||
if index.depth() == max_depth {
|
||||
return Some((index, node_hash));
|
||||
}
|
||||
|
||||
if let Some(node) = self.nodes.get(&node_hash) {
|
||||
if !empty_roots.contains(&node.left) {
|
||||
stack.push((index.left_child(), node.left));
|
||||
}
|
||||
if !empty_roots.contains(&node.right) {
|
||||
stack.push((index.right_child(), node.right));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
// STATE MUTATORS
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
|
@ -462,7 +490,6 @@ impl<T: KvMap<RpoDigest, StoreNode>> FromIterator<(RpoDigest, StoreNode)> for Me
|
|||
|
||||
// ITERATORS
|
||||
// ================================================================================================
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> Extend<InnerNodeInfo> for MerkleStore<T> {
|
||||
fn extend<I: IntoIterator<Item = InnerNodeInfo>>(&mut self, iter: I) {
|
||||
self.nodes.extend(iter.into_iter().map(|info| {
|
||||
|
@ -479,19 +506,34 @@ impl<T: KvMap<RpoDigest, StoreNode>> Extend<InnerNodeInfo> for MerkleStore<T> {
|
|||
|
||||
// DiffT & ApplyDiffT TRAIT IMPLEMENTATION
|
||||
// ================================================================================================
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> Diff<RpoDigest, StoreNode> for MerkleStore<T> {
|
||||
type DiffType = KvMapDiff<RpoDigest, StoreNode>;
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> TryApplyDiff<RpoDigest, StoreNode> for MerkleStore<T> {
|
||||
type Error = MerkleError;
|
||||
type DiffType = MerkleStoreDelta;
|
||||
|
||||
fn diff(&self, other: &Self) -> Self::DiffType {
|
||||
self.nodes.diff(&other.nodes)
|
||||
}
|
||||
}
|
||||
fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), MerkleError> {
|
||||
for (root, delta) in diff.0 {
|
||||
let mut root = root;
|
||||
for cleared_slot in delta.cleared_slots() {
|
||||
root = self
|
||||
.set_node(
|
||||
root,
|
||||
NodeIndex::new(delta.depth(), *cleared_slot)?,
|
||||
EMPTY_WORD.into(),
|
||||
)?
|
||||
.root;
|
||||
}
|
||||
for (updated_slot, updated_value) in delta.updated_slots() {
|
||||
root = self
|
||||
.set_node(
|
||||
root,
|
||||
NodeIndex::new(delta.depth(), *updated_slot)?,
|
||||
(*updated_value).into(),
|
||||
)?
|
||||
.root;
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: KvMap<RpoDigest, StoreNode>> ApplyDiff<RpoDigest, StoreNode> for MerkleStore<T> {
|
||||
type DiffType = KvMapDiff<RpoDigest, StoreNode>;
|
||||
|
||||
fn apply(&mut self, diff: Self::DiffType) {
|
||||
self.nodes.apply(diff);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -847,7 +847,7 @@ fn test_recorder() {
|
|||
|
||||
// construct the proof
|
||||
let rec_map = recorder.into_inner();
|
||||
let proof = rec_map.into_proof();
|
||||
let (_, proof) = rec_map.finalize();
|
||||
let merkle_store: MerkleStore = proof.into();
|
||||
|
||||
// make sure the proof contains all nodes from both trees
|
||||
|
|
|
@ -1,16 +1,31 @@
|
|||
/// A trait for computing the difference between two objects.
|
||||
pub trait Diff<K: Ord + Clone, V: Clone> {
|
||||
/// The type that describes the difference between two objects.
|
||||
type DiffType;
|
||||
|
||||
/// Returns a `Self::DiffType` object that represents the difference between this object and
|
||||
/// Returns a [Self::DiffType] object that represents the difference between this object and
|
||||
/// other.
|
||||
fn diff(&self, other: &Self) -> Self::DiffType;
|
||||
}
|
||||
|
||||
/// A trait for applying the difference between two objects.
|
||||
pub trait ApplyDiff<K: Ord + Clone, V: Clone> {
|
||||
/// The type that describes the difference between two objects.
|
||||
type DiffType;
|
||||
|
||||
/// Applies the provided changes described by [DiffType] to the object implementing this trait.
|
||||
/// Applies the provided changes described by [Self::DiffType] to the object implementing this trait.
|
||||
fn apply(&mut self, diff: Self::DiffType);
|
||||
}
|
||||
|
||||
/// A trait for applying the difference between two objects with the possibility of failure.
|
||||
pub trait TryApplyDiff<K: Ord + Clone, V: Clone> {
|
||||
/// The type that describes the difference between two objects.
|
||||
type DiffType;
|
||||
|
||||
/// An error type that can be returned if the changes cannot be applied.
|
||||
type Error;
|
||||
|
||||
/// Applies the provided changes described by [Self::DiffType] to the object implementing this trait.
|
||||
/// Returns an error if the changes cannot be applied.
|
||||
fn try_apply(&mut self, diff: Self::DiffType) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
|
|
@ -97,10 +97,12 @@ impl<K: Ord + Clone, V: Clone> RecordingMap<K, V> {
|
|||
// FINALIZER
|
||||
// --------------------------------------------------------------------------------------------
|
||||
|
||||
/// Consumes the [RecordingMap] and returns a [BTreeMap] containing the key-value pairs from
|
||||
/// the initial data set that were read during recording.
|
||||
pub fn into_proof(self) -> BTreeMap<K, V> {
|
||||
self.trace.take()
|
||||
/// Consumes the [RecordingMap] and returns a ([BTreeMap], [BTreeMap]) tuple. The first
|
||||
/// element of the tuple is a map that represents the state of the map at the time `.finalize()`
|
||||
/// is called. The second element contains the key-value pairs from the initial data set that
|
||||
/// were read during recording.
|
||||
pub fn finalize(self) -> (BTreeMap<K, V>, BTreeMap<K, V>) {
|
||||
(self.data, self.trace.take())
|
||||
}
|
||||
|
||||
// TEST HELPERS
|
||||
|
@ -217,8 +219,8 @@ impl<K: Clone + Ord, V: Clone> IntoIterator for RecordingMap<K, V> {
|
|||
/// - `removed` - a set of keys that were removed from the second map compared to the first map.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct KvMapDiff<K, V> {
|
||||
updated: BTreeMap<K, V>,
|
||||
removed: BTreeSet<K>,
|
||||
pub updated: BTreeMap<K, V>,
|
||||
pub removed: BTreeSet<K>,
|
||||
}
|
||||
|
||||
impl<K, V> KvMapDiff<K, V> {
|
||||
|
@ -296,7 +298,7 @@ mod tests {
|
|||
}
|
||||
|
||||
// convert the map into a proof
|
||||
let proof = map.into_proof();
|
||||
let (_, proof) = map.finalize();
|
||||
|
||||
// check that the proof contains the expected values
|
||||
for (key, value) in ITEMS.iter() {
|
||||
|
@ -319,7 +321,7 @@ mod tests {
|
|||
}
|
||||
|
||||
// convert the map into a proof
|
||||
let proof = map.into_proof();
|
||||
let (_, proof) = map.finalize();
|
||||
|
||||
// check that the proof contains the expected values
|
||||
for (key, _) in ITEMS.iter() {
|
||||
|
@ -383,7 +385,7 @@ mod tests {
|
|||
|
||||
// Note: The length reported by the proof will be different to the length originally
|
||||
// reported by the map.
|
||||
let proof = map.into_proof();
|
||||
let (_, proof) = map.finalize();
|
||||
|
||||
// length of the proof should be equal to get_items + 1. The extra item is the original
|
||||
// value at key = 4u64
|
||||
|
@ -458,7 +460,7 @@ mod tests {
|
|||
assert_eq!(map.updates_len(), 2);
|
||||
|
||||
// convert the map into a proof
|
||||
let proof = map.into_proof();
|
||||
let (_, proof) = map.finalize();
|
||||
|
||||
// check that the proof contains the expected values
|
||||
for (key, value) in ITEMS.iter() {
|
||||
|
|
Loading…
Add table
Reference in a new issue